AlloGen / code /data /dataset.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
PyTorch Dataset for two-state complex scoring.
Loads preprocessed graph data and provides batched tensors
with padding for variable-sized interface graphs.
"""
import os
import json
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
# Global ESM embedding cache: {file_path: tensor}
_ESM_CACHE = {}
def preload_esm_cache(esm_dir, targets):
"""Preload all ESM .pt files into global cache before DataLoader workers fork.
This ensures forked workers inherit the populated cache via copy-on-write,
avoiding redundant I/O across workers.
"""
import glob as glob_mod
n = 0
for target in targets:
target_dir = os.path.join(esm_dir, target)
if not os.path.isdir(target_dir):
continue
for pt_file in glob_mod.glob(os.path.join(target_dir, '*.pt')):
if pt_file not in _ESM_CACHE:
_ESM_CACHE[pt_file] = torch.load(pt_file, map_location='cpu', weights_only=True)
n += 1
return n
def load_esm_for_sample(sample, esm_dir, target_name, max_nodes=128):
"""Load and index ESM-2 embeddings for a sample's interface residues.
Returns: esm_feats [max_nodes, 1280] or None if unavailable.
"""
graph = sample['graph']
rec_idx = graph.get('rec_iface_idx')
binder_idx = graph.get('binder_iface_idx')
if rec_idx is None or binder_idx is None:
return None
# Get PDB ID (strip chain suffix like "2G1T_AE" -> "2G1T")
pdb_id = sample.get('pdb', '')
base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id
rec_chain = sample.get('rec_chain_id', 'A')
binder_chain = sample.get('binder_chain_id', 'B')
# Load ESM embeddings (cached)
rec_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{rec_chain}.pt')
binder_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{binder_chain}.pt')
def _load_cached(path):
if path not in _ESM_CACHE:
if not os.path.exists(path):
return None
_ESM_CACHE[path] = torch.load(path, map_location='cpu', weights_only=True)
return _ESM_CACHE[path]
rec_esm = _load_cached(rec_path)
binder_esm = _load_cached(binder_path)
if rec_esm is None or binder_esm is None:
return None
esm_dim = rec_esm.shape[-1] # 1280
n_rec = len(rec_idx)
n_binder = len(binder_idx)
# Index ESM embeddings by interface residue indices (clamp to valid range)
rec_idx_safe = np.clip(rec_idx, 0, len(rec_esm) - 1)
binder_idx_safe = np.clip(binder_idx, 0, len(binder_esm) - 1)
esm_feats = np.zeros((max_nodes, esm_dim), dtype=np.float32)
esm_feats[:n_rec] = rec_esm[rec_idx_safe].numpy()
esm_feats[n_rec:n_rec + n_binder] = binder_esm[binder_idx_safe].numpy()
return esm_feats
def load_rosetta_labels(rosetta_dir, target):
"""Load Rosetta dG labels for a target and normalize to [0,1]."""
path = os.path.join(rosetta_dir, f'{target}_rosetta.json')
if not os.path.exists(path):
return None
with open(path) as f:
raw = json.load(f)
if not raw:
return None
# Filter outliers: dG values outside [-500, 500] are failed Rosetta runs
dG_MIN, dG_MAX = -500.0, 500.0
# Normalize: sigmoid(-dG / tau) maps dG to [0,1]
# More negative dG = better binding = higher score
tau = 15.0 # temperature; dG=-30 -> 0.88, dG=-15 -> 0.73, dG=0 -> 0.5
labels = {}
for pdb_id, metrics in raw.items():
dG = metrics.get('dG_separated', 0.0)
if not np.isfinite(dG) or dG < dG_MIN or dG > dG_MAX:
continue # skip failed Rosetta runs
labels[pdb_id] = 1.0 / (1.0 + np.exp(dG / tau))
labels[pdb_id.upper()] = labels[pdb_id]
labels[pdb_id.lower()] = labels[pdb_id]
return labels
def apply_rosetta_labels(samples, rosetta_labels, label_source='rosetta', alpha=0.5):
"""Replace or combine sample labels with Rosetta-derived labels."""
if rosetta_labels is None:
return
n_replaced = 0
for s in samples:
pdb_id = s.get('pdb', '')
# Strip chain suffixes: "2G1T_AE" -> "2G1T"
base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id
rosetta_val = rosetta_labels.get(base_pdb) or rosetta_labels.get(base_pdb.upper())
if rosetta_val is None:
continue
if s['type'] == 'positive':
new_label = rosetta_val
elif s['type'].startswith('negative'):
new_label = 0.0 # apo mismatch stays 0
continue
elif s['type'].startswith('decoy'):
# Scale Rosetta label by DockQ-proxy quality
new_label = s['label'] * rosetta_val
else:
continue
if label_source == 'rosetta':
s['label'] = float(new_label)
elif label_source == 'combined':
s['label'] = float(alpha * s['label'] + (1 - alpha) * new_label)
n_replaced += 1
return n_replaced
class TwoStateComplexDataset(Dataset):
"""
Dataset of protein complex interface graphs with two-state labels.
Each sample contains:
node_feats: [N, node_dim] interface residue features
edge_feats: [N, N, edge_dim] pairwise SE(3)-invariant features
node_mask: [N] bool
label: scalar float in [0, 1] (DockQ proxy / selectivity label)
type: str (positive / negative_apo / decoy_*)
pdb: str
"""
def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False,
rosetta_labels: dict = None, label_source: str = 'dockq',
esm_dir: str = None, target_name: str = None,
binder_dropout: float = 0.0):
with open(data_path, 'rb') as f:
self.samples = pickle.load(f)
self.max_nodes = max_nodes
self.augment = augment
self.esm_dir = esm_dir
self.target_name = target_name
self.binder_dropout = binder_dropout
if label_source != 'dockq' and rosetta_labels:
apply_rosetta_labels(self.samples, rosetta_labels, label_source)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
graph = sample['graph']
node_feats = graph['node_feats'] # [N, node_dim]
edge_feats = graph['edge_feats'] # [N, N, edge_dim]
node_mask = graph['node_mask'] # [N]
N = len(node_feats)
assert N <= self.max_nodes, f"Too many nodes: {N} > {self.max_nodes}"
# Pad to max_nodes
node_dim = node_feats.shape[-1]
edge_dim = edge_feats.shape[-1]
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
node_feats_pad[:N] = node_feats
edge_feats_pad[:N, :N] = edge_feats
node_mask_pad[:N] = node_mask
# Optional: random coordinate noise augmentation
if self.augment:
noise = np.random.randn(*node_feats_pad.shape) * 0.01
node_feats_pad = node_feats_pad + noise.astype(np.float32)
# Binder-dropout: simulate backbone-only designs by masking binder
# sequence features (AA one-hot → UNK, chi angles → 0)
apply_binder_drop = (self.binder_dropout > 0
and np.random.rand() < self.binder_dropout)
if apply_binder_drop:
n_rec = graph.get('n_rec', N // 2)
# Zero out binder AA one-hot (dims 0-20), set UNK (dim 20 = 1)
node_feats_pad[n_rec:N, :21] = 0.0
node_feats_pad[n_rec:N, 20] = 1.0 # UNK
# Zero out binder chi angles (dims 27-30)
node_feats_pad[n_rec:N, 27:31] = 0.0
# Keep backbone torsions (dims 21-26) and chain indicator (dim 31)
result = {
'node_feats': torch.from_numpy(node_feats_pad), # [max_nodes, node_dim]
'edge_feats': torch.from_numpy(edge_feats_pad), # [max_nodes, max_nodes, edge_dim]
'node_mask': torch.from_numpy(node_mask_pad), # [max_nodes]
'label': torch.tensor(sample['label'], dtype=torch.float32),
'type': sample['type'],
'pdb': sample['pdb'],
}
# ESM-2 features (lazy load; zero-fill if unavailable)
if self.esm_dir:
esm = load_esm_for_sample(sample, self.esm_dir,
self.target_name or '', self.max_nodes)
if esm is not None:
esm_feats = esm
else:
esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
# Zero binder ESM if binder-dropout active
if apply_binder_drop:
n_rec = graph.get('n_rec', N // 2)
n_binder = graph.get('n_binder', N - n_rec)
esm_feats[n_rec:n_rec + n_binder] = 0.0
result['esm_feats'] = torch.from_numpy(esm_feats)
return result
def collate_fn(batch):
"""Collate a list of samples into batched tensors."""
node_feats = torch.stack([s['node_feats'] for s in batch])
edge_feats = torch.stack([s['edge_feats'] for s in batch])
node_mask = torch.stack([s['node_mask'] for s in batch])
labels = torch.stack([s['label'] for s in batch])
types = [s['type'] for s in batch]
pdbs = [s['pdb'] for s in batch]
result = {
'node_feats': node_feats, # [B, N, node_dim]
'edge_feats': edge_feats, # [B, N, N, edge_dim]
'node_mask': node_mask, # [B, N]
'label': labels, # [B]
'type': types,
'pdb': pdbs,
}
# Stack ESM features if present (handle mixed availability with zero-fill)
has_esm = any('esm_feats' in s for s in batch)
if has_esm:
esm_list = []
for s in batch:
if 'esm_feats' in s:
esm_list.append(s['esm_feats'])
else:
# Get shape from a sample that has ESM
ref = next(x['esm_feats'] for x in batch if 'esm_feats' in x)
esm_list.append(torch.zeros_like(ref))
result['esm_feats'] = torch.stack(esm_list)
return result
class TwoStateDatasetPaired(Dataset):
"""
Paired dataset: returns (positive, negative) pairs for selectivity training.
Groups samples by PDB ID and pairs positive (holo) with negative (apo) examples.
"""
def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False,
esm_dir: str = None, target_name: str = None,
binder_dropout: float = 0.0):
with open(data_path, 'rb') as f:
samples = pickle.load(f)
self.max_nodes = max_nodes
self.augment = augment
self.esm_dir = esm_dir
self.target_name = target_name
self.binder_dropout = binder_dropout
# Group by PDB
from collections import defaultdict
by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []})
for s in samples:
pdb = s['pdb']
t = s['type']
if t == 'positive':
by_pdb[pdb]['positive'].append(s)
elif t.startswith('negative'):
by_pdb[pdb]['negative'].append(s)
elif t.startswith('decoy'):
by_pdb[pdb]['decoy'].append(s)
# Build pairs: (positive, negative) per PDB
self.pairs = []
for pdb, groups in by_pdb.items():
if len(groups['positive']) > 0 and len(groups['negative']) > 0:
for pos in groups['positive']:
for neg in groups['negative']:
self.pairs.append((pos, neg))
# Also add (positive, decoy_large_rmsd) pairs
if len(groups['positive']) > 0 and len(groups['decoy']) > 0:
large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and
float(s['type'].replace('decoy_rmsd', '')) > 4.0]
for pos in groups['positive']:
for neg in large_decoys[:3]: # limit to 3 hard decoys per positive
self.pairs.append((pos, neg))
def __len__(self):
return len(self.pairs)
def _prepare(self, sample, apply_binder_drop=False):
graph = sample['graph']
node_feats = graph['node_feats']
edge_feats = graph['edge_feats']
node_mask = graph['node_mask']
N = len(node_feats)
node_dim = node_feats.shape[-1]
edge_dim = edge_feats.shape[-1]
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
n = min(N, self.max_nodes)
node_feats_pad[:n] = node_feats[:n]
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
node_mask_pad[:n] = node_mask[:n]
# Binder-dropout: simulate backbone-only designs
if apply_binder_drop:
n_rec = graph.get('n_rec', n // 2)
node_feats_pad[n_rec:n, :21] = 0.0
node_feats_pad[n_rec:n, 20] = 1.0 # UNK
node_feats_pad[n_rec:n, 27:31] = 0.0
result = {
'node_feats': torch.from_numpy(node_feats_pad),
'edge_feats': torch.from_numpy(edge_feats_pad),
'node_mask': torch.from_numpy(node_mask_pad),
'label': torch.tensor(sample['label'], dtype=torch.float32),
'contact_energy': torch.tensor(
sample.get('contact_energy', 0.5), dtype=torch.float32
),
}
# ESM-2 features (zero-fill if unavailable)
if self.esm_dir:
esm = load_esm_for_sample(sample, self.esm_dir,
self.target_name or '', self.max_nodes)
if esm is not None:
esm_feats = esm
else:
esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
if apply_binder_drop:
n_rec = graph.get('n_rec', n // 2)
n_binder = graph.get('n_binder', n - n_rec)
esm_feats[n_rec:n_rec + n_binder] = 0.0
result['esm_feats'] = torch.from_numpy(esm_feats)
return result
def __getitem__(self, idx):
pos_sample, neg_sample = self.pairs[idx]
# Same dropout decision for both pos and neg in a pair
drop = (self.binder_dropout > 0
and np.random.rand() < self.binder_dropout)
return {
'pos': self._prepare(pos_sample, apply_binder_drop=drop),
'neg': self._prepare(neg_sample, apply_binder_drop=drop),
}
def collate_paired_fn(batch):
"""Collate paired (positive, negative) samples."""
pos_batch = {
'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]),
'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]),
'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]),
'label': torch.stack([s['pos']['label'] for s in batch]),
'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]),
}
neg_batch = {
'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]),
'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]),
'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]),
'label': torch.stack([s['neg']['label'] for s in batch]),
'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]),
}
# ESM features (handle mixed availability)
has_pos_esm = any('esm_feats' in s['pos'] for s in batch)
if has_pos_esm:
def _stack_esm(batch_list, key):
esm_list = []
ref = next((x[key]['esm_feats'] for x in batch_list if 'esm_feats' in x[key]), None)
for s in batch_list:
if 'esm_feats' in s[key]:
esm_list.append(s[key]['esm_feats'])
else:
esm_list.append(torch.zeros_like(ref))
return torch.stack(esm_list)
pos_batch['esm_feats'] = _stack_esm(batch, 'pos')
neg_batch['esm_feats'] = _stack_esm(batch, 'neg')
return {'pos': pos_batch, 'neg': neg_batch}
class PathAwareDatasetPaired(Dataset):
"""
Paired dataset with transition-path frames for path-aware Phase 2 training.
Extends TwoStateDatasetPaired: each sample returns (positive, negative, path_frames)
where path_frames is a list of prepared graph dicts for intermediate conformations
stored in the positive sample's 'path_graphs' field.
"""
def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False):
with open(data_path, 'rb') as f:
samples = pickle.load(f)
self.max_nodes = max_nodes
self.augment = augment
from collections import defaultdict
by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []})
for s in samples:
pdb = s['pdb']
t = s['type']
if t == 'positive':
by_pdb[pdb]['positive'].append(s)
elif t.startswith('negative'):
by_pdb[pdb]['negative'].append(s)
elif t.startswith('decoy'):
by_pdb[pdb]['decoy'].append(s)
self.pairs = []
for pdb, groups in by_pdb.items():
if len(groups['positive']) > 0 and len(groups['negative']) > 0:
for pos in groups['positive']:
for neg in groups['negative']:
self.pairs.append((pos, neg))
if len(groups['positive']) > 0 and len(groups['decoy']) > 0:
large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and
float(s['type'].replace('decoy_rmsd', '')) > 4.0]
for pos in groups['positive']:
for neg in large_decoys[:3]:
self.pairs.append((pos, neg))
def _prepare(self, sample):
graph = sample['graph']
node_feats = graph['node_feats']
edge_feats = graph['edge_feats']
node_mask = graph['node_mask']
N = len(node_feats)
node_dim = node_feats.shape[-1]
edge_dim = edge_feats.shape[-1]
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
n = min(N, self.max_nodes)
node_feats_pad[:n] = node_feats[:n]
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
node_mask_pad[:n] = node_mask[:n]
return {
'node_feats': torch.from_numpy(node_feats_pad),
'edge_feats': torch.from_numpy(edge_feats_pad),
'node_mask': torch.from_numpy(node_mask_pad),
'label': torch.tensor(sample.get('label', 0.0), dtype=torch.float32),
'contact_energy': torch.tensor(
sample.get('contact_energy', 0.5), dtype=torch.float32
),
}
def _prepare_graph_only(self, path_entry):
"""Prepare a path frame graph (no label/contact_energy needed)."""
graph = path_entry['graph']
node_feats = graph['node_feats']
edge_feats = graph['edge_feats']
node_mask = graph['node_mask']
N = len(node_feats)
node_dim = node_feats.shape[-1]
edge_dim = edge_feats.shape[-1]
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
n = min(N, self.max_nodes)
node_feats_pad[:n] = node_feats[:n]
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
node_mask_pad[:n] = node_mask[:n]
return {
'node_feats': torch.from_numpy(node_feats_pad),
'edge_feats': torch.from_numpy(edge_feats_pad),
'node_mask': torch.from_numpy(node_mask_pad),
}
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
pos_sample, neg_sample = self.pairs[idx]
result = {
'pos': self._prepare(pos_sample),
'neg': self._prepare(neg_sample),
}
# Prepare path frames if available
path_graphs = pos_sample.get('path_graphs', [])
prepared_paths = []
path_taus = []
for pg in path_graphs:
prepared_paths.append(self._prepare_graph_only(pg))
path_taus.append(pg['tau'])
result['path'] = prepared_paths
result['path_taus'] = path_taus
return result
def collate_path_paired_fn(batch):
"""Collate paired samples with variable-length path frames."""
pos_batch = {
'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]),
'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]),
'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]),
'label': torch.stack([s['pos']['label'] for s in batch]),
'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]),
}
neg_batch = {
'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]),
'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]),
'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]),
'label': torch.stack([s['neg']['label'] for s in batch]),
'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]),
}
# Collate path frames: find max K across batch, pad shorter ones
max_k = max((len(s['path']) for s in batch), default=0)
path_batches = []
path_taus = []
if max_k > 0:
# Build a zero-filled placeholder for padding (graph-only keys)
ref = batch[0]['path'][0] if batch[0]['path'] else batch[0]['pos']
zero_placeholder = {
'node_feats': torch.zeros_like(ref['node_feats']),
'edge_feats': torch.zeros_like(ref['edge_feats']),
'node_mask': torch.zeros_like(ref['node_mask']),
}
for k_idx in range(max_k):
frames_at_k = []
taus_at_k = []
for s in batch:
if k_idx < len(s['path']):
frames_at_k.append(s['path'][k_idx])
taus_at_k.append(s['path_taus'][k_idx])
else:
frames_at_k.append(zero_placeholder)
taus_at_k.append(1.0)
path_batches.append({
'node_feats': torch.stack([f['node_feats'] for f in frames_at_k]),
'edge_feats': torch.stack([f['edge_feats'] for f in frames_at_k]),
'node_mask': torch.stack([f['node_mask'] for f in frames_at_k]),
})
path_taus.append(taus_at_k[0])
result = {'pos': pos_batch, 'neg': neg_batch}
if path_batches:
result['path'] = path_batches
result['path_taus'] = path_taus
return result
class MultiTargetDataset(Dataset):
"""
Pooled dataset combining samples from multiple targets.
Supports balanced sampling across targets.
"""
def __init__(self, data_paths: list, max_nodes: int = 128, augment: bool = False,
balance: bool = True, rosetta_dir: str = None, label_source: str = 'dockq',
esm_dir: str = None, binder_dropout: float = 0.0):
"""
Args:
data_paths: list of (target_name, pkl_path) tuples
max_nodes: max interface graph size
augment: apply noise augmentation
balance: if True, oversample smaller targets to balance
rosetta_dir: directory containing Rosetta label JSONs
label_source: 'dockq', 'rosetta', or 'combined'
"""
self.max_nodes = max_nodes
self.augment = augment
self.esm_dir = esm_dir
self.binder_dropout = binder_dropout
# Load all samples with target labels
self.samples = []
self.target_indices = {} # target_name -> list of indices
for target_name, path in data_paths:
if not os.path.exists(path):
continue
with open(path, 'rb') as f:
target_samples = pickle.load(f)
# Apply Rosetta labels if requested
if label_source != 'dockq' and rosetta_dir:
rl = load_rosetta_labels(rosetta_dir, target_name)
if rl:
apply_rosetta_labels(target_samples, rl, label_source)
start_idx = len(self.samples)
for s in target_samples:
s['_target'] = target_name
self.samples.append(s)
end_idx = len(self.samples)
self.target_indices[target_name] = list(range(start_idx, end_idx))
# Build balanced sampling weights
if balance and len(self.target_indices) > 1:
non_empty = {k: v for k, v in self.target_indices.items() if len(v) > 0}
max_count = max(len(idxs) for idxs in non_empty.values()) if non_empty else 1
self.weights = np.zeros(len(self.samples))
for target_name, idxs in self.target_indices.items():
if len(idxs) == 0:
continue
weight = max_count / len(idxs)
for i in idxs:
self.weights[i] = weight
self.weights /= self.weights.sum()
else:
self.weights = None
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
graph = sample['graph']
node_feats = graph['node_feats']
edge_feats = graph['edge_feats']
node_mask = graph['node_mask']
N = len(node_feats)
node_dim = node_feats.shape[-1]
edge_dim = edge_feats.shape[-1]
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
n = min(N, self.max_nodes)
node_feats_pad[:n] = node_feats[:n]
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
node_mask_pad[:n] = node_mask[:n]
if self.augment:
noise = np.random.randn(*node_feats_pad.shape) * 0.01
node_feats_pad = node_feats_pad + noise.astype(np.float32)
# Binder-dropout: simulate backbone-only designs
apply_binder_drop = (self.binder_dropout > 0
and np.random.rand() < self.binder_dropout)
if apply_binder_drop:
n_rec = graph.get('n_rec', N // 2)
node_feats_pad[n_rec:N, :21] = 0.0
node_feats_pad[n_rec:N, 20] = 1.0 # UNK
node_feats_pad[n_rec:N, 27:31] = 0.0
result = {
'node_feats': torch.from_numpy(node_feats_pad),
'edge_feats': torch.from_numpy(edge_feats_pad),
'node_mask': torch.from_numpy(node_mask_pad),
'label': torch.tensor(sample['label'], dtype=torch.float32),
'type': sample['type'],
'pdb': sample['pdb'],
'target': sample.get('_target', 'unknown'),
}
# ESM-2 features (zero-fill if unavailable)
if self.esm_dir:
target_name = sample.get('_target', 'unknown')
esm = load_esm_for_sample(sample, self.esm_dir, target_name, self.max_nodes)
if esm is not None:
esm_feats = esm
else:
esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
if apply_binder_drop:
n_rec = graph.get('n_rec', N // 2)
n_binder = graph.get('n_binder', N - n_rec)
esm_feats[n_rec:n_rec + n_binder] = 0.0
result['esm_feats'] = torch.from_numpy(esm_feats)
return result
@staticmethod
def get_pooled_dataloaders(data_dir, targets, batch_size=16, max_nodes=128,
num_workers=4, paired=False,
rosetta_dir=None, label_source='dockq',
esm_dir=None, binder_dropout=0.0):
"""Build pooled dataloaders from multiple targets.
Args:
data_dir: root data directory
targets: list of target names
batch_size: batch size
max_nodes: max interface nodes
num_workers: dataloader workers
paired: if True, build paired dataloaders for Phase 2
rosetta_dir: directory with Rosetta label JSONs
label_source: 'dockq', 'rosetta', or 'combined'
"""
from torch.utils.data import WeightedRandomSampler
# Preload ESM embeddings into global cache before creating datasets/workers
if esm_dir:
n_loaded = preload_esm_cache(esm_dir, targets)
loaders = {}
for split in ['train', 'val', 'test']:
data_paths = []
for target in targets:
path = os.path.join(data_dir, target, f"{split}.pkl")
if os.path.exists(path):
data_paths.append((target, path))
if not data_paths:
continue
augment = (split == 'train')
bd = binder_dropout if split == 'train' else 0.0
if paired:
# For paired mode, concatenate paired datasets
all_pairs = []
for target, path in data_paths:
ds = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment,
esm_dir=esm_dir, target_name=target,
binder_dropout=bd)
all_pairs.append(ds)
if not all_pairs:
continue
# Use ConcatDataset
from torch.utils.data import ConcatDataset
concat_ds = ConcatDataset(all_pairs)
p_batch = min(batch_size, max(1, len(concat_ds) // 2))
loaders[split] = DataLoader(
concat_ds, batch_size=p_batch,
shuffle=(split == 'train'),
num_workers=num_workers,
collate_fn=collate_paired_fn,
pin_memory=True,
)
else:
dataset = MultiTargetDataset(data_paths, max_nodes=max_nodes,
augment=augment, balance=(split == 'train'),
rosetta_dir=rosetta_dir, label_source=label_source,
esm_dir=esm_dir, binder_dropout=bd)
sampler = None
shuffle = (split == 'train')
if split == 'train' and dataset.weights is not None:
sampler = WeightedRandomSampler(
weights=dataset.weights,
num_samples=len(dataset),
replacement=True
)
shuffle = False
loaders[split] = DataLoader(
dataset, batch_size=batch_size,
shuffle=shuffle, sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
drop_last=(split == 'train' and len(dataset) > batch_size),
)
return loaders
def get_dataloaders(data_dir: str, target: str, batch_size: int = 16,
max_nodes: int = 128, num_workers: int = 4,
paired: bool = False, esm_dir: str = None,
binder_dropout: float = 0.0):
"""Build train/val/test dataloaders for a given target."""
loaders = {}
for split in ['train', 'val', 'test']:
path = os.path.join(data_dir, target, f"{split}.pkl")
if not os.path.exists(path):
continue
augment = (split == 'train')
bd = binder_dropout if split == 'train' else 0.0
if paired and split == 'train':
dataset = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment,
esm_dir=esm_dir, target_name=target,
binder_dropout=bd)
collate = collate_paired_fn
else:
dataset = TwoStateComplexDataset(path, max_nodes=max_nodes, augment=augment,
esm_dir=esm_dir, target_name=target,
binder_dropout=bd)
collate = collate_fn
loaders[split] = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(split == 'train'),
num_workers=num_workers,
collate_fn=collate,
pin_memory=True,
drop_last=(split == 'train' and len(dataset) > batch_size),
)
return loaders