"""VLAlert-X v2 Phase 2 — dual-stream cache extractor (leak-free). For each (video, 8-frame) tick, build a prompt that contains the per-frame BELIEF reasoning text but NO action tokens (this is the key: GT actions never enter causal attention so neither stream leaks). Scene: ... (optional, from manifest) Critical: ... (optional) <|BELIEF|> {belief_text_0} <|BELIEF|> {belief_text_1} ... <|BELIEF|> {belief_text_7} Forward through Qwen3-VL-4B (SFT'd, `checkpoints/sft_x_v2/best`) with `output_hidden_states=True`, then extract two complementary features per frame: (A) BELIEF_CONTENT[f] "perception/risk-cue register" = mean-pool hidden states over tokens BETWEEN the f-th `<|BELIEF|>` and the matching ``, EXCLUDING the two tags themselves. Concat hidden_states from layers {20, 24, 28, 32}. shape: [8, 4 × 2560] = [8, 10240] (B) POLICY_POSITION[f] "decision-time register" = hidden state AT the position of the f-th `` closing tag. Single layer 33. shape: [8, 2560] The position right after `` is where the SFT model committed to the next-token prediction (=action). At that position the model has just finished reading the belief reasoning and is about to emit the action; the hidden state encodes its commitment state. Output cache: data/belief_cache_v2/{tag}__{split}.pt = { "ids": list[str] (N,) "belief_content": tensor [N, 8, 10240] fp16 "policy_position": tensor [N, 8, 2560] fp16 "valid_frames": tensor [N, 8] bool "actions_pf": tensor [N, 8] long "danger_pf": tensor [N, 8] fp32 "tta_pf": tensor [N, 8] fp32 "tick_action": tensor [N] long "tick_tta_raw": tensor [N] fp32 "category": list[str] "source": list[str] "video_id": list[str] "schema": "vlalert_x_v2_dual_pool" "belief_layers": [20, 24, 28, 32] "policy_layer": 33 } Usage: python tools/make_cache_x_v2.py --split train python tools/make_cache_x_v2.py --split val """ from __future__ import annotations # PR patch must run BEFORE Qwen3-VL import import sys sys.path.insert(0, ".") from tools import run_train_cot_belief_fast # noqa: F401 import argparse import json import logging import re import time from pathlib import Path from typing import Dict, List, Tuple import torch from tqdm import tqdm ROOT = Path(__file__).resolve().parents[1] logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("make_cache_x_v2") ACTION_NAME_TO_IDX = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2} def build_extraction_assistant(beliefs_per_frame: List[str], scene: str = "", critical: str = "") -> str: """Same as SFT format_assistant_v2 but ACTION TOKENS REMOVED. This is the key leak-mitigation: at cache time the prompt has the belief reasoning content (perception, not decision) wrapped by `<|BELIEF|>...` and NO `<|ACTION|>` tokens anywhere. Causal attention cannot leak GT actions because they don't exist. """ from training.VLA.cot_belief_dataset_v2 import BELIEF_OPEN, BELIEF_CLOSE assert len(beliefs_per_frame) == 8 lines: List[str] = [] scene = (scene or "").strip() critical = (critical or "").strip() if scene: lines.append(f"Scene: {scene}") if critical: lines.append(f"Critical: {critical}") if lines: lines.append("") for b in beliefs_per_frame: b_clean = (b or "").strip().replace("\n", " ") b_clean = " ".join(b_clean.split()[:25]) lines.append(f"{BELIEF_OPEN} {b_clean} {BELIEF_CLOSE}") return "\n".join(lines) @torch.no_grad() def extract_split(ckpt_dir: Path, base_model: Path, manifest_path: Path, out_path: Path, belief_layers: Tuple[int, ...] = (20, 24, 28, 32), policy_layer: int = 33, n_frames: int = 8, limit: int = 0, batch_size: int = 4, pool_mode: str = "range", random_span_seed: int = 0): if out_path.exists(): logger.info(f"[skip] {out_path} exists — delete to re-extract") return from transformers import AutoProcessor, AutoModelForImageTextToText from peft import PeftModel from training.VLA.cot_belief_dataset_v2 import ( ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE, build_chat_v2, ) from training.VLA.frame_utils import sample_frames logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}") logger.info(f" belief_layers={belief_layers} policy_layer={policy_layer} " f"batch_size={batch_size}") processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True) processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL}) # IMPORTANT: right padding so BELIEF token positions stay correct in batched mode processor.tokenizer.padding_side = "right" model = AutoModelForImageTextToText.from_pretrained( base_model, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) model.resize_token_embeddings(len(processor.tokenizer)) if (ckpt_dir / "adapter_config.json").exists(): model = PeftModel.from_pretrained(model, ckpt_dir) model.eval() tok = processor.tokenizer belief_open_id = tok.convert_tokens_to_ids(BELIEF_OPEN) belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE) logger.info(f"[tok] BELIEF_OPEN={belief_open_id} BELIEF_CLOSE={belief_close_id}") # ── load manifest ── records: List[Dict] = [] with open(manifest_path) as f: for ln in f: ln = ln.strip() if not ln: continue try: r = json.loads(ln) except json.JSONDecodeError: continue if (isinstance(r.get("beliefs_per_frame"), list) and len(r["beliefs_per_frame"]) == n_frames and r.get("video_path")): records.append(r) if limit > 0: records = records[:limit] logger.info(f"[load] {manifest_path} n={len(records)}") # output tensors (lazy-alloc after first forward to know hidden_dim) N = len(records) n_belief_layers = len(belief_layers) out_belief: torch.Tensor = None # [N, 8, n_belief_layers * D] out_policy: torch.Tensor = None # [N, 8, D] out_valid = torch.zeros(N, n_frames, dtype=torch.bool) out_actions = torch.zeros(N, n_frames, dtype=torch.long) out_danger = torch.zeros(N, n_frames, dtype=torch.float32) out_tta = torch.zeros(N, n_frames, dtype=torch.float32) out_tick_action = torch.zeros(N, dtype=torch.long) out_tick_tta = torch.zeros(N, dtype=torch.float32) ids_list: List[str] = [None] * N cat_list: List[str] = [""] * N src_list: List[str] = [""] * N vid_list: List[str] = [""] * N n_failed = 0 n_pool_fallback = 0 t0 = time.time() def _prepare_one(rec): """Decode frames + build text for a single record. Returns (frames, full_text) or None on failure.""" frames = sample_frames(rec["video_path"], n_frames=n_frames, resize_short=336, frame_indices=rec["frame_indices"]) assistant_text = build_extraction_assistant( rec["beliefs_per_frame"], scene=rec.get("scene", ""), critical=rec.get("critical", ""), ) full_msgs = build_chat_v2(frames, assistant_text=assistant_text) full_text = processor.apply_chat_template( full_msgs, tokenize=False, add_generation_prompt=False) return frames, full_text # Process in batches of `batch_size` for parallel GPU utilisation. # With batch_size=4 on Qwen3-VL-4B + Conv3d→Linear patch, expect ~3-4× the # batch=1 throughput on RTX 5090 with ≤30 GB VRAM. for batch_start in tqdm(range(0, N, batch_size), ncols=80, desc="cache_v2"): batch_end = min(N, batch_start + batch_size) batch_recs = records[batch_start:batch_end] # ── prepare batch (CPU: decode + tokenize text) ── batch_frames = [] batch_texts = [] keep_idx = [] # indices within this batch that succeeded prep for j, rec in enumerate(batch_recs): try: frames, full_text = _prepare_one(rec) batch_frames.append(frames) batch_texts.append(full_text) keep_idx.append(j) except Exception as e: n_failed += 1 logger.warning(f"[skip] {rec.get('id')}: {e}") global_i = batch_start + j ids_list[global_i] = rec.get("id", str(global_i)) if not keep_idx: continue try: # batched tokenisation (right padding, so BELIEF positions stay correct) inputs = processor(text=batch_texts, images=batch_frames, return_tensors="pt", padding=True, truncation=True, max_length=4096) inputs = {k: v.to(model.device) for k, v in inputs.items()} out = model(**inputs, output_hidden_states=True, return_dict=True) hs_tuple = out.hidden_states # tuple of [B, T, D] ids_b_all = inputs["input_ids"] # [B, T] attn_b_all = inputs["attention_mask"] # [B, T] D = hs_tuple[-1].shape[-1] except torch.cuda.OutOfMemoryError as e: logger.error(f"[OOM] batch {batch_start}..{batch_end}: {e}") torch.cuda.empty_cache() n_failed += len(keep_idx) for j in keep_idx: global_i = batch_start + j ids_list[global_i] = batch_recs[j].get("id", str(global_i)) continue except Exception as e: logger.error(f"[fwd-err] batch {batch_start}..{batch_end}: {e}") n_failed += len(keep_idx) for j in keep_idx: global_i = batch_start + j ids_list[global_i] = batch_recs[j].get("id", str(global_i)) continue # ── per-sample extraction ── # lazy-allocate output tensors (need D from first forward) if out_belief is None: out_belief = torch.zeros(N, n_frames, n_belief_layers * D, dtype=torch.float16) out_policy = torch.zeros(N, n_frames, D, dtype=torch.float16) logger.info(f"[alloc] belief shape={tuple(out_belief.shape)} " f"policy shape={tuple(out_policy.shape)}") for b, j in enumerate(keep_idx): global_i = batch_start + j rec = batch_recs[j] ids_t = ids_b_all[b] attn_t = attn_b_all[b] # restrict to valid (non-pad) region valid_mask = attn_t.bool() open_pos = ((ids_t == belief_open_id) & valid_mask).nonzero( as_tuple=False).flatten().tolist() close_pos = ((ids_t == belief_close_id) & valid_mask).nonzero( as_tuple=False).flatten().tolist() n_blocks = min(len(open_pos), len(close_pos), n_frames) if n_blocks == 0: n_pool_fallback += 1 ids_list[global_i] = rec["id"] cat_list[global_i] = rec.get("category", "") src_list[global_i] = rec.get("source", "") vid_list[global_i] = rec.get("video_id", rec["id"]) continue belief_concat = torch.zeros(n_blocks, n_belief_layers * D, dtype=torch.float16) policy_vec = torch.zeros(n_blocks, D, dtype=torch.float16) # Pre-compute pool spans per frame, depending on pool_mode. # For each frame f we need (inner_start, inner_end) on the same # token stream as the original (range) extractor. T_valid = int(valid_mask.sum().item()) pairs_default = list(zip(open_pos[:n_blocks], close_pos[:n_blocks])) if pool_mode == "range": pool_spans = [(o + 1, c) for (o, c) in pairs_default] elif pool_mode == "open": # single-token pool at <|BELIEF|> open position (length-1 span) pool_spans = [(o, o + 1) for (o, c) in pairs_default] elif pool_mode == "token_mean": # Format-agnostic baseline: mean over the assistant-response span # (first OPEN → last CLOSE), replicated across n_blocks frames. resp_start = open_pos[0] resp_end = close_pos[min(len(close_pos), n_blocks) - 1] + 1 pool_spans = [(resp_start, resp_end)] * n_blocks elif pool_mode == "random_span": # Control: spans of same length as the average BELIEF span on # this sample, but at random positions inside the response. import random as _rnd rng = _rnd.Random(int(random_span_seed) * 100003 + global_i) span_lens = [c - (o + 1) for (o, c) in pairs_default if c > o + 1] L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1)))) resp_start = open_pos[0] resp_end = close_pos[min(len(close_pos), n_blocks) - 1] + 1 pool_spans = [] for f in range(n_blocks): if resp_end - resp_start <= L_span: pool_spans.append((resp_start, resp_end)) else: s = rng.randint(resp_start, resp_end - L_span) pool_spans.append((s, s + L_span)) else: raise ValueError(f"unknown pool_mode={pool_mode}") for f, ((o, c), (s, e)) in enumerate(zip(pairs_default, pool_spans)): if e <= s: n_pool_fallback += 1 continue parts = [] for L in belief_layers: Lh = hs_tuple[L][b, s:e] parts.append(Lh.mean(dim=0).to(torch.float16)) belief_concat[f] = torch.cat(parts, dim=-1).cpu() # policy_position stays as the hidden state AT the f-th close-tag # so downstream PolicyHead receives the same register regardless # of pool_mode — isolating the ablation to belief_content only. policy_vec[f] = hs_tuple[policy_layer][b, c].to(torch.float16).cpu() out_valid[global_i, f] = True out_belief[global_i, :n_blocks] = belief_concat out_policy[global_i, :n_blocks] = policy_vec ids_list[global_i] = rec["id"] cat_list[global_i] = rec.get("category", "") src_list[global_i] = rec.get("source", "") vid_list[global_i] = rec.get("video_id", rec["id"]) out_actions[global_i] = torch.tensor( [ACTION_NAME_TO_IDX.get(a, 0) for a in rec["actions_per_frame"]], dtype=torch.long) out_danger[global_i] = torch.tensor(rec["danger_per_frame"], dtype=torch.float32) out_tta[global_i] = torch.tensor(rec["tta_per_frame"], dtype=torch.float32) out_tick_action[global_i] = ACTION_NAME_TO_IDX.get( rec.get("tick_action", "SILENT"), 0) out_tick_tta[global_i] = float(rec.get("tick_tta_raw", -1.0)) # keep only successful entries (non-empty id) # MEMORY-SAFE: avoid fancy-index COPY of 30 GB belief tensor that OOM-kills the # process at save time. If all records succeeded (the typical case), pass # tensors through directly. Else use torch.index_select which is memory- # equivalent to fancy indexing but cleaner to free. keep = [k for k, x in enumerate(ids_list) if x is not None] all_valid = (len(keep) == N) if all_valid: belief_save = out_belief policy_save = out_policy valid_save = out_valid actions_save = out_actions danger_save = out_danger tta_save = out_tta tick_action_save = out_tick_action tick_tta_save = out_tick_tta else: keep_t = torch.tensor(keep, dtype=torch.long) belief_save = (out_belief.index_select(0, keep_t) if out_belief is not None else None) policy_save = (out_policy.index_select(0, keep_t) if out_policy is not None else None) valid_save = out_valid.index_select(0, keep_t) actions_save = out_actions.index_select(0, keep_t) danger_save = out_danger.index_select(0, keep_t) tta_save = out_tta.index_select(0, keep_t) tick_action_save = out_tick_action.index_select(0, keep_t) tick_tta_save = out_tick_tta.index_select(0, keep_t) # Free the original full tensors before torch.save (avoid 2x peak RAM) out_belief = out_policy = None out_valid = out_actions = out_danger = out_tta = None out_tick_action = out_tick_tta = None import gc; gc.collect() out_dict = { "ids": [ids_list[k] for k in keep], "belief_content": belief_save, "policy_position": policy_save, "valid_frames": valid_save, "actions_pf": actions_save, "danger_pf": danger_save, "tta_pf": tta_save, "tick_action": tick_action_save, "tick_tta_raw": tick_tta_save, "category": [cat_list[k] for k in keep], "source": [src_list[k] for k in keep], "video_id": [vid_list[k] for k in keep], "schema": "vlalert_x_v2_dual_pool", "belief_layers": list(belief_layers), "policy_layer": policy_layer, "pool_mode": pool_mode, "ckpt": str(ckpt_dir), } out_path.parent.mkdir(parents=True, exist_ok=True) logger.info(f"[save] writing → {out_path} " f"(belief {tuple(belief_save.shape) if belief_save is not None else None}, " f"policy {tuple(policy_save.shape) if policy_save is not None else None})") # Atomic write: save to .tmp then rename (avoids partial files on crash) tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") torch.save(out_dict, tmp_path) import os os.replace(str(tmp_path), str(out_path)) dt = time.time() - t0 logger.info(f"[save] DONE → {out_path}") if belief_save is not None: logger.info(f" belief_content shape={tuple(belief_save.shape)}") logger.info(f" policy_position shape={tuple(policy_save.shape)}") logger.info(f" n={len(keep)} failed={n_failed} fallback={n_pool_fallback} " f"elapsed={dt:.0f}s ({len(keep)/max(dt,1):.2f} it/s)") def main(): ap = argparse.ArgumentParser() ap.add_argument("--split", required=True, help="Tag for output filename. Common: train|val|" "multisrc_val_full|adasto_val|nexar_test|...") ap.add_argument("--manifest", type=Path) ap.add_argument("--ckpt", type=Path, default=ROOT / "checkpoints/sft_x_v2/best") ap.add_argument("--base_model", type=Path, default=ROOT / "models/Qwen3-VL-4B-Instruct") ap.add_argument("--tag", default="sft_x_v2") ap.add_argument("--out_dir", type=Path, default=ROOT / "data/belief_cache_v2") ap.add_argument("--belief_layers", nargs="+", type=int, default=[20, 24, 28, 32]) ap.add_argument("--policy_layer", type=int, default=33) ap.add_argument("--limit", type=int, default=0) ap.add_argument("--batch_size", type=int, default=4, help="Forward batch size. 4 fits in ~30 GB on RTX 5090 " "with Qwen3-VL-4B + Conv3d patch + bf16.") ap.add_argument("--pool_mode", choices=["range", "open", "token_mean", "random_span"], default="range", # Note: "action" mode is not supported here because the # extraction prompt only contains <|BELIEF|>... # spans (no action tokens fed to the model). Add a separate # extraction prompt if you want action-position pooling. help="How to pool hidden states to form belief_content: " "range=mean inside <|BELIEF|>... span (default); " "open=hidden at <|BELIEF|> open token; " "token_mean=mean over the whole response (format-agnostic); " "random_span=same-length span at random positions (control).") ap.add_argument("--random_span_seed", type=int, default=0) args = ap.parse_args() if args.manifest is None: args.manifest = ROOT / f"data/cot_corpus_v2/vlalert_x_perframe_v2_{args.split}.jsonl" out_path = args.out_dir / f"{args.tag}__{args.split}.pt" extract_split(ckpt_dir=args.ckpt, base_model=args.base_model, manifest_path=args.manifest, out_path=out_path, belief_layers=tuple(args.belief_layers), policy_layer=args.policy_layer, limit=args.limit, batch_size=args.batch_size, pool_mode=args.pool_mode, random_span_seed=args.random_span_seed) if __name__ == "__main__": main()