File size: 7,190 Bytes
ad9572d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Anisotropic Network Model (ANM) for conformational path interpolation.

From-scratch implementation using scipy eigendecomposition.
Projects the apo→holo displacement onto low-frequency normal modes
to create physically motivated interpolation paths.
"""

import numpy as np
from scipy.linalg import eigh


def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10):
    """
    Build elastic network Hessian and compute normal modes via eigendecomposition.

    Args:
        ca_coords: [N, 3] CA atom coordinates
        cutoff: distance cutoff for spring connections (Angstroms)
        n_modes: number of non-trivial modes to return

    Returns:
        eigenvalues: [n_modes] array of eigenvalues (force constants)
        eigenvectors: [n_modes, N, 3] mode displacement vectors
    """
    N = len(ca_coords)
    if N < 4:
        return np.zeros(n_modes), np.zeros((n_modes, N, 3))

    # Build 3N x 3N Hessian with uniform spring constant (gamma=1)
    H = np.zeros((3 * N, 3 * N), dtype=np.float64)

    for i in range(N):
        for j in range(i + 1, N):
            diff = ca_coords[j] - ca_coords[i]
            dist = np.linalg.norm(diff)
            if dist > cutoff or dist < 1e-6:
                continue

            # Outer product of unit displacement vector
            unit = diff / dist
            block = np.outer(unit, unit)  # [3, 3]

            # Off-diagonal: H[i,j] = -gamma * (r_ij ⊗ r_ij) / |r_ij|^2
            # With uniform gamma=1 and unit vectors, this simplifies to:
            ii, jj = 3 * i, 3 * j
            H[ii:ii+3, jj:jj+3] = -block
            H[jj:jj+3, ii:ii+3] = -block

            # Diagonal: accumulate
            H[ii:ii+3, ii:ii+3] += block
            H[jj:jj+3, jj:jj+3] += block

    # Eigendecompose — first 6 modes are trivial (3 translation + 3 rotation)
    n_total = min(6 + n_modes, 3 * N)
    eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1])

    # Skip the 6 trivial zero-frequency modes
    start = min(6, len(eigenvalues) - 1)
    n_available = len(eigenvalues) - start
    n_return = min(n_modes, n_available)

    evals = eigenvalues[start:start + n_return]
    evecs = eigvecs[:, start:start + n_return]  # [3N, n_return]

    # Reshape eigenvectors to [n_modes, N, 3]
    mode_vectors = np.zeros((n_return, N, 3))
    for k in range(n_return):
        mode_vectors[k] = evecs[:, k].reshape(N, 3)

    # Pad if fewer modes available than requested
    if n_return < n_modes:
        pad_evals = np.zeros(n_modes)
        pad_evals[:n_return] = evals
        pad_modes = np.zeros((n_modes, N, 3))
        pad_modes[:n_return] = mode_vectors
        return pad_evals, pad_modes

    return evals, mode_vectors


def _kabsch_align(mobile_ca, ref_ca):
    """Kabsch alignment of mobile onto ref (CA atoms only)."""
    t_mobile = mobile_ca.mean(axis=0)
    t_ref = ref_ca.mean(axis=0)

    m = mobile_ca - t_mobile
    r = ref_ca - t_ref

    H = m.T @ r
    U, S, Vt = np.linalg.svd(H)
    d = np.linalg.det(Vt.T @ U.T)
    sign = np.array([1.0, 1.0, np.sign(d)])
    R = Vt.T @ np.diag(sign) @ U.T

    return R, t_mobile, t_ref


def _reconstruct_oxygen(coords):
    """Reconstruct O atom from N, CA, C with ideal C=O geometry."""
    C_pos = coords[:, 2, :]
    CA_pos = coords[:, 1, :]
    C_CA = C_pos - CA_pos
    C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
    C_CA_norm = np.maximum(C_CA_norm, 1e-8)
    O_pos = C_pos + (C_CA / C_CA_norm) * 1.24
    coords[:, 3, :] = O_pos
    return coords


def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
                       n_frames=5, n_modes=10, cutoff=15.0):
    """
    Interpolate backbone along dominant ANM modes from X0 toward X1.

    Low-frequency modes capture global domain motions (e.g., CaM hinge bending),
    creating physically informed paths where large-scale motions precede local
    adjustments.

    Args:
        coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state
        coords_x1: [N1, 4, 3] backbone coords for holo state
        mask_x0: [N0] bool
        mask_x1: [N1] bool
        n_frames: number of intermediate frames (excluding endpoints)
        n_modes: number of ANM modes to use for projection
        cutoff: ANM spring cutoff in Angstroms

    Returns:
        path_frames: list of (coords_tau, mask_tau, tau) tuples
            Same interface as interpolate_backbone_path
    """
    n_common = min(len(coords_x0), len(coords_x1))
    c0 = coords_x0[:n_common].copy()
    c1 = coords_x1[:n_common].copy()
    m0 = mask_x0[:n_common]
    m1 = mask_x1[:n_common]

    common_mask = m0 & m1
    if common_mask.sum() < 5:
        return []

    # Kabsch-align X0 onto X1 using valid CA atoms
    ca0 = c0[common_mask, 1, :]
    ca1 = c1[common_mask, 1, :]
    R, t_mobile, t_ref = _kabsch_align(ca0, ca1)

    # Apply alignment to all X0 backbone atoms
    flat0 = c0.reshape(-1, 3)
    aligned0 = (flat0 - t_mobile) @ R.T + t_ref
    c0_aligned = aligned0.reshape(n_common, 4, 3)

    # Compute apo→holo displacement (CA atoms, valid residues only)
    ca0_aligned = c0_aligned[common_mask, 1, :]  # [N_valid, 3]
    ca1_valid = c1[common_mask, 1, :]

    displacement = ca1_valid - ca0_aligned  # [N_valid, 3]

    # Compute ANM modes of the aligned apo structure
    eigenvalues, mode_vectors = compute_anm_modes(
        ca0_aligned, cutoff=cutoff, n_modes=n_modes
    )  # mode_vectors: [n_modes, N_valid, 3]

    # Project displacement onto each mode
    # d_k = sum_i mode_k[i] . displacement[i]
    projections = np.zeros(n_modes)
    for k in range(n_modes):
        projections[k] = np.sum(mode_vectors[k] * displacement)

    # Reconstruct mode-projected displacement: d_mode = sum_k d_k * mode_k
    mode_displacement = np.zeros_like(displacement)  # [N_valid, 3]
    for k in range(n_modes):
        mode_displacement += projections[k] * mode_vectors[k]

    # Residual displacement not captured by modes
    residual = displacement - mode_displacement

    # Generate intermediate frames
    taus = np.linspace(0, 1, n_frames + 2)[1:-1]
    path_frames = []

    for tau in taus:
        # Apply mode-projected + residual displacement at each tau
        # Mode component applies smoothly; residual is linear
        ca_interp = ca0_aligned + tau * mode_displacement + tau * residual

        # Build full backbone by interpolating all 4 atom types
        coords_tau = (1.0 - tau) * c0_aligned + tau * c1
        # Override CA positions with ANM-interpolated values
        coords_tau[common_mask, 1, :] = ca_interp

        # Adjust N, C positions relative to CA shift
        # The N/CA/C triangle is preserved by blending the ANM CA shift
        # with the linear interpolation of N and C
        ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid)
        coords_tau[common_mask, 0, :] += ca_shift  # N atoms
        coords_tau[common_mask, 2, :] += ca_shift  # C atoms

        # Reconstruct O from N, CA, C
        coords_tau = _reconstruct_oxygen(coords_tau)

        path_frames.append((
            coords_tau.astype(np.float32),
            common_mask.copy(),
            float(tau),
        ))

    return path_frames