AlloGen / code /utils /anm.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
Anisotropic Network Model (ANM) for conformational path interpolation.
From-scratch implementation using scipy eigendecomposition.
Projects the apo→holo displacement onto low-frequency normal modes
to create physically motivated interpolation paths.
"""
import numpy as np
from scipy.linalg import eigh
def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10):
"""
Build elastic network Hessian and compute normal modes via eigendecomposition.
Args:
ca_coords: [N, 3] CA atom coordinates
cutoff: distance cutoff for spring connections (Angstroms)
n_modes: number of non-trivial modes to return
Returns:
eigenvalues: [n_modes] array of eigenvalues (force constants)
eigenvectors: [n_modes, N, 3] mode displacement vectors
"""
N = len(ca_coords)
if N < 4:
return np.zeros(n_modes), np.zeros((n_modes, N, 3))
# Build 3N x 3N Hessian with uniform spring constant (gamma=1)
H = np.zeros((3 * N, 3 * N), dtype=np.float64)
for i in range(N):
for j in range(i + 1, N):
diff = ca_coords[j] - ca_coords[i]
dist = np.linalg.norm(diff)
if dist > cutoff or dist < 1e-6:
continue
# Outer product of unit displacement vector
unit = diff / dist
block = np.outer(unit, unit) # [3, 3]
# Off-diagonal: H[i,j] = -gamma * (r_ij ⊗ r_ij) / |r_ij|^2
# With uniform gamma=1 and unit vectors, this simplifies to:
ii, jj = 3 * i, 3 * j
H[ii:ii+3, jj:jj+3] = -block
H[jj:jj+3, ii:ii+3] = -block
# Diagonal: accumulate
H[ii:ii+3, ii:ii+3] += block
H[jj:jj+3, jj:jj+3] += block
# Eigendecompose — first 6 modes are trivial (3 translation + 3 rotation)
n_total = min(6 + n_modes, 3 * N)
eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1])
# Skip the 6 trivial zero-frequency modes
start = min(6, len(eigenvalues) - 1)
n_available = len(eigenvalues) - start
n_return = min(n_modes, n_available)
evals = eigenvalues[start:start + n_return]
evecs = eigvecs[:, start:start + n_return] # [3N, n_return]
# Reshape eigenvectors to [n_modes, N, 3]
mode_vectors = np.zeros((n_return, N, 3))
for k in range(n_return):
mode_vectors[k] = evecs[:, k].reshape(N, 3)
# Pad if fewer modes available than requested
if n_return < n_modes:
pad_evals = np.zeros(n_modes)
pad_evals[:n_return] = evals
pad_modes = np.zeros((n_modes, N, 3))
pad_modes[:n_return] = mode_vectors
return pad_evals, pad_modes
return evals, mode_vectors
def _kabsch_align(mobile_ca, ref_ca):
"""Kabsch alignment of mobile onto ref (CA atoms only)."""
t_mobile = mobile_ca.mean(axis=0)
t_ref = ref_ca.mean(axis=0)
m = mobile_ca - t_mobile
r = ref_ca - t_ref
H = m.T @ r
U, S, Vt = np.linalg.svd(H)
d = np.linalg.det(Vt.T @ U.T)
sign = np.array([1.0, 1.0, np.sign(d)])
R = Vt.T @ np.diag(sign) @ U.T
return R, t_mobile, t_ref
def _reconstruct_oxygen(coords):
"""Reconstruct O atom from N, CA, C with ideal C=O geometry."""
C_pos = coords[:, 2, :]
CA_pos = coords[:, 1, :]
C_CA = C_pos - CA_pos
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
O_pos = C_pos + (C_CA / C_CA_norm) * 1.24
coords[:, 3, :] = O_pos
return coords
def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
n_frames=5, n_modes=10, cutoff=15.0):
"""
Interpolate backbone along dominant ANM modes from X0 toward X1.
Low-frequency modes capture global domain motions (e.g., CaM hinge bending),
creating physically informed paths where large-scale motions precede local
adjustments.
Args:
coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state
coords_x1: [N1, 4, 3] backbone coords for holo state
mask_x0: [N0] bool
mask_x1: [N1] bool
n_frames: number of intermediate frames (excluding endpoints)
n_modes: number of ANM modes to use for projection
cutoff: ANM spring cutoff in Angstroms
Returns:
path_frames: list of (coords_tau, mask_tau, tau) tuples
Same interface as interpolate_backbone_path
"""
n_common = min(len(coords_x0), len(coords_x1))
c0 = coords_x0[:n_common].copy()
c1 = coords_x1[:n_common].copy()
m0 = mask_x0[:n_common]
m1 = mask_x1[:n_common]
common_mask = m0 & m1
if common_mask.sum() < 5:
return []
# Kabsch-align X0 onto X1 using valid CA atoms
ca0 = c0[common_mask, 1, :]
ca1 = c1[common_mask, 1, :]
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
# Apply alignment to all X0 backbone atoms
flat0 = c0.reshape(-1, 3)
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
c0_aligned = aligned0.reshape(n_common, 4, 3)
# Compute apo→holo displacement (CA atoms, valid residues only)
ca0_aligned = c0_aligned[common_mask, 1, :] # [N_valid, 3]
ca1_valid = c1[common_mask, 1, :]
displacement = ca1_valid - ca0_aligned # [N_valid, 3]
# Compute ANM modes of the aligned apo structure
eigenvalues, mode_vectors = compute_anm_modes(
ca0_aligned, cutoff=cutoff, n_modes=n_modes
) # mode_vectors: [n_modes, N_valid, 3]
# Project displacement onto each mode
# d_k = sum_i mode_k[i] . displacement[i]
projections = np.zeros(n_modes)
for k in range(n_modes):
projections[k] = np.sum(mode_vectors[k] * displacement)
# Reconstruct mode-projected displacement: d_mode = sum_k d_k * mode_k
mode_displacement = np.zeros_like(displacement) # [N_valid, 3]
for k in range(n_modes):
mode_displacement += projections[k] * mode_vectors[k]
# Residual displacement not captured by modes
residual = displacement - mode_displacement
# Generate intermediate frames
taus = np.linspace(0, 1, n_frames + 2)[1:-1]
path_frames = []
for tau in taus:
# Apply mode-projected + residual displacement at each tau
# Mode component applies smoothly; residual is linear
ca_interp = ca0_aligned + tau * mode_displacement + tau * residual
# Build full backbone by interpolating all 4 atom types
coords_tau = (1.0 - tau) * c0_aligned + tau * c1
# Override CA positions with ANM-interpolated values
coords_tau[common_mask, 1, :] = ca_interp
# Adjust N, C positions relative to CA shift
# The N/CA/C triangle is preserved by blending the ANM CA shift
# with the linear interpolation of N and C
ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid)
coords_tau[common_mask, 0, :] += ca_shift # N atoms
coords_tau[common_mask, 2, :] += ca_shift # C atoms
# Reconstruct O from N, CA, C
coords_tau = _reconstruct_oxygen(coords_tau)
path_frames.append((
coords_tau.astype(np.float32),
common_mask.copy(),
float(tau),
))
return path_frames