File size: 13,490 Bytes
1e05592 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | #!/usr/bin/env python3
"""
make_cot_cache.py
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Generate K hazard-candidate Chain-of-Thought (CoT) sentences per policy
window, using the SFT-Qwen as the language reasoner.
Why
βββ
Phase 0b of the CoT-Pool plan. The CoT-Pool aggregator (M8βM14) needs
text-grounded queries. Generating them on-the-fly during training would
add ~K Γ VLM-decode per step. We pre-generate ONCE and cache.
Design choices (intentionally aggressive, per the safety-bias rule)
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β’ K = 8 candidates per window (paranoid: list every hazard, even unlikely)
β’ temperature = 0.9, top_p = 0.95 β diversity, not repetition
β’ Structured prompt asking for JSON-ish lines
{"entity":"β¦","location":"β¦","motion":"β¦","risk":"β¦"}
β’ The SFT-Qwen was fine-tuned on TTA only; its base instruction-following
capability is unchanged, so prompted hazard listing still works.
β’ This builder STORES raw generation text only. The three gates
(G1 self-consistency, G2 attn-entropy, G3 OVD cross-check) live in a
separate `verify_cot_cache.py` so we can iterate on filtering cheaply
without re-running expensive VLM generation.
Storage
βββββββ
data/cot_cache/{split}.jsonl.gz β one JSON line per window
schema:
{
"idx": int,
"video_id": str,
"category": str,
"action_label": int,
"candidates": [str, str, ..., str] # length K
}
Index alignment with PolicyDataset(manifest)["samples"][idx].
Usage
βββββ
python -m training.Policy.make_cot_cache \\
--sft_checkpoint checkpoints/SFT/sft_v2/best \\
--label_dir data/policy_labels \\
--out_dir data/cot_cache \\
--k 8 \\
--temperature 0.9 \\
--top_p 0.95 \\
--max_new_tokens 96 \\
--batch_size 4 \\
--splits val
"""
from __future__ import annotations
import argparse
import gzip
import json
import logging
from pathlib import Path
from typing import Any, Dict, List
import torch
from torch.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from transformers import AutoModelForImageTextToText, AutoProcessor
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.make_cot_cache")
SCHEMA_VERSION = 2
SYSTEM = (
"You are a defensive driving safety analyst. Your job is to enumerate "
"EVERY potentially dangerous element you can detect in a dashcam window "
"β err on the side of MORE hazards, never fewer. A missed hazard is "
"much worse than a false alarm."
)
USER_TEMPLATE = (
"Look at this {n}-frame dashcam window.\n"
"Context: {ctx}\n\n"
"List up to 4 distinct potential collision hazards. Be paranoid; if a "
"pedestrian, cyclist, vehicle, or unusual road condition could become "
"dangerous in the next ~3 seconds, list it.\n\n"
"Return ONE hazard per line, in this exact format:\n"
'HAZARD: entity="<noun>" | location="<L|C|R>-<near|mid|far>" | '
'motion="<approaching|crossing|braking|static>" | '
'risk="<low|med|high>" | reason="<short why>"\n\n'
"If no hazards exist write exactly: HAZARD: none"
)
def _ctx(meta: Dict[str, Any]) -> str:
parts = []
if meta.get("weather"): parts.append(f"weather={meta['weather']}")
if meta.get("road_type"): parts.append(f"road={meta['road_type']}")
if meta.get("time_of_day"): parts.append(f"time={meta['time_of_day']}")
return ", ".join(parts) or "urban driving"
class _CoTGenerator:
"""
Wraps a base VLM (no LoRA) + processor for hazard-listing generation.
We deliberately do NOT reuse PolicyModel / SFTModel here: the SFT-Qwen
LoRA was fine-tuned on TTA-scalar regression and has degraded language
ability β generation produces token soup. The BASE Qwen2.5-VL-Instruct
retains its instruction-following capability and is what we want for
offline CoT generation.
Optional: --use_sft_lora to restore the legacy (broken) behavior for
A/B comparison.
"""
def __init__(
self,
model_name: str = "PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct",
use_bf16: bool = True,
max_pixels: int = 768 * 28 * 28,
):
dtype = torch.bfloat16 if use_bf16 else torch.float32
self.amp_dtype = dtype
logger.info(f" Loading BASE VLM (no LoRA) for CoT gen: {model_name}")
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="cuda:0",
trust_remote_code=True,
attn_implementation="flash_attention_2",
)
self.model.eval()
self.model.config.use_cache = True
self.processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True,
min_pixels=256 * 28 * 28,
max_pixels=max_pixels,
)
# Decoder-only generation requires LEFT padding for correct results.
self.processor.tokenizer.padding_side = "left"
if self.processor.tokenizer.pad_token_id is None:
self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id
self.device = next(self.model.parameters()).device
self.dtype = next(self.model.parameters()).dtype
def _build_generation_inputs(gen: "_CoTGenerator", batch: Dict[str, Any]):
"""Build chat-template inputs for hazard-listing generation."""
proc = gen.processor
apply_chat = (
proc.apply_chat_template
if hasattr(proc, "apply_chat_template")
else proc.tokenizer.apply_chat_template
)
images_b = batch["images"]
metas = batch["metadata"]
texts: List[str] = []
for i in range(len(images_b)):
frames = images_b[i]
content = [{"type": "image"} for _ in range(len(frames))]
content.append({
"type": "text",
"text": USER_TEMPLATE.format(n=len(frames), ctx=_ctx(metas[i])),
})
msgs = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": content},
]
# add_generation_prompt=True so the model continues with assistant role
texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=True))
return proc(
text=texts, images=images_b,
return_tensors="pt", padding=True, truncation=True,
)
@torch.no_grad()
def _generate_k(
gen: "_CoTGenerator",
enc: Dict[str, torch.Tensor],
k: int,
temperature: float,
top_p: float,
max_new_tokens: int,
) -> List[List[str]]:
"""
Generate K candidates for each sample in `enc`. Returns a [B][K] list of
decoded strings (assistant-only, special tokens stripped).
"""
moved: Dict[str, torch.Tensor] = {}
for kk, vv in enc.items():
if not isinstance(vv, torch.Tensor):
moved[kk] = vv
continue
if kk == "pixel_values":
moved[kk] = vv.to(gen.device, dtype=gen.dtype, non_blocking=True)
else:
moved[kk] = vv.to(gen.device, non_blocking=True)
proc = gen.processor
pad_id = proc.tokenizer.pad_token_id
eos_id = proc.tokenizer.eos_token_id
input_len = moved["input_ids"].shape[1]
B = moved["input_ids"].shape[0]
gen_kwargs = dict(
do_sample = True,
temperature = float(temperature),
top_p = float(top_p),
max_new_tokens = int(max_new_tokens),
pad_token_id = pad_id,
eos_token_id = eos_id,
num_return_sequences= int(k),
use_cache = True,
)
with autocast(device_type="cuda", dtype=gen.amp_dtype, enabled=True):
out = gen.model.generate(**moved, **gen_kwargs)
# out shape: [B*K, in_len + new]
new_tokens = out[:, input_len:]
decoded = proc.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
# regroup B*K β B groups of K
grouped: List[List[str]] = []
for b in range(B):
grouped.append([decoded[b * k + j].strip() for j in range(k)])
return grouped
def _short_clean(s: str, max_chars: int = 400) -> str:
"""Lightly normalise generated text for storage."""
s = s.replace("\r", "").strip()
if len(s) > max_chars:
s = s[:max_chars] + "β¦"
return s
def build_split_cache(
gen: "_CoTGenerator",
loader: DataLoader,
out_path: Path,
k: int,
temperature: float,
top_p: float,
max_new_tokens: int,
samples_meta: List[Dict[str, Any]],
):
out_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = out_path.with_suffix(out_path.suffix + ".tmp")
sample_idx = 0
n_written = 0
with gzip.open(tmp_path, "wt", encoding="utf-8") as fout:
# First line: header
header = {
"schema_version": SCHEMA_VERSION,
"k_candidates": k,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"n_samples": len(samples_meta),
}
fout.write(json.dumps({"__header__": header}) + "\n")
for batch in tqdm(loader, desc=f" cot-gen {out_path.name}", ncols=100):
B = len(batch["images"])
enc = _build_generation_inputs(gen, batch)
cand_b = _generate_k(gen, enc, k, temperature, top_p, max_new_tokens)
for b in range(B):
meta = samples_meta[sample_idx]
rec = {
"idx": sample_idx,
"video_id": meta["video_id"],
"category": meta["category"],
"action_label": int(meta["action_label"]),
"candidates": [_short_clean(c) for c in cand_b[b]],
}
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
sample_idx += 1
n_written += 1
tmp_path.rename(out_path)
logger.info(f" wrote {n_written} CoT records β {out_path}")
def main():
ap = argparse.ArgumentParser("make_cot_cache")
ap.add_argument("--base_model", default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct",
help="Base VLM (no LoRA) β preserves instruction-following.")
ap.add_argument("--label_dir", default="data/policy_labels")
ap.add_argument("--out_dir", default="data/cot_cache")
ap.add_argument("--k", type=int, default=8,
help="Candidates per window (paranoid setting: 8)")
ap.add_argument("--temperature", type=float, default=0.9)
ap.add_argument("--top_p", type=float, default=0.95)
ap.add_argument("--max_new_tokens", type=int, default=96)
ap.add_argument("--batch_size", type=int, default=4)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--splits", nargs="+", default=["val", "train"])
ap.add_argument("--debug", action="store_true")
ap.add_argument("--debug_samples", type=int, default=8)
ap.add_argument("--overwrite", action="store_true")
args = ap.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
gen = _CoTGenerator(model_name=args.base_model, use_bf16=True)
logger.info(
f" CoT generation: K={args.k} T={args.temperature} "
f"top_p={args.top_p} max_new={args.max_new_tokens}"
)
for split in args.splits:
label_path = Path(args.label_dir) / f"{split}.json"
if not label_path.exists():
logger.warning(f" {label_path} missing β skip")
continue
out_path = out_dir / f"{split}.jsonl.gz"
if out_path.exists() and not args.overwrite:
logger.info(f" Cache exists: {out_path} β skip (use --overwrite)")
continue
ds = PolicyDataset(
manifests = [label_path],
split = split,
debug = args.debug,
debug_samples = args.debug_samples,
)
loader = DataLoader(
ds,
batch_size = args.batch_size,
shuffle = False,
num_workers = args.num_workers,
collate_fn = policy_collate_fn,
pin_memory = True,
)
samples_meta = ds.samples
build_split_cache(
gen, loader, out_path,
k = args.k,
temperature = args.temperature,
top_p = args.top_p,
max_new_tokens = args.max_new_tokens,
samples_meta = samples_meta,
)
logger.info("\ncot_cache complete.")
if __name__ == "__main__":
main()
|