""" Transition-path interpolation utilities for conformational induction. Provides: - Kabsch-aligned backbone interpolation between two conformational states - Gaussian Schrödinger Bridge (DSB) stochastic interpolation - Precomputed frame loading (for AlphaFlow / AFsample2) - Unified dispatcher: generate_path_frames() - Per-residue displacement computation (for allosteric hinge weighting) - Monotonically increasing path weight generation Used by the path-aware training, guidance, and refinement modules. """ import os import logging import numpy as np logger = logging.getLogger(__name__) def _kabsch_align(mobile_ca, ref_ca): """ Kabsch alignment of mobile onto ref (CA atoms only). Args: mobile_ca: [N, 3] array ref_ca: [N, 3] array Returns: R: [3, 3] rotation matrix t_mobile: [3] mobile centroid t_ref: [3] ref centroid Such that: aligned = (mobile - t_mobile) @ R.T + t_ref """ t_mobile = mobile_ca.mean(axis=0) t_ref = ref_ca.mean(axis=0) m = mobile_ca - t_mobile r = ref_ca - t_ref H = m.T @ r U, S, Vt = np.linalg.svd(H) d = np.linalg.det(Vt.T @ U.T) sign = np.array([1.0, 1.0, np.sign(d)]) R = Vt.T @ np.diag(sign) @ U.T return R, t_mobile, t_ref def interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5): """ Generate intermediate backbone conformations along the X0 -> X1 path. 1. Find common valid residues between X0 and X1 2. Kabsch-align X0 onto X1 using CA atoms 3. Linearly interpolate backbone coords at n_frames equally-spaced tau values 4. Reconstruct O from N/CA/C with ideal geometry Args: coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for state 0 coords_x1: [N1, 4, 3] backbone coords for state 1 mask_x0: [N0] bool mask_x1: [N1] bool n_frames: number of intermediate frames (excluding endpoints) Returns: path_frames: list of (coords_tau, mask_tau, tau) tuples coords_tau: [N_common, 4, 3] interpolated backbone coords mask_tau: [N_common] bool tau: float in (0, 1) exclusive """ # Use common length n_common = min(len(coords_x0), len(coords_x1)) c0 = coords_x0[:n_common].copy() c1 = coords_x1[:n_common].copy() m0 = mask_x0[:n_common] m1 = mask_x1[:n_common] # Valid in both states common_mask = m0 & m1 if common_mask.sum() < 5: return [] # Kabsch-align X0 onto X1 using valid CA atoms ca0 = c0[common_mask, 1, :] # CA atoms ca1 = c1[common_mask, 1, :] R, t_mobile, t_ref = _kabsch_align(ca0, ca1) # Apply alignment to all X0 backbone atoms n_res = n_common flat0 = c0.reshape(-1, 3) aligned0 = (flat0 - t_mobile) @ R.T + t_ref c0_aligned = aligned0.reshape(n_res, 4, 3) # Generate intermediate frames taus = np.linspace(0, 1, n_frames + 2)[1:-1] # exclude endpoints path_frames = [] for tau in taus: # Linear interpolation: X_tau = (1 - tau) * X0_aligned + tau * X1 coords_tau = (1.0 - tau) * c0_aligned + tau * c1 # Reconstruct O from N, CA, C with ideal C=O bond geometry C_pos = coords_tau[:, 2, :] # C atoms CA_pos = coords_tau[:, 1, :] # CA atoms C_CA = C_pos - CA_pos C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) C_CA_norm = np.maximum(C_CA_norm, 1e-8) O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 # ideal C=O bond length coords_tau[:, 3, :] = O_pos path_frames.append(( coords_tau.astype(np.float32), common_mask.copy(), float(tau), )) return path_frames def compute_residue_displacements(coords_x0, coords_x1, mask_x0, mask_x1): """ Per-residue CA displacement between X0 and X1 after Kabsch alignment. Args: coords_x0: [N0, 4, 3] backbone coords for state 0 coords_x1: [N1, 4, 3] backbone coords for state 1 mask_x0: [N0] bool mask_x1: [N1] bool Returns: displacements: [N_common] array of per-residue CA RMSD common_mask: [N_common] bool — which residues are valid """ n_common = min(len(coords_x0), len(coords_x1)) c0 = coords_x0[:n_common] c1 = coords_x1[:n_common] m0 = mask_x0[:n_common] m1 = mask_x1[:n_common] common_mask = m0 & m1 if common_mask.sum() < 5: return np.zeros(n_common), common_mask ca0 = c0[common_mask, 1, :] ca1 = c1[common_mask, 1, :] R, t_mobile, t_ref = _kabsch_align(ca0, ca1) # Align all CA of X0 all_ca0 = c0[:, 1, :] aligned_ca0 = (all_ca0 - t_mobile) @ R.T + t_ref # Per-residue displacement all_ca1 = c1[:, 1, :] displacements = np.linalg.norm(aligned_ca0 - all_ca1, axis=-1) # Zero out invalid residues displacements[~common_mask] = 0.0 return displacements.astype(np.float32), common_mask def generate_path_weights(n_frames, mode='linear'): """ Generate monotonically increasing weights for path frames. The weights increase toward tau=1 (the goal state), so that intermediate conformations closer to X1 are weighted more heavily. Args: n_frames: number of intermediate frames mode: weight schedule 'linear': w_tau = tau 'quadratic': w_tau = tau^2 'exponential': w_tau = (exp(tau) - 1) / (e - 1) 'uniform': w_tau = 1/n_frames (equal weighting) Returns: weights: [n_frames] numpy array, normalized to sum to 1 """ if n_frames == 0: return np.array([], dtype=np.float32) taus = np.linspace(0, 1, n_frames + 2)[1:-1] # same as interpolation if mode == 'linear': weights = taus.copy() elif mode == 'quadratic': weights = taus ** 2 elif mode == 'exponential': weights = (np.exp(taus) - 1.0) / (np.e - 1.0) elif mode == 'uniform': weights = np.ones(n_frames, dtype=np.float32) else: raise ValueError(f"Unknown weight mode: {mode}") # Normalize to sum to 1 total = weights.sum() if total > 0: weights = weights / total return weights.astype(np.float32) # --------------------------------------------------------------------------- # Gaussian Schrödinger Bridge (AlignDSB) interpolation # --------------------------------------------------------------------------- def dsb_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5, sigma=0.5, n_samples=20, seed=42): """ Gaussian Schrödinger Bridge with t*(1-t) variance schedule. Analytic formula (no neural network): X_t = (1-t) * X0_aligned + t * X1 + sqrt(t * (1-t)) * sigma * Z Variance peaks at t=0.5 (maximum uncertainty mid-transition) and vanishes at endpoints. sigma controls noise amplitude in Angstroms. For each tau, samples n_samples noisy interpolations and selects the median (by RMSD to the mean) for robustness. Args: coords_x0: [N0, 4, 3] backbone coords for state 0 coords_x1: [N1, 4, 3] backbone coords for state 1 mask_x0: [N0] bool mask_x1: [N1] bool n_frames: number of intermediate frames sigma: noise amplitude (Angstroms) n_samples: number of samples per frame for median selection seed: random seed Returns: path_frames: list of (coords_tau, mask_tau, tau) tuples """ rng = np.random.RandomState(seed) n_common = min(len(coords_x0), len(coords_x1)) c0 = coords_x0[:n_common].copy() c1 = coords_x1[:n_common].copy() m0 = mask_x0[:n_common] m1 = mask_x1[:n_common] common_mask = m0 & m1 if common_mask.sum() < 5: return [] # Kabsch-align X0 onto X1 ca0 = c0[common_mask, 1, :] ca1 = c1[common_mask, 1, :] R, t_mobile, t_ref = _kabsch_align(ca0, ca1) flat0 = c0.reshape(-1, 3) aligned0 = (flat0 - t_mobile) @ R.T + t_ref c0_aligned = aligned0.reshape(n_common, 4, 3) taus = np.linspace(0, 1, n_frames + 2)[1:-1] path_frames = [] for tau in taus: noise_scale = np.sqrt(tau * (1.0 - tau)) * sigma # Generate n_samples noisy interpolations samples = [] for _ in range(n_samples): Z = rng.randn(n_common, 4, 3).astype(np.float64) X_t = (1.0 - tau) * c0_aligned + tau * c1 + noise_scale * Z samples.append(X_t) samples = np.array(samples) # [n_samples, N, 4, 3] mean_sample = samples.mean(axis=0) # [N, 4, 3] # Select median sample by RMSD to mean (CA atoms) rmsds = [] for s in samples: diff = s[common_mask, 1, :] - mean_sample[common_mask, 1, :] rmsd = np.sqrt((diff ** 2).sum() / common_mask.sum()) rmsds.append(rmsd) median_idx = np.argsort(rmsds)[len(rmsds) // 2] coords_tau = samples[median_idx] # Reconstruct O from N, CA, C C_pos = coords_tau[:, 2, :] CA_pos = coords_tau[:, 1, :] C_CA = C_pos - CA_pos C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) C_CA_norm = np.maximum(C_CA_norm, 1e-8) coords_tau[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24 path_frames.append(( coords_tau.astype(np.float32), common_mask.copy(), float(tau), )) return path_frames # --------------------------------------------------------------------------- # Precomputed frame loading (for AlphaFlow / AFsample2) # --------------------------------------------------------------------------- def load_precomputed_frames(target, method, precomputed_dir, coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5): """ Load pre-generated frames from .npz and Kabsch-align to this complex's receptor coordinate frame. Expected file: {precomputed_dir}/{target}/{method}/frames.npz with keys: 'frames' [n_frames, N_ref, 4, 3], 'taus' [n_frames], 'mask' [N_ref] bool Args: target: target name (e.g. 'cam') method: method name ('alphaflow' or 'afsample2') precomputed_dir: root directory for precomputed frames coords_x0, coords_x1: apo/holo backbone coords for alignment mask_x0, mask_x1: residue masks n_frames: number of frames to return Returns: path_frames: list of (coords_tau, mask_tau, tau) tuples """ npz_path = os.path.join(precomputed_dir, target, method, 'frames.npz') if not os.path.exists(npz_path): logger.warning(f"Precomputed frames not found: {npz_path}, " f"falling back to linear interpolation") return interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames) data = np.load(npz_path) pre_frames = data['frames'] # [K, N_ref, 4, 3] pre_taus = data['taus'] # [K] pre_mask = data['mask'] # [N_ref] n_common = min(len(coords_x0), len(coords_x1), len(pre_mask)) m0 = mask_x0[:n_common] m1 = mask_x1[:n_common] pm = pre_mask[:n_common] common_mask = m0 & m1 & pm if common_mask.sum() < 5: logger.warning(f"Too few common residues for {target}/{method}, " f"falling back to linear") return interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames) # Align precomputed frames to the holo receptor (X1) coordinate frame # The precomputed frames were generated from the reference apo sequence # and may be in a different coordinate frame ref_ca = coords_x1[:n_common][common_mask, 1, :] # holo CA as reference path_frames = [] K = min(len(pre_frames), n_frames) # Select n_frames evenly spaced from available frames if len(pre_frames) > n_frames: indices = np.linspace(0, len(pre_frames) - 1, n_frames).astype(int) else: indices = np.arange(K) for idx in indices: frame = pre_frames[idx, :n_common].copy() # [N_common, 4, 3] tau = float(pre_taus[idx]) # Kabsch-align frame CA to holo CA frame_ca = frame[common_mask, 1, :] R, t_frame, t_ref = _kabsch_align(frame_ca, ref_ca) flat_frame = frame.reshape(-1, 3) aligned = (flat_frame - t_frame) @ R.T + t_ref frame_aligned = aligned.reshape(n_common, 4, 3) # Reconstruct O C_pos = frame_aligned[:, 2, :] CA_pos = frame_aligned[:, 1, :] C_CA = C_pos - CA_pos C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) C_CA_norm = np.maximum(C_CA_norm, 1e-8) frame_aligned[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24 path_frames.append(( frame_aligned.astype(np.float32), common_mask.copy(), tau, )) return path_frames # --------------------------------------------------------------------------- # Unified dispatcher # --------------------------------------------------------------------------- def generate_path_frames(coords_x0, coords_x1, mask_x0, mask_x1, method='linear', n_frames=5, precomputed_dir=None, target=None, **kwargs): """ Dispatch to method-specific frame generation. Args: coords_x0, coords_x1: [N, 4, 3] backbone coords for apo/holo mask_x0, mask_x1: [N] bool masks method: one of 'linear', 'alphaflow', 'afsample2', 'dsb', 'anm' n_frames: number of intermediate frames precomputed_dir: directory for precomputed frames (alphaflow/afsample2) target: target name (needed for precomputed methods) **kwargs: method-specific parameters (sigma, n_modes, etc.) Returns: path_frames: list of (coords_tau, mask_tau, tau) tuples """ if method == 'linear': return interpolate_backbone_path( coords_x0, coords_x1, mask_x0, mask_x1, n_frames) elif method in ('alphaflow', 'afsample2'): if precomputed_dir is None: raise ValueError(f"precomputed_dir required for method '{method}'") if target is None: raise ValueError(f"target name required for method '{method}'") return load_precomputed_frames( target, method, precomputed_dir, coords_x0, coords_x1, mask_x0, mask_x1, n_frames) elif method == 'dsb': return dsb_backbone_path( coords_x0, coords_x1, mask_x0, mask_x1, n_frames=n_frames, sigma=kwargs.get('sigma', 0.5), n_samples=kwargs.get('n_samples', 20), seed=kwargs.get('seed', 42)) elif method == 'anm': from utils.anm import anm_backbone_path return anm_backbone_path( coords_x0, coords_x1, mask_x0, mask_x1, n_frames=n_frames, n_modes=kwargs.get('n_modes', 10), cutoff=kwargs.get('cutoff', 15.0)) else: raise ValueError(f"Unknown path method: '{method}'. " f"Choose from: linear, alphaflow, afsample2, dsb, anm")