Skip to content

[feat] gemm support QMM#1115

Open
erhsh wants to merge 6 commits into
tile-ai:ascendc_ptofrom
erhsh:ascendc_pto_gemm_support_qmm
Open

[feat] gemm support QMM#1115
erhsh wants to merge 6 commits into
tile-ai:ascendc_ptofrom
erhsh:ascendc_pto_gemm_support_qmm

Conversation

@erhsh

@erhsh erhsh commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

No description provided.

@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MXFP (Microscaling Floating Point) GEMM, online quantization, and dequantization in the tilelang Ascend backend. Key feedback points out critical compilation errors in the C++ templates due to hardcoded dimensions and reference binding to null pointers, a template name mismatch between Python and C++, and incorrect slice name resolution in codegen for UB Vector tiles. Additionally, the review identifies potential runtime crashes when handling let-bound variables and symbolic shapes in the Python frontend.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/tl_templates/pto/common.h Outdated
Comment on lines +1554 to +1575
template <typename TSrc, typename TDst>
AICORE PTO_INLINE void tquant_mxfp8(TileUbDataND<TDst, 1, 1, 1, 1> &dst,
TileUbDataND<TSrc, 1, 1, 1, 1> &src,
TileUbDataND<uint8_t, 1, 1, 1, 1> &exp,
TileUbDataND<TSrc, 1, 1, 1, 1> &max_buf,
TileUbDataND<TSrc, 1, 1, 1, 1> &scaling) {
pto::TQUANT<pto::QuantType::MXFP8>(dst, src, &exp, &max_buf, &scaling);
}

// tdequant: Integer dequantization (INT8/INT16 -> float32). Wraps
// pto::TDEQUANT. Available on A2/A3 platforms only.
//
// dst = (src - offset) * scale
// ============================================================================

