VLAlert / tools /verify_patch_embed_correctness.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
"""Rigorous correctness check for Conv3d β†’ Linear replacement.
Tests three things:
1. fp32 equivalence: should be < 1e-6 (proves math is identical)
2. bf16 numerical error: max abs + max relative + mean relative
3. Downstream belief output diff: full vision tower forward, Conv3d vs Linear
If fp32 diff is < 1e-6, the math is provably equivalent.
If downstream belief cosine similarity > 0.9999, the head will see no difference.
"""
import sys
sys.path.insert(0, ".")
import torch
import torch.nn as nn
from peft import PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
from training.Policy.policy_dataset import PolicyDataset, _load_frames
from training.Policy import make_cot_belief_cache as M
def conv3d_to_linear(conv: nn.Conv3d) -> nn.Linear:
"""Build mathematically equivalent Linear layer."""
out_dim = conv.out_channels
in_dim = (conv.in_channels * conv.kernel_size[0]
* conv.kernel_size[1] * conv.kernel_size[2])
w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
bias = conv.bias.detach().clone() if conv.bias is not None else None
new = nn.Linear(in_dim, out_dim, bias=bias is not None)
new.weight.data.copy_(w_flat)
if bias is not None:
new.bias.data.copy_(bias)
return new.to(device=conv.weight.device, dtype=conv.weight.dtype)
def test_fp32_equivalence(conv: nn.Conv3d):
"""In fp32, Conv3d with stride=kernel ≑ Linear. Diff should be ~0."""
print("\n[Test 1] fp32 equivalence (math correctness)")
conv_fp32 = conv.float().cpu()
lin_fp32 = conv3d_to_linear(conv_fp32)
# Build identical 5D and flat input
torch.manual_seed(0)
N = 100
C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1]
x_5d = torch.randn(N, C, T, P, P, dtype=torch.float32)
x_flat = x_5d.reshape(N, -1).contiguous()
out_conv = conv_fp32(x_5d).view(N, -1)
out_lin = lin_fp32(x_flat)
abs_diff = (out_conv - out_lin).abs()
rel_diff = abs_diff / (out_conv.abs() + 1e-9)
print(f" max abs diff: {abs_diff.max().item():.2e}")
print(f" mean abs diff: {abs_diff.mean().item():.2e}")
print(f" max rel diff: {rel_diff.max().item():.2e}")
if abs_diff.max().item() < 1e-5:
print(f" βœ“ MATH CORRECT (Conv3d ≑ Linear in fp32)")
return True
else:
print(f" βœ— math diff > 1e-5 β€” flatten order may be wrong")
return False
def test_bf16_relative(conv: nn.Conv3d):
"""In bf16, accumulated error is expected ~sqrt(1536)Β·eps β‰ˆ 4e-2."""
print("\n[Test 2] bf16 numerical error (rounding only)")
conv_bf16 = conv.cuda().to(torch.bfloat16)
lin_bf16 = conv3d_to_linear(conv_bf16)
torch.manual_seed(0)
N = 100
C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1]
x_5d = torch.randn(N, C, T, P, P, dtype=torch.bfloat16, device="cuda")
x_flat = x_5d.reshape(N, -1).contiguous()
with torch.no_grad():
out_conv = conv_bf16(x_5d).view(N, -1).float()
out_lin = lin_bf16(x_flat).float()
abs_diff = (out_conv - out_lin).abs()
rel_diff = abs_diff / (out_conv.abs().clamp_min(1e-3))
cos_sim = torch.nn.functional.cosine_similarity(
out_conv.flatten().unsqueeze(0),
out_lin.flatten().unsqueeze(0)).item()
print(f" max abs diff: {abs_diff.max().item():.2e}")
print(f" mean abs diff: {abs_diff.mean().item():.2e}")
print(f" max rel diff (where |out|>1e-3): {rel_diff.max().item():.2%}")
print(f" mean rel diff: {rel_diff.mean().item():.2%}")
print(f" COSINE SIMILARITY (whole output): {cos_sim:.6f}")
if cos_sim > 0.999:
print(f" βœ“ outputs are essentially identical (cos > 0.999)")
return True
print(f" βœ— unexpected β€” cosine similarity < 0.999")
return False
def test_downstream_belief_diff():
"""Run the FULL vision tower forward via Conv3d path vs Linear path on
real ADAS-TO frames. Compare per-sample belief vectors (this is what the
head actually consumes)."""
print("\n[Test 3] Full vision tower forward, Conv3d vs Linear")
proc = AutoProcessor.from_pretrained(
"checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best")
ds = PolicyDataset(
manifests=["data/policy_labels/val.json"],
split="val", n_frames=8, sampling="last_biased", source_filter="all",
)
all_imgs = [
_load_frames(ds.samples[i]["source_dir"],
ds.samples[i]["frame_indices"], n_frames=8)
for i in range(8)
]
# ── Path A: original Conv3d ────────────────────────────────
print("\n loading model A (Conv3d, original)...")
model_a = AutoModelForImageTextToText.from_pretrained(
"models/Qwen3-VL-4B-Instruct",
dtype=torch.bfloat16, attn_implementation="sdpa",
)
model_a.resize_token_embeddings(151674)
model_a = PeftModel.from_pretrained(
model_a, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best"
).merge_and_unload()
model_a.cuda().eval()
inputs = M._build_inputs(proc, all_imgs[:4], [{}]*4, resize_short=336)
inputs_g = {k: (v.cuda() if isinstance(v, torch.Tensor) else v)
for k, v in inputs.items()}
inputs_g["pixel_values"] = inputs_g["pixel_values"].to(torch.bfloat16)
keys = ("input_ids", "attention_mask", "pixel_values", "image_grid_thw")
args = {k: inputs_g[k] for k in keys if k in inputs_g}
print(" running Conv3d forward (will be slow ~70s)...")
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
out_a = model_a.model(**args, use_cache=False, return_dict=True)
h_a = out_a.last_hidden_state.float().cpu()
print(f" Conv3d hidden shape: {tuple(h_a.shape)}")
del model_a; torch.cuda.empty_cache()
# ── Path B: patched Linear ─────────────────────────────────
print("\n loading model B (Linear, patched)...")
# Apply lazy patch
def _fast_forward(self, hidden_states):
target_dtype = self.proj.weight.dtype
if isinstance(self.proj, nn.Conv3d):
self.proj = conv3d_to_linear(self.proj)
print(f" [patched] Conv3d β†’ Linear at first call")
if hidden_states.dim() > 2 or hidden_states.shape[-1] != self.proj.in_features:
hidden_states = hidden_states.reshape(-1, self.proj.in_features)
return self.proj(hidden_states.to(dtype=target_dtype))
Qwen3VLVisionPatchEmbed.forward = _fast_forward
model_b = AutoModelForImageTextToText.from_pretrained(
"models/Qwen3-VL-4B-Instruct",
dtype=torch.bfloat16, attn_implementation="sdpa",
)
model_b.resize_token_embeddings(151674)
model_b = PeftModel.from_pretrained(
model_b, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best"
).merge_and_unload()
model_b.cuda().eval()
print(" running Linear forward (fast)...")
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
out_b = model_b.model(**args, use_cache=False, return_dict=True)
h_b = out_b.last_hidden_state.float().cpu()
print(f" Linear hidden shape: {tuple(h_b.shape)}")
del model_b; torch.cuda.empty_cache()
assert h_a.shape == h_b.shape, "shapes differ!"
abs_diff = (h_a - h_b).abs()
rel_diff = abs_diff / (h_a.abs().clamp_min(1e-3))
print(f"\n per-token hidden state diff:")
print(f" max abs: {abs_diff.max().item():.2e}")
print(f" mean abs: {abs_diff.mean().item():.2e}")
print(f" mean rel: {rel_diff.mean().item():.2%}")
# cosine similarity per (batch, token) β€” most relevant for head
h_a_flat = h_a.reshape(-1, h_a.shape[-1])
h_b_flat = h_b.reshape(-1, h_b.shape[-1])
cos = torch.nn.functional.cosine_similarity(h_a_flat, h_b_flat, dim=-1)
print(f"\n per-token cosine similarity:")
print(f" mean: {cos.mean().item():.6f}")
print(f" min: {cos.min().item():.6f}")
print(f" median: {cos.median().item():.6f}")
# mean-pool per sample (the actual belief feature consumed by head)
h_a_pool = h_a.mean(dim=1) # (B, D)
h_b_pool = h_b.mean(dim=1)
pool_cos = torch.nn.functional.cosine_similarity(h_a_pool, h_b_pool, dim=-1)
print(f"\n per-sample MEAN-POOLED belief cosine similarity:")
for i, c in enumerate(pool_cos.tolist()):
print(f" sample {i}: {c:.8f}")
print(f" mean: {pool_cos.mean().item():.8f}")
if pool_cos.min().item() > 0.99:
print(f"\n βœ“ DOWNSTREAM IMPACT NEGLIGIBLE (pooled cos > 0.99)")
return True
else:
print(f"\n ⚠️ pooled cosine < 0.99 β€” investigate before using")
return False
def main():
print("=" * 70)
print("Verify Conv3d β†’ Linear correctness for Qwen3VLVisionPatchEmbed")
print("=" * 70)
# Build a fresh Conv3d with same shape as Qwen3-VL-4B's patch_embed
conv = nn.Conv3d(
in_channels=3, out_channels=1024,
kernel_size=(2, 16, 16), stride=(2, 16, 16), bias=True,
)
ok1 = test_fp32_equivalence(conv)
ok2 = test_bf16_relative(conv)
ok3 = test_downstream_belief_diff()
print("\n" + "=" * 70)
print(f"SUMMARY:")
print(f" Test 1 (fp32 math equivalence): "
f"{'PASS' if ok1 else 'FAIL'}")
print(f" Test 2 (bf16 cosine sim): "
f"{'PASS' if ok2 else 'FAIL'}")
print(f" Test 3 (downstream belief sim): "
f"{'PASS' if ok3 else 'FAIL'}")
if ok1 and ok2 and ok3:
print(f"\n βœ“βœ“βœ“ Linear replacement is SAFE for inference.")
else:
print(f"\n ⚠️ at least one check failed; review before using.")
if __name__ == "__main__":
main()