AlloGen / code /utils /path_utils.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
Transition-path interpolation utilities for conformational induction.
Provides:
- Kabsch-aligned backbone interpolation between two conformational states
- Gaussian Schrödinger Bridge (DSB) stochastic interpolation
- Precomputed frame loading (for AlphaFlow / AFsample2)
- Unified dispatcher: generate_path_frames()
- Per-residue displacement computation (for allosteric hinge weighting)
- Monotonically increasing path weight generation
Used by the path-aware training, guidance, and refinement modules.
"""
import os
import logging
import numpy as np
logger = logging.getLogger(__name__)
def _kabsch_align(mobile_ca, ref_ca):
"""
Kabsch alignment of mobile onto ref (CA atoms only).
Args:
mobile_ca: [N, 3] array
ref_ca: [N, 3] array
Returns:
R: [3, 3] rotation matrix
t_mobile: [3] mobile centroid
t_ref: [3] ref centroid
Such that: aligned = (mobile - t_mobile) @ R.T + t_ref
"""
t_mobile = mobile_ca.mean(axis=0)
t_ref = ref_ca.mean(axis=0)
m = mobile_ca - t_mobile
r = ref_ca - t_ref
H = m.T @ r
U, S, Vt = np.linalg.svd(H)
d = np.linalg.det(Vt.T @ U.T)
sign = np.array([1.0, 1.0, np.sign(d)])
R = Vt.T @ np.diag(sign) @ U.T
return R, t_mobile, t_ref
def interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5):
"""
Generate intermediate backbone conformations along the X0 -> X1 path.
1. Find common valid residues between X0 and X1
2. Kabsch-align X0 onto X1 using CA atoms
3. Linearly interpolate backbone coords at n_frames equally-spaced tau values
4. Reconstruct O from N/CA/C with ideal geometry
Args:
coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for state 0
coords_x1: [N1, 4, 3] backbone coords for state 1
mask_x0: [N0] bool
mask_x1: [N1] bool
n_frames: number of intermediate frames (excluding endpoints)
Returns:
path_frames: list of (coords_tau, mask_tau, tau) tuples
coords_tau: [N_common, 4, 3] interpolated backbone coords
mask_tau: [N_common] bool
tau: float in (0, 1) exclusive
"""
# Use common length
n_common = min(len(coords_x0), len(coords_x1))
c0 = coords_x0[:n_common].copy()
c1 = coords_x1[:n_common].copy()
m0 = mask_x0[:n_common]
m1 = mask_x1[:n_common]
# Valid in both states
common_mask = m0 & m1
if common_mask.sum() < 5:
return []
# Kabsch-align X0 onto X1 using valid CA atoms
ca0 = c0[common_mask, 1, :] # CA atoms
ca1 = c1[common_mask, 1, :]
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
# Apply alignment to all X0 backbone atoms
n_res = n_common
flat0 = c0.reshape(-1, 3)
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
c0_aligned = aligned0.reshape(n_res, 4, 3)
# Generate intermediate frames
taus = np.linspace(0, 1, n_frames + 2)[1:-1] # exclude endpoints
path_frames = []
for tau in taus:
# Linear interpolation: X_tau = (1 - tau) * X0_aligned + tau * X1
coords_tau = (1.0 - tau) * c0_aligned + tau * c1
# Reconstruct O from N, CA, C with ideal C=O bond geometry
C_pos = coords_tau[:, 2, :] # C atoms
CA_pos = coords_tau[:, 1, :] # CA atoms
C_CA = C_pos - CA_pos
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 # ideal C=O bond length
coords_tau[:, 3, :] = O_pos
path_frames.append((
coords_tau.astype(np.float32),
common_mask.copy(),
float(tau),
))
return path_frames
def compute_residue_displacements(coords_x0, coords_x1, mask_x0, mask_x1):
"""
Per-residue CA displacement between X0 and X1 after Kabsch alignment.
Args:
coords_x0: [N0, 4, 3] backbone coords for state 0
coords_x1: [N1, 4, 3] backbone coords for state 1
mask_x0: [N0] bool
mask_x1: [N1] bool
Returns:
displacements: [N_common] array of per-residue CA RMSD
common_mask: [N_common] bool — which residues are valid
"""
n_common = min(len(coords_x0), len(coords_x1))
c0 = coords_x0[:n_common]
c1 = coords_x1[:n_common]
m0 = mask_x0[:n_common]
m1 = mask_x1[:n_common]
common_mask = m0 & m1
if common_mask.sum() < 5:
return np.zeros(n_common), common_mask
ca0 = c0[common_mask, 1, :]
ca1 = c1[common_mask, 1, :]
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
# Align all CA of X0
all_ca0 = c0[:, 1, :]
aligned_ca0 = (all_ca0 - t_mobile) @ R.T + t_ref
# Per-residue displacement
all_ca1 = c1[:, 1, :]
displacements = np.linalg.norm(aligned_ca0 - all_ca1, axis=-1)
# Zero out invalid residues
displacements[~common_mask] = 0.0
return displacements.astype(np.float32), common_mask
def generate_path_weights(n_frames, mode='linear'):
"""
Generate monotonically increasing weights for path frames.
The weights increase toward tau=1 (the goal state), so that
intermediate conformations closer to X1 are weighted more heavily.
Args:
n_frames: number of intermediate frames
mode: weight schedule
'linear': w_tau = tau
'quadratic': w_tau = tau^2
'exponential': w_tau = (exp(tau) - 1) / (e - 1)
'uniform': w_tau = 1/n_frames (equal weighting)
Returns:
weights: [n_frames] numpy array, normalized to sum to 1
"""
if n_frames == 0:
return np.array([], dtype=np.float32)
taus = np.linspace(0, 1, n_frames + 2)[1:-1] # same as interpolation
if mode == 'linear':
weights = taus.copy()
elif mode == 'quadratic':
weights = taus ** 2
elif mode == 'exponential':
weights = (np.exp(taus) - 1.0) / (np.e - 1.0)
elif mode == 'uniform':
weights = np.ones(n_frames, dtype=np.float32)
else:
raise ValueError(f"Unknown weight mode: {mode}")
# Normalize to sum to 1
total = weights.sum()
if total > 0:
weights = weights / total
return weights.astype(np.float32)
# ---------------------------------------------------------------------------
# Gaussian Schrödinger Bridge (AlignDSB) interpolation
# ---------------------------------------------------------------------------
def dsb_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
n_frames=5, sigma=0.5, n_samples=20, seed=42):
"""
Gaussian Schrödinger Bridge with t*(1-t) variance schedule.
Analytic formula (no neural network):
X_t = (1-t) * X0_aligned + t * X1 + sqrt(t * (1-t)) * sigma * Z
Variance peaks at t=0.5 (maximum uncertainty mid-transition) and vanishes
at endpoints. sigma controls noise amplitude in Angstroms.
For each tau, samples n_samples noisy interpolations and selects the
median (by RMSD to the mean) for robustness.
Args:
coords_x0: [N0, 4, 3] backbone coords for state 0
coords_x1: [N1, 4, 3] backbone coords for state 1
mask_x0: [N0] bool
mask_x1: [N1] bool
n_frames: number of intermediate frames
sigma: noise amplitude (Angstroms)
n_samples: number of samples per frame for median selection
seed: random seed
Returns:
path_frames: list of (coords_tau, mask_tau, tau) tuples
"""
rng = np.random.RandomState(seed)
n_common = min(len(coords_x0), len(coords_x1))
c0 = coords_x0[:n_common].copy()
c1 = coords_x1[:n_common].copy()
m0 = mask_x0[:n_common]
m1 = mask_x1[:n_common]
common_mask = m0 & m1
if common_mask.sum() < 5:
return []
# Kabsch-align X0 onto X1
ca0 = c0[common_mask, 1, :]
ca1 = c1[common_mask, 1, :]
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
flat0 = c0.reshape(-1, 3)
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
c0_aligned = aligned0.reshape(n_common, 4, 3)
taus = np.linspace(0, 1, n_frames + 2)[1:-1]
path_frames = []
for tau in taus:
noise_scale = np.sqrt(tau * (1.0 - tau)) * sigma
# Generate n_samples noisy interpolations
samples = []
for _ in range(n_samples):
Z = rng.randn(n_common, 4, 3).astype(np.float64)
X_t = (1.0 - tau) * c0_aligned + tau * c1 + noise_scale * Z
samples.append(X_t)
samples = np.array(samples) # [n_samples, N, 4, 3]
mean_sample = samples.mean(axis=0) # [N, 4, 3]
# Select median sample by RMSD to mean (CA atoms)
rmsds = []
for s in samples:
diff = s[common_mask, 1, :] - mean_sample[common_mask, 1, :]
rmsd = np.sqrt((diff ** 2).sum() / common_mask.sum())
rmsds.append(rmsd)
median_idx = np.argsort(rmsds)[len(rmsds) // 2]
coords_tau = samples[median_idx]
# Reconstruct O from N, CA, C
C_pos = coords_tau[:, 2, :]
CA_pos = coords_tau[:, 1, :]
C_CA = C_pos - CA_pos
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
coords_tau[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24
path_frames.append((
coords_tau.astype(np.float32),
common_mask.copy(),
float(tau),
))
return path_frames
# ---------------------------------------------------------------------------
# Precomputed frame loading (for AlphaFlow / AFsample2)
# ---------------------------------------------------------------------------
def load_precomputed_frames(target, method, precomputed_dir,
coords_x0, coords_x1, mask_x0, mask_x1,
n_frames=5):
"""
Load pre-generated frames from .npz and Kabsch-align to this complex's
receptor coordinate frame.
Expected file: {precomputed_dir}/{target}/{method}/frames.npz
with keys: 'frames' [n_frames, N_ref, 4, 3], 'taus' [n_frames],
'mask' [N_ref] bool
Args:
target: target name (e.g. 'cam')
method: method name ('alphaflow' or 'afsample2')
precomputed_dir: root directory for precomputed frames
coords_x0, coords_x1: apo/holo backbone coords for alignment
mask_x0, mask_x1: residue masks
n_frames: number of frames to return
Returns:
path_frames: list of (coords_tau, mask_tau, tau) tuples
"""
npz_path = os.path.join(precomputed_dir, target, method, 'frames.npz')
if not os.path.exists(npz_path):
logger.warning(f"Precomputed frames not found: {npz_path}, "
f"falling back to linear interpolation")
return interpolate_backbone_path(coords_x0, coords_x1,
mask_x0, mask_x1, n_frames)
data = np.load(npz_path)
pre_frames = data['frames'] # [K, N_ref, 4, 3]
pre_taus = data['taus'] # [K]
pre_mask = data['mask'] # [N_ref]
n_common = min(len(coords_x0), len(coords_x1), len(pre_mask))
m0 = mask_x0[:n_common]
m1 = mask_x1[:n_common]
pm = pre_mask[:n_common]
common_mask = m0 & m1 & pm
if common_mask.sum() < 5:
logger.warning(f"Too few common residues for {target}/{method}, "
f"falling back to linear")
return interpolate_backbone_path(coords_x0, coords_x1,
mask_x0, mask_x1, n_frames)
# Align precomputed frames to the holo receptor (X1) coordinate frame
# The precomputed frames were generated from the reference apo sequence
# and may be in a different coordinate frame
ref_ca = coords_x1[:n_common][common_mask, 1, :] # holo CA as reference
path_frames = []
K = min(len(pre_frames), n_frames)
# Select n_frames evenly spaced from available frames
if len(pre_frames) > n_frames:
indices = np.linspace(0, len(pre_frames) - 1, n_frames).astype(int)
else:
indices = np.arange(K)
for idx in indices:
frame = pre_frames[idx, :n_common].copy() # [N_common, 4, 3]
tau = float(pre_taus[idx])
# Kabsch-align frame CA to holo CA
frame_ca = frame[common_mask, 1, :]
R, t_frame, t_ref = _kabsch_align(frame_ca, ref_ca)
flat_frame = frame.reshape(-1, 3)
aligned = (flat_frame - t_frame) @ R.T + t_ref
frame_aligned = aligned.reshape(n_common, 4, 3)
# Reconstruct O
C_pos = frame_aligned[:, 2, :]
CA_pos = frame_aligned[:, 1, :]
C_CA = C_pos - CA_pos
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
frame_aligned[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24
path_frames.append((
frame_aligned.astype(np.float32),
common_mask.copy(),
tau,
))
return path_frames
# ---------------------------------------------------------------------------
# Unified dispatcher
# ---------------------------------------------------------------------------
def generate_path_frames(coords_x0, coords_x1, mask_x0, mask_x1,
method='linear', n_frames=5,
precomputed_dir=None, target=None, **kwargs):
"""
Dispatch to method-specific frame generation.
Args:
coords_x0, coords_x1: [N, 4, 3] backbone coords for apo/holo
mask_x0, mask_x1: [N] bool masks
method: one of 'linear', 'alphaflow', 'afsample2', 'dsb', 'anm'
n_frames: number of intermediate frames
precomputed_dir: directory for precomputed frames (alphaflow/afsample2)
target: target name (needed for precomputed methods)
**kwargs: method-specific parameters (sigma, n_modes, etc.)
Returns:
path_frames: list of (coords_tau, mask_tau, tau) tuples
"""
if method == 'linear':
return interpolate_backbone_path(
coords_x0, coords_x1, mask_x0, mask_x1, n_frames)
elif method in ('alphaflow', 'afsample2'):
if precomputed_dir is None:
raise ValueError(f"precomputed_dir required for method '{method}'")
if target is None:
raise ValueError(f"target name required for method '{method}'")
return load_precomputed_frames(
target, method, precomputed_dir,
coords_x0, coords_x1, mask_x0, mask_x1, n_frames)
elif method == 'dsb':
return dsb_backbone_path(
coords_x0, coords_x1, mask_x0, mask_x1,
n_frames=n_frames,
sigma=kwargs.get('sigma', 0.5),
n_samples=kwargs.get('n_samples', 20),
seed=kwargs.get('seed', 42))
elif method == 'anm':
from utils.anm import anm_backbone_path
return anm_backbone_path(
coords_x0, coords_x1, mask_x0, mask_x1,
n_frames=n_frames,
n_modes=kwargs.get('n_modes', 10),
cutoff=kwargs.get('cutoff', 15.0))
else:
raise ValueError(f"Unknown path method: '{method}'. "
f"Choose from: linear, alphaflow, afsample2, dsb, anm")