template <typename TSrc, typename TDst>
AICORE PTO_INLINE void tdequant(TileUbDataND<TDst, 1, 1, 1, 1> &dst,
TileUbDataND<TSrc, 1, 1, 1, 1> &src,
TileUbDataND<TDst, 1, 1, 1, 1> &scale,
TileUbDataND<TDst, 1, 1, 1, 1> &offset) {
pto::TDEQUANT(dst, src, scale, offset);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical Compilation Errors in tquant_mxfp8 and tdequant Templates

There are two critical compilation issues in these template definitions:

  1. Hardcoded Dimensions: The TileUbDataND parameters are hardcoded with 1, 1, 1, 1 dimensions. This will cause compilation failures when passing actual tiles of arbitrary shapes (e.g., 16x128 as used in the tests and examples) because C++ template matching will fail.
  2. Reference Binding to nullptr: The optional parameters (like max_buf, scaling, and offset) are defined as references (TileUbDataND &). However, the codegen passes nullptr (sentinel 0 or nullptr) when they are omitted, which cannot bind to a reference and will fail to compile.

Suggested Fix

Use fully templated tile types to allow automatic dimension and type deduction, and introduce a get_ptr helper or overloads to safely handle optional nullptr arguments.

template <typename T>
AICORE PTO_INLINE T* get_ptr(T &val) {
  return &val;
}

AICORE PTO_INLINE std::nullptr_t get_ptr(std::nullptr_t val) {
  return nullptr;
}

template <typename TDstTile, typename TSrcTile, typename TExpTile, typename TMaxParam, typename TScaleParam>
AICORE PTO_INLINE void tquant_mxfp8(TDstTile &dst,
                                    TSrcTile &src,
                                    TExpTile &exp,
                                    TMaxParam max_buf,
                                    TScaleParam scaling) {
  pto::TQUANT<pto::QuantType::MXFP8>(dst, src, get_ptr(exp), get_ptr(max_buf), get_ptr(scaling));
}

template <typename TDstTile, typename TSrcTile, typename TScaleTile>
AICORE PTO_INLINE void tdequant(TDstTile &dst,
                                TSrcTile &src,
                                TScaleTile &scale,
                                std::nullptr_t offset) {
  pto::TDEQUANT(dst, src, scale);
}

template <typename TDstTile, typename TSrcTile, typename TScaleTile, typename TOffsetTile>
AICORE PTO_INLINE void tdequant(TDstTile &dst,
                                TSrcTile &src,
                                TScaleTile &scale,
                                TOffsetTile &offset) {
  pto::TDEQUANT(dst, src, scale, offset);
}

Comment thread tilelang/language/customize.py Outdated
Comment on lines +310 to +313
template = (
f"mma_mx<{_dtype(A)}, {_dtype(C)}, {scale_ctype}, "
f"{M}, {N}, {K}>"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High Mismatch Mismatch between Python and C++ Template Names

The Python frontend generates the template string using mma_mx, but the C++ template function defined in src/tl_templates/pto/common.h is named mma_mxfp. This mismatch will cause a compilation error (error: 'mma_mx' is not a member of 'tl::ascend_pto').

Rename the template string in Python to mma_mxfp to match the C++ template definition.

Suggested change
template = (
f"mma_mx<{_dtype(A)}, {_dtype(C)}, {scale_ctype}, "
f"{M}, {N}, {K}>"
)
template = (
f"mma_mxfp<{_dtype(A)}, {_dtype(C)}, {scale_ctype}, "
f"{M}, {N}, {K}>"
)

Comment on lines +113 to +117
# The template string should include "mma_mx<...>"
template = str(call.args[0])
assert "mma_mx<" in template
assert "float8_e4m3_t" in template
assert "float, " in template

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High Mismatch in Test Assertion

Update the test assertion to expect mma_mxfp< instead of mma_mx< to align with the correct C++ template function name.

Suggested change
# The template string should include "mma_mx<...>"
template = str(call.args[0])
assert "mma_mx<" in template
assert "float8_e4m3_t" in template
assert "float, " in template
# The template string should include "mma_mxfp<..."
template = str(call.args[0])
assert "mma_mxfp<" in template
assert "float8_e4m3_t" in template
assert "float, " in template

Comment thread src/target/codegen_ascend_pto.cc Outdated
Comment on lines +3179 to +3198
auto resolve_or_null = [&](size_t idx,
const std::string &tile_name) -> std::string {
if (idx >= op->args.size()) {
return "nullptr";
}
const CallNode *access_call = op->args[idx].as<CallNode>();
if (!access_call) {
// Integer literal 0 used as null pointer sentinel
return "nullptr";
}
ShapeInfo info = GetSliceInfo(access_call);
return ResolveCubeSliceName(info, tile_name);
};

std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND");
std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND");
std::string exp_name = resolve_or_null(3, kAscendPtoScope + "TileUbDataND");
std::string max_name = resolve_or_null(4, kAscendPtoScope + "TileUbDataND");
std::string scaling_name =
resolve_or_null(5, kAscendPtoScope + "TileUbDataND");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High Correctness Issue in TQuantCodegen Slice Name Resolution

TQuantCodegen incorrectly uses ResolveCubeSliceName for TileUbDataND variables. TileUbDataND is a Unified Buffer (UB) Vector tile, not a Cube tile.

Using ResolveCubeSliceName forces RowValid = slice_row and ColValid = slice_col (which is aligned/padded to 32 bytes), completely losing the actual unaligned valid bounds (slice_valid_row and slice_valid_col). This can lead to out-of-bounds memory accesses or incorrect padding during quantization.

Use ResolveUbSliceName(info) instead, which correctly instantiates TileUbDataND with the actual valid bounds.

Suggested change
auto resolve_or_null = [&](size_t idx,
const std::string &tile_name) -> std::string {
if (idx >= op->args.size()) {
return "nullptr";
}
const CallNode *access_call = op->args[idx].as<CallNode>();
if (!access_call) {
// Integer literal 0 used as null pointer sentinel
return "nullptr";
}
ShapeInfo info = GetSliceInfo(access_call);
return ResolveCubeSliceName(info, tile_name);
};
std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND");
std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND");
std::string exp_name = resolve_or_null(3, kAscendPtoScope + "TileUbDataND");
std::string max_name = resolve_or_null(4, kAscendPtoScope + "TileUbDataND");
std::string scaling_name =
resolve_or_null(5, kAscendPtoScope + "TileUbDataND");
auto resolve_or_null = [&](size_t idx) -> std::string {
if (idx >= op->args.size()) {
return "nullptr";
}
const CallNode *access_call = op->args[idx].as<CallNode>();
if (!access_call) {
// Integer literal 0 used as null pointer sentinel
return "nullptr";
}
ShapeInfo info = GetSliceInfo(access_call);
return ResolveUbSliceName(info);
};
std::string dst_name = resolve_or_null(1);
std::string src_name = resolve_or_null(2);
std::string exp_name = resolve_or_null(3);
std::string max_name = resolve_or_null(4);
std::string scaling_name = resolve_or_null(5);

Comment thread src/target/codegen_ascend_pto.cc Outdated
Comment on lines +3227 to +3246
auto resolve_or_null = [&](size_t idx,
const std::string &tile_name) -> std::string {
if (idx >= op->args.size()) {
return "nullptr";
}
const CallNode *access_call = op->args[idx].as<CallNode>();
if (!access_call) {
return "nullptr";
}
ShapeInfo info = GetSliceInfo(access_call);
return ResolveCubeSliceName(info, tile_name);
};

std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND");
std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND");
std::string scale_name =
resolve_or_null(3, kAscendPtoScope + "TileUbDataND");
std::string offset_name =
resolve_or_null(4, kAscendPtoScope + "TileUbDataND");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High Correctness Issue in TDequantCodegen Slice Name Resolution

Similar to TQuantCodegen, TDequantCodegen incorrectly uses ResolveCubeSliceName for TileUbDataND variables. Use ResolveUbSliceName(info) to preserve the actual unaligned valid bounds.

Suggested change
auto resolve_or_null = [&](size_t idx,
const std::string &tile_name) -> std::string {
if (idx >= op->args.size()) {
return "nullptr";
}
const CallNode *access_call = op->args[idx].as<CallNode>();
if (!access_call) {
return "nullptr";
}
ShapeInfo info = GetSliceInfo(access_call);
return ResolveCubeSliceName(info, tile_name);
};
std::string dst_name = resolve_or_null(1, kAscendPtoScope + "TileUbDataND");
std::string src_name = resolve_or_null(2, kAscendPtoScope + "TileUbDataND");
std::string scale_name =
resolve_or_null(3, kAscendPtoScope + "TileUbDataND");
std::string offset_name =
resolve_or_null(4, kAscendPtoScope + "TileUbDataND");
auto resolve_or_null = [&](size_t idx) -> std::string {
if (idx >= op->args.size()) {
return "nullptr";
}
const CallNode *access_call = op->args[idx].as<CallNode>();
if (!access_call) {
return "nullptr";
}
ShapeInfo info = GetSliceInfo(access_call);
return ResolveUbSliceName(info);
};
std::string dst_name = resolve_or_null(1);
std::string src_name = resolve_or_null(2);
std::string scale_name = resolve_or_null(3);
std::string offset_name = resolve_or_null(4);

Comment on lines +2256 to +2257
if isinstance(obj, BufferRegion):
buf, region = obj.buffer, obj.region

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Medium Potential Crash with Let-Bound Variables

_get_buffer does not resolve let-bound variables (e.g., tir.Var bound to a buffer region). If a let-bound variable is passed to tquant_mxfp8 or tdequant, _get_buffer will raise a ValueError because it only checks for Buffer and BufferRegion.

Resolve let-bound variables at the beginning of _get_buffer using T.has_let_value and T.get_let_value.

Suggested change
if isinstance(obj, BufferRegion):
buf, region = obj.buffer, obj.region
if isinstance(obj, tir.Var) and T.has_let_value(obj):
obj = T.get_let_value(obj)
if isinstance(obj, BufferRegion):

Comment thread tilelang/language/customize.py Outdated
Comment on lines +256 to +269
assert int(K) % 64 == 0, (
f"MXFP GEMM requires K to be a multiple of 64, got K={K}"
)
kMXScaleFactor = 32
expected_sa_cols = int(K) // kMXScaleFactor
expected_sb_rows = int(K) // kMXScaleFactor
assert int(Sa_shape[-1]) == expected_sa_cols, (
f"scale_A column mismatch: expected {expected_sa_cols} (K/{kMXScaleFactor}), "
f"got {Sa_shape[-1]}"
)
assert int(Sb_shape[-2]) == expected_sb_rows, (
f"scale_B row mismatch: expected {expected_sb_rows} (K/{kMXScaleFactor}), "
f"got {Sb_shape[-2]}"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Medium Potential Crash with Symbolic Shapes

npu_gemm_mx shape validation uses int(K) and int(Sa_shape[-1]) directly. If the shapes are symbolic variables (tir.Var), calling int() on them will raise a TypeError at runtime.

Handle symbolic shapes defensively by checking isinstance(..., tir.IntImm) before performing integer modulo/division checks.

Suggested change
assert int(K) % 64 == 0, (
f"MXFP GEMM requires K to be a multiple of 64, got K={K}"
)
kMXScaleFactor = 32
expected_sa_cols = int(K) // kMXScaleFactor
expected_sb_rows = int(K) // kMXScaleFactor
assert int(Sa_shape[-1]) == expected_sa_cols, (
f"scale_A column mismatch: expected {expected_sa_cols} (K/{kMXScaleFactor}), "
f"got {Sa_shape[-1]}"
)
assert int(Sb_shape[-2]) == expected_sb_rows, (
f"scale_B row mismatch: expected {expected_sb_rows} (K/{kMXScaleFactor}), "
f"got {Sb_shape[-2]}"
)
if isinstance(K, tir.IntImm):
assert K.value % 64 == 0, (
f"MXFP GEMM requires K to be a multiple of 64, got K={K.value}"
)
kMXScaleFactor = 32
expected_sa_cols = K.value // kMXScaleFactor
expected_sb_rows = K.value // kMXScaleFactor
if isinstance(Sa_shape[-1], tir.IntImm):
assert Sa_shape[-1].value == expected_sa_cols, (
f"scale_A column mismatch: expected {expected_sa_cols} (K/{kMXScaleFactor}), "
f"got {Sa_shape[-1].value}"
)
if isinstance(Sb_shape[-2], tir.IntImm):
assert Sb_shape[-2].value == expected_sb_rows, (
f"scale_B row mismatch: expected {expected_sb_rows} (K/{kMXScaleFactor}), "
f"got {Sb_shape[-2].value}"
)

@erhsh erhsh force-pushed the ascendc_pto_gemm_support_qmm branch from df87470 to a4de75a Compare June 2, 2026 11:32
@erhsh

erhsh commented Jun 2, 2026

Copy link
Copy Markdown
Contributor Author

/re-test

@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

🔄 Re-running failed jobs

Original workflow run: View details

Only the failed jobs will be re-executed.

@erhsh

erhsh commented Jun 2, 2026

Copy link
Copy Markdown
Contributor Author

/re-test

@github-actions

github-actions Bot commented Jun 2, 2026

Copy link
Copy Markdown

🔄 Re-running failed jobs

Original workflow run: View details

Only the failed jobs will be re-executed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant