Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/triton/ops/gemm/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from scripts.triton import aot

_DTYPES = ("fp16", "bf16", "fp32")
_BLOCK_SIZES = ((64, 64, 32), (128, 64, 32), (64, 128, 32))
_ALIGNMENTS = (16, None)
_NUM_WARPS = 4
_NUM_STAGES = 3
_GROUP_SIZE_M = 8
_DATA_PTRS = ("a_ptr", "b_ptr", "c_ptr")
_I32_SCALARS = ("m", "n", "k", "batch_count")
_I64_SCALARS = (
"stride_am",
"stride_ak",
"stride_bk",
"stride_bn",
"stride_cm",
"stride_cn",
"batch_stride_a",
"batch_stride_b",
"batch_stride_c",
)


def _signature(dtype, block_size, alignment):
block_m, block_n, block_k = block_size
return aot.Signature(
pointer_dtypes={name: dtype for name in _DATA_PTRS},
pointer_alignments={name: alignment for name in _DATA_PTRS},
scalar_dtypes={
"alpha": "fp64",
"beta": "fp64",
**{name: "i32" for name in _I32_SCALARS},
**{name: "i64" for name in _I64_SCALARS},
},
constexprs={
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": _GROUP_SIZE_M,
},
)


def configs():
for dtype in _DTYPES:
yield tuple(
aot.CompileConfig(
signature=_signature(dtype, block_size, alignment),
grid=(
f"((m + {block_size[0]} - 1) / {block_size[0]}) * "
f"((n + {block_size[1]} - 1) / {block_size[1]}), batch_count, 1"
),
out_name=f"infini_ops_triton_gemm_{dtype}",
num_warps=_NUM_WARPS,
num_stages=_NUM_STAGES,
)
for block_size in _BLOCK_SIZES
for alignment in _ALIGNMENTS
)
76 changes: 76 additions & 0 deletions src/triton/ops/gemm/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#ifndef INFINI_OPS_TRITON_GEMM_H_
#define INFINI_OPS_TRITON_GEMM_H_

#include <cuda.h>

#include <cassert>
#include <cstdint>
#include <optional>

#include "base/gemm.h"
#include "data_type.h"
#include "gemm/infini_ops_triton_gemm.h"

namespace infini::ops {

template <>
class Operator<Gemm, Device::Type::kNvidia, 8> : public Gemm {
public:
using Gemm::operator();

Operator(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c)
: Gemm{a, b, alpha, beta, trans_a, trans_b, c} {}

void operator()(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c) const override {
assert(a_type_ == b_type_ && b_type_ == c_type_ &&
"Triton `Gemm` requires A, B, and C tensors to have the same dtype");

const auto alpha_value = alpha.value_or(alpha_);
const auto beta_value = beta.value_or(beta_);
const auto trans_a_value = static_cast<bool>(trans_a.value_or(trans_a_));
const auto trans_b_value = static_cast<bool>(trans_b.value_or(trans_b_));

const auto stride_am =
static_cast<int64_t>(trans_a_value ? a_strides_[a_strides_.size() - 1]
: a_strides_[a_strides_.size() - 2]);
const auto stride_ak =
static_cast<int64_t>(trans_a_value ? a_strides_[a_strides_.size() - 2]
: a_strides_[a_strides_.size() - 1]);
const auto stride_bk =
static_cast<int64_t>(trans_b_value ? b_strides_[b_strides_.size() - 1]
: b_strides_[b_strides_.size() - 2]);
const auto stride_bn =
static_cast<int64_t>(trans_b_value ? b_strides_[b_strides_.size() - 2]
: b_strides_[b_strides_.size() - 1]);
const auto stride_cm =
static_cast<int64_t>(c_strides_[c_strides_.size() - 2]);
const auto stride_cn =
static_cast<int64_t>(c_strides_[c_strides_.size() - 1]);
const auto batch_stride_a = static_cast<int64_t>(batch_stride_a_);
const auto batch_stride_b = static_cast<int64_t>(batch_stride_b_);
const auto batch_stride_c = static_cast<int64_t>(batch_stride_c_);
const auto batch_count = static_cast<int32_t>(batch_count_);

load_infini_ops_triton_gemm(c_type_);

auto result = launch_infini_ops_triton_gemm(
c_type_, static_cast<CUstream>(stream_),
reinterpret_cast<CUdeviceptr>(const_cast<void*>(a.data())),
reinterpret_cast<CUdeviceptr>(const_cast<void*>(b.data())),
reinterpret_cast<CUdeviceptr>(c.data()), alpha_value, beta_value,
static_cast<int32_t>(m_), static_cast<int32_t>(n_),
static_cast<int32_t>(k_), stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, batch_stride_a, batch_stride_b, batch_stride_c,
batch_count);

assert(result == CUDA_SUCCESS && "Triton `Gemm` launch failed");
}
};

} // namespace infini::ops

#endif
76 changes: 76 additions & 0 deletions src/triton/ops/gemm/gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import triton
import triton.language as tl


@triton.jit
def kernel(
a_ptr,
b_ptr,
c_ptr,
alpha,
beta,
m,
n,
k,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
batch_stride_a,
batch_stride_b,
batch_stride_c,
batch_count,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
batch = tl.program_id(1)

num_pid_m = tl.cdiv(m, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(n, BLOCK_SIZE_N)

group_size = GROUP_SIZE_M * num_pid_n
group = pid // group_size
first_pid_m = group * GROUP_SIZE_M
group_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M)

pid_m = first_pid_m + (pid % group_m)
pid_n = (pid % group_size) // group_m

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)

a_base = a_ptr + batch * batch_stride_a
b_base = b_ptr + batch * batch_stride_b
c_base = c_ptr + batch * batch_stride_c

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k_start in tl.range(0, k, BLOCK_SIZE_K):
k_idxs = k_start + offs_k

a = tl.load(
a_base + offs_m[:, None] * stride_am + k_idxs[None, :] * stride_ak,
mask=(offs_m[:, None] < m) & (k_idxs[None, :] < k),
other=0.0,
)
b = tl.load(
b_base + k_idxs[:, None] * stride_bk + offs_n[None, :] * stride_bn,
mask=(k_idxs[:, None] < k) & (offs_n[None, :] < n),
other=0.0,
)

acc = tl.dot(a, b, acc)

c_offsets = c_base + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
c_mask = (offs_m[:, None] < m) & (offs_n[None, :] < n)

c = tl.load(c_offsets, mask=c_mask, other=0.0)
out = alpha * acc + beta * c

tl.store(c_offsets, out, mask=c_mask)
Loading