AlloGen / code /scripts /rescore.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
#!/usr/bin/env python3
"""
Re-score binder PDB designs with a Q_theta checkpoint.
Walks a directory of designs (binder PDB + sibling holo / apo receptor PDBs),
runs each through DifferentiableQTheta, and writes per-design
S = Q_theta(holo) - Q_theta(apo) plus the raw holo/apo scores to JSON.
Usage:
python code/scripts/rescore.py \\
--checkpoint checkpoints/Q_theta_phase2.pt \\
--gpu 0
"""
import os, sys, json, argparse, glob, logging
import numpy as np
import torch
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
BASE = str(Path(__file__).resolve().parent.parent.parent)
sys.path.insert(0, os.path.join(BASE, 'code'))
sys.path.insert(0, BASE)
from models.differentiable_features import DifferentiableQTheta
from utils.pdb_utils import load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures
HOLO_PDB = os.path.join(BASE, 'data/pdbs/cam_holo/3CLN.pdb')
APO_PDB = os.path.join(BASE, 'data/pdbs/cam_apo/1CFD.pdb')
def score_pdb_list(dq, pdb_list, ref_resnums, ref_coords, device):
"""Score a list of design PDB files."""
results = []
for pdb_path in pdb_list:
name = os.path.basename(pdb_path).replace(".pdb", "")
try:
design_model = load_structure(pdb_path)
chains = [c.id for c in design_model.get_chains()]
rec_chain = 'A' if 'A' in chains else chains[0]
binder_chain = 'B' if 'B' in chains else [c for c in chains if c != rec_chain][0]
rec_res = get_residues(design_model[rec_chain])
binder_res = get_residues(design_model[binder_chain])
rec_coords_d, _ = get_backbone_coords(rec_res)
binder_coords, binder_mask = get_backbone_coords(binder_res)
binder_aa_idx = get_aa_indices(binder_res)
design_resnums = {r.get_id()[1]: i for i, r in enumerate(rec_res)}
common = sorted(set(design_resnums.keys()) & set(ref_resnums.keys()))
if len(common) < 10:
logger.warning(f" Skip {name}: <10 common residues")
continue
d_ca = rec_coords_d[[design_resnums[r] for r in common], 1]
r_ca = ref_coords[[ref_resnums[r] for r in common], 1]
mobile_center = d_ca.mean(0)
ref_center = r_ca.mean(0)
_, R = align_structures(d_ca, r_ca)
flat = binder_coords.reshape(-1, 3) - mobile_center
aligned_binder = (flat @ R.T + ref_center).reshape(-1, 4, 3)
coords_t = torch.from_numpy(aligned_binder).float().to(device)
mask_t = torch.from_numpy(binder_mask).bool().to(device)
aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
with torch.no_grad():
q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
receptor_label='holo').item()
q_apo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
receptor_label='apo').item()
S = q_holo - q_apo
results.append({"design": name, "Q_holo": q_holo, "Q_apo": q_apo, "S": S})
except Exception as e:
logger.warning(f" Skip {name}: {e}")
return results
def summarize(results, label):
if not results:
return {}
S = [r["S"] for r in results]
return {
"method": label, "n": len(S),
"S_mean": float(np.mean(S)), "S_std": float(np.std(S)),
"S_pos_pct": float(np.mean([s > 0 for s in S]) * 100),
"Q_holo_mean": float(np.mean([r["Q_holo"] for r in results])),
"Q_apo_mean": float(np.mean([r["Q_apo"] for r in results])),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=int, default=7)
parser.add_argument("--checkpoint", default="checkpoints/Q_theta_phase2.pt")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
device = "cuda:0"
logger.info(f"Loading Q_theta from {args.checkpoint}")
dq = DifferentiableQTheta(checkpoint_path=args.checkpoint, device=device,
esm_dir=os.path.join(BASE, "data/esm2_embeddings"))
dq.load_receptor(HOLO_PDB, chain='A', label='holo', esm_target='cam')
dq.load_receptor(APO_PDB, chain='A', label='apo', esm_target='cam')
ref_model = load_structure(HOLO_PDB)
ref_res = get_residues(ref_model['A'])
ref_coords, _ = get_backbone_coords(ref_res)
ref_resnums = {r.get_id()[1]: i for i, r in enumerate(ref_res)}
output_dir = os.path.join(BASE, "results/v2_strict_holdout/scoring")
os.makedirs(output_dir, exist_ok=True)
# Define design directories
design_sets = {
"vanilla": os.path.join(BASE, "results/independent_validation/vanilla/holo_pdbs"),
"langevin": os.path.join(BASE, "results/langevin_refinement/refined_pdbs"),
"classifier": os.path.join(BASE, "results/guided_diffusion/guided"),
"smc_r3": os.path.join(BASE, "results/smc_guidance/cam/round_3"),
}
# Also check for TDS and PXDesign
tds_dirs = glob.glob(os.path.join(BASE, "results/tds_guidance/cam/designs"))
if tds_dirs:
design_sets["tds"] = tds_dirs[0]
# PXDesign directories
for px_method in ["pxdesign_scoring", "pxdesign_classifier", "pxdesign_tds",
"pxdesign_smc", "pxdesign_langevin"]:
px_dir = os.path.join(BASE, f"results_familysplit/design_bd30/{px_method}")
if not os.path.exists(px_dir):
px_dir = os.path.join(BASE, f"results/{px_method}")
if os.path.exists(px_dir):
pdbs = glob.glob(os.path.join(px_dir, "*.pdb"))
if pdbs:
design_sets[px_method] = px_dir
all_results = {}
summaries = []
for method, pdb_dir in design_sets.items():
if not os.path.exists(pdb_dir):
logger.warning(f" {method}: directory not found ({pdb_dir})")
continue
pdbs = sorted(glob.glob(os.path.join(pdb_dir, "*.pdb")))
if not pdbs:
logger.warning(f" {method}: no PDB files")
continue
logger.info(f"\n=== {method} ({len(pdbs)} designs) ===")
results = score_pdb_list(dq, pdbs, ref_resnums, ref_coords, device)
s = summarize(results, method)
if s:
summaries.append(s)
logger.info(f" {method}: n={s['n']}, S̄={s['S_mean']:.3f}±{s['S_std']:.3f}, "
f"S>0={s['S_pos_pct']:.0f}%, Q+={s['Q_holo_mean']:.3f}, Q-={s['Q_apo_mean']:.3f}")
all_results[method] = {"results": results, "summary": s}
# Save
with open(os.path.join(output_dir, "rescore_v2_all.json"), "w") as f:
json.dump(all_results, f, indent=2)
# Print summary table
print("\n" + "=" * 70)
print("V2 RESCORING SUMMARY (strict holdout, CaM OOD)")
print("=" * 70)
print(f"{'Method':20s} {'n':>4s} {'S̄':>8s} {'±σ':>6s} {'S>0%':>6s} {'Q+':>6s} {'Q-':>6s}")
print("-" * 70)
for s in sorted(summaries, key=lambda x: x['S_mean'], reverse=True):
print(f"{s['method']:20s} {s['n']:4d} {s['S_mean']:8.3f} {s['S_std']:6.3f} "
f"{s['S_pos_pct']:5.1f}% {s['Q_holo_mean']:6.3f} {s['Q_apo_mean']:6.3f}")
if __name__ == "__main__":
main()