From ce08af063c7791c0f2d418e7ed83ce221a09918a Mon Sep 17 00:00:00 2001 From: Zhang Shuo <52872288+fuyou4546@users.noreply.github.com> Date: Mon, 22 Jun 2026 04:43:56 +0000 Subject: [PATCH] feat(triton): add gemm operator --- src/triton/ops/gemm/build.py | 59 ++++++++++++++++++++++++++++ src/triton/ops/gemm/gemm.h | 76 ++++++++++++++++++++++++++++++++++++ src/triton/ops/gemm/gemm.py | 76 ++++++++++++++++++++++++++++++++++++ 3 files changed, 211 insertions(+) create mode 100644 src/triton/ops/gemm/build.py create mode 100644 src/triton/ops/gemm/gemm.h create mode 100644 src/triton/ops/gemm/gemm.py diff --git a/src/triton/ops/gemm/build.py b/src/triton/ops/gemm/build.py new file mode 100644 index 000000000..6952b52d9 --- /dev/null +++ b/src/triton/ops/gemm/build.py @@ -0,0 +1,59 @@ +from scripts.triton import aot + +_DTYPES = ("fp16", "bf16", "fp32") +_BLOCK_SIZES = ((64, 64, 32), (128, 64, 32), (64, 128, 32)) +_ALIGNMENTS = (16, None) +_NUM_WARPS = 4 +_NUM_STAGES = 3 +_GROUP_SIZE_M = 8 +_DATA_PTRS = ("a_ptr", "b_ptr", "c_ptr") +_I32_SCALARS = ("m", "n", "k", "batch_count") +_I64_SCALARS = ( + "stride_am", + "stride_ak", + "stride_bk", + "stride_bn", + "stride_cm", + "stride_cn", + "batch_stride_a", + "batch_stride_b", + "batch_stride_c", +) + + +def _signature(dtype, block_size, alignment): + block_m, block_n, block_k = block_size + return aot.Signature( + pointer_dtypes={name: dtype for name in _DATA_PTRS}, + pointer_alignments={name: alignment for name in _DATA_PTRS}, + scalar_dtypes={ + "alpha": "fp64", + "beta": "fp64", + **{name: "i32" for name in _I32_SCALARS}, + **{name: "i64" for name in _I64_SCALARS}, + }, + constexprs={ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": _GROUP_SIZE_M, + }, + ) + + +def configs(): + for dtype in _DTYPES: + yield tuple( + aot.CompileConfig( + signature=_signature(dtype, block_size, alignment), + grid=( + f"((m + {block_size[0]} - 1) / {block_size[0]}) * " + f"((n + {block_size[1]} - 1) / {block_size[1]}), batch_count, 1" + ), + out_name=f"infini_ops_triton_gemm_{dtype}", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) + for block_size in _BLOCK_SIZES + for alignment in _ALIGNMENTS + ) diff --git a/src/triton/ops/gemm/gemm.h b/src/triton/ops/gemm/gemm.h new file mode 100644 index 000000000..d4f785dfc --- /dev/null +++ b/src/triton/ops/gemm/gemm.h @@ -0,0 +1,76 @@ +#ifndef INFINI_OPS_TRITON_GEMM_H_ +#define INFINI_OPS_TRITON_GEMM_H_ + +#include + +#include +#include +#include + +#include "base/gemm.h" +#include "data_type.h" +#include "gemm/infini_ops_triton_gemm.h" + +namespace infini::ops { + +template <> +class Operator : public Gemm { + public: + using Gemm::operator(); + + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) + : Gemm{a, b, alpha, beta, trans_a, trans_b, c} {} + + void operator()(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, std::optional trans_a, + std::optional trans_b, Tensor c) const override { + assert(a_type_ == b_type_ && b_type_ == c_type_ && + "Triton `Gemm` requires A, B, and C tensors to have the same dtype"); + + const auto alpha_value = alpha.value_or(alpha_); + const auto beta_value = beta.value_or(beta_); + const auto trans_a_value = static_cast(trans_a.value_or(trans_a_)); + const auto trans_b_value = static_cast(trans_b.value_or(trans_b_)); + + const auto stride_am = + static_cast(trans_a_value ? a_strides_[a_strides_.size() - 1] + : a_strides_[a_strides_.size() - 2]); + const auto stride_ak = + static_cast(trans_a_value ? a_strides_[a_strides_.size() - 2] + : a_strides_[a_strides_.size() - 1]); + const auto stride_bk = + static_cast(trans_b_value ? b_strides_[b_strides_.size() - 1] + : b_strides_[b_strides_.size() - 2]); + const auto stride_bn = + static_cast(trans_b_value ? b_strides_[b_strides_.size() - 2] + : b_strides_[b_strides_.size() - 1]); + const auto stride_cm = + static_cast(c_strides_[c_strides_.size() - 2]); + const auto stride_cn = + static_cast(c_strides_[c_strides_.size() - 1]); + const auto batch_stride_a = static_cast(batch_stride_a_); + const auto batch_stride_b = static_cast(batch_stride_b_); + const auto batch_stride_c = static_cast(batch_stride_c_); + const auto batch_count = static_cast(batch_count_); + + load_infini_ops_triton_gemm(c_type_); + + auto result = launch_infini_ops_triton_gemm( + c_type_, static_cast(stream_), + reinterpret_cast(const_cast(a.data())), + reinterpret_cast(const_cast(b.data())), + reinterpret_cast(c.data()), alpha_value, beta_value, + static_cast(m_), static_cast(n_), + static_cast(k_), stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, batch_stride_a, batch_stride_b, batch_stride_c, + batch_count); + + assert(result == CUDA_SUCCESS && "Triton `Gemm` launch failed"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/triton/ops/gemm/gemm.py b/src/triton/ops/gemm/gemm.py new file mode 100644 index 000000000..66a769430 --- /dev/null +++ b/src/triton/ops/gemm/gemm.py @@ -0,0 +1,76 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + a_ptr, + b_ptr, + c_ptr, + alpha, + beta, + m, + n, + k, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + batch_stride_a, + batch_stride_b, + batch_stride_c, + batch_count, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(0) + batch = tl.program_id(1) + + num_pid_m = tl.cdiv(m, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(n, BLOCK_SIZE_N) + + group_size = GROUP_SIZE_M * num_pid_n + group = pid // group_size + first_pid_m = group * GROUP_SIZE_M + group_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + + pid_m = first_pid_m + (pid % group_m) + pid_n = (pid % group_size) // group_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_base = a_ptr + batch * batch_stride_a + b_base = b_ptr + batch * batch_stride_b + c_base = c_ptr + batch * batch_stride_c + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k_start in tl.range(0, k, BLOCK_SIZE_K): + k_idxs = k_start + offs_k + + a = tl.load( + a_base + offs_m[:, None] * stride_am + k_idxs[None, :] * stride_ak, + mask=(offs_m[:, None] < m) & (k_idxs[None, :] < k), + other=0.0, + ) + b = tl.load( + b_base + k_idxs[:, None] * stride_bk + offs_n[None, :] * stride_bn, + mask=(k_idxs[:, None] < k) & (offs_n[None, :] < n), + other=0.0, + ) + + acc = tl.dot(a, b, acc) + + c_offsets = c_base + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + c_mask = (offs_m[:, None] < m) & (offs_n[None, :] < n) + + c = tl.load(c_offsets, mask=c_mask, other=0.0) + out = alpha * acc + beta * c + + tl.store(c_offsets, out, mask=c_mask)