Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/layers/quantization/marlin_support.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#if __has_include("infinicore/ops/gptq_marlin_gemm.hpp")
#if __has_include("infinicore/ops/gptq_marlin_gemm.hpp") && __has_include("infiniop/ops/awq_marlin_repack.h") && __has_include("infiniop/ops/gptq_marlin_repack.h")
#define INFINILM_ENABLE_MARLIN 1
#else
#define INFINILM_ENABLE_MARLIN 0
Expand Down
159 changes: 150 additions & 9 deletions csrc/layers/quantization/marlin_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,140 @@
#include "marlin_utils.hpp"
#include "marlin_support.hpp"

#include "infinicore/context/context.hpp"
#include <algorithm>
#include <cstring>
#include <numeric>
#include <stdexcept>

#if INFINILM_ENABLE_MARLIN
#include <infiniop/ops/awq_marlin_repack.h>
#include <infiniop/ops/gptq_marlin_repack.h>
#endif

namespace infinilm::quantization::marlin {

namespace {

#if INFINILM_ENABLE_MARLIN
void check_infiniop_status(infiniStatus_t status, const char *expr) {
if (status != INFINI_STATUS_SUCCESS) {
throw std::runtime_error(std::string(expr) + " failed with status " + std::to_string(static_cast<int>(status)));
}
}

template <typename Desc, auto Destroy>
class DescriptorGuard {
public:
explicit DescriptorGuard(Desc desc) : desc_(desc) {}
DescriptorGuard(const DescriptorGuard &) = delete;
DescriptorGuard &operator=(const DescriptorGuard &) = delete;
~DescriptorGuard() {
if (desc_ != nullptr) {
Destroy(desc_);
}
}
Desc get() const {
return desc_;
}

private:
Desc desc_;
};

infinicore::Tensor make_workspace(size_t workspace_size, const infinicore::Device &device) {
if (workspace_size == 0) {
return infinicore::Tensor();
}
return infinicore::Tensor::empty({workspace_size}, infinicore::DataType::U8, device);
}

void *workspace_data(infinicore::Tensor &workspace) {
return workspace ? workspace->data() : nullptr;
}

infinicore::Tensor awq_marlin_repack_gpu(const infinicore::Tensor &qweight, size_t size_k, size_t size_n, int num_bits) {
const size_t pack_factor = 32 / num_bits;
auto qweight_contiguous = qweight->is_contiguous() ? qweight : qweight->contiguous();
auto output = infinicore::Tensor::empty({size_k / 16, size_n * 16 / pack_factor}, infinicore::DataType::I32, qweight_contiguous->device());

infiniopAwqMarlinRepackDescriptor_t raw_desc = nullptr;
check_infiniop_status(
infiniopCreateAwqMarlinRepackDescriptor(
infinicore::context::getInfiniopHandle(qweight_contiguous->device()),
&raw_desc,
output->desc(),
qweight_contiguous->desc(),
num_bits,
false),
"infiniopCreateAwqMarlinRepackDescriptor");
DescriptorGuard<infiniopAwqMarlinRepackDescriptor_t, infiniopDestroyAwqMarlinRepackDescriptor> desc(raw_desc);

size_t workspace_size = 0;
check_infiniop_status(
infiniopGetAwqMarlinRepackWorkspaceSize(desc.get(), &workspace_size),
"infiniopGetAwqMarlinRepackWorkspaceSize");
auto workspace = make_workspace(workspace_size, qweight_contiguous->device());

check_infiniop_status(
infiniopAwqMarlinRepack(
desc.get(),
workspace_data(workspace),
workspace_size,
output->data(),
qweight_contiguous->data(),
infinicore::context::getStream()),
"infiniopAwqMarlinRepack");
infinicore::context::syncStream();
return output;
}

infinicore::Tensor gptq_marlin_repack_gpu(
const infinicore::Tensor &qweight,
const infinicore::Tensor &perm,
size_t size_k,
size_t size_n,
int num_bits) {
const size_t pack_factor = 32 / num_bits;
auto qweight_contiguous = qweight->is_contiguous() ? qweight : qweight->contiguous();
auto output = infinicore::Tensor::empty({size_k / 16, size_n * 16 / pack_factor}, infinicore::DataType::I32, qweight_contiguous->device());
auto perm_desc = (perm && perm->numel() != 0) ? perm->desc() : nullptr;
const void *perm_data = (perm && perm->numel() != 0) ? perm->data() : nullptr;

infiniopGptqMarlinRepackDescriptor_t raw_desc = nullptr;
check_infiniop_status(
infiniopCreateGptqMarlinRepackDescriptor(
infinicore::context::getInfiniopHandle(qweight_contiguous->device()),
&raw_desc,
output->desc(),
qweight_contiguous->desc(),
perm_desc,
num_bits,
false),
"infiniopCreateGptqMarlinRepackDescriptor");
DescriptorGuard<infiniopGptqMarlinRepackDescriptor_t, infiniopDestroyGptqMarlinRepackDescriptor> desc(raw_desc);

size_t workspace_size = 0;
check_infiniop_status(
infiniopGetGptqMarlinRepackWorkspaceSize(desc.get(), &workspace_size),
"infiniopGetGptqMarlinRepackWorkspaceSize");
auto workspace = make_workspace(workspace_size, qweight_contiguous->device());

check_infiniop_status(
infiniopGptqMarlinRepack(
desc.get(),
workspace_data(workspace),
workspace_size,
output->data(),
qweight_contiguous->data(),
perm_data,
infinicore::context::getStream()),
"infiniopGptqMarlinRepack");
infinicore::context::syncStream();
return output;
}
#endif

std::vector<int> scale_perm() {
std::vector<int> perm;
perm.reserve(64);
Expand Down Expand Up @@ -153,13 +279,10 @@ std::vector<int32_t> repack_to_marlin_tiles(size_t size_k, size_t size_n, int nu
}

if (num_bits == 4) {
out[out_offset + static_cast<size_t>(th * 4 + warp)] =
static_cast<int32_t>(pack_repack_values(vals, num_bits, false));
out[out_offset + static_cast<size_t>(th * 4 + warp)] = static_cast<int32_t>(pack_repack_values(vals, num_bits, false));
} else {
out[out_offset + static_cast<size_t>(th * 8 + warp * 2)] =
static_cast<int32_t>(pack_repack_values(vals, num_bits, false));
out[out_offset + static_cast<size_t>(th * 8 + warp * 2 + 1)] =
static_cast<int32_t>(pack_repack_values(vals, num_bits, true));
out[out_offset + static_cast<size_t>(th * 8 + warp * 2)] = static_cast<int32_t>(pack_repack_values(vals, num_bits, false));
out[out_offset + static_cast<size_t>(th * 8 + warp * 2 + 1)] = static_cast<int32_t>(pack_repack_values(vals, num_bits, true));
}
}
}
Expand Down Expand Up @@ -191,6 +314,14 @@ infinicore::Tensor make_i32_tensor(const std::vector<int32_t> &data, const std::
infinicore::Tensor awq_marlin_repack(const infinicore::Tensor &qweight, size_t size_k, size_t size_n, int num_bits) {
check_repack_shape(size_k, size_n, num_bits);
const size_t pack_factor = 32 / num_bits;
#if INFINILM_ENABLE_MARLIN
if (qweight->dtype() != infinicore::DataType::I32 || qweight->shape() != std::vector<size_t>{size_k, size_n / pack_factor}) {
throw std::runtime_error("awq_marlin_repack: unexpected qweight shape or dtype");
}
if (qweight->device().getType() == infinicore::Device::Type::NVIDIA) {
return awq_marlin_repack_gpu(qweight, size_k, size_n, num_bits);
}
#endif
auto cpu = to_cpu_contiguous(qweight);
if (cpu->dtype() != infinicore::DataType::I32 || cpu->shape() != std::vector<size_t>{size_k, size_n / pack_factor}) {
throw std::runtime_error("awq_marlin_repack: unexpected qweight shape or dtype");
Expand All @@ -205,6 +336,17 @@ infinicore::Tensor awq_marlin_repack(const infinicore::Tensor &qweight, size_t s
infinicore::Tensor gptq_marlin_repack(const infinicore::Tensor &qweight, const infinicore::Tensor &perm, size_t size_k, size_t size_n, int num_bits) {
check_repack_shape(size_k, size_n, num_bits);
const size_t pack_factor = 32 / num_bits;
#if INFINILM_ENABLE_MARLIN
if (qweight->dtype() != infinicore::DataType::I32 || qweight->shape() != std::vector<size_t>{size_k / pack_factor, size_n}) {
throw std::runtime_error("gptq_marlin_repack: unexpected qweight shape or dtype");
}
if (perm && perm->numel() != 0 && (perm->dtype() != infinicore::DataType::I32 || perm->numel() != size_k)) {
throw std::runtime_error("gptq_marlin_repack: unexpected perm shape or dtype");
}
if (qweight->device().getType() == infinicore::Device::Type::NVIDIA) {
return gptq_marlin_repack_gpu(qweight, perm, size_k, size_n, num_bits);
}
#endif
auto cpu = to_cpu_contiguous(qweight);
if (cpu->dtype() != infinicore::DataType::I32 || cpu->shape() != std::vector<size_t>{size_k / pack_factor, size_n}) {
throw std::runtime_error("gptq_marlin_repack: unexpected qweight shape or dtype");
Expand Down Expand Up @@ -283,8 +425,7 @@ infinicore::Tensor awq_to_marlin_zero_points(const infinicore::Tensor &qzeros, s
std::vector<int32_t> unpermuted(unpacked.size());
for (size_t row = 0; row < unpacked.size() / undo_interleave.size(); ++row) {
for (size_t i = 0; i < undo_interleave.size(); ++i) {
unpermuted[row * undo_interleave.size() + i] =
unpacked[row * undo_interleave.size() + static_cast<size_t>(undo_interleave[i])];
unpermuted[row * undo_interleave.size() + i] = unpacked[row * undo_interleave.size() + static_cast<size_t>(undo_interleave[i])];
}
}

Expand All @@ -300,7 +441,7 @@ infinicore::Tensor awq_to_marlin_zero_points(const infinicore::Tensor &qzeros, s
}

const std::vector<int> interleave = num_bits == 4 ? std::vector<int>{0, 2, 4, 6, 1, 3, 5, 7}
: std::vector<int>{0, 2, 1, 3};
: std::vector<int>{0, 2, 1, 3};
std::vector<int32_t> interleaved(permuted.size());
for (size_t row = 0; row < permuted.size() / interleave.size(); ++row) {
for (size_t i = 0; i < interleave.size(); ++i) {
Expand Down
Loading