AlloGen / code /utils /pdb_utils.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
PDB parsing utilities for Allo-Designer.
Extracts backbone geometry, computes local frames, and identifies interface residues.
"""
import numpy as np
from Bio import PDB
from Bio.PDB import PDBParser, MMCIFParser, PDBIO
from Bio.PDB.Polypeptide import is_aa
import warnings
warnings.filterwarnings("ignore", category=PDB.PDBExceptions.PDBConstructionWarning)
AA3_TO_IDX = {
'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4,
'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9,
'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19,
'UNK': 20,
}
NUM_AA = 21 # 20 standard + UNK
def load_structure(pdb_path: str, model_id: int = 0):
"""Load a PDB/CIF file and return the first model."""
if pdb_path.endswith('.cif') or pdb_path.endswith('.mmcif'):
parser = MMCIFParser(QUIET=True)
else:
parser = PDBParser(QUIET=True)
struct = parser.get_structure("protein", pdb_path)
return list(struct.get_models())[model_id]
def get_residues(chain, only_standard: bool = True):
"""Return a list of standard amino acid residues from a chain."""
residues = []
for res in chain.get_residues():
if only_standard and not is_aa(res, standard=True):
continue
if res.get_id()[0] != ' ': # skip HETATM
continue
residues.append(res)
return residues
def get_backbone_coords(residues):
"""
Extract backbone atom coordinates (N, CA, C, O) for each residue.
Returns: coords [N_res, 4, 3], mask [N_res] (True = all backbone atoms present)
"""
N = len(residues)
coords = np.zeros((N, 4, 3), dtype=np.float32)
mask = np.zeros(N, dtype=bool)
for i, res in enumerate(residues):
try:
coords[i, 0] = res['N'].get_vector().get_array()
coords[i, 1] = res['CA'].get_vector().get_array()
coords[i, 2] = res['C'].get_vector().get_array()
if 'O' in res:
coords[i, 3] = res['O'].get_vector().get_array()
else:
# Estimate O position if missing
coords[i, 3] = coords[i, 2]
mask[i] = True
except KeyError:
pass
return coords, mask
def get_aa_indices(residues):
"""Return integer amino acid indices for each residue."""
return np.array([
AA3_TO_IDX.get(res.get_resname(), AA3_TO_IDX['UNK'])
for res in residues
], dtype=np.int64)
def compute_backbone_frames(coords, mask):
"""
Compute SE(3)-equivariant backbone frames from N, CA, C atoms.
Frame: z-axis = CA->C, y-axis = component of CA->N perpendicular to z, x-axis = y x z.
Returns:
origins: [N, 3] = CA positions
rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes)
"""
N_res = coords.shape[0]
origins = coords[:, 1, :] # CA positions [N, 3]
rotations = np.zeros((N_res, 3, 3), dtype=np.float32)
for i in range(N_res):
if not mask[i]:
rotations[i] = np.eye(3)
continue
ca = coords[i, 1]
n = coords[i, 0]
c = coords[i, 2]
# z-axis: CA -> C
z = c - ca
z_norm = np.linalg.norm(z)
if z_norm < 1e-6:
rotations[i] = np.eye(3)
continue
z = z / z_norm
# y-axis: CA -> N, orthogonalized
y = n - ca
y = y - np.dot(y, z) * z
y_norm = np.linalg.norm(y)
if y_norm < 1e-6:
rotations[i] = np.eye(3)
continue
y = y / y_norm
# x-axis: y cross z
x = np.cross(y, z)
rotations[i] = np.stack([x, y, z], axis=-1) # columns are axes
return origins, rotations
def compute_torsion_angles(coords, mask):
"""
Compute backbone torsion angles (phi, psi, omega) for each residue.
Returns sin/cos of each angle. [N, 6]
"""
N = len(coords)
angles = np.zeros((N, 6), dtype=np.float32)
def dihedral(p0, p1, p2, p3):
"""Praxelis dihedral angle computation."""
b1 = p1 - p0
b2 = p2 - p1
b3 = p3 - p2
n1 = np.cross(b1, b2)
n2 = np.cross(b2, b3)
n1_norm = np.linalg.norm(n1)
n2_norm = np.linalg.norm(n2)
if n1_norm < 1e-6 or n2_norm < 1e-6:
return 0.0
n1 = n1 / n1_norm
n2 = n2 / n2_norm
m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8))
cos_a = np.clip(np.dot(n1, n2), -1, 1)
sin_a = np.dot(m1, n2)
return np.arctan2(sin_a, cos_a)
for i in range(N):
if not mask[i]:
continue
ca_i = coords[i, 1]
n_i = coords[i, 0]
c_i = coords[i, 2]
# Phi: C_{i-1} - N_i - CA_i - C_i
if i > 0 and mask[i - 1]:
c_prev = coords[i - 1, 2]
phi = dihedral(c_prev, n_i, ca_i, c_i)
angles[i, 0] = np.sin(phi)
angles[i, 1] = np.cos(phi)
# Psi: N_i - CA_i - C_i - N_{i+1}
if i < N - 1 and mask[i + 1]:
n_next = coords[i + 1, 0]
psi = dihedral(n_i, ca_i, c_i, n_next)
angles[i, 2] = np.sin(psi)
angles[i, 3] = np.cos(psi)
# Omega: CA_{i-1} - C_{i-1} - N_i - CA_i
if i > 0 and mask[i - 1]:
ca_prev = coords[i - 1, 1]
c_prev = coords[i - 1, 2]
omega = dihedral(ca_prev, c_prev, n_i, ca_i)
angles[i, 4] = np.sin(omega)
angles[i, 5] = np.cos(omega)
return angles
def get_interface_residues(rec_coords, binder_coords, rec_mask, binder_mask, cutoff: float = 8.0):
"""
Find interface residues: receptor residues within cutoff of any binder Cα, and vice versa.
Uses CA-CA distances.
Returns:
rec_interface: bool array [N_rec]
binder_interface: bool array [N_binder]
"""
rec_ca = rec_coords[:, 1, :] # [N_rec, 3]
binder_ca = binder_coords[:, 1, :] # [N_binder, 3]
# Pairwise CA-CA distances [N_rec, N_binder]
diff = rec_ca[:, None, :] - binder_ca[None, :, :] # [N_rec, N_binder, 3]
dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_rec, N_binder]
# Mask out residues without coordinates
dist[~rec_mask, :] = np.inf
dist[:, ~binder_mask] = np.inf
rec_interface = (dist < cutoff).any(axis=1)
binder_interface = (dist < cutoff).any(axis=0)
return rec_interface, binder_interface
def align_structures(mobile_ca, ref_ca, mobile_coords=None):
"""
Kabsch alignment: align mobile to ref using CA positions.
Returns aligned CA coords and optionally full backbone coords.
"""
assert mobile_ca.shape == ref_ca.shape, "Must have same number of residues"
# Center
mobile_center = mobile_ca.mean(axis=0)
ref_center = ref_ca.mean(axis=0)
m = mobile_ca - mobile_center
r = ref_ca - ref_center
# SVD
H = m.T @ r
U, S, Vt = np.linalg.svd(H)
d = np.sign(np.linalg.det(Vt.T @ U.T))
D = np.diag([1, 1, d])
R = Vt.T @ D @ U.T # rotation matrix
mobile_ca_aligned = (m @ R.T) + ref_center
if mobile_coords is not None:
# Apply same rotation to full backbone
N_res, N_atoms, _ = mobile_coords.shape
flat = mobile_coords.reshape(-1, 3) - mobile_center
aligned_flat = (flat @ R.T) + ref_center
mobile_coords_aligned = aligned_flat.reshape(N_res, N_atoms, 3)
return mobile_ca_aligned, R, mobile_coords_aligned
return mobile_ca_aligned, R
def compute_ca_rmsd(coords1, coords2, mask=None):
"""Compute CA-RMSD between two sets of backbone coordinates."""
ca1 = coords1[:, 1, :]
ca2 = coords2[:, 1, :]
if mask is not None:
ca1 = ca1[mask]
ca2 = ca2[mask]
diff = ca1 - ca2
return np.sqrt((diff ** 2).sum(axis=-1).mean())
def compute_fraction_native_contacts(
native_rec_ca, native_binder_ca,
model_rec_ca=None, model_binder_ca=None,
cutoff=8.0,
# Legacy 2-arg signature support
mask=None, delta=1.0,
):
"""
Compute fraction of native inter-chain contacts (fNAT).
fNAT = |recovered inter-chain contacts| / |native inter-chain contacts|
A native contact is a (receptor_i, binder_j) pair with CA-CA distance
< cutoff in the native complex. A contact is "recovered" if the same
pair is < cutoff in the model complex.
Args:
native_rec_ca: [N_rec, 3] receptor CA coords in native complex
native_binder_ca: [N_bind, 3] binder CA coords in native complex
model_rec_ca: [N_rec, 3] receptor CA in model (default: same as native)
model_binder_ca: [N_bind, 3] binder CA in model (default: same as native)
cutoff: contact distance threshold in Angstroms (default 8.0 for CA-CA)
Returns:
fNAT in [0, 1]. Returns 0.0 if no native contacts exist.
"""
if model_rec_ca is None:
model_rec_ca = native_rec_ca
if model_binder_ca is None:
model_binder_ca = native_binder_ca
# Inter-chain distance matrices [N_rec, N_bind]
native_dist = np.sqrt(
((native_rec_ca[:, None, :] - native_binder_ca[None, :, :]) ** 2).sum(-1)
)
model_dist = np.sqrt(
((model_rec_ca[:, None, :] - model_binder_ca[None, :, :]) ** 2).sum(-1)
)
native_contacts = native_dist < cutoff
recovered = native_contacts & (model_dist < cutoff)
n_native = native_contacts.sum()
if n_native == 0:
return 0.0
return float(recovered.sum()) / float(n_native)
def rbf_encode(distances, d_min=0.0, d_max=20.0, n_bins=16):
"""
RBF encoding of distances using Gaussian basis functions.
Returns: [*distances.shape, n_bins]
"""
centers = np.linspace(d_min, d_max, n_bins)
sigma = (d_max - d_min) / (n_bins - 1)
encoded = np.exp(-((distances[..., None] - centers) ** 2) / (2 * sigma ** 2))
return encoded.astype(np.float32)
# Candidate sidechain atoms for chi1 (first atom after CB)
_CHI1_ATOMS = ['CG', 'CG1', 'OG', 'OG1', 'SG']
# Candidate sidechain atoms for chi2 (second dihedral: CA-CB-XG-XD)
_CHI2_ATOMS = ['CD', 'CD1', 'SD', 'OD1', 'ND1', 'CE', 'NE', 'OE1']
def _dihedral_4pts(p0, p1, p2, p3):
"""Compute dihedral angle between four 3D points (radians)."""
b1 = p1 - p0
b2 = p2 - p1
b3 = p3 - p2
n1 = np.cross(b1, b2)
n2 = np.cross(b2, b3)
n1_norm = np.linalg.norm(n1)
n2_norm = np.linalg.norm(n2)
if n1_norm < 1e-6 or n2_norm < 1e-6:
return 0.0
n1 = n1 / n1_norm
n2 = n2 / n2_norm
m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8))
return np.arctan2(np.dot(m1, n2), np.dot(n1, n2))
def compute_chi_angles(residues, mask):
"""
Compute chi1 and chi2 sidechain torsion angles for each residue.
Chi1: N - CA - CB - XG (first sidechain dihedral)
Chi2: CA - CB - XG - XD (second sidechain dihedral)
For residues lacking the atoms (Gly, or missing coordinates), returns zeros.
Returns:
chi_feats: [N, 4] (sin_chi1, cos_chi1, sin_chi2, cos_chi2)
"""
N = len(residues)
chi_feats = np.zeros((N, 4), dtype=np.float32)
for i, res in enumerate(residues):
if not mask[i]:
continue
atoms = {atom.get_name(): atom.get_vector().get_array() for atom in res.get_atoms()
if atom.get_name() in ('N', 'CA', 'CB') + tuple(_CHI1_ATOMS) + tuple(_CHI2_ATOMS)}
n_pos = atoms.get('N')
ca_pos = atoms.get('CA')
cb_pos = atoms.get('CB')
if n_pos is None or ca_pos is None or cb_pos is None:
continue
# Chi1: N - CA - CB - XG
xg_pos = None
for aname in _CHI1_ATOMS:
if aname in atoms:
xg_pos = atoms[aname]
break
if xg_pos is not None:
chi1 = _dihedral_4pts(np.array(n_pos), np.array(ca_pos),
np.array(cb_pos), np.array(xg_pos))
chi_feats[i, 0] = np.sin(chi1)
chi_feats[i, 1] = np.cos(chi1)
# Chi2: CA - CB - XG - XD
xd_pos = None
for aname in _CHI2_ATOMS:
if aname in atoms:
xd_pos = atoms[aname]
break
if xd_pos is not None:
chi2 = _dihedral_4pts(np.array(ca_pos), np.array(cb_pos),
np.array(xg_pos), np.array(xd_pos))
chi_feats[i, 2] = np.sin(chi2)
chi_feats[i, 3] = np.cos(chi2)
return chi_feats
def get_cb_positions(residues, coords, mask):
"""
Return CB positions for each residue (CA position for Gly or missing CB).
Returns:
cb_pos: [N, 3]
"""
N = len(residues)
cb_pos = coords[:, 1, :].copy() # default to CA
for i, res in enumerate(residues):
if not mask[i]:
continue
try:
cb_pos[i] = res['CB'].get_vector().get_array()
except KeyError:
pass # Gly or missing CB: keep CA
return cb_pos.astype(np.float32)
# Simplified hydrophobicity groups for contact energy
_HYDROPHOBIC = {'ALA', 'VAL', 'ILE', 'LEU', 'MET', 'PHE', 'TRP', 'PRO', 'TYR'}
_POS_CHARGED = {'ARG', 'LYS', 'HIS'}
_NEG_CHARGED = {'ASP', 'GLU'}
def _residue_group(resname):
if resname in _HYDROPHOBIC:
return 'H'
if resname in _POS_CHARGED:
return '+'
if resname in _NEG_CHARGED:
return '-'
return 'P' # polar
def compute_contact_energy(rec_residues, binder_residues,
rec_cb, binder_cb,
rec_mask, binder_mask,
cutoff: float = 8.0):
"""
Compute a simple CB-CB contact energy as a physics-based ddG proxy.
Uses a 4-group hydrophobicity potential:
HH: -1.0 (hydrophobic-hydrophobic, favorable)
+-: -0.5 (opposite charges, favorable)
H+/-: +0.3 (hydrophobic-charged, unfavorable)
else: 0.0
Returns a scalar in [0, 1] via sigmoid normalization.
"""
n_rec = len(rec_residues)
n_binder = len(binder_residues)
# CB-CB distance matrix [n_rec, n_binder]
diff = rec_cb[:, None, :] - binder_cb[None, :, :] # [n_rec, n_binder, 3]
dist = np.sqrt((diff ** 2).sum(axis=-1)) # [n_rec, n_binder]
# Mask invalid residues
dist[~rec_mask, :] = np.inf
dist[:, ~binder_mask] = np.inf
contact_mask = dist < cutoff
energy = 0.0
for i in range(n_rec):
for j in range(n_binder):
if not contact_mask[i, j]:
continue
gi = _residue_group(rec_residues[i].get_resname())
gj = _residue_group(binder_residues[j].get_resname())
if gi == 'H' and gj == 'H':
energy -= 1.0
elif (gi == '+' and gj == '-') or (gi == '-' and gj == '+'):
energy -= 0.5
elif (gi == 'H' and gj in ('+', '-')) or (gj == 'H' and gi in ('+', '-')):
energy += 0.3
# Normalize: sigmoid of (energy / 10) shifted so that 0 contacts → score 0.3
score = 1.0 / (1.0 + np.exp(-(energy - 5.0) / 5.0))
return float(score)