File size: 7,190 Bytes
ad9572d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | """
Anisotropic Network Model (ANM) for conformational path interpolation.
From-scratch implementation using scipy eigendecomposition.
Projects the apo→holo displacement onto low-frequency normal modes
to create physically motivated interpolation paths.
"""
import numpy as np
from scipy.linalg import eigh
def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10):
"""
Build elastic network Hessian and compute normal modes via eigendecomposition.
Args:
ca_coords: [N, 3] CA atom coordinates
cutoff: distance cutoff for spring connections (Angstroms)
n_modes: number of non-trivial modes to return
Returns:
eigenvalues: [n_modes] array of eigenvalues (force constants)
eigenvectors: [n_modes, N, 3] mode displacement vectors
"""
N = len(ca_coords)
if N < 4:
return np.zeros(n_modes), np.zeros((n_modes, N, 3))
# Build 3N x 3N Hessian with uniform spring constant (gamma=1)
H = np.zeros((3 * N, 3 * N), dtype=np.float64)
for i in range(N):
for j in range(i + 1, N):
diff = ca_coords[j] - ca_coords[i]
dist = np.linalg.norm(diff)
if dist > cutoff or dist < 1e-6:
continue
# Outer product of unit displacement vector
unit = diff / dist
block = np.outer(unit, unit) # [3, 3]
# Off-diagonal: H[i,j] = -gamma * (r_ij ⊗ r_ij) / |r_ij|^2
# With uniform gamma=1 and unit vectors, this simplifies to:
ii, jj = 3 * i, 3 * j
H[ii:ii+3, jj:jj+3] = -block
H[jj:jj+3, ii:ii+3] = -block
# Diagonal: accumulate
H[ii:ii+3, ii:ii+3] += block
H[jj:jj+3, jj:jj+3] += block
# Eigendecompose — first 6 modes are trivial (3 translation + 3 rotation)
n_total = min(6 + n_modes, 3 * N)
eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1])
# Skip the 6 trivial zero-frequency modes
start = min(6, len(eigenvalues) - 1)
n_available = len(eigenvalues) - start
n_return = min(n_modes, n_available)
evals = eigenvalues[start:start + n_return]
evecs = eigvecs[:, start:start + n_return] # [3N, n_return]
# Reshape eigenvectors to [n_modes, N, 3]
mode_vectors = np.zeros((n_return, N, 3))
for k in range(n_return):
mode_vectors[k] = evecs[:, k].reshape(N, 3)
# Pad if fewer modes available than requested
if n_return < n_modes:
pad_evals = np.zeros(n_modes)
pad_evals[:n_return] = evals
pad_modes = np.zeros((n_modes, N, 3))
pad_modes[:n_return] = mode_vectors
return pad_evals, pad_modes
return evals, mode_vectors
def _kabsch_align(mobile_ca, ref_ca):
"""Kabsch alignment of mobile onto ref (CA atoms only)."""
t_mobile = mobile_ca.mean(axis=0)
t_ref = ref_ca.mean(axis=0)
m = mobile_ca - t_mobile
r = ref_ca - t_ref
H = m.T @ r
U, S, Vt = np.linalg.svd(H)
d = np.linalg.det(Vt.T @ U.T)
sign = np.array([1.0, 1.0, np.sign(d)])
R = Vt.T @ np.diag(sign) @ U.T
return R, t_mobile, t_ref
def _reconstruct_oxygen(coords):
"""Reconstruct O atom from N, CA, C with ideal C=O geometry."""
C_pos = coords[:, 2, :]
CA_pos = coords[:, 1, :]
C_CA = C_pos - CA_pos
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
O_pos = C_pos + (C_CA / C_CA_norm) * 1.24
coords[:, 3, :] = O_pos
return coords
def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
n_frames=5, n_modes=10, cutoff=15.0):
"""
Interpolate backbone along dominant ANM modes from X0 toward X1.
Low-frequency modes capture global domain motions (e.g., CaM hinge bending),
creating physically informed paths where large-scale motions precede local
adjustments.
Args:
coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state
coords_x1: [N1, 4, 3] backbone coords for holo state
mask_x0: [N0] bool
mask_x1: [N1] bool
n_frames: number of intermediate frames (excluding endpoints)
n_modes: number of ANM modes to use for projection
cutoff: ANM spring cutoff in Angstroms
Returns:
path_frames: list of (coords_tau, mask_tau, tau) tuples
Same interface as interpolate_backbone_path
"""
n_common = min(len(coords_x0), len(coords_x1))
c0 = coords_x0[:n_common].copy()
c1 = coords_x1[:n_common].copy()
m0 = mask_x0[:n_common]
m1 = mask_x1[:n_common]
common_mask = m0 & m1
if common_mask.sum() < 5:
return []
# Kabsch-align X0 onto X1 using valid CA atoms
ca0 = c0[common_mask, 1, :]
ca1 = c1[common_mask, 1, :]
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
# Apply alignment to all X0 backbone atoms
flat0 = c0.reshape(-1, 3)
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
c0_aligned = aligned0.reshape(n_common, 4, 3)
# Compute apo→holo displacement (CA atoms, valid residues only)
ca0_aligned = c0_aligned[common_mask, 1, :] # [N_valid, 3]
ca1_valid = c1[common_mask, 1, :]
displacement = ca1_valid - ca0_aligned # [N_valid, 3]
# Compute ANM modes of the aligned apo structure
eigenvalues, mode_vectors = compute_anm_modes(
ca0_aligned, cutoff=cutoff, n_modes=n_modes
) # mode_vectors: [n_modes, N_valid, 3]
# Project displacement onto each mode
# d_k = sum_i mode_k[i] . displacement[i]
projections = np.zeros(n_modes)
for k in range(n_modes):
projections[k] = np.sum(mode_vectors[k] * displacement)
# Reconstruct mode-projected displacement: d_mode = sum_k d_k * mode_k
mode_displacement = np.zeros_like(displacement) # [N_valid, 3]
for k in range(n_modes):
mode_displacement += projections[k] * mode_vectors[k]
# Residual displacement not captured by modes
residual = displacement - mode_displacement
# Generate intermediate frames
taus = np.linspace(0, 1, n_frames + 2)[1:-1]
path_frames = []
for tau in taus:
# Apply mode-projected + residual displacement at each tau
# Mode component applies smoothly; residual is linear
ca_interp = ca0_aligned + tau * mode_displacement + tau * residual
# Build full backbone by interpolating all 4 atom types
coords_tau = (1.0 - tau) * c0_aligned + tau * c1
# Override CA positions with ANM-interpolated values
coords_tau[common_mask, 1, :] = ca_interp
# Adjust N, C positions relative to CA shift
# The N/CA/C triangle is preserved by blending the ANM CA shift
# with the linear interpolation of N and C
ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid)
coords_tau[common_mask, 0, :] += ca_shift # N atoms
coords_tau[common_mask, 2, :] += ca_shift # C atoms
# Reconstruct O from N, CA, C
coords_tau = _reconstruct_oxygen(coords_tau)
path_frames.append((
coords_tau.astype(np.float32),
common_mask.copy(),
float(tau),
))
return path_frames
|