[CI][BugFix] Flash bwd varlen: zero-init lse/Delta padding to avoid NaN in Dk#2461
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThe varlen flash attention example now explicitly zero-fills padding slots in the forward ChangesVarlen padding writes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
LeiWang1999
left a comment
There was a problem hiding this comment.
Thanks for your contribution. I left a comment. And how do you think that if we avoid out_idx initialize lse and delta and directly initialize with torch.zeros?
Thanks — this is a great question. I think the right way to look at it is to separate two concerns that are easy to confuse: Correctness (kernel side). Initializing I think we could add an API for workspace-style outputs (lse, delta, m_partial, ...) that lets users explicitly specify buffer initialization, e.g. out_init={5: "zeros", 6: "zeros"}, to avoid the potential bugs from the current torch.empty default. |
|
LGTM:) |
The varlen GQA backward example allocates lse and Delta via out_idx, which uses torch.empty under the hood. Their padding rows (i >= q_current_seqlen) were left untouched by the forward / preprocess kernels and could carry NaN/Inf from uninitialized memory.
These dirty values then leak into the backward kernel:
lse is consumed as exp2(qkT*scale - lse_shared[j]); a NaN poisons the whole column of P, dP and downstream dK.
Delta is consumed as qkT_cast * (dsT - delta[j]); even masked-out (qkT=0) columns become NaN through 0 * NaN = NaN, which is then accumulated into dK by the GEMM over padded K columns.
Fix: explicitly write 0.0 to the padding positions of lse (in flash_fwd) and Delta (in flash_bwd_prep), so the backward kernel always reads well-defined values regardless of the host-side allocator.
Verified by examples/flash_attention/test_example_flash_attention.py::test_example_gqa_bwd_tma_reduce_varlen (previously failed with dK == NaN, now passes; tilelang throughput restored from ~1 TFlops to ~7 TFlops on H100).
Summary
0into the padding rows of:lseinflash_fwd(previously auto-allocated viaout_idx/torch.empty, so padding entries could contain NaN/Inf and later contaminateT.exp2(... - lse)because0 * NaN = NaN).Deltainflash_bwd_prep(same padding/uninitialized hazard, now cleared before backward uses it).Validation
examples/flash_attention/test_example_flash_attention.py::test_example_gqa_bwd_tma_reduce_varlen, which previously failed withdK == NaNand now passes.