"""Rigorous correctness check for Conv3d → Linear replacement. Tests three things: 1. fp32 equivalence: should be < 1e-6 (proves math is identical) 2. bf16 numerical error: max abs + max relative + mean relative 3. Downstream belief output diff: full vision tower forward, Conv3d vs Linear If fp32 diff is < 1e-6, the math is provably equivalent. If downstream belief cosine similarity > 0.9999, the head will see no difference. """ import sys sys.path.insert(0, ".") import torch import torch.nn as nn from peft import PeftModel from transformers import AutoModelForImageTextToText, AutoProcessor from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed from training.Policy.policy_dataset import PolicyDataset, _load_frames from training.Policy import make_cot_belief_cache as M def conv3d_to_linear(conv: nn.Conv3d) -> nn.Linear: """Build mathematically equivalent Linear layer.""" out_dim = conv.out_channels in_dim = (conv.in_channels * conv.kernel_size[0] * conv.kernel_size[1] * conv.kernel_size[2]) w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous() bias = conv.bias.detach().clone() if conv.bias is not None else None new = nn.Linear(in_dim, out_dim, bias=bias is not None) new.weight.data.copy_(w_flat) if bias is not None: new.bias.data.copy_(bias) return new.to(device=conv.weight.device, dtype=conv.weight.dtype) def test_fp32_equivalence(conv: nn.Conv3d): """In fp32, Conv3d with stride=kernel ≡ Linear. Diff should be ~0.""" print("\n[Test 1] fp32 equivalence (math correctness)") conv_fp32 = conv.float().cpu() lin_fp32 = conv3d_to_linear(conv_fp32) # Build identical 5D and flat input torch.manual_seed(0) N = 100 C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1] x_5d = torch.randn(N, C, T, P, P, dtype=torch.float32) x_flat = x_5d.reshape(N, -1).contiguous() out_conv = conv_fp32(x_5d).view(N, -1) out_lin = lin_fp32(x_flat) abs_diff = (out_conv - out_lin).abs() rel_diff = abs_diff / (out_conv.abs() + 1e-9) print(f" max abs diff: {abs_diff.max().item():.2e}") print(f" mean abs diff: {abs_diff.mean().item():.2e}") print(f" max rel diff: {rel_diff.max().item():.2e}") if abs_diff.max().item() < 1e-5: print(f" ✓ MATH CORRECT (Conv3d ≡ Linear in fp32)") return True else: print(f" ✗ math diff > 1e-5 — flatten order may be wrong") return False def test_bf16_relative(conv: nn.Conv3d): """In bf16, accumulated error is expected ~sqrt(1536)·eps ≈ 4e-2.""" print("\n[Test 2] bf16 numerical error (rounding only)") conv_bf16 = conv.cuda().to(torch.bfloat16) lin_bf16 = conv3d_to_linear(conv_bf16) torch.manual_seed(0) N = 100 C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1] x_5d = torch.randn(N, C, T, P, P, dtype=torch.bfloat16, device="cuda") x_flat = x_5d.reshape(N, -1).contiguous() with torch.no_grad(): out_conv = conv_bf16(x_5d).view(N, -1).float() out_lin = lin_bf16(x_flat).float() abs_diff = (out_conv - out_lin).abs() rel_diff = abs_diff / (out_conv.abs().clamp_min(1e-3)) cos_sim = torch.nn.functional.cosine_similarity( out_conv.flatten().unsqueeze(0), out_lin.flatten().unsqueeze(0)).item() print(f" max abs diff: {abs_diff.max().item():.2e}") print(f" mean abs diff: {abs_diff.mean().item():.2e}") print(f" max rel diff (where |out|>1e-3): {rel_diff.max().item():.2%}") print(f" mean rel diff: {rel_diff.mean().item():.2%}") print(f" COSINE SIMILARITY (whole output): {cos_sim:.6f}") if cos_sim > 0.999: print(f" ✓ outputs are essentially identical (cos > 0.999)") return True print(f" ✗ unexpected — cosine similarity < 0.999") return False def test_downstream_belief_diff(): """Run the FULL vision tower forward via Conv3d path vs Linear path on real ADAS-TO frames. Compare per-sample belief vectors (this is what the head actually consumes).""" print("\n[Test 3] Full vision tower forward, Conv3d vs Linear") proc = AutoProcessor.from_pretrained( "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best") ds = PolicyDataset( manifests=["data/policy_labels/val.json"], split="val", n_frames=8, sampling="last_biased", source_filter="all", ) all_imgs = [ _load_frames(ds.samples[i]["source_dir"], ds.samples[i]["frame_indices"], n_frames=8) for i in range(8) ] # ── Path A: original Conv3d ──────────────────────────────── print("\n loading model A (Conv3d, original)...") model_a = AutoModelForImageTextToText.from_pretrained( "models/Qwen3-VL-4B-Instruct", dtype=torch.bfloat16, attn_implementation="sdpa", ) model_a.resize_token_embeddings(151674) model_a = PeftModel.from_pretrained( model_a, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best" ).merge_and_unload() model_a.cuda().eval() inputs = M._build_inputs(proc, all_imgs[:4], [{}]*4, resize_short=336) inputs_g = {k: (v.cuda() if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} inputs_g["pixel_values"] = inputs_g["pixel_values"].to(torch.bfloat16) keys = ("input_ids", "attention_mask", "pixel_values", "image_grid_thw") args = {k: inputs_g[k] for k in keys if k in inputs_g} print(" running Conv3d forward (will be slow ~70s)...") with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): out_a = model_a.model(**args, use_cache=False, return_dict=True) h_a = out_a.last_hidden_state.float().cpu() print(f" Conv3d hidden shape: {tuple(h_a.shape)}") del model_a; torch.cuda.empty_cache() # ── Path B: patched Linear ───────────────────────────────── print("\n loading model B (Linear, patched)...") # Apply lazy patch def _fast_forward(self, hidden_states): target_dtype = self.proj.weight.dtype if isinstance(self.proj, nn.Conv3d): self.proj = conv3d_to_linear(self.proj) print(f" [patched] Conv3d → Linear at first call") if hidden_states.dim() > 2 or hidden_states.shape[-1] != self.proj.in_features: hidden_states = hidden_states.reshape(-1, self.proj.in_features) return self.proj(hidden_states.to(dtype=target_dtype)) Qwen3VLVisionPatchEmbed.forward = _fast_forward model_b = AutoModelForImageTextToText.from_pretrained( "models/Qwen3-VL-4B-Instruct", dtype=torch.bfloat16, attn_implementation="sdpa", ) model_b.resize_token_embeddings(151674) model_b = PeftModel.from_pretrained( model_b, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best" ).merge_and_unload() model_b.cuda().eval() print(" running Linear forward (fast)...") with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): out_b = model_b.model(**args, use_cache=False, return_dict=True) h_b = out_b.last_hidden_state.float().cpu() print(f" Linear hidden shape: {tuple(h_b.shape)}") del model_b; torch.cuda.empty_cache() assert h_a.shape == h_b.shape, "shapes differ!" abs_diff = (h_a - h_b).abs() rel_diff = abs_diff / (h_a.abs().clamp_min(1e-3)) print(f"\n per-token hidden state diff:") print(f" max abs: {abs_diff.max().item():.2e}") print(f" mean abs: {abs_diff.mean().item():.2e}") print(f" mean rel: {rel_diff.mean().item():.2%}") # cosine similarity per (batch, token) — most relevant for head h_a_flat = h_a.reshape(-1, h_a.shape[-1]) h_b_flat = h_b.reshape(-1, h_b.shape[-1]) cos = torch.nn.functional.cosine_similarity(h_a_flat, h_b_flat, dim=-1) print(f"\n per-token cosine similarity:") print(f" mean: {cos.mean().item():.6f}") print(f" min: {cos.min().item():.6f}") print(f" median: {cos.median().item():.6f}") # mean-pool per sample (the actual belief feature consumed by head) h_a_pool = h_a.mean(dim=1) # (B, D) h_b_pool = h_b.mean(dim=1) pool_cos = torch.nn.functional.cosine_similarity(h_a_pool, h_b_pool, dim=-1) print(f"\n per-sample MEAN-POOLED belief cosine similarity:") for i, c in enumerate(pool_cos.tolist()): print(f" sample {i}: {c:.8f}") print(f" mean: {pool_cos.mean().item():.8f}") if pool_cos.min().item() > 0.99: print(f"\n ✓ DOWNSTREAM IMPACT NEGLIGIBLE (pooled cos > 0.99)") return True else: print(f"\n ⚠️ pooled cosine < 0.99 — investigate before using") return False def main(): print("=" * 70) print("Verify Conv3d → Linear correctness for Qwen3VLVisionPatchEmbed") print("=" * 70) # Build a fresh Conv3d with same shape as Qwen3-VL-4B's patch_embed conv = nn.Conv3d( in_channels=3, out_channels=1024, kernel_size=(2, 16, 16), stride=(2, 16, 16), bias=True, ) ok1 = test_fp32_equivalence(conv) ok2 = test_bf16_relative(conv) ok3 = test_downstream_belief_diff() print("\n" + "=" * 70) print(f"SUMMARY:") print(f" Test 1 (fp32 math equivalence): " f"{'PASS' if ok1 else 'FAIL'}") print(f" Test 2 (bf16 cosine sim): " f"{'PASS' if ok2 else 'FAIL'}") print(f" Test 3 (downstream belief sim): " f"{'PASS' if ok3 else 'FAIL'}") if ok1 and ok2 and ok3: print(f"\n ✓✓✓ Linear replacement is SAFE for inference.") else: print(f"\n ⚠️ at least one check failed; review before using.") if __name__ == "__main__": main()