File size: 15,239 Bytes
ad9572d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
"""
Transition-path interpolation utilities for conformational induction.

Provides:
  - Kabsch-aligned backbone interpolation between two conformational states
  - Gaussian Schrödinger Bridge (DSB) stochastic interpolation
  - Precomputed frame loading (for AlphaFlow / AFsample2)
  - Unified dispatcher: generate_path_frames()
  - Per-residue displacement computation (for allosteric hinge weighting)
  - Monotonically increasing path weight generation

Used by the path-aware training, guidance, and refinement modules.
"""

import os
import logging
import numpy as np

logger = logging.getLogger(__name__)


def _kabsch_align(mobile_ca, ref_ca):
    """
    Kabsch alignment of mobile onto ref (CA atoms only).

    Args:
        mobile_ca: [N, 3] array
        ref_ca: [N, 3] array

    Returns:
        R: [3, 3] rotation matrix
        t_mobile: [3] mobile centroid
        t_ref: [3] ref centroid
        Such that: aligned = (mobile - t_mobile) @ R.T + t_ref
    """
    t_mobile = mobile_ca.mean(axis=0)
    t_ref = ref_ca.mean(axis=0)

    m = mobile_ca - t_mobile
    r = ref_ca - t_ref

    H = m.T @ r
    U, S, Vt = np.linalg.svd(H)
    d = np.linalg.det(Vt.T @ U.T)
    sign = np.array([1.0, 1.0, np.sign(d)])
    R = Vt.T @ np.diag(sign) @ U.T

    return R, t_mobile, t_ref


def interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5):
    """
    Generate intermediate backbone conformations along the X0 -> X1 path.

    1. Find common valid residues between X0 and X1
    2. Kabsch-align X0 onto X1 using CA atoms
    3. Linearly interpolate backbone coords at n_frames equally-spaced tau values
    4. Reconstruct O from N/CA/C with ideal geometry

    Args:
        coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for state 0
        coords_x1: [N1, 4, 3] backbone coords for state 1
        mask_x0: [N0] bool
        mask_x1: [N1] bool
        n_frames: number of intermediate frames (excluding endpoints)

    Returns:
        path_frames: list of (coords_tau, mask_tau, tau) tuples
            coords_tau: [N_common, 4, 3] interpolated backbone coords
            mask_tau: [N_common] bool
            tau: float in (0, 1) exclusive
    """
    # Use common length
    n_common = min(len(coords_x0), len(coords_x1))
    c0 = coords_x0[:n_common].copy()
    c1 = coords_x1[:n_common].copy()
    m0 = mask_x0[:n_common]
    m1 = mask_x1[:n_common]

    # Valid in both states
    common_mask = m0 & m1
    if common_mask.sum() < 5:
        return []

    # Kabsch-align X0 onto X1 using valid CA atoms
    ca0 = c0[common_mask, 1, :]  # CA atoms
    ca1 = c1[common_mask, 1, :]

    R, t_mobile, t_ref = _kabsch_align(ca0, ca1)

    # Apply alignment to all X0 backbone atoms
    n_res = n_common
    flat0 = c0.reshape(-1, 3)
    aligned0 = (flat0 - t_mobile) @ R.T + t_ref
    c0_aligned = aligned0.reshape(n_res, 4, 3)

    # Generate intermediate frames
    taus = np.linspace(0, 1, n_frames + 2)[1:-1]  # exclude endpoints
    path_frames = []

    for tau in taus:
        # Linear interpolation: X_tau = (1 - tau) * X0_aligned + tau * X1
        coords_tau = (1.0 - tau) * c0_aligned + tau * c1

        # Reconstruct O from N, CA, C with ideal C=O bond geometry
        C_pos = coords_tau[:, 2, :]   # C atoms
        CA_pos = coords_tau[:, 1, :]  # CA atoms
        C_CA = C_pos - CA_pos
        C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
        C_CA_norm = np.maximum(C_CA_norm, 1e-8)
        O_pos = C_pos + (C_CA / C_CA_norm) * 1.24  # ideal C=O bond length
        coords_tau[:, 3, :] = O_pos

        path_frames.append((
            coords_tau.astype(np.float32),
            common_mask.copy(),
            float(tau),
        ))

    return path_frames


