[CUDA] Add SM120 NVF4 block-scale MMA support#2364
Conversation
Add warp-level mxf4nvf4 block-scale MMA lowering and coverage so TileLang can validate NVF4 kernels against SM120/CUTLASS behavior.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds SM120 NVF4 block-scale MMA support end-to-end: TL builtin intrinsic, PTX/mma template, CodeGen emission, 4-bit layout transforms and emitter wiring, CUTLASS CUDA reference kernel, TileLang correctness harness with packing/decoding utilities, and CUDA-gated tests. ChangesNVF4 Block-Scale MMA Support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Apply clang-format updates and remove an unused block-scale emitter local.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py`:
- Around line 253-256: The code sets cutlass_root using a hardcoded developer
path which is unsafe; update the logic in the block that defines repo and
cutlass_root (variables cutlass_root, repo and the use of
os.environ["CUTLASS_ROOT"]) to remove the hardcoded "/data/home/..." fallback
and instead: prefer the CUTLASS_ROOT environment variable if set, otherwise fall
back to repo / "3rdparty" / "cutlass"; ensure cutlass_root.exists() is checked
and raise a clear error or log if neither location exists so the failure is
explicit.
In `@tilelang/language/ast/ir.py`:
- Line 2140: The module's __all__ includes "ptx_mma_block_scale" but no symbol
by that name is defined or imported, which breaks wildcard imports; either
remove "ptx_mma_block_scale" from the __all__ list or add a proper
definition/import for ptx_mma_block_scale (e.g., define the function/class or
import it from its source) so the name is actually bound in this module; update
the __all__ entry near the existing list and ensure the symbol name matches
exactly the defined/imported identifier.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6a5bd78f-7a51-446a-bd8a-0078295fbd05
📒 Files selected for processing (16)
maint/gemm/correctness_evaluation_nvf4_vs_cutlass.pymaint/gemm/cutlass_nvf4_ref.cusrc/cuda/codegen/codegen_cuda.ccsrc/cuda/codegen/codegen_cuda.hsrc/op/builtin.ccsrc/op/builtin.hsrc/tl_templates/cuda/instruction/mma_block_scale.htesting/python/language/test_tilelang_language_nvf4_mma_block_scale.pytilelang/cuda/intrinsics/__init__.pytilelang/cuda/intrinsics/layout/mma_layout.pytilelang/cuda/intrinsics/macro/__init__.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/intrinsics/__init__.pytilelang/language/ast/ir.pytilelang/language/tir/ir.pytilelang/language/tir/op.py
Apply ruff formatting and remove trailing whitespace from the NVF4 block-scale emitter changes.
Remove a developer-specific CUTLASS fallback path, bind the block-scale MMA AST helper, and apply clang-format to the CUDA codegen changes.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/cuda/intrinsics/macro/mma_macro_generator.py (2)
1402-1407:⚠️ Potential issue | 🟠 Major | ⚡ Quick winPreserve outer-dimension bases for plain scale buffers.
_scale_region_parts()treatsBufferdifferently from the A/B paths and drops every prefix dimension by returning([], 0, 0). The accesses at Lines 1478-1482 and Line 1510 then only index the trailing two axes, which breaks direct N-D scale buffers and sliced views. Reuse_legalize_to_buffer_region()here so scale buffers follow the same region contract as A/B.Suggested fix
`@staticmethod` def _scale_region_parts(scale_buf: Buffer | BufferRegion): - if isinstance(scale_buf, BufferRegion): - return scale_buf.buffer, [r.min for r in scale_buf.region[:-2]], scale_buf.region[-2].min, scale_buf.region[-1].min - if isinstance(scale_buf, Buffer): - return scale_buf, [], 0, 0 - raise ValueError(f"Unsupported scale buffer type: {type(scale_buf)}") + region = TensorCoreIntrinEmitter._legalize_to_buffer_region(scale_buf) + return ( + region.buffer, + [r.min for r in region.region[:-2]], + region.region[-2].min, + region.region[-1].min, + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py` around lines 1402 - 1407, _scale_region_parts currently returns (buffer, [], 0, 0) for plain Buffer which drops outer-dimension bases and breaks N-D scale buffers; change it to call and reuse _legalize_to_buffer_region(scale_buf) so both Buffer and BufferRegion follow the same region contract as A/B. Update _scale_region_parts to accept either type, call _legalize_to_buffer_region when scale_buf is a Buffer to obtain the buffer and full region, and then extract the same tuples (buffer, [r.min for r in region[:-2]], region[-2].min, region[-1].min) as for BufferRegion; keep the existing ValueError for unsupported types. Ensure references to _scale_region_parts usage (the indexing code that expects preserved outer bases) continue to work unchanged.
1282-1333:⚠️ Potential issue | 🟠 Major | ⚡ Quick winValidate the fixed NVF4 contract in the constructor.
This emitter always lowers as
mxf4nvf4withk64,e2m1/e2m1, and the block-scale fragment layouts from this class, but it never rejects incompatiblea_dtype,b_dtype, oraccum_dtypeinputs. A caller can currently instantiate this public API with, for example,float16operands and still get NVF4 PTX emitted against mismatched fragment assumptions. Fail fast here instead of silently generating wrong code.Suggested guard
def __init__( self, a_dtype: str = T.float4_e2m1fn, b_dtype: str = T.float4_e2m1fn, accum_dtype: str = T.float32, @@ kind: str = "mxf4nvf4", scale_vec_size: int = 4, stype: str = "ue4m3", ): + if str(DataType(a_dtype)) != str(T.float4_e2m1fn) or str(DataType(b_dtype)) != str(T.float4_e2m1fn): + raise ValueError("SM120 block-scale MMA currently only supports float4_e2m1fn operands") + if str(DataType(accum_dtype)) != str(T.float32): + raise ValueError("SM120 block-scale MMA currently only supports float32 accumulation") self.block_scale_config = _get_block_scale_mma_config(kind, scale_vec_size, stype)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py` around lines 1282 - 1333, The constructor currently forces block-scale config to NVF4 but does not validate the caller-provided dtypes, so callers can pass incompatible a_dtype/b_dtype/accum_dtype and silently generate wrong code; after calling _get_block_scale_mma_config(...) in __init__, add a guard that checks the resolved self.block_scale_config.kind (and/or its expected dtype descriptors from the config) against the incoming a_dtype, b_dtype, and accum_dtype parameters (e.g., ensure a_dtype and b_dtype match the NVF4 fragment dtypes such as T.float4_e2m1fn and accum_dtype matches the expected accumulator like T.float32 or whatever the config exposes), and raise a ValueError with a clear message if they mismatch; perform this validation before calling super().__init__ so invalid combinations fail fast.maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py (1)
342-354:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winThe FP32 reference diagnostics are missing the UE4M3 block scales.
refis built from the decoded FP4 payloads only, sodiff_tl_refanddiff_cutlass_refare not checking the same computation as the block-scaled GEMMs. In the defaultNVF4_SCALE_MODE="varying"case, those numbers are inherently misleading. Either applysfa_logical/sfb_logicalper 16-K chunk when buildingref, or drop these prints until that scale-aware reference exists.Suggested minimal change
- ref = _decode_rowmajor_fp4(a, M, K) @ _decode_rowmajor_fp4(b, N, K).T - diff_tl_ref = (c_tl - ref).abs() - diff_cutlass_ref = (c_cutlass - ref).abs() print("scale_mode:", scale_mode) print("input_mode:", input_mode) print("max_abs_diff:", diff.max().item()) print("mean_abs_diff:", diff.mean().item()) print("max_abs_diff_transposed:", diff_t.max().item()) print("mean_abs_diff_transposed:", diff_t.mean().item()) - print("max_abs_diff_tilelang_ref:", diff_tl_ref.max().item()) - print("mean_abs_diff_tilelang_ref:", diff_tl_ref.mean().item()) - print("max_abs_diff_cutlass_ref:", diff_cutlass_ref.max().item()) - print("mean_abs_diff_cutlass_ref:", diff_cutlass_ref.mean().item())🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py` around lines 342 - 354, The FP32 reference `ref` is computed from raw decoded FP4 payloads so it doesn't include the UE4M3 block scales (`sfa_logical`/`sfb_logical`), making `diff_tl_ref` and `diff_cutlass_ref` invalid under NVF4_SCALE_MODE="varying"; fix by applying the per-block scales to the decoded tensors before matmul: after calling `_decode_rowmajor_fp4(a, M, K)` and `_decode_rowmajor_fp4(b, N, K)`, multiply each 16xK chunk of the decoded A by the corresponding entries in `sfa_logical` and each 16xK chunk of the decoded B by `sfb_logical` (or apply equivalent broadcasting per 16-K block) so `ref = scaled_decoded_a @ scaled_decoded_b.T` matches the block-scaled GEMM, otherwise remove the `diff_tl_ref`/`diff_cutlass_ref` prints until the scale-aware reference is implemented; reference symbols: `ref`, `_decode_rowmajor_fp4`, `sfa_logical`, `sfb_logical`, `diff_tl_ref`, `diff_cutlass_ref`.
♻️ Duplicate comments (1)
maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py (1)
247-256:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winFail fast when the CUTLASS root is missing.
This now avoids the developer-local fallback, but it still passes non-existent include roots into
load(). IfCUTLASS_ROOTis unset or wrong and3rdparty/cutlassis absent, the harness fails later with a compiler error instead of an explicit setup error.Suggested fix
repo = Path(__file__).resolve().parents[2] cutlass_root_env = os.environ.get("CUTLASS_ROOT") cutlass_root = Path(cutlass_root_env) if cutlass_root_env else repo / "3rdparty" / "cutlass" + if not cutlass_root.exists(): + raise RuntimeError( + f"CUTLASS not found at {cutlass_root}. " + "Set CUTLASS_ROOT or ensure 3rdparty/cutlass exists." + ) return load(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py` around lines 247 - 256, Check that the computed cutlass_root (from cutlass_root_env or repo / "3rdparty" / "cutlass") actually exists before calling load; if it does not exist, raise a clear RuntimeError instructing the developer to set CUTLASS_ROOT or populate 3rdparty/cutlass, and only pass existing include paths (cutlass_root / "include" and cutlass_root / "tools" / "util" / "include") into the load(...) call instead of blindly passing non-existent paths; refer to the variables cutlass_root_env, cutlass_root, extra_include_paths and the load(...) call to locate where to add the existence check and error raise.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py`:
- Around line 342-354: The FP32 reference `ref` is computed from raw decoded FP4
payloads so it doesn't include the UE4M3 block scales
(`sfa_logical`/`sfb_logical`), making `diff_tl_ref` and `diff_cutlass_ref`
invalid under NVF4_SCALE_MODE="varying"; fix by applying the per-block scales to
the decoded tensors before matmul: after calling `_decode_rowmajor_fp4(a, M, K)`
and `_decode_rowmajor_fp4(b, N, K)`, multiply each 16xK chunk of the decoded A
by the corresponding entries in `sfa_logical` and each 16xK chunk of the decoded
B by `sfb_logical` (or apply equivalent broadcasting per 16-K block) so `ref =
scaled_decoded_a @ scaled_decoded_b.T` matches the block-scaled GEMM, otherwise
remove the `diff_tl_ref`/`diff_cutlass_ref` prints until the scale-aware
reference is implemented; reference symbols: `ref`, `_decode_rowmajor_fp4`,
`sfa_logical`, `sfb_logical`, `diff_tl_ref`, `diff_cutlass_ref`.
In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 1402-1407: _scale_region_parts currently returns (buffer, [], 0,
0) for plain Buffer which drops outer-dimension bases and breaks N-D scale
buffers; change it to call and reuse _legalize_to_buffer_region(scale_buf) so
both Buffer and BufferRegion follow the same region contract as A/B. Update
_scale_region_parts to accept either type, call _legalize_to_buffer_region when
scale_buf is a Buffer to obtain the buffer and full region, and then extract the
same tuples (buffer, [r.min for r in region[:-2]], region[-2].min,
region[-1].min) as for BufferRegion; keep the existing ValueError for
unsupported types. Ensure references to _scale_region_parts usage (the indexing
code that expects preserved outer bases) continue to work unchanged.
- Around line 1282-1333: The constructor currently forces block-scale config to
NVF4 but does not validate the caller-provided dtypes, so callers can pass
incompatible a_dtype/b_dtype/accum_dtype and silently generate wrong code; after
calling _get_block_scale_mma_config(...) in __init__, add a guard that checks
the resolved self.block_scale_config.kind (and/or its expected dtype descriptors
from the config) against the incoming a_dtype, b_dtype, and accum_dtype
parameters (e.g., ensure a_dtype and b_dtype match the NVF4 fragment dtypes such
as T.float4_e2m1fn and accum_dtype matches the expected accumulator like
T.float32 or whatever the config exposes), and raise a ValueError with a clear
message if they mismatch; perform this validation before calling
super().__init__ so invalid combinations fail fast.
---
Duplicate comments:
In `@maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py`:
- Around line 247-256: Check that the computed cutlass_root (from
cutlass_root_env or repo / "3rdparty" / "cutlass") actually exists before
calling load; if it does not exist, raise a clear RuntimeError instructing the
developer to set CUTLASS_ROOT or populate 3rdparty/cutlass, and only pass
existing include paths (cutlass_root / "include" and cutlass_root / "tools" /
"util" / "include") into the load(...) call instead of blindly passing
non-existent paths; refer to the variables cutlass_root_env, cutlass_root,
extra_include_paths and the load(...) call to locate where to add the existence
check and error raise.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5bd71dcd-3f86-40b2-b9d8-6c494431985b
📒 Files selected for processing (6)
maint/gemm/correctness_evaluation_nvf4_vs_cutlass.pymaint/gemm/cutlass_nvf4_ref.cusrc/cuda/codegen/codegen_cuda.ccsrc/tl_templates/cuda/instruction/mma_block_scale.htilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/language/ast/ir.py
🚧 Files skipped from review as they are similar to previous changes (3)
- src/cuda/codegen/codegen_cuda.cc
- src/tl_templates/cuda/instruction/mma_block_scale.h
- maint/gemm/cutlass_nvf4_ref.cu
Make the CUTLASS comparison reference scale-aware, validate block-scale emitter dtypes, preserve scale buffer regions, and check CUTLASS include paths explicitly.
|
Really appreciate your contribution! I think it's better to provide high-level API for blockscaled mma on sm120, like existing |
|
Thanks, that makes sense. This PR is currently a low-level SM120 NVF4 MMA emitter. It checks that the new mma.sync instruction works, and that the fragment layout and scale-factor lane mapping are correct. I agree that users should have a higher-level API, similar to the existing SM100 blockscaled GEMM API. The SM120 version should not copy the TCGEN05/TMEM behavior directly, but it can provide a similar user experience on top of the SM120 warp-level mma.sync path. For this PR, I can move the SM120 blockscaled MMA helper into gemm_sm120.h, so it sits with the SM120 GEMM code. Then we can add a higher-level tile op API as a follow-up, unless you prefer to include that in this PR. |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/tir/op.py (1)
1079-1119:⚠️ Potential issue | 🟠 MajorFix
tl.ptx_mma_block_scaleintrinsic arity:set_num_inputs(21)doesn’t match CUDA lowering expectations.
src/op/builtin.cc:244-247registerstl.ptx_mma_block_scalewith.set_num_inputs(21).src/cuda/codegen/codegen_cuda.cc:2852lowers the intrinsic withICHECK_EQ(op->args.size(), 17U)and consumesop->args[0..16](accum_dtype/shape/layouts/k/vec_size/dtypes + A/B/C pointers & offsets +scale_a/scale_b).- The Python wrapper
language/tir/op.py:1079-1119passes exactly those 17call_intrinargs, so the registration count should be aligned (likely change.set_num_inputs(21)→17).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/language/tir/op.py` around lines 1079 - 1119, The registered intrinsic tl.ptx_mma_block_scale has a mismatched arity: the Python wrapper function ptx_mma_block_scale builds 17 call_intrin arguments and the CUDA lowering asserts ICHECK_EQ(..., 17), so update the intrinsic registration to use .set_num_inputs(17) (replace the current .set_num_inputs(21)) so the registration count matches the call_intrin arguments and the CUDA lowering expectations.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tilelang/language/tir/op.py`:
- Around line 1079-1119: The registered intrinsic tl.ptx_mma_block_scale has a
mismatched arity: the Python wrapper function ptx_mma_block_scale builds 17
call_intrin arguments and the CUDA lowering asserts ICHECK_EQ(..., 17), so
update the intrinsic registration to use .set_num_inputs(17) (replace the
current .set_num_inputs(21)) so the registration count matches the call_intrin
arguments and the CUDA lowering expectations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 82a9e9cc-83ba-44c4-acfa-79f6cc76f5fa
📒 Files selected for processing (5)
src/cuda/codegen/codegen_cuda.ccsrc/tl_templates/cuda/gemm_sm120.htesting/python/language/test_tilelang_language_nvf4_mma_block_scale.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/language/tir/op.py
💤 Files with no reviewable changes (1)
- tilelang/cuda/intrinsics/macro/mma_macro_generator.py
🚧 Files skipped from review as they are similar to previous changes (1)
- testing/python/language/test_tilelang_language_nvf4_mma_block_scale.py
|
Nice work! May I ask whether we support TMA load for operands and SFs? If so, we may also illustrate that in our example. Also, I'm curious whether you've considered warp-specialization (either automatic version or handwritten) and its influence on the performance. |
|
Yes, I am still tuning to optimize the performance. |
Add warp-level mxf4nvf4 block-scale MMA lowering and coverage so TileLang can validate NVF4 kernels against SM120/CUTLASS behavior.
Accuracy Validation
To validate the SM120 NVF4 block-scale MMA lowering, I added a CUTLASS-based reference check in
maint/gemm/correctness_evaluation_nvf4_vs_cutlass.py.The validation compares TileLang's warp-level
mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3path against CUTLASS SM120 block-scaled NVFP4 GEMM using the same packed NVF4 A/B inputs and the same UE4M3 scale factors.Test configuration:
M=N=128, K=256128x128x256e2m1float32ue4m3scale_vec::4XResult: