""" SE(3)-invariant feature extraction for interface graphs. Node and edge features used by the Q_theta scorer. """ import os import sys import numpy as np # Ensure utils is importable (for both direct and package imports) _CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if _CODE_DIR not in sys.path: sys.path.insert(0, _CODE_DIR) from utils.pdb_utils import ( rbf_encode, compute_backbone_frames, compute_torsion_angles, get_aa_indices, compute_chi_angles, get_cb_positions, NUM_AA ) # Feature dimensions # one-hot AA (21) + backbone torsions (6) + chi1 sin/cos (2) + chi2 sin/cos (2) + chain indicator (1) = 32 NODE_DIM = NUM_AA + 6 + 4 + 1 # = 32 EDGE_DIM = 16 + 3 + 9 + 8 + 1 # RBF dist (16) + direction (3) + rel rotation (9) + seq sep (8) + same chain (1) = 37 MAX_SEQ_SEP = 32 # bins for sequence separation def seq_sep_encode(sep, n_bins=8, max_sep=MAX_SEQ_SEP): """Bin-encode sequence separation.""" bins = np.linspace(-max_sep, max_sep, n_bins + 1) sep_clipped = np.clip(sep, -max_sep, max_sep) encoded = np.zeros(n_bins, dtype=np.float32) bin_idx = np.digitize(sep_clipped, bins) - 1 bin_idx = np.clip(bin_idx, 0, n_bins - 1) encoded[bin_idx] = 1.0 return encoded def extract_node_features(residues, coords, mask, torsion_angles, chi_angles, chain_id): """ Compute per-residue node features. Args: residues: list of Bio.PDB residues coords: [N, 4, 3] backbone coords mask: [N] bool torsion_angles: [N, 6] sin/cos of phi, psi, omega chi_angles: [N, 4] sin/cos of chi1, chi2 chain_id: 0 = receptor, 1 = binder Returns: node_feats: [N, NODE_DIM] (NODE_DIM = 32) """ N = len(residues) aa_idx = get_aa_indices(residues) # One-hot amino acid aa_onehot = np.zeros((N, NUM_AA), dtype=np.float32) for i in range(N): if mask[i]: aa_onehot[i, aa_idx[i]] = 1.0 # Chain indicator chain_feat = np.full((N, 1), chain_id, dtype=np.float32) # Concatenate node_feats = np.concatenate([ aa_onehot, # [N, 21] torsion_angles, # [N, 6] chi_angles, # [N, 4] chain_feat, # [N, 1] ], axis=-1) return node_feats # [N, 32] def extract_edge_features(coords_i, frames_i, coords_j, frames_j, seq_idx_i, seq_idx_j, chain_i, chain_j, mask_i, mask_j): """ Compute SE(3)-invariant edge features between residue sets i and j. Vectorized over all pairs. Args: coords_i: [N_i, 4, 3] backbone coords of set i (full interface) frames_i: (origins_i [N_i, 3], rotations_i [N_i, 3, 3]) coords_j: [N_j, 4, 3] frames_j: (origins_j [N_j, 3], rotations_j [N_j, 3, 3]) seq_idx_i: [N_i] integer sequence indices (for sequence separation) seq_idx_j: [N_j] integer sequence indices chain_i: int (0 or 1) chain_j: int (0 or 1) mask_i: [N_i] bool mask_j: [N_j] bool Returns: edge_feats: [N_i, N_j, EDGE_DIM] """ N_i, N_j = len(coords_i), len(coords_j) origins_i, rotations_i = frames_i origins_j, rotations_j = frames_j ca_i = origins_i # [N_i, 3] ca_j = origins_j # [N_j, 3] # --- Distance features --- diff = ca_j[None, :, :] - ca_i[:, None, :] # [N_i, N_j, 3] dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_i, N_j] dist_rbf = rbf_encode(dist, d_min=0., d_max=20., n_bins=16) # [N_i, N_j, 16] # --- Direction in local frame of i --- # unit vector from i to j in global frame unit_diff = diff / (dist[..., None] + 1e-8) # [N_i, N_j, 3] # rotate by R_i^T to get local direction # rotations_i: [N_i, 3, 3], unit_diff: [N_i, N_j, 3] # local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist local_dir = np.einsum('ikl,ijl->ijk', rotations_i, unit_diff) # [N_i, N_j, 3] # --- Relative rotation: R_i^T R_j --- # rotations_i: [N_i, 3, 3], rotations_j: [N_j, 3, 3] # rel_rot[i,j] = R_i^T @ R_j -> [N_i, N_j, 3, 3] -> flatten to [N_i, N_j, 9] rel_rot = np.einsum('ikl,jlm->ijkm', rotations_i, rotations_j) # [N_i, N_j, 3, 3] rel_rot_flat = rel_rot.reshape(N_i, N_j, 9) # [N_i, N_j, 9] # --- Sequence separation --- sep = seq_idx_j[None, :] - seq_idx_i[:, None] # [N_i, N_j] # Encode each pair (loop over all; use vectorized bin assignment) sep_flat = sep.reshape(-1) sep_enc = np.array([seq_sep_encode(s) for s in sep_flat]) # [N_i*N_j, 8] sep_enc = sep_enc.reshape(N_i, N_j, 8) # Cross-chain pairs get sep=0 by convention if different chains if chain_i != chain_j: sep_enc[:] = 0.0 # --- Same chain indicator --- same_chain = float(chain_i == chain_j) same_chain_feat = np.full((N_i, N_j, 1), same_chain, dtype=np.float32) # --- Concatenate --- edge_feats = np.concatenate([ dist_rbf, # [N_i, N_j, 16] local_dir, # [N_i, N_j, 3] rel_rot_flat, # [N_i, N_j, 9] sep_enc, # [N_i, N_j, 8] same_chain_feat # [N_i, N_j, 1] ], axis=-1) # [N_i, N_j, 37] # Zero out edges involving masked residues edge_feats[~mask_i, :, :] = 0.0 edge_feats[:, ~mask_j, :] = 0.0 return edge_feats.astype(np.float32) def build_interface_graph(rec_residues, rec_coords, rec_mask, binder_residues, binder_coords, binder_mask, rec_interface_mask, binder_interface_mask, max_nodes: int = 128): """ Build a joint interface graph combining receptor and binder interface residues. Returns a dict with: node_feats: [N_total, NODE_DIM] edge_feats: [N_total, N_total, EDGE_DIM] node_mask: [N_total] bool n_rec: int (number of receptor interface nodes) n_binder: int (number of binder interface nodes) """ # Select interface residues rec_iface_idx = np.where(rec_interface_mask)[0] binder_iface_idx = np.where(binder_interface_mask)[0] # Truncate if too many if len(rec_iface_idx) > max_nodes // 2: rec_iface_idx = rec_iface_idx[:max_nodes // 2] if len(binder_iface_idx) > max_nodes // 2: binder_iface_idx = binder_iface_idx[:max_nodes // 2] n_rec = len(rec_iface_idx) n_binder = len(binder_iface_idx) n_total = n_rec + n_binder if n_total == 0: return None # Extract coords for interface residues rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3] binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3] rec_iface_mask = rec_mask[rec_iface_idx] binder_iface_mask = binder_mask[binder_iface_idx] # Compute backbone frames rec_origins, rec_rotations = compute_backbone_frames(rec_iface_coords, rec_iface_mask) binder_origins, binder_rotations = compute_backbone_frames(binder_iface_coords, binder_iface_mask) # Compute torsion angles # We need full-chain coords for proper phi/psi computation, but use local approximation here rec_torsion = compute_torsion_angles(rec_iface_coords, rec_iface_mask) binder_torsion = compute_torsion_angles(binder_iface_coords, binder_iface_mask) # Extract residues rec_iface_residues = [rec_residues[i] for i in rec_iface_idx] binder_iface_residues = [binder_residues[i] for i in binder_iface_idx] # Compute sidechain chi1/chi2 angles rec_chi = compute_chi_angles(rec_iface_residues, rec_iface_mask) binder_chi = compute_chi_angles(binder_iface_residues, binder_iface_mask) # Node features rec_node_feats = extract_node_features( rec_iface_residues, rec_iface_coords, rec_iface_mask, rec_torsion, rec_chi, chain_id=0 ) # [n_rec, NODE_DIM] binder_node_feats = extract_node_features( binder_iface_residues, binder_iface_coords, binder_iface_mask, binder_torsion, binder_chi, chain_id=1 ) # [n_binder, NODE_DIM] node_feats = np.concatenate([rec_node_feats, binder_node_feats], axis=0) # [N, NODE_DIM] node_mask = np.concatenate([rec_iface_mask, binder_iface_mask], axis=0) # Edge features (4 blocks: RR, RB, BR, BB) all_coords = np.concatenate([rec_iface_coords, binder_iface_coords], axis=0) all_mask = node_mask all_origins = np.concatenate([rec_origins, binder_origins], axis=0) all_rotations = np.concatenate([rec_rotations, binder_rotations], axis=0) all_seq_idx = np.concatenate([rec_iface_idx, binder_iface_idx + len(rec_residues)], axis=0) all_chain = np.array([0] * n_rec + [1] * n_binder, dtype=np.int32) # Compute full NxN edge features frames_all = (all_origins, all_rotations) edge_feats = extract_edge_features( all_coords, frames_all, all_coords, frames_all, all_seq_idx, all_seq_idx, -1, -1, # chain handled via all_chain array below all_mask, all_mask ) # [N, N, EDGE_DIM] # Patch same_chain feature (last dim) using actual chain IDs same_chain_feat = (all_chain[:, None] == all_chain[None, :]).astype(np.float32) edge_feats[:, :, -1] = same_chain_feat return { 'node_feats': node_feats.astype(np.float32), # [N, NODE_DIM] 'edge_feats': edge_feats.astype(np.float32), # [N, N, EDGE_DIM] 'node_mask': node_mask, # [N] 'n_rec': n_rec, 'n_binder': n_binder, 'rec_iface_idx': rec_iface_idx, # [n_rec] original residue indices 'binder_iface_idx': binder_iface_idx, # [n_binder] original residue indices }