| |
| """ |
| Re-score binder PDB designs with a Q_theta checkpoint. |
| |
| Walks a directory of designs (binder PDB + sibling holo / apo receptor PDBs), |
| runs each through DifferentiableQTheta, and writes per-design |
| S = Q_theta(holo) - Q_theta(apo) plus the raw holo/apo scores to JSON. |
| |
| Usage: |
| python code/scripts/rescore.py \\ |
| --checkpoint checkpoints/Q_theta_phase2.pt \\ |
| --gpu 0 |
| """ |
| import os, sys, json, argparse, glob, logging |
| import numpy as np |
| import torch |
| from pathlib import Path |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| BASE = str(Path(__file__).resolve().parent.parent.parent) |
| sys.path.insert(0, os.path.join(BASE, 'code')) |
| sys.path.insert(0, BASE) |
|
|
| from models.differentiable_features import DifferentiableQTheta |
| from utils.pdb_utils import load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures |
|
|
| HOLO_PDB = os.path.join(BASE, 'data/pdbs/cam_holo/3CLN.pdb') |
| APO_PDB = os.path.join(BASE, 'data/pdbs/cam_apo/1CFD.pdb') |
|
|
|
|
| def score_pdb_list(dq, pdb_list, ref_resnums, ref_coords, device): |
| """Score a list of design PDB files.""" |
| results = [] |
| for pdb_path in pdb_list: |
| name = os.path.basename(pdb_path).replace(".pdb", "") |
| try: |
| design_model = load_structure(pdb_path) |
| chains = [c.id for c in design_model.get_chains()] |
| rec_chain = 'A' if 'A' in chains else chains[0] |
| binder_chain = 'B' if 'B' in chains else [c for c in chains if c != rec_chain][0] |
|
|
| rec_res = get_residues(design_model[rec_chain]) |
| binder_res = get_residues(design_model[binder_chain]) |
| rec_coords_d, _ = get_backbone_coords(rec_res) |
| binder_coords, binder_mask = get_backbone_coords(binder_res) |
| binder_aa_idx = get_aa_indices(binder_res) |
|
|
| design_resnums = {r.get_id()[1]: i for i, r in enumerate(rec_res)} |
| common = sorted(set(design_resnums.keys()) & set(ref_resnums.keys())) |
| if len(common) < 10: |
| logger.warning(f" Skip {name}: <10 common residues") |
| continue |
|
|
| d_ca = rec_coords_d[[design_resnums[r] for r in common], 1] |
| r_ca = ref_coords[[ref_resnums[r] for r in common], 1] |
| mobile_center = d_ca.mean(0) |
| ref_center = r_ca.mean(0) |
| _, R = align_structures(d_ca, r_ca) |
|
|
| flat = binder_coords.reshape(-1, 3) - mobile_center |
| aligned_binder = (flat @ R.T + ref_center).reshape(-1, 4, 3) |
|
|
| coords_t = torch.from_numpy(aligned_binder).float().to(device) |
| mask_t = torch.from_numpy(binder_mask).bool().to(device) |
| aa_t = torch.from_numpy(binder_aa_idx).long().to(device) |
|
|
| with torch.no_grad(): |
| q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t, |
| receptor_label='holo').item() |
| q_apo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t, |
| receptor_label='apo').item() |
| S = q_holo - q_apo |
| results.append({"design": name, "Q_holo": q_holo, "Q_apo": q_apo, "S": S}) |
| except Exception as e: |
| logger.warning(f" Skip {name}: {e}") |
| return results |
|
|
|
|
| def summarize(results, label): |
| if not results: |
| return {} |
| S = [r["S"] for r in results] |
| return { |
| "method": label, "n": len(S), |
| "S_mean": float(np.mean(S)), "S_std": float(np.std(S)), |
| "S_pos_pct": float(np.mean([s > 0 for s in S]) * 100), |
| "Q_holo_mean": float(np.mean([r["Q_holo"] for r in results])), |
| "Q_apo_mean": float(np.mean([r["Q_apo"] for r in results])), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--gpu", type=int, default=7) |
| parser.add_argument("--checkpoint", default="checkpoints/Q_theta_phase2.pt") |
| args = parser.parse_args() |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
| device = "cuda:0" |
|
|
| logger.info(f"Loading Q_theta from {args.checkpoint}") |
| dq = DifferentiableQTheta(checkpoint_path=args.checkpoint, device=device, |
| esm_dir=os.path.join(BASE, "data/esm2_embeddings")) |
| dq.load_receptor(HOLO_PDB, chain='A', label='holo', esm_target='cam') |
| dq.load_receptor(APO_PDB, chain='A', label='apo', esm_target='cam') |
|
|
| ref_model = load_structure(HOLO_PDB) |
| ref_res = get_residues(ref_model['A']) |
| ref_coords, _ = get_backbone_coords(ref_res) |
| ref_resnums = {r.get_id()[1]: i for i, r in enumerate(ref_res)} |
|
|
| output_dir = os.path.join(BASE, "results/v2_strict_holdout/scoring") |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| design_sets = { |
| "vanilla": os.path.join(BASE, "results/independent_validation/vanilla/holo_pdbs"), |
| "langevin": os.path.join(BASE, "results/langevin_refinement/refined_pdbs"), |
| "classifier": os.path.join(BASE, "results/guided_diffusion/guided"), |
| "smc_r3": os.path.join(BASE, "results/smc_guidance/cam/round_3"), |
| } |
|
|
| |
| tds_dirs = glob.glob(os.path.join(BASE, "results/tds_guidance/cam/designs")) |
| if tds_dirs: |
| design_sets["tds"] = tds_dirs[0] |
|
|
| |
| for px_method in ["pxdesign_scoring", "pxdesign_classifier", "pxdesign_tds", |
| "pxdesign_smc", "pxdesign_langevin"]: |
| px_dir = os.path.join(BASE, f"results_familysplit/design_bd30/{px_method}") |
| if not os.path.exists(px_dir): |
| px_dir = os.path.join(BASE, f"results/{px_method}") |
| if os.path.exists(px_dir): |
| pdbs = glob.glob(os.path.join(px_dir, "*.pdb")) |
| if pdbs: |
| design_sets[px_method] = px_dir |
|
|
| all_results = {} |
| summaries = [] |
|
|
| for method, pdb_dir in design_sets.items(): |
| if not os.path.exists(pdb_dir): |
| logger.warning(f" {method}: directory not found ({pdb_dir})") |
| continue |
| pdbs = sorted(glob.glob(os.path.join(pdb_dir, "*.pdb"))) |
| if not pdbs: |
| logger.warning(f" {method}: no PDB files") |
| continue |
|
|
| logger.info(f"\n=== {method} ({len(pdbs)} designs) ===") |
| results = score_pdb_list(dq, pdbs, ref_resnums, ref_coords, device) |
| s = summarize(results, method) |
| if s: |
| summaries.append(s) |
| logger.info(f" {method}: n={s['n']}, S̄={s['S_mean']:.3f}±{s['S_std']:.3f}, " |
| f"S>0={s['S_pos_pct']:.0f}%, Q+={s['Q_holo_mean']:.3f}, Q-={s['Q_apo_mean']:.3f}") |
| all_results[method] = {"results": results, "summary": s} |
|
|
| |
| with open(os.path.join(output_dir, "rescore_v2_all.json"), "w") as f: |
| json.dump(all_results, f, indent=2) |
|
|
| |
| print("\n" + "=" * 70) |
| print("V2 RESCORING SUMMARY (strict holdout, CaM OOD)") |
| print("=" * 70) |
| print(f"{'Method':20s} {'n':>4s} {'S̄':>8s} {'±σ':>6s} {'S>0%':>6s} {'Q+':>6s} {'Q-':>6s}") |
| print("-" * 70) |
| for s in sorted(summaries, key=lambda x: x['S_mean'], reverse=True): |
| print(f"{s['method']:20s} {s['n']:4d} {s['S_mean']:8.3f} {s['S_std']:6.3f} " |
| f"{s['S_pos_pct']:5.1f}% {s['Q_holo_mean']:6.3f} {s['Q_apo_mean']:6.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|