From c77ea4e5dfe55ab54ae5afaa9bb5cf1ba793ceaa Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 25 Jun 2026 06:50:39 +0000 Subject: [PATCH 1/2] issue/1192: add marlin repack operators Squash AWQ/GPTQ Marlin repack operator commits from issue/1192-clean. --- include/infiniop/ops/awq_marlin_repack.h | 27 + include/infiniop/ops/gptq_marlin_repack.h | 29 + .../ops/awq_marlin_repack/awq_marlin_repack.h | 48 ++ .../ops/awq_marlin_repack/cuda/kernel.cuh | 197 +++++++ src/infiniop/ops/awq_marlin_repack/info.h | 49 ++ .../ops/awq_marlin_repack/marlin/marlin.cuh | 178 ++++++ .../nvidia/awq_marlin_repack_nvidia.cu | 122 ++++ .../nvidia/awq_marlin_repack_nvidia.cuh | 8 + .../ops/awq_marlin_repack/operator.cc | 101 ++++ .../ops/gptq_marlin_repack/cuda/kernel.cuh | 252 +++++++++ .../gptq_marlin_repack/gptq_marlin_repack.h | 50 ++ src/infiniop/ops/gptq_marlin_repack/info.h | 55 ++ .../ops/gptq_marlin_repack/marlin/marlin.cuh | 178 ++++++ .../nvidia/gptq_marlin_repack_nvidia.cu | 134 +++++ .../nvidia/gptq_marlin_repack_nvidia.cuh | 8 + .../ops/gptq_marlin_repack/operator.cc | 104 ++++ test/infiniop/awq_marlin_repack.py | 443 +++++++++++++++ test/infiniop/gptq_marlin_repack.py | 531 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 67 +++ 19 files changed, 2581 insertions(+) create mode 100644 include/infiniop/ops/awq_marlin_repack.h create mode 100644 include/infiniop/ops/gptq_marlin_repack.h create mode 100644 src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h create mode 100644 src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh create mode 100644 src/infiniop/ops/awq_marlin_repack/info.h create mode 100644 src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh create mode 100644 src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu create mode 100644 src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh create mode 100644 src/infiniop/ops/awq_marlin_repack/operator.cc create mode 100644 src/infiniop/ops/gptq_marlin_repack/cuda/kernel.cuh create mode 100644 src/infiniop/ops/gptq_marlin_repack/gptq_marlin_repack.h create mode 100644 src/infiniop/ops/gptq_marlin_repack/info.h create mode 100644 src/infiniop/ops/gptq_marlin_repack/marlin/marlin.cuh create mode 100644 src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu create mode 100644 src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cuh create mode 100644 src/infiniop/ops/gptq_marlin_repack/operator.cc create mode 100644 test/infiniop/awq_marlin_repack.py create mode 100644 test/infiniop/gptq_marlin_repack.py diff --git a/include/infiniop/ops/awq_marlin_repack.h b/include/infiniop/ops/awq_marlin_repack.h new file mode 100644 index 000000000..017ff5568 --- /dev/null +++ b/include/infiniop/ops/awq_marlin_repack.h @@ -0,0 +1,27 @@ +#ifndef __INFINIOP_AWQ_MARLIN_REPACK_API_H__ +#define __INFINIOP_AWQ_MARLIN_REPACK_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopAwqMarlinRepackDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateAwqMarlinRepackDescriptor(infiniopHandle_t handle, + infiniopAwqMarlinRepackDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit); + +__INFINI_C __export infiniStatus_t infiniopGetAwqMarlinRepackWorkspaceSize(infiniopAwqMarlinRepackDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopAwqMarlinRepack(infiniopAwqMarlinRepackDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyAwqMarlinRepackDescriptor(infiniopAwqMarlinRepackDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/gptq_marlin_repack.h b/include/infiniop/ops/gptq_marlin_repack.h new file mode 100644 index 000000000..c3b588fa5 --- /dev/null +++ b/include/infiniop/ops/gptq_marlin_repack.h @@ -0,0 +1,29 @@ +#ifndef __INFINIOP_GPTQ_MARLIN_REPACK_API_H__ +#define __INFINIOP_GPTQ_MARLIN_REPACK_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopGptqMarlinRepackDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateGptqMarlinRepackDescriptor(infiniopHandle_t handle, + infiniopGptqMarlinRepackDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t perm_desc, + int64_t num_bits, + bool is_a_8bit); + +__INFINI_C __export infiniStatus_t infiniopGetGptqMarlinRepackWorkspaceSize(infiniopGptqMarlinRepackDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopGptqMarlinRepack(infiniopGptqMarlinRepackDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + const void *perm, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyGptqMarlinRepackDescriptor(infiniopGptqMarlinRepackDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h b/src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h new file mode 100644 index 000000000..e2768173f --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h @@ -0,0 +1,48 @@ +#ifndef AWQ_MARLIN_REPACK_H +#define AWQ_MARLIN_REPACK_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::awq_marlin_repack::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AwqMarlinRepackInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + AwqMarlinRepackInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + infiniopTensorDescriptor_t input_desc, \ + int64_t num_bits, \ + bool is_a_8bit); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *output, \ + const void *input, \ + void *stream) const; \ + }; \ + } + +#endif // AWQ_MARLIN_REPACK_H diff --git a/src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh b/src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh new file mode 100644 index 000000000..b7288b603 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh @@ -0,0 +1,197 @@ +#include "../marlin/marlin.cuh" + +namespace marlin { + +template +__device__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = target_tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = target_tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * target_tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4 *sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * target_tile_k_size; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (is_a_8bit) { + int cur_elem = tc_row + i; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + sh_stride * (cur_elem + 16)]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } else { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + } + + constexpr int tile_size = target_tile_k_size * target_tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace marlin diff --git a/src/infiniop/ops/awq_marlin_repack/info.h b/src/infiniop/ops/awq_marlin_repack/info.h new file mode 100644 index 000000000..c9dea93fe --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/info.h @@ -0,0 +1,49 @@ +#ifndef __AWQ_MARLIN_REPACK_INFO_H__ +#define __AWQ_MARLIN_REPACK_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include "marlin/marlin.cuh" +#include + +#include + +namespace op::awq_marlin_repack { + +class AwqMarlinRepackInfo { + AwqMarlinRepackInfo() = default; + +public: + infiniDtype_t output_dtype, input_dtype; + size_t size_k, size_n; + int64_t num_bits; + bool is_a_8bit; + + static utils::Result create( + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit) { + CHECK_OR_RETURN( + output_desc != nullptr && input_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + const infiniDtype_t output_dtype = output_desc->dtype(); + const infiniDtype_t input_dtype = input_desc->dtype(); + CHECK_DTYPE(input_dtype, INFINI_DTYPE_I32); + CHECK_DTYPE(input_dtype, output_dtype); + + size_t size_k = input_desc->dim(0); + int const pack_factor = 32 / num_bits; + size_t size_n = input_desc->dim(1) * pack_factor; + + CHECK_OR_RETURN(size_k / marlin::tile_size == output_desc->dim(0) || size_n * marlin::tile_size / pack_factor == output_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result( + AwqMarlinRepackInfo{output_dtype, input_dtype, size_k, size_n, num_bits, is_a_8bit}); + } +}; + +} // namespace op::awq_marlin_repack + +#endif // __AWQ_MARLIN_REPACK_INFO_H__ diff --git a/src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh b/src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh new file mode 100644 index 000000000..f3d897d27 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh @@ -0,0 +1,178 @@ +#pragma once + +#ifndef _marlin_cuh +#define _marlin_cuh + +#include +#include +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +template +__device__ __forceinline__ uint32_t __cvta_generic_to_shared(T *ptr) { + size_t smem_addr; + asm volatile( + "cvta.to.shared.u64 %0, %1;" + : "=l"(smem_addr) + : "l"(ptr)); + return static_cast(smem_addr); +} + +namespace MARLIN_NAMESPACE_NAME { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; +} + +__device__ inline void cp_async_fence() {} + +template +__device__ inline void cp_async_wait() {} + +#else + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu new file mode 100644 index 000000000..ffa7ff44c --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu @@ -0,0 +1,122 @@ +#if defined(ENABLE_NVIDIA_API) +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/kernel.cuh" +#include "awq_marlin_repack_nvidia.cuh" +#include + +template +INFINIOP_CUDA_KERNEL awqMarlinRepackKernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, + int size_k, int size_n) { + marlin::awq_marlin_repack_kernel( + b_q_weight_ptr, out_ptr, + size_k, size_n); +} + +#define CALL_IF(NUM_BITS, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \ + cudaFuncSetAttribute( \ + awqMarlinRepackKernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + awqMarlinRepackKernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ + } + +infiniStatus_t awqMarlinRepack(uint32_t *out_ptr, const uint32_t *b_q_weight_ptr, int64_t size_k, + int64_t size_n, int64_t num_bits, + bool is_a_8bit, cudaStream_t stream) { + // Verify compatibility with marlin tile of 16x64 + if (size_k % marlin::tile_k_size != 0) { + std::cout << "size_k = " << size_k << " is not divisible by tile_k_size = " << marlin::tile_k_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (size_n % marlin::tile_n_size != 0) { + std::cout << "size_n = " << size_n << " is not divisible by tile_n_size = " << marlin::tile_n_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_bits != 4 && num_bits != 8) { + std::cout << "num_bits must be 4 or 8. Got = " << num_bits << std::endl; + return INFINI_STATUS_BAD_PARAM; + } + + int const pack_factor = 32 / num_bits; + + // Get dev info + int device_id = 0; + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id); + assert(max_shared_mem > 0 && "max_shared_mem must be greater than 0"); + + if (false) { + } + CALL_IF(4, false) + CALL_IF(8, false) + CALL_IF(4, true) + CALL_IF(8, true) + else { + assert(false && "Unsupported repack config: num_bits, is_a_8bit"); + } + + return INFINI_STATUS_SUCCESS; +} + +namespace op::awq_marlin_repack::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit) { + + auto handle = reinterpret_cast(handle_); + auto result = AwqMarlinRepackInfo::create(output_desc, input_desc, num_bits, is_a_8bit); + + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + result.take(), + workspace_size, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *output, + const void *input, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + + int64_t size_k = static_cast(_info.size_k); + int64_t size_n = static_cast(_info.size_n); + int64_t num_bits = _info.num_bits; + bool is_a_8bit = _info.is_a_8bit; + + awqMarlinRepack((uint32_t *)output, (const uint32_t *)input, size_k, + size_n, num_bits, + is_a_8bit, stream); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::awq_marlin_repack::nvidia +#endif diff --git a/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh new file mode 100644 index 000000000..3cbec6c66 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __AWQ_MARLIN_REPACK_CUDA_CUH__ +#define __AWQ_MARLIN_REPACK_CUDA_CUH__ + +#include "../awq_marlin_repack.h" + +DESCRIPTOR(nvidia) + +#endif // __AWQ_MARLIN_REPACK_CUDA_CUH__ diff --git a/src/infiniop/ops/awq_marlin_repack/operator.cc b/src/infiniop/ops/awq_marlin_repack/operator.cc new file mode 100644 index 000000000..ea02da87d --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/operator.cc @@ -0,0 +1,101 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/awq_marlin_repack.h" + +#if defined ENABLE_NVIDIA_API +#include "nvidia/awq_marlin_repack_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateAwqMarlinRepackDescriptor( + infiniopHandle_t handle, + infiniopAwqMarlinRepackDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::awq_marlin_repack::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + num_bits, \ + is_a_8bit) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetAwqMarlinRepackWorkspaceSize(infiniopAwqMarlinRepackDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopAwqMarlinRepack( + infiniopAwqMarlinRepackDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, output, input, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyAwqMarlinRepackDescriptor(infiniopAwqMarlinRepackDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif diff --git a/src/infiniop/ops/gptq_marlin_repack/cuda/kernel.cuh b/src/infiniop/ops/gptq_marlin_repack/cuda/kernel.cuh new file mode 100644 index 000000000..9424588b8 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/cuda/kernel.cuh @@ -0,0 +1,252 @@ +#include "../marlin/marlin.cuh" + +namespace marlin { + +template +__device__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, uint32_t *__restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = target_tile_k_size / 4; + + int4 *sh_perm_ptr = sh; + int4 *sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = target_tile_k_size / pack_factor; + + constexpr int stage_n_threads = target_tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * target_tile_k_size) / 4; + + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * target_tile_n_size; + + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + } else { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * target_tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; + + constexpr int sh_stride = target_tile_n_size; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + static_assert(!is_a_8bit); + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + +#pragma unroll + for (int i = 0; i < tile_ints; i++) { + if constexpr (is_a_8bit) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8]; + } else { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + } + +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]); + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + if constexpr (is_a_8bit) + vals[4 + i] = (b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask; + else + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = target_tile_k_size * target_tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace marlin diff --git a/src/infiniop/ops/gptq_marlin_repack/gptq_marlin_repack.h b/src/infiniop/ops/gptq_marlin_repack/gptq_marlin_repack.h new file mode 100644 index 000000000..3a059b664 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/gptq_marlin_repack.h @@ -0,0 +1,50 @@ +#ifndef GPTQ_MARLIN_REPACK_H +#define GPTQ_MARLIN_REPACK_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::gptq_marlin_repack::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + GptqMarlinRepackInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + GptqMarlinRepackInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + infiniopTensorDescriptor_t input_desc, \ + infiniopTensorDescriptor_t perm_desc, \ + int64_t num_bits, \ + bool is_a_8bit); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *output, \ + const void *input, \ + const void *perm, \ + void *stream) const; \ + }; \ + } + +#endif // GPTQ_MARLIN_REPACK_H diff --git a/src/infiniop/ops/gptq_marlin_repack/info.h b/src/infiniop/ops/gptq_marlin_repack/info.h new file mode 100644 index 000000000..dd724d358 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/info.h @@ -0,0 +1,55 @@ +#ifndef __AWQ_MARLIN_REPACK_INFO_H__ +#define __AWQ_MARLIN_REPACK_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include "marlin/marlin.cuh" +#include + +#include + +namespace op::gptq_marlin_repack { + +class GptqMarlinRepackInfo { + GptqMarlinRepackInfo() = default; + +public: + infiniDtype_t output_dtype, input_dtype; + size_t size_k, size_n; + int64_t num_bits; + bool is_a_8bit, has_perm; + + static utils::Result create( + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t perm_desc, + int64_t num_bits, + bool is_a_8bit) { + CHECK_OR_RETURN( + output_desc != nullptr && input_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + const infiniDtype_t output_dtype = output_desc->dtype(); + const infiniDtype_t input_dtype = input_desc->dtype(); + + CHECK_DTYPE(input_dtype, output_dtype); + + int const pack_factor = 32 / num_bits; + size_t size_k = input_desc->dim(0) * pack_factor; + size_t size_n = input_desc->dim(1); + + CHECK_OR_RETURN(size_k / marlin::tile_size == output_desc->dim(0) || size_n * marlin::tile_size / pack_factor == output_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + bool has_perm = false; + + if (perm_desc != nullptr && perm_desc->dim(0) != 0) { + has_perm = true; + } + + return utils::Result( + GptqMarlinRepackInfo{output_dtype, input_dtype, size_k, size_n, num_bits, is_a_8bit, has_perm}); + } +}; + +} // namespace op::gptq_marlin_repack + +#endif // __AWQ_MARLIN_REPACK_INFO_H__ diff --git a/src/infiniop/ops/gptq_marlin_repack/marlin/marlin.cuh b/src/infiniop/ops/gptq_marlin_repack/marlin/marlin.cuh new file mode 100644 index 000000000..f3d897d27 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/marlin/marlin.cuh @@ -0,0 +1,178 @@ +#pragma once + +#ifndef _marlin_cuh +#define _marlin_cuh + +#include +#include +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +template +__device__ __forceinline__ uint32_t __cvta_generic_to_shared(T *ptr) { + size_t smem_addr; + asm volatile( + "cvta.to.shared.u64 %0, %1;" + : "=l"(smem_addr) + : "l"(ptr)); + return static_cast(smem_addr); +} + +namespace MARLIN_NAMESPACE_NAME { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; +} + +__device__ inline void cp_async_fence() {} + +template +__device__ inline void cp_async_wait() {} + +#else + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu b/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu new file mode 100644 index 000000000..50e99afbe --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu @@ -0,0 +1,134 @@ +#if defined(ENABLE_NVIDIA_API) +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/kernel.cuh" +#include "gptq_marlin_repack_nvidia.cuh" +#include + +template +INFINIOP_CUDA_KERNEL gptqMarlinRepackKernel( + const uint32_t *__restrict__ b_q_weight_ptr, + const uint32_t *__restrict__ perm_ptr, uint32_t *__restrict__ out_ptr, + int size_k, int size_n) { + + marlin::gptq_marlin_repack_kernel( + b_q_weight_ptr, perm_ptr, out_ptr, + size_k, size_n); +} + +#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM && is_a_8bit == IS_A_8BIT) { \ + cudaFuncSetAttribute( \ + gptqMarlinRepackKernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptqMarlinRepackKernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + +infiniStatus_t gptqMarlinRepack(uint32_t *out_ptr, const uint32_t *b_q_weight_ptr, const uint32_t *perm_ptr, + int64_t size_k, int64_t size_n, int64_t num_bits, + bool is_a_8bit, bool has_perm, cudaStream_t stream) { + + // Verify compatibility with marlin tile of 16x64 + if (size_k % marlin::tile_k_size != 0) { + std::cout << "size_k = " << size_k << " is not divisible by tile_k_size = " << marlin::tile_k_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (size_n % marlin::tile_n_size != 0) { + std::cout << "size_n = " << size_n << " is not divisible by tile_n_size = " << marlin::tile_n_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_bits != 4 && num_bits != 8) { + std::cout << "num_bits must be 4 or 8. Got = " << num_bits << std::endl; + return INFINI_STATUS_BAD_PARAM; + } + + // Get dev info + int device_id = 0; + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id); + assert(max_shared_mem > 0 && "max_shared_mem must be greater than 0"); + + if (false) { + } + CALL_IF(4, false, false) + CALL_IF(4, true, false) + CALL_IF(8, false, false) + CALL_IF(8, true, false) + + CALL_IF(4, false, true) + CALL_IF(8, false, true) + else { + fprintf(stderr, "Unsupported repack config: num_bits = %ld, has_perm = %s, is_a_8bit = %s\n", + num_bits, + has_perm ? "true" : "false", + is_a_8bit ? "true" : "false"); + assert(false); + } + + return INFINI_STATUS_SUCCESS; +} + +namespace op::gptq_marlin_repack::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t perm_desc, + int64_t num_bits, + bool is_a_8bit) { + + auto handle = reinterpret_cast(handle_); + auto result = GptqMarlinRepackInfo::create(output_desc, input_desc, perm_desc, num_bits, is_a_8bit); + + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + result.take(), + workspace_size, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *output, + const void *input, + const void *perm, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + + int64_t size_k = static_cast(_info.size_k); + int64_t size_n = static_cast(_info.size_n); + int64_t num_bits = _info.num_bits; + bool is_a_8bit = _info.is_a_8bit; + bool has_perm = _info.has_perm; + + gptqMarlinRepack((uint32_t *)output, (const uint32_t *)input, (const uint32_t *)perm, + size_k, size_n, num_bits, + is_a_8bit, has_perm, stream); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gptq_marlin_repack::nvidia +#endif diff --git a/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cuh b/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cuh new file mode 100644 index 000000000..25a537c9e --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __GPTQ_MARLIN_REPACK_CUDA_CUH__ +#define __GPTQ_MARLIN_REPACK_CUDA_CUH__ + +#include "../gptq_marlin_repack.h" + +DESCRIPTOR(nvidia) + +#endif // __GPTQ_MARLIN_REPACK_CUDA_CUH__ diff --git a/src/infiniop/ops/gptq_marlin_repack/operator.cc b/src/infiniop/ops/gptq_marlin_repack/operator.cc new file mode 100644 index 000000000..966574877 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_repack/operator.cc @@ -0,0 +1,104 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gptq_marlin_repack.h" + +#if defined ENABLE_NVIDIA_API +#include "nvidia/gptq_marlin_repack_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateGptqMarlinRepackDescriptor( + infiniopHandle_t handle, + infiniopGptqMarlinRepackDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t perm_desc, + int64_t num_bits, + bool is_a_8bit) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gptq_marlin_repack::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + perm_desc, \ + num_bits, \ + is_a_8bit) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetGptqMarlinRepackWorkspaceSize(infiniopGptqMarlinRepackDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopGptqMarlinRepack( + infiniopGptqMarlinRepackDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + const void *perm, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, output, input, perm, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyGptqMarlinRepackDescriptor(infiniopGptqMarlinRepackDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif diff --git a/test/infiniop/awq_marlin_repack.py b/test/infiniop/awq_marlin_repack.py new file mode 100644 index 000000000..3d2b5507e --- /dev/null +++ b/test/infiniop/awq_marlin_repack.py @@ -0,0 +1,443 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + to_torch_dtype, +) +import itertools +import numpy +from libinfiniop.scalar_type import scalar_types, ScalarType +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +import numpy as np + + +GPTQ_MARLIN_TILE = 16 +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 256] + +MARLIN_REPACK_NK_FACTORS = [ + (4, 8), + (7, 5), + (13, 11), +] + +def to_iter(x): + return x if isinstance(x, (list, tuple)) else (x,) + + +_TEST_CASES = list( + itertools.product( + to_iter(MARLIN_K_CHUNKS), + to_iter(MARLIN_N_CHUNKS), + to_iter([scalar_types.uint4]), + to_iter([True, False]), + to_iter(MARLIN_REPACK_NK_FACTORS), + to_iter([128]), + ) +) + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 3e-5, "rtol": 1e-5}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int | None, + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert quant_type.is_integer(), ( + "Floating point quantization may work but has not been tested" + ) + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) + +def get_weight_perm(num_bits: int, is_a_8bit: bool = False): + perm_list: list[int] = [] + if is_a_8bit: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + 4 * (i % 4 + 4), + 4 * (i % 4 + 4) + 1, + 4 * (i % 4 + 4) + 2, + 4 * (i % 4 + 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(2): + perm_list.extend([p + 512 * j for p in perm1]) + else: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 1, 2, 3]) + else: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + +def marlin_permute_weights( + q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False +): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + if is_a_8bit: + # Permute weights to 32x32 marlin tiles + q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile)) + else: + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + +def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def awq_marlin_repack_torch(b_weight, size_k, size_n, group_size, quant_type, is_a_8bit): + # Quantize + w_ref, q_w, s, zp = quantize_weights( + b_weight, quant_type, group_size, zero_points=True + ) + + # Pack to AWQ format + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) + + # Pack to Marlin format + weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit + ) + return marlin_q_w_1 + + +def test( + handle, + device, + k_chunk, + n_chunk, + quant_type, + is_a_8bit, + nk_factors, + group_size=128, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing awq_marlin_repack on {device} with k_chunk:{k_chunk}, n_chunk:{n_chunk}, is_a_8bit:{is_a_8bit}, nk_factors:{nk_factors}, group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + n_factor, k_factor = nk_factors + + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + + b_weight = TestTensor((size_k, size_n), None, dtype, device) + + w_ref, q_w, s, zp = quantize_weights( + b_weight.torch_tensor(), quant_type, group_size, zero_points=True + ) + + # Pack to AWQ format + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) + + ans = awq_marlin_repack_torch(b_weight.torch_tensor(), size_k, size_n, group_size, quant_type, is_a_8bit) + + input = TestTensor( + q_w_awq.shape, + q_w_awq.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=q_w_awq, + ) + output = TestTensor(ans.shape, None, InfiniDtype.I32, device, mode="zeros") + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateAwqMarlinRepackDescriptor( + handle, + ctypes.byref(descriptor), + output.descriptor, + input.descriptor, + quant_type.size_bits, + is_a_8bit, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [ + output, + input, + ]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAwqMarlinRepackWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_awq_marlin_repack(): + check_error( + LIBINFINIOP.infiniopAwqMarlinRepack( + descriptor, + workspace.data(), + workspace_size.value, + output.data(), + input.data(), + None, + ) + ) + + lib_awq_marlin_repack() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(output.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(output.actual_tensor(), ans, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: awq_marlin_repack_torch(b_weight.torch_tensor(), size_k, size_n, group_size, quant_type, is_a_8bit), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lambda: lib_awq_marlin_repack(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error(LIBINFINIOP.infiniopDestroyAwqMarlinRepackDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/gptq_marlin_repack.py b/test/infiniop/gptq_marlin_repack.py new file mode 100644 index 000000000..d013ec947 --- /dev/null +++ b/test/infiniop/gptq_marlin_repack.py @@ -0,0 +1,531 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + to_torch_dtype, + torch_device_map, +) +import itertools +import numpy +from libinfiniop.scalar_type import scalar_types, ScalarType +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +import numpy as np + + +ACT_ORDER_OPTS = [False, True] +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 256] +ACT_ORDER_OPTS = [False, True] +MARLIN_REPACK_NK_FACTORS = [ + (4, 8), + (7, 5), + (13, 11), +] + +def to_iter(x): + return x if isinstance(x, (list, tuple)) else (x,) + + +_TEST_CASES = list( + itertools.product( + to_iter(MARLIN_K_CHUNKS), + to_iter(MARLIN_N_CHUNKS), + to_iter([scalar_types.uint4b8, scalar_types.uint8b128]), + to_iter(ACT_ORDER_OPTS), + to_iter([True, False]), + to_iter(MARLIN_REPACK_NK_FACTORS), + to_iter([128]), + ) +) + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 3e-5, "rtol": 1e-5}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_TILE = 16 + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int | None, + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert quant_type.is_integer(), ( + "Floating point quantization may work but has not been tested" + ) + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: torch.Tensor | None = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: torch.Tensor | None = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, ( + f"Unsupported gptq type = {quant_type}" + ) + assert group_size in SUPPORTED_GROUP_SIZES + [size_k], ( + f"Unsupported groupsize = {group_size}" + ) + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert group_size < size_k, ( + "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + +def get_weight_perm(num_bits: int, is_a_8bit: bool = False): + perm_list: list[int] = [] + if is_a_8bit: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + 4 * (i % 4 + 4), + 4 * (i % 4 + 4) + 1, + 4 * (i % 4 + 4) + 2, + 4 * (i % 4 + 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(2): + perm_list.extend([p + 512 * j for p in perm1]) + else: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 1, 2, 3]) + else: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + +def marlin_permute_weights( + q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False +): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + if is_a_8bit: + # Permute weights to 32x32 marlin tiles + q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile)) + else: + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + +def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + +def gptq_marlin_repack_torch(b_weight, size_k, size_n, group_size, quant_type, act_order, is_a_8bit): + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + b_weight, quant_type, group_size, act_order + ) + + # Pack to GPTQ format + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Pack to Marlin format + weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit + ) + return q_w_gptq, sort_indices, marlin_q_w_1 + +def test( + handle, + device, + k_chunk, + n_chunk, + quant_type, + act_order, + is_a_8bit, + nk_factors, + group_size=128, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing gptq_marlin_repack on {device} with k_chunk:{k_chunk}, n_chunk:{n_chunk}, act_order:{act_order}, is_a_8bit:{is_a_8bit}, nk_factors:{nk_factors}, group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + n_factor, k_factor = nk_factors + + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + if is_a_8bit: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + test_dtype = to_torch_dtype(dtype) + device_str = torch_device_map[device] + b_weight = torch.randn((size_k, size_n), dtype=test_dtype, device=device_str) #must be randn + + q_w_gptq, sort_indices, ans = gptq_marlin_repack_torch(b_weight, size_k, size_n, group_size, quant_type, act_order, is_a_8bit) + + input = TestTensor( + q_w_gptq.shape, + q_w_gptq.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=q_w_gptq, + ) + perm = TestTensor( + sort_indices.shape, + sort_indices.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=sort_indices, + ) + + output = TestTensor(ans.shape, None, InfiniDtype.I32, device, mode="zeros") + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateGptqMarlinRepackDescriptor( + handle, + ctypes.byref(descriptor), + output.descriptor, + input.descriptor, + perm.descriptor, + quant_type.size_bits, + is_a_8bit, + ) + ) + + for tensor in [ + output, + input, + perm, + ]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetGptqMarlinRepackWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_gptq_marlin_repack(): + check_error( + LIBINFINIOP.infiniopGptqMarlinRepack( + descriptor, + workspace.data(), + workspace_size.value, + output.data(), + input.data(), + perm.data(), + None, + ) + ) + + lib_gptq_marlin_repack() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(output.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(output.actual_tensor().to(ans.dtype), ans, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: gptq_marlin_repack_torch(b_weight.torch_tensor(), size_k, size_n, group_size, quant_type, act_order, is_a_8bit), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lambda: lib_gptq_marlin_repack(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error(LIBINFINIOP.infiniopDestroyGptqMarlinRepackDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ad88fcb43..815b91a15 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1401,6 +1401,39 @@ def gptq_marlin_gemm_(lib): ] +@OpRegister.operator +def gptq_marlin_repack_(lib): + lib.infiniopCreateGptqMarlinRepackDescriptor.restype = c_int32 + lib.infiniopCreateGptqMarlinRepackDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int64, + c_bool, + ] + lib.infiniopGetGptqMarlinRepackWorkspaceSize.restype = c_int32 + lib.infiniopGetGptqMarlinRepackWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopGptqMarlinRepack.restype = c_int32 + lib.infiniopGptqMarlinRepack.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyGptqMarlinRepackDescriptor.restype = c_int32 + lib.infiniopDestroyGptqMarlinRepackDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def gptq_qyblas_gemm_(lib): lib.infiniopCreateGptqQyblasGemmDescriptor.restype = c_int32 @@ -1494,6 +1527,40 @@ def awq_marlin_gemm_(lib): ] +@OpRegister.operator +def awq_marlin_repack_(lib): + lib.infiniopCreateAwqMarlinRepackDescriptor.restype = c_int32 + lib.infiniopCreateAwqMarlinRepackDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int64, + c_bool, + ] + + lib.infiniopGetAwqMarlinRepackWorkspaceSize.restype = c_int32 + lib.infiniopGetAwqMarlinRepackWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAwqMarlinRepack.restype = c_int32 + lib.infiniopAwqMarlinRepack.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyAwqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopDestroyAwqMarlinGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32 From 025e755339aee53ea0aac5c767e2ae826d31a552 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 25 Jun 2026 06:50:39 +0000 Subject: [PATCH 2/2] fix: make add graph capture safe for marlin --- .../gptq_marlin_gemm_infiniop.cc | 2 + src/infiniop/ops/add/add.h | 2 +- src/infiniop/ops/add/nvidia/add_nvidia.cu | 109 ++++++++++++++++-- src/infiniop/ops/add/nvidia/add_nvidia.cuh | 42 ++++++- src/infiniop/ops/add/operator.cc | 47 +++++++- .../nvidia/awq_marlin_repack_nvidia.cu | 2 +- .../ops/fused_ffn/nvidia/fused_ffn_nvidia.cu | 2 +- .../nvidia/gptq_marlin_repack_nvidia.cu | 2 +- 8 files changed, 190 insertions(+), 18 deletions(-) diff --git a/src/infinicore/ops/gptq_marlin_gemm/gptq_marlin_gemm_infiniop.cc b/src/infinicore/ops/gptq_marlin_gemm/gptq_marlin_gemm_infiniop.cc index 5d89bb862..e0448e6b8 100644 --- a/src/infinicore/ops/gptq_marlin_gemm/gptq_marlin_gemm_infiniop.cc +++ b/src/infinicore/ops/gptq_marlin_gemm/gptq_marlin_gemm_infiniop.cc @@ -138,6 +138,7 @@ void run_with_workspace(void *planned_meta) { return tensor->numel() == 0 ? nullptr : tensor->data(); }; + context::setDeviceMemoryAsync(planned->workspace->data(), 0, planned->workspace->nbytes(), context::getStream()); INFINICORE_CHECK_ERROR(infiniopGptqMarlinGemm( planned->descriptor->desc, planned->workspace->data(), @@ -183,6 +184,7 @@ void direct_with_workspace(Tensor workspace, Tensor out, const Tensor &a, const return tensor->numel() == 0 ? nullptr : tensor->data(); }; + context::setDeviceMemoryAsync(workspace->data(), 0, workspace->nbytes(), context::getStream()); INFINICORE_CHECK_ERROR(infiniopGptqMarlinGemm( descriptor->desc, workspace->data(), diff --git a/src/infiniop/ops/add/add.h b/src/infiniop/ops/add/add.h index 606e54676..c7caa6b57 100644 --- a/src/infiniop/ops/add/add.h +++ b/src/infiniop/ops/add/add.h @@ -43,4 +43,4 @@ }; \ } -#endif // ADD_H \ No newline at end of file +#endif // ADD_H diff --git a/src/infiniop/ops/add/nvidia/add_nvidia.cu b/src/infiniop/ops/add/nvidia/add_nvidia.cu index 543a89fb2..787f45a0e 100644 --- a/src/infiniop/ops/add/nvidia/add_nvidia.cu +++ b/src/infiniop/ops/add/nvidia/add_nvidia.cu @@ -4,6 +4,86 @@ #include "add_nvidia.cuh" namespace op::add::nvidia { +namespace { + +template +INFINIOP_CUDA_KERNEL addKernel( + size_t output_size, + size_t ndim, + bool output_contiguous, + const bool *__restrict__ input_contiguous, + const bool *__restrict__ input_broadcasted, + const size_t *__restrict__ output_shape, + const size_t *__restrict__ input_shapes, + const ptrdiff_t *__restrict__ output_strides, + const ptrdiff_t *__restrict__ input_strides, + T *output, + const T *__restrict__ a, + const T *__restrict__ b, + size_t offset) { + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; + if (idx >= output_size) { + return; + } + + size_t out_idx = op::elementwise::nvidia::getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides); + op::elementwise::nvidia::InputIndexer indexer{ + idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides}; + output[out_idx] = cuda::AddOp{}(a[indexer(0)], b[indexer(1)]); +} + +template +infiniStatus_t launchAddKernel( + const op::elementwise::ElementwiseInfo &info, + const std::shared_ptr &internal, + void *workspace, + void *output, + const void *a, + const void *b, + cudaStream_t stream) { + + auto output_size = info.getOutputSize(); + if (output_size == 0) { + return INFINI_STATUS_SUCCESS; + } + + auto ndim = info.getNdim(); + auto *d_meta_start = reinterpret_cast(workspace); + CHECK_CUDA(cudaMemcpyAsync(d_meta_start, info.getMetaStart(), info.getMetaMemSize(), cudaMemcpyHostToDevice, stream)); + + auto *d_output_shape = reinterpret_cast(d_meta_start); + auto *d_output_strides = reinterpret_cast(d_output_shape + ndim); + auto *d_input_shapes = reinterpret_cast(d_output_strides + ndim); + auto *d_input_strides = reinterpret_cast(d_input_shapes + info.getInputSize() * ndim); + auto *d_input_contiguous = reinterpret_cast(d_input_strides + info.getInputSize() * ndim); + auto *d_input_broadcasted = reinterpret_cast(d_input_contiguous + info.getInputSize()); + + dim3 block_dims(std::min(256U, static_cast(internal->maxThreadsPerBlock()))); + dim3 grid_dims(std::min(uint32_t(CEIL_DIV(output_size, block_dims.x)), static_cast(internal->gridSizeX()))); + size_t step = grid_dims.x * block_dims.x; + + for (size_t i = 0; i < output_size; i += step) { + addKernel<<>>( + output_size, + ndim, + info.isOutputContiguous(), + d_input_contiguous, + d_input_broadcasted, + d_output_shape, + d_input_shapes, + d_output_strides, + d_input_strides, + reinterpret_cast(output), + reinterpret_cast(a), + reinterpret_cast(b), + i); + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace Descriptor::~Descriptor() = default; @@ -26,8 +106,18 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); - // create CUDA elementwise descriptor - CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc_vec); + CHECK_RESULT(info_result); + auto info = info_result.take(); + auto workspace_size = info.getMetaMemSize(); + + *desc_ptr = new Descriptor( + dtype, + std::move(info), + handle->internal(), + workspace_size, + handle->device, + handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -36,7 +126,8 @@ infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *output, - std::vector inputs, + const void *a, + const void *b, void *stream) const { if (workspace_size < _workspace_size) { @@ -45,17 +136,17 @@ infiniStatus_t Descriptor::calculate( switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, cuda::AddOp, half>(_info, workspace, output, inputs, stream); + return launchAddKernel(_info, _internal, workspace, output, a, b, reinterpret_cast(stream)); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::AddOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + return launchAddKernel(_info, _internal, workspace, output, a, b, reinterpret_cast(stream)); case INFINI_DTYPE_F32: - return _device_info->calculate<256, cuda::AddOp, float>(_info, workspace, output, inputs, stream); + return launchAddKernel(_info, _internal, workspace, output, a, b, reinterpret_cast(stream)); case INFINI_DTYPE_I32: - return _device_info->calculate<256, cuda::AddOp, int32_t>(_info, workspace, output, inputs, stream); + return launchAddKernel(_info, _internal, workspace, output, a, b, reinterpret_cast(stream)); case INFINI_DTYPE_I64: - return _device_info->calculate<256, cuda::AddOp, int64_t>(_info, workspace, output, inputs, stream); + return launchAddKernel(_info, _internal, workspace, output, a, b, reinterpret_cast(stream)); case INFINI_DTYPE_F64: - return _device_info->calculate<256, cuda::AddOp, double>(_info, workspace, output, inputs, stream); + return launchAddKernel(_info, _internal, workspace, output, a, b, reinterpret_cast(stream)); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/add/nvidia/add_nvidia.cuh b/src/infiniop/ops/add/nvidia/add_nvidia.cuh index b9f48250d..1abc71b4f 100644 --- a/src/infiniop/ops/add/nvidia/add_nvidia.cuh +++ b/src/infiniop/ops/add/nvidia/add_nvidia.cuh @@ -1,8 +1,48 @@ #ifndef __ADD_CUDA_API_H__ #define __ADD_CUDA_API_H__ +#include "../../../devices/nvidia/nvidia_handle.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" -ELEMENTWISE_DESCRIPTOR(add, nvidia) +namespace op::add::nvidia { +class Descriptor final : public InfiniopDescriptor { + infiniDtype_t _dtype; + op::elementwise::ElementwiseInfo _info; + std::shared_ptr _internal; + size_t _workspace_size; + + Descriptor( + infiniDtype_t dtype, + op::elementwise::ElementwiseInfo info, + std::shared_ptr internal, + size_t workspace_size, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _dtype(dtype), + _info(std::move(info)), + _internal(std::move(internal)), + _workspace_size(workspace_size) {} + +public: + ~Descriptor(); + + size_t workspaceSize() const { return _workspace_size; } + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + std::vector input_descs); + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *output, + const void *a, + const void *b, + void *stream) const; +}; +} // namespace op::add::nvidia #endif // __ADD_CUDA_API_H__ diff --git a/src/infiniop/ops/add/operator.cc b/src/infiniop/ops/add/operator.cc index 88677c75b..9546a7577 100644 --- a/src/infiniop/ops/add/operator.cc +++ b/src/infiniop/ops/add/operator.cc @@ -23,6 +23,7 @@ #ifdef ENABLE_ASCEND_API #include "ascend/add_ascend.h" #endif +#include __INFINI_C infiniStatus_t infiniopCreateAddDescriptor( infiniopHandle_t handle, @@ -132,6 +133,36 @@ __INFINI_C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t de return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } +namespace { + +template +infiniStatus_t calculateAdd( + const Descriptor *desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { + const std::vector inputs{a, b}; + return desc->calculate(workspace, workspace_size, c, inputs, stream); +} + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API) +infiniStatus_t calculateAdd( + const op::add::nvidia::Descriptor *desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { + return desc->calculate(workspace, workspace_size, c, a, b, stream); +} +#endif + +} // namespace + __INFINI_C infiniStatus_t infiniopAdd( infiniopAddDescriptor_t desc, void *workspace, @@ -141,10 +172,18 @@ __INFINI_C infiniStatus_t infiniopAdd( const void *b, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ - return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, c, {a, b}, stream) +// NVIDIA Add keeps explicit a/b pointers because the generic elementwise +// input-vector path copies inputs.data() from host to device workspace before +// launching the kernel. During CUDA graph capture, that H2D node records the +// host source address; if infiniopAdd used a temporary vector such as {a, b}, +// graph replay could later read from an invalid host address and copy bad input +// pointers into device workspace. Other backends keep their original vector +// interface through calculateAdd's default forwarding path. +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return calculateAdd( \ + reinterpret_cast(desc), \ + workspace, workspace_size, c, a, b, stream) switch (desc->device_type) { diff --git a/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu index ffa7ff44c..c4eb1d1df 100644 --- a/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu +++ b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu @@ -119,4 +119,4 @@ Descriptor::calculate( } } // namespace op::awq_marlin_repack::nvidia -#endif +#endif // ENABLE_NVIDIA_API diff --git a/src/infiniop/ops/fused_ffn/nvidia/fused_ffn_nvidia.cu b/src/infiniop/ops/fused_ffn/nvidia/fused_ffn_nvidia.cu index e50b55c72..18336b224 100644 --- a/src/infiniop/ops/fused_ffn/nvidia/fused_ffn_nvidia.cu +++ b/src/infiniop/ops/fused_ffn/nvidia/fused_ffn_nvidia.cu @@ -423,7 +423,7 @@ infiniStatus_t Descriptor::calculate( if (_opaque->has_residual && !fuse_residual) { CHECK_STATUS(_opaque->residual_add->calculate( inner_ws, inner_ws_size, - out, {out, residual}, stream)); + out, out, residual, stream)); } return INFINI_STATUS_SUCCESS; diff --git a/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu b/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu index 50e99afbe..1c709fc4f 100644 --- a/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu +++ b/src/infiniop/ops/gptq_marlin_repack/nvidia/gptq_marlin_repack_nvidia.cu @@ -131,4 +131,4 @@ Descriptor::calculate( } } // namespace op::gptq_marlin_repack::nvidia -#endif +#endif // ENABLE_NVIDIA_API