From c70805c9595e5e165c3e98f178ba17a736a3c86d Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Mon, 2 Mar 2026 14:42:44 +0800 Subject: [PATCH 1/2] issue/1035: kv caching on nvidia --- src/infiniop/ops/kv_caching/cuda/kernel.cuh | 63 ++++++ src/infiniop/ops/kv_caching/info.h | 105 +++++++++ src/infiniop/ops/kv_caching/kv_caching.h | 49 +++++ .../kv_caching/nvidia/kv_caching_nvidia.cu | 159 ++++++++++++++ .../kv_caching/nvidia/kv_caching_nvidia.cuh | 7 + src/infiniop/ops/kv_caching/operator.cc | 32 +++ test/infiniop/kv_caching.py | 205 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 42 ++++ 8 files changed, 662 insertions(+) create mode 100644 src/infiniop/ops/kv_caching/cuda/kernel.cuh create mode 100644 src/infiniop/ops/kv_caching/info.h create mode 100644 src/infiniop/ops/kv_caching/kv_caching.h create mode 100644 src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu create mode 100644 src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh create mode 100644 test/infiniop/kv_caching.py diff --git a/src/infiniop/ops/kv_caching/cuda/kernel.cuh b/src/infiniop/ops/kv_caching/cuda/kernel.cuh new file mode 100644 index 000000000..9b20ca149 --- /dev/null +++ b/src/infiniop/ops/kv_caching/cuda/kernel.cuh @@ -0,0 +1,63 @@ +#ifndef __KV_CACHING_KERNEL_CUH__ +#define __KV_CACHING_KERNEL_CUH__ + +template +__device__ void kvCachingKernel( + Tdata *__restrict__ k_cache, + Tdata *__restrict__ v_cache, + const Tdata *__restrict__ k, + const Tdata *__restrict__ v, + const int64_t *__restrict__ past_kv_lengths, + int batch_size, + int num_kv_heads, + int max_seq_len, + int seq_len, + int hidden_dim, + ptrdiff_t k_cache_strides_0, + ptrdiff_t k_cache_strides_1, + ptrdiff_t k_cache_strides_2, + ptrdiff_t k_cache_strides_3, + ptrdiff_t v_cache_strides_0, + ptrdiff_t v_cache_strides_1, + ptrdiff_t v_cache_strides_2, + ptrdiff_t v_cache_strides_3, + ptrdiff_t k_strides_0, + ptrdiff_t k_strides_1, + ptrdiff_t k_strides_2, + ptrdiff_t k_strides_3, + ptrdiff_t v_strides_0, + ptrdiff_t v_strides_1, + ptrdiff_t v_strides_2, + ptrdiff_t v_strides_3) { + // 总元素数 = B * H * seq_len * D + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total = batch_size * num_kv_heads * seq_len * hidden_dim; + + const int grid_size = blockDim.x * gridDim.x; + + for (int idx = tid; idx < total; idx += grid_size) { + // 反解 index + + int d = idx % hidden_dim; + idx /= hidden_dim; + + int s = idx % seq_len; + idx /= seq_len; + + int h = idx % num_kv_heads; + int b = idx / num_kv_heads; + + int past_len = static_cast(past_kv_lengths[b]); + // 写入位置 + int cache_s = past_len + s; + int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0; + int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0; + + int k_src_offset = d * (int)k_strides_3 + s * (int)k_strides_2 + h * (int)k_strides_1 + b * (int)k_strides_0; + int v_src_offset = d * (int)v_strides_3 + s * (int)v_strides_2 + h * (int)v_strides_1 + b * (int)v_strides_0; + k_cache[k_cache_offset] = k[k_src_offset]; + v_cache[v_cache_offset] = v[v_src_offset]; + } +} + +#endif // __KV_CACHING_KERNEL_CUH__ diff --git a/src/infiniop/ops/kv_caching/info.h b/src/infiniop/ops/kv_caching/info.h new file mode 100644 index 000000000..a595b15a9 --- /dev/null +++ b/src/infiniop/ops/kv_caching/info.h @@ -0,0 +1,105 @@ +#ifndef __KV_CACHING_INFO_H__ +#define __KV_CACHING_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" + +namespace op::kv_caching { + +class KVCachingInfo { +private: + KVCachingInfo() = default; + +public: + infiniDtype_t dtype; + size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim; + ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3; + ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3; + ptrdiff_t k_strides_0, k_strides_1, k_strides_2, k_strides_3; + ptrdiff_t v_strides_0, v_strides_1, v_strides_2, v_strides_3; + + static utils::Result createKVCachingInfo( + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + + CHECK_OR_RETURN( + k_cache != nullptr && v_cache != nullptr && k != nullptr && v != nullptr && past_kv_lengths != nullptr, + INFINI_STATUS_NULL_POINTER); + + const infiniDtype_t dtype = k_cache->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + + CHECK_OR_RETURN(k_cache->ndim() == 4 + && v_cache->ndim() == 4 + && k->ndim() == 4 + && v->ndim() == 4, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto shape = k_cache->shape(); + CHECK_SAME_SHAPE(shape, v_cache->shape()); + CHECK_SAME_SHAPE(k->shape(), v->shape()); + + size_t batch_size = shape[0]; + size_t num_kv_heads = shape[1]; + size_t max_seq_len = shape[2]; + size_t hidden_dim = shape[3]; + + size_t seq_len = k->shape()[2]; + + CHECK_OR_RETURN(batch_size == k->dim(0) + || num_kv_heads == k->dim(1) + || hidden_dim == k->dim(3), + INFINI_STATUS_BAD_TENSOR_SHAPE); + + ptrdiff_t k_cache_strides_0 = k_cache->strides()[0]; + ptrdiff_t k_cache_strides_1 = k_cache->strides()[1]; + ptrdiff_t k_cache_strides_2 = k_cache->strides()[2]; + ptrdiff_t k_cache_strides_3 = k_cache->strides()[3]; + + ptrdiff_t v_cache_strides_0 = v_cache->strides()[0]; + ptrdiff_t v_cache_strides_1 = v_cache->strides()[1]; + ptrdiff_t v_cache_strides_2 = v_cache->strides()[2]; + ptrdiff_t v_cache_strides_3 = v_cache->strides()[3]; + + ptrdiff_t k_strides_0 = k->strides()[0]; + ptrdiff_t k_strides_1 = k->strides()[1]; + ptrdiff_t k_strides_2 = k->strides()[2]; + ptrdiff_t k_strides_3 = k->strides()[3]; + + ptrdiff_t v_strides_0 = v->strides()[0]; + ptrdiff_t v_strides_1 = v->strides()[1]; + ptrdiff_t v_strides_2 = v->strides()[2]; + ptrdiff_t v_strides_3 = v->strides()[3]; + + return utils::Result(KVCachingInfo{ + dtype, + batch_size, + num_kv_heads, + max_seq_len, + seq_len, + hidden_dim, + k_cache_strides_0, + k_cache_strides_1, + k_cache_strides_2, + k_cache_strides_3, + v_cache_strides_0, + v_cache_strides_1, + v_cache_strides_2, + v_cache_strides_3, + k_strides_0, + k_strides_1, + k_strides_2, + k_strides_3, + v_strides_0, + v_strides_1, + v_strides_2, + v_strides_3}); + } +}; +} // namespace op::kv_caching + +#endif // __KV_CACHING_INFO_H__ diff --git a/src/infiniop/ops/kv_caching/kv_caching.h b/src/infiniop/ops/kv_caching/kv_caching.h new file mode 100644 index 000000000..e90a0db27 --- /dev/null +++ b/src/infiniop/ops/kv_caching/kv_caching.h @@ -0,0 +1,49 @@ +#ifndef KV_CACHING_H +#define KV_CACHING_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::kv_caching::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + KVCachingInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + KVCachingInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t get_workspace_size() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t k_cache, \ + infiniopTensorDescriptor_t v_cache, \ + infiniopTensorDescriptor_t k, \ + infiniopTensorDescriptor_t v, \ + infiniopTensorDescriptor_t past_kv_lengths); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *k_cache, void *v_cache, \ + const void *k, const void *v, const void *past_kv_lengths, \ + void *stream) const; \ + }; \ + } + +#endif // KV_CACHING_H diff --git a/src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu b/src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu new file mode 100644 index 000000000..c47ee8e53 --- /dev/null +++ b/src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cu @@ -0,0 +1,159 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "kv_caching_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_CUDA_KERNEL kvCaching( + Tdata *k_cache, + Tdata *v_cache, + const Tdata *k, + const Tdata *v, + const int64_t *past_kv_lengths, + int batch_size, + int num_kv_heads, + int max_seq_len, + int seq_len, + int hidden_dim, + ptrdiff_t k_cache_strides_0, + ptrdiff_t k_cache_strides_1, + ptrdiff_t k_cache_strides_2, + ptrdiff_t k_cache_strides_3, + ptrdiff_t v_cache_strides_0, + ptrdiff_t v_cache_strides_1, + ptrdiff_t v_cache_strides_2, + ptrdiff_t v_cache_strides_3, + ptrdiff_t k_strides_0, + ptrdiff_t k_strides_1, + ptrdiff_t k_strides_2, + ptrdiff_t k_strides_3, + ptrdiff_t v_strides_0, + ptrdiff_t v_strides_1, + ptrdiff_t v_strides_2, + ptrdiff_t v_strides_3) { + kvCachingKernel(k_cache, v_cache, k, v, past_kv_lengths, + batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, + k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, + v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, + k_strides_0, k_strides_1, k_strides_2, k_strides_3, + v_strides_0, v_strides_1, v_strides_2, v_strides_3); +} + +namespace op::kv_caching::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel(const KVCachingInfo &info, + Tdata *k_cache, + Tdata *v_cache, + const Tdata *k, + const Tdata *v, + const int64_t *past_kv_lengths, + cudaStream_t stream, void *workspace) { + + int batch_size = static_cast(info.batch_size); + int num_kv_heads = static_cast(info.num_kv_heads); + int max_seq_len = static_cast(info.max_seq_len); + int hidden_dim = static_cast(info.hidden_dim); + + int seq_len = static_cast(info.seq_len); + + int total = batch_size * num_kv_heads * seq_len * hidden_dim; + + ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0; + ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1; + ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2; + ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3; + + ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0; + ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1; + ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2; + ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3; + + ptrdiff_t k_strides_0 = info.k_strides_0; + ptrdiff_t k_strides_1 = info.k_strides_1; + ptrdiff_t k_strides_2 = info.k_strides_2; + ptrdiff_t k_strides_3 = info.k_strides_3; + + ptrdiff_t v_strides_0 = info.v_strides_0; + ptrdiff_t v_strides_1 = info.v_strides_1; + ptrdiff_t v_strides_2 = info.v_strides_2; + ptrdiff_t v_strides_3 = info.v_strides_3; + + int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + + kvCaching + <<>>(k_cache, v_cache, k, v, past_kv_lengths, + batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, + k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, + v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, + k_strides_0, k_strides_1, k_strides_2, k_strides_3, + v_strides_0, v_strides_1, v_strides_2, v_strides_3); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; +#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \ + launchKernel(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace) +#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_KV_CACHING(BLOCK_SIZE, __nv_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::kv_caching::nvidia diff --git a/src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh b/src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh new file mode 100644 index 000000000..91eaa3f4b --- /dev/null +++ b/src/infiniop/ops/kv_caching/nvidia/kv_caching_nvidia.cuh @@ -0,0 +1,7 @@ +#ifndef __KV_CACHING_NVIDIA_API_H__ +#define __KV_CACHING_NVIDIA_API_H__ +#include "../kv_caching.h" + +DESCRIPTOR(nvidia) + +#endif // __KV_CACHING_NVIDIA_API_H__ diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc index 34bdf9a99..2081e357d 100644 --- a/src/infiniop/ops/kv_caching/operator.cc +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -8,6 +8,10 @@ #endif #endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#include "nvidia/kv_caching_nvidia.cuh" +#endif + __C infiniStatus_t infiniopCreateKVCachingDescriptor( infiniopHandle_t handle, infiniopKVCachingDescriptor_t *desc_ptr, @@ -42,6 +46,13 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor( #endif #endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -71,6 +82,13 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize( #if defined(ENABLE_METAX_API) GET_SIZE(INFINI_DEVICE_METAX, ninetoothed); #endif +#endif + +#ifdef ENABLE_NVIDIA_API + GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + GET_SIZE(INFINI_DEVICE_QY, nvidia); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -107,6 +125,13 @@ __C infiniStatus_t infiniopKVCaching( #if defined(ENABLE_METAX_API) CALCULATE(INFINI_DEVICE_METAX, ninetoothed); #endif +#endif + +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -135,6 +160,13 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor( #if defined(ENABLE_METAX_API) DELETE(INFINI_DEVICE_METAX, ninetoothed); #endif +#endif + +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_QY_API + DELETE(INFINI_DEVICE_QY, nvidia); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infiniop/kv_caching.py b/test/infiniop/kv_caching.py new file mode 100644 index 000000000..1e5bb3cff --- /dev/null +++ b/test/infiniop/kv_caching.py @@ -0,0 +1,205 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + + +# ============================================================================== +# Reference Implementation +# ============================================================================== +def torch_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + #k_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim] + #v_cache.shape=[batch_size, num_kv_heads, max_seq_len, hidden_dim] + #k.shape=[batch_size, num_kv_heads, seq_len, hidden_dim] + #v.shape=[batch_size, num_kv_heads, seq_len, hidden_dim] + #past_kv_lengths.shape = [batch_size] + batch_size, num_kv_heads, _, head_dim = k_cache.shape + seq_len = k.shape[2] + + for b in range(batch_size): + past_len = past_kv_lengths[b].item() + for h in range(num_kv_heads): + k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :] + v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :] + + return k_cache, v_cache + + +# ============================================================================== +# Test Configuration (Internal Use Only) +# ============================================================================== +_TEST_CASES_ = [ + # (num_seqs, num_kv_heads, max_seq_len, hidden_dim), strides + ((1, 1, 8, 1), None), + ((1, 8, 32, 32), None), + ((8, 8, 64, 32), None), + ((1, 32, 8, 64), (32768, 1024, 64, 1)), + ((4, 8, 32, 16), (65536, 8192, 256, 16)), + ((8, 16, 64, 128), (8388608, 524288, 8192, 1)), + ((1, 2, 2304, 128), (589824, 294912, 128, 1)), +] + +# Data types for testing +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0, "rtol": 0}, + InfiniDtype.BF16: {"atol": 0, "rtol": 0}, + InfiniDtype.F32: {"atol": 0, "rtol": 0}, +} + +# Global flags for controlling test behavior +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +def test( + handle, + device, + cache_shape, + strides, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing KVCaching on {InfiniDeviceNames[device]} with cache_shape:{cache_shape}, strides:{strides}, dtype={InfiniDtypeNames[dtype]}" + ) + + import random + + kv_shape = ( + cache_shape[0], + cache_shape[1], + random.randrange(1, cache_shape[2]), + cache_shape[3], + ) + past_shape = (cache_shape[0],) + + k_cache = TestTensor(cache_shape, strides, dtype, device) + v_cache = TestTensor(cache_shape, strides, dtype, device) + + k = TestTensor(kv_shape, None, dtype, device) + v = TestTensor(kv_shape, None, dtype, device) + + past_kv_lengths = TestTensor(past_shape, None, InfiniDtype.I64, device, randint_low=0, randint_high=cache_shape[2] - kv_shape[2]) + + # Run reference implementation + k_cache_ref, v_cache_ref = torch_kv_caching( + k_cache.torch_tensor(), + v_cache.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + past_kv_lengths.torch_tensor()) + + if sync: + sync() + + # Create operator descriptor + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateKVCachingDescriptor( + handle, + ctypes.byref(descriptor), + k_cache.descriptor, + v_cache.descriptor, + k.descriptor, + v.descriptor, + past_kv_lengths.descriptor, + ) + ) + + # Get workspace size (likely 0 for this operator, but good practice to include) + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetKVCachingWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + # Invalidate descriptors to ensure kernel does not rely on them + k.destroy_desc() + v.destroy_desc() + k_cache.destroy_desc() + v_cache.destroy_desc() + past_kv_lengths.destroy_desc() + + # Define the library call as a lambda for profiling + def lib_kv_caching(): + check_error( + LIBINFINIOP.infiniopKVCaching( + descriptor, + workspace.data(), + workspace_size.value, + k_cache.data(), + v_cache.data(), + k.data(), + v.data(), + past_kv_lengths.data(), + None, + ) + ) + + # Execute the custom operator + lib_kv_caching() + + if sync: + sync() + + # Verify correctness + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + print("Verifying K cache...") + debug(k_cache.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol) + print("Verifying V cache...") + debug(v_cache.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol) + + assert torch.allclose( + k_cache.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol + ) + assert torch.allclose( + v_cache.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol + ) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_kv_caching(k_cache.torch_tensor(), v_cache.torch_tensor(), k.torch_tensor(), v.torch_tensor(), past_kv_lengths.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lib_kv_caching, device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + # Clean up resources + check_error(LIBINFINIOP.infiniopDestroyKVCachingDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options from command line arguments + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 00283ad3e..275689e78 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1054,6 +1054,48 @@ def scaled_mm_int8_(lib): ] + +@OpRegister.operator +def kv_caching_(lib): + lib.infiniopCreateKVCachingDescriptor.restype = c_int32 + lib.infiniopCreateKVCachingDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + + lib.infiniopGetKVCachingWorkspaceSize.restype = c_int32 + lib.infiniopGetKVCachingWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + + lib.infiniopKVCaching.restype = c_int32 + lib.infiniopKVCaching.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + + lib.infiniopDestroyKVCachingDescriptor.restype = c_int32 + lib.infiniopDestroyKVCachingDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def paged_attention_(lib): lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32 From af394a32c02cf72867963aa232d8148a2d39b1b3 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Tue, 3 Mar 2026 07:19:55 +0000 Subject: [PATCH 2/2] issue/1035 - kv caching on ali, ilu, hygon, qy, metax --- src/infiniop/ops/kv_caching/cuda/kernel.cuh | 6 +- src/infiniop/ops/kv_caching/info.h | 1 + .../ops/kv_caching/metax/kv_caching_metax.h | 7 + .../kv_caching/metax/kv_caching_metax.maca | 160 ++++++++++++++++++ src/infiniop/ops/kv_caching/operator.cc | 111 ++++++------ 5 files changed, 227 insertions(+), 58 deletions(-) create mode 100644 src/infiniop/ops/kv_caching/metax/kv_caching_metax.h create mode 100644 src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca diff --git a/src/infiniop/ops/kv_caching/cuda/kernel.cuh b/src/infiniop/ops/kv_caching/cuda/kernel.cuh index 9b20ca149..658bff5e1 100644 --- a/src/infiniop/ops/kv_caching/cuda/kernel.cuh +++ b/src/infiniop/ops/kv_caching/cuda/kernel.cuh @@ -29,14 +29,14 @@ __device__ void kvCachingKernel( ptrdiff_t v_strides_1, ptrdiff_t v_strides_2, ptrdiff_t v_strides_3) { - // 总元素数 = B * H * seq_len * D + // num of ele = B * H * seq_len * D int tid = blockIdx.x * blockDim.x + threadIdx.x; int total = batch_size * num_kv_heads * seq_len * hidden_dim; const int grid_size = blockDim.x * gridDim.x; for (int idx = tid; idx < total; idx += grid_size) { - // 反解 index + // unravel index int d = idx % hidden_dim; idx /= hidden_dim; @@ -48,7 +48,7 @@ __device__ void kvCachingKernel( int b = idx / num_kv_heads; int past_len = static_cast(past_kv_lengths[b]); - // 写入位置 + // write position int cache_s = past_len + s; int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0; int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0; diff --git a/src/infiniop/ops/kv_caching/info.h b/src/infiniop/ops/kv_caching/info.h index a595b15a9..8df348600 100644 --- a/src/infiniop/ops/kv_caching/info.h +++ b/src/infiniop/ops/kv_caching/info.h @@ -32,6 +32,7 @@ class KVCachingInfo { const infiniDtype_t dtype = k_cache->dtype(); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + CHECK_DTYPE(past_kv_lengths->dtype(), INFINI_DTYPE_I64); CHECK_OR_RETURN(k_cache->ndim() == 4 && v_cache->ndim() == 4 diff --git a/src/infiniop/ops/kv_caching/metax/kv_caching_metax.h b/src/infiniop/ops/kv_caching/metax/kv_caching_metax.h new file mode 100644 index 000000000..083e89de0 --- /dev/null +++ b/src/infiniop/ops/kv_caching/metax/kv_caching_metax.h @@ -0,0 +1,7 @@ +#ifndef __KV_CACHING_METAX_API_H__ +#define __KV_CACHING_METAX_API_H__ +#include "../kv_caching.h" + +DESCRIPTOR(metax) + +#endif // __KV_CACHING_METAX_API_H__ diff --git a/src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca b/src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca new file mode 100644 index 000000000..11c776f69 --- /dev/null +++ b/src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca @@ -0,0 +1,160 @@ +#include "../../../devices/metax/metax_common.h" +#include "kv_caching_metax.h" + +#include "../../../devices/metax/metax_kernel_common.h" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_METAX_KERNEL kvCaching( + Tdata *k_cache, + Tdata *v_cache, + const Tdata *k, + const Tdata *v, + const int64_t *past_kv_lengths, + int batch_size, + int num_kv_heads, + int max_seq_len, + int seq_len, + int hidden_dim, + ptrdiff_t k_cache_strides_0, + ptrdiff_t k_cache_strides_1, + ptrdiff_t k_cache_strides_2, + ptrdiff_t k_cache_strides_3, + ptrdiff_t v_cache_strides_0, + ptrdiff_t v_cache_strides_1, + ptrdiff_t v_cache_strides_2, + ptrdiff_t v_cache_strides_3, + ptrdiff_t k_strides_0, + ptrdiff_t k_strides_1, + ptrdiff_t k_strides_2, + ptrdiff_t k_strides_3, + ptrdiff_t v_strides_0, + ptrdiff_t v_strides_1, + ptrdiff_t v_strides_2, + ptrdiff_t v_strides_3) { + kvCachingKernel(k_cache, v_cache, k, v, past_kv_lengths, + batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, + k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, + v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, + k_strides_0, k_strides_1, k_strides_2, k_strides_3, + v_strides_0, v_strides_1, v_strides_2, v_strides_3); +} + +namespace op::kv_caching::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel(const KVCachingInfo &info, + Tdata *k_cache, + Tdata *v_cache, + const Tdata *k, + const Tdata *v, + const int64_t *past_kv_lengths, + hcStream_t stream, void *workspace) { + + int batch_size = static_cast(info.batch_size); + int num_kv_heads = static_cast(info.num_kv_heads); + int max_seq_len = static_cast(info.max_seq_len); + int hidden_dim = static_cast(info.hidden_dim); + + int seq_len = static_cast(info.seq_len); + + int total = batch_size * num_kv_heads * seq_len * hidden_dim; + + ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0; + ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1; + ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2; + ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3; + + ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0; + ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1; + ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2; + ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3; + + ptrdiff_t k_strides_0 = info.k_strides_0; + ptrdiff_t k_strides_1 = info.k_strides_1; + ptrdiff_t k_strides_2 = info.k_strides_2; + ptrdiff_t k_strides_3 = info.k_strides_3; + + ptrdiff_t v_strides_0 = info.v_strides_0; + ptrdiff_t v_strides_1 = info.v_strides_1; + ptrdiff_t v_strides_2 = info.v_strides_2; + ptrdiff_t v_strides_3 = info.v_strides_3; + + int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + + kvCaching + <<>>(k_cache, v_cache, k, v, past_kv_lengths, + batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, + k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, + v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, + k_strides_0, k_strides_1, k_strides_2, k_strides_3, + v_strides_0, v_strides_1, v_strides_2, v_strides_3); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream_) const { + hcStream_t stream = (hcStream_t)stream_; +#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \ + launchKernel(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace) +#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + + if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024) + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512) + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048) + } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) { + CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::kv_caching::metax diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc index 2081e357d..46c203197 100644 --- a/src/infiniop/ops/kv_caching/operator.cc +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -2,15 +2,12 @@ #include "../../handle.h" #include "infiniop/ops/kv_caching.h" -#if defined(ENABLE_NINETOOTHED) -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API) -#include "ninetoothed/kv_caching.h" -#endif -#endif - -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_ALI_API) || defined(ENABLE_HYGON_API) #include "nvidia/kv_caching_nvidia.cuh" #endif +#if defined(ENABLE_METAX_API) +#include "metax/kv_caching_metax.h" +#endif __C infiniStatus_t infiniopCreateKVCachingDescriptor( infiniopHandle_t handle, @@ -34,24 +31,24 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor( switch (handle->device) { -#if defined(ENABLE_NINETOOTHED) -#if defined(ENABLE_NVIDIA_API) - CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); -#endif -#if defined(ENABLE_ILUVATAR_API) - CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); -#endif -#if defined(ENABLE_METAX_API) - CREATE(INFINI_DEVICE_METAX, ninetoothed); -#endif -#endif - #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_ALI_API + CREATE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#if defined(ENABLE_METAX_API) + CREATE(INFINI_DEVICE_METAX, metax); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -72,24 +69,25 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize( switch (desc->device_type) { -#if defined(ENABLE_NINETOOTHED) -#if defined(ENABLE_NVIDIA_API) - GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); -#endif -#if defined(ENABLE_ILUVATAR_API) - GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed); -#endif -#if defined(ENABLE_METAX_API) - GET_SIZE(INFINI_DEVICE_METAX, ninetoothed); -#endif -#endif - #ifdef ENABLE_NVIDIA_API GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); #endif #ifdef ENABLE_QY_API GET_SIZE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_ILUVATAR_API + GET_SIZE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_ALI_API + GET_SIZE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_HYGON_API + GET_SIZE(INFINI_DEVICE_HYGON, nvidia); +#endif +#if defined(ENABLE_METAX_API) + GET_SIZE(INFINI_DEVICE_METAX, metax); +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -115,24 +113,25 @@ __C infiniStatus_t infiniopKVCaching( switch (desc->device_type) { -#if defined(ENABLE_NINETOOTHED) -#if defined(ENABLE_NVIDIA_API) - CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); -#endif -#if defined(ENABLE_ILUVATAR_API) - CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); -#endif -#if defined(ENABLE_METAX_API) - CALCULATE(INFINI_DEVICE_METAX, ninetoothed); -#endif -#endif - #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_ALI_API + CALCULATE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#if defined(ENABLE_METAX_API) + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -150,26 +149,28 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor( switch (desc->device_type) { -#if defined(ENABLE_NINETOOTHED) -#if defined(ENABLE_NVIDIA_API) - DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); -#endif -#if defined(ENABLE_ILUVATAR_API) - DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); -#endif -#if defined(ENABLE_METAX_API) - DELETE(INFINI_DEVICE_METAX, ninetoothed); -#endif -#endif - #ifdef ENABLE_NVIDIA_API DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif #ifdef ENABLE_QY_API DELETE(INFINI_DEVICE_QY, nvidia); #endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_ALI_API + DELETE(INFINI_DEVICE_ALI, nvidia); +#endif +#ifdef ENABLE_HYGON_API + DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif +#if defined(ENABLE_METAX_API) + DELETE(INFINI_DEVICE_METAX, metax); +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } + #undef DELETE }