File size: 9,621 Bytes
ad9572d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
SE(3)-invariant feature extraction for interface graphs.
Node and edge features used by the Q_theta scorer.
"""

import os
import sys
import numpy as np

# Ensure utils is importable (for both direct and package imports)
_CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _CODE_DIR not in sys.path:
    sys.path.insert(0, _CODE_DIR)

from utils.pdb_utils import (
    rbf_encode, compute_backbone_frames, compute_torsion_angles,
    get_aa_indices, compute_chi_angles, get_cb_positions, NUM_AA
)

# Feature dimensions
# one-hot AA (21) + backbone torsions (6) + chi1 sin/cos (2) + chi2 sin/cos (2) + chain indicator (1) = 32
NODE_DIM = NUM_AA + 6 + 4 + 1  # = 32
EDGE_DIM = 16 + 3 + 9 + 8 + 1  # RBF dist (16) + direction (3) + rel rotation (9) + seq sep (8) + same chain (1) = 37

MAX_SEQ_SEP = 32  # bins for sequence separation


def seq_sep_encode(sep, n_bins=8, max_sep=MAX_SEQ_SEP):
    """Bin-encode sequence separation."""
    bins = np.linspace(-max_sep, max_sep, n_bins + 1)
    sep_clipped = np.clip(sep, -max_sep, max_sep)
    encoded = np.zeros(n_bins, dtype=np.float32)
    bin_idx = np.digitize(sep_clipped, bins) - 1
    bin_idx = np.clip(bin_idx, 0, n_bins - 1)
    encoded[bin_idx] = 1.0
    return encoded


def extract_node_features(residues, coords, mask, torsion_angles, chi_angles, chain_id):
    """
    Compute per-residue node features.

    Args:
        residues: list of Bio.PDB residues
        coords: [N, 4, 3] backbone coords
        mask: [N] bool
        torsion_angles: [N, 6] sin/cos of phi, psi, omega
        chi_angles: [N, 4] sin/cos of chi1, chi2
        chain_id: 0 = receptor, 1 = binder

    Returns:
        node_feats: [N, NODE_DIM]  (NODE_DIM = 32)
    """
    N = len(residues)
    aa_idx = get_aa_indices(residues)

    # One-hot amino acid
    aa_onehot = np.zeros((N, NUM_AA), dtype=np.float32)
    for i in range(N):
        if mask[i]:
            aa_onehot[i, aa_idx[i]] = 1.0

    # Chain indicator
    chain_feat = np.full((N, 1), chain_id, dtype=np.float32)

    # Concatenate
    node_feats = np.concatenate([
        aa_onehot,          # [N, 21]
        torsion_angles,     # [N, 6]
        chi_angles,         # [N, 4]
        chain_feat,         # [N, 1]
    ], axis=-1)

    return node_feats  # [N, 32]


def extract_edge_features(coords_i, frames_i, coords_j, frames_j,
                          seq_idx_i, seq_idx_j, chain_i, chain_j, mask_i, mask_j):
    """
    Compute SE(3)-invariant edge features between residue sets i and j.
    Vectorized over all pairs.

    Args:
        coords_i: [N_i, 4, 3] backbone coords of set i (full interface)
        frames_i: (origins_i [N_i, 3], rotations_i [N_i, 3, 3])
        coords_j: [N_j, 4, 3]
        frames_j: (origins_j [N_j, 3], rotations_j [N_j, 3, 3])
        seq_idx_i: [N_i] integer sequence indices (for sequence separation)
        seq_idx_j: [N_j] integer sequence indices
        chain_i: int (0 or 1)
        chain_j: int (0 or 1)
        mask_i: [N_i] bool
        mask_j: [N_j] bool

    Returns:
        edge_feats: [N_i, N_j, EDGE_DIM]
    """
    N_i, N_j = len(coords_i), len(coords_j)
    origins_i, rotations_i = frames_i
    origins_j, rotations_j = frames_j

    ca_i = origins_i  # [N_i, 3]
    ca_j = origins_j  # [N_j, 3]

    # --- Distance features ---
    diff = ca_j[None, :, :] - ca_i[:, None, :]  # [N_i, N_j, 3]
    dist = np.sqrt((diff ** 2).sum(axis=-1))     # [N_i, N_j]
    dist_rbf = rbf_encode(dist, d_min=0., d_max=20., n_bins=16)  # [N_i, N_j, 16]

    # --- Direction in local frame of i ---
    # unit vector from i to j in global frame
    unit_diff = diff / (dist[..., None] + 1e-8)  # [N_i, N_j, 3]
    # rotate by R_i^T to get local direction
    # rotations_i: [N_i, 3, 3], unit_diff: [N_i, N_j, 3]
    # local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
    local_dir = np.einsum('ikl,ijl->ijk', rotations_i, unit_diff)  # [N_i, N_j, 3]

    # --- Relative rotation: R_i^T R_j ---
    # rotations_i: [N_i, 3, 3], rotations_j: [N_j, 3, 3]
    # rel_rot[i,j] = R_i^T @ R_j -> [N_i, N_j, 3, 3] -> flatten to [N_i, N_j, 9]
    rel_rot = np.einsum('ikl,jlm->ijkm', rotations_i, rotations_j)  # [N_i, N_j, 3, 3]
    rel_rot_flat = rel_rot.reshape(N_i, N_j, 9)  # [N_i, N_j, 9]

    # --- Sequence separation ---
    sep = seq_idx_j[None, :] - seq_idx_i[:, None]  # [N_i, N_j]
    # Encode each pair (loop over all; use vectorized bin assignment)
    sep_flat = sep.reshape(-1)
    sep_enc = np.array([seq_sep_encode(s) for s in sep_flat])  # [N_i*N_j, 8]
    sep_enc = sep_enc.reshape(N_i, N_j, 8)

    # Cross-chain pairs get sep=0 by convention if different chains
    if chain_i != chain_j:
        sep_enc[:] = 0.0

    # --- Same chain indicator ---
    same_chain = float(chain_i == chain_j)
    same_chain_feat = np.full((N_i, N_j, 1), same_chain, dtype=np.float32)

    # --- Concatenate ---
    edge_feats = np.concatenate([
        dist_rbf,       # [N_i, N_j, 16]
        local_dir,      # [N_i, N_j, 3]
        rel_rot_flat,   # [N_i, N_j, 9]
        sep_enc,        # [N_i, N_j, 8]
        same_chain_feat # [N_i, N_j, 1]
    ], axis=-1)         # [N_i, N_j, 37]

    # Zero out edges involving masked residues
    edge_feats[~mask_i, :, :] = 0.0
    edge_feats[:, ~mask_j, :] = 0.0

    return edge_feats.astype(np.float32)


def build_interface_graph(rec_residues, rec_coords, rec_mask,
                          binder_residues, binder_coords, binder_mask,
                          rec_interface_mask, binder_interface_mask,
                          max_nodes: int = 128):
    """
    Build a joint interface graph combining receptor and binder interface residues.

    Returns a dict with:
        node_feats: [N_total, NODE_DIM]
        edge_feats: [N_total, N_total, EDGE_DIM]
        node_mask: [N_total] bool
        n_rec: int (number of receptor interface nodes)
        n_binder: int (number of binder interface nodes)
    """
    # Select interface residues
    rec_iface_idx = np.where(rec_interface_mask)[0]
    binder_iface_idx = np.where(binder_interface_mask)[0]

    # Truncate if too many
    if len(rec_iface_idx) > max_nodes // 2:
        rec_iface_idx = rec_iface_idx[:max_nodes // 2]
    if len(binder_iface_idx) > max_nodes // 2:
        binder_iface_idx = binder_iface_idx[:max_nodes // 2]

    n_rec = len(rec_iface_idx)
    n_binder = len(binder_iface_idx)
    n_total = n_rec + n_binder

    if n_total == 0:
        return None

    # Extract coords for interface residues
    rec_iface_coords = rec_coords[rec_iface_idx]    # [n_rec, 4, 3]
    binder_iface_coords = binder_coords[binder_iface_idx]  # [n_binder, 4, 3]
    rec_iface_mask = rec_mask[rec_iface_idx]
    binder_iface_mask = binder_mask[binder_iface_idx]

    # Compute backbone frames
    rec_origins, rec_rotations = compute_backbone_frames(rec_iface_coords, rec_iface_mask)
    binder_origins, binder_rotations = compute_backbone_frames(binder_iface_coords, binder_iface_mask)

    # Compute torsion angles
    # We need full-chain coords for proper phi/psi computation, but use local approximation here
    rec_torsion = compute_torsion_angles(rec_iface_coords, rec_iface_mask)
    binder_torsion = compute_torsion_angles(binder_iface_coords, binder_iface_mask)

    # Extract residues
    rec_iface_residues = [rec_residues[i] for i in rec_iface_idx]
    binder_iface_residues = [binder_residues[i] for i in binder_iface_idx]

    # Compute sidechain chi1/chi2 angles
    rec_chi = compute_chi_angles(rec_iface_residues, rec_iface_mask)
    binder_chi = compute_chi_angles(binder_iface_residues, binder_iface_mask)

    # Node features
    rec_node_feats = extract_node_features(
        rec_iface_residues, rec_iface_coords, rec_iface_mask, rec_torsion, rec_chi, chain_id=0
    )  # [n_rec, NODE_DIM]
    binder_node_feats = extract_node_features(
        binder_iface_residues, binder_iface_coords, binder_iface_mask, binder_torsion, binder_chi, chain_id=1
    )  # [n_binder, NODE_DIM]

    node_feats = np.concatenate([rec_node_feats, binder_node_feats], axis=0)  # [N, NODE_DIM]
    node_mask = np.concatenate([rec_iface_mask, binder_iface_mask], axis=0)

    # Edge features (4 blocks: RR, RB, BR, BB)
    all_coords = np.concatenate([rec_iface_coords, binder_iface_coords], axis=0)
    all_mask = node_mask
    all_origins = np.concatenate([rec_origins, binder_origins], axis=0)
    all_rotations = np.concatenate([rec_rotations, binder_rotations], axis=0)
    all_seq_idx = np.concatenate([rec_iface_idx, binder_iface_idx + len(rec_residues)], axis=0)
    all_chain = np.array([0] * n_rec + [1] * n_binder, dtype=np.int32)

    # Compute full NxN edge features
    frames_all = (all_origins, all_rotations)
    edge_feats = extract_edge_features(
        all_coords, frames_all,
        all_coords, frames_all,
        all_seq_idx, all_seq_idx,
        -1, -1,  # chain handled via all_chain array below
        all_mask, all_mask
    )  # [N, N, EDGE_DIM]

    # Patch same_chain feature (last dim) using actual chain IDs
    same_chain_feat = (all_chain[:, None] == all_chain[None, :]).astype(np.float32)
    edge_feats[:, :, -1] = same_chain_feat

    return {
        'node_feats': node_feats.astype(np.float32),    # [N, NODE_DIM]
        'edge_feats': edge_feats.astype(np.float32),    # [N, N, EDGE_DIM]
        'node_mask': node_mask,                          # [N]
        'n_rec': n_rec,
        'n_binder': n_binder,
        'rec_iface_idx': rec_iface_idx,                  # [n_rec] original residue indices
        'binder_iface_idx': binder_iface_idx,            # [n_binder] original residue indices
    }