Summary
transformer_engine.pytorch.GroupedLinear is ~2× slower than a Python loop of F.linear calls (against the same stacked weight tensor) on Mixtral-8x7B at EP=2 / seq=8192 across batch sizes 1-16. nsys traces show the cuBLAS grouped-matmul path firing ~6× more nvjet_* GEMM launches plus splitKreduce + buffer-fill scaffolding. Filing as an investigation thread.
Context: tutorial work in #2642 (docs/examples/te_mixtral/).
Step-time gap
8× B300, BF16, EP=2 (DP=4, 4 experts/rank), seq=8192, 200 timed steps + 10 warmup. Same model, same stacked weights; only the per-expert FFN dispatch differs.
| Batch |
Loop (F.linear × 4) |
GroupedLinear |
Ratio |
| 1 |
240 ms |
449 ms |
1.87× |
| 2 |
242 ms |
474 ms |
1.96× |
| 4 |
242 ms |
511 ms |
2.11× |
| 8 |
306 ms |
617 ms |
2.02× |
| 16 |
506 ms |
OOM |
— |
Profile (nsys, batch=2, BF16, 30 timed steps)
nsys stats --report cuda_gpu_kern_sum on nsys_tier{2,3}_batch2.nsys-rep:
| Kernel |
Loop instances |
GroupedLinear instances |
nvjet_sm103_tst_* BF16 GEMM (sum) |
10,693 |
66,083 (6.2×) |
cublasLt::splitKreduce_kernel |
0 |
48,823 |
FillFunctor<bf16> (zero-init) |
(not in top-15) |
138,296 |
CUDAFunctor_add |
(not in top-15) |
98,560 |
| Total GPU kernel time (capture) |
~70 s |
~138 s (1.99×) |
NCCL SendRecv is actually faster on the grouped path (37 s vs 46 s); the 2× wall-time gap is purely additional GPU compute kernels.
Working hypothesis
cuBLAS Lt selects a split-K plan for (M ≈ 4096, K = 4096, N = 14336) grouped GEMMs and decomposes one grouped matmul into many sub-tile launches plus reduction + buffer fills. Plain F.linear at the same shape runs in a single non-split-K nvjet launch.
Repro (in PR #2642)
cd docs/examples/te_mixtral
# Loop (Tier 2)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
--improvement 2 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
--warmup-steps 10 --train-steps 200
# GroupedLinear (Tier 3)
torchrun --standalone --nproc_per_node=8 run_finetune_ep.py \
--improvement 3 --ep-size 2 --batch-size 2 --max-seq-length 8192 \
--warmup-steps 10 --train-steps 200
Code path: te_mixtral.py:551-580 selects expert_ffn_mode = "grouped" | "loop". Both paths use the same stacked weight tensor.
Environment
- 8× NVIDIA B300 SXM
- NGC
pytorch-25.12-py3 container (CUDA 13.1)
torch 2.10, transformer_engine 2.10
- Mixtral-8x7B v0.1, BF16 mixed precision
Questions
- Is split-K the intended cuBLAS plan for
(M ≈ 4k, K = 4k, N = 14k) grouped GEMMs?
- Could
GroupedLinear set a cublasLtMatmulPreference to discourage split-K when per-expert M is already large?
- Has anyone benchmarked
GroupedLinear vs per-expert F.linear loop at Mixtral FFN shapes?
Happy to share full nsys traces (~500 MB each) on request. Tagging for visibility.
Summary
transformer_engine.pytorch.GroupedLinearis ~2× slower than a Python loop ofF.linearcalls (against the same stacked weight tensor) on Mixtral-8x7B at EP=2 / seq=8192 across batch sizes 1-16. nsys traces show the cuBLAS grouped-matmul path firing ~6× morenvjet_*GEMM launches plussplitKreduce+ buffer-fill scaffolding. Filing as an investigation thread.Context: tutorial work in #2642 (
docs/examples/te_mixtral/).Step-time gap
8× B300, BF16, EP=2 (DP=4, 4 experts/rank), seq=8192, 200 timed steps + 10 warmup. Same model, same stacked weights; only the per-expert FFN dispatch differs.
F.linear× 4)GroupedLinearProfile (nsys, batch=2, BF16, 30 timed steps)
nsys stats --report cuda_gpu_kern_sumonnsys_tier{2,3}_batch2.nsys-rep:nvjet_sm103_tst_*BF16 GEMM (sum)cublasLt::splitKreduce_kernelFillFunctor<bf16>(zero-init)CUDAFunctor_addNCCL
SendRecvis actually faster on the grouped path (37 s vs 46 s); the 2× wall-time gap is purely additional GPU compute kernels.Working hypothesis
cuBLAS Lt selects a split-K plan for
(M ≈ 4096, K = 4096, N = 14336)grouped GEMMs and decomposes one grouped matmul into many sub-tile launches plus reduction + buffer fills. PlainF.linearat the same shape runs in a single non-split-Knvjetlaunch.Repro (in PR #2642)
Code path:
te_mixtral.py:551-580selectsexpert_ffn_mode = "grouped" | "loop". Both paths use the same stacked weight tensor.Environment
pytorch-25.12-py3container (CUDA 13.1)torch2.10,transformer_engine2.10Questions
(M ≈ 4k, K = 4k, N = 14k)grouped GEMMs?GroupedLinearset acublasLtMatmulPreferenceto discourage split-K when per-expert M is already large?GroupedLinearvs per-expertF.linearloop at Mixtral FFN shapes?Happy to share full nsys traces (~500 MB each) on request. Tagging for visibility.