[feat] gemm support QMM#1115
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request adds support for MXFP (Microscaling Floating Point) GEMM, online quantization, and dequantization in the tilelang Ascend backend. Key feedback points out critical compilation errors in the C++ templates due to hardcoded dimensions and reference binding to null pointers, a template name mismatch between Python and C++, and incorrect slice name resolution in codegen for UB Vector tiles. Additionally, the review identifies potential runtime crashes when handling let-bound variables and symbolic shapes in the Python frontend.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| template <typename TSrc, typename TDst> | ||
| AICORE PTO_INLINE void tquant_mxfp8(TileUbDataND<TDst, 1, 1, 1, 1> &dst, | ||
| TileUbDataND<TSrc, 1, 1, 1, 1> &src, | ||
| TileUbDataND<uint8_t, 1, 1, 1, 1> &exp, | ||
| TileUbDataND<TSrc, 1, 1, 1, 1> &max_buf, | ||
| TileUbDataND<TSrc, 1, 1, 1, 1> &scaling) { | ||
| pto::TQUANT<pto::QuantType::MXFP8>(dst, src, &exp, &max_buf, &scaling); | ||
| } | ||
|
|
||
| // tdequant: Integer dequantization (INT8/INT16 -> float32). Wraps | ||
| // pto::TDEQUANT. Available on A2/A3 platforms only. | ||
| // | ||
| // dst = (src - offset) * scale | ||
| // ============================================================================ | ||
|
|
||
| template <typename TSrc, typename TDst> | ||
| AICORE PTO_INLINE void tdequant(TileUbDataND<TDst, 1, 1, 1, 1> &dst, | ||
| TileUbDataND<TSrc, 1, 1, 1, 1> &src, | ||
| TileUbDataND<TDst, 1, 1, 1, 1> &scale, | ||
| TileUbDataND<TDst, 1, 1, 1, 1> &offset) { | ||
| pto::TDEQUANT(dst, src, scale, offset); | ||
| } |
There was a problem hiding this comment.
Critical Compilation Errors in tquant_mxfp8 and tdequant Templates
There are two critical compilation issues in these template definitions:
- Hardcoded Dimensions: The
TileUbDataNDparameters are hardcoded with1, 1, 1, 1dimensions. This will cause compilation failures when passing actual tiles of arbitrary shapes (e.g.,16x128as used in the tests and examples) because C++ template matching will fail. - Reference Binding to
nullptr: The optional parameters (likemax_buf,scaling, andoffset) are defined as references (TileUbDataND &). However, the codegen passesnullptr(sentinel0ornullptr) when they are omitted, which cannot bind to a reference and will fail to compile.
Suggested Fix
Use fully templated tile types to allow automatic dimension and type deduction, and introduce a get_ptr helper or overloads to safely handle optional nullptr arguments.
template <typename T>
AICORE PTO_INLINE T* get_ptr(T &val) {
return &val;
}
AICORE PTO_INLINE std::nullptr_t get_ptr(std::nullptr_t val) {
return nullptr;
}
template <typename TDstTile, typename TSrcTile, typename TExpTile, typename TMaxParam, typename TScaleParam>
AICORE PTO_INLINE void tquant_mxfp8(TDstTile &dst,
TSrcTile &src,
TExpTile &exp,
TMaxParam max_buf,
TScaleParam scaling) {
pto::TQUANT<pto::QuantType::MXFP8>(dst, src, get_ptr(exp), get_ptr(max_buf), get_ptr(scaling));
}
template <typename TDstTile, typename TSrcTile, typename TScaleTile>
AICORE PTO_INLINE void tdequant(TDstTile &dst,
TSrcTile &src,
TScaleTile &scale,
std::nullptr_t offset) {
pto::TDEQUANT(dst, src, scale);
}
template <typename TDstTile, typename TSrcTile, typename TScaleTile, typename TOffsetTile>
AICORE PTO_INLINE void tdequant(TDstTile &dst,
TSrcTile &src,
TScaleTile &scale,
TOffsetTile &offset) {
pto::TDEQUANT(dst, src, scale, offset);
}| template = ( | ||
| f"mma_mx<{_dtype(A)}, {_dtype(C)}, {scale_ctype}, " | ||
| f"{M}, {N}, {K}>" | ||
| ) |
There was a problem hiding this comment.
High Mismatch Mismatch between Python and C++ Template Names
The Python frontend generates the template string using mma_mx, but the C++ template function defined in src/tl_templates/pto/common.h is named mma_mxfp. This mismatch will cause a compilation error (error: 'mma_mx' is not a member of 'tl::ascend_pto').
Rename the template string in Python to mma_mxfp to match the C++ template definition.
| template = ( | |
| f"mma_mx<{_dtype(A)}, {_dtype(C)}, {scale_ctype}, " | |
| f"{M}, {N}, {K}>" | |
| ) | |
| template = ( | |
| f"mma_mxfp<{_dtype(A)}, {_dtype(C)}, {scale_ctype}, " | |
| f"{M}, {N}, {K}>" | |
| ) |
| # The template string should include "mma_mx<...>" | ||
| template = str(call.args[0]) | ||
| assert "mma_mx<" in template | ||
| assert "float8_e4m3_t" in template | ||
| assert "float, " in template |
There was a problem hiding this comment.
High Mismatch in Test Assertion
Update the test assertion to expect mma_mxfp< instead of mma_mx< to align with the correct C++ template function name.
| # The template string should include "mma_mx<...>" | |
| template = str(call.args[0]) | |
| assert "mma_mx<" in template | |
| assert "float8_e4m3_t" in template | |
| assert "float, " in template | |
| # The template string should include "mma_mxfp<..." | |
| template = str(call.args[0]) | |
| assert "mma_mxfp<" in template | |
| assert "float8_e4m3_t" in template | |
| assert "float, " in template |
| auto resolve_or_null = [&](size_t idx, | ||
| const std::string &tile_name) -> std::string { | ||
| if (idx >= op->args.size()) { | ||
| return "nullptr"; | ||
| } | ||
| const CallNode *access_call = op->args[idx].as<CallNode>(); | ||
| if (!access_call) { | ||
| // Integer literal 0 used as null pointer sentinel | ||
| return "nullptr"; | ||
| } | ||
| ShapeInfo info = GetSliceInfo(access_call); | ||
| return ResolveCubeSliceName(info, tile_name); | ||
| }; | ||
|
|
||
| std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND"); | ||
| std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND"); | ||
| std::string exp_name = resolve_or_null(3, kAscendPtoScope + "TileUbDataND"); | ||
| std::string max_name = resolve_or_null(4, kAscendPtoScope + "TileUbDataND"); | ||
| std::string scaling_name = | ||
| resolve_or_null(5, kAscendPtoScope + "TileUbDataND"); |
There was a problem hiding this comment.
High Correctness Issue in TQuantCodegen Slice Name Resolution
TQuantCodegen incorrectly uses ResolveCubeSliceName for TileUbDataND variables. TileUbDataND is a Unified Buffer (UB) Vector tile, not a Cube tile.
Using ResolveCubeSliceName forces RowValid = slice_row and ColValid = slice_col (which is aligned/padded to 32 bytes), completely losing the actual unaligned valid bounds (slice_valid_row and slice_valid_col). This can lead to out-of-bounds memory accesses or incorrect padding during quantization.
Use ResolveUbSliceName(info) instead, which correctly instantiates TileUbDataND with the actual valid bounds.
| auto resolve_or_null = [&](size_t idx, | |
| const std::string &tile_name) -> std::string { | |
| if (idx >= op->args.size()) { | |
| return "nullptr"; | |
| } | |
| const CallNode *access_call = op->args[idx].as<CallNode>(); | |
| if (!access_call) { | |
| // Integer literal 0 used as null pointer sentinel | |
| return "nullptr"; | |
| } | |
| ShapeInfo info = GetSliceInfo(access_call); | |
| return ResolveCubeSliceName(info, tile_name); | |
| }; | |
| std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND"); | |
| std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND"); | |
| std::string exp_name = resolve_or_null(3, kAscendPtoScope + "TileUbDataND"); | |
| std::string max_name = resolve_or_null(4, kAscendPtoScope + "TileUbDataND"); | |
| std::string scaling_name = | |
| resolve_or_null(5, kAscendPtoScope + "TileUbDataND"); | |
| auto resolve_or_null = [&](size_t idx) -> std::string { | |
| if (idx >= op->args.size()) { | |
| return "nullptr"; | |
| } | |
| const CallNode *access_call = op->args[idx].as<CallNode>(); | |
| if (!access_call) { | |
| // Integer literal 0 used as null pointer sentinel | |
| return "nullptr"; | |
| } | |
| ShapeInfo info = GetSliceInfo(access_call); | |
| return ResolveUbSliceName(info); | |
| }; | |
| std::string dst_name = resolve_or_null(1); | |
| std::string src_name = resolve_or_null(2); | |
| std::string exp_name = resolve_or_null(3); | |
| std::string max_name = resolve_or_null(4); | |
| std::string scaling_name = resolve_or_null(5); |
| auto resolve_or_null = [&](size_t idx, | ||
| const std::string &tile_name) -> std::string { | ||
| if (idx >= op->args.size()) { | ||
| return "nullptr"; | ||
| } | ||
| const CallNode *access_call = op->args[idx].as<CallNode>(); | ||
| if (!access_call) { | ||
| return "nullptr"; | ||
| } | ||
| ShapeInfo info = GetSliceInfo(access_call); | ||
| return ResolveCubeSliceName(info, tile_name); | ||
| }; | ||
|
|
||
| std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND"); | ||
| std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND"); | ||
| std::string scale_name = | ||
| resolve_or_null(3, kAscendPtoScope + "TileUbDataND"); | ||
| std::string offset_name = | ||
| resolve_or_null(4, kAscendPtoScope + "TileUbDataND"); | ||
|
|
There was a problem hiding this comment.
High Correctness Issue in TDequantCodegen Slice Name Resolution
Similar to TQuantCodegen, TDequantCodegen incorrectly uses ResolveCubeSliceName for TileUbDataND variables. Use ResolveUbSliceName(info) to preserve the actual unaligned valid bounds.
| auto resolve_or_null = [&](size_t idx, | |
| const std::string &tile_name) -> std::string { | |
| if (idx >= op->args.size()) { | |
| return "nullptr"; | |
| } | |
| const CallNode *access_call = op->args[idx].as<CallNode>(); | |
| if (!access_call) { | |
| return "nullptr"; | |
| } | |
| ShapeInfo info = GetSliceInfo(access_call); | |
| return ResolveCubeSliceName(info, tile_name); | |
| }; | |
| std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND"); | |
| std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND"); | |
| std::string scale_name = | |
| resolve_or_null(3, kAscendPtoScope + "TileUbDataND"); | |
| std::string offset_name = | |
| resolve_or_null(4, kAscendPtoScope + "TileUbDataND"); | |
| auto resolve_or_null = [&](size_t idx) -> std::string { | |
| if (idx >= op->args.size()) { | |
| return "nullptr"; | |
| } | |
| const CallNode *access_call = op->args[idx].as<CallNode>(); | |
| if (!access_call) { | |
| return "nullptr"; | |
| } | |
| ShapeInfo info = GetSliceInfo(access_call); | |
| return ResolveUbSliceName(info); | |
| }; | |
| std::string dst_name = resolve_or_null(1); | |
| std::string src_name = resolve_or_null(2); | |
| std::string scale_name = resolve_or_null(3); | |
| std::string offset_name = resolve_or_null(4); |
| if isinstance(obj, BufferRegion): | ||
| buf, region = obj.buffer, obj.region |
There was a problem hiding this comment.
Medium Potential Crash with Let-Bound Variables
_get_buffer does not resolve let-bound variables (e.g., tir.Var bound to a buffer region). If a let-bound variable is passed to tquant_mxfp8 or tdequant, _get_buffer will raise a ValueError because it only checks for Buffer and BufferRegion.
Resolve let-bound variables at the beginning of _get_buffer using T.has_let_value and T.get_let_value.
| if isinstance(obj, BufferRegion): | |
| buf, region = obj.buffer, obj.region | |
| if isinstance(obj, tir.Var) and T.has_let_value(obj): | |
| obj = T.get_let_value(obj) | |
| if isinstance(obj, BufferRegion): |
| assert int(K) % 64 == 0, ( | ||
| f"MXFP GEMM requires K to be a multiple of 64, got K={K}" | ||
| ) | ||
| kMXScaleFactor = 32 | ||
| expected_sa_cols = int(K) // kMXScaleFactor | ||
| expected_sb_rows = int(K) // kMXScaleFactor | ||
| assert int(Sa_shape[-1]) == expected_sa_cols, ( | ||
| f"scale_A column mismatch: expected {expected_sa_cols} (K/{kMXScaleFactor}), " | ||
| f"got {Sa_shape[-1]}" | ||
| ) | ||
| assert int(Sb_shape[-2]) == expected_sb_rows, ( | ||
| f"scale_B row mismatch: expected {expected_sb_rows} (K/{kMXScaleFactor}), " | ||
| f"got {Sb_shape[-2]}" | ||
| ) |
There was a problem hiding this comment.
Medium Potential Crash with Symbolic Shapes
npu_gemm_mx shape validation uses int(K) and int(Sa_shape[-1]) directly. If the shapes are symbolic variables (tir.Var), calling int() on them will raise a TypeError at runtime.
Handle symbolic shapes defensively by checking isinstance(..., tir.IntImm) before performing integer modulo/division checks.
| assert int(K) % 64 == 0, ( | |
| f"MXFP GEMM requires K to be a multiple of 64, got K={K}" | |
| ) | |
| kMXScaleFactor = 32 | |
| expected_sa_cols = int(K) // kMXScaleFactor | |
| expected_sb_rows = int(K) // kMXScaleFactor | |
| assert int(Sa_shape[-1]) == expected_sa_cols, ( | |
| f"scale_A column mismatch: expected {expected_sa_cols} (K/{kMXScaleFactor}), " | |
| f"got {Sa_shape[-1]}" | |
| ) | |
| assert int(Sb_shape[-2]) == expected_sb_rows, ( | |
| f"scale_B row mismatch: expected {expected_sb_rows} (K/{kMXScaleFactor}), " | |
| f"got {Sb_shape[-2]}" | |
| ) | |
| if isinstance(K, tir.IntImm): | |
| assert K.value % 64 == 0, ( | |
| f"MXFP GEMM requires K to be a multiple of 64, got K={K.value}" | |
| ) | |
| kMXScaleFactor = 32 | |
| expected_sa_cols = K.value // kMXScaleFactor | |
| expected_sb_rows = K.value // kMXScaleFactor | |
| if isinstance(Sa_shape[-1], tir.IntImm): | |
| assert Sa_shape[-1].value == expected_sa_cols, ( | |
| f"scale_A column mismatch: expected {expected_sa_cols} (K/{kMXScaleFactor}), " | |
| f"got {Sa_shape[-1].value}" | |
| ) | |
| if isinstance(Sb_shape[-2], tir.IntImm): | |
| assert Sb_shape[-2].value == expected_sb_rows, ( | |
| f"scale_B row mismatch: expected {expected_sb_rows} (K/{kMXScaleFactor}), " | |
| f"got {Sb_shape[-2].value}" | |
| ) |
df87470 to
a4de75a
Compare
|
/re-test |
|
🔄 Re-running failed jobs Original workflow run: View details Only the failed jobs will be re-executed. |
|
/re-test |
|
🔄 Re-running failed jobs Original workflow run: View details Only the failed jobs will be re-executed. |
No description provided.