def compute_residue_displacements(coords_x0, coords_x1, mask_x0, mask_x1):
    """
    Per-residue CA displacement between X0 and X1 after Kabsch alignment.

    Args:
        coords_x0: [N0, 4, 3] backbone coords for state 0
        coords_x1: [N1, 4, 3] backbone coords for state 1
        mask_x0: [N0] bool
        mask_x1: [N1] bool

    Returns:
        displacements: [N_common] array of per-residue CA RMSD
        common_mask: [N_common] bool — which residues are valid
    """
    n_common = min(len(coords_x0), len(coords_x1))
    c0 = coords_x0[:n_common]
    c1 = coords_x1[:n_common]
    m0 = mask_x0[:n_common]
    m1 = mask_x1[:n_common]
    common_mask = m0 & m1

    if common_mask.sum() < 5:
        return np.zeros(n_common), common_mask

    ca0 = c0[common_mask, 1, :]
    ca1 = c1[common_mask, 1, :]

    R, t_mobile, t_ref = _kabsch_align(ca0, ca1)

    # Align all CA of X0
    all_ca0 = c0[:, 1, :]
    aligned_ca0 = (all_ca0 - t_mobile) @ R.T + t_ref

    # Per-residue displacement
    all_ca1 = c1[:, 1, :]
    displacements = np.linalg.norm(aligned_ca0 - all_ca1, axis=-1)

    # Zero out invalid residues
    displacements[~common_mask] = 0.0

    return displacements.astype(np.float32), common_mask


def generate_path_weights(n_frames, mode='linear'):
    """
    Generate monotonically increasing weights for path frames.

    The weights increase toward tau=1 (the goal state), so that
    intermediate conformations closer to X1 are weighted more heavily.

    Args:
        n_frames: number of intermediate frames
        mode: weight schedule
            'linear': w_tau = tau
            'quadratic': w_tau = tau^2
            'exponential': w_tau = (exp(tau) - 1) / (e - 1)
            'uniform': w_tau = 1/n_frames (equal weighting)

    Returns:
        weights: [n_frames] numpy array, normalized to sum to 1
    """
    if n_frames == 0:
        return np.array([], dtype=np.float32)

    taus = np.linspace(0, 1, n_frames + 2)[1:-1]  # same as interpolation

    if mode == 'linear':
        weights = taus.copy()
    elif mode == 'quadratic':
        weights = taus ** 2
    elif mode == 'exponential':
        weights = (np.exp(taus) - 1.0) / (np.e - 1.0)
    elif mode == 'uniform':
        weights = np.ones(n_frames, dtype=np.float32)
    else:
        raise ValueError(f"Unknown weight mode: {mode}")

    # Normalize to sum to 1
    total = weights.sum()
    if total > 0:
        weights = weights / total

    return weights.astype(np.float32)


# ---------------------------------------------------------------------------
# Gaussian Schrödinger Bridge (AlignDSB) interpolation
# ---------------------------------------------------------------------------

