VLAlert / PATCH_conv3d_linear.md
AsianPlayer's picture
Add VLAlert code
1e05592 verified

Qwen3-VL Vision Patch Embedding: 1000× Slowdown from nn.Conv3d on Blackwell GPUs

Author: Anonymous · Date: 2026-05-03 Status: confirmed bug · workaround validated · upstream patch proposed Component: transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionPatchEmbed


TL;DR

Qwen3VLVisionPatchEmbed.forward runs at ~16 seconds per call for a single 8-frame video clip on RTX 5090 (Blackwell, sm_120) with PyTorch 2.9 + CUDA 12.8 + cuDNN 9.10.0.2 + bf16. The bottleneck is a single nn.Conv3d op whose kernel_size == stride == [2, 16, 16] configuration falls into a degenerate cuDNN slow-path. Replacing it with a mathematically equivalent nn.Linear makes it run in ~0.3 ms — a >50,000× speedup on the isolated layer, and ~64× end-to-end on the full vision tower forward.

This bug makes large-scale belief-cache extraction effectively impossible: extracting features for 29,169 multisrc-val samples would have taken ~6 days with Conv3d, but completes in ~2 hours with the Linear replacement. Mathematical equivalence is proven and downstream belief cosine similarity > 0.99.


1. Environment

Python:        3.14.0
PyTorch:       2.9.0+cu128
CUDA:          12.8
cuDNN:         9.10.0.2 (91002)
transformers:  5.0.0.dev0
flash-attn:    2.8.3 (installed)
GPU:           NVIDIA GeForce RTX 5090 (Blackwell, compute capability 12.0)
OS:            Linux-6.8.0-110-generic-x86_64-with-glibc2.39

Hardware: 32 GB VRAM, 24 CPU cores, 62 GB RAM.


2. The buggy implementation

File:

~/miniconda3/envs/lkalert/lib/python3.14/site-packages/
    transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines 59–76:

