File size: 7,366 Bytes
ad9572d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | #!/usr/bin/env python3
"""
Re-score binder PDB designs with a Q_theta checkpoint.
Walks a directory of designs (binder PDB + sibling holo / apo receptor PDBs),
runs each through DifferentiableQTheta, and writes per-design
S = Q_theta(holo) - Q_theta(apo) plus the raw holo/apo scores to JSON.
Usage:
python code/scripts/rescore.py \\
--checkpoint checkpoints/Q_theta_phase2.pt \\
--gpu 0
"""
import os, sys, json, argparse, glob, logging
import numpy as np
import torch
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
BASE = str(Path(__file__).resolve().parent.parent.parent)
sys.path.insert(0, os.path.join(BASE, 'code'))
sys.path.insert(0, BASE)
from models.differentiable_features import DifferentiableQTheta
from utils.pdb_utils import load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures
HOLO_PDB = os.path.join(BASE, 'data/pdbs/cam_holo/3CLN.pdb')
APO_PDB = os.path.join(BASE, 'data/pdbs/cam_apo/1CFD.pdb')
def score_pdb_list(dq, pdb_list, ref_resnums, ref_coords, device):
"""Score a list of design PDB files."""
results = []
for pdb_path in pdb_list:
name = os.path.basename(pdb_path).replace(".pdb", "")
try:
design_model = load_structure(pdb_path)
chains = [c.id for c in design_model.get_chains()]
rec_chain = 'A' if 'A' in chains else chains[0]
binder_chain = 'B' if 'B' in chains else [c for c in chains if c != rec_chain][0]
rec_res = get_residues(design_model[rec_chain])
binder_res = get_residues(design_model[binder_chain])
rec_coords_d, _ = get_backbone_coords(rec_res)
binder_coords, binder_mask = get_backbone_coords(binder_res)
binder_aa_idx = get_aa_indices(binder_res)
design_resnums = {r.get_id()[1]: i for i, r in enumerate(rec_res)}
common = sorted(set(design_resnums.keys()) & set(ref_resnums.keys()))
if len(common) < 10:
logger.warning(f" Skip {name}: <10 common residues")
continue
d_ca = rec_coords_d[[design_resnums[r] for r in common], 1]
r_ca = ref_coords[[ref_resnums[r] for r in common], 1]
mobile_center = d_ca.mean(0)
ref_center = r_ca.mean(0)
_, R = align_structures(d_ca, r_ca)
flat = binder_coords.reshape(-1, 3) - mobile_center
aligned_binder = (flat @ R.T + ref_center).reshape(-1, 4, 3)
coords_t = torch.from_numpy(aligned_binder).float().to(device)
mask_t = torch.from_numpy(binder_mask).bool().to(device)
aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
with torch.no_grad():
q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
receptor_label='holo').item()
q_apo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
receptor_label='apo').item()
S = q_holo - q_apo
results.append({"design": name, "Q_holo": q_holo, "Q_apo": q_apo, "S": S})
except Exception as e:
logger.warning(f" Skip {name}: {e}")
return results
def summarize(results, label):
if not results:
return {}
S = [r["S"] for r in results]
return {
"method": label, "n": len(S),
"S_mean": float(np.mean(S)), "S_std": float(np.std(S)),
"S_pos_pct": float(np.mean([s > 0 for s in S]) * 100),
"Q_holo_mean": float(np.mean([r["Q_holo"] for r in results])),
"Q_apo_mean": float(np.mean([r["Q_apo"] for r in results])),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=int, default=7)
parser.add_argument("--checkpoint", default="checkpoints/Q_theta_phase2.pt")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
device = "cuda:0"
logger.info(f"Loading Q_theta from {args.checkpoint}")
dq = DifferentiableQTheta(checkpoint_path=args.checkpoint, device=device,
esm_dir=os.path.join(BASE, "data/esm2_embeddings"))
dq.load_receptor(HOLO_PDB, chain='A', label='holo', esm_target='cam')
dq.load_receptor(APO_PDB, chain='A', label='apo', esm_target='cam')
ref_model = load_structure(HOLO_PDB)
ref_res = get_residues(ref_model['A'])
ref_coords, _ = get_backbone_coords(ref_res)
ref_resnums = {r.get_id()[1]: i for i, r in enumerate(ref_res)}
output_dir = os.path.join(BASE, "results/v2_strict_holdout/scoring")
os.makedirs(output_dir, exist_ok=True)
# Define design directories
design_sets = {
"vanilla": os.path.join(BASE, "results/independent_validation/vanilla/holo_pdbs"),
"langevin": os.path.join(BASE, "results/langevin_refinement/refined_pdbs"),
"classifier": os.path.join(BASE, "results/guided_diffusion/guided"),
"smc_r3": os.path.join(BASE, "results/smc_guidance/cam/round_3"),
}
# Also check for TDS and PXDesign
tds_dirs = glob.glob(os.path.join(BASE, "results/tds_guidance/cam/designs"))
if tds_dirs:
design_sets["tds"] = tds_dirs[0]
# PXDesign directories
for px_method in ["pxdesign_scoring", "pxdesign_classifier", "pxdesign_tds",
"pxdesign_smc", "pxdesign_langevin"]:
px_dir = os.path.join(BASE, f"results_familysplit/design_bd30/{px_method}")
if not os.path.exists(px_dir):
px_dir = os.path.join(BASE, f"results/{px_method}")
if os.path.exists(px_dir):
pdbs = glob.glob(os.path.join(px_dir, "*.pdb"))
if pdbs:
design_sets[px_method] = px_dir
all_results = {}
summaries = []
for method, pdb_dir in design_sets.items():
if not os.path.exists(pdb_dir):
logger.warning(f" {method}: directory not found ({pdb_dir})")
continue
pdbs = sorted(glob.glob(os.path.join(pdb_dir, "*.pdb")))
if not pdbs:
logger.warning(f" {method}: no PDB files")
continue
logger.info(f"\n=== {method} ({len(pdbs)} designs) ===")
results = score_pdb_list(dq, pdbs, ref_resnums, ref_coords, device)
s = summarize(results, method)
if s:
summaries.append(s)
logger.info(f" {method}: n={s['n']}, S̄={s['S_mean']:.3f}±{s['S_std']:.3f}, "
f"S>0={s['S_pos_pct']:.0f}%, Q+={s['Q_holo_mean']:.3f}, Q-={s['Q_apo_mean']:.3f}")
all_results[method] = {"results": results, "summary": s}
# Save
with open(os.path.join(output_dir, "rescore_v2_all.json"), "w") as f:
json.dump(all_results, f, indent=2)
# Print summary table
print("\n" + "=" * 70)
print("V2 RESCORING SUMMARY (strict holdout, CaM OOD)")
print("=" * 70)
print(f"{'Method':20s} {'n':>4s} {'S̄':>8s} {'±σ':>6s} {'S>0%':>6s} {'Q+':>6s} {'Q-':>6s}")
print("-" * 70)
for s in sorted(summaries, key=lambda x: x['S_mean'], reverse=True):
print(f"{s['method']:20s} {s['n']:4d} {s['S_mean']:8.3f} {s['S_std']:6.3f} "
f"{s['S_pos_pct']:5.1f}% {s['Q_holo_mean']:6.3f} {s['Q_apo_mean']:6.3f}")
if __name__ == "__main__":
main()
|