| |
| """ |
| make_cot_cache.py |
| ═══════════════════════════════════════════════════════════════════════════════ |
| Generate K hazard-candidate Chain-of-Thought (CoT) sentences per policy |
| window, using the SFT-Qwen as the language reasoner. |
| |
| Why |
| ─── |
| Phase 0b of the CoT-Pool plan. The CoT-Pool aggregator (M8–M14) needs |
| text-grounded queries. Generating them on-the-fly during training would |
| add ~K × VLM-decode per step. We pre-generate ONCE and cache. |
| |
| Design choices (intentionally aggressive, per the safety-bias rule) |
| ─────────────────────────────────────────────────────────────────── |
| • K = 8 candidates per window (paranoid: list every hazard, even unlikely) |
| • temperature = 0.9, top_p = 0.95 → diversity, not repetition |
| • Structured prompt asking for JSON-ish lines |
| {"entity":"…","location":"…","motion":"…","risk":"…"} |
| • The SFT-Qwen was fine-tuned on TTA only; its base instruction-following |
| capability is unchanged, so prompted hazard listing still works. |
| • This builder STORES raw generation text only. The three gates |
| (G1 self-consistency, G2 attn-entropy, G3 OVD cross-check) live in a |
| separate `verify_cot_cache.py` so we can iterate on filtering cheaply |
| without re-running expensive VLM generation. |
| |
| Storage |
| ─────── |
| data/cot_cache/{split}.jsonl.gz — one JSON line per window |
| schema: |
| { |
| "idx": int, |
| "video_id": str, |
| "category": str, |
| "action_label": int, |
| "candidates": [str, str, ..., str] # length K |
| } |
| Index alignment with PolicyDataset(manifest)["samples"][idx]. |
| |
| Usage |
| ───── |
| python -m training.Policy.make_cot_cache \\ |
| --sft_checkpoint checkpoints/SFT/sft_v2/best \\ |
| --label_dir data/policy_labels \\ |
| --out_dir data/cot_cache \\ |
| --k 8 \\ |
| --temperature 0.9 \\ |
| --top_p 0.95 \\ |
| --max_new_tokens 96 \\ |
| --batch_size 4 \\ |
| --splits val |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import gzip |
| import json |
| import logging |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| import torch |
| from torch.amp import autocast |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
|
|
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
| from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn |
|
|
| logging.basicConfig(level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.make_cot_cache") |
|
|
|
|
| SCHEMA_VERSION = 2 |
|
|
| SYSTEM = ( |
| "You are a defensive driving safety analyst. Your job is to enumerate " |
| "EVERY potentially dangerous element you can detect in a dashcam window " |
| "— err on the side of MORE hazards, never fewer. A missed hazard is " |
| "much worse than a false alarm." |
| ) |
|
|
| USER_TEMPLATE = ( |
| "Look at this {n}-frame dashcam window.\n" |
| "Context: {ctx}\n\n" |
| "List up to 4 distinct potential collision hazards. Be paranoid; if a " |
| "pedestrian, cyclist, vehicle, or unusual road condition could become " |
| "dangerous in the next ~3 seconds, list it.\n\n" |
| "Return ONE hazard per line, in this exact format:\n" |
| 'HAZARD: entity="<noun>" | location="<L|C|R>-<near|mid|far>" | ' |
| 'motion="<approaching|crossing|braking|static>" | ' |
| 'risk="<low|med|high>" | reason="<short why>"\n\n' |
| "If no hazards exist write exactly: HAZARD: none" |
| ) |
|
|
|
|
| def _ctx(meta: Dict[str, Any]) -> str: |
| parts = [] |
| if meta.get("weather"): parts.append(f"weather={meta['weather']}") |
| if meta.get("road_type"): parts.append(f"road={meta['road_type']}") |
| if meta.get("time_of_day"): parts.append(f"time={meta['time_of_day']}") |
| return ", ".join(parts) or "urban driving" |
|
|
|
|
| class _CoTGenerator: |
| """ |
| Wraps a base VLM (no LoRA) + processor for hazard-listing generation. |
| |
| We deliberately do NOT reuse PolicyModel / SFTModel here: the SFT-Qwen |
| LoRA was fine-tuned on TTA-scalar regression and has degraded language |
| ability — generation produces token soup. The BASE Qwen2.5-VL-Instruct |
| retains its instruction-following capability and is what we want for |
| offline CoT generation. |
| |
| Optional: --use_sft_lora to restore the legacy (broken) behavior for |
| A/B comparison. |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct", |
| use_bf16: bool = True, |
| max_pixels: int = 768 * 28 * 28, |
| ): |
| dtype = torch.bfloat16 if use_bf16 else torch.float32 |
| self.amp_dtype = dtype |
| logger.info(f" Loading BASE VLM (no LoRA) for CoT gen: {model_name}") |
| self.model = AutoModelForImageTextToText.from_pretrained( |
| model_name, |
| torch_dtype=dtype, |
| device_map="cuda:0", |
| trust_remote_code=True, |
| attn_implementation="flash_attention_2", |
| ) |
| self.model.eval() |
| self.model.config.use_cache = True |
| self.processor = AutoProcessor.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| min_pixels=256 * 28 * 28, |
| max_pixels=max_pixels, |
| ) |
| |
| self.processor.tokenizer.padding_side = "left" |
| if self.processor.tokenizer.pad_token_id is None: |
| self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id |
| self.device = next(self.model.parameters()).device |
| self.dtype = next(self.model.parameters()).dtype |
|
|
|
|
| def _build_generation_inputs(gen: "_CoTGenerator", batch: Dict[str, Any]): |
| """Build chat-template inputs for hazard-listing generation.""" |
| proc = gen.processor |
| apply_chat = ( |
| proc.apply_chat_template |
| if hasattr(proc, "apply_chat_template") |
| else proc.tokenizer.apply_chat_template |
| ) |
| images_b = batch["images"] |
| metas = batch["metadata"] |
| texts: List[str] = [] |
| for i in range(len(images_b)): |
| frames = images_b[i] |
| content = [{"type": "image"} for _ in range(len(frames))] |
| content.append({ |
| "type": "text", |
| "text": USER_TEMPLATE.format(n=len(frames), ctx=_ctx(metas[i])), |
| }) |
| msgs = [ |
| {"role": "system", "content": SYSTEM}, |
| {"role": "user", "content": content}, |
| ] |
| |
| texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=True)) |
| return proc( |
| text=texts, images=images_b, |
| return_tensors="pt", padding=True, truncation=True, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def _generate_k( |
| gen: "_CoTGenerator", |
| enc: Dict[str, torch.Tensor], |
| k: int, |
| temperature: float, |
| top_p: float, |
| max_new_tokens: int, |
| ) -> List[List[str]]: |
| """ |
| Generate K candidates for each sample in `enc`. Returns a [B][K] list of |
| decoded strings (assistant-only, special tokens stripped). |
| """ |
| moved: Dict[str, torch.Tensor] = {} |
| for kk, vv in enc.items(): |
| if not isinstance(vv, torch.Tensor): |
| moved[kk] = vv |
| continue |
| if kk == "pixel_values": |
| moved[kk] = vv.to(gen.device, dtype=gen.dtype, non_blocking=True) |
| else: |
| moved[kk] = vv.to(gen.device, non_blocking=True) |
|
|
| proc = gen.processor |
| pad_id = proc.tokenizer.pad_token_id |
| eos_id = proc.tokenizer.eos_token_id |
| input_len = moved["input_ids"].shape[1] |
| B = moved["input_ids"].shape[0] |
|
|
| gen_kwargs = dict( |
| do_sample = True, |
| temperature = float(temperature), |
| top_p = float(top_p), |
| max_new_tokens = int(max_new_tokens), |
| pad_token_id = pad_id, |
| eos_token_id = eos_id, |
| num_return_sequences= int(k), |
| use_cache = True, |
| ) |
|
|
| with autocast(device_type="cuda", dtype=gen.amp_dtype, enabled=True): |
| out = gen.model.generate(**moved, **gen_kwargs) |
| |
| new_tokens = out[:, input_len:] |
| decoded = proc.tokenizer.batch_decode(new_tokens, skip_special_tokens=True) |
|
|
| |
| grouped: List[List[str]] = [] |
| for b in range(B): |
| grouped.append([decoded[b * k + j].strip() for j in range(k)]) |
| return grouped |
|
|
|
|
| def _short_clean(s: str, max_chars: int = 400) -> str: |
| """Lightly normalise generated text for storage.""" |
| s = s.replace("\r", "").strip() |
| if len(s) > max_chars: |
| s = s[:max_chars] + "…" |
| return s |
|
|
|
|
| def build_split_cache( |
| gen: "_CoTGenerator", |
| loader: DataLoader, |
| out_path: Path, |
| k: int, |
| temperature: float, |
| top_p: float, |
| max_new_tokens: int, |
| samples_meta: List[Dict[str, Any]], |
| ): |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") |
|
|
| sample_idx = 0 |
| n_written = 0 |
| with gzip.open(tmp_path, "wt", encoding="utf-8") as fout: |
| |
| header = { |
| "schema_version": SCHEMA_VERSION, |
| "k_candidates": k, |
| "temperature": temperature, |
| "top_p": top_p, |
| "max_new_tokens": max_new_tokens, |
| "n_samples": len(samples_meta), |
| } |
| fout.write(json.dumps({"__header__": header}) + "\n") |
|
|
| for batch in tqdm(loader, desc=f" cot-gen {out_path.name}", ncols=100): |
| B = len(batch["images"]) |
| enc = _build_generation_inputs(gen, batch) |
| cand_b = _generate_k(gen, enc, k, temperature, top_p, max_new_tokens) |
|
|
| for b in range(B): |
| meta = samples_meta[sample_idx] |
| rec = { |
| "idx": sample_idx, |
| "video_id": meta["video_id"], |
| "category": meta["category"], |
| "action_label": int(meta["action_label"]), |
| "candidates": [_short_clean(c) for c in cand_b[b]], |
| } |
| fout.write(json.dumps(rec, ensure_ascii=False) + "\n") |
| sample_idx += 1 |
| n_written += 1 |
|
|
| tmp_path.rename(out_path) |
| logger.info(f" wrote {n_written} CoT records → {out_path}") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser("make_cot_cache") |
| ap.add_argument("--base_model", default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct", |
| help="Base VLM (no LoRA) — preserves instruction-following.") |
| ap.add_argument("--label_dir", default="data/policy_labels") |
| ap.add_argument("--out_dir", default="data/cot_cache") |
| ap.add_argument("--k", type=int, default=8, |
| help="Candidates per window (paranoid setting: 8)") |
| ap.add_argument("--temperature", type=float, default=0.9) |
| ap.add_argument("--top_p", type=float, default=0.95) |
| ap.add_argument("--max_new_tokens", type=int, default=96) |
| ap.add_argument("--batch_size", type=int, default=4) |
| ap.add_argument("--num_workers", type=int, default=0) |
| ap.add_argument("--splits", nargs="+", default=["val", "train"]) |
| ap.add_argument("--debug", action="store_true") |
| ap.add_argument("--debug_samples", type=int, default=8) |
| ap.add_argument("--overwrite", action="store_true") |
| args = ap.parse_args() |
|
|
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| gen = _CoTGenerator(model_name=args.base_model, use_bf16=True) |
| logger.info( |
| f" CoT generation: K={args.k} T={args.temperature} " |
| f"top_p={args.top_p} max_new={args.max_new_tokens}" |
| ) |
|
|
| for split in args.splits: |
| label_path = Path(args.label_dir) / f"{split}.json" |
| if not label_path.exists(): |
| logger.warning(f" {label_path} missing — skip") |
| continue |
| out_path = out_dir / f"{split}.jsonl.gz" |
| if out_path.exists() and not args.overwrite: |
| logger.info(f" Cache exists: {out_path} — skip (use --overwrite)") |
| continue |
|
|
| ds = PolicyDataset( |
| manifests = [label_path], |
| split = split, |
| debug = args.debug, |
| debug_samples = args.debug_samples, |
| ) |
| loader = DataLoader( |
| ds, |
| batch_size = args.batch_size, |
| shuffle = False, |
| num_workers = args.num_workers, |
| collate_fn = policy_collate_fn, |
| pin_memory = True, |
| ) |
| samples_meta = ds.samples |
|
|
| build_split_cache( |
| gen, loader, out_path, |
| k = args.k, |
| temperature = args.temperature, |
| top_p = args.top_p, |
| max_new_tokens = args.max_new_tokens, |
| samples_meta = samples_meta, |
| ) |
|
|
| logger.info("\ncot_cache complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|