Qwen3-VL Vision Patch Embedding: 1000× Slowdown from nn.Conv3d on Blackwell GPUs
Author: Anonymous · Date: 2026-05-03
Status: confirmed bug · workaround validated · upstream patch proposed
Component: transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionPatchEmbed
TL;DR
Qwen3VLVisionPatchEmbed.forward runs at ~16 seconds per call for a single
8-frame video clip on RTX 5090 (Blackwell, sm_120) with PyTorch 2.9 +
CUDA 12.8 + cuDNN 9.10.0.2 + bf16. The bottleneck is a single nn.Conv3d op
whose kernel_size == stride == [2, 16, 16] configuration falls into a
degenerate cuDNN slow-path. Replacing it with a mathematically equivalent
nn.Linear makes it run in ~0.3 ms — a >50,000× speedup on the
isolated layer, and ~64× end-to-end on the full vision tower forward.
This bug makes large-scale belief-cache extraction effectively impossible:
extracting features for 29,169 multisrc-val samples would have taken
~6 days with Conv3d, but completes in ~2 hours with the Linear
replacement. Mathematical equivalence is proven and downstream belief
cosine similarity > 0.99.
1. Environment
Python: 3.14.0
PyTorch: 2.9.0+cu128
CUDA: 12.8
cuDNN: 9.10.0.2 (91002)
transformers: 5.0.0.dev0
flash-attn: 2.8.3 (installed)
GPU: NVIDIA GeForce RTX 5090 (Blackwell, compute capability 12.0)
OS: Linux-6.8.0-110-generic-x86_64-with-glibc2.39
Hardware: 32 GB VRAM, 24 CPU cores, 62 GB RAM.
2. The buggy implementation
File:
~/miniconda3/envs/lkalert/lib/python3.14/site-packages/
transformers/models/qwen3_vl/modeling_qwen3_vl.py
Lines 59–76:
class Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size # 16
self.temporal_patch_size = config.temporal_patch_size # 2
self.in_channels = config.in_channels # 3
self.embed_dim = config.hidden_size # 1024
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
# ▼ The slow op:
self.proj = nn.Conv3d(
self.in_channels, self.embed_dim,
kernel_size=kernel_size, stride=kernel_size, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size,
self.patch_size, self.patch_size,
)
hidden_states = self.proj(
hidden_states.to(dtype=target_dtype)
).view(-1, self.embed_dim)
return hidden_states
The convolution has kernel_size == stride, no padding, no dilation.
3. Discovery timeline
The slowdown was found while attempting to extract per-frame Qwen3-VL-4B
belief features for the LKAlert paper's multisrc-val evaluation set
(29,169 samples). The end-to-end extraction script
[training/Policy/make_cot_belief_cache.py] was running at 138 seconds per
DataLoader iteration with --batch_size 8, projecting to 5–6 days of
wall-clock time. Profiling proceeded in five stages.
Stage 1 — confirm GPU is healthy
Pure matmul benchmark on RTX 5090:
matmul 4096x4096: 0.8 ms total/10, 182.3 TFLOPS
matmul 8192x8192: 4.9 ms total/10, 223.7 TFLOPS
Hardware delivers ~200 TFLOPs bf16 — within spec. GPU is fine.
Stage 2 — eliminate batching as the cause
Tested forward time at multiple batch sizes:
| batch_size | total time | per-sample | seq_len | VRAM |
|---|---|---|---|---|
| 1 | 16.5 s | 16.5 s | 1653 | 9.7 GB |
| 4 | 65.3 s | 16.3 s | 2133 | 10.0 GB |
| 8 | 148 s | 18.5 s | 2133 | 10.0 GB |
| 16 | 145 s | 9.3 s | 2133 | 10.0 GB |
Per-sample time is ~16 s regardless of batch size, ruling out a DataLoader, collate, or padding bug. Batch=16 saturates at the same total time, suggesting the bottleneck is per-token, not per-sample.
Stage 3 — eliminate attention as the cause
Tested all three attn_implementation settings on Qwen3-VL:
| attn_implementation | bs=1 forward | bs=8 forward |
|---|---|---|
eager |
17.1 s | — |
sdpa |
16.5 s | 145.6 s |
flash_attention_2 |
16.5 s | 147.6 s |
All three are identically slow. A monkey-patch replacing
Qwen3VLVisionAttention.forward with a clean SDPA implementation also gave
no speedup (still ~150 s at bs=8). Attention is not the bottleneck.
Stage 4 — granular component timing
Per-component timing of Qwen3VLVisionModel.forward for bs=1 (8 frames,
6080 visual patches):
patch_embed: 16,111.3 ms ← 96% of forward time
pos_embed_interpolate: 22.8 ms
rot_pos_emb: 20.7 ms
block[0]: 23.4 ms (warmup)
block[1..23] (23 layers): 1.4 ms each
block ALL total (24 layers):56.4 ms ← entire transformer is fast
merger: 0.5 ms
─────────────────────────────────────
TOTAL ≈ 16,212 ms
The 24-layer ViT transformer takes 56 ms total. The single Conv3d
patch projection takes 16,111 ms — 287× more than the rest of the
network combined.
Stage 5 — pinpoint the slow op
Source inspection of Qwen3VLVisionPatchEmbed.proj reveals
nn.Conv3d(3, 1024, kernel=[2,16,16], stride=[2,16,16]). With
stride == kernel, this convolution has zero overlap between output
positions. Each output element is a function of exactly one disjoint
3-channel × 2-frame × 16×16-pixel window — i.e., a per-window dot product.
This is mathematically a flatten + linear projection, not a true 3-D convolution.
4. Root-cause analysis
Why the cuDNN path is slow
cuDNN's convolution_forward dispatcher does not detect the special case
kernel_size == stride && dilation == 1 && padding == 0. For typical 3D
convolutions (overlapping kernels, e.g. video models), this is fine — cuDNN
selects implicit-GEMM or Winograd algorithms tuned for spatial reuse.
For the patchification case (no spatial reuse), cuDNN still goes through the full 3-D path. On Blackwell (sm_120) at the time of writing, this path appears to fall back to a generic, unfused, non-tensor-core kernel for bf16
- tiny kernels. We did not bisect to the exact kernel name, but the empirical 1000× slowdown vs. the Linear equivalent is consistent with "loops + scalar ops" rather than "tensor-core GEMM".
Layered responsibility
| Layer | Has bug? | Could fix? |
|---|---|---|
| HuggingFace transformers (Qwen3-VL design) | Source: chose nn.Conv3d for a non-convolutional op |
Replace with nn.Linear (1-line PR) |
| cuDNN 9.10.0.2 | Yes — slow path for stride==kernel Conv3d on sm_120 + bf16 |
NVIDIA |
| PyTorch 2.9 | Could short-circuit stride==kernel to bmm/Linear in dispatcher |
PyTorch team |
Most pragmatic fix: change one line in transformers.
Why this wasn't noticed earlier
- The same pattern exists in Qwen2-VL and Qwen2.5-VL (same
nn.Conv3ddesign). Earlier extractions on these checkpoints may have run on Hopper (sm_90) or older cuDNN, where the slow path didn't trigger, or completed despite being slow because dataset sizes were smaller. - Earlier Qwen3-VL extractions in this repo (DAD test = 466 samples, DADA test = 1001 samples) did run at 16 s/sample — the user simply waited 2–4 hours per extraction without noticing the inefficiency. The bug only became blocking when extracting 29,169 multisrc samples.
- Standard ImageNet ViT benchmarks use Conv2d (not Conv3d) for patch embed; Qwen-VL is unusual in needing a 3-D op (because of the temporal patch dimension).
5. Mathematical equivalence proof
Claim
For an nn.Conv3d configured with kernel_size = stride (and padding = 0,
dilation = 1, groups = 1), the operation is exactly equivalent to:
y = x.flatten() @ W.flatten().T + b
where W.flatten() reshapes the convolution kernel from
(out_dim, in_C, k_t, k_h, k_w) to (out_dim, in_C·k_t·k_h·k_w) in
row-major (C-style) order, and x.flatten() similarly reshapes the input
patch.
Proof
nn.Conv3d defines, for output position (t', h', w'):
y[k, t', h', w'] = b[k] + Σ_{c, dt, dh, dw} W[k, c, dt, dh, dw] · x[c, s_t·t' + dt, s_h·h' + dh, s_w·w' + dw]
with s_t, s_h, s_w the strides and dt, dh, dw ranging over the kernel
extents [0, k_t), [0, k_h), [0, k_w).
When s_t = k_t, s_h = k_h, s_w = k_w (the patchification case), the input
windows for distinct output positions are disjoint:
window(t') = [t'·k_t, (t'+1)·k_t) non-overlapping
window(h') = [h'·k_h, (h'+1)·k_h) non-overlapping
window(w') = [w'·k_w, (w'+1)·k_w) non-overlapping
For each disjoint window, the convolution output is exactly the dot product between the flattened window contents and the flattened kernel:
y[k, t', h', w'] = b[k] + Σ_{c, dt, dh, dw}
W[k, c, dt, dh, dw]
· x[c, t'·k_t + dt, h'·k_h + dh, w'·k_w + dw]
= b[k] + ⟨ flatten(W[k]) , flatten(window(t', h', w')) ⟩
If we reshape the input tensor so that each disjoint window is a row,
this is literally nn.Linear's definition:
y = b + W_flat @ x_flat.T where W_flat = W.reshape(out_dim, -1)
x_flat = x.reshape(N_patches, -1)
The flattening order must be consistent on both sides. PyTorch's default
row-major (.reshape() / .view() without permutation) preserves
(c, dt, dh, dw) ordering on both W and x, so a single
.reshape(out_dim, -1) of the kernel and .reshape(N, -1) of the input
gives the equivalence. ∎
Implementation
def conv3d_to_linear(conv: nn.Conv3d) -> nn.Linear:
"""Build mathematically equivalent Linear for a Conv3d with stride=kernel."""
out_dim = conv.out_channels
in_dim = (conv.in_channels * conv.kernel_size[0]
* conv.kernel_size[1] * conv.kernel_size[2])
# Conv3d weight: (out, in_C, k_t, k_h, k_w) → row-major flatten
w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
bias = conv.bias.detach().clone() if conv.bias is not None else None
new = nn.Linear(in_dim, out_dim, bias=bias is not None)
new.weight.data.copy_(w_flat)
if bias is not None:
new.bias.data.copy_(bias)
return new.to(device=conv.weight.device, dtype=conv.weight.dtype)
6. Verification
6.1 Numerical equivalence
Three tests defined in
tools/verify_patch_embed_correctness.py:
| Test | Tolerance | Result | What it proves |
|---|---|---|---|
| fp32 math equivalence | max abs diff < 1e-5 | < 1e-7 (typical) | Conv3d ≡ Linear up to fp32 round-off |
| bf16 numerical noise | cosine sim > 0.999 | ~0.9995 | bf16 accumulation noise is bounded |
| Downstream belief output (after 24-layer ViT) | per-sample pooled cos > 0.99 | > 0.999 | head receives indistinguishable features |
The bf16 absolute difference of 1.56e-2 on the patch_embed output alone is
the expected sqrt(N_inputs) · ε_bf16 ≈ √1536 · 2⁻⁷ ≈ 0.4 for direct
single-precision accumulation, well bounded by nn.Linear's use of
fma + tensor cores.
6.2 End-to-end speedup
Benchmark on RTX 5090, single 8-frame video clip (6080 visual patches at short-edge 336):
| forward | bs=1 | bs=8 | bs=16 | end-to-end (29,169 samples) |
|---|---|---|---|---|
| Conv3d (current) | 16.5 s | 150 s | 145 s | ~6 days |
| Linear (patched) | 0.27 s | 2.16 s | (TBD) | ~2.2 hours |
| Speedup | 61× | 70× | — | ~65× |
Patch-embed micro-benchmark (just the layer in isolation):
| Conv3d | Linear | speedup | |
|---|---|---|---|
| time per forward | 16,111 ms | 0.3 ms | >50,000× |
7. Workaround code
The following workaround is in
tools/run_qwen3_cache_fast.py at this repository:
import torch.nn as nn
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
def _fast_patch_embed_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Lazy in-place replacement: first call swaps Conv3d → Linear, then
runs the equivalent flat-projection forward."""
target_dtype = self.proj.weight.dtype
if isinstance(self.proj, nn.Conv3d):
# First call on this instance: convert in place
conv = self.proj
out_dim = conv.out_channels
in_dim = (conv.in_channels * conv.kernel_size[0]
* conv.kernel_size[1] * conv.kernel_size[2])
w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
bias = conv.bias.detach().clone() if conv.bias is not None else None
new_proj = nn.Linear(in_dim, out_dim, bias=bias is not None)
new_proj.weight.data.copy_(w_flat)
if bias is not None:
new_proj.bias.data.copy_(bias)
new_proj.to(device=conv.weight.device, dtype=conv.weight.dtype)
self.proj = new_proj # in-place attribute swap
# self.proj is now nn.Linear; route through it
if hidden_states.dim() > 2 or hidden_states.shape[-1] != self.proj.in_features:
hidden_states = hidden_states.reshape(-1, self.proj.in_features)
return self.proj(hidden_states.to(dtype=target_dtype))
# Apply class-level patch BEFORE any model is instantiated
Qwen3VLVisionPatchEmbed.forward = _fast_patch_embed_forward
Apply once at process start; the lazy in-place conversion is triggered
on the first forward of each Qwen3VLVisionPatchEmbed instance.
Properties
- No model weight modification — the existing
state_dictis preserved exactly; only the layout ofself.projchanges (Conv3d → Linear) at inference time. - No effect on training — the patch is only applied in our inference pipeline.
- Idempotent — re-applying does nothing (the
isinstancecheck skips conversion whenself.projis alreadynn.Linear). - Resumable —
make_cot_belief_cache.pywrites per-chunk.ptfiles, so a crashed run can resume.
8. Proposed upstream fix
Replacing 3 lines in transformers/models/qwen3_vl/modeling_qwen3_vl.py
removes the slowdown for all users of Qwen3-VL without any behavioral
change:
class Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
- kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
- self.proj = nn.Conv3d(
- self.in_channels, self.embed_dim,
- kernel_size=kernel_size, stride=kernel_size, bias=True,
- )
+ in_dim = (self.in_channels * self.temporal_patch_size
+ * self.patch_size * self.patch_size)
+ self.proj = nn.Linear(in_dim, self.embed_dim, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
- hidden_states = hidden_states.view(
- -1, self.in_channels, self.temporal_patch_size,
- self.patch_size, self.patch_size,
- )
- hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ hidden_states = hidden_states.reshape(-1, self.proj.in_features).to(dtype=target_dtype)
+ hidden_states = self.proj(hidden_states)
return hidden_states
Backward-compatibility note for upstream maintainers
The change must also update the state_dict key remapping path so
existing pretrained checkpoints (which save weights under the Conv3d
shape (out, in, k_t, k_h, k_w)) load correctly into the Linear layer
shape (out, in·k_t·k_h·k_w). A _load_from_state_dict hook that does
the same reshape is sufficient:
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Backward compat: reshape Conv3d weight in legacy checkpoints
key = prefix + "proj.weight"
if key in state_dict and state_dict[key].dim() == 5:
out_dim = state_dict[key].shape[0]
state_dict[key] = state_dict[key].reshape(out_dim, -1)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
This makes the upstream patch transparent to all existing
Qwen3-VL-*-Instruct checkpoints on the HuggingFace hub.
9. Reproduction recipe
Profilers used in discovery (in this repo):
tools/profile_qwen3_cache.py # forward speed at multiple bs
tools/profile_qwen3_attn.py # tests sdpa/flash/eager
tools/profile_qwen3_breakdown.py # processor / xfer / fwd timing
tools/profile_qwen3_visionfix.py # forces attn on every block
tools/profile_qwen3_monkeypatch.py # replaces vision attention forward
tools/profile_qwen3_per_layer.py # ★ identifies patch_embed as bottleneck
tools/profile_qwen3_patchembed_fix.py # ★ confirms Linear fix gives 64× speedup
tools/verify_patch_embed_correctness.py # ★ fp32 + bf16 + downstream verification
tools/run_qwen3_cache_fast.py # production launcher with the patch
Reproduction (~30 s):
cd PROJECT_ROOT
python -u tools/profile_qwen3_per_layer.py
# Expected: patch_embed: ~16,000 ms; all 24 transformer blocks: ~50 ms
10. Impact summary
For LKAlert paper §5 main table (multisrc-val binary_AP for v3-pomdp-v2):
- Without this fix: infeasible (~6 days wall-clock, exceeds paper deadline)
- With this fix: ~2 hours wall-clock for a 29,169-sample feature cache
- Verified equivalent: downstream belief cosine sim > 0.999
For the broader community: anyone running Qwen3-VL inference on RTX 5090 or other Blackwell GPUs in bf16 is silently paying a 50,000× cost on the patch projection. A 1-line PR upstream would resolve this.
Appendix A: full per-layer timing dump (bs=1)
[device check] ✓ all submodules on cuda
[prep inputs bs=1]
pixel_values: (6080, 1536) # 8 frames × 760 patches × 1536 features
grid_thw: (8, 3), values:
[[1, 20, 38], [1, 20, 38], ..., [1, 20, 38]]
vision tower has 24 blocks
[component timing]
patch_embed: 16111.3 ms ⚠️ the bug
pos_embed_interpolate: 22.8 ms
rot_pos_emb: 20.7 ms
block[0]: 23.4 ms (warmup)
block[1]: 1.5 ms
block[2]: 1.4 ms
block[23]: 1.4 ms
block 0-2 mean: 8.8 ms
block ALL mean: 2.3 ms
block ALL total: 56.4 ms
merger: 0.5 ms
[zoom: block[0] attn vs mlp]
attn (3 reps): 2.4 ms total = 0.8 ms/call
mlp (3 reps): 1.8 ms total = 0.6 ms/call
Appendix B: per-batch-size scaling
Pre-fix (nn.Conv3d):
| bs | total time | per-sample | seq_len | VRAM |
|---|---|---|---|---|
| 1 | 16.7 s | 16.7 s | 1653 | 9.7 GB |
| 4 | 65.3 s | 16.3 s | 2133 | 10.0 GB |
| 8 | 148 s | 18.5 s | 2133 | 10.0 GB |
| 16 | 145 s | 9.3 s | 2133 | 10.0 GB |
Post-fix (nn.Linear):
| bs | total time | per-sample |
|---|---|---|
| 1 | 0.27 s | 0.27 s |
| 8 | 2.16 s | 0.27 s |
Linear keeps a constant ~0.27 s/sample across batch sizes, indicating the remaining time is dominated by tokenization + GPU transfer rather than the vision tower itself.
Appendix C: related code paths in this repo
The slowdown affects two existing scripts in our codebase that build Qwen3-VL belief caches; both should be migrated to use the workaround:
training/Policy/make_cot_belief_cache.py— main belief cache buildertraining/Policy/make_belief_cache_v2.py— older variant
To run cached extraction with the fix today, use
tools/run_qwen3_cache_fast.py instead, which applies the monkey-patch
before importing the cache builder. The CLI surface is identical.