-
Notifications
You must be signed in to change notification settings - Fork 123
issue/1035: 添加NVIDIA平台上的kv_caching算子 #1039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| #ifndef __KV_CACHING_KERNEL_CUH__ | ||
| #define __KV_CACHING_KERNEL_CUH__ | ||
|
|
||
| template <typename Tdata> | ||
| __device__ void kvCachingKernel( | ||
| Tdata *__restrict__ k_cache, | ||
| Tdata *__restrict__ v_cache, | ||
| const Tdata *__restrict__ k, | ||
| const Tdata *__restrict__ v, | ||
| const int64_t *__restrict__ past_kv_lengths, | ||
| int batch_size, | ||
| int num_kv_heads, | ||
| int max_seq_len, | ||
| int seq_len, | ||
| int hidden_dim, | ||
| ptrdiff_t k_cache_strides_0, | ||
| ptrdiff_t k_cache_strides_1, | ||
| ptrdiff_t k_cache_strides_2, | ||
| ptrdiff_t k_cache_strides_3, | ||
| ptrdiff_t v_cache_strides_0, | ||
| ptrdiff_t v_cache_strides_1, | ||
| ptrdiff_t v_cache_strides_2, | ||
| ptrdiff_t v_cache_strides_3, | ||
| ptrdiff_t k_strides_0, | ||
| ptrdiff_t k_strides_1, | ||
| ptrdiff_t k_strides_2, | ||
| ptrdiff_t k_strides_3, | ||
| ptrdiff_t v_strides_0, | ||
| ptrdiff_t v_strides_1, | ||
| ptrdiff_t v_strides_2, | ||
| ptrdiff_t v_strides_3) { | ||
| // num of ele = B * H * seq_len * D | ||
| int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
| int total = batch_size * num_kv_heads * seq_len * hidden_dim; | ||
|
|
||
| const int grid_size = blockDim.x * gridDim.x; | ||
|
|
||
| for (int idx = tid; idx < total; idx += grid_size) { | ||
| // unravel index | ||
|
|
||
| int d = idx % hidden_dim; | ||
| idx /= hidden_dim; | ||
|
|
||
| int s = idx % seq_len; | ||
| idx /= seq_len; | ||
|
|
||
| int h = idx % num_kv_heads; | ||
| int b = idx / num_kv_heads; | ||
|
|
||
| int past_len = static_cast<int32_t>(past_kv_lengths[b]); | ||
| // write position | ||
| int cache_s = past_len + s; | ||
| int k_cache_offset = d * (int)k_cache_strides_3 + cache_s * (int)k_cache_strides_2 + h * (int)k_cache_strides_1 + b * (int)k_cache_strides_0; | ||
| int v_cache_offset = d * (int)v_cache_strides_3 + cache_s * (int)v_cache_strides_2 + h * (int)v_cache_strides_1 + b * (int)v_cache_strides_0; | ||
|
|
||
| int k_src_offset = d * (int)k_strides_3 + s * (int)k_strides_2 + h * (int)k_strides_1 + b * (int)k_strides_0; | ||
| int v_src_offset = d * (int)v_strides_3 + s * (int)v_strides_2 + h * (int)v_strides_1 + b * (int)v_strides_0; | ||
| k_cache[k_cache_offset] = k[k_src_offset]; | ||
| v_cache[v_cache_offset] = v[v_src_offset]; | ||
| } | ||
| } | ||
|
|
||
| #endif // __KV_CACHING_KERNEL_CUH__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| #ifndef __KV_CACHING_INFO_H__ | ||
| #define __KV_CACHING_INFO_H__ | ||
|
|
||
| #include "../../../utils.h" | ||
| #include "../../operator.h" | ||
| #include "../../tensor.h" | ||
|
|
||
| namespace op::kv_caching { | ||
|
|
||
| class KVCachingInfo { | ||
| private: | ||
| KVCachingInfo() = default; | ||
|
|
||
| public: | ||
| infiniDtype_t dtype; | ||
| size_t batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim; | ||
| ptrdiff_t k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3; | ||
| ptrdiff_t v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3; | ||
| ptrdiff_t k_strides_0, k_strides_1, k_strides_2, k_strides_3; | ||
| ptrdiff_t v_strides_0, v_strides_1, v_strides_2, v_strides_3; | ||
|
|
||
| static utils::Result<KVCachingInfo> createKVCachingInfo( | ||
| infiniopTensorDescriptor_t k_cache, | ||
| infiniopTensorDescriptor_t v_cache, | ||
| infiniopTensorDescriptor_t k, | ||
| infiniopTensorDescriptor_t v, | ||
| infiniopTensorDescriptor_t past_kv_lengths) { | ||
|
|
||
| CHECK_OR_RETURN( | ||
| k_cache != nullptr && v_cache != nullptr && k != nullptr && v != nullptr && past_kv_lengths != nullptr, | ||
| INFINI_STATUS_NULL_POINTER); | ||
|
|
||
| const infiniDtype_t dtype = k_cache->dtype(); | ||
| CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); | ||
| CHECK_DTYPE(past_kv_lengths->dtype(), INFINI_DTYPE_I64); | ||
|
|
||
| CHECK_OR_RETURN(k_cache->ndim() == 4 | ||
| && v_cache->ndim() == 4 | ||
| && k->ndim() == 4 | ||
| && v->ndim() == 4, | ||
| INFINI_STATUS_BAD_TENSOR_SHAPE); | ||
|
|
||
| auto shape = k_cache->shape(); | ||
| CHECK_SAME_SHAPE(shape, v_cache->shape()); | ||
| CHECK_SAME_SHAPE(k->shape(), v->shape()); | ||
|
|
||
| size_t batch_size = shape[0]; | ||
| size_t num_kv_heads = shape[1]; | ||
| size_t max_seq_len = shape[2]; | ||
| size_t hidden_dim = shape[3]; | ||
|
|
||
| size_t seq_len = k->shape()[2]; | ||
|
|
||
| CHECK_OR_RETURN(batch_size == k->dim(0) | ||
| || num_kv_heads == k->dim(1) | ||
| || hidden_dim == k->dim(3), | ||
| INFINI_STATUS_BAD_TENSOR_SHAPE); | ||
|
|
||
| ptrdiff_t k_cache_strides_0 = k_cache->strides()[0]; | ||
| ptrdiff_t k_cache_strides_1 = k_cache->strides()[1]; | ||
| ptrdiff_t k_cache_strides_2 = k_cache->strides()[2]; | ||
| ptrdiff_t k_cache_strides_3 = k_cache->strides()[3]; | ||
|
|
||
| ptrdiff_t v_cache_strides_0 = v_cache->strides()[0]; | ||
| ptrdiff_t v_cache_strides_1 = v_cache->strides()[1]; | ||
| ptrdiff_t v_cache_strides_2 = v_cache->strides()[2]; | ||
| ptrdiff_t v_cache_strides_3 = v_cache->strides()[3]; | ||
|
|
||
| ptrdiff_t k_strides_0 = k->strides()[0]; | ||
| ptrdiff_t k_strides_1 = k->strides()[1]; | ||
| ptrdiff_t k_strides_2 = k->strides()[2]; | ||
| ptrdiff_t k_strides_3 = k->strides()[3]; | ||
|
|
||
| ptrdiff_t v_strides_0 = v->strides()[0]; | ||
| ptrdiff_t v_strides_1 = v->strides()[1]; | ||
| ptrdiff_t v_strides_2 = v->strides()[2]; | ||
| ptrdiff_t v_strides_3 = v->strides()[3]; | ||
|
|
||
| return utils::Result<KVCachingInfo>(KVCachingInfo{ | ||
| dtype, | ||
| batch_size, | ||
| num_kv_heads, | ||
| max_seq_len, | ||
| seq_len, | ||
| hidden_dim, | ||
| k_cache_strides_0, | ||
| k_cache_strides_1, | ||
| k_cache_strides_2, | ||
| k_cache_strides_3, | ||
| v_cache_strides_0, | ||
| v_cache_strides_1, | ||
| v_cache_strides_2, | ||
| v_cache_strides_3, | ||
| k_strides_0, | ||
| k_strides_1, | ||
| k_strides_2, | ||
| k_strides_3, | ||
| v_strides_0, | ||
| v_strides_1, | ||
| v_strides_2, | ||
| v_strides_3}); | ||
| } | ||
| }; | ||
| } // namespace op::kv_caching | ||
|
|
||
| #endif // __KV_CACHING_INFO_H__ | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| #ifndef KV_CACHING_H | ||
| #define KV_CACHING_H | ||
|
|
||
| #include "../../operator.h" | ||
| #include "info.h" | ||
|
|
||
| #define DESCRIPTOR(NAMESPACE) \ | ||
| \ | ||
| namespace op::kv_caching::NAMESPACE { \ | ||
| class Descriptor final : public InfiniopDescriptor { \ | ||
| struct Opaque; \ | ||
| Opaque *_opaque; \ | ||
| KVCachingInfo _info; \ | ||
| size_t _workspace_size; \ | ||
| \ | ||
| Descriptor( \ | ||
| Opaque *opaque, \ | ||
| KVCachingInfo info, \ | ||
| size_t workspace_size, \ | ||
| infiniDevice_t device_type, \ | ||
| int device_id) \ | ||
| : InfiniopDescriptor{device_type, device_id}, \ | ||
| _opaque(opaque), \ | ||
| _info(info), \ | ||
| _workspace_size(workspace_size) {} \ | ||
| \ | ||
| public: \ | ||
| ~Descriptor(); \ | ||
| \ | ||
| size_t get_workspace_size() const { return _workspace_size; } \ | ||
| \ | ||
| static infiniStatus_t create( \ | ||
| infiniopHandle_t handle, \ | ||
| Descriptor **desc_ptr, \ | ||
| infiniopTensorDescriptor_t k_cache, \ | ||
| infiniopTensorDescriptor_t v_cache, \ | ||
| infiniopTensorDescriptor_t k, \ | ||
| infiniopTensorDescriptor_t v, \ | ||
| infiniopTensorDescriptor_t past_kv_lengths); \ | ||
| \ | ||
| infiniStatus_t calculate( \ | ||
| void *workspace, size_t workspace_size, \ | ||
| void *k_cache, void *v_cache, \ | ||
| const void *k, const void *v, const void *past_kv_lengths, \ | ||
| void *stream) const; \ | ||
| }; \ | ||
| } | ||
|
|
||
| #endif // KV_CACHING_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| #ifndef __KV_CACHING_METAX_API_H__ | ||
| #define __KV_CACHING_METAX_API_H__ | ||
| #include "../kv_caching.h" | ||
|
|
||
| DESCRIPTOR(metax) | ||
|
|
||
| #endif // __KV_CACHING_METAX_API_H__ |
160 changes: 160 additions & 0 deletions
160
src/infiniop/ops/kv_caching/metax/kv_caching_metax.maca
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| #include "../../../devices/metax/metax_common.h" | ||
| #include "kv_caching_metax.h" | ||
|
|
||
| #include "../../../devices/metax/metax_kernel_common.h" | ||
| #include <cub/block/block_reduce.cuh> | ||
|
|
||
| #include "../../../reduce/cuda/reduce.cuh" | ||
|
|
||
| #include "../cuda/kernel.cuh" | ||
|
|
||
| template <typename Tdata> | ||
| INFINIOP_METAX_KERNEL kvCaching( | ||
| Tdata *k_cache, | ||
| Tdata *v_cache, | ||
| const Tdata *k, | ||
| const Tdata *v, | ||
| const int64_t *past_kv_lengths, | ||
| int batch_size, | ||
| int num_kv_heads, | ||
| int max_seq_len, | ||
| int seq_len, | ||
| int hidden_dim, | ||
| ptrdiff_t k_cache_strides_0, | ||
| ptrdiff_t k_cache_strides_1, | ||
| ptrdiff_t k_cache_strides_2, | ||
| ptrdiff_t k_cache_strides_3, | ||
| ptrdiff_t v_cache_strides_0, | ||
| ptrdiff_t v_cache_strides_1, | ||
| ptrdiff_t v_cache_strides_2, | ||
| ptrdiff_t v_cache_strides_3, | ||
| ptrdiff_t k_strides_0, | ||
| ptrdiff_t k_strides_1, | ||
| ptrdiff_t k_strides_2, | ||
| ptrdiff_t k_strides_3, | ||
| ptrdiff_t v_strides_0, | ||
| ptrdiff_t v_strides_1, | ||
| ptrdiff_t v_strides_2, | ||
| ptrdiff_t v_strides_3) { | ||
| kvCachingKernel<Tdata>(k_cache, v_cache, k, v, past_kv_lengths, | ||
| batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, | ||
| k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, | ||
| v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, | ||
| k_strides_0, k_strides_1, k_strides_2, k_strides_3, | ||
| v_strides_0, v_strides_1, v_strides_2, v_strides_3); | ||
| } | ||
|
|
||
| namespace op::kv_caching::metax { | ||
|
|
||
| struct Descriptor::Opaque { | ||
| std::shared_ptr<device::metax::Handle::Internal> internal; | ||
| }; | ||
|
|
||
| Descriptor::~Descriptor() { | ||
| delete _opaque; | ||
| } | ||
|
|
||
| infiniStatus_t Descriptor::create( | ||
| infiniopHandle_t handle, | ||
| Descriptor **desc_ptr, | ||
| infiniopTensorDescriptor_t k_cache, | ||
| infiniopTensorDescriptor_t v_cache, | ||
| infiniopTensorDescriptor_t k, | ||
| infiniopTensorDescriptor_t v, | ||
| infiniopTensorDescriptor_t past_kv_lengths) { | ||
| auto info = KVCachingInfo::createKVCachingInfo(k_cache, v_cache, k, v, past_kv_lengths); | ||
| CHECK_RESULT(info); | ||
|
|
||
| *desc_ptr = new Descriptor( | ||
| new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()}, | ||
| info.take(), 0, handle->device, handle->device_id); | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| template <unsigned int BLOCK_SIZE, typename Tdata> | ||
| infiniStatus_t launchKernel(const KVCachingInfo &info, | ||
| Tdata *k_cache, | ||
| Tdata *v_cache, | ||
| const Tdata *k, | ||
| const Tdata *v, | ||
| const int64_t *past_kv_lengths, | ||
| hcStream_t stream, void *workspace) { | ||
|
|
||
| int batch_size = static_cast<int>(info.batch_size); | ||
| int num_kv_heads = static_cast<int>(info.num_kv_heads); | ||
| int max_seq_len = static_cast<int>(info.max_seq_len); | ||
| int hidden_dim = static_cast<int>(info.hidden_dim); | ||
|
|
||
| int seq_len = static_cast<int>(info.seq_len); | ||
|
|
||
| int total = batch_size * num_kv_heads * seq_len * hidden_dim; | ||
|
|
||
| ptrdiff_t k_cache_strides_0 = info.k_cache_strides_0; | ||
| ptrdiff_t k_cache_strides_1 = info.k_cache_strides_1; | ||
| ptrdiff_t k_cache_strides_2 = info.k_cache_strides_2; | ||
| ptrdiff_t k_cache_strides_3 = info.k_cache_strides_3; | ||
|
|
||
| ptrdiff_t v_cache_strides_0 = info.v_cache_strides_0; | ||
| ptrdiff_t v_cache_strides_1 = info.v_cache_strides_1; | ||
| ptrdiff_t v_cache_strides_2 = info.v_cache_strides_2; | ||
| ptrdiff_t v_cache_strides_3 = info.v_cache_strides_3; | ||
|
|
||
| ptrdiff_t k_strides_0 = info.k_strides_0; | ||
| ptrdiff_t k_strides_1 = info.k_strides_1; | ||
| ptrdiff_t k_strides_2 = info.k_strides_2; | ||
| ptrdiff_t k_strides_3 = info.k_strides_3; | ||
|
|
||
| ptrdiff_t v_strides_0 = info.v_strides_0; | ||
| ptrdiff_t v_strides_1 = info.v_strides_1; | ||
| ptrdiff_t v_strides_2 = info.v_strides_2; | ||
| ptrdiff_t v_strides_3 = info.v_strides_3; | ||
|
|
||
| int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; | ||
|
|
||
| kvCaching<Tdata> | ||
| <<<num_blocks, BLOCK_SIZE, 0, stream>>>(k_cache, v_cache, k, v, past_kv_lengths, | ||
| batch_size, num_kv_heads, max_seq_len, seq_len, hidden_dim, | ||
| k_cache_strides_0, k_cache_strides_1, k_cache_strides_2, k_cache_strides_3, | ||
| v_cache_strides_0, v_cache_strides_1, v_cache_strides_2, v_cache_strides_3, | ||
| k_strides_0, k_strides_1, k_strides_2, k_strides_3, | ||
| v_strides_0, v_strides_1, v_strides_2, v_strides_3); | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, | ||
| void *k_cache, | ||
| void *v_cache, | ||
| const void *k, | ||
| const void *v, | ||
| const void *past_kv_lengths, | ||
| void *stream_) const { | ||
| hcStream_t stream = (hcStream_t)stream_; | ||
| #define CALCULATE_KV_CACHING(BLOCK_SIZE, TDATA) \ | ||
| launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)k_cache, (TDATA *)v_cache, (const TDATA *)k, (const TDATA *)v, (const int64_t *)past_kv_lengths, stream, workspace) | ||
| #define CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(BLOCK_SIZE) \ | ||
| { \ | ||
| if (_info.dtype == INFINI_DTYPE_F16) \ | ||
| return CALCULATE_KV_CACHING(BLOCK_SIZE, half); \ | ||
| else if (_info.dtype == INFINI_DTYPE_F32) \ | ||
| return CALCULATE_KV_CACHING(BLOCK_SIZE, float); \ | ||
| else if (_info.dtype == INFINI_DTYPE_BF16) \ | ||
| return CALCULATE_KV_CACHING(BLOCK_SIZE, __hpcc_bfloat16); \ | ||
| else \ | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; \ | ||
| } | ||
|
|
||
| if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { | ||
| CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_1024) | ||
| } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) { | ||
| CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_512) | ||
| } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) { | ||
| CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_2048) | ||
| } else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) { | ||
| CALCULATE_KV_CACHING_WITH_BLOCK_SIZE(METAX_BLOCK_SIZE_4096) | ||
| } else { | ||
| return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; | ||
| } | ||
| return INFINI_STATUS_SUCCESS; | ||
| } | ||
|
|
||
| } // namespace op::kv_caching::metax |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.