#!/usr/bin/env python3 """ make_cot_cache.py ═══════════════════════════════════════════════════════════════════════════════ Generate K hazard-candidate Chain-of-Thought (CoT) sentences per policy window, using the SFT-Qwen as the language reasoner. Why ─── Phase 0b of the CoT-Pool plan. The CoT-Pool aggregator (M8–M14) needs text-grounded queries. Generating them on-the-fly during training would add ~K × VLM-decode per step. We pre-generate ONCE and cache. Design choices (intentionally aggressive, per the safety-bias rule) ─────────────────────────────────────────────────────────────────── • K = 8 candidates per window (paranoid: list every hazard, even unlikely) • temperature = 0.9, top_p = 0.95 → diversity, not repetition • Structured prompt asking for JSON-ish lines {"entity":"…","location":"…","motion":"…","risk":"…"} • The SFT-Qwen was fine-tuned on TTA only; its base instruction-following capability is unchanged, so prompted hazard listing still works. • This builder STORES raw generation text only. The three gates (G1 self-consistency, G2 attn-entropy, G3 OVD cross-check) live in a separate `verify_cot_cache.py` so we can iterate on filtering cheaply without re-running expensive VLM generation. Storage ─────── data/cot_cache/{split}.jsonl.gz — one JSON line per window schema: { "idx": int, "video_id": str, "category": str, "action_label": int, "candidates": [str, str, ..., str] # length K } Index alignment with PolicyDataset(manifest)["samples"][idx]. Usage ───── python -m training.Policy.make_cot_cache \\ --sft_checkpoint checkpoints/SFT/sft_v2/best \\ --label_dir data/policy_labels \\ --out_dir data/cot_cache \\ --k 8 \\ --temperature 0.9 \\ --top_p 0.95 \\ --max_new_tokens 96 \\ --batch_size 4 \\ --splits val """ from __future__ import annotations import argparse import gzip import json import logging from pathlib import Path from typing import Any, Dict, List import torch from torch.amp import autocast from torch.utils.data import DataLoader from tqdm import tqdm import sys sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from transformers import AutoModelForImageTextToText, AutoProcessor from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("Policy.make_cot_cache") SCHEMA_VERSION = 2 SYSTEM = ( "You are a defensive driving safety analyst. Your job is to enumerate " "EVERY potentially dangerous element you can detect in a dashcam window " "— err on the side of MORE hazards, never fewer. A missed hazard is " "much worse than a false alarm." ) USER_TEMPLATE = ( "Look at this {n}-frame dashcam window.\n" "Context: {ctx}\n\n" "List up to 4 distinct potential collision hazards. Be paranoid; if a " "pedestrian, cyclist, vehicle, or unusual road condition could become " "dangerous in the next ~3 seconds, list it.\n\n" "Return ONE hazard per line, in this exact format:\n" 'HAZARD: entity="" | location="-" | ' 'motion="" | ' 'risk="" | reason=""\n\n' "If no hazards exist write exactly: HAZARD: none" ) def _ctx(meta: Dict[str, Any]) -> str: parts = [] if meta.get("weather"): parts.append(f"weather={meta['weather']}") if meta.get("road_type"): parts.append(f"road={meta['road_type']}") if meta.get("time_of_day"): parts.append(f"time={meta['time_of_day']}") return ", ".join(parts) or "urban driving" class _CoTGenerator: """ Wraps a base VLM (no LoRA) + processor for hazard-listing generation. We deliberately do NOT reuse PolicyModel / SFTModel here: the SFT-Qwen LoRA was fine-tuned on TTA-scalar regression and has degraded language ability — generation produces token soup. The BASE Qwen2.5-VL-Instruct retains its instruction-following capability and is what we want for offline CoT generation. Optional: --use_sft_lora to restore the legacy (broken) behavior for A/B comparison. """ def __init__( self, model_name: str = "PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct", use_bf16: bool = True, max_pixels: int = 768 * 28 * 28, ): dtype = torch.bfloat16 if use_bf16 else torch.float32 self.amp_dtype = dtype logger.info(f" Loading BASE VLM (no LoRA) for CoT gen: {model_name}") self.model = AutoModelForImageTextToText.from_pretrained( model_name, torch_dtype=dtype, device_map="cuda:0", trust_remote_code=True, attn_implementation="flash_attention_2", ) self.model.eval() self.model.config.use_cache = True self.processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True, min_pixels=256 * 28 * 28, max_pixels=max_pixels, ) # Decoder-only generation requires LEFT padding for correct results. self.processor.tokenizer.padding_side = "left" if self.processor.tokenizer.pad_token_id is None: self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id self.device = next(self.model.parameters()).device self.dtype = next(self.model.parameters()).dtype def _build_generation_inputs(gen: "_CoTGenerator", batch: Dict[str, Any]): """Build chat-template inputs for hazard-listing generation.""" proc = gen.processor apply_chat = ( proc.apply_chat_template if hasattr(proc, "apply_chat_template") else proc.tokenizer.apply_chat_template ) images_b = batch["images"] metas = batch["metadata"] texts: List[str] = [] for i in range(len(images_b)): frames = images_b[i] content = [{"type": "image"} for _ in range(len(frames))] content.append({ "type": "text", "text": USER_TEMPLATE.format(n=len(frames), ctx=_ctx(metas[i])), }) msgs = [ {"role": "system", "content": SYSTEM}, {"role": "user", "content": content}, ] # add_generation_prompt=True so the model continues with assistant role texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=True)) return proc( text=texts, images=images_b, return_tensors="pt", padding=True, truncation=True, ) @torch.no_grad() def _generate_k( gen: "_CoTGenerator", enc: Dict[str, torch.Tensor], k: int, temperature: float, top_p: float, max_new_tokens: int, ) -> List[List[str]]: """ Generate K candidates for each sample in `enc`. Returns a [B][K] list of decoded strings (assistant-only, special tokens stripped). """ moved: Dict[str, torch.Tensor] = {} for kk, vv in enc.items(): if not isinstance(vv, torch.Tensor): moved[kk] = vv continue if kk == "pixel_values": moved[kk] = vv.to(gen.device, dtype=gen.dtype, non_blocking=True) else: moved[kk] = vv.to(gen.device, non_blocking=True) proc = gen.processor pad_id = proc.tokenizer.pad_token_id eos_id = proc.tokenizer.eos_token_id input_len = moved["input_ids"].shape[1] B = moved["input_ids"].shape[0] gen_kwargs = dict( do_sample = True, temperature = float(temperature), top_p = float(top_p), max_new_tokens = int(max_new_tokens), pad_token_id = pad_id, eos_token_id = eos_id, num_return_sequences= int(k), use_cache = True, ) with autocast(device_type="cuda", dtype=gen.amp_dtype, enabled=True): out = gen.model.generate(**moved, **gen_kwargs) # out shape: [B*K, in_len + new] new_tokens = out[:, input_len:] decoded = proc.tokenizer.batch_decode(new_tokens, skip_special_tokens=True) # regroup B*K → B groups of K grouped: List[List[str]] = [] for b in range(B): grouped.append([decoded[b * k + j].strip() for j in range(k)]) return grouped def _short_clean(s: str, max_chars: int = 400) -> str: """Lightly normalise generated text for storage.""" s = s.replace("\r", "").strip() if len(s) > max_chars: s = s[:max_chars] + "…" return s def build_split_cache( gen: "_CoTGenerator", loader: DataLoader, out_path: Path, k: int, temperature: float, top_p: float, max_new_tokens: int, samples_meta: List[Dict[str, Any]], ): out_path.parent.mkdir(parents=True, exist_ok=True) tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") sample_idx = 0 n_written = 0 with gzip.open(tmp_path, "wt", encoding="utf-8") as fout: # First line: header header = { "schema_version": SCHEMA_VERSION, "k_candidates": k, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "n_samples": len(samples_meta), } fout.write(json.dumps({"__header__": header}) + "\n") for batch in tqdm(loader, desc=f" cot-gen {out_path.name}", ncols=100): B = len(batch["images"]) enc = _build_generation_inputs(gen, batch) cand_b = _generate_k(gen, enc, k, temperature, top_p, max_new_tokens) for b in range(B): meta = samples_meta[sample_idx] rec = { "idx": sample_idx, "video_id": meta["video_id"], "category": meta["category"], "action_label": int(meta["action_label"]), "candidates": [_short_clean(c) for c in cand_b[b]], } fout.write(json.dumps(rec, ensure_ascii=False) + "\n") sample_idx += 1 n_written += 1 tmp_path.rename(out_path) logger.info(f" wrote {n_written} CoT records → {out_path}") def main(): ap = argparse.ArgumentParser("make_cot_cache") ap.add_argument("--base_model", default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct", help="Base VLM (no LoRA) — preserves instruction-following.") ap.add_argument("--label_dir", default="data/policy_labels") ap.add_argument("--out_dir", default="data/cot_cache") ap.add_argument("--k", type=int, default=8, help="Candidates per window (paranoid setting: 8)") ap.add_argument("--temperature", type=float, default=0.9) ap.add_argument("--top_p", type=float, default=0.95) ap.add_argument("--max_new_tokens", type=int, default=96) ap.add_argument("--batch_size", type=int, default=4) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--splits", nargs="+", default=["val", "train"]) ap.add_argument("--debug", action="store_true") ap.add_argument("--debug_samples", type=int, default=8) ap.add_argument("--overwrite", action="store_true") args = ap.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) gen = _CoTGenerator(model_name=args.base_model, use_bf16=True) logger.info( f" CoT generation: K={args.k} T={args.temperature} " f"top_p={args.top_p} max_new={args.max_new_tokens}" ) for split in args.splits: label_path = Path(args.label_dir) / f"{split}.json" if not label_path.exists(): logger.warning(f" {label_path} missing — skip") continue out_path = out_dir / f"{split}.jsonl.gz" if out_path.exists() and not args.overwrite: logger.info(f" Cache exists: {out_path} — skip (use --overwrite)") continue ds = PolicyDataset( manifests = [label_path], split = split, debug = args.debug, debug_samples = args.debug_samples, ) loader = DataLoader( ds, batch_size = args.batch_size, shuffle = False, num_workers = args.num_workers, collate_fn = policy_collate_fn, pin_memory = True, ) samples_meta = ds.samples build_split_cache( gen, loader, out_path, k = args.k, temperature = args.temperature, top_p = args.top_p, max_new_tokens = args.max_new_tokens, samples_meta = samples_meta, ) logger.info("\ncot_cache complete.") if __name__ == "__main__": main()