VLAlert / tools /make_cache_x_v2.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
"""VLAlert-X v2 Phase 2 — dual-stream cache extractor (leak-free).
For each (video, 8-frame) tick, build a prompt that contains the per-frame
BELIEF reasoning text but NO action tokens (this is the key: GT actions
never enter causal attention so neither stream leaks).
Scene: ... (optional, from manifest)
Critical: ... (optional)
<|BELIEF|> {belief_text_0} </|BELIEF|>
<|BELIEF|> {belief_text_1} </|BELIEF|>
...
<|BELIEF|> {belief_text_7} </|BELIEF|>
Forward through Qwen3-VL-4B (SFT'd, `checkpoints/sft_x_v2/best`) with
`output_hidden_states=True`, then extract two complementary features per frame:
(A) BELIEF_CONTENT[f] "perception/risk-cue register"
= mean-pool hidden states over tokens BETWEEN
the f-th `<|BELIEF|>` and the matching `</|BELIEF|>`,
EXCLUDING the two tags themselves.
Concat hidden_states from layers {20, 24, 28, 32}.
shape: [8, 4 × 2560] = [8, 10240]
(B) POLICY_POSITION[f] "decision-time register"
= hidden state AT the position of the f-th `</|BELIEF|>` closing tag.
Single layer 33.
shape: [8, 2560]
The position right after `</|BELIEF|>` is where the SFT model committed to
the next-token prediction (=action). At that position the model has just
finished reading the belief reasoning and is about to emit the action; the
hidden state encodes its commitment state.
Output cache:
data/belief_cache_v2/{tag}__{split}.pt = {
"ids": list[str] (N,)
"belief_content": tensor [N, 8, 10240] fp16
"policy_position": tensor [N, 8, 2560] fp16
"valid_frames": tensor [N, 8] bool
"actions_pf": tensor [N, 8] long
"danger_pf": tensor [N, 8] fp32
"tta_pf": tensor [N, 8] fp32
"tick_action": tensor [N] long
"tick_tta_raw": tensor [N] fp32
"category": list[str]
"source": list[str]
"video_id": list[str]
"schema": "vlalert_x_v2_dual_pool"
"belief_layers": [20, 24, 28, 32]
"policy_layer": 33
}
Usage:
python tools/make_cache_x_v2.py --split train
python tools/make_cache_x_v2.py --split val
"""
from __future__ import annotations
# PR patch must run BEFORE Qwen3-VL import
import sys
sys.path.insert(0, ".")
from tools import run_train_cot_belief_fast # noqa: F401
import argparse
import json
import logging
import re
import time
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from tqdm import tqdm
ROOT = Path(__file__).resolve().parents[1]
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("make_cache_x_v2")
ACTION_NAME_TO_IDX = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2}
def build_extraction_assistant(beliefs_per_frame: List[str],
scene: str = "",
critical: str = "") -> str:
"""Same as SFT format_assistant_v2 but ACTION TOKENS REMOVED.
This is the key leak-mitigation: at cache time the prompt has the
belief reasoning content (perception, not decision) wrapped by
`<|BELIEF|>...</|BELIEF|>` and NO `<|ACTION|>` tokens anywhere.
Causal attention cannot leak GT actions because they don't exist.
"""
from training.VLA.cot_belief_dataset_v2 import BELIEF_OPEN, BELIEF_CLOSE
assert len(beliefs_per_frame) == 8
lines: List[str] = []
scene = (scene or "").strip()
critical = (critical or "").strip()
if scene:
lines.append(f"Scene: {scene}")
if critical:
lines.append(f"Critical: {critical}")
if lines:
lines.append("")
for b in beliefs_per_frame:
b_clean = (b or "").strip().replace("\n", " ")
b_clean = " ".join(b_clean.split()[:25])
lines.append(f"{BELIEF_OPEN} {b_clean} {BELIEF_CLOSE}")
return "\n".join(lines)
@torch.no_grad()
def extract_split(ckpt_dir: Path, base_model: Path,
manifest_path: Path, out_path: Path,
belief_layers: Tuple[int, ...] = (20, 24, 28, 32),
policy_layer: int = 33,
n_frames: int = 8,
limit: int = 0,
batch_size: int = 4,
pool_mode: str = "range",
random_span_seed: int = 0):
if out_path.exists():
logger.info(f"[skip] {out_path} exists — delete to re-extract")
return
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
from training.VLA.cot_belief_dataset_v2 import (
ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE, build_chat_v2,
)
from training.VLA.frame_utils import sample_frames
logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}")
logger.info(f" belief_layers={belief_layers} policy_layer={policy_layer} "
f"batch_size={batch_size}")
processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True)
processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL})
# IMPORTANT: right padding so BELIEF token positions stay correct in batched mode
processor.tokenizer.padding_side = "right"
model = AutoModelForImageTextToText.from_pretrained(
base_model, dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True)
model.resize_token_embeddings(len(processor.tokenizer))
if (ckpt_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(model, ckpt_dir)
model.eval()
tok = processor.tokenizer
belief_open_id = tok.convert_tokens_to_ids(BELIEF_OPEN)
belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE)
logger.info(f"[tok] BELIEF_OPEN={belief_open_id} BELIEF_CLOSE={belief_close_id}")
# ── load manifest ──
records: List[Dict] = []
with open(manifest_path) as f:
for ln in f:
ln = ln.strip()
if not ln: continue
try:
r = json.loads(ln)
except json.JSONDecodeError:
continue
if (isinstance(r.get("beliefs_per_frame"), list)
and len(r["beliefs_per_frame"]) == n_frames
and r.get("video_path")):
records.append(r)
if limit > 0:
records = records[:limit]
logger.info(f"[load] {manifest_path} n={len(records)}")
# output tensors (lazy-alloc after first forward to know hidden_dim)
N = len(records)
n_belief_layers = len(belief_layers)
out_belief: torch.Tensor = None # [N, 8, n_belief_layers * D]
out_policy: torch.Tensor = None # [N, 8, D]
out_valid = torch.zeros(N, n_frames, dtype=torch.bool)
out_actions = torch.zeros(N, n_frames, dtype=torch.long)
out_danger = torch.zeros(N, n_frames, dtype=torch.float32)
out_tta = torch.zeros(N, n_frames, dtype=torch.float32)
out_tick_action = torch.zeros(N, dtype=torch.long)
out_tick_tta = torch.zeros(N, dtype=torch.float32)
ids_list: List[str] = [None] * N
cat_list: List[str] = [""] * N
src_list: List[str] = [""] * N
vid_list: List[str] = [""] * N
n_failed = 0
n_pool_fallback = 0
t0 = time.time()
def _prepare_one(rec):
"""Decode frames + build text for a single record. Returns
(frames, full_text) or None on failure."""
frames = sample_frames(rec["video_path"], n_frames=n_frames,
resize_short=336,
frame_indices=rec["frame_indices"])
assistant_text = build_extraction_assistant(
rec["beliefs_per_frame"],
scene=rec.get("scene", ""),
critical=rec.get("critical", ""),
)
full_msgs = build_chat_v2(frames, assistant_text=assistant_text)
full_text = processor.apply_chat_template(
full_msgs, tokenize=False, add_generation_prompt=False)
return frames, full_text
# Process in batches of `batch_size` for parallel GPU utilisation.
# With batch_size=4 on Qwen3-VL-4B + Conv3d→Linear patch, expect ~3-4× the
# batch=1 throughput on RTX 5090 with ≤30 GB VRAM.
for batch_start in tqdm(range(0, N, batch_size), ncols=80, desc="cache_v2"):
batch_end = min(N, batch_start + batch_size)
batch_recs = records[batch_start:batch_end]
# ── prepare batch (CPU: decode + tokenize text) ──
batch_frames = []
batch_texts = []
keep_idx = [] # indices within this batch that succeeded prep
for j, rec in enumerate(batch_recs):
try:
frames, full_text = _prepare_one(rec)
batch_frames.append(frames)
batch_texts.append(full_text)
keep_idx.append(j)
except Exception as e:
n_failed += 1
logger.warning(f"[skip] {rec.get('id')}: {e}")
global_i = batch_start + j
ids_list[global_i] = rec.get("id", str(global_i))
if not keep_idx:
continue
try:
# batched tokenisation (right padding, so BELIEF positions stay correct)
inputs = processor(text=batch_texts, images=batch_frames,
return_tensors="pt", padding=True,
truncation=True, max_length=4096)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model(**inputs, output_hidden_states=True, return_dict=True)
hs_tuple = out.hidden_states # tuple of [B, T, D]
ids_b_all = inputs["input_ids"] # [B, T]
attn_b_all = inputs["attention_mask"] # [B, T]
D = hs_tuple[-1].shape[-1]
except torch.cuda.OutOfMemoryError as e:
logger.error(f"[OOM] batch {batch_start}..{batch_end}: {e}")
torch.cuda.empty_cache()
n_failed += len(keep_idx)
for j in keep_idx:
global_i = batch_start + j
ids_list[global_i] = batch_recs[j].get("id", str(global_i))
continue
except Exception as e:
logger.error(f"[fwd-err] batch {batch_start}..{batch_end}: {e}")
n_failed += len(keep_idx)
for j in keep_idx:
global_i = batch_start + j
ids_list[global_i] = batch_recs[j].get("id", str(global_i))
continue
# ── per-sample extraction ──
# lazy-allocate output tensors (need D from first forward)
if out_belief is None:
out_belief = torch.zeros(N, n_frames, n_belief_layers * D,
dtype=torch.float16)
out_policy = torch.zeros(N, n_frames, D, dtype=torch.float16)
logger.info(f"[alloc] belief shape={tuple(out_belief.shape)} "
f"policy shape={tuple(out_policy.shape)}")
for b, j in enumerate(keep_idx):
global_i = batch_start + j
rec = batch_recs[j]
ids_t = ids_b_all[b]
attn_t = attn_b_all[b]
# restrict to valid (non-pad) region
valid_mask = attn_t.bool()
open_pos = ((ids_t == belief_open_id) & valid_mask).nonzero(
as_tuple=False).flatten().tolist()
close_pos = ((ids_t == belief_close_id) & valid_mask).nonzero(
as_tuple=False).flatten().tolist()
n_blocks = min(len(open_pos), len(close_pos), n_frames)
if n_blocks == 0:
n_pool_fallback += 1
ids_list[global_i] = rec["id"]
cat_list[global_i] = rec.get("category", "")
src_list[global_i] = rec.get("source", "")
vid_list[global_i] = rec.get("video_id", rec["id"])
continue
belief_concat = torch.zeros(n_blocks, n_belief_layers * D,
dtype=torch.float16)
policy_vec = torch.zeros(n_blocks, D, dtype=torch.float16)
# Pre-compute pool spans per frame, depending on pool_mode.
# For each frame f we need (inner_start, inner_end) on the same
# token stream as the original (range) extractor.
T_valid = int(valid_mask.sum().item())
pairs_default = list(zip(open_pos[:n_blocks], close_pos[:n_blocks]))
if pool_mode == "range":
pool_spans = [(o + 1, c) for (o, c) in pairs_default]
elif pool_mode == "open":
# single-token pool at <|BELIEF|> open position (length-1 span)
pool_spans = [(o, o + 1) for (o, c) in pairs_default]
elif pool_mode == "token_mean":
# Format-agnostic baseline: mean over the assistant-response span
# (first OPEN → last CLOSE), replicated across n_blocks frames.
resp_start = open_pos[0]
resp_end = close_pos[min(len(close_pos), n_blocks) - 1] + 1
pool_spans = [(resp_start, resp_end)] * n_blocks
elif pool_mode == "random_span":
# Control: spans of same length as the average BELIEF span on
# this sample, but at random positions inside the response.
import random as _rnd
rng = _rnd.Random(int(random_span_seed) * 100003 + global_i)
span_lens = [c - (o + 1) for (o, c) in pairs_default if c > o + 1]
L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1))))
resp_start = open_pos[0]
resp_end = close_pos[min(len(close_pos), n_blocks) - 1] + 1
pool_spans = []
for f in range(n_blocks):
if resp_end - resp_start <= L_span:
pool_spans.append((resp_start, resp_end))
else:
s = rng.randint(resp_start, resp_end - L_span)
pool_spans.append((s, s + L_span))
else:
raise ValueError(f"unknown pool_mode={pool_mode}")
for f, ((o, c), (s, e)) in enumerate(zip(pairs_default, pool_spans)):
if e <= s:
n_pool_fallback += 1
continue
parts = []
for L in belief_layers:
Lh = hs_tuple[L][b, s:e]
parts.append(Lh.mean(dim=0).to(torch.float16))
belief_concat[f] = torch.cat(parts, dim=-1).cpu()
# policy_position stays as the hidden state AT the f-th close-tag
# so downstream PolicyHead receives the same register regardless
# of pool_mode — isolating the ablation to belief_content only.
policy_vec[f] = hs_tuple[policy_layer][b, c].to(torch.float16).cpu()
out_valid[global_i, f] = True
out_belief[global_i, :n_blocks] = belief_concat
out_policy[global_i, :n_blocks] = policy_vec
ids_list[global_i] = rec["id"]
cat_list[global_i] = rec.get("category", "")
src_list[global_i] = rec.get("source", "")
vid_list[global_i] = rec.get("video_id", rec["id"])
out_actions[global_i] = torch.tensor(
[ACTION_NAME_TO_IDX.get(a, 0) for a in rec["actions_per_frame"]],
dtype=torch.long)
out_danger[global_i] = torch.tensor(rec["danger_per_frame"],
dtype=torch.float32)
out_tta[global_i] = torch.tensor(rec["tta_per_frame"],
dtype=torch.float32)
out_tick_action[global_i] = ACTION_NAME_TO_IDX.get(
rec.get("tick_action", "SILENT"), 0)
out_tick_tta[global_i] = float(rec.get("tick_tta_raw", -1.0))
# keep only successful entries (non-empty id)
# MEMORY-SAFE: avoid fancy-index COPY of 30 GB belief tensor that OOM-kills the
# process at save time. If all records succeeded (the typical case), pass
# tensors through directly. Else use torch.index_select which is memory-
# equivalent to fancy indexing but cleaner to free.
keep = [k for k, x in enumerate(ids_list) if x is not None]
all_valid = (len(keep) == N)
if all_valid:
belief_save = out_belief
policy_save = out_policy
valid_save = out_valid
actions_save = out_actions
danger_save = out_danger
tta_save = out_tta
tick_action_save = out_tick_action
tick_tta_save = out_tick_tta
else:
keep_t = torch.tensor(keep, dtype=torch.long)
belief_save = (out_belief.index_select(0, keep_t)
if out_belief is not None else None)
policy_save = (out_policy.index_select(0, keep_t)
if out_policy is not None else None)
valid_save = out_valid.index_select(0, keep_t)
actions_save = out_actions.index_select(0, keep_t)
danger_save = out_danger.index_select(0, keep_t)
tta_save = out_tta.index_select(0, keep_t)
tick_action_save = out_tick_action.index_select(0, keep_t)
tick_tta_save = out_tick_tta.index_select(0, keep_t)
# Free the original full tensors before torch.save (avoid 2x peak RAM)
out_belief = out_policy = None
out_valid = out_actions = out_danger = out_tta = None
out_tick_action = out_tick_tta = None
import gc; gc.collect()
out_dict = {
"ids": [ids_list[k] for k in keep],
"belief_content": belief_save,
"policy_position": policy_save,
"valid_frames": valid_save,
"actions_pf": actions_save,
"danger_pf": danger_save,
"tta_pf": tta_save,
"tick_action": tick_action_save,
"tick_tta_raw": tick_tta_save,
"category": [cat_list[k] for k in keep],
"source": [src_list[k] for k in keep],
"video_id": [vid_list[k] for k in keep],
"schema": "vlalert_x_v2_dual_pool",
"belief_layers": list(belief_layers),
"policy_layer": policy_layer,
"pool_mode": pool_mode,
"ckpt": str(ckpt_dir),
}
out_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"[save] writing → {out_path} "
f"(belief {tuple(belief_save.shape) if belief_save is not None else None}, "
f"policy {tuple(policy_save.shape) if policy_save is not None else None})")
# Atomic write: save to .tmp then rename (avoids partial files on crash)
tmp_path = out_path.with_suffix(out_path.suffix + ".tmp")
torch.save(out_dict, tmp_path)
import os
os.replace(str(tmp_path), str(out_path))
dt = time.time() - t0
logger.info(f"[save] DONE → {out_path}")
if belief_save is not None:
logger.info(f" belief_content shape={tuple(belief_save.shape)}")
logger.info(f" policy_position shape={tuple(policy_save.shape)}")
logger.info(f" n={len(keep)} failed={n_failed} fallback={n_pool_fallback} "
f"elapsed={dt:.0f}s ({len(keep)/max(dt,1):.2f} it/s)")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--split", required=True,
help="Tag for output filename. Common: train|val|"
"multisrc_val_full|adasto_val|nexar_test|...")
ap.add_argument("--manifest", type=Path)
ap.add_argument("--ckpt", type=Path,
default=ROOT / "checkpoints/sft_x_v2/best")
ap.add_argument("--base_model", type=Path,
default=ROOT / "models/Qwen3-VL-4B-Instruct")
ap.add_argument("--tag", default="sft_x_v2")
ap.add_argument("--out_dir", type=Path,
default=ROOT / "data/belief_cache_v2")
ap.add_argument("--belief_layers", nargs="+", type=int,
default=[20, 24, 28, 32])
ap.add_argument("--policy_layer", type=int, default=33)
ap.add_argument("--limit", type=int, default=0)
ap.add_argument("--batch_size", type=int, default=4,
help="Forward batch size. 4 fits in ~30 GB on RTX 5090 "
"with Qwen3-VL-4B + Conv3d patch + bf16.")
ap.add_argument("--pool_mode",
choices=["range", "open", "token_mean", "random_span"],
default="range",
# Note: "action" mode is not supported here because the
# extraction prompt only contains <|BELIEF|>...</|BELIEF|>
# spans (no action tokens fed to the model). Add a separate
# extraction prompt if you want action-position pooling.
help="How to pool hidden states to form belief_content: "
"range=mean inside <|BELIEF|>...</|BELIEF|> span (default); "
"open=hidden at <|BELIEF|> open token; "
"token_mean=mean over the whole response (format-agnostic); "
"random_span=same-length span at random positions (control).")
ap.add_argument("--random_span_seed", type=int, default=0)
args = ap.parse_args()
if args.manifest is None:
args.manifest = ROOT / f"data/cot_corpus_v2/vlalert_x_perframe_v2_{args.split}.jsonl"
out_path = args.out_dir / f"{args.tag}__{args.split}.pt"
extract_split(ckpt_dir=args.ckpt, base_model=args.base_model,
manifest_path=args.manifest, out_path=out_path,
belief_layers=tuple(args.belief_layers),
policy_layer=args.policy_layer,
limit=args.limit,
batch_size=args.batch_size,
pool_mode=args.pool_mode,
random_span_seed=args.random_span_seed)
if __name__ == "__main__":
main()