Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/infiniop/ops/kv_caching/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#ifndef __KV_CACHING_KERNEL_CUH__
#define __KV_CACHING_KERNEL_CUH__

template <typename Tdata>
__device__ void kvCachingKernel(
Tdata *__restrict__ k_cache,
Tdata *__restrict__ v_cache,
const Tdata *__restrict__ k,
const Tdata *__restrict__ v,
const int64_t *__restrict__ past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
// num of ele = B * H * seq_len * D
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int total = batch_size * num_kv_heads * seq_len * hidden_dim;

const int grid_size = blockDim.x * gridDim.x;

for (int idx = tid; idx < total; idx += grid_size) {
// unravel index

int d = idx % hidden_dim;
idx /= hidden_dim;

int s = idx % seq_len;
idx /= seq_len;

int h = idx % num_kv_heads;
int b = idx / num_kv_heads;

int past_len = static_cast<int32_t>(past_kv_lengths[b]);
// write position
int cache_s = past_len + s;
int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0;
int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0;

int k_src_offset = d * (int)k_strides_3 + s * (int)k_strides_2 + h * (int)k_strides_1 + b * (int)k_strides_0;
int v_src_offset = d * (int)v_strides_3 + s * (int)v_strides_2 + h * (int)v_strides_1 + b * (int)v_strides_0;
k_cache[k_cache_offset] = k[k_src_offset];
v_cache[v_cache_offset] = v[v_src_offset];
}
}

#endif // __KV_CACHING_KERNEL_CUH__
106 changes: 106 additions & 0 deletions src/infiniop/ops/kv_caching/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifndef __KV_CACHING_INFO_H__
#define __KV_CACHING_INFO_H__

#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"

namespace op::kv_caching {

class KVCachingInfo {
private:
KVCachingInfo() = default;

public:
infiniDtype_t dtype;
size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim;
ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3;
ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3;
ptrdiff_t k_strides_0, k_strides_1, k_strides_2, k_strides_3;
ptrdiff_t v_strides_0, v_strides_1, v_strides_2, v_strides_3;

static utils::Result<KVCachingInfo> createKVCachingInfo(
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {

CHECK_OR_RETURN(
k_cache != nullptr && v_cache != nullptr && k != nullptr && v != nullptr && past_kv_lengths != nullptr,
INFINI_STATUS_NULL_POINTER);

const infiniDtype_t dtype = k_cache->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
Comment thread
wooway777 marked this conversation as resolved.
CHECK_DTYPE(past_kv_lengths->dtype(), INFINI_DTYPE_I64);

CHECK_OR_RETURN(k_cache->ndim() == 4
&& v_cache->ndim() == 4
&& k->ndim() == 4
&& v->ndim() == 4,
INFINI_STATUS_BAD_TENSOR_SHAPE);

auto shape = k_cache->shape();
CHECK_SAME_SHAPE(shape, v_cache->shape());
CHECK_SAME_SHAPE(k->shape(), v->shape());

size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t max_seq_len = shape[2];
size_t hidden_dim = shape[3];

size_t seq_len = k->shape()[2];

CHECK_OR_RETURN(batch_size == k->dim(0)
|| num_kv_heads == k->dim(1)
|| hidden_dim == k->dim(3),
INFINI_STATUS_BAD_TENSOR_SHAPE);

ptrdiff_t k_cache_strides_0 = k_cache->strides()[0];
ptrdiff_t k_cache_strides_1 = k_cache->strides()[1];
ptrdiff_t k_cache_strides_2 = k_cache->strides()[2];
ptrdiff_t k_cache_strides_3 = k_cache->strides()[3];

ptrdiff_t v_cache_strides_0 = v_cache->strides()[0];
ptrdiff_t v_cache_strides_1 = v_cache->strides()[1];
ptrdiff_t v_cache_strides_2 = v_cache->strides()[2];
ptrdiff_t v_cache_strides_3 = v_cache->strides()[3];

ptrdiff_t k_strides_0 = k->strides()[0];
ptrdiff_t k_strides_1 = k->strides()[1];
ptrdiff_t k_strides_2 = k->strides()[2];
ptrdiff_t k_strides_3 = k->strides()[3];

ptrdiff_t v_strides_0 = v->strides()[0];
ptrdiff_t v_strides_1 = v->strides()[1];
ptrdiff_t v_strides_2 = v->strides()[2];
ptrdiff_t v_strides_3 = v->strides()[3];

return utils::Result<KVCachingInfo>(KVCachingInfo{
dtype,
batch_size,
num_kv_heads,
max_seq_len,
seq_len,
hidden_dim,
k_cache_strides_0,
k_cache_strides_1,
k_cache_strides_2,
k_cache_strides_3,
v_cache_strides_0,
v_cache_strides_1,
v_cache_strides_2,
v_cache_strides_3,
k_strides_0,
k_strides_1,
k_strides_2,
k_strides_3,
v_strides_0,
v_strides_1,
v_strides_2,
v_strides_3});
}
};
} // namespace op::kv_caching

#endif // __KV_CACHING_INFO_H__
49 changes: 49 additions & 0 deletions src/infiniop/ops/kv_caching/kv_caching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef KV_CACHING_H
#define KV_CACHING_H

#include "../../operator.h"
#include "info.h"

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::kv_caching::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
KVCachingInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
KVCachingInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t get_workspace_size() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t k_cache, \
infiniopTensorDescriptor_t v_cache, \
infiniopTensorDescriptor_t k, \
infiniopTensorDescriptor_t v, \
infiniopTensorDescriptor_t past_kv_lengths); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *k_cache, void *v_cache, \
const void *k, const void *v, const void *past_kv_lengths, \
void *stream) const; \
}; \
}

