File size: 12,621 Bytes
ad9572d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Evaluation script for the trained Q_theta scorer.

Computes:
  1. Selectivity metrics (gap, ranking accuracy, AUC)
  2. DockQ correlation (Spearman/Pearson)
  3. Score distributions (violin plots)
  4. Best-of-K analysis (as function of K)
  5. Per-target breakdown

Usage:
    python code/scripts/evaluate.py \
        --target cam \
        --checkpoint checkpoints/Q_theta_phase2.pt \
        --data_dir data/processed \
        --gpu 7
"""

import os
import sys
import argparse
import logging
import json
import numpy as np
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import roc_auc_score, roc_curve

_CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if _CODE_DIR not in sys.path:
    sys.path.insert(0, _CODE_DIR)

from models.scorer import build_model
from data.dataset import TwoStateComplexDataset, collate_fn
from torch.utils.data import DataLoader

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)


def compute_best_of_k(pos_scores, K_values=None, threshold=0.7):
    """
    Simulate best-of-K selection: what fraction of draws contain at least one good binder?
    Assumes pos_scores are from a distribution of candidate binders for goal state X+.
    """
    if K_values is None:
        K_values = [1, 2, 5, 10, 20, 50, 100]
    results = {}
    n = len(pos_scores)
    n_trials = 1000

    for K in K_values:
        successes = 0
        for _ in range(n_trials):
            idxs = np.random.choice(n, size=min(K, n), replace=False)
            best_score = pos_scores[idxs].max()
            if best_score >= threshold:
                successes += 1
        results[K] = successes / n_trials

    return results


def compute_selectivity_margin(pos_scores, neg_scores):
    """Compute per-sample selectivity margin S_theta."""
    eps = 1e-6
    pos_logit = np.log(pos_scores.clip(eps, 1-eps) / (1-pos_scores).clip(eps))
    neg_logit = np.log(neg_scores.clip(eps, 1-eps) / (1-neg_scores).clip(eps))
    selectivity = pos_logit - np.log(np.exp(neg_logit) + 1e-8)
    return selectivity


def plot_score_distributions(pos_scores, neg_scores, decoy_scores=None,
                              title='Score Distributions', outpath=None):
    """Violin plot of score distributions for different complex types."""
    fig, ax = plt.subplots(figsize=(8, 6))

    data = [pos_scores, neg_scores]
    labels = ['Positive\n(X+, Y)', 'Negative\n(X0, Y)']
    colors = ['#2196F3', '#F44336']

    if decoy_scores is not None and len(decoy_scores) > 0:
        data.append(decoy_scores)
        labels.append('Decoys\n(X+, Y~)')
        colors.append('#FF9800')

    parts = ax.violinplot(data, positions=range(len(data)), showmedians=True)
    for i, (pc, c) in enumerate(zip(parts['bodies'], colors)):
        pc.set_facecolor(c)
        pc.set_alpha(0.7)

    ax.set_xticks(range(len(data)))
    ax.set_xticklabels(labels)
    ax.set_ylabel('Q_theta Score', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_ylim(0, 1)
    ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Decision boundary')
    ax.legend()

    # Add mean + std annotations
    for i, (d, c) in enumerate(zip(data, colors)):
        ax.text(i, 0.02, f'μ={d.mean():.2f}\nσ={d.std():.2f}',
                ha='center', fontsize=9, color=c)

    plt.tight_layout()
    if outpath:
        plt.savefig(outpath, dpi=150, bbox_inches='tight')
        logger.info(f"Saved plot to {outpath}")
    plt.close()


def plot_roc_curve(labels, scores, title='ROC Curve', outpath=None):
    """Plot ROC curve for positive vs negative classification."""
    fpr, tpr, _ = roc_curve(labels, scores)
    auc = roc_auc_score(labels, scores)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.plot(fpr, tpr, 'b-', lw=2, label=f'AUC = {auc:.3f}')
    ax.plot([0, 1], [0, 1], 'k--', lw=1)
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title(title)
    ax.legend()
    plt.tight_layout()
    if outpath:
        plt.savefig(outpath, dpi=150, bbox_inches='tight')
    plt.close()
    return auc


def plot_best_of_k(results, outpath=None):
    """Plot best-of-K success rate as a function of K."""
    Ks = sorted(results.keys())
    success_rates = [results[K] for K in Ks]

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.semilogx(Ks, success_rates, 'b-o', lw=2, markersize=8)
    ax.set_xlabel('K (number of candidates)', fontsize=12)
    ax.set_ylabel('Success rate (best score > 0.7)', fontsize=12)
    ax.set_title('Best-of-K Analysis', fontsize=14)
    ax.set_ylim(0, 1.05)
    ax.grid(True, alpha=0.3)
    ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='80% success')
    ax.legend()
    plt.tight_layout()
    if outpath:
        plt.savefig(outpath, dpi=150, bbox_inches='tight')
    plt.close()


@torch.no_grad()
def evaluate(model, loader, device):
    """Run model on a dataset and collect all predictions."""
    model.eval()
    all_scores, all_labels, all_types, all_pdbs = [], [], [], []

    for batch in loader:
        esm_feats = batch['esm_feats'].to(device) if 'esm_feats' in batch else None
        scores = model(
            batch['node_feats'].to(device),
            batch['edge_feats'].to(device),
            batch['node_mask'].to(device),
            esm_feats=esm_feats,
        )
        all_scores.extend(scores.cpu().numpy().tolist())
        all_labels.extend(batch['label'].numpy().tolist())
        all_types.extend(batch['type'])
        all_pdbs.extend(batch['pdb'])

    return (np.array(all_scores), np.array(all_labels),
            np.array(all_types), np.array(all_pdbs))


def main():
    parser = argparse.ArgumentParser(description='Evaluate Allo-Designer Q_theta scorer')
    parser.add_argument('--target', default='cam',
                        help='Target name (cam, abl, era, or any custom target with data in data/processed/)')
    parser.add_argument('--all_targets', action='store_true',
                        help='Evaluate on all available targets and produce aggregated results')
    parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
    parser.add_argument('--data_dir', default='data/processed')
    parser.add_argument('--split', choices=['val', 'test'], default='test')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--gpu', type=int, default=7)
    parser.add_argument('--outdir', default='results')
    parser.add_argument('--bok_threshold', type=float, default=0.7,
                        help='Score threshold for best-of-K (default 0.7; use per-target value for calibrated results)')
    parser.add_argument('--esm_dir', default=None,
                        help='Path to ESM-2 embedding cache (auto-detected at <data_dir>/esm2_embeddings if omitted)')
    parser.add_argument('--no_wandb', action='store_true', help='(ignored; here for CLI compatibility)')
    args = parser.parse_args()

    # Auto-detect ESM dir under data_dir
    if args.esm_dir is None:
        cand = os.path.join(args.data_dir, 'esm2_embeddings')
        if os.path.isdir(cand):
            args.esm_dir = cand

    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    os.makedirs(args.outdir, exist_ok=True)
    os.makedirs(f'{args.outdir}/figures', exist_ok=True)
    os.makedirs(f'{args.outdir}/tables', exist_ok=True)

    # Load model
    state = torch.load(args.checkpoint, map_location=device)
    config = state.get('config', {})
    model = build_model(config).to(device)
    model.load_state_dict(state['model_state'])
    logger.info(f"Loaded model from {args.checkpoint}")

    # Load dataset
    data_path = os.path.join(args.data_dir, args.target, f'{args.split}.pkl')
    if not os.path.exists(data_path):
        logger.error(f"Data not found: {data_path}")
        sys.exit(1)

    dataset = TwoStateComplexDataset(data_path, max_nodes=128,
                                     esm_dir=args.esm_dir, target_name=args.target)
    loader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=2, collate_fn=collate_fn
    )

    # Run evaluation
    logger.info(f"Evaluating on {len(dataset)} samples...")
    scores, labels, types, pdbs = evaluate(model, loader, device)

    # Separate by type
    pos_mask = (types == 'positive')
    neg_apo_mask = (types == 'negative_apo')
    decoy_mask = np.array(['decoy' in t for t in types])

    pos_scores = scores[pos_mask]
    neg_scores = scores[neg_apo_mask]
    decoy_scores = scores[decoy_mask]

    logger.info(f"\n{'='*50}")
    logger.info(f"Results for {args.target} ({args.split})")
    logger.info(f"{'='*50}")
    logger.info(f"Positive samples: {pos_mask.sum()}")
    logger.info(f"Negative (apo) samples: {neg_apo_mask.sum()}")
    logger.info(f"Decoy samples: {decoy_mask.sum()}")

    # --- Core metrics ---
    metrics = {}

    # 1. Spearman correlation with DockQ labels
    sp, p_val = spearmanr(scores, labels)
    metrics['spearman_all'] = float(sp)
    metrics['spearman_pval'] = float(p_val)
    logger.info(f"\nSpearman(Q_theta, DockQ): {sp:.3f} (p={p_val:.3e})")

    # 2. Selectivity gap (positive vs negative_apo)
    if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0:
        gap = float(pos_scores.mean() - neg_scores.mean())
        ranking_acc = float((pos_scores.mean() > neg_scores).mean() if len(neg_scores) > 0 else 0.5)
        metrics['selectivity_gap'] = gap
        metrics['pos_score_mean'] = float(pos_scores.mean())
        metrics['neg_score_mean'] = float(neg_scores.mean())
        metrics['pos_score_std'] = float(pos_scores.std())
        metrics['neg_score_std'] = float(neg_scores.std())
        logger.info(f"Selectivity gap (pos - neg): {gap:.3f}")
        logger.info(f"  Pos: {pos_scores.mean():.3f} ± {pos_scores.std():.3f}")
        logger.info(f"  Neg: {neg_scores.mean():.3f} ± {neg_scores.std():.3f}")

    # 3. AUC for positive vs negative
    if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0:
        pn_scores = np.concatenate([pos_scores, neg_scores])
        pn_labels = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
        auc = roc_auc_score(pn_labels, pn_scores)
        metrics['auc_pos_vs_neg'] = float(auc)
        logger.info(f"AUC (pos vs neg_apo): {auc:.3f}")

        # ROC curve
        plot_roc_curve(
            pn_labels, pn_scores,
            title=f'ROC: Positive vs Negative Apo ({args.target.upper()})',
            outpath=f'{args.outdir}/figures/roc_{args.target}_{args.split}.png'
        )

    # 4. AUC for quality classification (DockQ > 0.5)
    binary = (labels > 0.5).astype(int)
    if binary.sum() > 0 and binary.sum() < len(binary):
        auc_quality = roc_auc_score(binary, scores)
        metrics['auc_quality'] = float(auc_quality)
        logger.info(f"AUC (quality>0.5): {auc_quality:.3f}")

    # 5. Best-of-K analysis
    if len(pos_scores) > 0:
        bok_results = compute_best_of_k(pos_scores, K_values=[1, 2, 5, 10, 20, 50],
                                         threshold=args.bok_threshold)
        metrics['best_of_k'] = {str(K): float(v) for K, v in bok_results.items()}
        logger.info(f"\nBest-of-K success rates:")
        for K, rate in bok_results.items():
            logger.info(f"  K={K:3d}: {rate:.3f}")
        plot_best_of_k(
            bok_results,
            outpath=f'{args.outdir}/figures/best_of_k_{args.target}_{args.split}.png'
        )

    # 6. Score distributions plot
    plot_score_distributions(
        pos_scores if len(pos_scores) > 0 else np.array([]),
        neg_scores if len(neg_scores) > 0 else np.array([]),
        decoy_scores if len(decoy_scores) > 0 else None,
        title=f'Q_theta Score Distributions ({args.target.upper()})',
        outpath=f'{args.outdir}/figures/score_dist_{args.target}_{args.split}.png'
    )

    # Save metrics
    out_json = f'{args.outdir}/tables/eval_{args.target}_{args.split}.json'
    with open(out_json, 'w') as f:
        json.dump(metrics, f, indent=2)
    logger.info(f"\nSaved metrics to {out_json}")

    # Print summary table
    logger.info(f"\n{'='*50}")
    logger.info("SUMMARY TABLE")
    logger.info(f"{'='*50}")
    logger.info(f"{'Metric':<30} {'Value':>10}")
    logger.info(f"{'-'*42}")
    for k, v in metrics.items():
        if isinstance(v, float):
            logger.info(f"{k:<30} {v:>10.4f}")
    logger.info(f"{'='*50}")


if __name__ == '__main__':
    main()