""" Differentiable feature extraction for Q_theta guidance. This module re-implements the key feature extraction functions from features.py and pdb_utils.py using PyTorch operations, enabling gradient computation through Q_theta with respect to backbone coordinates. The differentiable path: coords (N,4,3) → backbone frames → torsions, distances, directions, rotations → node_feats, edge_feats → Q_theta → score → backward() → ∇coords Non-differentiable features (AA one-hot, chain_id, seq_sep, same_chain) are treated as constants. """ import os import torch import torch.nn.functional as F import numpy as np # ── Differentiable backbone frame computation ──────────────────────────────── def compute_backbone_frames_torch(coords, mask): """ Compute SE(3)-equivariant backbone frames from N, CA, C atoms. Differentiable w.r.t. coords. Args: coords: [N, 4, 3] backbone coords (N, CA, C, O) — requires_grad=True for binder mask: [N] bool tensor Returns: origins: [N, 3] = CA positions rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes) """ N_res = coords.shape[0] device = coords.device origins = coords[:, 1, :] # CA positions [N, 3] rotations = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0).expand(N_res, -1, -1).clone() ca = coords[:, 1, :] # [N, 3] n_atom = coords[:, 0, :] # [N, 3] c_atom = coords[:, 2, :] # [N, 3] # z-axis: CA -> C z = c_atom - ca # [N, 3] z_norm = torch.norm(z, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1] z = z / z_norm # [N, 3] # y-axis: CA -> N, orthogonalized against z y = n_atom - ca # [N, 3] y_proj = (y * z).sum(dim=-1, keepdim=True) # [N, 1] y = y - y_proj * z # [N, 3] y_norm = torch.norm(y, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1] y = y / y_norm # [N, 3] # x-axis: y cross z x = torch.cross(y, z, dim=-1) # [N, 3] # Stack columns: [N, 3, 3] where columns are x, y, z rot = torch.stack([x, y, z], dim=-1) # [N, 3, 3] # Apply mask: identity for masked residues mask_f = mask.float().unsqueeze(-1).unsqueeze(-1) # [N, 1, 1] eye = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0) # [1, 3, 3] rotations = rot * mask_f + eye * (1 - mask_f) return origins, rotations # ── Differentiable torsion angle computation ───────────────────────────────── def _dihedral_torch(p0, p1, p2, p3): """ Compute dihedral angle for batches of 4 points. Returns sin, cos. Differentiable w.r.t. all inputs. Args: p0, p1, p2, p3: [N, 3] tensors Returns: sin_angle: [N] cos_angle: [N] """ b1 = p1 - p0 # [N, 3] b2 = p2 - p1 b3 = p3 - p2 n1 = torch.cross(b1, b2, dim=-1) # [N, 3] n2 = torch.cross(b2, b3, dim=-1) n1_norm = torch.norm(n1, dim=-1, keepdim=True).clamp(min=1e-8) n2_norm = torch.norm(n2, dim=-1, keepdim=True).clamp(min=1e-8) n1 = n1 / n1_norm n2 = n2 / n2_norm b2_norm = torch.norm(b2, dim=-1, keepdim=True).clamp(min=1e-8) m1 = torch.cross(n1, b2 / b2_norm, dim=-1) # [N, 3] cos_angle = (n1 * n2).sum(dim=-1) # [N] sin_angle = (m1 * n2).sum(dim=-1) # [N] return sin_angle, cos_angle def compute_torsion_angles_torch(coords, mask): """ Compute backbone torsion angles (phi, psi, omega) as sin/cos pairs. Differentiable w.r.t. coords. Args: coords: [N, 4, 3] backbone coords (N, CA, C, O) mask: [N] bool tensor Returns: torsions: [N, 6] (sin_phi, cos_phi, sin_psi, cos_psi, sin_omega, cos_omega) """ N = coords.shape[0] device = coords.device torsions = torch.zeros(N, 6, device=device, dtype=coords.dtype) if N < 2: return torsions n_atoms = coords[:, 0, :] # N atoms [N, 3] ca_atoms = coords[:, 1, :] # CA atoms c_atoms = coords[:, 2, :] # C atoms # Phi: C_{i-1} - N_i - CA_i - C_i (for i >= 1) if N > 1: phi_mask = mask[1:] & mask[:-1] # [N-1] sin_phi, cos_phi = _dihedral_torch( c_atoms[:-1], # C_{i-1} n_atoms[1:], # N_i ca_atoms[1:], # CA_i c_atoms[1:] # C_i ) torsions[1:, 0] = sin_phi * phi_mask.float() torsions[1:, 1] = cos_phi * phi_mask.float() # Psi: N_i - CA_i - C_i - N_{i+1} (for i < N-1) if N > 1: psi_mask = mask[:-1] & mask[1:] # [N-1] sin_psi, cos_psi = _dihedral_torch( n_atoms[:-1], # N_i ca_atoms[:-1], # CA_i c_atoms[:-1], # C_i n_atoms[1:] # N_{i+1} ) torsions[:-1, 2] = sin_psi * psi_mask.float() torsions[:-1, 3] = cos_psi * psi_mask.float() # Omega: CA_{i-1} - C_{i-1} - N_i - CA_i (for i >= 1) if N > 1: omega_mask = mask[1:] & mask[:-1] # [N-1] sin_omega, cos_omega = _dihedral_torch( ca_atoms[:-1], # CA_{i-1} c_atoms[:-1], # C_{i-1} n_atoms[1:], # N_i ca_atoms[1:] # CA_i ) torsions[1:, 4] = sin_omega * omega_mask.float() torsions[1:, 5] = cos_omega * omega_mask.float() return torsions # ── Differentiable RBF distance encoding ───────────────────────────────────── def rbf_encode_torch(distances, d_min=0.0, d_max=20.0, n_bins=16): """ RBF encoding of distances using Gaussian basis functions. Differentiable w.r.t. distances. Args: distances: [...] tensor Returns: encoded: [..., n_bins] tensor """ centers = torch.linspace(d_min, d_max, n_bins, device=distances.device, dtype=distances.dtype) sigma = (d_max - d_min) / (n_bins - 1) return torch.exp(-((distances.unsqueeze(-1) - centers) ** 2) / (2 * sigma ** 2)) # ── Differentiable edge feature computation ────────────────────────────────── def compute_edge_features_torch(origins, rotations, seq_idx, chain_ids, mask, n_bins_rbf=16, n_bins_sep=8, max_sep=32): """ Compute SE(3)-invariant edge features between all residue pairs. Differentiable w.r.t. origins and rotations (which derive from coords). Args: origins: [N, 3] CA positions rotations: [N, 3, 3] backbone frame rotations seq_idx: [N] int tensor — sequence indices (non-differentiable) chain_ids: [N] int tensor — chain labels (non-differentiable) mask: [N] bool tensor Returns: edge_feats: [N, N, 37] """ N = origins.shape[0] device = origins.device dtype = origins.dtype # --- Distance features (differentiable) --- diff = origins.unsqueeze(1) - origins.unsqueeze(0) # [N, N, 3] dist = torch.norm(diff, dim=-1).clamp(min=1e-8) # [N, N] dist_rbf = rbf_encode_torch(dist, d_min=0., d_max=20., n_bins=n_bins_rbf) # [N, N, 16] # --- Direction in local frame (differentiable) --- unit_diff = diff / dist.unsqueeze(-1) # [N, N, 3] # local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist # rotations: [N, 3, 3], unit_diff: [N, N, 3] local_dir = torch.einsum('ikl,ijl->ijk', rotations, unit_diff) # [N, N, 3] # --- Relative rotation (differentiable) --- # rel_rot[i,j] = R_i^T @ R_j -> [N, N, 3, 3] -> flatten to [N, N, 9] rel_rot = torch.einsum('ikl,jlm->ijkm', rotations, rotations) # [N, N, 3, 3] rel_rot_flat = rel_rot.reshape(N, N, 9) # [N, N, 9] # --- Sequence separation (non-differentiable, constant) --- sep = seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0) # [N, N] bins = torch.linspace(-max_sep, max_sep, n_bins_sep + 1, device=device) sep_clipped = sep.float().clamp(-max_sep, max_sep) # Bin encoding via soft assignment (but really we just use hard binning) sep_enc = torch.zeros(N, N, n_bins_sep, device=device, dtype=dtype) bin_idx = torch.bucketize(sep_clipped, bins) - 1 bin_idx = bin_idx.clamp(0, n_bins_sep - 1) # Scatter one-hot sep_enc.scatter_(2, bin_idx.unsqueeze(-1).long(), 1.0) # Cross-chain pairs get sep=0 same_chain = (chain_ids.unsqueeze(1) == chain_ids.unsqueeze(0)) # [N, N] cross_chain = ~same_chain sep_enc[cross_chain] = 0.0 # --- Same chain indicator (non-differentiable, constant) --- same_chain_feat = same_chain.float().unsqueeze(-1) # [N, N, 1] # --- Concatenate --- edge_feats = torch.cat([ dist_rbf, # [N, N, 16] local_dir, # [N, N, 3] rel_rot_flat, # [N, N, 9] sep_enc, # [N, N, 8] same_chain_feat # [N, N, 1] ], dim=-1) # [N, N, 37] # Zero out edges involving masked residues mask_2d = mask.unsqueeze(1) & mask.unsqueeze(0) # [N, N] edge_feats = edge_feats * mask_2d.unsqueeze(-1).float() return edge_feats # ── Full differentiable interface graph builder ────────────────────────────── def build_differentiable_interface_graph( rec_coords, rec_mask, rec_aa_idx, rec_chi, binder_coords, binder_mask, binder_aa_idx, binder_chi, cutoff=8.0, max_nodes=128 ): """ Build interface graph with differentiable features w.r.t. binder_coords. Receptor coords are treated as constants (detached). Args: rec_coords: [N_rec, 4, 3] — receptor backbone coords (constant, no grad) rec_mask: [N_rec] bool rec_aa_idx: [N_rec] int — amino acid indices (constant) rec_chi: [N_rec, 4] — chi1/chi2 sin/cos (constant) binder_coords: [N_binder, 4, 3] — binder backbone coords (requires_grad) binder_mask: [N_binder] bool binder_aa_idx: [N_binder] int — amino acid indices (constant, UNK for designed) binder_chi: [N_binder, 4] — chi1/chi2 sin/cos (zeros for backbone-only) cutoff: interface distance cutoff (Å) max_nodes: maximum nodes per chain in the graph Returns: node_feats: [1, N_total, 32] tensor edge_feats: [1, N_total, N_total, 37] tensor node_mask: [1, N_total] bool tensor n_rec: int n_binder: int or None if no interface """ device = binder_coords.device dtype = binder_coords.dtype NUM_AA = 21 # ── Find interface residues (differentiable distances but hard threshold) ── rec_ca = rec_coords[:, 1, :] # [N_rec, 3] binder_ca = binder_coords[:, 1, :] # [N_binder, 3] # Pairwise CA distances dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) # [N_rec, N_binder] # Mask invalid residues dist_mat = dist_mat.clone() dist_mat[~rec_mask, :] = float('inf') dist_mat[:, ~binder_mask] = float('inf') rec_iface = (dist_mat < cutoff).any(dim=1) # [N_rec] binder_iface = (dist_mat < cutoff).any(dim=0) # [N_binder] rec_iface_idx = torch.where(rec_iface)[0] binder_iface_idx = torch.where(binder_iface)[0] # Truncate if too many if len(rec_iface_idx) > max_nodes // 2: rec_iface_idx = rec_iface_idx[:max_nodes // 2] if len(binder_iface_idx) > max_nodes // 2: binder_iface_idx = binder_iface_idx[:max_nodes // 2] n_rec = len(rec_iface_idx) n_binder = len(binder_iface_idx) n_total = n_rec + n_binder if n_total == 0: return None # ── Extract interface subsets ── rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3] binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3] rec_iface_mask = rec_mask[rec_iface_idx] binder_iface_mask = binder_mask[binder_iface_idx] # ── Compute backbone frames (differentiable) ── rec_origins, rec_rotations = compute_backbone_frames_torch(rec_iface_coords, rec_iface_mask) binder_origins, binder_rotations = compute_backbone_frames_torch(binder_iface_coords, binder_iface_mask) # ── Compute torsion angles (differentiable) ── rec_torsion = compute_torsion_angles_torch(rec_iface_coords, rec_iface_mask) # [n_rec, 6] binder_torsion = compute_torsion_angles_torch(binder_iface_coords, binder_iface_mask) # [n_binder, 6] # ── Node features ── # AA one-hot (non-differentiable constant) rec_aa_onehot = F.one_hot(rec_aa_idx[rec_iface_idx].long(), NUM_AA).float() # [n_rec, 21] binder_aa_onehot = F.one_hot(binder_aa_idx[binder_iface_idx].long(), NUM_AA).float() # [n_binder, 21] # Chi angles (constant for receptor, zeros for backbone-only binder) rec_chi_iface = rec_chi[rec_iface_idx] # [n_rec, 4] binder_chi_iface = binder_chi[binder_iface_idx] # [n_binder, 4] # Chain indicator rec_chain_feat = torch.zeros(n_rec, 1, device=device, dtype=dtype) binder_chain_feat = torch.ones(n_binder, 1, device=device, dtype=dtype) # Concatenate node features: [AA(21) + torsions(6) + chi(4) + chain(1)] = 32 rec_node = torch.cat([rec_aa_onehot, rec_torsion, rec_chi_iface, rec_chain_feat], dim=-1) binder_node = torch.cat([binder_aa_onehot, binder_torsion, binder_chi_iface, binder_chain_feat], dim=-1) node_feats = torch.cat([rec_node, binder_node], dim=0) # [N_total, 32] node_mask_flat = torch.cat([rec_iface_mask, binder_iface_mask], dim=0) # [N_total] # ── Edge features (differentiable) ── all_origins = torch.cat([rec_origins, binder_origins], dim=0) # [N_total, 3] all_rotations = torch.cat([rec_rotations, binder_rotations], dim=0) # [N_total, 3, 3] # Sequence indices rec_seq_idx = rec_iface_idx binder_seq_idx = binder_iface_idx + rec_coords.shape[0] all_seq_idx = torch.cat([rec_seq_idx, binder_seq_idx], dim=0) # Chain IDs all_chain_ids = torch.cat([ torch.zeros(n_rec, device=device, dtype=torch.long), torch.ones(n_binder, device=device, dtype=torch.long) ], dim=0) edge_feats = compute_edge_features_torch( all_origins, all_rotations, all_seq_idx, all_chain_ids, node_mask_flat ) # [N_total, N_total, 37] # Add batch dimension return { 'node_feats': node_feats.unsqueeze(0), # [1, N, 32] 'edge_feats': edge_feats.unsqueeze(0), # [1, N, N, 37] 'node_mask': node_mask_flat.unsqueeze(0), # [1, N] 'n_rec': n_rec, 'n_binder': n_binder, } # ── Differentiable Q_theta scoring function ────────────────────────────────── class DifferentiableQTheta: """ Wraps the Q_theta scorer for differentiable scoring w.r.t. binder backbone coordinates. Receptor structures are pre-loaded and cached. Usage: dq = DifferentiableQTheta(checkpoint_path, device) dq.load_receptor(holo_pdb, chain='A', label='holo') dq.load_receptor(apo_pdb, chain='A', label='apo') binder_coords = torch.tensor(...) # [N_binder, 4, 3], requires_grad=True score_holo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'holo') score_apo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'apo') selectivity = score_holo - score_apo selectivity.backward() # binder_coords.grad now contains ∂S/∂coords """ def __init__(self, checkpoint_path, device='cuda:0', esm_dir=None): import sys, os _code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) if _code_dir not in sys.path: sys.path.insert(0, _code_dir) from models.scorer import build_model self.device = torch.device(device) ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) self.config = ckpt['config'] self.model = build_model(self.config) self.model.load_state_dict(ckpt['model_state']) self.model = self.model.to(self.device) self.model.eval() # ESM feature support self.use_esm = self.config.get('esm_dim', 0) > 0 self.esm_dim = self.config.get('esm_dim', 0) self.esm_dir = esm_dir or os.path.join(os.environ.get('ALLOGEN_ROOT', '.'), 'data/esm2_embeddings') # Cache receptor data self.receptors = {} # label -> {coords, mask, aa_idx, chi, esm_emb?} def load_receptor(self, pdb_path, chain='A', label='holo', esm_target=None, esm_key=None): """Pre-load and cache receptor structure, optionally with ESM embeddings. Args: pdb_path: path to receptor PDB chain: chain ID label: cache key esm_target: target name for ESM dir (e.g., 'abl' for data/esm2_embeddings/abl/) esm_key: ESM embedding file key (e.g., '6XR7_A'). If None, auto-derived. """ import sys, os _code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) if _code_dir not in sys.path: sys.path.insert(0, _code_dir) from utils.pdb_utils import ( load_structure, get_residues, get_backbone_coords, get_aa_indices, compute_chi_angles ) model = load_structure(pdb_path) chain_obj = model[chain] residues = get_residues(chain_obj) coords, mask = get_backbone_coords(residues) aa_idx = get_aa_indices(residues) chi = compute_chi_angles(residues, mask) rec_data = { 'coords': torch.from_numpy(coords).float().to(self.device), 'mask': torch.from_numpy(mask).bool().to(self.device), 'aa_idx': torch.from_numpy(aa_idx).long().to(self.device), 'chi': torch.from_numpy(chi).float().to(self.device), 'residues': residues, } # Load ESM embeddings if model uses ESM if self.use_esm and esm_target: pdb_id = os.path.basename(pdb_path).replace('.pdb', '') if esm_key is None: esm_key = f'{pdb_id}_{chain}' esm_path = os.path.join(self.esm_dir, esm_target, f'{esm_key}.pt') if os.path.exists(esm_path): esm_emb = torch.load(esm_path, map_location=self.device, weights_only=True) # Truncate/pad to match residue count n_res = len(residues) if esm_emb.shape[0] > n_res: esm_emb = esm_emb[:n_res] elif esm_emb.shape[0] < n_res: pad = torch.zeros(n_res - esm_emb.shape[0], esm_emb.shape[1], device=self.device) esm_emb = torch.cat([esm_emb, pad], dim=0) rec_data['esm_emb'] = esm_emb.float() else: rec_data['esm_emb'] = torch.zeros(len(residues), self.esm_dim, device=self.device) self.receptors[label] = rec_data def load_receptor_from_coords(self, coords, mask, aa_idx=None, chi=None, label='path'): """ Load a receptor from raw backbone coords (not from PDB file). Used for interpolated path frames that don't have PDB files. If aa_idx is None, uses all-ALA (index 0). If chi is None, uses zeros. Args: coords: [N, 4, 3] numpy or torch backbone coords (N, CA, C, O) mask: [N] numpy or torch bool aa_idx: [N] numpy or torch int (default: all-ALA = 0) chi: [N, 4] numpy or torch float (default: zeros) label: str key for caching """ import numpy as np # Convert numpy to torch if needed if isinstance(coords, np.ndarray): coords = torch.from_numpy(coords).float() if isinstance(mask, np.ndarray): mask = torch.from_numpy(mask).bool() N = coords.shape[0] if aa_idx is None: aa_idx = torch.zeros(N, dtype=torch.long) # all-ALA elif isinstance(aa_idx, np.ndarray): aa_idx = torch.from_numpy(aa_idx).long() if chi is None: chi = torch.zeros(N, 4, dtype=coords.dtype) elif isinstance(chi, np.ndarray): chi = torch.from_numpy(chi).float() self.receptors[label] = { 'coords': coords.to(self.device), 'mask': mask.to(self.device), 'aa_idx': aa_idx.to(self.device), 'chi': chi.to(self.device), } def score(self, binder_coords, binder_mask, binder_aa_idx=None, binder_chi=None, receptor_label='holo', cutoff=8.0): """ Score binder against a cached receptor. Differentiable w.r.t. binder_coords. Args: binder_coords: [N_binder, 4, 3] tensor (can have requires_grad=True) binder_mask: [N_binder] bool tensor binder_aa_idx: [N_binder] int tensor (default: all UNK) binder_chi: [N_binder, 4] tensor (default: zeros) receptor_label: key into cached receptors cutoff: interface distance cutoff Returns: score: scalar tensor in (0, 1), differentiable w.r.t. binder_coords """ rec = self.receptors[receptor_label] N_binder = binder_coords.shape[0] if binder_aa_idx is None: binder_aa_idx = torch.full((N_binder,), 20, device=self.device, dtype=torch.long) # UNK if binder_chi is None: binder_chi = torch.zeros(N_binder, 4, device=self.device, dtype=binder_coords.dtype) graph = build_differentiable_interface_graph( rec_coords=rec['coords'], rec_mask=rec['mask'], rec_aa_idx=rec['aa_idx'], rec_chi=rec['chi'], binder_coords=binder_coords, binder_mask=binder_mask, binder_aa_idx=binder_aa_idx, binder_chi=binder_chi, cutoff=cutoff, ) if graph is None: # No interface — return zero score with gradient return torch.zeros(1, device=self.device, dtype=binder_coords.dtype, requires_grad=True).squeeze() # Build ESM features if model uses ESM esm_feats = None if self.use_esm: n_rec = graph['n_rec'] n_binder = graph['n_binder'] n_total = n_rec + n_binder # Receptor ESM: use cached if available, else zeros if 'esm_emb' in rec: rec_esm = rec['esm_emb'] # Need to select interface residues (same indices as structural features) # The graph was built with rec_iface_idx — we need those indices # For simplicity, use zeros for now and rely on the projection layer # to handle the zero binder ESM gracefully rec_esm_full = rec_esm # [N_rec_total, 1280] else: rec_esm_full = torch.zeros(rec['coords'].shape[0], self.esm_dim, device=self.device) # Binder ESM: zeros (designed backbone, no sequence) binder_esm = torch.zeros(binder_coords.shape[0], self.esm_dim, device=self.device) # We need interface indices to select — rebuild them rec_ca = rec['coords'][:, 1, :] binder_ca = binder_coords[:, 1, :] dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) dist_mat_c = dist_mat.clone() dist_mat_c[~rec['mask'], :] = float('inf') dist_mat_c[:, ~binder_mask] = float('inf') rec_iface = (dist_mat_c < cutoff).any(dim=1) binder_iface = (dist_mat_c < cutoff).any(dim=0) rec_iface_idx = torch.where(rec_iface)[0][:n_rec] binder_iface_idx = torch.where(binder_iface)[0][:n_binder] rec_esm_iface = rec_esm_full[rec_iface_idx] # [n_rec, 1280] binder_esm_iface = binder_esm[binder_iface_idx] # [n_binder, 1280] esm_combined = torch.cat([rec_esm_iface, binder_esm_iface], dim=0) # [n_total, 1280] esm_feats = esm_combined.unsqueeze(0) # [1, n_total, 1280] score = self.model(graph['node_feats'], graph['edge_feats'], graph['node_mask'], esm_feats=esm_feats) return score.squeeze() # scalar def selectivity_margin(self, binder_coords, binder_mask, binder_aa_idx=None, binder_chi=None, holo_label='holo', apo_label='apo', cutoff=8.0): """ Compute selectivity margin S = Q(holo, Y) - Q(apo, Y). Differentiable w.r.t. binder_coords. """ q_holo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi, holo_label, cutoff) q_apo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi, apo_label, cutoff) return q_holo - q_apo, q_holo, q_apo