AlloGen / code /models /differentiable_features.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
Differentiable feature extraction for Q_theta guidance.
This module re-implements the key feature extraction functions from features.py
and pdb_utils.py using PyTorch operations, enabling gradient computation through
Q_theta with respect to backbone coordinates.
The differentiable path:
coords (N,4,3) → backbone frames → torsions, distances, directions, rotations
→ node_feats, edge_feats → Q_theta → score → backward() → ∇coords
Non-differentiable features (AA one-hot, chain_id, seq_sep, same_chain) are
treated as constants.
"""
import os
import torch
import torch.nn.functional as F
import numpy as np
# ── Differentiable backbone frame computation ────────────────────────────────
def compute_backbone_frames_torch(coords, mask):
"""
Compute SE(3)-equivariant backbone frames from N, CA, C atoms.
Differentiable w.r.t. coords.
Args:
coords: [N, 4, 3] backbone coords (N, CA, C, O) — requires_grad=True for binder
mask: [N] bool tensor
Returns:
origins: [N, 3] = CA positions
rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes)
"""
N_res = coords.shape[0]
device = coords.device
origins = coords[:, 1, :] # CA positions [N, 3]
rotations = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0).expand(N_res, -1, -1).clone()
ca = coords[:, 1, :] # [N, 3]
n_atom = coords[:, 0, :] # [N, 3]
c_atom = coords[:, 2, :] # [N, 3]
# z-axis: CA -> C
z = c_atom - ca # [N, 3]
z_norm = torch.norm(z, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1]
z = z / z_norm # [N, 3]
# y-axis: CA -> N, orthogonalized against z
y = n_atom - ca # [N, 3]
y_proj = (y * z).sum(dim=-1, keepdim=True) # [N, 1]
y = y - y_proj * z # [N, 3]
y_norm = torch.norm(y, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1]
y = y / y_norm # [N, 3]
# x-axis: y cross z
x = torch.cross(y, z, dim=-1) # [N, 3]
# Stack columns: [N, 3, 3] where columns are x, y, z
rot = torch.stack([x, y, z], dim=-1) # [N, 3, 3]
# Apply mask: identity for masked residues
mask_f = mask.float().unsqueeze(-1).unsqueeze(-1) # [N, 1, 1]
eye = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0) # [1, 3, 3]
rotations = rot * mask_f + eye * (1 - mask_f)
return origins, rotations
# ── Differentiable torsion angle computation ─────────────────────────────────
def _dihedral_torch(p0, p1, p2, p3):
"""
Compute dihedral angle for batches of 4 points. Returns sin, cos.
Differentiable w.r.t. all inputs.
Args:
p0, p1, p2, p3: [N, 3] tensors
Returns:
sin_angle: [N]
cos_angle: [N]
"""
b1 = p1 - p0 # [N, 3]
b2 = p2 - p1
b3 = p3 - p2
n1 = torch.cross(b1, b2, dim=-1) # [N, 3]
n2 = torch.cross(b2, b3, dim=-1)
n1_norm = torch.norm(n1, dim=-1, keepdim=True).clamp(min=1e-8)
n2_norm = torch.norm(n2, dim=-1, keepdim=True).clamp(min=1e-8)
n1 = n1 / n1_norm
n2 = n2 / n2_norm
b2_norm = torch.norm(b2, dim=-1, keepdim=True).clamp(min=1e-8)
m1 = torch.cross(n1, b2 / b2_norm, dim=-1) # [N, 3]
cos_angle = (n1 * n2).sum(dim=-1) # [N]
sin_angle = (m1 * n2).sum(dim=-1) # [N]
return sin_angle, cos_angle
def compute_torsion_angles_torch(coords, mask):
"""
Compute backbone torsion angles (phi, psi, omega) as sin/cos pairs.
Differentiable w.r.t. coords.
Args:
coords: [N, 4, 3] backbone coords (N, CA, C, O)
mask: [N] bool tensor
Returns:
torsions: [N, 6] (sin_phi, cos_phi, sin_psi, cos_psi, sin_omega, cos_omega)
"""
N = coords.shape[0]
device = coords.device
torsions = torch.zeros(N, 6, device=device, dtype=coords.dtype)
if N < 2:
return torsions
n_atoms = coords[:, 0, :] # N atoms [N, 3]
ca_atoms = coords[:, 1, :] # CA atoms
c_atoms = coords[:, 2, :] # C atoms
# Phi: C_{i-1} - N_i - CA_i - C_i (for i >= 1)
if N > 1:
phi_mask = mask[1:] & mask[:-1] # [N-1]
sin_phi, cos_phi = _dihedral_torch(
c_atoms[:-1], # C_{i-1}
n_atoms[1:], # N_i
ca_atoms[1:], # CA_i
c_atoms[1:] # C_i
)
torsions[1:, 0] = sin_phi * phi_mask.float()
torsions[1:, 1] = cos_phi * phi_mask.float()
# Psi: N_i - CA_i - C_i - N_{i+1} (for i < N-1)
if N > 1:
psi_mask = mask[:-1] & mask[1:] # [N-1]
sin_psi, cos_psi = _dihedral_torch(
n_atoms[:-1], # N_i
ca_atoms[:-1], # CA_i
c_atoms[:-1], # C_i
n_atoms[1:] # N_{i+1}
)
torsions[:-1, 2] = sin_psi * psi_mask.float()
torsions[:-1, 3] = cos_psi * psi_mask.float()
# Omega: CA_{i-1} - C_{i-1} - N_i - CA_i (for i >= 1)
if N > 1:
omega_mask = mask[1:] & mask[:-1] # [N-1]
sin_omega, cos_omega = _dihedral_torch(
ca_atoms[:-1], # CA_{i-1}
c_atoms[:-1], # C_{i-1}
n_atoms[1:], # N_i
ca_atoms[1:] # CA_i
)
torsions[1:, 4] = sin_omega * omega_mask.float()
torsions[1:, 5] = cos_omega * omega_mask.float()
return torsions
# ── Differentiable RBF distance encoding ─────────────────────────────────────
def rbf_encode_torch(distances, d_min=0.0, d_max=20.0, n_bins=16):
"""
RBF encoding of distances using Gaussian basis functions.
Differentiable w.r.t. distances.
Args:
distances: [...] tensor
Returns:
encoded: [..., n_bins] tensor
"""
centers = torch.linspace(d_min, d_max, n_bins, device=distances.device, dtype=distances.dtype)
sigma = (d_max - d_min) / (n_bins - 1)
return torch.exp(-((distances.unsqueeze(-1) - centers) ** 2) / (2 * sigma ** 2))
# ── Differentiable edge feature computation ──────────────────────────────────
def compute_edge_features_torch(origins, rotations, seq_idx, chain_ids, mask,
n_bins_rbf=16, n_bins_sep=8, max_sep=32):
"""
Compute SE(3)-invariant edge features between all residue pairs.
Differentiable w.r.t. origins and rotations (which derive from coords).
Args:
origins: [N, 3] CA positions
rotations: [N, 3, 3] backbone frame rotations
seq_idx: [N] int tensor — sequence indices (non-differentiable)
chain_ids: [N] int tensor — chain labels (non-differentiable)
mask: [N] bool tensor
Returns:
edge_feats: [N, N, 37]
"""
N = origins.shape[0]
device = origins.device
dtype = origins.dtype
# --- Distance features (differentiable) ---
diff = origins.unsqueeze(1) - origins.unsqueeze(0) # [N, N, 3]
dist = torch.norm(diff, dim=-1).clamp(min=1e-8) # [N, N]
dist_rbf = rbf_encode_torch(dist, d_min=0., d_max=20., n_bins=n_bins_rbf) # [N, N, 16]
# --- Direction in local frame (differentiable) ---
unit_diff = diff / dist.unsqueeze(-1) # [N, N, 3]
# local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
# rotations: [N, 3, 3], unit_diff: [N, N, 3]
local_dir = torch.einsum('ikl,ijl->ijk', rotations, unit_diff) # [N, N, 3]
# --- Relative rotation (differentiable) ---
# rel_rot[i,j] = R_i^T @ R_j -> [N, N, 3, 3] -> flatten to [N, N, 9]
rel_rot = torch.einsum('ikl,jlm->ijkm', rotations, rotations) # [N, N, 3, 3]
rel_rot_flat = rel_rot.reshape(N, N, 9) # [N, N, 9]
# --- Sequence separation (non-differentiable, constant) ---
sep = seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0) # [N, N]
bins = torch.linspace(-max_sep, max_sep, n_bins_sep + 1, device=device)
sep_clipped = sep.float().clamp(-max_sep, max_sep)
# Bin encoding via soft assignment (but really we just use hard binning)
sep_enc = torch.zeros(N, N, n_bins_sep, device=device, dtype=dtype)
bin_idx = torch.bucketize(sep_clipped, bins) - 1
bin_idx = bin_idx.clamp(0, n_bins_sep - 1)
# Scatter one-hot
sep_enc.scatter_(2, bin_idx.unsqueeze(-1).long(), 1.0)
# Cross-chain pairs get sep=0
same_chain = (chain_ids.unsqueeze(1) == chain_ids.unsqueeze(0)) # [N, N]
cross_chain = ~same_chain
sep_enc[cross_chain] = 0.0
# --- Same chain indicator (non-differentiable, constant) ---
same_chain_feat = same_chain.float().unsqueeze(-1) # [N, N, 1]
# --- Concatenate ---
edge_feats = torch.cat([
dist_rbf, # [N, N, 16]
local_dir, # [N, N, 3]
rel_rot_flat, # [N, N, 9]
sep_enc, # [N, N, 8]
same_chain_feat # [N, N, 1]
], dim=-1) # [N, N, 37]
# Zero out edges involving masked residues
mask_2d = mask.unsqueeze(1) & mask.unsqueeze(0) # [N, N]
edge_feats = edge_feats * mask_2d.unsqueeze(-1).float()
return edge_feats
# ── Full differentiable interface graph builder ──────────────────────────────
def build_differentiable_interface_graph(
rec_coords, rec_mask, rec_aa_idx, rec_chi,
binder_coords, binder_mask, binder_aa_idx, binder_chi,
cutoff=8.0, max_nodes=128
):
"""
Build interface graph with differentiable features w.r.t. binder_coords.
Receptor coords are treated as constants (detached).
Args:
rec_coords: [N_rec, 4, 3] — receptor backbone coords (constant, no grad)
rec_mask: [N_rec] bool
rec_aa_idx: [N_rec] int — amino acid indices (constant)
rec_chi: [N_rec, 4] — chi1/chi2 sin/cos (constant)
binder_coords: [N_binder, 4, 3] — binder backbone coords (requires_grad)
binder_mask: [N_binder] bool
binder_aa_idx: [N_binder] int — amino acid indices (constant, UNK for designed)
binder_chi: [N_binder, 4] — chi1/chi2 sin/cos (zeros for backbone-only)
cutoff: interface distance cutoff (Å)
max_nodes: maximum nodes per chain in the graph
Returns:
node_feats: [1, N_total, 32] tensor
edge_feats: [1, N_total, N_total, 37] tensor
node_mask: [1, N_total] bool tensor
n_rec: int
n_binder: int
or None if no interface
"""
device = binder_coords.device
dtype = binder_coords.dtype
NUM_AA = 21
# ── Find interface residues (differentiable distances but hard threshold) ──
rec_ca = rec_coords[:, 1, :] # [N_rec, 3]
binder_ca = binder_coords[:, 1, :] # [N_binder, 3]
# Pairwise CA distances
dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) # [N_rec, N_binder]
# Mask invalid residues
dist_mat = dist_mat.clone()
dist_mat[~rec_mask, :] = float('inf')
dist_mat[:, ~binder_mask] = float('inf')
rec_iface = (dist_mat < cutoff).any(dim=1) # [N_rec]
binder_iface = (dist_mat < cutoff).any(dim=0) # [N_binder]
rec_iface_idx = torch.where(rec_iface)[0]
binder_iface_idx = torch.where(binder_iface)[0]
# Truncate if too many
if len(rec_iface_idx) > max_nodes // 2:
rec_iface_idx = rec_iface_idx[:max_nodes // 2]
if len(binder_iface_idx) > max_nodes // 2:
binder_iface_idx = binder_iface_idx[:max_nodes // 2]
n_rec = len(rec_iface_idx)
n_binder = len(binder_iface_idx)
n_total = n_rec + n_binder
if n_total == 0:
return None
# ── Extract interface subsets ──
rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3]
binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3]
rec_iface_mask = rec_mask[rec_iface_idx]
binder_iface_mask = binder_mask[binder_iface_idx]
# ── Compute backbone frames (differentiable) ──
rec_origins, rec_rotations = compute_backbone_frames_torch(rec_iface_coords, rec_iface_mask)
binder_origins, binder_rotations = compute_backbone_frames_torch(binder_iface_coords, binder_iface_mask)
# ── Compute torsion angles (differentiable) ──
rec_torsion = compute_torsion_angles_torch(rec_iface_coords, rec_iface_mask) # [n_rec, 6]
binder_torsion = compute_torsion_angles_torch(binder_iface_coords, binder_iface_mask) # [n_binder, 6]
# ── Node features ──
# AA one-hot (non-differentiable constant)
rec_aa_onehot = F.one_hot(rec_aa_idx[rec_iface_idx].long(), NUM_AA).float() # [n_rec, 21]
binder_aa_onehot = F.one_hot(binder_aa_idx[binder_iface_idx].long(), NUM_AA).float() # [n_binder, 21]
# Chi angles (constant for receptor, zeros for backbone-only binder)
rec_chi_iface = rec_chi[rec_iface_idx] # [n_rec, 4]
binder_chi_iface = binder_chi[binder_iface_idx] # [n_binder, 4]
# Chain indicator
rec_chain_feat = torch.zeros(n_rec, 1, device=device, dtype=dtype)
binder_chain_feat = torch.ones(n_binder, 1, device=device, dtype=dtype)
# Concatenate node features: [AA(21) + torsions(6) + chi(4) + chain(1)] = 32
rec_node = torch.cat([rec_aa_onehot, rec_torsion, rec_chi_iface, rec_chain_feat], dim=-1)
binder_node = torch.cat([binder_aa_onehot, binder_torsion, binder_chi_iface, binder_chain_feat], dim=-1)
node_feats = torch.cat([rec_node, binder_node], dim=0) # [N_total, 32]
node_mask_flat = torch.cat([rec_iface_mask, binder_iface_mask], dim=0) # [N_total]
# ── Edge features (differentiable) ──
all_origins = torch.cat([rec_origins, binder_origins], dim=0) # [N_total, 3]
all_rotations = torch.cat([rec_rotations, binder_rotations], dim=0) # [N_total, 3, 3]
# Sequence indices
rec_seq_idx = rec_iface_idx
binder_seq_idx = binder_iface_idx + rec_coords.shape[0]
all_seq_idx = torch.cat([rec_seq_idx, binder_seq_idx], dim=0)
# Chain IDs
all_chain_ids = torch.cat([
torch.zeros(n_rec, device=device, dtype=torch.long),
torch.ones(n_binder, device=device, dtype=torch.long)
], dim=0)
edge_feats = compute_edge_features_torch(
all_origins, all_rotations, all_seq_idx, all_chain_ids, node_mask_flat
) # [N_total, N_total, 37]
# Add batch dimension
return {
'node_feats': node_feats.unsqueeze(0), # [1, N, 32]
'edge_feats': edge_feats.unsqueeze(0), # [1, N, N, 37]
'node_mask': node_mask_flat.unsqueeze(0), # [1, N]
'n_rec': n_rec,
'n_binder': n_binder,
}
# ── Differentiable Q_theta scoring function ──────────────────────────────────
class DifferentiableQTheta:
"""
Wraps the Q_theta scorer for differentiable scoring w.r.t. binder backbone
coordinates. Receptor structures are pre-loaded and cached.
Usage:
dq = DifferentiableQTheta(checkpoint_path, device)
dq.load_receptor(holo_pdb, chain='A', label='holo')
dq.load_receptor(apo_pdb, chain='A', label='apo')
binder_coords = torch.tensor(...) # [N_binder, 4, 3], requires_grad=True
score_holo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'holo')
score_apo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'apo')
selectivity = score_holo - score_apo
selectivity.backward()
# binder_coords.grad now contains ∂S/∂coords
"""
def __init__(self, checkpoint_path, device='cuda:0', esm_dir=None):
import sys, os
_code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if _code_dir not in sys.path:
sys.path.insert(0, _code_dir)
from models.scorer import build_model
self.device = torch.device(device)
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
self.config = ckpt['config']
self.model = build_model(self.config)
self.model.load_state_dict(ckpt['model_state'])
self.model = self.model.to(self.device)
self.model.eval()
# ESM feature support
self.use_esm = self.config.get('esm_dim', 0) > 0
self.esm_dim = self.config.get('esm_dim', 0)
self.esm_dir = esm_dir or os.path.join(os.environ.get('ALLOGEN_ROOT', '.'), 'data/esm2_embeddings')
# Cache receptor data
self.receptors = {} # label -> {coords, mask, aa_idx, chi, esm_emb?}
def load_receptor(self, pdb_path, chain='A', label='holo',
esm_target=None, esm_key=None):
"""Pre-load and cache receptor structure, optionally with ESM embeddings.
Args:
pdb_path: path to receptor PDB
chain: chain ID
label: cache key
esm_target: target name for ESM dir (e.g., 'abl' for data/esm2_embeddings/abl/)
esm_key: ESM embedding file key (e.g., '6XR7_A'). If None, auto-derived.
"""
import sys, os
_code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if _code_dir not in sys.path:
sys.path.insert(0, _code_dir)
from utils.pdb_utils import (
load_structure, get_residues, get_backbone_coords,
get_aa_indices, compute_chi_angles
)
model = load_structure(pdb_path)
chain_obj = model[chain]
residues = get_residues(chain_obj)
coords, mask = get_backbone_coords(residues)
aa_idx = get_aa_indices(residues)
chi = compute_chi_angles(residues, mask)
rec_data = {
'coords': torch.from_numpy(coords).float().to(self.device),
'mask': torch.from_numpy(mask).bool().to(self.device),
'aa_idx': torch.from_numpy(aa_idx).long().to(self.device),
'chi': torch.from_numpy(chi).float().to(self.device),
'residues': residues,
}
# Load ESM embeddings if model uses ESM
if self.use_esm and esm_target:
pdb_id = os.path.basename(pdb_path).replace('.pdb', '')
if esm_key is None:
esm_key = f'{pdb_id}_{chain}'
esm_path = os.path.join(self.esm_dir, esm_target, f'{esm_key}.pt')
if os.path.exists(esm_path):
esm_emb = torch.load(esm_path, map_location=self.device, weights_only=True)
# Truncate/pad to match residue count
n_res = len(residues)
if esm_emb.shape[0] > n_res:
esm_emb = esm_emb[:n_res]
elif esm_emb.shape[0] < n_res:
pad = torch.zeros(n_res - esm_emb.shape[0], esm_emb.shape[1],
device=self.device)
esm_emb = torch.cat([esm_emb, pad], dim=0)
rec_data['esm_emb'] = esm_emb.float()
else:
rec_data['esm_emb'] = torch.zeros(len(residues), self.esm_dim,
device=self.device)
self.receptors[label] = rec_data
def load_receptor_from_coords(self, coords, mask, aa_idx=None, chi=None,
label='path'):
"""
Load a receptor from raw backbone coords (not from PDB file).
Used for interpolated path frames that don't have PDB files.
If aa_idx is None, uses all-ALA (index 0). If chi is None, uses zeros.
Args:
coords: [N, 4, 3] numpy or torch backbone coords (N, CA, C, O)
mask: [N] numpy or torch bool
aa_idx: [N] numpy or torch int (default: all-ALA = 0)
chi: [N, 4] numpy or torch float (default: zeros)
label: str key for caching
"""
import numpy as np
# Convert numpy to torch if needed
if isinstance(coords, np.ndarray):
coords = torch.from_numpy(coords).float()
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask).bool()
N = coords.shape[0]
if aa_idx is None:
aa_idx = torch.zeros(N, dtype=torch.long) # all-ALA
elif isinstance(aa_idx, np.ndarray):
aa_idx = torch.from_numpy(aa_idx).long()
if chi is None:
chi = torch.zeros(N, 4, dtype=coords.dtype)
elif isinstance(chi, np.ndarray):
chi = torch.from_numpy(chi).float()
self.receptors[label] = {
'coords': coords.to(self.device),
'mask': mask.to(self.device),
'aa_idx': aa_idx.to(self.device),
'chi': chi.to(self.device),
}
def score(self, binder_coords, binder_mask, binder_aa_idx=None,
binder_chi=None, receptor_label='holo', cutoff=8.0):
"""
Score binder against a cached receptor. Differentiable w.r.t. binder_coords.
Args:
binder_coords: [N_binder, 4, 3] tensor (can have requires_grad=True)
binder_mask: [N_binder] bool tensor
binder_aa_idx: [N_binder] int tensor (default: all UNK)
binder_chi: [N_binder, 4] tensor (default: zeros)
receptor_label: key into cached receptors
cutoff: interface distance cutoff
Returns:
score: scalar tensor in (0, 1), differentiable w.r.t. binder_coords
"""
rec = self.receptors[receptor_label]
N_binder = binder_coords.shape[0]
if binder_aa_idx is None:
binder_aa_idx = torch.full((N_binder,), 20, device=self.device, dtype=torch.long) # UNK
if binder_chi is None:
binder_chi = torch.zeros(N_binder, 4, device=self.device, dtype=binder_coords.dtype)
graph = build_differentiable_interface_graph(
rec_coords=rec['coords'],
rec_mask=rec['mask'],
rec_aa_idx=rec['aa_idx'],
rec_chi=rec['chi'],
binder_coords=binder_coords,
binder_mask=binder_mask,
binder_aa_idx=binder_aa_idx,
binder_chi=binder_chi,
cutoff=cutoff,
)
if graph is None:
# No interface — return zero score with gradient
return torch.zeros(1, device=self.device, dtype=binder_coords.dtype, requires_grad=True).squeeze()
# Build ESM features if model uses ESM
esm_feats = None
if self.use_esm:
n_rec = graph['n_rec']
n_binder = graph['n_binder']
n_total = n_rec + n_binder
# Receptor ESM: use cached if available, else zeros
if 'esm_emb' in rec:
rec_esm = rec['esm_emb']
# Need to select interface residues (same indices as structural features)
# The graph was built with rec_iface_idx — we need those indices
# For simplicity, use zeros for now and rely on the projection layer
# to handle the zero binder ESM gracefully
rec_esm_full = rec_esm # [N_rec_total, 1280]
else:
rec_esm_full = torch.zeros(rec['coords'].shape[0], self.esm_dim,
device=self.device)
# Binder ESM: zeros (designed backbone, no sequence)
binder_esm = torch.zeros(binder_coords.shape[0], self.esm_dim,
device=self.device)
# We need interface indices to select — rebuild them
rec_ca = rec['coords'][:, 1, :]
binder_ca = binder_coords[:, 1, :]
dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0)
dist_mat_c = dist_mat.clone()
dist_mat_c[~rec['mask'], :] = float('inf')
dist_mat_c[:, ~binder_mask] = float('inf')
rec_iface = (dist_mat_c < cutoff).any(dim=1)
binder_iface = (dist_mat_c < cutoff).any(dim=0)
rec_iface_idx = torch.where(rec_iface)[0][:n_rec]
binder_iface_idx = torch.where(binder_iface)[0][:n_binder]
rec_esm_iface = rec_esm_full[rec_iface_idx] # [n_rec, 1280]
binder_esm_iface = binder_esm[binder_iface_idx] # [n_binder, 1280]
esm_combined = torch.cat([rec_esm_iface, binder_esm_iface], dim=0) # [n_total, 1280]
esm_feats = esm_combined.unsqueeze(0) # [1, n_total, 1280]
score = self.model(graph['node_feats'], graph['edge_feats'], graph['node_mask'],
esm_feats=esm_feats)
return score.squeeze() # scalar
def selectivity_margin(self, binder_coords, binder_mask,
binder_aa_idx=None, binder_chi=None,
holo_label='holo', apo_label='apo', cutoff=8.0):
"""
Compute selectivity margin S = Q(holo, Y) - Q(apo, Y).
Differentiable w.r.t. binder_coords.
"""
q_holo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi,
holo_label, cutoff)
q_apo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi,
apo_label, cutoff)
return q_holo - q_apo, q_holo, q_apo