Skip to content

[CI][BugFix] Flash bwd varlen: zero-init lse/Delta padding to avoid NaN in Dk#2461

Merged
LeiWang1999 merged 1 commit into
tile-ai:mainfrom
RuneFang:fix_flash_bwd
Jun 26, 2026
Merged

[CI][BugFix] Flash bwd varlen: zero-init lse/Delta padding to avoid NaN in Dk#2461
LeiWang1999 merged 1 commit into
tile-ai:mainfrom
RuneFang:fix_flash_bwd

Conversation

@RuneFang

@RuneFang RuneFang commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

The varlen GQA backward example allocates lse and Delta via out_idx, which uses torch.empty under the hood. Their padding rows (i >= q_current_seqlen) were left untouched by the forward / preprocess kernels and could carry NaN/Inf from uninitialized memory.

These dirty values then leak into the backward kernel:

  • lse is consumed as exp2(qkT*scale - lse_shared[j]); a NaN poisons the whole column of P, dP and downstream dK.

  • Delta is consumed as qkT_cast * (dsT - delta[j]); even masked-out (qkT=0) columns become NaN through 0 * NaN = NaN, which is then accumulated into dK by the GEMM over padded K columns.

Fix: explicitly write 0.0 to the padding positions of lse (in flash_fwd) and Delta (in flash_bwd_prep), so the backward kernel always reads well-defined values regardless of the host-side allocator.

Verified by examples/flash_attention/test_example_flash_attention.py::test_example_gqa_bwd_tma_reduce_varlen (previously failed with dK == NaN, now passes; tilelang throughput restored from ~1 TFlops to ~7 TFlops on H100).

Summary

  • Fixed the variable-length GQA backward flash-attention example by explicitly writing 0 into the padding rows of:
    • lse in flash_fwd (previously auto-allocated via out_idx/torch.empty, so padding entries could contain NaN/Inf and later contaminate T.exp2(... - lse) because 0 * NaN = NaN).
    • Delta in flash_bwd_prep (same padding/uninitialized hazard, now cleared before backward uses it).
  • Added inline notes documenting the NaN/Inf propagation risk and why host-side zero-initializing the outputs is insufficient.

Validation

  • Verified examples/flash_attention/test_example_flash_attention.py::test_example_gqa_bwd_tma_reduce_varlen, which previously failed with dK == NaN and now passes.

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: eab82a43-aa4d-4b6e-8b7a-00c25ed59908

📥 Commits

Reviewing files that changed from the base of the PR and between fd1b1f7 and 1ffceff.

📒 Files selected for processing (1)
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py

📝 Walkthrough

Walkthrough

The varlen flash attention example now explicitly zero-fills padding slots in the forward lse output and the backward preprocessing Delta tensor.

Changes

Varlen padding writes

Layer / File(s) Summary
Zero-fill forward lse
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
flash_fwd writes logsum for valid query positions and 0 for out-of-range lse entries.
Zero-fill backward Delta
examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
flash_bwd_prep writes delta for valid query positions and 0 for out-of-range Delta entries.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

A rabbit hopped through the flash-attn glen,
Tucking zeros in padding again and again.
lse stays tidy, Delta too,
No stray NaNs sneak through the brew.
🐇✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main fix: zero-initializing varlen padding in lse/Delta to prevent NaNs in backward dK.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@RuneFang RuneFang changed the title [BugFix] Flash bwd varlen: zero-init lse/Delta padding to avoid NaN in Dk [CI][BugFix] Flash bwd varlen: zero-init lse/Delta padding to avoid NaN in Dk Jun 26, 2026

@LeiWang1999 LeiWang1999 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. I left a comment. And how do you think that if we avoid out_idx initialize lse and delta and directly initialize with torch.zeros?

Comment thread examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py Outdated
@RuneFang

RuneFang commented Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for your contribution. I left a comment. And how do you think that if we avoid out_idx initialize lse and delta and directly initialize with torch.zeros?

Thanks — this is a great question. I think the right way to look at it is to separate two concerns that are easy to confuse:

Correctness (kernel side). Initializing lse / delta with torch.zeros on the host doesn't actually prevent the out-of-bounds access inside the kernel — the backward pass still reads lse[..., j] for j >= q_current_seqlen and relies on 0 × dirty_value = 0 to mask it out later. That assumption silently breaks for non-finite floats (0 × NaN/Inf = NaN).

I think we could add an API for workspace-style outputs (lse, delta, m_partial, ...) that lets users explicitly specify buffer initialization, e.g. out_init={5: "zeros", 6: "zeros"}, to avoid the potential bugs from the current torch.empty default.

@LeiWang1999 LeiWang1999 merged commit 1f70387 into tile-ai:main Jun 26, 2026
5 checks passed
@LeiWang1999

Copy link
Copy Markdown
Member

LGTM:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants