File size: 9,956 Bytes
1e05592 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | """Rigorous correctness check for Conv3d β Linear replacement.
Tests three things:
1. fp32 equivalence: should be < 1e-6 (proves math is identical)
2. bf16 numerical error: max abs + max relative + mean relative
3. Downstream belief output diff: full vision tower forward, Conv3d vs Linear
If fp32 diff is < 1e-6, the math is provably equivalent.
If downstream belief cosine similarity > 0.9999, the head will see no difference.
"""
import sys
sys.path.insert(0, ".")
import torch
import torch.nn as nn
from peft import PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
from training.Policy.policy_dataset import PolicyDataset, _load_frames
from training.Policy import make_cot_belief_cache as M
def conv3d_to_linear(conv: nn.Conv3d) -> nn.Linear:
"""Build mathematically equivalent Linear layer."""
out_dim = conv.out_channels
in_dim = (conv.in_channels * conv.kernel_size[0]
* conv.kernel_size[1] * conv.kernel_size[2])
w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
bias = conv.bias.detach().clone() if conv.bias is not None else None
new = nn.Linear(in_dim, out_dim, bias=bias is not None)
new.weight.data.copy_(w_flat)
if bias is not None:
new.bias.data.copy_(bias)
return new.to(device=conv.weight.device, dtype=conv.weight.dtype)
def test_fp32_equivalence(conv: nn.Conv3d):
"""In fp32, Conv3d with stride=kernel β‘ Linear. Diff should be ~0."""
print("\n[Test 1] fp32 equivalence (math correctness)")
conv_fp32 = conv.float().cpu()
lin_fp32 = conv3d_to_linear(conv_fp32)
# Build identical 5D and flat input
torch.manual_seed(0)
N = 100
C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1]
x_5d = torch.randn(N, C, T, P, P, dtype=torch.float32)
x_flat = x_5d.reshape(N, -1).contiguous()
out_conv = conv_fp32(x_5d).view(N, -1)
out_lin = lin_fp32(x_flat)
abs_diff = (out_conv - out_lin).abs()
rel_diff = abs_diff / (out_conv.abs() + 1e-9)
print(f" max abs diff: {abs_diff.max().item():.2e}")
print(f" mean abs diff: {abs_diff.mean().item():.2e}")
print(f" max rel diff: {rel_diff.max().item():.2e}")
if abs_diff.max().item() < 1e-5:
print(f" β MATH CORRECT (Conv3d β‘ Linear in fp32)")
return True
else:
print(f" β math diff > 1e-5 β flatten order may be wrong")
return False
def test_bf16_relative(conv: nn.Conv3d):
"""In bf16, accumulated error is expected ~sqrt(1536)Β·eps β 4e-2."""
print("\n[Test 2] bf16 numerical error (rounding only)")
conv_bf16 = conv.cuda().to(torch.bfloat16)
lin_bf16 = conv3d_to_linear(conv_bf16)
torch.manual_seed(0)
N = 100
C, T, P = conv.in_channels, conv.kernel_size[0], conv.kernel_size[1]
x_5d = torch.randn(N, C, T, P, P, dtype=torch.bfloat16, device="cuda")
x_flat = x_5d.reshape(N, -1).contiguous()
with torch.no_grad():
out_conv = conv_bf16(x_5d).view(N, -1).float()
out_lin = lin_bf16(x_flat).float()
abs_diff = (out_conv - out_lin).abs()
rel_diff = abs_diff / (out_conv.abs().clamp_min(1e-3))
cos_sim = torch.nn.functional.cosine_similarity(
out_conv.flatten().unsqueeze(0),
out_lin.flatten().unsqueeze(0)).item()
print(f" max abs diff: {abs_diff.max().item():.2e}")
print(f" mean abs diff: {abs_diff.mean().item():.2e}")
print(f" max rel diff (where |out|>1e-3): {rel_diff.max().item():.2%}")
print(f" mean rel diff: {rel_diff.mean().item():.2%}")
print(f" COSINE SIMILARITY (whole output): {cos_sim:.6f}")
if cos_sim > 0.999:
print(f" β outputs are essentially identical (cos > 0.999)")
return True
print(f" β unexpected β cosine similarity < 0.999")
return False
def test_downstream_belief_diff():
"""Run the FULL vision tower forward via Conv3d path vs Linear path on
real ADAS-TO frames. Compare per-sample belief vectors (this is what the
head actually consumes)."""
print("\n[Test 3] Full vision tower forward, Conv3d vs Linear")
proc = AutoProcessor.from_pretrained(
"checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best")
ds = PolicyDataset(
manifests=["data/policy_labels/val.json"],
split="val", n_frames=8, sampling="last_biased", source_filter="all",
)
all_imgs = [
_load_frames(ds.samples[i]["source_dir"],
ds.samples[i]["frame_indices"], n_frames=8)
for i in range(8)
]
# ββ Path A: original Conv3d ββββββββββββββββββββββββββββββββ
print("\n loading model A (Conv3d, original)...")
model_a = AutoModelForImageTextToText.from_pretrained(
"models/Qwen3-VL-4B-Instruct",
dtype=torch.bfloat16, attn_implementation="sdpa",
)
model_a.resize_token_embeddings(151674)
model_a = PeftModel.from_pretrained(
model_a, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best"
).merge_and_unload()
model_a.cuda().eval()
inputs = M._build_inputs(proc, all_imgs[:4], [{}]*4, resize_short=336)
inputs_g = {k: (v.cuda() if isinstance(v, torch.Tensor) else v)
for k, v in inputs.items()}
inputs_g["pixel_values"] = inputs_g["pixel_values"].to(torch.bfloat16)
keys = ("input_ids", "attention_mask", "pixel_values", "image_grid_thw")
args = {k: inputs_g[k] for k in keys if k in inputs_g}
print(" running Conv3d forward (will be slow ~70s)...")
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
out_a = model_a.model(**args, use_cache=False, return_dict=True)
h_a = out_a.last_hidden_state.float().cpu()
print(f" Conv3d hidden shape: {tuple(h_a.shape)}")
del model_a; torch.cuda.empty_cache()
# ββ Path B: patched Linear βββββββββββββββββββββββββββββββββ
print("\n loading model B (Linear, patched)...")
# Apply lazy patch
def _fast_forward(self, hidden_states):
target_dtype = self.proj.weight.dtype
if isinstance(self.proj, nn.Conv3d):
self.proj = conv3d_to_linear(self.proj)
print(f" [patched] Conv3d β Linear at first call")
if hidden_states.dim() > 2 or hidden_states.shape[-1] != self.proj.in_features:
hidden_states = hidden_states.reshape(-1, self.proj.in_features)
return self.proj(hidden_states.to(dtype=target_dtype))
Qwen3VLVisionPatchEmbed.forward = _fast_forward
model_b = AutoModelForImageTextToText.from_pretrained(
"models/Qwen3-VL-4B-Instruct",
dtype=torch.bfloat16, attn_implementation="sdpa",
)
model_b.resize_token_embeddings(151674)
model_b = PeftModel.from_pretrained(
model_b, "checkpoints/VLA/qwen3vl4b_cot_belief_perframe/best"
).merge_and_unload()
model_b.cuda().eval()
print(" running Linear forward (fast)...")
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
out_b = model_b.model(**args, use_cache=False, return_dict=True)
h_b = out_b.last_hidden_state.float().cpu()
print(f" Linear hidden shape: {tuple(h_b.shape)}")
del model_b; torch.cuda.empty_cache()
assert h_a.shape == h_b.shape, "shapes differ!"
abs_diff = (h_a - h_b).abs()
rel_diff = abs_diff / (h_a.abs().clamp_min(1e-3))
print(f"\n per-token hidden state diff:")
print(f" max abs: {abs_diff.max().item():.2e}")
print(f" mean abs: {abs_diff.mean().item():.2e}")
print(f" mean rel: {rel_diff.mean().item():.2%}")
# cosine similarity per (batch, token) β most relevant for head
h_a_flat = h_a.reshape(-1, h_a.shape[-1])
h_b_flat = h_b.reshape(-1, h_b.shape[-1])
cos = torch.nn.functional.cosine_similarity(h_a_flat, h_b_flat, dim=-1)
print(f"\n per-token cosine similarity:")
print(f" mean: {cos.mean().item():.6f}")
print(f" min: {cos.min().item():.6f}")
print(f" median: {cos.median().item():.6f}")
# mean-pool per sample (the actual belief feature consumed by head)
h_a_pool = h_a.mean(dim=1) # (B, D)
h_b_pool = h_b.mean(dim=1)
pool_cos = torch.nn.functional.cosine_similarity(h_a_pool, h_b_pool, dim=-1)
print(f"\n per-sample MEAN-POOLED belief cosine similarity:")
for i, c in enumerate(pool_cos.tolist()):
print(f" sample {i}: {c:.8f}")
print(f" mean: {pool_cos.mean().item():.8f}")
if pool_cos.min().item() > 0.99:
print(f"\n β DOWNSTREAM IMPACT NEGLIGIBLE (pooled cos > 0.99)")
return True
else:
print(f"\n β οΈ pooled cosine < 0.99 β investigate before using")
return False
def main():
print("=" * 70)
print("Verify Conv3d β Linear correctness for Qwen3VLVisionPatchEmbed")
print("=" * 70)
# Build a fresh Conv3d with same shape as Qwen3-VL-4B's patch_embed
conv = nn.Conv3d(
in_channels=3, out_channels=1024,
kernel_size=(2, 16, 16), stride=(2, 16, 16), bias=True,
)
ok1 = test_fp32_equivalence(conv)
ok2 = test_bf16_relative(conv)
ok3 = test_downstream_belief_diff()
print("\n" + "=" * 70)
print(f"SUMMARY:")
print(f" Test 1 (fp32 math equivalence): "
f"{'PASS' if ok1 else 'FAIL'}")
print(f" Test 2 (bf16 cosine sim): "
f"{'PASS' if ok2 else 'FAIL'}")
print(f" Test 3 (downstream belief sim): "
f"{'PASS' if ok3 else 'FAIL'}")
if ok1 and ok2 and ok3:
print(f"\n βββ Linear replacement is SAFE for inference.")
else:
print(f"\n β οΈ at least one check failed; review before using.")
if __name__ == "__main__":
main()
|