#!/usr/bin/env python3 """ Re-score binder PDB designs with a Q_theta checkpoint. Walks a directory of designs (binder PDB + sibling holo / apo receptor PDBs), runs each through DifferentiableQTheta, and writes per-design S = Q_theta(holo) - Q_theta(apo) plus the raw holo/apo scores to JSON. Usage: python code/scripts/rescore.py \\ --checkpoint checkpoints/Q_theta_phase2.pt \\ --gpu 0 """ import os, sys, json, argparse, glob, logging import numpy as np import torch from pathlib import Path logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) BASE = str(Path(__file__).resolve().parent.parent.parent) sys.path.insert(0, os.path.join(BASE, 'code')) sys.path.insert(0, BASE) from models.differentiable_features import DifferentiableQTheta from utils.pdb_utils import load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures HOLO_PDB = os.path.join(BASE, 'data/pdbs/cam_holo/3CLN.pdb') APO_PDB = os.path.join(BASE, 'data/pdbs/cam_apo/1CFD.pdb') def score_pdb_list(dq, pdb_list, ref_resnums, ref_coords, device): """Score a list of design PDB files.""" results = [] for pdb_path in pdb_list: name = os.path.basename(pdb_path).replace(".pdb", "") try: design_model = load_structure(pdb_path) chains = [c.id for c in design_model.get_chains()] rec_chain = 'A' if 'A' in chains else chains[0] binder_chain = 'B' if 'B' in chains else [c for c in chains if c != rec_chain][0] rec_res = get_residues(design_model[rec_chain]) binder_res = get_residues(design_model[binder_chain]) rec_coords_d, _ = get_backbone_coords(rec_res) binder_coords, binder_mask = get_backbone_coords(binder_res) binder_aa_idx = get_aa_indices(binder_res) design_resnums = {r.get_id()[1]: i for i, r in enumerate(rec_res)} common = sorted(set(design_resnums.keys()) & set(ref_resnums.keys())) if len(common) < 10: logger.warning(f" Skip {name}: <10 common residues") continue d_ca = rec_coords_d[[design_resnums[r] for r in common], 1] r_ca = ref_coords[[ref_resnums[r] for r in common], 1] mobile_center = d_ca.mean(0) ref_center = r_ca.mean(0) _, R = align_structures(d_ca, r_ca) flat = binder_coords.reshape(-1, 3) - mobile_center aligned_binder = (flat @ R.T + ref_center).reshape(-1, 4, 3) coords_t = torch.from_numpy(aligned_binder).float().to(device) mask_t = torch.from_numpy(binder_mask).bool().to(device) aa_t = torch.from_numpy(binder_aa_idx).long().to(device) with torch.no_grad(): q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t, receptor_label='holo').item() q_apo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t, receptor_label='apo').item() S = q_holo - q_apo results.append({"design": name, "Q_holo": q_holo, "Q_apo": q_apo, "S": S}) except Exception as e: logger.warning(f" Skip {name}: {e}") return results def summarize(results, label): if not results: return {} S = [r["S"] for r in results] return { "method": label, "n": len(S), "S_mean": float(np.mean(S)), "S_std": float(np.std(S)), "S_pos_pct": float(np.mean([s > 0 for s in S]) * 100), "Q_holo_mean": float(np.mean([r["Q_holo"] for r in results])), "Q_apo_mean": float(np.mean([r["Q_apo"] for r in results])), } def main(): parser = argparse.ArgumentParser() parser.add_argument("--gpu", type=int, default=7) parser.add_argument("--checkpoint", default="checkpoints/Q_theta_phase2.pt") args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) device = "cuda:0" logger.info(f"Loading Q_theta from {args.checkpoint}") dq = DifferentiableQTheta(checkpoint_path=args.checkpoint, device=device, esm_dir=os.path.join(BASE, "data/esm2_embeddings")) dq.load_receptor(HOLO_PDB, chain='A', label='holo', esm_target='cam') dq.load_receptor(APO_PDB, chain='A', label='apo', esm_target='cam') ref_model = load_structure(HOLO_PDB) ref_res = get_residues(ref_model['A']) ref_coords, _ = get_backbone_coords(ref_res) ref_resnums = {r.get_id()[1]: i for i, r in enumerate(ref_res)} output_dir = os.path.join(BASE, "results/v2_strict_holdout/scoring") os.makedirs(output_dir, exist_ok=True) # Define design directories design_sets = { "vanilla": os.path.join(BASE, "results/independent_validation/vanilla/holo_pdbs"), "langevin": os.path.join(BASE, "results/langevin_refinement/refined_pdbs"), "classifier": os.path.join(BASE, "results/guided_diffusion/guided"), "smc_r3": os.path.join(BASE, "results/smc_guidance/cam/round_3"), } # Also check for TDS and PXDesign tds_dirs = glob.glob(os.path.join(BASE, "results/tds_guidance/cam/designs")) if tds_dirs: design_sets["tds"] = tds_dirs[0] # PXDesign directories for px_method in ["pxdesign_scoring", "pxdesign_classifier", "pxdesign_tds", "pxdesign_smc", "pxdesign_langevin"]: px_dir = os.path.join(BASE, f"results_familysplit/design_bd30/{px_method}") if not os.path.exists(px_dir): px_dir = os.path.join(BASE, f"results/{px_method}") if os.path.exists(px_dir): pdbs = glob.glob(os.path.join(px_dir, "*.pdb")) if pdbs: design_sets[px_method] = px_dir all_results = {} summaries = [] for method, pdb_dir in design_sets.items(): if not os.path.exists(pdb_dir): logger.warning(f" {method}: directory not found ({pdb_dir})") continue pdbs = sorted(glob.glob(os.path.join(pdb_dir, "*.pdb"))) if not pdbs: logger.warning(f" {method}: no PDB files") continue logger.info(f"\n=== {method} ({len(pdbs)} designs) ===") results = score_pdb_list(dq, pdbs, ref_resnums, ref_coords, device) s = summarize(results, method) if s: summaries.append(s) logger.info(f" {method}: n={s['n']}, S̄={s['S_mean']:.3f}±{s['S_std']:.3f}, " f"S>0={s['S_pos_pct']:.0f}%, Q+={s['Q_holo_mean']:.3f}, Q-={s['Q_apo_mean']:.3f}") all_results[method] = {"results": results, "summary": s} # Save with open(os.path.join(output_dir, "rescore_v2_all.json"), "w") as f: json.dump(all_results, f, indent=2) # Print summary table print("\n" + "=" * 70) print("V2 RESCORING SUMMARY (strict holdout, CaM OOD)") print("=" * 70) print(f"{'Method':20s} {'n':>4s} {'S̄':>8s} {'±σ':>6s} {'S>0%':>6s} {'Q+':>6s} {'Q-':>6s}") print("-" * 70) for s in sorted(summaries, key=lambda x: x['S_mean'], reverse=True): print(f"{s['method']:20s} {s['n']:4d} {s['S_mean']:8.3f} {s['S_std']:6.3f} " f"{s['S_pos_pct']:5.1f}% {s['Q_holo_mean']:6.3f} {s['Q_apo_mean']:6.3f}") if __name__ == "__main__": main()