""" Anisotropic Network Model (ANM) for conformational path interpolation. From-scratch implementation using scipy eigendecomposition. Projects the apo→holo displacement onto low-frequency normal modes to create physically motivated interpolation paths. """ import numpy as np from scipy.linalg import eigh def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10): """ Build elastic network Hessian and compute normal modes via eigendecomposition. Args: ca_coords: [N, 3] CA atom coordinates cutoff: distance cutoff for spring connections (Angstroms) n_modes: number of non-trivial modes to return Returns: eigenvalues: [n_modes] array of eigenvalues (force constants) eigenvectors: [n_modes, N, 3] mode displacement vectors """ N = len(ca_coords) if N < 4: return np.zeros(n_modes), np.zeros((n_modes, N, 3)) # Build 3N x 3N Hessian with uniform spring constant (gamma=1) H = np.zeros((3 * N, 3 * N), dtype=np.float64) for i in range(N): for j in range(i + 1, N): diff = ca_coords[j] - ca_coords[i] dist = np.linalg.norm(diff) if dist > cutoff or dist < 1e-6: continue # Outer product of unit displacement vector unit = diff / dist block = np.outer(unit, unit) # [3, 3] # Off-diagonal: H[i,j] = -gamma * (r_ij ⊗ r_ij) / |r_ij|^2 # With uniform gamma=1 and unit vectors, this simplifies to: ii, jj = 3 * i, 3 * j H[ii:ii+3, jj:jj+3] = -block H[jj:jj+3, ii:ii+3] = -block # Diagonal: accumulate H[ii:ii+3, ii:ii+3] += block H[jj:jj+3, jj:jj+3] += block # Eigendecompose — first 6 modes are trivial (3 translation + 3 rotation) n_total = min(6 + n_modes, 3 * N) eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1]) # Skip the 6 trivial zero-frequency modes start = min(6, len(eigenvalues) - 1) n_available = len(eigenvalues) - start n_return = min(n_modes, n_available) evals = eigenvalues[start:start + n_return] evecs = eigvecs[:, start:start + n_return] # [3N, n_return] # Reshape eigenvectors to [n_modes, N, 3] mode_vectors = np.zeros((n_return, N, 3)) for k in range(n_return): mode_vectors[k] = evecs[:, k].reshape(N, 3) # Pad if fewer modes available than requested if n_return < n_modes: pad_evals = np.zeros(n_modes) pad_evals[:n_return] = evals pad_modes = np.zeros((n_modes, N, 3)) pad_modes[:n_return] = mode_vectors return pad_evals, pad_modes return evals, mode_vectors def _kabsch_align(mobile_ca, ref_ca): """Kabsch alignment of mobile onto ref (CA atoms only).""" t_mobile = mobile_ca.mean(axis=0) t_ref = ref_ca.mean(axis=0) m = mobile_ca - t_mobile r = ref_ca - t_ref H = m.T @ r U, S, Vt = np.linalg.svd(H) d = np.linalg.det(Vt.T @ U.T) sign = np.array([1.0, 1.0, np.sign(d)]) R = Vt.T @ np.diag(sign) @ U.T return R, t_mobile, t_ref def _reconstruct_oxygen(coords): """Reconstruct O atom from N, CA, C with ideal C=O geometry.""" C_pos = coords[:, 2, :] CA_pos = coords[:, 1, :] C_CA = C_pos - CA_pos C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) C_CA_norm = np.maximum(C_CA_norm, 1e-8) O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 coords[:, 3, :] = O_pos return coords def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5, n_modes=10, cutoff=15.0): """ Interpolate backbone along dominant ANM modes from X0 toward X1. Low-frequency modes capture global domain motions (e.g., CaM hinge bending), creating physically informed paths where large-scale motions precede local adjustments. Args: coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state coords_x1: [N1, 4, 3] backbone coords for holo state mask_x0: [N0] bool mask_x1: [N1] bool n_frames: number of intermediate frames (excluding endpoints) n_modes: number of ANM modes to use for projection cutoff: ANM spring cutoff in Angstroms Returns: path_frames: list of (coords_tau, mask_tau, tau) tuples Same interface as interpolate_backbone_path """ n_common = min(len(coords_x0), len(coords_x1)) c0 = coords_x0[:n_common].copy() c1 = coords_x1[:n_common].copy() m0 = mask_x0[:n_common] m1 = mask_x1[:n_common] common_mask = m0 & m1 if common_mask.sum() < 5: return [] # Kabsch-align X0 onto X1 using valid CA atoms ca0 = c0[common_mask, 1, :] ca1 = c1[common_mask, 1, :] R, t_mobile, t_ref = _kabsch_align(ca0, ca1) # Apply alignment to all X0 backbone atoms flat0 = c0.reshape(-1, 3) aligned0 = (flat0 - t_mobile) @ R.T + t_ref c0_aligned = aligned0.reshape(n_common, 4, 3) # Compute apo→holo displacement (CA atoms, valid residues only) ca0_aligned = c0_aligned[common_mask, 1, :] # [N_valid, 3] ca1_valid = c1[common_mask, 1, :] displacement = ca1_valid - ca0_aligned # [N_valid, 3] # Compute ANM modes of the aligned apo structure eigenvalues, mode_vectors = compute_anm_modes( ca0_aligned, cutoff=cutoff, n_modes=n_modes ) # mode_vectors: [n_modes, N_valid, 3] # Project displacement onto each mode # d_k = sum_i mode_k[i] . displacement[i] projections = np.zeros(n_modes) for k in range(n_modes): projections[k] = np.sum(mode_vectors[k] * displacement) # Reconstruct mode-projected displacement: d_mode = sum_k d_k * mode_k mode_displacement = np.zeros_like(displacement) # [N_valid, 3] for k in range(n_modes): mode_displacement += projections[k] * mode_vectors[k] # Residual displacement not captured by modes residual = displacement - mode_displacement # Generate intermediate frames taus = np.linspace(0, 1, n_frames + 2)[1:-1] path_frames = [] for tau in taus: # Apply mode-projected + residual displacement at each tau # Mode component applies smoothly; residual is linear ca_interp = ca0_aligned + tau * mode_displacement + tau * residual # Build full backbone by interpolating all 4 atom types coords_tau = (1.0 - tau) * c0_aligned + tau * c1 # Override CA positions with ANM-interpolated values coords_tau[common_mask, 1, :] = ca_interp # Adjust N, C positions relative to CA shift # The N/CA/C triangle is preserved by blending the ANM CA shift # with the linear interpolation of N and C ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid) coords_tau[common_mask, 0, :] += ca_shift # N atoms coords_tau[common_mask, 2, :] += ca_shift # C atoms # Reconstruct O from N, CA, C coords_tau = _reconstruct_oxygen(coords_tau) path_frames.append(( coords_tau.astype(np.float32), common_mask.copy(), float(tau), )) return path_frames