| """Rigorous correctness check for Conv3d β Linear replacement. |
| |
| Tests three things: |
| 1. fp32 equivalence: should be < 1e-6 (proves math is identical) |
| 2. bf16 numerical error: max abs + max relative + mean relative |
| 3. Downstream belief output diff: full vision tower forward, Conv3d vs Linear |
| |
| If fp32 diff is < 1e-6, the math is provably equivalent. |
| If downstream belief cosine similarity > 0.9999, the head will see no difference. |
| """ |
| import sys |
| sys.path.insert(0, ".") |
|
|
| import torch |
| import torch.nn as nn |
| from peft import PeftModel |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed |
|
|
| from training.Policy.policy_dataset import PolicyDataset, _load_frames |
| from training.Policy import make_cot_belief_cache as M |
|
|
|
|
| def conv3d_to_linear(conv: nn.Conv3d) -> nn.Linear: |
| """Build mathematically equivalent Linear layer.""" |
| out_dim = conv.out_channels |
| in_dim = (conv.in_channels * conv.kernel_size[0] |
| * conv.kernel_size[1] * conv.kernel_size[2]) |
| w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous() |
| bias = conv.bias.detach().clone() if conv.bias is not None else None |
| new = nn.Linear(in_dim, out_dim, bias=bias is not None) |
| new.weight.data.copy_(w_flat) |
| if bias is not None: |
| new.bias.data.copy_(bias) |
| return new.to(device=conv.weight.device, dtype=conv.weight.dtype) |
|
|
|
|
| def test_fp32_equivalence(conv: nn.Conv3d): |
| """In fp32, Conv3d with stride=kernel β‘ Linear. Diff should be ~0.""" |
| print("\n[Test 1] fp32 equivalence (math correctness)") |
| conv_fp32 = conv.float().cpu() |
| lin_fp32 = conv3d_to_linear(conv_fp32) |
|
|
| |
| torch.manual_seed(0) |
| N = 100 |
| C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1] |
| x_5d = torch.randn(N, C, T, P, P, dtype=torch.float32) |
| x_flat = x_5d.reshape(N, -1).contiguous() |
|
|
| out_conv = conv_fp32(x_5d).view(N, -1) |
| out_lin = lin_fp32(x_flat) |
|
|
| abs_diff = (out_conv - out_lin).abs() |
| rel_diff = abs_diff / (out_conv.abs() + 1e-9) |
| print(f" max abs diff: {abs_diff.max().item():.2e}") |
| print(f" mean abs diff: {abs_diff.mean().item():.2e}") |
| print(f" max rel diff: {rel_diff.max().item():.2e}") |
| if abs_diff.max().item() < 1e-5: |
| print(f" β MATH CORRECT (Conv3d β‘ Linear in fp32)") |
| return True |
| else: |
| print(f" β math diff > 1e-5 β flatten order may be wrong") |
| return False |
|
|
|
|
| def test_bf16_relative(conv: nn.Conv3d): |
| """In bf16, accumulated error is expected ~sqrt(1536)Β·eps β 4e-2.""" |
| print("\n[Test 2] bf16 numerical error (rounding only)") |
| conv_bf16 = conv.cuda().to(torch.bfloat16) |
| lin_bf16 = conv3d_to_linear(conv_bf16) |
|
|
| torch.manual_seed(0) |
| N = 100 |
| C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1] |
| x_5d = torch.randn(N, C, T, P, P, dtype=torch.bfloat16, device="cuda") |
| x_flat = x_5d.reshape(N, -1).contiguous() |
|
|
| with torch.no_grad(): |
| out_conv = conv_bf16(x_5d).view(N, -1).float() |
| out_lin = lin_bf16(x_flat).float() |
|
|
| abs_diff = (out_conv - out_lin).abs() |
| rel_diff = abs_diff / (out_conv.abs().clamp_min(1e-3)) |
| cos_sim = torch.nn.functional.cosine_similarity( |
| out_conv.flatten().unsqueeze(0), |
| out_lin.flatten().unsqueeze(0)).item() |
| print(f" max abs diff: {abs_diff.max().item():.2e}") |
| print(f" mean abs diff: {abs_diff.mean().item():.2e}") |
| print(f" max rel diff (where |out|>1e-3): {rel_diff.max().item():.2%}") |
| print(f" mean rel diff: {rel_diff.mean().item():.2%}") |
| print(f" COSINE SIMILARITY (whole output): {cos_sim:.6f}") |
| if cos_sim > 0.999: |
| print(f" β outputs are essentially identical (cos > 0.999)") |
| return True |
| print(f" β unexpected β cosine similarity < 0.999") |
| return False |
|
|
|
|
| def test_downstream_belief_diff(): |
| """Run the FULL vision tower forward via Conv3d path vs Linear path on |
| real ADAS-TO frames. Compare per-sample belief vectors (this is what the |
| head actually consumes).""" |
| print("\n[Test 3] Full vision tower forward, Conv3d vs Linear") |
| proc = AutoProcessor.from_pretrained( |
| "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best") |
| ds = PolicyDataset( |
| manifests=["data/policy_labels/val.json"], |
| split="val", n_frames=8, sampling="last_biased", source_filter="all", |
| ) |
| all_imgs = [ |
| _load_frames(ds.samples[i]["source_dir"], |
| ds.samples[i]["frame_indices"], n_frames=8) |
| for i in range(8) |
| ] |
|
|
| |
| print("\n loading model A (Conv3d, original)...") |
| model_a = AutoModelForImageTextToText.from_pretrained( |
| "models/Qwen3-VL-4B-Instruct", |
| dtype=torch.bfloat16, attn_implementation="sdpa", |
| ) |
| model_a.resize_token_embeddings(151674) |
| model_a = PeftModel.from_pretrained( |
| model_a, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best" |
| ).merge_and_unload() |
| model_a.cuda().eval() |
|
|
| inputs = M._build_inputs(proc, all_imgs[:4], [{}]*4, resize_short=336) |
| inputs_g = {k: (v.cuda() if isinstance(v, torch.Tensor) else v) |
| for k, v in inputs.items()} |
| inputs_g["pixel_values"] = inputs_g["pixel_values"].to(torch.bfloat16) |
|
|
| keys = ("input_ids", "attention_mask", "pixel_values", "image_grid_thw") |
| args = {k: inputs_g[k] for k in keys if k in inputs_g} |
|
|
| print(" running Conv3d forward (will be slow ~70s)...") |
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out_a = model_a.model(**args, use_cache=False, return_dict=True) |
| h_a = out_a.last_hidden_state.float().cpu() |
| print(f" Conv3d hidden shape: {tuple(h_a.shape)}") |
| del model_a; torch.cuda.empty_cache() |
|
|
| |
| print("\n loading model B (Linear, patched)...") |
|
|
| |
| def _fast_forward(self, hidden_states): |
| target_dtype = self.proj.weight.dtype |
| if isinstance(self.proj, nn.Conv3d): |
| self.proj = conv3d_to_linear(self.proj) |
| print(f" [patched] Conv3d β Linear at first call") |
| if hidden_states.dim() > 2 or hidden_states.shape[-1] != self.proj.in_features: |
| hidden_states = hidden_states.reshape(-1, self.proj.in_features) |
| return self.proj(hidden_states.to(dtype=target_dtype)) |
| Qwen3VLVisionPatchEmbed.forward = _fast_forward |
|
|
| model_b = AutoModelForImageTextToText.from_pretrained( |
| "models/Qwen3-VL-4B-Instruct", |
| dtype=torch.bfloat16, attn_implementation="sdpa", |
| ) |
| model_b.resize_token_embeddings(151674) |
| model_b = PeftModel.from_pretrained( |
| model_b, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best" |
| ).merge_and_unload() |
| model_b.cuda().eval() |
|
|
| print(" running Linear forward (fast)...") |
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out_b = model_b.model(**args, use_cache=False, return_dict=True) |
| h_b = out_b.last_hidden_state.float().cpu() |
| print(f" Linear hidden shape: {tuple(h_b.shape)}") |
| del model_b; torch.cuda.empty_cache() |
|
|
| assert h_a.shape == h_b.shape, "shapes differ!" |
| abs_diff = (h_a - h_b).abs() |
| rel_diff = abs_diff / (h_a.abs().clamp_min(1e-3)) |
| print(f"\n per-token hidden state diff:") |
| print(f" max abs: {abs_diff.max().item():.2e}") |
| print(f" mean abs: {abs_diff.mean().item():.2e}") |
| print(f" mean rel: {rel_diff.mean().item():.2%}") |
|
|
| |
| h_a_flat = h_a.reshape(-1, h_a.shape[-1]) |
| h_b_flat = h_b.reshape(-1, h_b.shape[-1]) |
| cos = torch.nn.functional.cosine_similarity(h_a_flat, h_b_flat, dim=-1) |
| print(f"\n per-token cosine similarity:") |
| print(f" mean: {cos.mean().item():.6f}") |
| print(f" min: {cos.min().item():.6f}") |
| print(f" median: {cos.median().item():.6f}") |
|
|
| |
| h_a_pool = h_a.mean(dim=1) |
| h_b_pool = h_b.mean(dim=1) |
| pool_cos = torch.nn.functional.cosine_similarity(h_a_pool, h_b_pool, dim=-1) |
| print(f"\n per-sample MEAN-POOLED belief cosine similarity:") |
| for i, c in enumerate(pool_cos.tolist()): |
| print(f" sample {i}: {c:.8f}") |
| print(f" mean: {pool_cos.mean().item():.8f}") |
|
|
| if pool_cos.min().item() > 0.99: |
| print(f"\n β DOWNSTREAM IMPACT NEGLIGIBLE (pooled cos > 0.99)") |
| return True |
| else: |
| print(f"\n β οΈ pooled cosine < 0.99 β investigate before using") |
| return False |
|
|
|
|
| def main(): |
| print("=" * 70) |
| print("Verify Conv3d β Linear correctness for Qwen3VLVisionPatchEmbed") |
| print("=" * 70) |
|
|
| |
| conv = nn.Conv3d( |
| in_channels=3, out_channels=1024, |
| kernel_size=(2, 16, 16), stride=(2, 16, 16), bias=True, |
| ) |
|
|
| ok1 = test_fp32_equivalence(conv) |
| ok2 = test_bf16_relative(conv) |
| ok3 = test_downstream_belief_diff() |
|
|
| print("\n" + "=" * 70) |
| print(f"SUMMARY:") |
| print(f" Test 1 (fp32 math equivalence): " |
| f"{'PASS' if ok1 else 'FAIL'}") |
| print(f" Test 2 (bf16 cosine sim): " |
| f"{'PASS' if ok2 else 'FAIL'}") |
| print(f" Test 3 (downstream belief sim): " |
| f"{'PASS' if ok3 else 'FAIL'}") |
| if ok1 and ok2 and ok3: |
| print(f"\n βββ Linear replacement is SAFE for inference.") |
| else: |
| print(f"\n β οΈ at least one check failed; review before using.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|