| """ |
| Transition-path interpolation utilities for conformational induction. |
| |
| Provides: |
| - Kabsch-aligned backbone interpolation between two conformational states |
| - Gaussian Schrödinger Bridge (DSB) stochastic interpolation |
| - Precomputed frame loading (for AlphaFlow / AFsample2) |
| - Unified dispatcher: generate_path_frames() |
| - Per-residue displacement computation (for allosteric hinge weighting) |
| - Monotonically increasing path weight generation |
| |
| Used by the path-aware training, guidance, and refinement modules. |
| """ |
|
|
| import os |
| import logging |
| import numpy as np |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _kabsch_align(mobile_ca, ref_ca): |
| """ |
| Kabsch alignment of mobile onto ref (CA atoms only). |
| |
| Args: |
| mobile_ca: [N, 3] array |
| ref_ca: [N, 3] array |
| |
| Returns: |
| R: [3, 3] rotation matrix |
| t_mobile: [3] mobile centroid |
| t_ref: [3] ref centroid |
| Such that: aligned = (mobile - t_mobile) @ R.T + t_ref |
| """ |
| t_mobile = mobile_ca.mean(axis=0) |
| t_ref = ref_ca.mean(axis=0) |
|
|
| m = mobile_ca - t_mobile |
| r = ref_ca - t_ref |
|
|
| H = m.T @ r |
| U, S, Vt = np.linalg.svd(H) |
| d = np.linalg.det(Vt.T @ U.T) |
| sign = np.array([1.0, 1.0, np.sign(d)]) |
| R = Vt.T @ np.diag(sign) @ U.T |
|
|
| return R, t_mobile, t_ref |
|
|
|
|
| def interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5): |
| """ |
| Generate intermediate backbone conformations along the X0 -> X1 path. |
| |
| 1. Find common valid residues between X0 and X1 |
| 2. Kabsch-align X0 onto X1 using CA atoms |
| 3. Linearly interpolate backbone coords at n_frames equally-spaced tau values |
| 4. Reconstruct O from N/CA/C with ideal geometry |
| |
| Args: |
| coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for state 0 |
| coords_x1: [N1, 4, 3] backbone coords for state 1 |
| mask_x0: [N0] bool |
| mask_x1: [N1] bool |
| n_frames: number of intermediate frames (excluding endpoints) |
| |
| Returns: |
| path_frames: list of (coords_tau, mask_tau, tau) tuples |
| coords_tau: [N_common, 4, 3] interpolated backbone coords |
| mask_tau: [N_common] bool |
| tau: float in (0, 1) exclusive |
| """ |
| |
| n_common = min(len(coords_x0), len(coords_x1)) |
| c0 = coords_x0[:n_common].copy() |
| c1 = coords_x1[:n_common].copy() |
| m0 = mask_x0[:n_common] |
| m1 = mask_x1[:n_common] |
|
|
| |
| common_mask = m0 & m1 |
| if common_mask.sum() < 5: |
| return [] |
|
|
| |
| ca0 = c0[common_mask, 1, :] |
| ca1 = c1[common_mask, 1, :] |
|
|
| R, t_mobile, t_ref = _kabsch_align(ca0, ca1) |
|
|
| |
| n_res = n_common |
| flat0 = c0.reshape(-1, 3) |
| aligned0 = (flat0 - t_mobile) @ R.T + t_ref |
| c0_aligned = aligned0.reshape(n_res, 4, 3) |
|
|
| |
| taus = np.linspace(0, 1, n_frames + 2)[1:-1] |
| path_frames = [] |
|
|
| for tau in taus: |
| |
| coords_tau = (1.0 - tau) * c0_aligned + tau * c1 |
|
|
| |
| C_pos = coords_tau[:, 2, :] |
| CA_pos = coords_tau[:, 1, :] |
| C_CA = C_pos - CA_pos |
| C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) |
| C_CA_norm = np.maximum(C_CA_norm, 1e-8) |
| O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 |
| coords_tau[:, 3, :] = O_pos |
|
|
| path_frames.append(( |
| coords_tau.astype(np.float32), |
| common_mask.copy(), |
| float(tau), |
| )) |
|
|
| return path_frames |
|
|
|
|
| def compute_residue_displacements(coords_x0, coords_x1, mask_x0, mask_x1): |
| """ |
| Per-residue CA displacement between X0 and X1 after Kabsch alignment. |
| |
| Args: |
| coords_x0: [N0, 4, 3] backbone coords for state 0 |
| coords_x1: [N1, 4, 3] backbone coords for state 1 |
| mask_x0: [N0] bool |
| mask_x1: [N1] bool |
| |
| Returns: |
| displacements: [N_common] array of per-residue CA RMSD |
| common_mask: [N_common] bool — which residues are valid |
| """ |
| n_common = min(len(coords_x0), len(coords_x1)) |
| c0 = coords_x0[:n_common] |
| c1 = coords_x1[:n_common] |
| m0 = mask_x0[:n_common] |
| m1 = mask_x1[:n_common] |
| common_mask = m0 & m1 |
|
|
| if common_mask.sum() < 5: |
| return np.zeros(n_common), common_mask |
|
|
| ca0 = c0[common_mask, 1, :] |
| ca1 = c1[common_mask, 1, :] |
|
|
| R, t_mobile, t_ref = _kabsch_align(ca0, ca1) |
|
|
| |
| all_ca0 = c0[:, 1, :] |
| aligned_ca0 = (all_ca0 - t_mobile) @ R.T + t_ref |
|
|
| |
| all_ca1 = c1[:, 1, :] |
| displacements = np.linalg.norm(aligned_ca0 - all_ca1, axis=-1) |
|
|
| |
| displacements[~common_mask] = 0.0 |
|
|
| return displacements.astype(np.float32), common_mask |
|
|
|
|
| def generate_path_weights(n_frames, mode='linear'): |
| """ |
| Generate monotonically increasing weights for path frames. |
| |
| The weights increase toward tau=1 (the goal state), so that |
| intermediate conformations closer to X1 are weighted more heavily. |
| |
| Args: |
| n_frames: number of intermediate frames |
| mode: weight schedule |
| 'linear': w_tau = tau |
| 'quadratic': w_tau = tau^2 |
| 'exponential': w_tau = (exp(tau) - 1) / (e - 1) |
| 'uniform': w_tau = 1/n_frames (equal weighting) |
| |
| Returns: |
| weights: [n_frames] numpy array, normalized to sum to 1 |
| """ |
| if n_frames == 0: |
| return np.array([], dtype=np.float32) |
|
|
| taus = np.linspace(0, 1, n_frames + 2)[1:-1] |
|
|
| if mode == 'linear': |
| weights = taus.copy() |
| elif mode == 'quadratic': |
| weights = taus ** 2 |
| elif mode == 'exponential': |
| weights = (np.exp(taus) - 1.0) / (np.e - 1.0) |
| elif mode == 'uniform': |
| weights = np.ones(n_frames, dtype=np.float32) |
| else: |
| raise ValueError(f"Unknown weight mode: {mode}") |
|
|
| |
| total = weights.sum() |
| if total > 0: |
| weights = weights / total |
|
|
| return weights.astype(np.float32) |
|
|
|
|
| |
| |
| |
|
|
| def dsb_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, |
| n_frames=5, sigma=0.5, n_samples=20, seed=42): |
| """ |
| Gaussian Schrödinger Bridge with t*(1-t) variance schedule. |
| |
| Analytic formula (no neural network): |
| X_t = (1-t) * X0_aligned + t * X1 + sqrt(t * (1-t)) * sigma * Z |
| |
| Variance peaks at t=0.5 (maximum uncertainty mid-transition) and vanishes |
| at endpoints. sigma controls noise amplitude in Angstroms. |
| |
| For each tau, samples n_samples noisy interpolations and selects the |
| median (by RMSD to the mean) for robustness. |
| |
| Args: |
| coords_x0: [N0, 4, 3] backbone coords for state 0 |
| coords_x1: [N1, 4, 3] backbone coords for state 1 |
| mask_x0: [N0] bool |
| mask_x1: [N1] bool |
| n_frames: number of intermediate frames |
| sigma: noise amplitude (Angstroms) |
| n_samples: number of samples per frame for median selection |
| seed: random seed |
| |
| Returns: |
| path_frames: list of (coords_tau, mask_tau, tau) tuples |
| """ |
| rng = np.random.RandomState(seed) |
|
|
| n_common = min(len(coords_x0), len(coords_x1)) |
| c0 = coords_x0[:n_common].copy() |
| c1 = coords_x1[:n_common].copy() |
| m0 = mask_x0[:n_common] |
| m1 = mask_x1[:n_common] |
|
|
| common_mask = m0 & m1 |
| if common_mask.sum() < 5: |
| return [] |
|
|
| |
| ca0 = c0[common_mask, 1, :] |
| ca1 = c1[common_mask, 1, :] |
| R, t_mobile, t_ref = _kabsch_align(ca0, ca1) |
|
|
| flat0 = c0.reshape(-1, 3) |
| aligned0 = (flat0 - t_mobile) @ R.T + t_ref |
| c0_aligned = aligned0.reshape(n_common, 4, 3) |
|
|
| taus = np.linspace(0, 1, n_frames + 2)[1:-1] |
| path_frames = [] |
|
|
| for tau in taus: |
| noise_scale = np.sqrt(tau * (1.0 - tau)) * sigma |
|
|
| |
| samples = [] |
| for _ in range(n_samples): |
| Z = rng.randn(n_common, 4, 3).astype(np.float64) |
| X_t = (1.0 - tau) * c0_aligned + tau * c1 + noise_scale * Z |
| samples.append(X_t) |
|
|
| samples = np.array(samples) |
| mean_sample = samples.mean(axis=0) |
|
|
| |
| rmsds = [] |
| for s in samples: |
| diff = s[common_mask, 1, :] - mean_sample[common_mask, 1, :] |
| rmsd = np.sqrt((diff ** 2).sum() / common_mask.sum()) |
| rmsds.append(rmsd) |
| median_idx = np.argsort(rmsds)[len(rmsds) // 2] |
| coords_tau = samples[median_idx] |
|
|
| |
| C_pos = coords_tau[:, 2, :] |
| CA_pos = coords_tau[:, 1, :] |
| C_CA = C_pos - CA_pos |
| C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) |
| C_CA_norm = np.maximum(C_CA_norm, 1e-8) |
| coords_tau[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24 |
|
|
| path_frames.append(( |
| coords_tau.astype(np.float32), |
| common_mask.copy(), |
| float(tau), |
| )) |
|
|
| return path_frames |
|
|
|
|
| |
| |
| |
|
|
| def load_precomputed_frames(target, method, precomputed_dir, |
| coords_x0, coords_x1, mask_x0, mask_x1, |
| n_frames=5): |
| """ |
| Load pre-generated frames from .npz and Kabsch-align to this complex's |
| receptor coordinate frame. |
| |
| Expected file: {precomputed_dir}/{target}/{method}/frames.npz |
| with keys: 'frames' [n_frames, N_ref, 4, 3], 'taus' [n_frames], |
| 'mask' [N_ref] bool |
| |
| Args: |
| target: target name (e.g. 'cam') |
| method: method name ('alphaflow' or 'afsample2') |
| precomputed_dir: root directory for precomputed frames |
| coords_x0, coords_x1: apo/holo backbone coords for alignment |
| mask_x0, mask_x1: residue masks |
| n_frames: number of frames to return |
| |
| Returns: |
| path_frames: list of (coords_tau, mask_tau, tau) tuples |
| """ |
| npz_path = os.path.join(precomputed_dir, target, method, 'frames.npz') |
| if not os.path.exists(npz_path): |
| logger.warning(f"Precomputed frames not found: {npz_path}, " |
| f"falling back to linear interpolation") |
| return interpolate_backbone_path(coords_x0, coords_x1, |
| mask_x0, mask_x1, n_frames) |
|
|
| data = np.load(npz_path) |
| pre_frames = data['frames'] |
| pre_taus = data['taus'] |
| pre_mask = data['mask'] |
|
|
| n_common = min(len(coords_x0), len(coords_x1), len(pre_mask)) |
| m0 = mask_x0[:n_common] |
| m1 = mask_x1[:n_common] |
| pm = pre_mask[:n_common] |
| common_mask = m0 & m1 & pm |
|
|
| if common_mask.sum() < 5: |
| logger.warning(f"Too few common residues for {target}/{method}, " |
| f"falling back to linear") |
| return interpolate_backbone_path(coords_x0, coords_x1, |
| mask_x0, mask_x1, n_frames) |
|
|
| |
| |
| |
| ref_ca = coords_x1[:n_common][common_mask, 1, :] |
|
|
| path_frames = [] |
| K = min(len(pre_frames), n_frames) |
|
|
| |
| if len(pre_frames) > n_frames: |
| indices = np.linspace(0, len(pre_frames) - 1, n_frames).astype(int) |
| else: |
| indices = np.arange(K) |
|
|
| for idx in indices: |
| frame = pre_frames[idx, :n_common].copy() |
| tau = float(pre_taus[idx]) |
|
|
| |
| frame_ca = frame[common_mask, 1, :] |
| R, t_frame, t_ref = _kabsch_align(frame_ca, ref_ca) |
|
|
| flat_frame = frame.reshape(-1, 3) |
| aligned = (flat_frame - t_frame) @ R.T + t_ref |
| frame_aligned = aligned.reshape(n_common, 4, 3) |
|
|
| |
| C_pos = frame_aligned[:, 2, :] |
| CA_pos = frame_aligned[:, 1, :] |
| C_CA = C_pos - CA_pos |
| C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) |
| C_CA_norm = np.maximum(C_CA_norm, 1e-8) |
| frame_aligned[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24 |
|
|
| path_frames.append(( |
| frame_aligned.astype(np.float32), |
| common_mask.copy(), |
| tau, |
| )) |
|
|
| return path_frames |
|
|
|
|
| |
| |
| |
|
|
| def generate_path_frames(coords_x0, coords_x1, mask_x0, mask_x1, |
| method='linear', n_frames=5, |
| precomputed_dir=None, target=None, **kwargs): |
| """ |
| Dispatch to method-specific frame generation. |
| |
| Args: |
| coords_x0, coords_x1: [N, 4, 3] backbone coords for apo/holo |
| mask_x0, mask_x1: [N] bool masks |
| method: one of 'linear', 'alphaflow', 'afsample2', 'dsb', 'anm' |
| n_frames: number of intermediate frames |
| precomputed_dir: directory for precomputed frames (alphaflow/afsample2) |
| target: target name (needed for precomputed methods) |
| **kwargs: method-specific parameters (sigma, n_modes, etc.) |
| |
| Returns: |
| path_frames: list of (coords_tau, mask_tau, tau) tuples |
| """ |
| if method == 'linear': |
| return interpolate_backbone_path( |
| coords_x0, coords_x1, mask_x0, mask_x1, n_frames) |
|
|
| elif method in ('alphaflow', 'afsample2'): |
| if precomputed_dir is None: |
| raise ValueError(f"precomputed_dir required for method '{method}'") |
| if target is None: |
| raise ValueError(f"target name required for method '{method}'") |
| return load_precomputed_frames( |
| target, method, precomputed_dir, |
| coords_x0, coords_x1, mask_x0, mask_x1, n_frames) |
|
|
| elif method == 'dsb': |
| return dsb_backbone_path( |
| coords_x0, coords_x1, mask_x0, mask_x1, |
| n_frames=n_frames, |
| sigma=kwargs.get('sigma', 0.5), |
| n_samples=kwargs.get('n_samples', 20), |
| seed=kwargs.get('seed', 42)) |
|
|
| elif method == 'anm': |
| from utils.anm import anm_backbone_path |
| return anm_backbone_path( |
| coords_x0, coords_x1, mask_x0, mask_x1, |
| n_frames=n_frames, |
| n_modes=kwargs.get('n_modes', 10), |
| cutoff=kwargs.get('cutoff', 15.0)) |
|
|
| else: |
| raise ValueError(f"Unknown path method: '{method}'. " |
| f"Choose from: linear, alphaflow, afsample2, dsb, anm") |
|
|