| |
| """ |
| Binary-collapse evaluation for apples-to-apples comparison with the Nexar |
| Kaggle winner (MViT-V2-S + 3LC, AP=0.898) and BADAS (V-JEPA2). |
| |
| Why this script exists |
| ββββββββββββββββββββββ |
| The Nexar challenge defines the task as BINARY (collision vs no-collision). |
| Our 3-class schema (SILENT / OBSERVE / ALERT) is strictly RICHER β OBSERVE |
| is an extra "heads-up without alarming" layer that did not exist in the |
| challenge rubric. When we report 0.266 "binary AP" using only P(ALERT) vs |
| SILENT, OBSERVE-positive samples are scored against us even though they |
| ARE detected collisions in the Nexar sense. |
| |
| The correct comparison β and the one we use in the paper β collapses |
| {ALERT, OBSERVE} β "positive" and uses P(ALERT)+P(OBSERVE) as the score. |
| Under this collapse: |
| β’ On Nexar-only: MViT is 0.898; we should be close (no OBSERVE there). |
| β’ On DADA-only: new number β MViT has not been reported on DADA. |
| β’ Merged : paper headline. |
| |
| Usage |
| βββββ |
| python -m training.Policy.eval_binary_collapse \\ |
| --checkpoints traj_full temporal_long_mono \\ |
| --label_dir data/policy_labels \\ |
| --cache_dir data/belief_cache \\ |
| --output eval_results/binary_collapse.json |
| |
| Output: JSON + human-readable table. For each checkpoint Γ subset |
| {all, nexar, dada}, reports: |
| strict_ap β P(ALERT), label == 2 |
| merged_ap β P(ALERT)+P(OBSERVE), label β {1, 2} |
| class_ap β per-class 1-vs-rest |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from collections import Counter |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from sklearn.metrics import average_precision_score |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
|
|
| from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn |
| from training.Policy.temporal_trainer import ( |
| TemporalPolicyDataset, TemporalPolicyModel, temporal_collate_fn, |
| ) |
| from training.Policy.trajectory_trainer import ( |
| TrajectoryPolicyDataset, TrajectoryPolicyModel, trajectory_collate_fn, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.eval_binary_collapse") |
|
|
|
|
| |
| |
| |
|
|
| def load_policy_checkpoint(ckpt_dir: Path, hidden_dim: int, seq_len: int): |
| """Return (model, is_trajectory: bool) for a v6 or v7 checkpoint.""" |
| meta_path = ckpt_dir / "policy_meta.json" |
| if not meta_path.exists(): |
| raise FileNotFoundError(f"policy_meta.json missing under {ckpt_dir}") |
| meta = json.loads(meta_path.read_text()) |
| version = meta.get("version", "") |
|
|
| if "trajectory" in version or "v7" in version: |
| model = TrajectoryPolicyModel( |
| hidden_dim=hidden_dim, |
| seq_len=seq_len, |
| use_gru=meta.get("use_gru", True), |
| belief_noise_std=0.0, |
| ) |
| model.load_policy_checkpoint(str(ckpt_dir)) |
| return model, True |
|
|
| |
| model = TemporalPolicyModel(hidden_dim=hidden_dim, seq_len=seq_len) |
| model.load_policy_checkpoint(str(ckpt_dir)) |
| return model, False |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def run_inference(model, loader, is_trajectory: bool) -> np.ndarray: |
| model.eval() |
| all_probs = [] |
| for batch in tqdm(loader, desc="Inference", ncols=80, leave=False): |
| if is_trajectory: |
| logits, _ = model( |
| batch["belief_seqs"], batch["tta_mean_seqs"], batch["tta_var_seqs"] |
| ) |
| else: |
| logits = model( |
| batch["belief_seqs"], batch["tta_mean_seqs"], batch["tta_var_seqs"] |
| ) |
| probs = F.softmax(logits, dim=-1).cpu().numpy() |
| all_probs.append(probs) |
| return np.concatenate(all_probs, axis=0) |
|
|
|
|
| |
| |
| |
|
|
| def _safe_ap(y_true: np.ndarray, y_score: np.ndarray) -> Optional[float]: |
| """AP, or None if degenerate (no positives / no negatives).""" |
| n_pos = int(y_true.sum()) |
| n_neg = int(len(y_true) - n_pos) |
| if n_pos == 0 or n_neg == 0: |
| return None |
| return float(average_precision_score(y_true, y_score)) |
|
|
|
|
| def compute_subset_metrics( |
| probs: np.ndarray, |
| labels: np.ndarray, |
| mask: np.ndarray, |
| name: str, |
| ) -> Dict[str, Any]: |
| """ |
| probs columns: 0=SILENT, 1=OBSERVE, 2=ALERT |
| """ |
| n = int(mask.sum()) |
| if n == 0: |
| return {"name": name, "n": 0} |
|
|
| p = probs[mask] |
| y = labels[mask] |
|
|
| |
| strict_ap = _safe_ap((y == 2).astype(int), p[:, 2]) |
|
|
| |
| merged_ap = _safe_ap((y >= 1).astype(int), p[:, 1] + p[:, 2]) |
|
|
| |
| observe_ap = _safe_ap((y == 1).astype(int), p[:, 1]) |
|
|
| |
| cls_dist = Counter(int(v) for v in y.tolist()) |
|
|
| return { |
| "name": name, |
| "n": n, |
| "class_dist": {int(k): int(v) for k, v in cls_dist.items()}, |
| "strict_ap": strict_ap, |
| "merged_ap": merged_ap, |
| "observe_ap": observe_ap, |
| } |
|
|
|
|
| def evaluate_checkpoint( |
| ckpt_name: str, |
| ckpt_dir: Path, |
| val_ds, |
| val_loader, |
| sources: np.ndarray, |
| hidden_dim: int, |
| seq_len: int, |
| ) -> Dict[str, Any]: |
| logger.info(f"βββ {ckpt_name} βββ") |
| logger.info(f" Checkpoint: {ckpt_dir}") |
| model, is_traj = load_policy_checkpoint(ckpt_dir, hidden_dim, seq_len) |
| probs = run_inference(model, val_loader, is_traj) |
| del model |
| torch.cuda.empty_cache() |
|
|
| labels = np.array([s["action_label"] for s in val_ds.samples], dtype=np.int64) |
| assert len(labels) == len(probs), (len(labels), len(probs)) |
|
|
| all_mask = np.ones_like(labels, dtype=bool) |
| nex_mask = sources == "nexar" |
| dada_mask = sources == "dada" |
|
|
| subsets = { |
| "all": compute_subset_metrics(probs, labels, all_mask, "all"), |
| "nexar": compute_subset_metrics(probs, labels, nex_mask, "nexar"), |
| "dada": compute_subset_metrics(probs, labels, dada_mask, "dada"), |
| } |
|
|
| meta = json.loads((ckpt_dir / "policy_meta.json").read_text()) |
|
|
| return { |
| "checkpoint": ckpt_name, |
| "checkpoint_path": str(ckpt_dir), |
| "version": meta.get("version"), |
| "seq_len": meta.get("seq_len", seq_len), |
| "train_policy_score": meta.get("grid_best_policy_score"), |
| "train_binary_ap": meta.get("binary_ap"), |
| "subsets": subsets, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _fmt_ap(v): |
| return "β " if v is None else f"{v:.4f}" |
|
|
|
|
| def print_table(results: List[Dict[str, Any]]): |
| print("\n" + "β" * 108) |
| print(" BINARY-COLLAPSE EVAL β for fair comparison with Nexar winner (MViT AP=0.898)") |
| print(" strict_ap : P(ALERT) only (same scoring rule as challenge; penalises OBSERVE)") |
| print(" merged_ap : P(ALERT)+P(OBS) (collapses 3-class β binary; our paper headline)") |
| print("β" * 108) |
| header = ( |
| f"{'checkpoint':<26}{'subset':<8}{'n':>7} " |
| f"{'strict_AP':>10} {'merged_AP':>10} {'observe_AP':>11} {'class_dist':<20}" |
| ) |
| print(header) |
| print("β" * 108) |
|
|
| for r in results: |
| for sub_name in ("all", "nexar", "dada"): |
| s = r["subsets"][sub_name] |
| if s["n"] == 0: |
| continue |
| print( |
| f"{r['checkpoint']:<26}{sub_name:<8}{s['n']:>7} " |
| f"{_fmt_ap(s['strict_ap']):>10} {_fmt_ap(s['merged_ap']):>10} " |
| f"{_fmt_ap(s['observe_ap']):>11} {str(s['class_dist']):<20}" |
| ) |
| print("β" * 108) |
|
|
| |
| print("\n Paper-facing numbers (merged_AP, i.e. ALERTβͺOBSERVE collapse):") |
| print(" " + " ".join( |
| f"{r['checkpoint']}={_fmt_ap(r['subsets']['nexar']['merged_ap'])}/nexar, " |
| f"{_fmt_ap(r['subsets']['dada']['merged_ap'])}/dada" |
| for r in results |
| )) |
| print(" External references:") |
| print(" Nexar-2025 winner (MViT-V2-S + 3LC) : strict_AP = 0.898 (nexar)") |
| print(" BADAS (V-JEPA2, arXiv 2510.14876) : AP on DAD/DADA/DoTA (see paper)") |
| print("β" * 108 + "\n") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser("eval_binary_collapse") |
| parser.add_argument( |
| "--checkpoints", nargs="+", required=True, |
| help="Policy checkpoint names under --ckpt_root (each must contain best/)." |
| ) |
| parser.add_argument("--ckpt_root", default="checkpoints/Policy") |
| parser.add_argument("--label_dir", default="data/policy_labels") |
| parser.add_argument("--cache_dir", default="data/belief_cache") |
| parser.add_argument("--output", default="eval_results/binary_collapse.json") |
| parser.add_argument("--seq_len", type=int, default=8, |
| help="Dataset context length β overridden per-ckpt by meta if present.") |
| parser.add_argument("--batch_size", type=int, default=256) |
| parser.add_argument("--use_trajectory_ds", action="store_true", |
| help="Use TrajectoryPolicyDataset (extra per-timestep fields); " |
| "required if any trajectory checkpoint is in the list.") |
| args = parser.parse_args() |
|
|
| label_dir = Path(args.label_dir) |
| cache_dir = Path(args.cache_dir) |
| ckpt_root = Path(args.ckpt_root) |
|
|
| |
| ckpt_dirs, seq_lens, has_traj = {}, [], False |
| for name in args.checkpoints: |
| d = ckpt_root / name / "best" |
| if not (d / "policy_head.pt").exists(): |
| raise FileNotFoundError(f"{d}/policy_head.pt not found") |
| meta = json.loads((d / "policy_meta.json").read_text()) |
| ckpt_dirs[name] = d |
| seq_lens.append(meta.get("seq_len", args.seq_len)) |
| if "trajectory" in meta.get("version", "") or "v7" in meta.get("version", ""): |
| has_traj = True |
|
|
| use_traj_ds = args.use_trajectory_ds or has_traj |
| ds_cls = TrajectoryPolicyDataset if use_traj_ds else TemporalPolicyDataset |
| collate = trajectory_collate_fn if use_traj_ds else temporal_collate_fn |
|
|
| |
| |
| unique_seq_lens = sorted(set(seq_lens)) |
| datasets = {} |
| loaders = {} |
| sources_ref = None |
| hidden_dim = None |
| for sl in unique_seq_lens: |
| ds = ds_cls( |
| manifests=[label_dir / "val.json"], |
| split="val", |
| belief_cache_path=cache_dir / "val.pt", |
| seq_len=sl, |
| ) |
| datasets[sl] = ds |
| loaders[sl] = DataLoader( |
| ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=collate, num_workers=2, pin_memory=True, |
| ) |
| if sources_ref is None: |
| sources_ref = np.array( |
| [s.get("source", "unknown") for s in ds.samples], dtype=object |
| ) |
| hidden_dim = ds._cache["beliefs"].shape[-1] |
| src_dist = Counter(sources_ref.tolist()) |
| logger.info(f"Source distribution: {dict(src_dist)}") |
| logger.info(f"Belief hidden_dim = {hidden_dim}") |
|
|
| results = [] |
| for name, d in ckpt_dirs.items(): |
| meta = json.loads((d / "policy_meta.json").read_text()) |
| sl = meta.get("seq_len", args.seq_len) |
| results.append( |
| evaluate_checkpoint( |
| name, d, datasets[sl], loaders[sl], sources_ref, hidden_dim, sl, |
| ) |
| ) |
|
|
| print_table(results) |
|
|
| out_path = Path(args.output) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(json.dumps( |
| {"checkpoints": results, "source_dist": dict(src_dist)}, |
| indent=2, default=float, |
| )) |
| logger.info(f"Saved -> {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|