| """ |
| Anisotropic Network Model (ANM) for conformational path interpolation. |
| |
| From-scratch implementation using scipy eigendecomposition. |
| Projects the apo→holo displacement onto low-frequency normal modes |
| to create physically motivated interpolation paths. |
| """ |
|
|
| import numpy as np |
| from scipy.linalg import eigh |
|
|
|
|
| def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10): |
| """ |
| Build elastic network Hessian and compute normal modes via eigendecomposition. |
| |
| Args: |
| ca_coords: [N, 3] CA atom coordinates |
| cutoff: distance cutoff for spring connections (Angstroms) |
| n_modes: number of non-trivial modes to return |
| |
| Returns: |
| eigenvalues: [n_modes] array of eigenvalues (force constants) |
| eigenvectors: [n_modes, N, 3] mode displacement vectors |
| """ |
| N = len(ca_coords) |
| if N < 4: |
| return np.zeros(n_modes), np.zeros((n_modes, N, 3)) |
|
|
| |
| H = np.zeros((3 * N, 3 * N), dtype=np.float64) |
|
|
| for i in range(N): |
| for j in range(i + 1, N): |
| diff = ca_coords[j] - ca_coords[i] |
| dist = np.linalg.norm(diff) |
| if dist > cutoff or dist < 1e-6: |
| continue |
|
|
| |
| unit = diff / dist |
| block = np.outer(unit, unit) |
|
|
| |
| |
| ii, jj = 3 * i, 3 * j |
| H[ii:ii+3, jj:jj+3] = -block |
| H[jj:jj+3, ii:ii+3] = -block |
|
|
| |
| H[ii:ii+3, ii:ii+3] += block |
| H[jj:jj+3, jj:jj+3] += block |
|
|
| |
| n_total = min(6 + n_modes, 3 * N) |
| eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1]) |
|
|
| |
| start = min(6, len(eigenvalues) - 1) |
| n_available = len(eigenvalues) - start |
| n_return = min(n_modes, n_available) |
|
|
| evals = eigenvalues[start:start + n_return] |
| evecs = eigvecs[:, start:start + n_return] |
|
|
| |
| mode_vectors = np.zeros((n_return, N, 3)) |
| for k in range(n_return): |
| mode_vectors[k] = evecs[:, k].reshape(N, 3) |
|
|
| |
| if n_return < n_modes: |
| pad_evals = np.zeros(n_modes) |
| pad_evals[:n_return] = evals |
| pad_modes = np.zeros((n_modes, N, 3)) |
| pad_modes[:n_return] = mode_vectors |
| return pad_evals, pad_modes |
|
|
| return evals, mode_vectors |
|
|
|
|
| def _kabsch_align(mobile_ca, ref_ca): |
| """Kabsch alignment of mobile onto ref (CA atoms only).""" |
| t_mobile = mobile_ca.mean(axis=0) |
| t_ref = ref_ca.mean(axis=0) |
|
|
| m = mobile_ca - t_mobile |
| r = ref_ca - t_ref |
|
|
| H = m.T @ r |
| U, S, Vt = np.linalg.svd(H) |
| d = np.linalg.det(Vt.T @ U.T) |
| sign = np.array([1.0, 1.0, np.sign(d)]) |
| R = Vt.T @ np.diag(sign) @ U.T |
|
|
| return R, t_mobile, t_ref |
|
|
|
|
| def _reconstruct_oxygen(coords): |
| """Reconstruct O atom from N, CA, C with ideal C=O geometry.""" |
| C_pos = coords[:, 2, :] |
| CA_pos = coords[:, 1, :] |
| C_CA = C_pos - CA_pos |
| C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True) |
| C_CA_norm = np.maximum(C_CA_norm, 1e-8) |
| O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 |
| coords[:, 3, :] = O_pos |
| return coords |
|
|
|
|
| def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, |
| n_frames=5, n_modes=10, cutoff=15.0): |
| """ |
| Interpolate backbone along dominant ANM modes from X0 toward X1. |
| |
| Low-frequency modes capture global domain motions (e.g., CaM hinge bending), |
| creating physically informed paths where large-scale motions precede local |
| adjustments. |
| |
| Args: |
| coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state |
| coords_x1: [N1, 4, 3] backbone coords for holo state |
| mask_x0: [N0] bool |
| mask_x1: [N1] bool |
| n_frames: number of intermediate frames (excluding endpoints) |
| n_modes: number of ANM modes to use for projection |
| cutoff: ANM spring cutoff in Angstroms |
| |
| Returns: |
| path_frames: list of (coords_tau, mask_tau, tau) tuples |
| Same interface as interpolate_backbone_path |
| """ |
| n_common = min(len(coords_x0), len(coords_x1)) |
| c0 = coords_x0[:n_common].copy() |
| c1 = coords_x1[:n_common].copy() |
| m0 = mask_x0[:n_common] |
| m1 = mask_x1[:n_common] |
|
|
| common_mask = m0 & m1 |
| if common_mask.sum() < 5: |
| return [] |
|
|
| |
| ca0 = c0[common_mask, 1, :] |
| ca1 = c1[common_mask, 1, :] |
| R, t_mobile, t_ref = _kabsch_align(ca0, ca1) |
|
|
| |
| flat0 = c0.reshape(-1, 3) |
| aligned0 = (flat0 - t_mobile) @ R.T + t_ref |
| c0_aligned = aligned0.reshape(n_common, 4, 3) |
|
|
| |
| ca0_aligned = c0_aligned[common_mask, 1, :] |
| ca1_valid = c1[common_mask, 1, :] |
|
|
| displacement = ca1_valid - ca0_aligned |
|
|
| |
| eigenvalues, mode_vectors = compute_anm_modes( |
| ca0_aligned, cutoff=cutoff, n_modes=n_modes |
| ) |
|
|
| |
| |
| projections = np.zeros(n_modes) |
| for k in range(n_modes): |
| projections[k] = np.sum(mode_vectors[k] * displacement) |
|
|
| |
| mode_displacement = np.zeros_like(displacement) |
| for k in range(n_modes): |
| mode_displacement += projections[k] * mode_vectors[k] |
|
|
| |
| residual = displacement - mode_displacement |
|
|
| |
| taus = np.linspace(0, 1, n_frames + 2)[1:-1] |
| path_frames = [] |
|
|
| for tau in taus: |
| |
| |
| ca_interp = ca0_aligned + tau * mode_displacement + tau * residual |
|
|
| |
| coords_tau = (1.0 - tau) * c0_aligned + tau * c1 |
| |
| coords_tau[common_mask, 1, :] = ca_interp |
|
|
| |
| |
| |
| ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid) |
| coords_tau[common_mask, 0, :] += ca_shift |
| coords_tau[common_mask, 2, :] += ca_shift |
|
|
| |
| coords_tau = _reconstruct_oxygen(coords_tau) |
|
|
| path_frames.append(( |
| coords_tau.astype(np.float32), |
| common_mask.copy(), |
| float(tau), |
| )) |
|
|
| return path_frames |
|
|