#endif // KV_CACHING_H
7 changes: 7 additions & 0 deletions src/infiniop/ops/kv_caching/metax/kv_caching_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef __KV_CACHING_METAX_API_H__
#define __KV_CACHING_METAX_API_H__
#include "../kv_caching.h"

DESCRIPTOR(metax)

#endif // __KV_CACHING_METAX_API_H__
160 changes: 160 additions & 0 deletions src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#include "../../../devices/metax/metax_common.h"
#include "kv_caching_metax.h"

#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>

#include "../../../reduce/cuda/reduce.cuh"

#include "../cuda/kernel.cuh"

template <typename Tdata>
INFINIOP_METAX_KERNEL kvCaching(
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
int batch_size,
int num_kv_heads,
int max_seq_len,
int seq_len,
int hidden_dim,
ptrdiff_t k_cache_strides_0,
ptrdiff_t k_cache_strides_1,
ptrdiff_t k_cache_strides_2,
ptrdiff_t k_cache_strides_3,
ptrdiff_t v_cache_strides_0,
ptrdiff_t v_cache_strides_1,
ptrdiff_t v_cache_strides_2,
ptrdiff_t v_cache_strides_3,
ptrdiff_t k_strides_0,
ptrdiff_t k_strides_1,
ptrdiff_t k_strides_2,
ptrdiff_t k_strides_3,
ptrdiff_t v_strides_0,
ptrdiff_t v_strides_1,
ptrdiff_t v_strides_2,
ptrdiff_t v_strides_3) {
kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
}

namespace op::kv_caching::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths);
CHECK_RESULT(info);

*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t launchKernel(const KVCachingInfo &info,
Tdata *k_cache,
Tdata *v_cache,
const Tdata *k,
const Tdata *v,
const int64_t *past_kv_lengths,
hcStream_t stream, void *workspace) {

int batch_size = static_cast<int>(info.batch_size);
int num_kv_heads = static_cast<int>(info.num_kv_heads);
int max_seq_len = static_cast<int>(info.max_seq_len);
int hidden_dim = static_cast<int>(info.hidden_dim);

int seq_len = static_cast<int>(info.seq_len);

int total = batch_size * num_kv_heads * seq_len * hidden_dim;

ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0;
ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1;
ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2;
ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3;

ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0;
ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1;
ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2;
ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3;

ptrdiff_t k_strides_0 = info.k_strides_0;
ptrdiff_t k_strides_1 = info.k_strides_1;
ptrdiff_t k_strides_2 = info.k_strides_2;
ptrdiff_t k_strides_3 = info.k_strides_3;

ptrdiff_t v_strides_0 = info.v_strides_0;
ptrdiff_t v_strides_1 = info.v_strides_1;
ptrdiff_t v_strides_2 = info.v_strides_2;
ptrdiff_t v_strides_3 = info.v_strides_3;

int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;

kvCaching<Tdata>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths,
batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim,
k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3,
v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3,
k_strides_0, k_strides_1, k_strides_2, k_strides_3,
v_strides_0, v_strides_1, v_strides_2, v_strides_3);
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
#define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace)
#define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}

if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}

} // namespace op::kv_caching::metax
Loading