class Qwen3VLVisionPatchEmbed(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.patch_size = config.patch_size                  # 16
        self.temporal_patch_size = config.temporal_patch_size  # 2
        self.in_channels = config.in_channels                 # 3
        self.embed_dim = config.hidden_size                   # 1024

        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
        # ▼ The slow op:
        self.proj = nn.Conv3d(
            self.in_channels, self.embed_dim,
            kernel_size=kernel_size, stride=kernel_size, bias=True
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(
            -1, self.in_channels, self.temporal_patch_size,
            self.patch_size, self.patch_size,
        )
        hidden_states = self.proj(
            hidden_states.to(dtype=target_dtype)
        ).view(-1, self.embed_dim)
        return hidden_states

The convolution has kernel_size == stride, no padding, no dilation.


3. Discovery timeline

The slowdown was found while attempting to extract per-frame Qwen3-VL-4B belief features for the LKAlert paper's multisrc-val evaluation set (29,169 samples). The end-to-end extraction script [training/Policy/make_cot_belief_cache.py] was running at 138 seconds per DataLoader iteration with --batch_size 8, projecting to 5–6 days of wall-clock time. Profiling proceeded in five stages.

Stage 1 — confirm GPU is healthy

Pure matmul benchmark on RTX 5090:

matmul 4096x4096:  0.8 ms total/10,  182.3 TFLOPS
matmul 8192x8192:  4.9 ms total/10,  223.7 TFLOPS

Hardware delivers ~200 TFLOPs bf16 — within spec. GPU is fine.

Stage 2 — eliminate batching as the cause

Tested forward time at multiple batch sizes:

batch_size total time per-sample seq_len VRAM
1 16.5 s 16.5 s 1653 9.7 GB
4 65.3 s 16.3 s 2133 10.0 GB
8 148 s 18.5 s 2133 10.0 GB
16 145 s 9.3 s 2133 10.0 GB

Per-sample time is ~16 s regardless of batch size, ruling out a DataLoader, collate, or padding bug. Batch=16 saturates at the same total time, suggesting the bottleneck is per-token, not per-sample.

Stage 3 — eliminate attention as the cause

Tested all three attn_implementation settings on Qwen3-VL:

attn_implementation bs=1 forward bs=8 forward
eager 17.1 s
sdpa 16.5 s 145.6 s
flash_attention_2 16.5 s 147.6 s

All three are identically slow. A monkey-patch replacing Qwen3VLVisionAttention.forward with a clean SDPA implementation also gave no speedup (still ~150 s at bs=8). Attention is not the bottleneck.

Stage 4 — granular component timing

Per-component timing of Qwen3VLVisionModel.forward for bs=1 (8 frames, 6080 visual patches):

patch_embed:            16,111.3 ms   ← 96% of forward time
pos_embed_interpolate:      22.8 ms
rot_pos_emb:                20.7 ms
block[0]:                   23.4 ms   (warmup)
block[1..23] (23 layers):    1.4 ms each
block ALL total (24 layers):56.4 ms   ← entire transformer is fast
merger:                      0.5 ms
─────────────────────────────────────
TOTAL                  ≈ 16,212 ms

The 24-layer ViT transformer takes 56 ms total. The single Conv3d patch projection takes 16,111 ms — 287× more than the rest of the network combined.

Stage 5 — pinpoint the slow op

Source inspection of Qwen3VLVisionPatchEmbed.proj reveals nn.Conv3d(3, 1024, kernel=[2,16,16], stride=[2,16,16]). With stride == kernel, this convolution has zero overlap between output positions. Each output element is a function of exactly one disjoint 3-channel × 2-frame × 16×16-pixel window — i.e., a per-window dot product.

This is mathematically a flatten + linear projection, not a true 3-D convolution.


4. Root-cause analysis

Why the cuDNN path is slow

cuDNN's convolution_forward dispatcher does not detect the special case kernel_size == stride && dilation == 1 && padding == 0. For typical 3D convolutions (overlapping kernels, e.g. video models), this is fine — cuDNN selects implicit-GEMM or Winograd algorithms tuned for spatial reuse.

For the patchification case (no spatial reuse), cuDNN still goes through the full 3-D path. On Blackwell (sm_120) at the time of writing, this path appears to fall back to a generic, unfused, non-tensor-core kernel for bf16

  • tiny kernels. We did not bisect to the exact kernel name, but the empirical 1000× slowdown vs. the Linear equivalent is consistent with "loops + scalar ops" rather than "tensor-core GEMM".

Layered responsibility

Layer Has bug? Could fix?
HuggingFace transformers (Qwen3-VL design) Source: chose nn.Conv3d for a non-convolutional op Replace with nn.Linear (1-line PR)
cuDNN 9.10.0.2 Yes — slow path for stride==kernel Conv3d on sm_120 + bf16 NVIDIA
PyTorch 2.9 Could short-circuit stride==kernel to bmm/Linear in dispatcher PyTorch team

Most pragmatic fix: change one line in transformers.

Why this wasn't noticed earlier

  1. The same pattern exists in Qwen2-VL and Qwen2.5-VL (same nn.Conv3d design). Earlier extractions on these checkpoints may have run on Hopper (sm_90) or older cuDNN, where the slow path didn't trigger, or completed despite being slow because dataset sizes were smaller.
  2. Earlier Qwen3-VL extractions in this repo (DAD test = 466 samples, DADA test = 1001 samples) did run at 16 s/sample — the user simply waited 2–4 hours per extraction without noticing the inefficiency. The bug only became blocking when extracting 29,169 multisrc samples.
  3. Standard ImageNet ViT benchmarks use Conv2d (not Conv3d) for patch embed; Qwen-VL is unusual in needing a 3-D op (because of the temporal patch dimension).

5. Mathematical equivalence proof

Claim

For an nn.Conv3d configured with kernel_size = stride (and padding = 0, dilation = 1, groups = 1), the operation is exactly equivalent to:

y = x.flatten() @ W.flatten().T + b

where W.flatten() reshapes the convolution kernel from (out_dim, in_C, k_t, k_h, k_w) to (out_dim, in_C·k_t·k_h·k_w) in row-major (C-style) order, and x.flatten() similarly reshapes the input patch.

Proof

nn.Conv3d defines, for output position (t', h', w'):

y[k, t', h', w'] = b[k] + Σ_{c, dt, dh, dw}  W[k, c, dt, dh, dw] · x[c, s_t·t' + dt, s_h·h' + dh, s_w·w' + dw]

with s_t, s_h, s_w the strides and dt, dh, dw ranging over the kernel extents [0, k_t), [0, k_h), [0, k_w).

When s_t = k_t, s_h = k_h, s_w = k_w (the patchification case), the input windows for distinct output positions are disjoint:

window(t') = [t'·k_t, (t'+1)·k_t)        non-overlapping
window(h') = [h'·k_h, (h'+1)·k_h)        non-overlapping
window(w') = [w'·k_w, (w'+1)·k_w)        non-overlapping

For each disjoint window, the convolution output is exactly the dot product between the flattened window contents and the flattened kernel:

y[k, t', h', w']  =  b[k]  +  Σ_{c, dt, dh, dw}
                                W[k, c, dt, dh, dw]
                              · x[c, t'·k_t + dt, h'·k_h + dh, w'·k_w + dw]

                  =  b[k]  +  ⟨ flatten(W[k]) , flatten(window(t', h', w')) ⟩

If we reshape the input tensor so that each disjoint window is a row, this is literally nn.Linear's definition:

y = b + W_flat @ x_flat.T            where W_flat = W.reshape(out_dim, -1)
                                            x_flat = x.reshape(N_patches, -1)

The flattening order must be consistent on both sides. PyTorch's default row-major (.reshape() / .view() without permutation) preserves (c, dt, dh, dw) ordering on both W and x, so a single .reshape(out_dim, -1) of the kernel and .reshape(N, -1) of the input gives the equivalence. ∎

Implementation

def conv3d_to_linear(conv: nn.Conv3d) -> nn.Linear:
    """Build mathematically equivalent Linear for a Conv3d with stride=kernel."""
    out_dim = conv.out_channels
    in_dim = (conv.in_channels * conv.kernel_size[0]
              * conv.kernel_size[1] * conv.kernel_size[2])
    # Conv3d weight: (out, in_C, k_t, k_h, k_w)  → row-major flatten
    w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
    bias = conv.bias.detach().clone() if conv.bias is not None else None
    new = nn.Linear(in_dim, out_dim, bias=bias is not None)
    new.weight.data.copy_(w_flat)
    if bias is not None:
        new.bias.data.copy_(bias)
    return new.to(device=conv.weight.device, dtype=conv.weight.dtype)

6. Verification

6.1 Numerical equivalence

Three tests defined in tools/verify_patch_embed_correctness.py:

Test Tolerance Result What it proves
fp32 math equivalence max abs diff < 1e-5 < 1e-7 (typical) Conv3d ≡ Linear up to fp32 round-off
bf16 numerical noise cosine sim > 0.999 ~0.9995 bf16 accumulation noise is bounded
Downstream belief output (after 24-layer ViT) per-sample pooled cos > 0.99 > 0.999 head receives indistinguishable features

The bf16 absolute difference of 1.56e-2 on the patch_embed output alone is the expected sqrt(N_inputs) · ε_bf16 ≈ √1536 · 2⁻⁷ ≈ 0.4 for direct single-precision accumulation, well bounded by nn.Linear's use of fma + tensor cores.

6.2 End-to-end speedup

Benchmark on RTX 5090, single 8-frame video clip (6080 visual patches at short-edge 336):

forward bs=1 bs=8 bs=16 end-to-end (29,169 samples)
Conv3d (current) 16.5 s 150 s 145 s ~6 days
Linear (patched) 0.27 s 2.16 s (TBD) ~2.2 hours
Speedup 61× 70× ~65×

Patch-embed micro-benchmark (just the layer in isolation):

Conv3d Linear speedup
time per forward 16,111 ms 0.3 ms >50,000×

7. Workaround code

The following workaround is in tools/run_qwen3_cache_fast.py at this repository:

import torch.nn as nn
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed


def _fast_patch_embed_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """Lazy in-place replacement: first call swaps Conv3d → Linear, then
    runs the equivalent flat-projection forward."""
    target_dtype = self.proj.weight.dtype

    if isinstance(self.proj, nn.Conv3d):
        # First call on this instance: convert in place
        conv = self.proj
        out_dim = conv.out_channels
        in_dim = (conv.in_channels * conv.kernel_size[0]
                  * conv.kernel_size[1] * conv.kernel_size[2])
        w_flat = conv.weight.detach().reshape(out_dim, in_dim).contiguous()
        bias = conv.bias.detach().clone() if conv.bias is not None else None
        new_proj = nn.Linear(in_dim, out_dim, bias=bias is not None)
        new_proj.weight.data.copy_(w_flat)
        if bias is not None:
            new_proj.bias.data.copy_(bias)
        new_proj.to(device=conv.weight.device, dtype=conv.weight.dtype)
        self.proj = new_proj  # in-place attribute swap

    # self.proj is now nn.Linear; route through it
    if hidden_states.dim() > 2 or hidden_states.shape[-1] != self.proj.in_features:
        hidden_states = hidden_states.reshape(-1, self.proj.in_features)
    return self.proj(hidden_states.to(dtype=target_dtype))


# Apply class-level patch BEFORE any model is instantiated
Qwen3VLVisionPatchEmbed.forward = _fast_patch_embed_forward

Apply once at process start; the lazy in-place conversion is triggered on the first forward of each Qwen3VLVisionPatchEmbed instance.

Properties

  • No model weight modification — the existing state_dict is preserved exactly; only the layout of self.proj changes (Conv3d → Linear) at inference time.
  • No effect on training — the patch is only applied in our inference pipeline.
  • Idempotent — re-applying does nothing (the isinstance check skips conversion when self.proj is already nn.Linear).
  • Resumablemake_cot_belief_cache.py writes per-chunk .pt files, so a crashed run can resume.

8. Proposed upstream fix

Replacing 3 lines in transformers/models/qwen3_vl/modeling_qwen3_vl.py removes the slowdown for all users of Qwen3-VL without any behavioral change:

 class Qwen3VLVisionPatchEmbed(nn.Module):
     def __init__(self, config) -> None:
         super().__init__()
         self.patch_size = config.patch_size
         self.temporal_patch_size = config.temporal_patch_size
         self.in_channels = config.in_channels
         self.embed_dim = config.hidden_size

-        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
-        self.proj = nn.Conv3d(
-            self.in_channels, self.embed_dim,
-            kernel_size=kernel_size, stride=kernel_size, bias=True,
-        )
+        in_dim = (self.in_channels * self.temporal_patch_size
+                  * self.patch_size * self.patch_size)
+        self.proj = nn.Linear(in_dim, self.embed_dim, bias=True)

     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         target_dtype = self.proj.weight.dtype
-        hidden_states = hidden_states.view(
-            -1, self.in_channels, self.temporal_patch_size,
-            self.patch_size, self.patch_size,
-        )
-        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+        hidden_states = hidden_states.reshape(-1, self.proj.in_features).to(dtype=target_dtype)
+        hidden_states = self.proj(hidden_states)
         return hidden_states

Backward-compatibility note for upstream maintainers

The change must also update the state_dict key remapping path so existing pretrained checkpoints (which save weights under the Conv3d shape (out, in, k_t, k_h, k_w)) load correctly into the Linear layer shape (out, in·k_t·k_h·k_w). A _load_from_state_dict hook that does the same reshape is sufficient:

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
    # Backward compat: reshape Conv3d weight in legacy checkpoints
    key = prefix + "proj.weight"
    if key in state_dict and state_dict[key].dim() == 5:
        out_dim = state_dict[key].shape[0]
        state_dict[key] = state_dict[key].reshape(out_dim, -1)
    super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

This makes the upstream patch transparent to all existing Qwen3-VL-*-Instruct checkpoints on the HuggingFace hub.


9. Reproduction recipe

Profilers used in discovery (in this repo):

tools/profile_qwen3_cache.py            # forward speed at multiple bs
tools/profile_qwen3_attn.py             # tests sdpa/flash/eager
tools/profile_qwen3_breakdown.py        # processor / xfer / fwd timing
tools/profile_qwen3_visionfix.py        # forces attn on every block
tools/profile_qwen3_monkeypatch.py      # replaces vision attention forward
tools/profile_qwen3_per_layer.py        # ★ identifies patch_embed as bottleneck
tools/profile_qwen3_patchembed_fix.py   # ★ confirms Linear fix gives 64× speedup
tools/verify_patch_embed_correctness.py # ★ fp32 + bf16 + downstream verification
tools/run_qwen3_cache_fast.py           # production launcher with the patch

Reproduction (~30 s):

cd PROJECT_ROOT
python -u tools/profile_qwen3_per_layer.py
# Expected: patch_embed: ~16,000 ms; all 24 transformer blocks: ~50 ms

10. Impact summary

For LKAlert paper §5 main table (multisrc-val binary_AP for v3-pomdp-v2):

  • Without this fix: infeasible (~6 days wall-clock, exceeds paper deadline)
  • With this fix: ~2 hours wall-clock for a 29,169-sample feature cache
  • Verified equivalent: downstream belief cosine sim > 0.999

For the broader community: anyone running Qwen3-VL inference on RTX 5090 or other Blackwell GPUs in bf16 is silently paying a 50,000× cost on the patch projection. A 1-line PR upstream would resolve this.


Appendix A: full per-layer timing dump (bs=1)

[device check]   ✓ all submodules on cuda

[prep inputs bs=1]
   pixel_values: (6080, 1536)            # 8 frames × 760 patches × 1536 features
   grid_thw: (8, 3), values:
       [[1, 20, 38], [1, 20, 38], ..., [1, 20, 38]]
   vision tower has 24 blocks

[component timing]
   patch_embed:           16111.3 ms   ⚠️  the bug
   pos_embed_interpolate:    22.8 ms
   rot_pos_emb:              20.7 ms
   block[0]:                 23.4 ms   (warmup)
   block[1]:                  1.5 ms
   block[2]:                  1.4 ms
   block[23]:                 1.4 ms
   block 0-2 mean:            8.8 ms
   block ALL mean:            2.3 ms
   block ALL total:          56.4 ms
   merger:                    0.5 ms

[zoom: block[0] attn vs mlp]
   attn (3 reps): 2.4 ms total = 0.8 ms/call
   mlp  (3 reps): 1.8 ms total = 0.6 ms/call

Appendix B: per-batch-size scaling

Pre-fix (nn.Conv3d):

bs total time per-sample seq_len VRAM
1 16.7 s 16.7 s 1653 9.7 GB
4 65.3 s 16.3 s 2133 10.0 GB
8 148 s 18.5 s 2133 10.0 GB
16 145 s 9.3 s 2133 10.0 GB

Post-fix (nn.Linear):

bs total time per-sample
1 0.27 s 0.27 s
8 2.16 s 0.27 s

Linear keeps a constant ~0.27 s/sample across batch sizes, indicating the remaining time is dominated by tokenization + GPU transfer rather than the vision tower itself.


Appendix C: related code paths in this repo

The slowdown affects two existing scripts in our codebase that build Qwen3-VL belief caches; both should be migrated to use the workaround:

  1. training/Policy/make_cot_belief_cache.py — main belief cache builder
  2. training/Policy/make_belief_cache_v2.py — older variant

To run cached extraction with the fix today, use tools/run_qwen3_cache_fast.py instead, which applies the monkey-patch before importing the cache builder. The CLI surface is identical.