def dsb_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
                       n_frames=5, sigma=0.5, n_samples=20, seed=42):
    """
    Gaussian Schrödinger Bridge with t*(1-t) variance schedule.

    Analytic formula (no neural network):
        X_t = (1-t) * X0_aligned + t * X1 + sqrt(t * (1-t)) * sigma * Z

    Variance peaks at t=0.5 (maximum uncertainty mid-transition) and vanishes
    at endpoints. sigma controls noise amplitude in Angstroms.

    For each tau, samples n_samples noisy interpolations and selects the
    median (by RMSD to the mean) for robustness.

    Args:
        coords_x0: [N0, 4, 3] backbone coords for state 0
        coords_x1: [N1, 4, 3] backbone coords for state 1
        mask_x0: [N0] bool
        mask_x1: [N1] bool
        n_frames: number of intermediate frames
        sigma: noise amplitude (Angstroms)
        n_samples: number of samples per frame for median selection
        seed: random seed

    Returns:
        path_frames: list of (coords_tau, mask_tau, tau) tuples
    """
    rng = np.random.RandomState(seed)

    n_common = min(len(coords_x0), len(coords_x1))
    c0 = coords_x0[:n_common].copy()
    c1 = coords_x1[:n_common].copy()
    m0 = mask_x0[:n_common]
    m1 = mask_x1[:n_common]

    common_mask = m0 & m1
    if common_mask.sum() < 5:
        return []

    # Kabsch-align X0 onto X1
    ca0 = c0[common_mask, 1, :]
    ca1 = c1[common_mask, 1, :]
    R, t_mobile, t_ref = _kabsch_align(ca0, ca1)

    flat0 = c0.reshape(-1, 3)
    aligned0 = (flat0 - t_mobile) @ R.T + t_ref
    c0_aligned = aligned0.reshape(n_common, 4, 3)

    taus = np.linspace(0, 1, n_frames + 2)[1:-1]
    path_frames = []

    for tau in taus:
        noise_scale = np.sqrt(tau * (1.0 - tau)) * sigma

        # Generate n_samples noisy interpolations
        samples = []
        for _ in range(n_samples):
            Z = rng.randn(n_common, 4, 3).astype(np.float64)
            X_t = (1.0 - tau) * c0_aligned + tau * c1 + noise_scale * Z
            samples.append(X_t)

        samples = np.array(samples)  # [n_samples, N, 4, 3]
        mean_sample = samples.mean(axis=0)  # [N, 4, 3]

        # Select median sample by RMSD to mean (CA atoms)
        rmsds = []
        for s in samples:
            diff = s[common_mask, 1, :] - mean_sample[common_mask, 1, :]
            rmsd = np.sqrt((diff ** 2).sum() / common_mask.sum())
            rmsds.append(rmsd)
        median_idx = np.argsort(rmsds)[len(rmsds) // 2]
        coords_tau = samples[median_idx]

        # Reconstruct O from N, CA, C
        C_pos = coords_tau[:, 2, :]
        CA_pos = coords_tau[:, 1, :]
        C_CA = C_pos - CA_pos
        C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
        C_CA_norm = np.maximum(C_CA_norm, 1e-8)
        coords_tau[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24

        path_frames.append((
            coords_tau.astype(np.float32),
            common_mask.copy(),
            float(tau),
        ))

    return path_frames


# ---------------------------------------------------------------------------
# Precomputed frame loading (for AlphaFlow / AFsample2)
# ---------------------------------------------------------------------------

def load_precomputed_frames(target, method, precomputed_dir,
                             coords_x0, coords_x1, mask_x0, mask_x1,
                             n_frames=5):
    """
    Load pre-generated frames from .npz and Kabsch-align to this complex's
    receptor coordinate frame.

    Expected file: {precomputed_dir}/{target}/{method}/frames.npz
    with keys: 'frames' [n_frames, N_ref, 4, 3], 'taus' [n_frames],
               'mask' [N_ref] bool

    Args:
        target: target name (e.g. 'cam')
        method: method name ('alphaflow' or 'afsample2')
        precomputed_dir: root directory for precomputed frames
        coords_x0, coords_x1: apo/holo backbone coords for alignment
        mask_x0, mask_x1: residue masks
        n_frames: number of frames to return

    Returns:
        path_frames: list of (coords_tau, mask_tau, tau) tuples
    """
    npz_path = os.path.join(precomputed_dir, target, method, 'frames.npz')
    if not os.path.exists(npz_path):
        logger.warning(f"Precomputed frames not found: {npz_path}, "
                       f"falling back to linear interpolation")
        return interpolate_backbone_path(coords_x0, coords_x1,
                                          mask_x0, mask_x1, n_frames)

    data = np.load(npz_path)
    pre_frames = data['frames']   # [K, N_ref, 4, 3]
    pre_taus = data['taus']       # [K]
    pre_mask = data['mask']       # [N_ref]

    n_common = min(len(coords_x0), len(coords_x1), len(pre_mask))
    m0 = mask_x0[:n_common]
    m1 = mask_x1[:n_common]
    pm = pre_mask[:n_common]
    common_mask = m0 & m1 & pm

    if common_mask.sum() < 5:
        logger.warning(f"Too few common residues for {target}/{method}, "
                       f"falling back to linear")
        return interpolate_backbone_path(coords_x0, coords_x1,
                                          mask_x0, mask_x1, n_frames)

    # Align precomputed frames to the holo receptor (X1) coordinate frame
    # The precomputed frames were generated from the reference apo sequence
    # and may be in a different coordinate frame
    ref_ca = coords_x1[:n_common][common_mask, 1, :]  # holo CA as reference

    path_frames = []
    K = min(len(pre_frames), n_frames)

    # Select n_frames evenly spaced from available frames
    if len(pre_frames) > n_frames:
        indices = np.linspace(0, len(pre_frames) - 1, n_frames).astype(int)
    else:
        indices = np.arange(K)

    for idx in indices:
        frame = pre_frames[idx, :n_common].copy()  # [N_common, 4, 3]
        tau = float(pre_taus[idx])

        # Kabsch-align frame CA to holo CA
        frame_ca = frame[common_mask, 1, :]
        R, t_frame, t_ref = _kabsch_align(frame_ca, ref_ca)

        flat_frame = frame.reshape(-1, 3)
        aligned = (flat_frame - t_frame) @ R.T + t_ref
        frame_aligned = aligned.reshape(n_common, 4, 3)

        # Reconstruct O
        C_pos = frame_aligned[:, 2, :]
        CA_pos = frame_aligned[:, 1, :]
        C_CA = C_pos - CA_pos
        C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
        C_CA_norm = np.maximum(C_CA_norm, 1e-8)
        frame_aligned[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24

        path_frames.append((
            frame_aligned.astype(np.float32),
            common_mask.copy(),
            tau,
        ))

    return path_frames


# ---------------------------------------------------------------------------
# Unified dispatcher
# ---------------------------------------------------------------------------

def generate_path_frames(coords_x0, coords_x1, mask_x0, mask_x1,
                          method='linear', n_frames=5,
                          precomputed_dir=None, target=None, **kwargs):
    """
    Dispatch to method-specific frame generation.

    Args:
        coords_x0, coords_x1: [N, 4, 3] backbone coords for apo/holo
        mask_x0, mask_x1: [N] bool masks
        method: one of 'linear', 'alphaflow', 'afsample2', 'dsb', 'anm'
        n_frames: number of intermediate frames
        precomputed_dir: directory for precomputed frames (alphaflow/afsample2)
        target: target name (needed for precomputed methods)
        **kwargs: method-specific parameters (sigma, n_modes, etc.)

    Returns:
        path_frames: list of (coords_tau, mask_tau, tau) tuples
    """
    if method == 'linear':
        return interpolate_backbone_path(
            coords_x0, coords_x1, mask_x0, mask_x1, n_frames)

    elif method in ('alphaflow', 'afsample2'):
        if precomputed_dir is None:
            raise ValueError(f"precomputed_dir required for method '{method}'")
        if target is None:
            raise ValueError(f"target name required for method '{method}'")
        return load_precomputed_frames(
            target, method, precomputed_dir,
            coords_x0, coords_x1, mask_x0, mask_x1, n_frames)

    elif method == 'dsb':
        return dsb_backbone_path(
            coords_x0, coords_x1, mask_x0, mask_x1,
            n_frames=n_frames,
            sigma=kwargs.get('sigma', 0.5),
            n_samples=kwargs.get('n_samples', 20),
            seed=kwargs.get('seed', 42))

    elif method == 'anm':
        from utils.anm import anm_backbone_path
        return anm_backbone_path(
            coords_x0, coords_x1, mask_x0, mask_x1,
            n_frames=n_frames,
            n_modes=kwargs.get('n_modes', 10),
            cutoff=kwargs.get('cutoff', 15.0))

    else:
        raise ValueError(f"Unknown path method: '{method}'. "
                         f"Choose from: linear, alphaflow, afsample2, dsb, anm")