| """ |
| PyTorch Dataset for two-state complex scoring. |
| |
| Loads preprocessed graph data and provides batched tensors |
| with padding for variable-sized interface graphs. |
| """ |
|
|
| import os |
| import json |
| import pickle |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
|
|
| |
| _ESM_CACHE = {} |
|
|
|
|
| def preload_esm_cache(esm_dir, targets): |
| """Preload all ESM .pt files into global cache before DataLoader workers fork. |
| |
| This ensures forked workers inherit the populated cache via copy-on-write, |
| avoiding redundant I/O across workers. |
| """ |
| import glob as glob_mod |
| n = 0 |
| for target in targets: |
| target_dir = os.path.join(esm_dir, target) |
| if not os.path.isdir(target_dir): |
| continue |
| for pt_file in glob_mod.glob(os.path.join(target_dir, '*.pt')): |
| if pt_file not in _ESM_CACHE: |
| _ESM_CACHE[pt_file] = torch.load(pt_file, map_location='cpu', weights_only=True) |
| n += 1 |
| return n |
|
|
|
|
| def load_esm_for_sample(sample, esm_dir, target_name, max_nodes=128): |
| """Load and index ESM-2 embeddings for a sample's interface residues. |
| |
| Returns: esm_feats [max_nodes, 1280] or None if unavailable. |
| """ |
| graph = sample['graph'] |
| rec_idx = graph.get('rec_iface_idx') |
| binder_idx = graph.get('binder_iface_idx') |
| if rec_idx is None or binder_idx is None: |
| return None |
|
|
| |
| pdb_id = sample.get('pdb', '') |
| base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id |
| rec_chain = sample.get('rec_chain_id', 'A') |
| binder_chain = sample.get('binder_chain_id', 'B') |
|
|
| |
| rec_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{rec_chain}.pt') |
| binder_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{binder_chain}.pt') |
|
|
| def _load_cached(path): |
| if path not in _ESM_CACHE: |
| if not os.path.exists(path): |
| return None |
| _ESM_CACHE[path] = torch.load(path, map_location='cpu', weights_only=True) |
| return _ESM_CACHE[path] |
|
|
| rec_esm = _load_cached(rec_path) |
| binder_esm = _load_cached(binder_path) |
| if rec_esm is None or binder_esm is None: |
| return None |
|
|
| esm_dim = rec_esm.shape[-1] |
| n_rec = len(rec_idx) |
| n_binder = len(binder_idx) |
|
|
| |
| rec_idx_safe = np.clip(rec_idx, 0, len(rec_esm) - 1) |
| binder_idx_safe = np.clip(binder_idx, 0, len(binder_esm) - 1) |
|
|
| esm_feats = np.zeros((max_nodes, esm_dim), dtype=np.float32) |
| esm_feats[:n_rec] = rec_esm[rec_idx_safe].numpy() |
| esm_feats[n_rec:n_rec + n_binder] = binder_esm[binder_idx_safe].numpy() |
|
|
| return esm_feats |
|
|
|
|
| def load_rosetta_labels(rosetta_dir, target): |
| """Load Rosetta dG labels for a target and normalize to [0,1].""" |
| path = os.path.join(rosetta_dir, f'{target}_rosetta.json') |
| if not os.path.exists(path): |
| return None |
| with open(path) as f: |
| raw = json.load(f) |
| if not raw: |
| return None |
| |
| dG_MIN, dG_MAX = -500.0, 500.0 |
| |
| |
| tau = 15.0 |
| labels = {} |
| for pdb_id, metrics in raw.items(): |
| dG = metrics.get('dG_separated', 0.0) |
| if not np.isfinite(dG) or dG < dG_MIN or dG > dG_MAX: |
| continue |
| labels[pdb_id] = 1.0 / (1.0 + np.exp(dG / tau)) |
| labels[pdb_id.upper()] = labels[pdb_id] |
| labels[pdb_id.lower()] = labels[pdb_id] |
| return labels |
|
|
|
|
| def apply_rosetta_labels(samples, rosetta_labels, label_source='rosetta', alpha=0.5): |
| """Replace or combine sample labels with Rosetta-derived labels.""" |
| if rosetta_labels is None: |
| return |
| n_replaced = 0 |
| for s in samples: |
| pdb_id = s.get('pdb', '') |
| |
| base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id |
| rosetta_val = rosetta_labels.get(base_pdb) or rosetta_labels.get(base_pdb.upper()) |
| if rosetta_val is None: |
| continue |
| if s['type'] == 'positive': |
| new_label = rosetta_val |
| elif s['type'].startswith('negative'): |
| new_label = 0.0 |
| continue |
| elif s['type'].startswith('decoy'): |
| |
| new_label = s['label'] * rosetta_val |
| else: |
| continue |
| if label_source == 'rosetta': |
| s['label'] = float(new_label) |
| elif label_source == 'combined': |
| s['label'] = float(alpha * s['label'] + (1 - alpha) * new_label) |
| n_replaced += 1 |
| return n_replaced |
|
|
|
|
| class TwoStateComplexDataset(Dataset): |
| """ |
| Dataset of protein complex interface graphs with two-state labels. |
| |
| Each sample contains: |
| node_feats: [N, node_dim] interface residue features |
| edge_feats: [N, N, edge_dim] pairwise SE(3)-invariant features |
| node_mask: [N] bool |
| label: scalar float in [0, 1] (DockQ proxy / selectivity label) |
| type: str (positive / negative_apo / decoy_*) |
| pdb: str |
| """ |
|
|
| def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False, |
| rosetta_labels: dict = None, label_source: str = 'dockq', |
| esm_dir: str = None, target_name: str = None, |
| binder_dropout: float = 0.0): |
| with open(data_path, 'rb') as f: |
| self.samples = pickle.load(f) |
| self.max_nodes = max_nodes |
| self.augment = augment |
| self.esm_dir = esm_dir |
| self.target_name = target_name |
| self.binder_dropout = binder_dropout |
| if label_source != 'dockq' and rosetta_labels: |
| apply_rosetta_labels(self.samples, rosetta_labels, label_source) |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
| graph = sample['graph'] |
|
|
| node_feats = graph['node_feats'] |
| edge_feats = graph['edge_feats'] |
| node_mask = graph['node_mask'] |
|
|
| N = len(node_feats) |
| assert N <= self.max_nodes, f"Too many nodes: {N} > {self.max_nodes}" |
|
|
| |
| node_dim = node_feats.shape[-1] |
| edge_dim = edge_feats.shape[-1] |
|
|
| node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) |
| edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) |
| node_mask_pad = np.zeros(self.max_nodes, dtype=bool) |
|
|
| node_feats_pad[:N] = node_feats |
| edge_feats_pad[:N, :N] = edge_feats |
| node_mask_pad[:N] = node_mask |
|
|
| |
| if self.augment: |
| noise = np.random.randn(*node_feats_pad.shape) * 0.01 |
| node_feats_pad = node_feats_pad + noise.astype(np.float32) |
|
|
| |
| |
| apply_binder_drop = (self.binder_dropout > 0 |
| and np.random.rand() < self.binder_dropout) |
| if apply_binder_drop: |
| n_rec = graph.get('n_rec', N // 2) |
| |
| node_feats_pad[n_rec:N, :21] = 0.0 |
| node_feats_pad[n_rec:N, 20] = 1.0 |
| |
| node_feats_pad[n_rec:N, 27:31] = 0.0 |
| |
|
|
| result = { |
| 'node_feats': torch.from_numpy(node_feats_pad), |
| 'edge_feats': torch.from_numpy(edge_feats_pad), |
| 'node_mask': torch.from_numpy(node_mask_pad), |
| 'label': torch.tensor(sample['label'], dtype=torch.float32), |
| 'type': sample['type'], |
| 'pdb': sample['pdb'], |
| } |
|
|
| |
| if self.esm_dir: |
| esm = load_esm_for_sample(sample, self.esm_dir, |
| self.target_name or '', self.max_nodes) |
| if esm is not None: |
| esm_feats = esm |
| else: |
| esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32) |
| |
| if apply_binder_drop: |
| n_rec = graph.get('n_rec', N // 2) |
| n_binder = graph.get('n_binder', N - n_rec) |
| esm_feats[n_rec:n_rec + n_binder] = 0.0 |
| result['esm_feats'] = torch.from_numpy(esm_feats) |
|
|
| return result |
|
|
|
|
| def collate_fn(batch): |
| """Collate a list of samples into batched tensors.""" |
| node_feats = torch.stack([s['node_feats'] for s in batch]) |
| edge_feats = torch.stack([s['edge_feats'] for s in batch]) |
| node_mask = torch.stack([s['node_mask'] for s in batch]) |
| labels = torch.stack([s['label'] for s in batch]) |
| types = [s['type'] for s in batch] |
| pdbs = [s['pdb'] for s in batch] |
|
|
| result = { |
| 'node_feats': node_feats, |
| 'edge_feats': edge_feats, |
| 'node_mask': node_mask, |
| 'label': labels, |
| 'type': types, |
| 'pdb': pdbs, |
| } |
|
|
| |
| has_esm = any('esm_feats' in s for s in batch) |
| if has_esm: |
| esm_list = [] |
| for s in batch: |
| if 'esm_feats' in s: |
| esm_list.append(s['esm_feats']) |
| else: |
| |
| ref = next(x['esm_feats'] for x in batch if 'esm_feats' in x) |
| esm_list.append(torch.zeros_like(ref)) |
| result['esm_feats'] = torch.stack(esm_list) |
|
|
| return result |
|
|
|
|
| class TwoStateDatasetPaired(Dataset): |
| """ |
| Paired dataset: returns (positive, negative) pairs for selectivity training. |
| Groups samples by PDB ID and pairs positive (holo) with negative (apo) examples. |
| """ |
|
|
| def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False, |
| esm_dir: str = None, target_name: str = None, |
| binder_dropout: float = 0.0): |
| with open(data_path, 'rb') as f: |
| samples = pickle.load(f) |
| self.max_nodes = max_nodes |
| self.augment = augment |
| self.esm_dir = esm_dir |
| self.target_name = target_name |
| self.binder_dropout = binder_dropout |
|
|
| |
| from collections import defaultdict |
| by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []}) |
| for s in samples: |
| pdb = s['pdb'] |
| t = s['type'] |
| if t == 'positive': |
| by_pdb[pdb]['positive'].append(s) |
| elif t.startswith('negative'): |
| by_pdb[pdb]['negative'].append(s) |
| elif t.startswith('decoy'): |
| by_pdb[pdb]['decoy'].append(s) |
|
|
| |
| self.pairs = [] |
| for pdb, groups in by_pdb.items(): |
| if len(groups['positive']) > 0 and len(groups['negative']) > 0: |
| for pos in groups['positive']: |
| for neg in groups['negative']: |
| self.pairs.append((pos, neg)) |
| |
| if len(groups['positive']) > 0 and len(groups['decoy']) > 0: |
| large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and |
| float(s['type'].replace('decoy_rmsd', '')) > 4.0] |
| for pos in groups['positive']: |
| for neg in large_decoys[:3]: |
| self.pairs.append((pos, neg)) |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def _prepare(self, sample, apply_binder_drop=False): |
| graph = sample['graph'] |
| node_feats = graph['node_feats'] |
| edge_feats = graph['edge_feats'] |
| node_mask = graph['node_mask'] |
| N = len(node_feats) |
| node_dim = node_feats.shape[-1] |
| edge_dim = edge_feats.shape[-1] |
|
|
| node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) |
| edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) |
| node_mask_pad = np.zeros(self.max_nodes, dtype=bool) |
|
|
| n = min(N, self.max_nodes) |
| node_feats_pad[:n] = node_feats[:n] |
| edge_feats_pad[:n, :n] = edge_feats[:n, :n] |
| node_mask_pad[:n] = node_mask[:n] |
|
|
| |
| if apply_binder_drop: |
| n_rec = graph.get('n_rec', n // 2) |
| node_feats_pad[n_rec:n, :21] = 0.0 |
| node_feats_pad[n_rec:n, 20] = 1.0 |
| node_feats_pad[n_rec:n, 27:31] = 0.0 |
|
|
| result = { |
| 'node_feats': torch.from_numpy(node_feats_pad), |
| 'edge_feats': torch.from_numpy(edge_feats_pad), |
| 'node_mask': torch.from_numpy(node_mask_pad), |
| 'label': torch.tensor(sample['label'], dtype=torch.float32), |
| 'contact_energy': torch.tensor( |
| sample.get('contact_energy', 0.5), dtype=torch.float32 |
| ), |
| } |
|
|
| |
| if self.esm_dir: |
| esm = load_esm_for_sample(sample, self.esm_dir, |
| self.target_name or '', self.max_nodes) |
| if esm is not None: |
| esm_feats = esm |
| else: |
| esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32) |
| if apply_binder_drop: |
| n_rec = graph.get('n_rec', n // 2) |
| n_binder = graph.get('n_binder', n - n_rec) |
| esm_feats[n_rec:n_rec + n_binder] = 0.0 |
| result['esm_feats'] = torch.from_numpy(esm_feats) |
|
|
| return result |
|
|
| def __getitem__(self, idx): |
| pos_sample, neg_sample = self.pairs[idx] |
| |
| drop = (self.binder_dropout > 0 |
| and np.random.rand() < self.binder_dropout) |
| return { |
| 'pos': self._prepare(pos_sample, apply_binder_drop=drop), |
| 'neg': self._prepare(neg_sample, apply_binder_drop=drop), |
| } |
|
|
|
|
| def collate_paired_fn(batch): |
| """Collate paired (positive, negative) samples.""" |
| pos_batch = { |
| 'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]), |
| 'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]), |
| 'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]), |
| 'label': torch.stack([s['pos']['label'] for s in batch]), |
| 'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]), |
| } |
| neg_batch = { |
| 'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]), |
| 'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]), |
| 'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]), |
| 'label': torch.stack([s['neg']['label'] for s in batch]), |
| 'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]), |
| } |
| |
| has_pos_esm = any('esm_feats' in s['pos'] for s in batch) |
| if has_pos_esm: |
| def _stack_esm(batch_list, key): |
| esm_list = [] |
| ref = next((x[key]['esm_feats'] for x in batch_list if 'esm_feats' in x[key]), None) |
| for s in batch_list: |
| if 'esm_feats' in s[key]: |
| esm_list.append(s[key]['esm_feats']) |
| else: |
| esm_list.append(torch.zeros_like(ref)) |
| return torch.stack(esm_list) |
| pos_batch['esm_feats'] = _stack_esm(batch, 'pos') |
| neg_batch['esm_feats'] = _stack_esm(batch, 'neg') |
| return {'pos': pos_batch, 'neg': neg_batch} |
|
|
|
|
| class PathAwareDatasetPaired(Dataset): |
| """ |
| Paired dataset with transition-path frames for path-aware Phase 2 training. |
| |
| Extends TwoStateDatasetPaired: each sample returns (positive, negative, path_frames) |
| where path_frames is a list of prepared graph dicts for intermediate conformations |
| stored in the positive sample's 'path_graphs' field. |
| """ |
|
|
| def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False): |
| with open(data_path, 'rb') as f: |
| samples = pickle.load(f) |
| self.max_nodes = max_nodes |
| self.augment = augment |
|
|
| from collections import defaultdict |
| by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []}) |
| for s in samples: |
| pdb = s['pdb'] |
| t = s['type'] |
| if t == 'positive': |
| by_pdb[pdb]['positive'].append(s) |
| elif t.startswith('negative'): |
| by_pdb[pdb]['negative'].append(s) |
| elif t.startswith('decoy'): |
| by_pdb[pdb]['decoy'].append(s) |
|
|
| self.pairs = [] |
| for pdb, groups in by_pdb.items(): |
| if len(groups['positive']) > 0 and len(groups['negative']) > 0: |
| for pos in groups['positive']: |
| for neg in groups['negative']: |
| self.pairs.append((pos, neg)) |
| if len(groups['positive']) > 0 and len(groups['decoy']) > 0: |
| large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and |
| float(s['type'].replace('decoy_rmsd', '')) > 4.0] |
| for pos in groups['positive']: |
| for neg in large_decoys[:3]: |
| self.pairs.append((pos, neg)) |
|
|
| def _prepare(self, sample): |
| graph = sample['graph'] |
| node_feats = graph['node_feats'] |
| edge_feats = graph['edge_feats'] |
| node_mask = graph['node_mask'] |
| N = len(node_feats) |
| node_dim = node_feats.shape[-1] |
| edge_dim = edge_feats.shape[-1] |
|
|
| node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) |
| edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) |
| node_mask_pad = np.zeros(self.max_nodes, dtype=bool) |
|
|
| n = min(N, self.max_nodes) |
| node_feats_pad[:n] = node_feats[:n] |
| edge_feats_pad[:n, :n] = edge_feats[:n, :n] |
| node_mask_pad[:n] = node_mask[:n] |
|
|
| return { |
| 'node_feats': torch.from_numpy(node_feats_pad), |
| 'edge_feats': torch.from_numpy(edge_feats_pad), |
| 'node_mask': torch.from_numpy(node_mask_pad), |
| 'label': torch.tensor(sample.get('label', 0.0), dtype=torch.float32), |
| 'contact_energy': torch.tensor( |
| sample.get('contact_energy', 0.5), dtype=torch.float32 |
| ), |
| } |
|
|
| def _prepare_graph_only(self, path_entry): |
| """Prepare a path frame graph (no label/contact_energy needed).""" |
| graph = path_entry['graph'] |
| node_feats = graph['node_feats'] |
| edge_feats = graph['edge_feats'] |
| node_mask = graph['node_mask'] |
| N = len(node_feats) |
| node_dim = node_feats.shape[-1] |
| edge_dim = edge_feats.shape[-1] |
|
|
| node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) |
| edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) |
| node_mask_pad = np.zeros(self.max_nodes, dtype=bool) |
|
|
| n = min(N, self.max_nodes) |
| node_feats_pad[:n] = node_feats[:n] |
| edge_feats_pad[:n, :n] = edge_feats[:n, :n] |
| node_mask_pad[:n] = node_mask[:n] |
|
|
| return { |
| 'node_feats': torch.from_numpy(node_feats_pad), |
| 'edge_feats': torch.from_numpy(edge_feats_pad), |
| 'node_mask': torch.from_numpy(node_mask_pad), |
| } |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx): |
| pos_sample, neg_sample = self.pairs[idx] |
| result = { |
| 'pos': self._prepare(pos_sample), |
| 'neg': self._prepare(neg_sample), |
| } |
|
|
| |
| path_graphs = pos_sample.get('path_graphs', []) |
| prepared_paths = [] |
| path_taus = [] |
| for pg in path_graphs: |
| prepared_paths.append(self._prepare_graph_only(pg)) |
| path_taus.append(pg['tau']) |
|
|
| result['path'] = prepared_paths |
| result['path_taus'] = path_taus |
|
|
| return result |
|
|
|
|
| def collate_path_paired_fn(batch): |
| """Collate paired samples with variable-length path frames.""" |
| pos_batch = { |
| 'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]), |
| 'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]), |
| 'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]), |
| 'label': torch.stack([s['pos']['label'] for s in batch]), |
| 'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]), |
| } |
| neg_batch = { |
| 'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]), |
| 'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]), |
| 'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]), |
| 'label': torch.stack([s['neg']['label'] for s in batch]), |
| 'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]), |
| } |
|
|
| |
| max_k = max((len(s['path']) for s in batch), default=0) |
| path_batches = [] |
| path_taus = [] |
|
|
| if max_k > 0: |
| |
| ref = batch[0]['path'][0] if batch[0]['path'] else batch[0]['pos'] |
| zero_placeholder = { |
| 'node_feats': torch.zeros_like(ref['node_feats']), |
| 'edge_feats': torch.zeros_like(ref['edge_feats']), |
| 'node_mask': torch.zeros_like(ref['node_mask']), |
| } |
|
|
| for k_idx in range(max_k): |
| frames_at_k = [] |
| taus_at_k = [] |
| for s in batch: |
| if k_idx < len(s['path']): |
| frames_at_k.append(s['path'][k_idx]) |
| taus_at_k.append(s['path_taus'][k_idx]) |
| else: |
| frames_at_k.append(zero_placeholder) |
| taus_at_k.append(1.0) |
|
|
| path_batches.append({ |
| 'node_feats': torch.stack([f['node_feats'] for f in frames_at_k]), |
| 'edge_feats': torch.stack([f['edge_feats'] for f in frames_at_k]), |
| 'node_mask': torch.stack([f['node_mask'] for f in frames_at_k]), |
| }) |
| path_taus.append(taus_at_k[0]) |
|
|
| result = {'pos': pos_batch, 'neg': neg_batch} |
| if path_batches: |
| result['path'] = path_batches |
| result['path_taus'] = path_taus |
| return result |
|
|
|
|
| class MultiTargetDataset(Dataset): |
| """ |
| Pooled dataset combining samples from multiple targets. |
| Supports balanced sampling across targets. |
| """ |
|
|
| def __init__(self, data_paths: list, max_nodes: int = 128, augment: bool = False, |
| balance: bool = True, rosetta_dir: str = None, label_source: str = 'dockq', |
| esm_dir: str = None, binder_dropout: float = 0.0): |
| """ |
| Args: |
| data_paths: list of (target_name, pkl_path) tuples |
| max_nodes: max interface graph size |
| augment: apply noise augmentation |
| balance: if True, oversample smaller targets to balance |
| rosetta_dir: directory containing Rosetta label JSONs |
| label_source: 'dockq', 'rosetta', or 'combined' |
| """ |
| self.max_nodes = max_nodes |
| self.augment = augment |
| self.esm_dir = esm_dir |
| self.binder_dropout = binder_dropout |
|
|
| |
| self.samples = [] |
| self.target_indices = {} |
|
|
| for target_name, path in data_paths: |
| if not os.path.exists(path): |
| continue |
| with open(path, 'rb') as f: |
| target_samples = pickle.load(f) |
|
|
| |
| if label_source != 'dockq' and rosetta_dir: |
| rl = load_rosetta_labels(rosetta_dir, target_name) |
| if rl: |
| apply_rosetta_labels(target_samples, rl, label_source) |
|
|
| start_idx = len(self.samples) |
| for s in target_samples: |
| s['_target'] = target_name |
| self.samples.append(s) |
| end_idx = len(self.samples) |
| self.target_indices[target_name] = list(range(start_idx, end_idx)) |
|
|
| |
| if balance and len(self.target_indices) > 1: |
| non_empty = {k: v for k, v in self.target_indices.items() if len(v) > 0} |
| max_count = max(len(idxs) for idxs in non_empty.values()) if non_empty else 1 |
| self.weights = np.zeros(len(self.samples)) |
| for target_name, idxs in self.target_indices.items(): |
| if len(idxs) == 0: |
| continue |
| weight = max_count / len(idxs) |
| for i in idxs: |
| self.weights[i] = weight |
| self.weights /= self.weights.sum() |
| else: |
| self.weights = None |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| sample = self.samples[idx] |
| graph = sample['graph'] |
| node_feats = graph['node_feats'] |
| edge_feats = graph['edge_feats'] |
| node_mask = graph['node_mask'] |
| N = len(node_feats) |
| node_dim = node_feats.shape[-1] |
| edge_dim = edge_feats.shape[-1] |
|
|
| node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) |
| edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) |
| node_mask_pad = np.zeros(self.max_nodes, dtype=bool) |
|
|
| n = min(N, self.max_nodes) |
| node_feats_pad[:n] = node_feats[:n] |
| edge_feats_pad[:n, :n] = edge_feats[:n, :n] |
| node_mask_pad[:n] = node_mask[:n] |
|
|
| if self.augment: |
| noise = np.random.randn(*node_feats_pad.shape) * 0.01 |
| node_feats_pad = node_feats_pad + noise.astype(np.float32) |
|
|
| |
| apply_binder_drop = (self.binder_dropout > 0 |
| and np.random.rand() < self.binder_dropout) |
| if apply_binder_drop: |
| n_rec = graph.get('n_rec', N // 2) |
| node_feats_pad[n_rec:N, :21] = 0.0 |
| node_feats_pad[n_rec:N, 20] = 1.0 |
| node_feats_pad[n_rec:N, 27:31] = 0.0 |
|
|
| result = { |
| 'node_feats': torch.from_numpy(node_feats_pad), |
| 'edge_feats': torch.from_numpy(edge_feats_pad), |
| 'node_mask': torch.from_numpy(node_mask_pad), |
| 'label': torch.tensor(sample['label'], dtype=torch.float32), |
| 'type': sample['type'], |
| 'pdb': sample['pdb'], |
| 'target': sample.get('_target', 'unknown'), |
| } |
|
|
| |
| if self.esm_dir: |
| target_name = sample.get('_target', 'unknown') |
| esm = load_esm_for_sample(sample, self.esm_dir, target_name, self.max_nodes) |
| if esm is not None: |
| esm_feats = esm |
| else: |
| esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32) |
| if apply_binder_drop: |
| n_rec = graph.get('n_rec', N // 2) |
| n_binder = graph.get('n_binder', N - n_rec) |
| esm_feats[n_rec:n_rec + n_binder] = 0.0 |
| result['esm_feats'] = torch.from_numpy(esm_feats) |
|
|
| return result |
|
|
| @staticmethod |
| def get_pooled_dataloaders(data_dir, targets, batch_size=16, max_nodes=128, |
| num_workers=4, paired=False, |
| rosetta_dir=None, label_source='dockq', |
| esm_dir=None, binder_dropout=0.0): |
| """Build pooled dataloaders from multiple targets. |
| |
| Args: |
| data_dir: root data directory |
| targets: list of target names |
| batch_size: batch size |
| max_nodes: max interface nodes |
| num_workers: dataloader workers |
| paired: if True, build paired dataloaders for Phase 2 |
| rosetta_dir: directory with Rosetta label JSONs |
| label_source: 'dockq', 'rosetta', or 'combined' |
| """ |
| from torch.utils.data import WeightedRandomSampler |
|
|
| |
| if esm_dir: |
| n_loaded = preload_esm_cache(esm_dir, targets) |
|
|
| loaders = {} |
| for split in ['train', 'val', 'test']: |
| data_paths = [] |
| for target in targets: |
| path = os.path.join(data_dir, target, f"{split}.pkl") |
| if os.path.exists(path): |
| data_paths.append((target, path)) |
|
|
| if not data_paths: |
| continue |
|
|
| augment = (split == 'train') |
| bd = binder_dropout if split == 'train' else 0.0 |
|
|
| if paired: |
| |
| all_pairs = [] |
| for target, path in data_paths: |
| ds = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment, |
| esm_dir=esm_dir, target_name=target, |
| binder_dropout=bd) |
| all_pairs.append(ds) |
|
|
| if not all_pairs: |
| continue |
|
|
| |
| from torch.utils.data import ConcatDataset |
| concat_ds = ConcatDataset(all_pairs) |
| p_batch = min(batch_size, max(1, len(concat_ds) // 2)) |
| loaders[split] = DataLoader( |
| concat_ds, batch_size=p_batch, |
| shuffle=(split == 'train'), |
| num_workers=num_workers, |
| collate_fn=collate_paired_fn, |
| pin_memory=True, |
| ) |
| else: |
| dataset = MultiTargetDataset(data_paths, max_nodes=max_nodes, |
| augment=augment, balance=(split == 'train'), |
| rosetta_dir=rosetta_dir, label_source=label_source, |
| esm_dir=esm_dir, binder_dropout=bd) |
|
|
| sampler = None |
| shuffle = (split == 'train') |
| if split == 'train' and dataset.weights is not None: |
| sampler = WeightedRandomSampler( |
| weights=dataset.weights, |
| num_samples=len(dataset), |
| replacement=True |
| ) |
| shuffle = False |
|
|
| loaders[split] = DataLoader( |
| dataset, batch_size=batch_size, |
| shuffle=shuffle, sampler=sampler, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=(split == 'train' and len(dataset) > batch_size), |
| ) |
|
|
| return loaders |
|
|
|
|
| def get_dataloaders(data_dir: str, target: str, batch_size: int = 16, |
| max_nodes: int = 128, num_workers: int = 4, |
| paired: bool = False, esm_dir: str = None, |
| binder_dropout: float = 0.0): |
| """Build train/val/test dataloaders for a given target.""" |
| loaders = {} |
| for split in ['train', 'val', 'test']: |
| path = os.path.join(data_dir, target, f"{split}.pkl") |
| if not os.path.exists(path): |
| continue |
|
|
| augment = (split == 'train') |
| bd = binder_dropout if split == 'train' else 0.0 |
| if paired and split == 'train': |
| dataset = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment, |
| esm_dir=esm_dir, target_name=target, |
| binder_dropout=bd) |
| collate = collate_paired_fn |
| else: |
| dataset = TwoStateComplexDataset(path, max_nodes=max_nodes, augment=augment, |
| esm_dir=esm_dir, target_name=target, |
| binder_dropout=bd) |
| collate = collate_fn |
|
|
| loaders[split] = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=(split == 'train'), |
| num_workers=num_workers, |
| collate_fn=collate, |
| pin_memory=True, |
| drop_last=(split == 'train' and len(dataset) > batch_size), |
| ) |
| return loaders |
|
|