| """ |
| Differentiable feature extraction for Q_theta guidance. |
| |
| This module re-implements the key feature extraction functions from features.py |
| and pdb_utils.py using PyTorch operations, enabling gradient computation through |
| Q_theta with respect to backbone coordinates. |
| |
| The differentiable path: |
| coords (N,4,3) → backbone frames → torsions, distances, directions, rotations |
| → node_feats, edge_feats → Q_theta → score → backward() → ∇coords |
| |
| Non-differentiable features (AA one-hot, chain_id, seq_sep, same_chain) are |
| treated as constants. |
| """ |
|
|
| import os |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| |
|
|
| def compute_backbone_frames_torch(coords, mask): |
| """ |
| Compute SE(3)-equivariant backbone frames from N, CA, C atoms. |
| Differentiable w.r.t. coords. |
| |
| Args: |
| coords: [N, 4, 3] backbone coords (N, CA, C, O) — requires_grad=True for binder |
| mask: [N] bool tensor |
| |
| Returns: |
| origins: [N, 3] = CA positions |
| rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes) |
| """ |
| N_res = coords.shape[0] |
| device = coords.device |
|
|
| origins = coords[:, 1, :] |
| rotations = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0).expand(N_res, -1, -1).clone() |
|
|
| ca = coords[:, 1, :] |
| n_atom = coords[:, 0, :] |
| c_atom = coords[:, 2, :] |
|
|
| |
| z = c_atom - ca |
| z_norm = torch.norm(z, dim=-1, keepdim=True).clamp(min=1e-6) |
| z = z / z_norm |
|
|
| |
| y = n_atom - ca |
| y_proj = (y * z).sum(dim=-1, keepdim=True) |
| y = y - y_proj * z |
| y_norm = torch.norm(y, dim=-1, keepdim=True).clamp(min=1e-6) |
| y = y / y_norm |
|
|
| |
| x = torch.cross(y, z, dim=-1) |
|
|
| |
| rot = torch.stack([x, y, z], dim=-1) |
|
|
| |
| mask_f = mask.float().unsqueeze(-1).unsqueeze(-1) |
| eye = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0) |
| rotations = rot * mask_f + eye * (1 - mask_f) |
|
|
| return origins, rotations |
|
|
|
|
| |
|
|
| def _dihedral_torch(p0, p1, p2, p3): |
| """ |
| Compute dihedral angle for batches of 4 points. Returns sin, cos. |
| Differentiable w.r.t. all inputs. |
| |
| Args: |
| p0, p1, p2, p3: [N, 3] tensors |
| |
| Returns: |
| sin_angle: [N] |
| cos_angle: [N] |
| """ |
| b1 = p1 - p0 |
| b2 = p2 - p1 |
| b3 = p3 - p2 |
|
|
| n1 = torch.cross(b1, b2, dim=-1) |
| n2 = torch.cross(b2, b3, dim=-1) |
|
|
| n1_norm = torch.norm(n1, dim=-1, keepdim=True).clamp(min=1e-8) |
| n2_norm = torch.norm(n2, dim=-1, keepdim=True).clamp(min=1e-8) |
| n1 = n1 / n1_norm |
| n2 = n2 / n2_norm |
|
|
| b2_norm = torch.norm(b2, dim=-1, keepdim=True).clamp(min=1e-8) |
| m1 = torch.cross(n1, b2 / b2_norm, dim=-1) |
|
|
| cos_angle = (n1 * n2).sum(dim=-1) |
| sin_angle = (m1 * n2).sum(dim=-1) |
|
|
| return sin_angle, cos_angle |
|
|
|
|
| def compute_torsion_angles_torch(coords, mask): |
| """ |
| Compute backbone torsion angles (phi, psi, omega) as sin/cos pairs. |
| Differentiable w.r.t. coords. |
| |
| Args: |
| coords: [N, 4, 3] backbone coords (N, CA, C, O) |
| mask: [N] bool tensor |
| |
| Returns: |
| torsions: [N, 6] (sin_phi, cos_phi, sin_psi, cos_psi, sin_omega, cos_omega) |
| """ |
| N = coords.shape[0] |
| device = coords.device |
| torsions = torch.zeros(N, 6, device=device, dtype=coords.dtype) |
|
|
| if N < 2: |
| return torsions |
|
|
| n_atoms = coords[:, 0, :] |
| ca_atoms = coords[:, 1, :] |
| c_atoms = coords[:, 2, :] |
|
|
| |
| if N > 1: |
| phi_mask = mask[1:] & mask[:-1] |
| sin_phi, cos_phi = _dihedral_torch( |
| c_atoms[:-1], |
| n_atoms[1:], |
| ca_atoms[1:], |
| c_atoms[1:] |
| ) |
| torsions[1:, 0] = sin_phi * phi_mask.float() |
| torsions[1:, 1] = cos_phi * phi_mask.float() |
|
|
| |
| if N > 1: |
| psi_mask = mask[:-1] & mask[1:] |
| sin_psi, cos_psi = _dihedral_torch( |
| n_atoms[:-1], |
| ca_atoms[:-1], |
| c_atoms[:-1], |
| n_atoms[1:] |
| ) |
| torsions[:-1, 2] = sin_psi * psi_mask.float() |
| torsions[:-1, 3] = cos_psi * psi_mask.float() |
|
|
| |
| if N > 1: |
| omega_mask = mask[1:] & mask[:-1] |
| sin_omega, cos_omega = _dihedral_torch( |
| ca_atoms[:-1], |
| c_atoms[:-1], |
| n_atoms[1:], |
| ca_atoms[1:] |
| ) |
| torsions[1:, 4] = sin_omega * omega_mask.float() |
| torsions[1:, 5] = cos_omega * omega_mask.float() |
|
|
| return torsions |
|
|
|
|
| |
|
|
| def rbf_encode_torch(distances, d_min=0.0, d_max=20.0, n_bins=16): |
| """ |
| RBF encoding of distances using Gaussian basis functions. |
| Differentiable w.r.t. distances. |
| |
| Args: |
| distances: [...] tensor |
| Returns: |
| encoded: [..., n_bins] tensor |
| """ |
| centers = torch.linspace(d_min, d_max, n_bins, device=distances.device, dtype=distances.dtype) |
| sigma = (d_max - d_min) / (n_bins - 1) |
| return torch.exp(-((distances.unsqueeze(-1) - centers) ** 2) / (2 * sigma ** 2)) |
|
|
|
|
| |
|
|
| def compute_edge_features_torch(origins, rotations, seq_idx, chain_ids, mask, |
| n_bins_rbf=16, n_bins_sep=8, max_sep=32): |
| """ |
| Compute SE(3)-invariant edge features between all residue pairs. |
| Differentiable w.r.t. origins and rotations (which derive from coords). |
| |
| Args: |
| origins: [N, 3] CA positions |
| rotations: [N, 3, 3] backbone frame rotations |
| seq_idx: [N] int tensor — sequence indices (non-differentiable) |
| chain_ids: [N] int tensor — chain labels (non-differentiable) |
| mask: [N] bool tensor |
| |
| Returns: |
| edge_feats: [N, N, 37] |
| """ |
| N = origins.shape[0] |
| device = origins.device |
| dtype = origins.dtype |
|
|
| |
| diff = origins.unsqueeze(1) - origins.unsqueeze(0) |
| dist = torch.norm(diff, dim=-1).clamp(min=1e-8) |
| dist_rbf = rbf_encode_torch(dist, d_min=0., d_max=20., n_bins=n_bins_rbf) |
|
|
| |
| unit_diff = diff / dist.unsqueeze(-1) |
| |
| |
| local_dir = torch.einsum('ikl,ijl->ijk', rotations, unit_diff) |
|
|
| |
| |
| rel_rot = torch.einsum('ikl,jlm->ijkm', rotations, rotations) |
| rel_rot_flat = rel_rot.reshape(N, N, 9) |
|
|
| |
| sep = seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0) |
| bins = torch.linspace(-max_sep, max_sep, n_bins_sep + 1, device=device) |
| sep_clipped = sep.float().clamp(-max_sep, max_sep) |
| |
| sep_enc = torch.zeros(N, N, n_bins_sep, device=device, dtype=dtype) |
| bin_idx = torch.bucketize(sep_clipped, bins) - 1 |
| bin_idx = bin_idx.clamp(0, n_bins_sep - 1) |
| |
| sep_enc.scatter_(2, bin_idx.unsqueeze(-1).long(), 1.0) |
|
|
| |
| same_chain = (chain_ids.unsqueeze(1) == chain_ids.unsqueeze(0)) |
| cross_chain = ~same_chain |
| sep_enc[cross_chain] = 0.0 |
|
|
| |
| same_chain_feat = same_chain.float().unsqueeze(-1) |
|
|
| |
| edge_feats = torch.cat([ |
| dist_rbf, |
| local_dir, |
| rel_rot_flat, |
| sep_enc, |
| same_chain_feat |
| ], dim=-1) |
|
|
| |
| mask_2d = mask.unsqueeze(1) & mask.unsqueeze(0) |
| edge_feats = edge_feats * mask_2d.unsqueeze(-1).float() |
|
|
| return edge_feats |
|
|
|
|
| |
|
|
| def build_differentiable_interface_graph( |
| rec_coords, rec_mask, rec_aa_idx, rec_chi, |
| binder_coords, binder_mask, binder_aa_idx, binder_chi, |
| cutoff=8.0, max_nodes=128 |
| ): |
| """ |
| Build interface graph with differentiable features w.r.t. binder_coords. |
| Receptor coords are treated as constants (detached). |
| |
| Args: |
| rec_coords: [N_rec, 4, 3] — receptor backbone coords (constant, no grad) |
| rec_mask: [N_rec] bool |
| rec_aa_idx: [N_rec] int — amino acid indices (constant) |
| rec_chi: [N_rec, 4] — chi1/chi2 sin/cos (constant) |
| binder_coords: [N_binder, 4, 3] — binder backbone coords (requires_grad) |
| binder_mask: [N_binder] bool |
| binder_aa_idx: [N_binder] int — amino acid indices (constant, UNK for designed) |
| binder_chi: [N_binder, 4] — chi1/chi2 sin/cos (zeros for backbone-only) |
| cutoff: interface distance cutoff (Å) |
| max_nodes: maximum nodes per chain in the graph |
| |
| Returns: |
| node_feats: [1, N_total, 32] tensor |
| edge_feats: [1, N_total, N_total, 37] tensor |
| node_mask: [1, N_total] bool tensor |
| n_rec: int |
| n_binder: int |
| or None if no interface |
| """ |
| device = binder_coords.device |
| dtype = binder_coords.dtype |
| NUM_AA = 21 |
|
|
| |
| rec_ca = rec_coords[:, 1, :] |
| binder_ca = binder_coords[:, 1, :] |
|
|
| |
| dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) |
| |
| dist_mat = dist_mat.clone() |
| dist_mat[~rec_mask, :] = float('inf') |
| dist_mat[:, ~binder_mask] = float('inf') |
|
|
| rec_iface = (dist_mat < cutoff).any(dim=1) |
| binder_iface = (dist_mat < cutoff).any(dim=0) |
|
|
| rec_iface_idx = torch.where(rec_iface)[0] |
| binder_iface_idx = torch.where(binder_iface)[0] |
|
|
| |
| if len(rec_iface_idx) > max_nodes // 2: |
| rec_iface_idx = rec_iface_idx[:max_nodes // 2] |
| if len(binder_iface_idx) > max_nodes // 2: |
| binder_iface_idx = binder_iface_idx[:max_nodes // 2] |
|
|
| n_rec = len(rec_iface_idx) |
| n_binder = len(binder_iface_idx) |
| n_total = n_rec + n_binder |
|
|
| if n_total == 0: |
| return None |
|
|
| |
| rec_iface_coords = rec_coords[rec_iface_idx] |
| binder_iface_coords = binder_coords[binder_iface_idx] |
| rec_iface_mask = rec_mask[rec_iface_idx] |
| binder_iface_mask = binder_mask[binder_iface_idx] |
|
|
| |
| rec_origins, rec_rotations = compute_backbone_frames_torch(rec_iface_coords, rec_iface_mask) |
| binder_origins, binder_rotations = compute_backbone_frames_torch(binder_iface_coords, binder_iface_mask) |
|
|
| |
| rec_torsion = compute_torsion_angles_torch(rec_iface_coords, rec_iface_mask) |
| binder_torsion = compute_torsion_angles_torch(binder_iface_coords, binder_iface_mask) |
|
|
| |
| |
| rec_aa_onehot = F.one_hot(rec_aa_idx[rec_iface_idx].long(), NUM_AA).float() |
| binder_aa_onehot = F.one_hot(binder_aa_idx[binder_iface_idx].long(), NUM_AA).float() |
|
|
| |
| rec_chi_iface = rec_chi[rec_iface_idx] |
| binder_chi_iface = binder_chi[binder_iface_idx] |
|
|
| |
| rec_chain_feat = torch.zeros(n_rec, 1, device=device, dtype=dtype) |
| binder_chain_feat = torch.ones(n_binder, 1, device=device, dtype=dtype) |
|
|
| |
| rec_node = torch.cat([rec_aa_onehot, rec_torsion, rec_chi_iface, rec_chain_feat], dim=-1) |
| binder_node = torch.cat([binder_aa_onehot, binder_torsion, binder_chi_iface, binder_chain_feat], dim=-1) |
| node_feats = torch.cat([rec_node, binder_node], dim=0) |
| node_mask_flat = torch.cat([rec_iface_mask, binder_iface_mask], dim=0) |
|
|
| |
| all_origins = torch.cat([rec_origins, binder_origins], dim=0) |
| all_rotations = torch.cat([rec_rotations, binder_rotations], dim=0) |
|
|
| |
| rec_seq_idx = rec_iface_idx |
| binder_seq_idx = binder_iface_idx + rec_coords.shape[0] |
| all_seq_idx = torch.cat([rec_seq_idx, binder_seq_idx], dim=0) |
|
|
| |
| all_chain_ids = torch.cat([ |
| torch.zeros(n_rec, device=device, dtype=torch.long), |
| torch.ones(n_binder, device=device, dtype=torch.long) |
| ], dim=0) |
|
|
| edge_feats = compute_edge_features_torch( |
| all_origins, all_rotations, all_seq_idx, all_chain_ids, node_mask_flat |
| ) |
|
|
| |
| return { |
| 'node_feats': node_feats.unsqueeze(0), |
| 'edge_feats': edge_feats.unsqueeze(0), |
| 'node_mask': node_mask_flat.unsqueeze(0), |
| 'n_rec': n_rec, |
| 'n_binder': n_binder, |
| } |
|
|
|
|
| |
|
|
| class DifferentiableQTheta: |
| """ |
| Wraps the Q_theta scorer for differentiable scoring w.r.t. binder backbone |
| coordinates. Receptor structures are pre-loaded and cached. |
| |
| Usage: |
| dq = DifferentiableQTheta(checkpoint_path, device) |
| dq.load_receptor(holo_pdb, chain='A', label='holo') |
| dq.load_receptor(apo_pdb, chain='A', label='apo') |
| |
| binder_coords = torch.tensor(...) # [N_binder, 4, 3], requires_grad=True |
| score_holo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'holo') |
| score_apo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'apo') |
| selectivity = score_holo - score_apo |
| selectivity.backward() |
| # binder_coords.grad now contains ∂S/∂coords |
| """ |
|
|
| def __init__(self, checkpoint_path, device='cuda:0', esm_dir=None): |
| import sys, os |
| _code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
| if _code_dir not in sys.path: |
| sys.path.insert(0, _code_dir) |
| from models.scorer import build_model |
|
|
| self.device = torch.device(device) |
| ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| self.config = ckpt['config'] |
| self.model = build_model(self.config) |
| self.model.load_state_dict(ckpt['model_state']) |
| self.model = self.model.to(self.device) |
| self.model.eval() |
|
|
| |
| self.use_esm = self.config.get('esm_dim', 0) > 0 |
| self.esm_dim = self.config.get('esm_dim', 0) |
| self.esm_dir = esm_dir or os.path.join(os.environ.get('ALLOGEN_ROOT', '.'), 'data/esm2_embeddings') |
|
|
| |
| self.receptors = {} |
|
|
| def load_receptor(self, pdb_path, chain='A', label='holo', |
| esm_target=None, esm_key=None): |
| """Pre-load and cache receptor structure, optionally with ESM embeddings. |
| |
| Args: |
| pdb_path: path to receptor PDB |
| chain: chain ID |
| label: cache key |
| esm_target: target name for ESM dir (e.g., 'abl' for data/esm2_embeddings/abl/) |
| esm_key: ESM embedding file key (e.g., '6XR7_A'). If None, auto-derived. |
| """ |
| import sys, os |
| _code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
| if _code_dir not in sys.path: |
| sys.path.insert(0, _code_dir) |
| from utils.pdb_utils import ( |
| load_structure, get_residues, get_backbone_coords, |
| get_aa_indices, compute_chi_angles |
| ) |
|
|
| model = load_structure(pdb_path) |
| chain_obj = model[chain] |
| residues = get_residues(chain_obj) |
| coords, mask = get_backbone_coords(residues) |
| aa_idx = get_aa_indices(residues) |
| chi = compute_chi_angles(residues, mask) |
|
|
| rec_data = { |
| 'coords': torch.from_numpy(coords).float().to(self.device), |
| 'mask': torch.from_numpy(mask).bool().to(self.device), |
| 'aa_idx': torch.from_numpy(aa_idx).long().to(self.device), |
| 'chi': torch.from_numpy(chi).float().to(self.device), |
| 'residues': residues, |
| } |
|
|
| |
| if self.use_esm and esm_target: |
| pdb_id = os.path.basename(pdb_path).replace('.pdb', '') |
| if esm_key is None: |
| esm_key = f'{pdb_id}_{chain}' |
| esm_path = os.path.join(self.esm_dir, esm_target, f'{esm_key}.pt') |
| if os.path.exists(esm_path): |
| esm_emb = torch.load(esm_path, map_location=self.device, weights_only=True) |
| |
| n_res = len(residues) |
| if esm_emb.shape[0] > n_res: |
| esm_emb = esm_emb[:n_res] |
| elif esm_emb.shape[0] < n_res: |
| pad = torch.zeros(n_res - esm_emb.shape[0], esm_emb.shape[1], |
| device=self.device) |
| esm_emb = torch.cat([esm_emb, pad], dim=0) |
| rec_data['esm_emb'] = esm_emb.float() |
| else: |
| rec_data['esm_emb'] = torch.zeros(len(residues), self.esm_dim, |
| device=self.device) |
|
|
| self.receptors[label] = rec_data |
|
|
| def load_receptor_from_coords(self, coords, mask, aa_idx=None, chi=None, |
| label='path'): |
| """ |
| Load a receptor from raw backbone coords (not from PDB file). |
| |
| Used for interpolated path frames that don't have PDB files. |
| If aa_idx is None, uses all-ALA (index 0). If chi is None, uses zeros. |
| |
| Args: |
| coords: [N, 4, 3] numpy or torch backbone coords (N, CA, C, O) |
| mask: [N] numpy or torch bool |
| aa_idx: [N] numpy or torch int (default: all-ALA = 0) |
| chi: [N, 4] numpy or torch float (default: zeros) |
| label: str key for caching |
| """ |
| import numpy as np |
|
|
| |
| if isinstance(coords, np.ndarray): |
| coords = torch.from_numpy(coords).float() |
| if isinstance(mask, np.ndarray): |
| mask = torch.from_numpy(mask).bool() |
|
|
| N = coords.shape[0] |
|
|
| if aa_idx is None: |
| aa_idx = torch.zeros(N, dtype=torch.long) |
| elif isinstance(aa_idx, np.ndarray): |
| aa_idx = torch.from_numpy(aa_idx).long() |
|
|
| if chi is None: |
| chi = torch.zeros(N, 4, dtype=coords.dtype) |
| elif isinstance(chi, np.ndarray): |
| chi = torch.from_numpy(chi).float() |
|
|
| self.receptors[label] = { |
| 'coords': coords.to(self.device), |
| 'mask': mask.to(self.device), |
| 'aa_idx': aa_idx.to(self.device), |
| 'chi': chi.to(self.device), |
| } |
|
|
| def score(self, binder_coords, binder_mask, binder_aa_idx=None, |
| binder_chi=None, receptor_label='holo', cutoff=8.0): |
| """ |
| Score binder against a cached receptor. Differentiable w.r.t. binder_coords. |
| |
| Args: |
| binder_coords: [N_binder, 4, 3] tensor (can have requires_grad=True) |
| binder_mask: [N_binder] bool tensor |
| binder_aa_idx: [N_binder] int tensor (default: all UNK) |
| binder_chi: [N_binder, 4] tensor (default: zeros) |
| receptor_label: key into cached receptors |
| cutoff: interface distance cutoff |
| |
| Returns: |
| score: scalar tensor in (0, 1), differentiable w.r.t. binder_coords |
| """ |
| rec = self.receptors[receptor_label] |
| N_binder = binder_coords.shape[0] |
|
|
| if binder_aa_idx is None: |
| binder_aa_idx = torch.full((N_binder,), 20, device=self.device, dtype=torch.long) |
| if binder_chi is None: |
| binder_chi = torch.zeros(N_binder, 4, device=self.device, dtype=binder_coords.dtype) |
|
|
| graph = build_differentiable_interface_graph( |
| rec_coords=rec['coords'], |
| rec_mask=rec['mask'], |
| rec_aa_idx=rec['aa_idx'], |
| rec_chi=rec['chi'], |
| binder_coords=binder_coords, |
| binder_mask=binder_mask, |
| binder_aa_idx=binder_aa_idx, |
| binder_chi=binder_chi, |
| cutoff=cutoff, |
| ) |
|
|
| if graph is None: |
| |
| return torch.zeros(1, device=self.device, dtype=binder_coords.dtype, requires_grad=True).squeeze() |
|
|
| |
| esm_feats = None |
| if self.use_esm: |
| n_rec = graph['n_rec'] |
| n_binder = graph['n_binder'] |
| n_total = n_rec + n_binder |
| |
| if 'esm_emb' in rec: |
| rec_esm = rec['esm_emb'] |
| |
| |
| |
| |
| rec_esm_full = rec_esm |
| else: |
| rec_esm_full = torch.zeros(rec['coords'].shape[0], self.esm_dim, |
| device=self.device) |
| |
| binder_esm = torch.zeros(binder_coords.shape[0], self.esm_dim, |
| device=self.device) |
| |
| rec_ca = rec['coords'][:, 1, :] |
| binder_ca = binder_coords[:, 1, :] |
| dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) |
| dist_mat_c = dist_mat.clone() |
| dist_mat_c[~rec['mask'], :] = float('inf') |
| dist_mat_c[:, ~binder_mask] = float('inf') |
| rec_iface = (dist_mat_c < cutoff).any(dim=1) |
| binder_iface = (dist_mat_c < cutoff).any(dim=0) |
| rec_iface_idx = torch.where(rec_iface)[0][:n_rec] |
| binder_iface_idx = torch.where(binder_iface)[0][:n_binder] |
|
|
| rec_esm_iface = rec_esm_full[rec_iface_idx] |
| binder_esm_iface = binder_esm[binder_iface_idx] |
| esm_combined = torch.cat([rec_esm_iface, binder_esm_iface], dim=0) |
| esm_feats = esm_combined.unsqueeze(0) |
|
|
| score = self.model(graph['node_feats'], graph['edge_feats'], graph['node_mask'], |
| esm_feats=esm_feats) |
| return score.squeeze() |
|
|
| def selectivity_margin(self, binder_coords, binder_mask, |
| binder_aa_idx=None, binder_chi=None, |
| holo_label='holo', apo_label='apo', cutoff=8.0): |
| """ |
| Compute selectivity margin S = Q(holo, Y) - Q(apo, Y). |
| Differentiable w.r.t. binder_coords. |
| """ |
| q_holo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi, |
| holo_label, cutoff) |
| q_apo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi, |
| apo_label, cutoff) |
| return q_holo - q_apo, q_holo, q_apo |
|
|