From edaf67f45e2b08d2ff73b740d0267ec25f4dc3b4 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 25 Jun 2026 06:11:34 +0000 Subject: [PATCH] feat: use gpu marlin repack for quant weights --- csrc/layers/quantization/marlin_support.hpp | 2 +- csrc/layers/quantization/marlin_utils.cpp | 159 ++++++++++++++++++-- 2 files changed, 151 insertions(+), 10 deletions(-) diff --git a/csrc/layers/quantization/marlin_support.hpp b/csrc/layers/quantization/marlin_support.hpp index c87e131e2..3a230c44f 100644 --- a/csrc/layers/quantization/marlin_support.hpp +++ b/csrc/layers/quantization/marlin_support.hpp @@ -1,6 +1,6 @@ #pragma once -#if __has_include("infinicore/ops/gptq_marlin_gemm.hpp") +#if __has_include("infinicore/ops/gptq_marlin_gemm.hpp") && __has_include("infiniop/ops/awq_marlin_repack.h") && __has_include("infiniop/ops/gptq_marlin_repack.h") #define INFINILM_ENABLE_MARLIN 1 #else #define INFINILM_ENABLE_MARLIN 0 diff --git a/csrc/layers/quantization/marlin_utils.cpp b/csrc/layers/quantization/marlin_utils.cpp index 155a80844..47546e244 100644 --- a/csrc/layers/quantization/marlin_utils.cpp +++ b/csrc/layers/quantization/marlin_utils.cpp @@ -1,14 +1,140 @@ #include "marlin_utils.hpp" +#include "marlin_support.hpp" +#include "infinicore/context/context.hpp" #include #include #include #include +#if INFINILM_ENABLE_MARLIN +#include +#include +#endif + namespace infinilm::quantization::marlin { namespace { +#if INFINILM_ENABLE_MARLIN +void check_infiniop_status(infiniStatus_t status, const char *expr) { + if (status != INFINI_STATUS_SUCCESS) { + throw std::runtime_error(std::string(expr) + " failed with status " + std::to_string(static_cast(status))); + } +} + +template +class DescriptorGuard { +public: + explicit DescriptorGuard(Desc desc) : desc_(desc) {} + DescriptorGuard(const DescriptorGuard &) = delete; + DescriptorGuard &operator=(const DescriptorGuard &) = delete; + ~DescriptorGuard() { + if (desc_ != nullptr) { + Destroy(desc_); + } + } + Desc get() const { + return desc_; + } + +private: + Desc desc_; +}; + +infinicore::Tensor make_workspace(size_t workspace_size, const infinicore::Device &device) { + if (workspace_size == 0) { + return infinicore::Tensor(); + } + return infinicore::Tensor::empty({workspace_size}, infinicore::DataType::U8, device); +} + +void *workspace_data(infinicore::Tensor &workspace) { + return workspace ? workspace->data() : nullptr; +} + +infinicore::Tensor awq_marlin_repack_gpu(const infinicore::Tensor &qweight, size_t size_k, size_t size_n, int num_bits) { + const size_t pack_factor = 32 / num_bits; + auto qweight_contiguous = qweight->is_contiguous() ? qweight : qweight->contiguous(); + auto output = infinicore::Tensor::empty({size_k / 16, size_n * 16 / pack_factor}, infinicore::DataType::I32, qweight_contiguous->device()); + + infiniopAwqMarlinRepackDescriptor_t raw_desc = nullptr; + check_infiniop_status( + infiniopCreateAwqMarlinRepackDescriptor( + infinicore::context::getInfiniopHandle(qweight_contiguous->device()), + &raw_desc, + output->desc(), + qweight_contiguous->desc(), + num_bits, + false), + "infiniopCreateAwqMarlinRepackDescriptor"); + DescriptorGuard desc(raw_desc); + + size_t workspace_size = 0; + check_infiniop_status( + infiniopGetAwqMarlinRepackWorkspaceSize(desc.get(), &workspace_size), + "infiniopGetAwqMarlinRepackWorkspaceSize"); + auto workspace = make_workspace(workspace_size, qweight_contiguous->device()); + + check_infiniop_status( + infiniopAwqMarlinRepack( + desc.get(), + workspace_data(workspace), + workspace_size, + output->data(), + qweight_contiguous->data(), + infinicore::context::getStream()), + "infiniopAwqMarlinRepack"); + infinicore::context::syncStream(); + return output; +} + +infinicore::Tensor gptq_marlin_repack_gpu( + const infinicore::Tensor &qweight, + const infinicore::Tensor &perm, + size_t size_k, + size_t size_n, + int num_bits) { + const size_t pack_factor = 32 / num_bits; + auto qweight_contiguous = qweight->is_contiguous() ? qweight : qweight->contiguous(); + auto output = infinicore::Tensor::empty({size_k / 16, size_n * 16 / pack_factor}, infinicore::DataType::I32, qweight_contiguous->device()); + auto perm_desc = (perm && perm->numel() != 0) ? perm->desc() : nullptr; + const void *perm_data = (perm && perm->numel() != 0) ? perm->data() : nullptr; + + infiniopGptqMarlinRepackDescriptor_t raw_desc = nullptr; + check_infiniop_status( + infiniopCreateGptqMarlinRepackDescriptor( + infinicore::context::getInfiniopHandle(qweight_contiguous->device()), + &raw_desc, + output->desc(), + qweight_contiguous->desc(), + perm_desc, + num_bits, + false), + "infiniopCreateGptqMarlinRepackDescriptor"); + DescriptorGuard desc(raw_desc); + + size_t workspace_size = 0; + check_infiniop_status( + infiniopGetGptqMarlinRepackWorkspaceSize(desc.get(), &workspace_size), + "infiniopGetGptqMarlinRepackWorkspaceSize"); + auto workspace = make_workspace(workspace_size, qweight_contiguous->device()); + + check_infiniop_status( + infiniopGptqMarlinRepack( + desc.get(), + workspace_data(workspace), + workspace_size, + output->data(), + qweight_contiguous->data(), + perm_data, + infinicore::context::getStream()), + "infiniopGptqMarlinRepack"); + infinicore::context::syncStream(); + return output; +} +#endif + std::vector scale_perm() { std::vector perm; perm.reserve(64); @@ -153,13 +279,10 @@ std::vector repack_to_marlin_tiles(size_t size_k, size_t size_n, int nu } if (num_bits == 4) { - out[out_offset + static_cast(th * 4 + warp)] = - static_cast(pack_repack_values(vals, num_bits, false)); + out[out_offset + static_cast(th * 4 + warp)] = static_cast(pack_repack_values(vals, num_bits, false)); } else { - out[out_offset + static_cast(th * 8 + warp * 2)] = - static_cast(pack_repack_values(vals, num_bits, false)); - out[out_offset + static_cast(th * 8 + warp * 2 + 1)] = - static_cast(pack_repack_values(vals, num_bits, true)); + out[out_offset + static_cast(th * 8 + warp * 2)] = static_cast(pack_repack_values(vals, num_bits, false)); + out[out_offset + static_cast(th * 8 + warp * 2 + 1)] = static_cast(pack_repack_values(vals, num_bits, true)); } } } @@ -191,6 +314,14 @@ infinicore::Tensor make_i32_tensor(const std::vector &data, const std:: infinicore::Tensor awq_marlin_repack(const infinicore::Tensor &qweight, size_t size_k, size_t size_n, int num_bits) { check_repack_shape(size_k, size_n, num_bits); const size_t pack_factor = 32 / num_bits; +#if INFINILM_ENABLE_MARLIN + if (qweight->dtype() != infinicore::DataType::I32 || qweight->shape() != std::vector{size_k, size_n / pack_factor}) { + throw std::runtime_error("awq_marlin_repack: unexpected qweight shape or dtype"); + } + if (qweight->device().getType() == infinicore::Device::Type::NVIDIA) { + return awq_marlin_repack_gpu(qweight, size_k, size_n, num_bits); + } +#endif auto cpu = to_cpu_contiguous(qweight); if (cpu->dtype() != infinicore::DataType::I32 || cpu->shape() != std::vector{size_k, size_n / pack_factor}) { throw std::runtime_error("awq_marlin_repack: unexpected qweight shape or dtype"); @@ -205,6 +336,17 @@ infinicore::Tensor awq_marlin_repack(const infinicore::Tensor &qweight, size_t s infinicore::Tensor gptq_marlin_repack(const infinicore::Tensor &qweight, const infinicore::Tensor &perm, size_t size_k, size_t size_n, int num_bits) { check_repack_shape(size_k, size_n, num_bits); const size_t pack_factor = 32 / num_bits; +#if INFINILM_ENABLE_MARLIN + if (qweight->dtype() != infinicore::DataType::I32 || qweight->shape() != std::vector{size_k / pack_factor, size_n}) { + throw std::runtime_error("gptq_marlin_repack: unexpected qweight shape or dtype"); + } + if (perm && perm->numel() != 0 && (perm->dtype() != infinicore::DataType::I32 || perm->numel() != size_k)) { + throw std::runtime_error("gptq_marlin_repack: unexpected perm shape or dtype"); + } + if (qweight->device().getType() == infinicore::Device::Type::NVIDIA) { + return gptq_marlin_repack_gpu(qweight, perm, size_k, size_n, num_bits); + } +#endif auto cpu = to_cpu_contiguous(qweight); if (cpu->dtype() != infinicore::DataType::I32 || cpu->shape() != std::vector{size_k / pack_factor, size_n}) { throw std::runtime_error("gptq_marlin_repack: unexpected qweight shape or dtype"); @@ -283,8 +425,7 @@ infinicore::Tensor awq_to_marlin_zero_points(const infinicore::Tensor &qzeros, s std::vector unpermuted(unpacked.size()); for (size_t row = 0; row < unpacked.size() / undo_interleave.size(); ++row) { for (size_t i = 0; i < undo_interleave.size(); ++i) { - unpermuted[row * undo_interleave.size() + i] = - unpacked[row * undo_interleave.size() + static_cast(undo_interleave[i])]; + unpermuted[row * undo_interleave.size() + i] = unpacked[row * undo_interleave.size() + static_cast(undo_interleave[i])]; } } @@ -300,7 +441,7 @@ infinicore::Tensor awq_to_marlin_zero_points(const infinicore::Tensor &qzeros, s } const std::vector interleave = num_bits == 4 ? std::vector{0, 2, 4, 6, 1, 3, 5, 7} - : std::vector{0, 2, 1, 3}; + : std::vector{0, 2, 1, 3}; std::vector interleaved(permuted.size()); for (size_t row = 0; row < permuted.size() / interleave.size(); ++row) { for (size_t i = 0; i < interleave.size(); ++i) {