""" PyTorch Dataset for two-state complex scoring. Loads preprocessed graph data and provides batched tensors with padding for variable-sized interface graphs. """ import os import json import pickle import numpy as np import torch from torch.utils.data import Dataset, DataLoader # Global ESM embedding cache: {file_path: tensor} _ESM_CACHE = {} def preload_esm_cache(esm_dir, targets): """Preload all ESM .pt files into global cache before DataLoader workers fork. This ensures forked workers inherit the populated cache via copy-on-write, avoiding redundant I/O across workers. """ import glob as glob_mod n = 0 for target in targets: target_dir = os.path.join(esm_dir, target) if not os.path.isdir(target_dir): continue for pt_file in glob_mod.glob(os.path.join(target_dir, '*.pt')): if pt_file not in _ESM_CACHE: _ESM_CACHE[pt_file] = torch.load(pt_file, map_location='cpu', weights_only=True) n += 1 return n def load_esm_for_sample(sample, esm_dir, target_name, max_nodes=128): """Load and index ESM-2 embeddings for a sample's interface residues. Returns: esm_feats [max_nodes, 1280] or None if unavailable. """ graph = sample['graph'] rec_idx = graph.get('rec_iface_idx') binder_idx = graph.get('binder_iface_idx') if rec_idx is None or binder_idx is None: return None # Get PDB ID (strip chain suffix like "2G1T_AE" -> "2G1T") pdb_id = sample.get('pdb', '') base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id rec_chain = sample.get('rec_chain_id', 'A') binder_chain = sample.get('binder_chain_id', 'B') # Load ESM embeddings (cached) rec_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{rec_chain}.pt') binder_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{binder_chain}.pt') def _load_cached(path): if path not in _ESM_CACHE: if not os.path.exists(path): return None _ESM_CACHE[path] = torch.load(path, map_location='cpu', weights_only=True) return _ESM_CACHE[path] rec_esm = _load_cached(rec_path) binder_esm = _load_cached(binder_path) if rec_esm is None or binder_esm is None: return None esm_dim = rec_esm.shape[-1] # 1280 n_rec = len(rec_idx) n_binder = len(binder_idx) # Index ESM embeddings by interface residue indices (clamp to valid range) rec_idx_safe = np.clip(rec_idx, 0, len(rec_esm) - 1) binder_idx_safe = np.clip(binder_idx, 0, len(binder_esm) - 1) esm_feats = np.zeros((max_nodes, esm_dim), dtype=np.float32) esm_feats[:n_rec] = rec_esm[rec_idx_safe].numpy() esm_feats[n_rec:n_rec + n_binder] = binder_esm[binder_idx_safe].numpy() return esm_feats def load_rosetta_labels(rosetta_dir, target): """Load Rosetta dG labels for a target and normalize to [0,1].""" path = os.path.join(rosetta_dir, f'{target}_rosetta.json') if not os.path.exists(path): return None with open(path) as f: raw = json.load(f) if not raw: return None # Filter outliers: dG values outside [-500, 500] are failed Rosetta runs dG_MIN, dG_MAX = -500.0, 500.0 # Normalize: sigmoid(-dG / tau) maps dG to [0,1] # More negative dG = better binding = higher score tau = 15.0 # temperature; dG=-30 -> 0.88, dG=-15 -> 0.73, dG=0 -> 0.5 labels = {} for pdb_id, metrics in raw.items(): dG = metrics.get('dG_separated', 0.0) if not np.isfinite(dG) or dG < dG_MIN or dG > dG_MAX: continue # skip failed Rosetta runs labels[pdb_id] = 1.0 / (1.0 + np.exp(dG / tau)) labels[pdb_id.upper()] = labels[pdb_id] labels[pdb_id.lower()] = labels[pdb_id] return labels def apply_rosetta_labels(samples, rosetta_labels, label_source='rosetta', alpha=0.5): """Replace or combine sample labels with Rosetta-derived labels.""" if rosetta_labels is None: return n_replaced = 0 for s in samples: pdb_id = s.get('pdb', '') # Strip chain suffixes: "2G1T_AE" -> "2G1T" base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id rosetta_val = rosetta_labels.get(base_pdb) or rosetta_labels.get(base_pdb.upper()) if rosetta_val is None: continue if s['type'] == 'positive': new_label = rosetta_val elif s['type'].startswith('negative'): new_label = 0.0 # apo mismatch stays 0 continue elif s['type'].startswith('decoy'): # Scale Rosetta label by DockQ-proxy quality new_label = s['label'] * rosetta_val else: continue if label_source == 'rosetta': s['label'] = float(new_label) elif label_source == 'combined': s['label'] = float(alpha * s['label'] + (1 - alpha) * new_label) n_replaced += 1 return n_replaced class TwoStateComplexDataset(Dataset): """ Dataset of protein complex interface graphs with two-state labels. Each sample contains: node_feats: [N, node_dim] interface residue features edge_feats: [N, N, edge_dim] pairwise SE(3)-invariant features node_mask: [N] bool label: scalar float in [0, 1] (DockQ proxy / selectivity label) type: str (positive / negative_apo / decoy_*) pdb: str """ def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False, rosetta_labels: dict = None, label_source: str = 'dockq', esm_dir: str = None, target_name: str = None, binder_dropout: float = 0.0): with open(data_path, 'rb') as f: self.samples = pickle.load(f) self.max_nodes = max_nodes self.augment = augment self.esm_dir = esm_dir self.target_name = target_name self.binder_dropout = binder_dropout if label_source != 'dockq' and rosetta_labels: apply_rosetta_labels(self.samples, rosetta_labels, label_source) def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] graph = sample['graph'] node_feats = graph['node_feats'] # [N, node_dim] edge_feats = graph['edge_feats'] # [N, N, edge_dim] node_mask = graph['node_mask'] # [N] N = len(node_feats) assert N <= self.max_nodes, f"Too many nodes: {N} > {self.max_nodes}" # Pad to max_nodes node_dim = node_feats.shape[-1] edge_dim = edge_feats.shape[-1] node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) node_mask_pad = np.zeros(self.max_nodes, dtype=bool) node_feats_pad[:N] = node_feats edge_feats_pad[:N, :N] = edge_feats node_mask_pad[:N] = node_mask # Optional: random coordinate noise augmentation if self.augment: noise = np.random.randn(*node_feats_pad.shape) * 0.01 node_feats_pad = node_feats_pad + noise.astype(np.float32) # Binder-dropout: simulate backbone-only designs by masking binder # sequence features (AA one-hot → UNK, chi angles → 0) apply_binder_drop = (self.binder_dropout > 0 and np.random.rand() < self.binder_dropout) if apply_binder_drop: n_rec = graph.get('n_rec', N // 2) # Zero out binder AA one-hot (dims 0-20), set UNK (dim 20 = 1) node_feats_pad[n_rec:N, :21] = 0.0 node_feats_pad[n_rec:N, 20] = 1.0 # UNK # Zero out binder chi angles (dims 27-30) node_feats_pad[n_rec:N, 27:31] = 0.0 # Keep backbone torsions (dims 21-26) and chain indicator (dim 31) result = { 'node_feats': torch.from_numpy(node_feats_pad), # [max_nodes, node_dim] 'edge_feats': torch.from_numpy(edge_feats_pad), # [max_nodes, max_nodes, edge_dim] 'node_mask': torch.from_numpy(node_mask_pad), # [max_nodes] 'label': torch.tensor(sample['label'], dtype=torch.float32), 'type': sample['type'], 'pdb': sample['pdb'], } # ESM-2 features (lazy load; zero-fill if unavailable) if self.esm_dir: esm = load_esm_for_sample(sample, self.esm_dir, self.target_name or '', self.max_nodes) if esm is not None: esm_feats = esm else: esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32) # Zero binder ESM if binder-dropout active if apply_binder_drop: n_rec = graph.get('n_rec', N // 2) n_binder = graph.get('n_binder', N - n_rec) esm_feats[n_rec:n_rec + n_binder] = 0.0 result['esm_feats'] = torch.from_numpy(esm_feats) return result def collate_fn(batch): """Collate a list of samples into batched tensors.""" node_feats = torch.stack([s['node_feats'] for s in batch]) edge_feats = torch.stack([s['edge_feats'] for s in batch]) node_mask = torch.stack([s['node_mask'] for s in batch]) labels = torch.stack([s['label'] for s in batch]) types = [s['type'] for s in batch] pdbs = [s['pdb'] for s in batch] result = { 'node_feats': node_feats, # [B, N, node_dim] 'edge_feats': edge_feats, # [B, N, N, edge_dim] 'node_mask': node_mask, # [B, N] 'label': labels, # [B] 'type': types, 'pdb': pdbs, } # Stack ESM features if present (handle mixed availability with zero-fill) has_esm = any('esm_feats' in s for s in batch) if has_esm: esm_list = [] for s in batch: if 'esm_feats' in s: esm_list.append(s['esm_feats']) else: # Get shape from a sample that has ESM ref = next(x['esm_feats'] for x in batch if 'esm_feats' in x) esm_list.append(torch.zeros_like(ref)) result['esm_feats'] = torch.stack(esm_list) return result class TwoStateDatasetPaired(Dataset): """ Paired dataset: returns (positive, negative) pairs for selectivity training. Groups samples by PDB ID and pairs positive (holo) with negative (apo) examples. """ def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False, esm_dir: str = None, target_name: str = None, binder_dropout: float = 0.0): with open(data_path, 'rb') as f: samples = pickle.load(f) self.max_nodes = max_nodes self.augment = augment self.esm_dir = esm_dir self.target_name = target_name self.binder_dropout = binder_dropout # Group by PDB from collections import defaultdict by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []}) for s in samples: pdb = s['pdb'] t = s['type'] if t == 'positive': by_pdb[pdb]['positive'].append(s) elif t.startswith('negative'): by_pdb[pdb]['negative'].append(s) elif t.startswith('decoy'): by_pdb[pdb]['decoy'].append(s) # Build pairs: (positive, negative) per PDB self.pairs = [] for pdb, groups in by_pdb.items(): if len(groups['positive']) > 0 and len(groups['negative']) > 0: for pos in groups['positive']: for neg in groups['negative']: self.pairs.append((pos, neg)) # Also add (positive, decoy_large_rmsd) pairs if len(groups['positive']) > 0 and len(groups['decoy']) > 0: large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and float(s['type'].replace('decoy_rmsd', '')) > 4.0] for pos in groups['positive']: for neg in large_decoys[:3]: # limit to 3 hard decoys per positive self.pairs.append((pos, neg)) def __len__(self): return len(self.pairs) def _prepare(self, sample, apply_binder_drop=False): graph = sample['graph'] node_feats = graph['node_feats'] edge_feats = graph['edge_feats'] node_mask = graph['node_mask'] N = len(node_feats) node_dim = node_feats.shape[-1] edge_dim = edge_feats.shape[-1] node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) node_mask_pad = np.zeros(self.max_nodes, dtype=bool) n = min(N, self.max_nodes) node_feats_pad[:n] = node_feats[:n] edge_feats_pad[:n, :n] = edge_feats[:n, :n] node_mask_pad[:n] = node_mask[:n] # Binder-dropout: simulate backbone-only designs if apply_binder_drop: n_rec = graph.get('n_rec', n // 2) node_feats_pad[n_rec:n, :21] = 0.0 node_feats_pad[n_rec:n, 20] = 1.0 # UNK node_feats_pad[n_rec:n, 27:31] = 0.0 result = { 'node_feats': torch.from_numpy(node_feats_pad), 'edge_feats': torch.from_numpy(edge_feats_pad), 'node_mask': torch.from_numpy(node_mask_pad), 'label': torch.tensor(sample['label'], dtype=torch.float32), 'contact_energy': torch.tensor( sample.get('contact_energy', 0.5), dtype=torch.float32 ), } # ESM-2 features (zero-fill if unavailable) if self.esm_dir: esm = load_esm_for_sample(sample, self.esm_dir, self.target_name or '', self.max_nodes) if esm is not None: esm_feats = esm else: esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32) if apply_binder_drop: n_rec = graph.get('n_rec', n // 2) n_binder = graph.get('n_binder', n - n_rec) esm_feats[n_rec:n_rec + n_binder] = 0.0 result['esm_feats'] = torch.from_numpy(esm_feats) return result def __getitem__(self, idx): pos_sample, neg_sample = self.pairs[idx] # Same dropout decision for both pos and neg in a pair drop = (self.binder_dropout > 0 and np.random.rand() < self.binder_dropout) return { 'pos': self._prepare(pos_sample, apply_binder_drop=drop), 'neg': self._prepare(neg_sample, apply_binder_drop=drop), } def collate_paired_fn(batch): """Collate paired (positive, negative) samples.""" pos_batch = { 'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]), 'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]), 'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]), 'label': torch.stack([s['pos']['label'] for s in batch]), 'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]), } neg_batch = { 'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]), 'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]), 'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]), 'label': torch.stack([s['neg']['label'] for s in batch]), 'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]), } # ESM features (handle mixed availability) has_pos_esm = any('esm_feats' in s['pos'] for s in batch) if has_pos_esm: def _stack_esm(batch_list, key): esm_list = [] ref = next((x[key]['esm_feats'] for x in batch_list if 'esm_feats' in x[key]), None) for s in batch_list: if 'esm_feats' in s[key]: esm_list.append(s[key]['esm_feats']) else: esm_list.append(torch.zeros_like(ref)) return torch.stack(esm_list) pos_batch['esm_feats'] = _stack_esm(batch, 'pos') neg_batch['esm_feats'] = _stack_esm(batch, 'neg') return {'pos': pos_batch, 'neg': neg_batch} class PathAwareDatasetPaired(Dataset): """ Paired dataset with transition-path frames for path-aware Phase 2 training. Extends TwoStateDatasetPaired: each sample returns (positive, negative, path_frames) where path_frames is a list of prepared graph dicts for intermediate conformations stored in the positive sample's 'path_graphs' field. """ def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False): with open(data_path, 'rb') as f: samples = pickle.load(f) self.max_nodes = max_nodes self.augment = augment from collections import defaultdict by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []}) for s in samples: pdb = s['pdb'] t = s['type'] if t == 'positive': by_pdb[pdb]['positive'].append(s) elif t.startswith('negative'): by_pdb[pdb]['negative'].append(s) elif t.startswith('decoy'): by_pdb[pdb]['decoy'].append(s) self.pairs = [] for pdb, groups in by_pdb.items(): if len(groups['positive']) > 0 and len(groups['negative']) > 0: for pos in groups['positive']: for neg in groups['negative']: self.pairs.append((pos, neg)) if len(groups['positive']) > 0 and len(groups['decoy']) > 0: large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and float(s['type'].replace('decoy_rmsd', '')) > 4.0] for pos in groups['positive']: for neg in large_decoys[:3]: self.pairs.append((pos, neg)) def _prepare(self, sample): graph = sample['graph'] node_feats = graph['node_feats'] edge_feats = graph['edge_feats'] node_mask = graph['node_mask'] N = len(node_feats) node_dim = node_feats.shape[-1] edge_dim = edge_feats.shape[-1] node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) node_mask_pad = np.zeros(self.max_nodes, dtype=bool) n = min(N, self.max_nodes) node_feats_pad[:n] = node_feats[:n] edge_feats_pad[:n, :n] = edge_feats[:n, :n] node_mask_pad[:n] = node_mask[:n] return { 'node_feats': torch.from_numpy(node_feats_pad), 'edge_feats': torch.from_numpy(edge_feats_pad), 'node_mask': torch.from_numpy(node_mask_pad), 'label': torch.tensor(sample.get('label', 0.0), dtype=torch.float32), 'contact_energy': torch.tensor( sample.get('contact_energy', 0.5), dtype=torch.float32 ), } def _prepare_graph_only(self, path_entry): """Prepare a path frame graph (no label/contact_energy needed).""" graph = path_entry['graph'] node_feats = graph['node_feats'] edge_feats = graph['edge_feats'] node_mask = graph['node_mask'] N = len(node_feats) node_dim = node_feats.shape[-1] edge_dim = edge_feats.shape[-1] node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) node_mask_pad = np.zeros(self.max_nodes, dtype=bool) n = min(N, self.max_nodes) node_feats_pad[:n] = node_feats[:n] edge_feats_pad[:n, :n] = edge_feats[:n, :n] node_mask_pad[:n] = node_mask[:n] return { 'node_feats': torch.from_numpy(node_feats_pad), 'edge_feats': torch.from_numpy(edge_feats_pad), 'node_mask': torch.from_numpy(node_mask_pad), } def __len__(self): return len(self.pairs) def __getitem__(self, idx): pos_sample, neg_sample = self.pairs[idx] result = { 'pos': self._prepare(pos_sample), 'neg': self._prepare(neg_sample), } # Prepare path frames if available path_graphs = pos_sample.get('path_graphs', []) prepared_paths = [] path_taus = [] for pg in path_graphs: prepared_paths.append(self._prepare_graph_only(pg)) path_taus.append(pg['tau']) result['path'] = prepared_paths result['path_taus'] = path_taus return result def collate_path_paired_fn(batch): """Collate paired samples with variable-length path frames.""" pos_batch = { 'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]), 'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]), 'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]), 'label': torch.stack([s['pos']['label'] for s in batch]), 'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]), } neg_batch = { 'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]), 'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]), 'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]), 'label': torch.stack([s['neg']['label'] for s in batch]), 'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]), } # Collate path frames: find max K across batch, pad shorter ones max_k = max((len(s['path']) for s in batch), default=0) path_batches = [] path_taus = [] if max_k > 0: # Build a zero-filled placeholder for padding (graph-only keys) ref = batch[0]['path'][0] if batch[0]['path'] else batch[0]['pos'] zero_placeholder = { 'node_feats': torch.zeros_like(ref['node_feats']), 'edge_feats': torch.zeros_like(ref['edge_feats']), 'node_mask': torch.zeros_like(ref['node_mask']), } for k_idx in range(max_k): frames_at_k = [] taus_at_k = [] for s in batch: if k_idx < len(s['path']): frames_at_k.append(s['path'][k_idx]) taus_at_k.append(s['path_taus'][k_idx]) else: frames_at_k.append(zero_placeholder) taus_at_k.append(1.0) path_batches.append({ 'node_feats': torch.stack([f['node_feats'] for f in frames_at_k]), 'edge_feats': torch.stack([f['edge_feats'] for f in frames_at_k]), 'node_mask': torch.stack([f['node_mask'] for f in frames_at_k]), }) path_taus.append(taus_at_k[0]) result = {'pos': pos_batch, 'neg': neg_batch} if path_batches: result['path'] = path_batches result['path_taus'] = path_taus return result class MultiTargetDataset(Dataset): """ Pooled dataset combining samples from multiple targets. Supports balanced sampling across targets. """ def __init__(self, data_paths: list, max_nodes: int = 128, augment: bool = False, balance: bool = True, rosetta_dir: str = None, label_source: str = 'dockq', esm_dir: str = None, binder_dropout: float = 0.0): """ Args: data_paths: list of (target_name, pkl_path) tuples max_nodes: max interface graph size augment: apply noise augmentation balance: if True, oversample smaller targets to balance rosetta_dir: directory containing Rosetta label JSONs label_source: 'dockq', 'rosetta', or 'combined' """ self.max_nodes = max_nodes self.augment = augment self.esm_dir = esm_dir self.binder_dropout = binder_dropout # Load all samples with target labels self.samples = [] self.target_indices = {} # target_name -> list of indices for target_name, path in data_paths: if not os.path.exists(path): continue with open(path, 'rb') as f: target_samples = pickle.load(f) # Apply Rosetta labels if requested if label_source != 'dockq' and rosetta_dir: rl = load_rosetta_labels(rosetta_dir, target_name) if rl: apply_rosetta_labels(target_samples, rl, label_source) start_idx = len(self.samples) for s in target_samples: s['_target'] = target_name self.samples.append(s) end_idx = len(self.samples) self.target_indices[target_name] = list(range(start_idx, end_idx)) # Build balanced sampling weights if balance and len(self.target_indices) > 1: non_empty = {k: v for k, v in self.target_indices.items() if len(v) > 0} max_count = max(len(idxs) for idxs in non_empty.values()) if non_empty else 1 self.weights = np.zeros(len(self.samples)) for target_name, idxs in self.target_indices.items(): if len(idxs) == 0: continue weight = max_count / len(idxs) for i in idxs: self.weights[i] = weight self.weights /= self.weights.sum() else: self.weights = None def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] graph = sample['graph'] node_feats = graph['node_feats'] edge_feats = graph['edge_feats'] node_mask = graph['node_mask'] N = len(node_feats) node_dim = node_feats.shape[-1] edge_dim = edge_feats.shape[-1] node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32) edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32) node_mask_pad = np.zeros(self.max_nodes, dtype=bool) n = min(N, self.max_nodes) node_feats_pad[:n] = node_feats[:n] edge_feats_pad[:n, :n] = edge_feats[:n, :n] node_mask_pad[:n] = node_mask[:n] if self.augment: noise = np.random.randn(*node_feats_pad.shape) * 0.01 node_feats_pad = node_feats_pad + noise.astype(np.float32) # Binder-dropout: simulate backbone-only designs apply_binder_drop = (self.binder_dropout > 0 and np.random.rand() < self.binder_dropout) if apply_binder_drop: n_rec = graph.get('n_rec', N // 2) node_feats_pad[n_rec:N, :21] = 0.0 node_feats_pad[n_rec:N, 20] = 1.0 # UNK node_feats_pad[n_rec:N, 27:31] = 0.0 result = { 'node_feats': torch.from_numpy(node_feats_pad), 'edge_feats': torch.from_numpy(edge_feats_pad), 'node_mask': torch.from_numpy(node_mask_pad), 'label': torch.tensor(sample['label'], dtype=torch.float32), 'type': sample['type'], 'pdb': sample['pdb'], 'target': sample.get('_target', 'unknown'), } # ESM-2 features (zero-fill if unavailable) if self.esm_dir: target_name = sample.get('_target', 'unknown') esm = load_esm_for_sample(sample, self.esm_dir, target_name, self.max_nodes) if esm is not None: esm_feats = esm else: esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32) if apply_binder_drop: n_rec = graph.get('n_rec', N // 2) n_binder = graph.get('n_binder', N - n_rec) esm_feats[n_rec:n_rec + n_binder] = 0.0 result['esm_feats'] = torch.from_numpy(esm_feats) return result @staticmethod def get_pooled_dataloaders(data_dir, targets, batch_size=16, max_nodes=128, num_workers=4, paired=False, rosetta_dir=None, label_source='dockq', esm_dir=None, binder_dropout=0.0): """Build pooled dataloaders from multiple targets. Args: data_dir: root data directory targets: list of target names batch_size: batch size max_nodes: max interface nodes num_workers: dataloader workers paired: if True, build paired dataloaders for Phase 2 rosetta_dir: directory with Rosetta label JSONs label_source: 'dockq', 'rosetta', or 'combined' """ from torch.utils.data import WeightedRandomSampler # Preload ESM embeddings into global cache before creating datasets/workers if esm_dir: n_loaded = preload_esm_cache(esm_dir, targets) loaders = {} for split in ['train', 'val', 'test']: data_paths = [] for target in targets: path = os.path.join(data_dir, target, f"{split}.pkl") if os.path.exists(path): data_paths.append((target, path)) if not data_paths: continue augment = (split == 'train') bd = binder_dropout if split == 'train' else 0.0 if paired: # For paired mode, concatenate paired datasets all_pairs = [] for target, path in data_paths: ds = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment, esm_dir=esm_dir, target_name=target, binder_dropout=bd) all_pairs.append(ds) if not all_pairs: continue # Use ConcatDataset from torch.utils.data import ConcatDataset concat_ds = ConcatDataset(all_pairs) p_batch = min(batch_size, max(1, len(concat_ds) // 2)) loaders[split] = DataLoader( concat_ds, batch_size=p_batch, shuffle=(split == 'train'), num_workers=num_workers, collate_fn=collate_paired_fn, pin_memory=True, ) else: dataset = MultiTargetDataset(data_paths, max_nodes=max_nodes, augment=augment, balance=(split == 'train'), rosetta_dir=rosetta_dir, label_source=label_source, esm_dir=esm_dir, binder_dropout=bd) sampler = None shuffle = (split == 'train') if split == 'train' and dataset.weights is not None: sampler = WeightedRandomSampler( weights=dataset.weights, num_samples=len(dataset), replacement=True ) shuffle = False loaders[split] = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True, drop_last=(split == 'train' and len(dataset) > batch_size), ) return loaders def get_dataloaders(data_dir: str, target: str, batch_size: int = 16, max_nodes: int = 128, num_workers: int = 4, paired: bool = False, esm_dir: str = None, binder_dropout: float = 0.0): """Build train/val/test dataloaders for a given target.""" loaders = {} for split in ['train', 'val', 'test']: path = os.path.join(data_dir, target, f"{split}.pkl") if not os.path.exists(path): continue augment = (split == 'train') bd = binder_dropout if split == 'train' else 0.0 if paired and split == 'train': dataset = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment, esm_dir=esm_dir, target_name=target, binder_dropout=bd) collate = collate_paired_fn else: dataset = TwoStateComplexDataset(path, max_nodes=max_nodes, augment=augment, esm_dir=esm_dir, target_name=target, binder_dropout=bd) collate = collate_fn loaders[split] = DataLoader( dataset, batch_size=batch_size, shuffle=(split == 'train'), num_workers=num_workers, collate_fn=collate, pin_memory=True, drop_last=(split == 'train' and len(dataset) > batch_size), ) return loaders