VLAlert / training /Policy /make_cot_cache.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
#!/usr/bin/env python3
"""
make_cot_cache.py
═══════════════════════════════════════════════════════════════════════════════
Generate K hazard-candidate Chain-of-Thought (CoT) sentences per policy
window, using the SFT-Qwen as the language reasoner.
Why
───
Phase 0b of the CoT-Pool plan. The CoT-Pool aggregator (M8–M14) needs
text-grounded queries. Generating them on-the-fly during training would
add ~K × VLM-decode per step. We pre-generate ONCE and cache.
Design choices (intentionally aggressive, per the safety-bias rule)
───────────────────────────────────────────────────────────────────
• K = 8 candidates per window (paranoid: list every hazard, even unlikely)
• temperature = 0.9, top_p = 0.95 → diversity, not repetition
• Structured prompt asking for JSON-ish lines
{"entity":"…","location":"…","motion":"…","risk":"…"}
• The SFT-Qwen was fine-tuned on TTA only; its base instruction-following
capability is unchanged, so prompted hazard listing still works.
• This builder STORES raw generation text only. The three gates
(G1 self-consistency, G2 attn-entropy, G3 OVD cross-check) live in a
separate `verify_cot_cache.py` so we can iterate on filtering cheaply
without re-running expensive VLM generation.
Storage
───────
data/cot_cache/{split}.jsonl.gz — one JSON line per window
schema:
{
"idx": int,
"video_id": str,
"category": str,
"action_label": int,
"candidates": [str, str, ..., str] # length K
}
Index alignment with PolicyDataset(manifest)["samples"][idx].
Usage
─────
python -m training.Policy.make_cot_cache \\
--sft_checkpoint checkpoints/SFT/sft_v2/best \\
--label_dir data/policy_labels \\
--out_dir data/cot_cache \\
--k 8 \\
--temperature 0.9 \\
--top_p 0.95 \\
--max_new_tokens 96 \\
--batch_size 4 \\
--splits val
"""
from __future__ import annotations
import argparse
import gzip
import json
import logging
from pathlib import Path
from typing import Any, Dict, List
import torch
from torch.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from transformers import AutoModelForImageTextToText, AutoProcessor
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.make_cot_cache")
SCHEMA_VERSION = 2
SYSTEM = (
"You are a defensive driving safety analyst. Your job is to enumerate "
"EVERY potentially dangerous element you can detect in a dashcam window "
"— err on the side of MORE hazards, never fewer. A missed hazard is "
"much worse than a false alarm."
)
USER_TEMPLATE = (
"Look at this {n}-frame dashcam window.\n"
"Context: {ctx}\n\n"
"List up to 4 distinct potential collision hazards. Be paranoid; if a "
"pedestrian, cyclist, vehicle, or unusual road condition could become "
"dangerous in the next ~3 seconds, list it.\n\n"
"Return ONE hazard per line, in this exact format:\n"
'HAZARD: entity="<noun>" | location="<L|C|R>-<near|mid|far>" | '
'motion="<approaching|crossing|braking|static>" | '
'risk="<low|med|high>" | reason="<short why>"\n\n'
"If no hazards exist write exactly: HAZARD: none"
)
def _ctx(meta: Dict[str, Any]) -> str:
parts = []
if meta.get("weather"): parts.append(f"weather={meta['weather']}")
if meta.get("road_type"): parts.append(f"road={meta['road_type']}")
if meta.get("time_of_day"): parts.append(f"time={meta['time_of_day']}")
return ", ".join(parts) or "urban driving"
class _CoTGenerator:
"""
Wraps a base VLM (no LoRA) + processor for hazard-listing generation.
We deliberately do NOT reuse PolicyModel / SFTModel here: the SFT-Qwen
LoRA was fine-tuned on TTA-scalar regression and has degraded language
ability — generation produces token soup. The BASE Qwen2.5-VL-Instruct
retains its instruction-following capability and is what we want for
offline CoT generation.
Optional: --use_sft_lora to restore the legacy (broken) behavior for
A/B comparison.
"""
def __init__(
self,
model_name: str = "PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct",
use_bf16: bool = True,
max_pixels: int = 768 * 28 * 28,
):
dtype = torch.bfloat16 if use_bf16 else torch.float32
self.amp_dtype = dtype
logger.info(f" Loading BASE VLM (no LoRA) for CoT gen: {model_name}")
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="cuda:0",
trust_remote_code=True,
attn_implementation="flash_attention_2",
)
self.model.eval()
self.model.config.use_cache = True
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True,
min_pixels=256 * 28 * 28,
max_pixels=max_pixels,
)
# Decoder-only generation requires LEFT padding for correct results.
self.processor.tokenizer.padding_side = "left"
if self.processor.tokenizer.pad_token_id is None:
self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id
self.device = next(self.model.parameters()).device
self.dtype = next(self.model.parameters()).dtype
def _build_generation_inputs(gen: "_CoTGenerator", batch: Dict[str, Any]):
"""Build chat-template inputs for hazard-listing generation."""
proc = gen.processor
apply_chat = (
proc.apply_chat_template
if hasattr(proc, "apply_chat_template")
else proc.tokenizer.apply_chat_template
)
images_b = batch["images"]
metas = batch["metadata"]
texts: List[str] = []
for i in range(len(images_b)):
frames = images_b[i]
content = [{"type": "image"} for _ in range(len(frames))]
content.append({
"type": "text",
"text": USER_TEMPLATE.format(n=len(frames), ctx=_ctx(metas[i])),
})
msgs = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": content},
]
# add_generation_prompt=True so the model continues with assistant role
texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=True))
return proc(
text=texts, images=images_b,
return_tensors="pt", padding=True, truncation=True,
)
@torch.no_grad()
def _generate_k(
gen: "_CoTGenerator",
enc: Dict[str, torch.Tensor],
k: int,
temperature: float,
top_p: float,
max_new_tokens: int,
) -> List[List[str]]:
"""
Generate K candidates for each sample in `enc`. Returns a [B][K] list of
decoded strings (assistant-only, special tokens stripped).
"""
moved: Dict[str, torch.Tensor] = {}
for kk, vv in enc.items():
if not isinstance(vv, torch.Tensor):
moved[kk] = vv
continue
if kk == "pixel_values":
moved[kk] = vv.to(gen.device, dtype=gen.dtype, non_blocking=True)
else:
moved[kk] = vv.to(gen.device, non_blocking=True)
proc = gen.processor
pad_id = proc.tokenizer.pad_token_id
eos_id = proc.tokenizer.eos_token_id
input_len = moved["input_ids"].shape[1]
B = moved["input_ids"].shape[0]
gen_kwargs = dict(
do_sample = True,
temperature = float(temperature),
top_p = float(top_p),
max_new_tokens = int(max_new_tokens),
pad_token_id = pad_id,
eos_token_id = eos_id,
num_return_sequences= int(k),
use_cache = True,
)
with autocast(device_type="cuda", dtype=gen.amp_dtype, enabled=True):
out = gen.model.generate(**moved, **gen_kwargs)
# out shape: [B*K, in_len + new]
new_tokens = out[:, input_len:]
decoded = proc.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
# regroup B*K → B groups of K
grouped: List[List[str]] = []
for b in range(B):
grouped.append([decoded[b * k + j].strip() for j in range(k)])
return grouped
def _short_clean(s: str, max_chars: int = 400) -> str:
"""Lightly normalise generated text for storage."""
s = s.replace("\r", "").strip()
if len(s) > max_chars:
s = s[:max_chars] + "…"
return s
def build_split_cache(
gen: "_CoTGenerator",
loader: DataLoader,
out_path: Path,
k: int,
temperature: float,
top_p: float,
max_new_tokens: int,
samples_meta: List[Dict[str, Any]],
):
out_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = out_path.with_suffix(out_path.suffix + ".tmp")
sample_idx = 0
n_written = 0
with gzip.open(tmp_path, "wt", encoding="utf-8") as fout:
# First line: header
header = {
"schema_version": SCHEMA_VERSION,
"k_candidates": k,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"n_samples": len(samples_meta),
}
fout.write(json.dumps({"__header__": header}) + "\n")
for batch in tqdm(loader, desc=f" cot-gen {out_path.name}", ncols=100):
B = len(batch["images"])
enc = _build_generation_inputs(gen, batch)
cand_b = _generate_k(gen, enc, k, temperature, top_p, max_new_tokens)
for b in range(B):
meta = samples_meta[sample_idx]
rec = {
"idx": sample_idx,
"video_id": meta["video_id"],
"category": meta["category"],
"action_label": int(meta["action_label"]),
"candidates": [_short_clean(c) for c in cand_b[b]],
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
sample_idx += 1
n_written += 1
tmp_path.rename(out_path)
logger.info(f" wrote {n_written} CoT records → {out_path}")
def main():
ap = argparse.ArgumentParser("make_cot_cache")
ap.add_argument("--base_model", default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct",
help="Base VLM (no LoRA) — preserves instruction-following.")
ap.add_argument("--label_dir", default="data/policy_labels")
ap.add_argument("--out_dir", default="data/cot_cache")
ap.add_argument("--k", type=int, default=8,
help="Candidates per window (paranoid setting: 8)")
ap.add_argument("--temperature", type=float, default=0.9)
ap.add_argument("--top_p", type=float, default=0.95)
ap.add_argument("--max_new_tokens", type=int, default=96)
ap.add_argument("--batch_size", type=int, default=4)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--splits", nargs="+", default=["val", "train"])
ap.add_argument("--debug", action="store_true")
ap.add_argument("--debug_samples", type=int, default=8)
ap.add_argument("--overwrite", action="store_true")
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
gen = _CoTGenerator(model_name=args.base_model, use_bf16=True)
logger.info(
f" CoT generation: K={args.k} T={args.temperature} "
f"top_p={args.top_p} max_new={args.max_new_tokens}"
)
for split in args.splits:
label_path = Path(args.label_dir) / f"{split}.json"
if not label_path.exists():
logger.warning(f" {label_path} missing — skip")
continue
out_path = out_dir / f"{split}.jsonl.gz"
if out_path.exists() and not args.overwrite:
logger.info(f" Cache exists: {out_path} — skip (use --overwrite)")
continue
ds = PolicyDataset(
manifests = [label_path],
split = split,
debug = args.debug,
debug_samples = args.debug_samples,
)
loader = DataLoader(
ds,
batch_size = args.batch_size,
shuffle = False,
num_workers = args.num_workers,
collate_fn = policy_collate_fn,
pin_memory = True,
)
samples_meta = ds.samples
build_split_cache(
gen, loader, out_path,
k = args.k,
temperature = args.temperature,
top_p = args.top_p,
max_new_tokens = args.max_new_tokens,
samples_meta = samples_meta,
)
logger.info("\ncot_cache complete.")
if __name__ == "__main__":
main()