| """VLAlert-X v2 Phase 2 — dual-stream cache extractor (leak-free). |
| |
| For each (video, 8-frame) tick, build a prompt that contains the per-frame |
| BELIEF reasoning text but NO action tokens (this is the key: GT actions |
| never enter causal attention so neither stream leaks). |
| |
| Scene: ... (optional, from manifest) |
| Critical: ... (optional) |
| <|BELIEF|> {belief_text_0} </|BELIEF|> |
| <|BELIEF|> {belief_text_1} </|BELIEF|> |
| ... |
| <|BELIEF|> {belief_text_7} </|BELIEF|> |
| |
| Forward through Qwen3-VL-4B (SFT'd, `checkpoints/sft_x_v2/best`) with |
| `output_hidden_states=True`, then extract two complementary features per frame: |
| |
| (A) BELIEF_CONTENT[f] "perception/risk-cue register" |
| = mean-pool hidden states over tokens BETWEEN |
| the f-th `<|BELIEF|>` and the matching `</|BELIEF|>`, |
| EXCLUDING the two tags themselves. |
| Concat hidden_states from layers {20, 24, 28, 32}. |
| shape: [8, 4 × 2560] = [8, 10240] |
| |
| (B) POLICY_POSITION[f] "decision-time register" |
| = hidden state AT the position of the f-th `</|BELIEF|>` closing tag. |
| Single layer 33. |
| shape: [8, 2560] |
| |
| The position right after `</|BELIEF|>` is where the SFT model committed to |
| the next-token prediction (=action). At that position the model has just |
| finished reading the belief reasoning and is about to emit the action; the |
| hidden state encodes its commitment state. |
| |
| Output cache: |
| data/belief_cache_v2/{tag}__{split}.pt = { |
| "ids": list[str] (N,) |
| "belief_content": tensor [N, 8, 10240] fp16 |
| "policy_position": tensor [N, 8, 2560] fp16 |
| "valid_frames": tensor [N, 8] bool |
| "actions_pf": tensor [N, 8] long |
| "danger_pf": tensor [N, 8] fp32 |
| "tta_pf": tensor [N, 8] fp32 |
| "tick_action": tensor [N] long |
| "tick_tta_raw": tensor [N] fp32 |
| "category": list[str] |
| "source": list[str] |
| "video_id": list[str] |
| "schema": "vlalert_x_v2_dual_pool" |
| "belief_layers": [20, 24, 28, 32] |
| "policy_layer": 33 |
| } |
| |
| Usage: |
| python tools/make_cache_x_v2.py --split train |
| python tools/make_cache_x_v2.py --split val |
| """ |
| from __future__ import annotations |
|
|
| |
| import sys |
| sys.path.insert(0, ".") |
| from tools import run_train_cot_belief_fast |
|
|
| import argparse |
| import json |
| import logging |
| import re |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| logging.basicConfig(level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("make_cache_x_v2") |
|
|
| ACTION_NAME_TO_IDX = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2} |
|
|
|
|
| def build_extraction_assistant(beliefs_per_frame: List[str], |
| scene: str = "", |
| critical: str = "") -> str: |
| """Same as SFT format_assistant_v2 but ACTION TOKENS REMOVED. |
| |
| This is the key leak-mitigation: at cache time the prompt has the |
| belief reasoning content (perception, not decision) wrapped by |
| `<|BELIEF|>...</|BELIEF|>` and NO `<|ACTION|>` tokens anywhere. |
| Causal attention cannot leak GT actions because they don't exist. |
| """ |
| from training.VLA.cot_belief_dataset_v2 import BELIEF_OPEN, BELIEF_CLOSE |
| assert len(beliefs_per_frame) == 8 |
| lines: List[str] = [] |
| scene = (scene or "").strip() |
| critical = (critical or "").strip() |
| if scene: |
| lines.append(f"Scene: {scene}") |
| if critical: |
| lines.append(f"Critical: {critical}") |
| if lines: |
| lines.append("") |
| for b in beliefs_per_frame: |
| b_clean = (b or "").strip().replace("\n", " ") |
| b_clean = " ".join(b_clean.split()[:25]) |
| lines.append(f"{BELIEF_OPEN} {b_clean} {BELIEF_CLOSE}") |
| return "\n".join(lines) |
|
|
|
|
| @torch.no_grad() |
| def extract_split(ckpt_dir: Path, base_model: Path, |
| manifest_path: Path, out_path: Path, |
| belief_layers: Tuple[int, ...] = (20, 24, 28, 32), |
| policy_layer: int = 33, |
| n_frames: int = 8, |
| limit: int = 0, |
| batch_size: int = 4, |
| pool_mode: str = "range", |
| random_span_seed: int = 0): |
| if out_path.exists(): |
| logger.info(f"[skip] {out_path} exists — delete to re-extract") |
| return |
|
|
| from transformers import AutoProcessor, AutoModelForImageTextToText |
| from peft import PeftModel |
| from training.VLA.cot_belief_dataset_v2 import ( |
| ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE, build_chat_v2, |
| ) |
| from training.VLA.frame_utils import sample_frames |
|
|
| logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}") |
| logger.info(f" belief_layers={belief_layers} policy_layer={policy_layer} " |
| f"batch_size={batch_size}") |
| processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True) |
| processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL}) |
| |
| processor.tokenizer.padding_side = "right" |
| model = AutoModelForImageTextToText.from_pretrained( |
| base_model, dtype=torch.bfloat16, device_map="auto", |
| trust_remote_code=True) |
| model.resize_token_embeddings(len(processor.tokenizer)) |
| if (ckpt_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(model, ckpt_dir) |
| model.eval() |
|
|
| tok = processor.tokenizer |
| belief_open_id = tok.convert_tokens_to_ids(BELIEF_OPEN) |
| belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE) |
| logger.info(f"[tok] BELIEF_OPEN={belief_open_id} BELIEF_CLOSE={belief_close_id}") |
|
|
| |
| records: List[Dict] = [] |
| with open(manifest_path) as f: |
| for ln in f: |
| ln = ln.strip() |
| if not ln: continue |
| try: |
| r = json.loads(ln) |
| except json.JSONDecodeError: |
| continue |
| if (isinstance(r.get("beliefs_per_frame"), list) |
| and len(r["beliefs_per_frame"]) == n_frames |
| and r.get("video_path")): |
| records.append(r) |
| if limit > 0: |
| records = records[:limit] |
| logger.info(f"[load] {manifest_path} n={len(records)}") |
|
|
| |
| N = len(records) |
| n_belief_layers = len(belief_layers) |
| out_belief: torch.Tensor = None |
| out_policy: torch.Tensor = None |
| out_valid = torch.zeros(N, n_frames, dtype=torch.bool) |
| out_actions = torch.zeros(N, n_frames, dtype=torch.long) |
| out_danger = torch.zeros(N, n_frames, dtype=torch.float32) |
| out_tta = torch.zeros(N, n_frames, dtype=torch.float32) |
| out_tick_action = torch.zeros(N, dtype=torch.long) |
| out_tick_tta = torch.zeros(N, dtype=torch.float32) |
| ids_list: List[str] = [None] * N |
| cat_list: List[str] = [""] * N |
| src_list: List[str] = [""] * N |
| vid_list: List[str] = [""] * N |
|
|
| n_failed = 0 |
| n_pool_fallback = 0 |
| t0 = time.time() |
|
|
| def _prepare_one(rec): |
| """Decode frames + build text for a single record. Returns |
| (frames, full_text) or None on failure.""" |
| frames = sample_frames(rec["video_path"], n_frames=n_frames, |
| resize_short=336, |
| frame_indices=rec["frame_indices"]) |
| assistant_text = build_extraction_assistant( |
| rec["beliefs_per_frame"], |
| scene=rec.get("scene", ""), |
| critical=rec.get("critical", ""), |
| ) |
| full_msgs = build_chat_v2(frames, assistant_text=assistant_text) |
| full_text = processor.apply_chat_template( |
| full_msgs, tokenize=False, add_generation_prompt=False) |
| return frames, full_text |
|
|
| |
| |
| |
| for batch_start in tqdm(range(0, N, batch_size), ncols=80, desc="cache_v2"): |
| batch_end = min(N, batch_start + batch_size) |
| batch_recs = records[batch_start:batch_end] |
|
|
| |
| batch_frames = [] |
| batch_texts = [] |
| keep_idx = [] |
| for j, rec in enumerate(batch_recs): |
| try: |
| frames, full_text = _prepare_one(rec) |
| batch_frames.append(frames) |
| batch_texts.append(full_text) |
| keep_idx.append(j) |
| except Exception as e: |
| n_failed += 1 |
| logger.warning(f"[skip] {rec.get('id')}: {e}") |
| global_i = batch_start + j |
| ids_list[global_i] = rec.get("id", str(global_i)) |
|
|
| if not keep_idx: |
| continue |
|
|
| try: |
| |
| inputs = processor(text=batch_texts, images=batch_frames, |
| return_tensors="pt", padding=True, |
| truncation=True, max_length=4096) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| out = model(**inputs, output_hidden_states=True, return_dict=True) |
| hs_tuple = out.hidden_states |
| ids_b_all = inputs["input_ids"] |
| attn_b_all = inputs["attention_mask"] |
| D = hs_tuple[-1].shape[-1] |
| except torch.cuda.OutOfMemoryError as e: |
| logger.error(f"[OOM] batch {batch_start}..{batch_end}: {e}") |
| torch.cuda.empty_cache() |
| n_failed += len(keep_idx) |
| for j in keep_idx: |
| global_i = batch_start + j |
| ids_list[global_i] = batch_recs[j].get("id", str(global_i)) |
| continue |
| except Exception as e: |
| logger.error(f"[fwd-err] batch {batch_start}..{batch_end}: {e}") |
| n_failed += len(keep_idx) |
| for j in keep_idx: |
| global_i = batch_start + j |
| ids_list[global_i] = batch_recs[j].get("id", str(global_i)) |
| continue |
|
|
| |
| |
| if out_belief is None: |
| out_belief = torch.zeros(N, n_frames, n_belief_layers * D, |
| dtype=torch.float16) |
| out_policy = torch.zeros(N, n_frames, D, dtype=torch.float16) |
| logger.info(f"[alloc] belief shape={tuple(out_belief.shape)} " |
| f"policy shape={tuple(out_policy.shape)}") |
|
|
| for b, j in enumerate(keep_idx): |
| global_i = batch_start + j |
| rec = batch_recs[j] |
| ids_t = ids_b_all[b] |
| attn_t = attn_b_all[b] |
|
|
| |
| valid_mask = attn_t.bool() |
| open_pos = ((ids_t == belief_open_id) & valid_mask).nonzero( |
| as_tuple=False).flatten().tolist() |
| close_pos = ((ids_t == belief_close_id) & valid_mask).nonzero( |
| as_tuple=False).flatten().tolist() |
| n_blocks = min(len(open_pos), len(close_pos), n_frames) |
|
|
| if n_blocks == 0: |
| n_pool_fallback += 1 |
| ids_list[global_i] = rec["id"] |
| cat_list[global_i] = rec.get("category", "") |
| src_list[global_i] = rec.get("source", "") |
| vid_list[global_i] = rec.get("video_id", rec["id"]) |
| continue |
|
|
| belief_concat = torch.zeros(n_blocks, n_belief_layers * D, |
| dtype=torch.float16) |
| policy_vec = torch.zeros(n_blocks, D, dtype=torch.float16) |
|
|
| |
| |
| |
| T_valid = int(valid_mask.sum().item()) |
| pairs_default = list(zip(open_pos[:n_blocks], close_pos[:n_blocks])) |
|
|
| if pool_mode == "range": |
| pool_spans = [(o + 1, c) for (o, c) in pairs_default] |
| elif pool_mode == "open": |
| |
| pool_spans = [(o, o + 1) for (o, c) in pairs_default] |
| elif pool_mode == "token_mean": |
| |
| |
| resp_start = open_pos[0] |
| resp_end = close_pos[min(len(close_pos), n_blocks) - 1] + 1 |
| pool_spans = [(resp_start, resp_end)] * n_blocks |
| elif pool_mode == "random_span": |
| |
| |
| import random as _rnd |
| rng = _rnd.Random(int(random_span_seed) * 100003 + global_i) |
| span_lens = [c - (o + 1) for (o, c) in pairs_default if c > o + 1] |
| L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1)))) |
| resp_start = open_pos[0] |
| resp_end = close_pos[min(len(close_pos), n_blocks) - 1] + 1 |
| pool_spans = [] |
| for f in range(n_blocks): |
| if resp_end - resp_start <= L_span: |
| pool_spans.append((resp_start, resp_end)) |
| else: |
| s = rng.randint(resp_start, resp_end - L_span) |
| pool_spans.append((s, s + L_span)) |
| else: |
| raise ValueError(f"unknown pool_mode={pool_mode}") |
|
|
| for f, ((o, c), (s, e)) in enumerate(zip(pairs_default, pool_spans)): |
| if e <= s: |
| n_pool_fallback += 1 |
| continue |
| parts = [] |
| for L in belief_layers: |
| Lh = hs_tuple[L][b, s:e] |
| parts.append(Lh.mean(dim=0).to(torch.float16)) |
| belief_concat[f] = torch.cat(parts, dim=-1).cpu() |
| |
| |
| |
| policy_vec[f] = hs_tuple[policy_layer][b, c].to(torch.float16).cpu() |
| out_valid[global_i, f] = True |
|
|
| out_belief[global_i, :n_blocks] = belief_concat |
| out_policy[global_i, :n_blocks] = policy_vec |
|
|
| ids_list[global_i] = rec["id"] |
| cat_list[global_i] = rec.get("category", "") |
| src_list[global_i] = rec.get("source", "") |
| vid_list[global_i] = rec.get("video_id", rec["id"]) |
| out_actions[global_i] = torch.tensor( |
| [ACTION_NAME_TO_IDX.get(a, 0) for a in rec["actions_per_frame"]], |
| dtype=torch.long) |
| out_danger[global_i] = torch.tensor(rec["danger_per_frame"], |
| dtype=torch.float32) |
| out_tta[global_i] = torch.tensor(rec["tta_per_frame"], |
| dtype=torch.float32) |
| out_tick_action[global_i] = ACTION_NAME_TO_IDX.get( |
| rec.get("tick_action", "SILENT"), 0) |
| out_tick_tta[global_i] = float(rec.get("tick_tta_raw", -1.0)) |
|
|
| |
| |
| |
| |
| |
| keep = [k for k, x in enumerate(ids_list) if x is not None] |
| all_valid = (len(keep) == N) |
|
|
| if all_valid: |
| belief_save = out_belief |
| policy_save = out_policy |
| valid_save = out_valid |
| actions_save = out_actions |
| danger_save = out_danger |
| tta_save = out_tta |
| tick_action_save = out_tick_action |
| tick_tta_save = out_tick_tta |
| else: |
| keep_t = torch.tensor(keep, dtype=torch.long) |
| belief_save = (out_belief.index_select(0, keep_t) |
| if out_belief is not None else None) |
| policy_save = (out_policy.index_select(0, keep_t) |
| if out_policy is not None else None) |
| valid_save = out_valid.index_select(0, keep_t) |
| actions_save = out_actions.index_select(0, keep_t) |
| danger_save = out_danger.index_select(0, keep_t) |
| tta_save = out_tta.index_select(0, keep_t) |
| tick_action_save = out_tick_action.index_select(0, keep_t) |
| tick_tta_save = out_tick_tta.index_select(0, keep_t) |
| |
| out_belief = out_policy = None |
| out_valid = out_actions = out_danger = out_tta = None |
| out_tick_action = out_tick_tta = None |
| import gc; gc.collect() |
|
|
| out_dict = { |
| "ids": [ids_list[k] for k in keep], |
| "belief_content": belief_save, |
| "policy_position": policy_save, |
| "valid_frames": valid_save, |
| "actions_pf": actions_save, |
| "danger_pf": danger_save, |
| "tta_pf": tta_save, |
| "tick_action": tick_action_save, |
| "tick_tta_raw": tick_tta_save, |
| "category": [cat_list[k] for k in keep], |
| "source": [src_list[k] for k in keep], |
| "video_id": [vid_list[k] for k in keep], |
| "schema": "vlalert_x_v2_dual_pool", |
| "belief_layers": list(belief_layers), |
| "policy_layer": policy_layer, |
| "pool_mode": pool_mode, |
| "ckpt": str(ckpt_dir), |
| } |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| logger.info(f"[save] writing → {out_path} " |
| f"(belief {tuple(belief_save.shape) if belief_save is not None else None}, " |
| f"policy {tuple(policy_save.shape) if policy_save is not None else None})") |
| |
| tmp_path = out_path.with_suffix(out_path.suffix + ".tmp") |
| torch.save(out_dict, tmp_path) |
| import os |
| os.replace(str(tmp_path), str(out_path)) |
| dt = time.time() - t0 |
| logger.info(f"[save] DONE → {out_path}") |
| if belief_save is not None: |
| logger.info(f" belief_content shape={tuple(belief_save.shape)}") |
| logger.info(f" policy_position shape={tuple(policy_save.shape)}") |
| logger.info(f" n={len(keep)} failed={n_failed} fallback={n_pool_fallback} " |
| f"elapsed={dt:.0f}s ({len(keep)/max(dt,1):.2f} it/s)") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--split", required=True, |
| help="Tag for output filename. Common: train|val|" |
| "multisrc_val_full|adasto_val|nexar_test|...") |
| ap.add_argument("--manifest", type=Path) |
| ap.add_argument("--ckpt", type=Path, |
| default=ROOT / "checkpoints/sft_x_v2/best") |
| ap.add_argument("--base_model", type=Path, |
| default=ROOT / "models/Qwen3-VL-4B-Instruct") |
| ap.add_argument("--tag", default="sft_x_v2") |
| ap.add_argument("--out_dir", type=Path, |
| default=ROOT / "data/belief_cache_v2") |
| ap.add_argument("--belief_layers", nargs="+", type=int, |
| default=[20, 24, 28, 32]) |
| ap.add_argument("--policy_layer", type=int, default=33) |
| ap.add_argument("--limit", type=int, default=0) |
| ap.add_argument("--batch_size", type=int, default=4, |
| help="Forward batch size. 4 fits in ~30 GB on RTX 5090 " |
| "with Qwen3-VL-4B + Conv3d patch + bf16.") |
| ap.add_argument("--pool_mode", |
| choices=["range", "open", "token_mean", "random_span"], |
| default="range", |
| |
| |
| |
| |
| help="How to pool hidden states to form belief_content: " |
| "range=mean inside <|BELIEF|>...</|BELIEF|> span (default); " |
| "open=hidden at <|BELIEF|> open token; " |
| "token_mean=mean over the whole response (format-agnostic); " |
| "random_span=same-length span at random positions (control).") |
| ap.add_argument("--random_span_seed", type=int, default=0) |
| args = ap.parse_args() |
|
|
| if args.manifest is None: |
| args.manifest = ROOT / f"data/cot_corpus_v2/vlalert_x_perframe_v2_{args.split}.jsonl" |
| out_path = args.out_dir / f"{args.tag}__{args.split}.pt" |
| extract_split(ckpt_dir=args.ckpt, base_model=args.base_model, |
| manifest_path=args.manifest, out_path=out_path, |
| belief_layers=tuple(args.belief_layers), |
| policy_layer=args.policy_layer, |
| limit=args.limit, |
| batch_size=args.batch_size, |
| pool_mode=args.pool_mode, |
| random_span_seed=args.random_span_seed) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|