From ceef3d0bdabfc63feef01431426546a1e74236bd Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 18 Jun 2026 07:07:55 +0000 Subject: [PATCH 1/3] issue/1210 add delta rule ops --- include/infiniop.h | 2 + include/infiniop/ops/chunk_gated_delta_rule.h | 43 +++ .../infiniop/ops/recurrent_gated_delta_rule.h | 42 +++ .../chunk_gated_delta_rule.h | 57 +++ .../chunk_gated_delta_rule/cuda/kernel.cuh | 283 ++++++++++++++ .../ops/chunk_gated_delta_rule/info.h | 76 ++++ .../nvidia/chunk_gated_delta_rule_nvidia.cu | 181 +++++++++ .../nvidia/chunk_gated_delta_rule_nvidia.cuh | 8 + .../ops/chunk_gated_delta_rule/operator.cc | 110 ++++++ .../cuda/kernel.cuh | 154 ++++++++ .../ops/recurrent_gated_delta_rule/info.h | 74 ++++ .../recurrent_gated_delta_rule_nvidia.cu | 132 +++++++ .../recurrent_gated_delta_rule_nvidia.cuh | 8 + .../recurrent_gated_delta_rule/operator.cc | 106 ++++++ .../recurrent_gated_delta_rule.h | 56 +++ test/infiniop/chunk_gated_delta_rule.py | 344 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 85 +++++ test/infiniop/recurrent_gated_delta_rule.py | 293 +++++++++++++++ 18 files changed, 2054 insertions(+) create mode 100644 include/infiniop/ops/chunk_gated_delta_rule.h create mode 100644 include/infiniop/ops/recurrent_gated_delta_rule.h create mode 100644 src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h create mode 100644 src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh create mode 100644 src/infiniop/ops/chunk_gated_delta_rule/info.h create mode 100644 src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu create mode 100644 src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh create mode 100644 src/infiniop/ops/chunk_gated_delta_rule/operator.cc create mode 100644 src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh create mode 100644 src/infiniop/ops/recurrent_gated_delta_rule/info.h create mode 100644 src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu create mode 100644 src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh create mode 100644 src/infiniop/ops/recurrent_gated_delta_rule/operator.cc create mode 100644 src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h create mode 100644 test/infiniop/chunk_gated_delta_rule.py create mode 100644 test/infiniop/recurrent_gated_delta_rule.py diff --git a/include/infiniop.h b/include/infiniop.h index 5f58805b9..c1482552f 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -31,6 +31,7 @@ #include "infiniop/ops/broadcast_to.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/cdist.h" +#include "infiniop/ops/chunk_gated_delta_rule.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/cross_entropy.h" @@ -100,6 +101,7 @@ #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/reciprocal.h" +#include "infiniop/ops/recurrent_gated_delta_rule.h" #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" diff --git a/include/infiniop/ops/chunk_gated_delta_rule.h b/include/infiniop/ops/chunk_gated_delta_rule.h new file mode 100644 index 000000000..1822e8266 --- /dev/null +++ b/include/infiniop/ops/chunk_gated_delta_rule.h @@ -0,0 +1,43 @@ +#ifndef __INFINIOP_CHUNK_GATED_DELTA_RULE_API_H__ +#define __INFINIOP_CHUNK_GATED_DELTA_RULE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopChunkGatedDeltaRuleDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_desc, + bool use_qk_l2norm, + size_t chunk_size); + +__INFINI_C __export infiniStatus_t infiniopGetChunkGatedDeltaRuleWorkspaceSize( + infiniopChunkGatedDeltaRuleDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopChunkGatedDeltaRule( + infiniopChunkGatedDeltaRuleDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *initial_state, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor( + infiniopChunkGatedDeltaRuleDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/recurrent_gated_delta_rule.h b/include/infiniop/ops/recurrent_gated_delta_rule.h new file mode 100644 index 000000000..85c1f3564 --- /dev/null +++ b/include/infiniop/ops/recurrent_gated_delta_rule.h @@ -0,0 +1,42 @@ +#ifndef __INFINIOP_RECURRENT_GATED_DELTA_RULE_API_H__ +#define __INFINIOP_RECURRENT_GATED_DELTA_RULE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRecurrentGatedDeltaRuleDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_desc, + bool use_qk_l2norm); + +__INFINI_C __export infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopRecurrentGatedDeltaRule( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *initial_state, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h b/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h new file mode 100644 index 000000000..e6850d0bc --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h @@ -0,0 +1,57 @@ +// infiniop/ops/chunk_gated_delta_rule.h + +#ifndef __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ +#define __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::chunk_gated_delta_rule::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + ChunkGatedDeltaRuleInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + ChunkGatedDeltaRuleInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t final_state_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_desc, \ + infiniopTensorDescriptor_t beta_desc, \ + const std::optional &initial_state_desc, \ + bool use_qk_l2norm, \ + size_t chunk_size); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, void *final_state, \ + const void *q, const void *k, const void *v, \ + const void *g, const void *beta, const void *initial_state, \ + void *stream) const; \ + }; \ + } + +#endif // __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ \ No newline at end of file diff --git a/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh new file mode 100644 index 000000000..2bba28795 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh @@ -0,0 +1,283 @@ +// op/chunk_gated_delta_rule/cuda/kernel.cuh + +#ifndef __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ +#define __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ + +#include +#include + +#include + +template +__device__ void chunkGatedDeltaRuleKernel( + Tdata *out, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tdata *g, + const Tdata *beta, + const Tdata *initial_state, + bool use_qk_l2norm, + const size_t chunk_size, + const size_t T // Original sequence length, must be passed from host +) { + // Grid Strategy: Each block handles one sequence for one head. + // gridDim.x = B, gridDim.y = H + const size_t batch_idx = blockIdx.x; + const size_t head_idx = blockIdx.y; + const size_t thread_idx = threadIdx.x; + + const size_t H = gridDim.y; + + const size_t T_padded = (T + chunk_size - 1) / chunk_size * chunk_size; + const size_t num_chunks = T_padded / chunk_size; + const float scale = rsqrtf(static_cast(Dk)); + + using BlockScan = cub::BlockScan; + + // --- Shared Memory Layout --- + extern __shared__ char shared_mem_char[]; + Tcompute *shared_mem = reinterpret_cast(shared_mem_char); + + // Pointers to different sections of shared memory + Tcompute *q_s = shared_mem; + Tcompute *k_s = q_s + chunk_size * Dk; + Tcompute *v_s = k_s + chunk_size * Dk; + Tcompute *k_beta_s = v_s + chunk_size * Dv; + Tcompute *g_s = k_beta_s + chunk_size * Dk; + Tcompute *beta_s = g_s + chunk_size; + Tcompute *g_cumsum_s = beta_s + chunk_size; + Tcompute *attn_s = g_cumsum_s + chunk_size; + Tcompute *k_cumdecay_s = attn_s + chunk_size * chunk_size; + Tcompute *value_prime_s = k_cumdecay_s + chunk_size * Dk; + Tcompute *v_prime_s = value_prime_s + chunk_size * Dv; + Tcompute *attn_inter_s = v_prime_s + chunk_size * Dv; + + typename BlockScan::TempStorage *cub_temp_storage = (typename BlockScan::TempStorage *)(attn_inter_s + chunk_size * Dv); + + // --- Main loop over chunks of the sequence --- + for (size_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + const Tdata *current_state_ptr_g = (chunk_idx == 0 && initial_state != nullptr) ? initial_state : final_state; + const ptrdiff_t state_offset = (batch_idx * H + head_idx) * (Dk * Dv); + + __syncthreads(); + size_t chunk_offset = chunk_idx * chunk_size; + + // --- 2.1: Collaborative Loading of chunk data --- + // (This section is unchanged) + for (size_t i = thread_idx; i < chunk_size; i += BLOCK_THREADS) { + size_t t_idx = chunk_offset + i; + if (t_idx < T) { + ptrdiff_t gb_offset = (batch_idx * H * T) + (head_idx * T) + t_idx; + g_s[i] = static_cast(g[gb_offset]); + beta_s[i] = static_cast(beta[gb_offset]); + } else { + g_s[i] = 0.0f; + beta_s[i] = 1.0f; + } + } + for (size_t i = thread_idx; i < chunk_size * Dk; i += BLOCK_THREADS) { + size_t t_local = i / Dk; + size_t d = i % Dk; + size_t t_global = chunk_offset + t_local; + if (t_global < T) { + ptrdiff_t qk_offset = (batch_idx * H * T * Dk) + (head_idx * T * Dk) + (t_global * Dk) + d; + q_s[i] = static_cast(q[qk_offset]); + k_s[i] = static_cast(k[qk_offset]); + } else { + q_s[i] = 0.0f; + k_s[i] = 0.0f; + } + } + for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { + size_t t_local = i / Dv; + size_t d = i % Dv; + size_t t_global = chunk_offset + t_local; + if (t_global < T) { + ptrdiff_t v_offset = (batch_idx * H * T * Dv) + (head_idx * T * Dv) + (t_global * Dv) + d; + v_s[i] = static_cast(v[v_offset]); + } else { + v_s[i] = 0.0f; + } + } + __syncthreads(); + + // --- 2.2: Optional L2 Normalization --- (Unchanged) + if (use_qk_l2norm) { + // This loop is collapsed for brevity. It is correct and unchanged. + for (size_t t = thread_idx; t < chunk_size; t += BLOCK_THREADS) { + size_t t_global = chunk_offset + t; + if (t_global < T) { + Tcompute q_norm_sq = 0.0f; + Tcompute k_norm_sq = 0.0f; + for (size_t d = 0; d < Dk; ++d) { + Tcompute q_val = q_s[t * Dk + d]; + Tcompute k_val = k_s[t * Dk + d]; + q_norm_sq += q_val * q_val; + k_norm_sq += k_val * k_val; + } + Tcompute r_q_norm = rsqrtf(q_norm_sq + 1e-6f); + Tcompute r_k_norm = rsqrtf(k_norm_sq + 1e-6f); + for (size_t d = 0; d < Dk; ++d) { + q_s[t * Dk + d] *= r_q_norm; + k_s[t * Dk + d] *= r_k_norm; + } + } + } + __syncthreads(); + } + + // --- 2.3 Intra-Chunk Calculations --- (Unchanged, all operate on shared memory) + Tcompute g_val = (thread_idx < chunk_size) ? g_s[thread_idx] : 0.0f; + Tcompute g_cumsum_val; + BlockScan(*cub_temp_storage).InclusiveSum(g_val, g_cumsum_val); + if (thread_idx < chunk_size) { + g_cumsum_s[thread_idx] = g_cumsum_val; + } + __syncthreads(); + for (size_t i = thread_idx; i < chunk_size; i += BLOCK_THREADS) { + Tcompute beta_val = beta_s[i]; + for (size_t d = 0; d < Dk; ++d) { + k_beta_s[i * Dk + d] = k_s[i * Dk + d] * beta_val; + } + for (size_t d = 0; d < Dv; ++d) { + v_s[i * Dv + d] *= beta_val; + } + for (size_t d = 0; d < Dk; ++d) { + q_s[i * Dk + d] *= scale; + } + } + __syncthreads(); + for (size_t i = thread_idx; i < chunk_size * chunk_size; i += BLOCK_THREADS) { + size_t row = i / chunk_size; + size_t col = i % chunk_size; + Tcompute dot_prod = 0.0f; + if (col < row) { + for (size_t d = 0; d < Dk; ++d) { dot_prod += k_beta_s[row * Dk + d] * k_s[col * Dk + d]; } + Tcompute decay = expf(g_cumsum_s[row] - g_cumsum_s[col]); + attn_s[i] = -dot_prod * decay; + } else { + attn_s[i] = 0.0f; + } + } + __syncthreads(); + for (size_t i = 1; i < chunk_size; ++i) { + for (size_t j = thread_idx; j < i; j += BLOCK_THREADS) { + Tcompute update_val = 0.0f; + for (size_t l = 0; l < i; ++l) { update_val += attn_s[i * chunk_size + l] * attn_s[l * chunk_size + j]; } + attn_s[i * chunk_size + j] += update_val; + } + __syncthreads(); + } + if (thread_idx < chunk_size) { + attn_s[thread_idx * chunk_size + thread_idx] += 1.0f; + } + __syncthreads(); + for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { + size_t row = i / Dv; + size_t col_v = i % Dv; + Tcompute dot_prod = 0.0f; + for (size_t d = 0; d < chunk_size; ++d) { + dot_prod += attn_s[row * chunk_size + d] * v_s[d * Dv + col_v]; + } + value_prime_s[i] = dot_prod; + } + for (size_t i = thread_idx; i < chunk_size * Dk; i += BLOCK_THREADS) { + size_t row = i / Dk; + int col_k = i % Dk; + Tcompute dot_prod = 0.0f; + for (size_t d = 0; d < chunk_size; ++d) { + dot_prod += attn_s[row * chunk_size + d] * k_beta_s[d * Dk + col_k] * expf(g_cumsum_s[d]); + } + k_cumdecay_s[i] = dot_prod; + } + __syncthreads(); + + // --- 2.4: Inter-Chunk Interaction --- + // (Correctly reads from global memory) + for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { + size_t row = i / Dv; + size_t col_v = i % Dv; + Tcompute sum = 0.0f; + for (size_t d = 0; d < Dk; ++d) { + Tcompute state_val = (initial_state == nullptr && chunk_idx == 0) ? 0.0f : static_cast(current_state_ptr_g[state_offset + d * Dv + col_v]); + sum += k_cumdecay_s[row * Dk + d] * state_val; + } + v_prime_s[i] = sum; + } + for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { + size_t row = i / Dv; + size_t col_v = i % Dv; + Tcompute sum = 0.0f; + Tcompute g_exp = expf(g_cumsum_s[row]); + for (size_t d = 0; d < Dk; ++d) { + Tcompute state_val = (initial_state == nullptr && chunk_idx == 0) ? 0.0f : static_cast(current_state_ptr_g[state_offset + d * Dv + col_v]); + sum += (q_s[row * Dk + d] * g_exp) * state_val; + } + attn_inter_s[i] = sum; + } + __syncthreads(); + + // --- 2.5: Final Output Calculation and Writeback --- (Unchanged) + for (size_t t = thread_idx; t < chunk_size; t += BLOCK_THREADS) { + size_t global_t = chunk_offset + t; + if (global_t < T) { + ptrdiff_t out_offset = (batch_idx * H * T * Dv) + (head_idx * T * Dv) + (global_t * Dv); + for (size_t d_v = 0; d_v < Dv; ++d_v) { + Tcompute intra_sum = 0.0f; + for (size_t j = 0; j <= t; ++j) { + Tcompute dot_qk = 0.0f; + for (size_t d_k = 0; d_k < Dk; ++d_k) { + dot_qk += q_s[t * Dk + d_k] * k_s[j * Dk + d_k]; + } + Tcompute value_prime_j = value_prime_s[j * Dv + d_v]; + Tcompute v_prime_j = v_prime_s[j * Dv + d_v]; + Tcompute v_new_j = value_prime_j - v_prime_j; + Tcompute decay = expf(g_cumsum_s[t] - g_cumsum_s[j]); + intra_sum += (dot_qk * decay) * v_new_j; + } + out[out_offset + d_v] = static_cast(attn_inter_s[t * Dv + d_v] + intra_sum); + } + } + } + + // --- 2.6: Update inter_chunk_state for the next iteration --- + // (Correctly reads-updates-writes to global memory) + __syncthreads(); + Tcompute g_final_cumsum = g_cumsum_s[chunk_size - 1]; + Tcompute final_decay_factor = expf(g_final_cumsum); + Tdata *final_state_ptr = final_state + state_offset; + + for (size_t i = thread_idx; i < Dk * Dv; i += BLOCK_THREADS) { + size_t dk = i / Dv; + size_t dv = i % Dv; + + Tcompute old_state_val; + if (chunk_idx == 0) { + old_state_val = (initial_state != nullptr) ? static_cast(initial_state[state_offset + i]) : 0.0f; + } else { + old_state_val = static_cast(final_state_ptr[i]); + } + Tcompute decayed_state = old_state_val * final_decay_factor; + + Tcompute chunk_contribution = 0.0f; + for (size_t t = 0; t < chunk_size; ++t) { + Tcompute decay_factor = expf(g_final_cumsum - g_cumsum_s[t]); + Tcompute value_prime_t = value_prime_s[t * Dv + dv]; + Tcompute v_prime_t = v_prime_s[t * Dv + dv]; + Tcompute v_new_t = value_prime_t - v_prime_t; + chunk_contribution += (k_s[t * Dk + dk] * decay_factor) * v_new_t; + } + + final_state_ptr[i] = static_cast(decayed_state + chunk_contribution); + } + + // BUG FIX: Add a block-wide memory fence to ensure global memory writes from this + // iteration are visible to all threads before the next iteration begins. + __threadfence_block(); + } +} + +#endif // __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ \ No newline at end of file diff --git a/src/infiniop/ops/chunk_gated_delta_rule/info.h b/src/infiniop/ops/chunk_gated_delta_rule/info.h new file mode 100644 index 000000000..922153d0d --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/info.h @@ -0,0 +1,76 @@ +// infiniop/ops/chunk_gated_delta_rule/info.h + +#ifndef __CHUNK_GATED_DELTA_RULE_INFO_H__ +#define __CHUNK_GATED_DELTA_RULE_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op { +namespace chunk_gated_delta_rule { + +class ChunkGatedDeltaRuleInfo { + ChunkGatedDeltaRuleInfo() = default; + +public: + // --- Data Types and Flags --- + infiniDtype_t dtype; + bool use_qk_l2norm; + + // --- Shape Dimensions --- + size_t B, H, T, Dk, Dv, chunk_size; + + // --- Strides for Memory Layout --- + // Strides can be added here if needed for more complex layouts + + static utils::Result + create(infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + const std::optional &initial_state_desc, + bool use_qk_l2norm, + size_t chunk_size) { + + auto dtype = q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + + // Check for consistent data types across all tensors + if (out_desc->dtype() != dtype || final_state_desc->dtype() != dtype || k_desc->dtype() != dtype || v_desc->dtype() != dtype || g_desc->dtype() != dtype || beta_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // Check tensor dimensions + if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || out_desc->ndim() != 4 || final_state_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + ChunkGatedDeltaRuleInfo info; + info.dtype = dtype; + info.use_qk_l2norm = use_qk_l2norm; + info.chunk_size = chunk_size; + + auto q_shape = q_desc->shape(); + info.B = q_shape[0]; + info.H = q_shape[1]; + info.T = q_shape[2]; + info.Dk = q_shape[3]; + + info.Dv = v_desc->shape()[3]; + + // Further validation can be added here to ensure all shapes are compatible. + // For example, check if initial_state_desc shape is [B, H, Dk, Dv]. + + return utils::Result(info); + } +}; + +} // namespace chunk_gated_delta_rule +} // namespace op + +#endif // __CHUNK_GATED_DELTA_RULE_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu new file mode 100644 index 000000000..f5f96e859 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu @@ -0,0 +1,181 @@ +// chunk_gated_delta_rule_nvidia.cu + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "chunk_gated_delta_rule_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel.cuh" +#include + +// Kernel Launcher Wrapper +template +INFINIOP_CUDA_KERNEL chunkGatedDeltaRule( + Tdata *out, Tdata *final_state, + const Tdata *q, const Tdata *k, const Tdata *v, + const Tdata *g, const Tdata *beta, const Tdata *initial_state, + bool use_qk_l2norm, size_t chunk_size, size_t T) { + chunkGatedDeltaRuleKernel( + out, final_state, q, k, v, g, beta, initial_state, use_qk_l2norm, chunk_size, T); +} + +namespace op { +namespace chunk_gated_delta_rule { +namespace nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + const std::optional &initial_state_desc, + bool use_qk_l2norm, + size_t chunk_size) { + auto info = ChunkGatedDeltaRuleInfo::create( + out_desc, final_state_desc, q_desc, k_desc, v_desc, + g_desc, beta_desc, initial_state_desc, use_qk_l2norm, chunk_size); + CHECK_RESULT(info); + + // Calculate workspace size if needed, here it's 0 + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + void *out, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *initial_state, + bool use_qk_l2norm, + infiniDtype_t dtype, + size_t B, size_t H, size_t T, size_t chunk_size, + cudaStream_t stream) { + dim3 grid(uint32_t(B), uint32_t(H), 1); + dim3 block(NUM_THREADS); + // Shared memory for local Q, K, and one reduction value + // size_t shared_mem_size = (Dk + Dk + NUM_THREADS) * sizeof(float); + + using Tcompute = float; + using BlockScan = cub::BlockScan; + // using BlockReduce = cub::BlockReduce; + + // size_t shared_mem_size = ( + // chunk_size * (3 * Dk + Dv + 3) + + // chunk_size * chunk_size + + // Dk * Dv + // ) * sizeof(Tcompute) + sizeof(typename BlockScan::TempStorage) + sizeof(typename BlockReduce::TempStorage); + // size_t shared_mem_size = ( + // // q_s, k_s, k_beta_s, k_cumdecay_s + // chunk_size * 4 * Dk + + // // v_s, value_prime_s, v_prime_s, attn_inter_s + // chunk_size * 4 * Dv + + // // g_s, beta_s, g_cumsum_s + // chunk_size * 3 + + // // attn_s (removed decay_mask_s) + // chunk_size * chunk_size + + // // inter_chunk_state_s + // Dk * Dv + // ) * sizeof(Tcompute) + sizeof(typename BlockScan::TempStorage) + sizeof(typename BlockReduce::TempStorage); + + // size_t shared_mem_size = ( + // // q_s, k_s, k_beta_s, k_cumdecay_s + // chunk_size * 4 * Dk + + // // v_s, value_prime_s, v_prime_s (v_new_s is still here from prev version) + // chunk_size * 4 * Dv + + // // g_s, beta_s, g_cumsum_s + // chunk_size * 3 + + // // attn_s + // chunk_size * chunk_size + + // // inter_chunk_state_s + // Dk * Dv + // ) * sizeof(Tcompute) + sizeof(typename BlockScan::TempStorage) + sizeof(typename BlockReduce::TempStorage); + size_t shared_mem_size = ( + // q_s, k_s, k_beta_s, k_cumdecay_s + chunk_size * 4 * Dk + + // v_s, value_prime_s, v_prime_s, attn_inter_s + chunk_size * 4 * Dv + + // g_s, beta_s, g_cumsum_s + chunk_size * 3 + + // attn_s + chunk_size * chunk_size + // NOTE: Dk * Dv term for inter_chunk_state_s has been removed. + ) + * sizeof(Tcompute) + + sizeof(typename BlockScan::TempStorage); + + if (dtype == INFINI_DTYPE_F16) { + chunkGatedDeltaRule + <<>>( + (half *)out, (half *)final_state, + (const half *)q, (const half *)k, (const half *)v, + (const half *)g, (const half *)beta, (const half *)initial_state, + use_qk_l2norm, chunk_size, T); + } else if (dtype == INFINI_DTYPE_BF16) { + chunkGatedDeltaRule<__nv_bfloat16, float, Dk, Dv, NUM_THREADS> + <<>>( + (__nv_bfloat16 *)out, (__nv_bfloat16 *)final_state, + (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k, (const __nv_bfloat16 *)v, + (const __nv_bfloat16 *)g, (const __nv_bfloat16 *)beta, (const __nv_bfloat16 *)initial_state, + use_qk_l2norm, chunk_size, T); + } else if (dtype == INFINI_DTYPE_F32) { + chunkGatedDeltaRule + <<>>( + (float *)out, (float *)final_state, + (const float *)q, (const float *)k, (const float *)v, + (const float *)g, (const float *)beta, (const float *)initial_state, + use_qk_l2norm, chunk_size, T); + } else { + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *initial_state, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + + // Specialize for common shapes and thread counts + if (_info.Dk == 128 && _info.Dv == 128) { + if (_opaque->internal->maxThreadsPerBlock() >= 128) { + return launchKernel<128, 128, 128>( + out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, + _info.dtype, _info.B, _info.H, _info.T, _info.chunk_size, stream); + } + } else if (_info.Dk == 64 && _info.Dv == 64) { + if (_opaque->internal->maxThreadsPerBlock() >= 64) { + return launchKernel<64, 64, 64>( + out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, + _info.dtype, _info.B, _info.H, _info.T, _info.chunk_size, stream); + } + } + + // Fallback or error for unsupported shapes + // You can add more specializations for other shapes here. + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +} // namespace nvidia +} // namespace chunk_gated_delta_rule +} // namespace op \ No newline at end of file diff --git a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh new file mode 100644 index 000000000..b811a7185 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ +#define __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ + +#include "../chunk_gated_delta_rule.h" + +DESCRIPTOR(nvidia) + +#endif // __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ \ No newline at end of file diff --git a/src/infiniop/ops/chunk_gated_delta_rule/operator.cc b/src/infiniop/ops/chunk_gated_delta_rule/operator.cc new file mode 100644 index 000000000..776deb94a --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/operator.cc @@ -0,0 +1,110 @@ +// infiniop/ops/chunk_gated_delta_rule/operator.cc + +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/chunk_gated_delta_rule.h" + +#if defined(ENABLE_NVIDIA_API) +#include "nvidia/chunk_gated_delta_rule_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_desc, + bool use_qk_l2norm, + size_t chunk_size) { + + std::optional initial_state_opt = (initial_state_desc == nullptr) ? std::nullopt : std::optional(initial_state_desc); + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::chunk_gated_delta_rule::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor **>( \ + desc_ptr), \ + out_desc, final_state_desc, q_desc, k_desc, v_desc, g_desc, \ + beta_desc, initial_state_opt, use_qk_l2norm, chunk_size); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetChunkGatedDeltaRuleWorkspaceSize( + infiniopChunkGatedDeltaRuleDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopChunkGatedDeltaRule( + infiniopChunkGatedDeltaRuleDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *initial_state, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->calculate(workspace, workspace_size, out, final_state, q, k, v, \ + g, beta, initial_state, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor( + infiniopChunkGatedDeltaRuleDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} \ No newline at end of file diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh new file mode 100644 index 000000000..a8d68c61e --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh @@ -0,0 +1,154 @@ +// kernel.cuh (in op/recurrent_gated_delta_rule/cuda/) + +#ifndef __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ +#define __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ + +#include +#include +// Tdata: (e.g., half) +// Tcompute: (e.g., float) +template +__device__ void recurrentGatedDeltaRuleKernel( + Tdata *out, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tdata *g, + const Tdata *beta, + const Tdata *initial_state, + bool use_qk_l2norm) { + const int batch_idx = blockIdx.x; + const int head_idx = blockIdx.y; + const int thread_idx = threadIdx.x; + + // T=1 for decode stage, so seq_idx is always 0 + const int seq_idx = 0; + + const size_t H = gridDim.y; + const size_t base_offset_qkv = (batch_idx * H + head_idx) * Dk; // T=1, Dk=Dv for simplicity now + const size_t base_offset_gb = (batch_idx * H + head_idx); // T=1 + const size_t state_offset = (batch_idx * H + head_idx) * Dk * Dv; + + const Tdata *q_ptr = q + base_offset_qkv; + const Tdata *k_ptr = k + base_offset_qkv; + const Tdata *v_ptr = v + base_offset_qkv; // Assuming Dv = Dk + const Tdata *g_ptr = g + base_offset_gb; + const Tdata *beta_ptr = beta + base_offset_gb; + const Tdata *initial_state_ptr = initial_state + state_offset; + + Tdata *out_ptr = out + base_offset_qkv; + Tdata *final_state_ptr = final_state + state_offset; + + extern __shared__ char shared_mem_char[]; + Tcompute *shared_mem = reinterpret_cast(shared_mem_char); + + // shared memory layout: q_local[Dk], k_local[Dk], norm_val[1] + Tcompute *q_local = shared_mem; + Tcompute *k_local = q_local + Dk; + Tcompute *norm_val = k_local + Dk; // for reduction + + // 1. Load Q and K into shared memory and optionally normalize + // Load + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + q_local[i] = static_cast(q_ptr[i]); + k_local[i] = static_cast(k_ptr[i]); + } + + if (use_qk_l2norm) { + __syncthreads(); + // Parallel reduction to compute L2 norm for Q + Tcompute sum_sq = 0.0f; + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + sum_sq += q_local[i] * q_local[i]; + } + // Simplified reduction, for real use CUB will be better + // This part needs a proper block-wide reduction implementation + norm_val[thread_idx] = sum_sq; + __syncthreads(); + if (thread_idx == 0) { + Tcompute total_sum_sq = 0.0f; + for (int i = 0; i < NUM_THREADS; ++i) { + total_sum_sq += norm_val[i]; + } + norm_val[0] = rsqrtf(total_sum_sq + 1e-6f); + } + __syncthreads(); + Tcompute r_norm_q = norm_val[0]; + + // Normalize Q + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + q_local[i] *= r_norm_q; + } + + // Repeat for K + sum_sq = 0.0f; + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + sum_sq += k_local[i] * k_local[i]; + } + norm_val[thread_idx] = sum_sq; + __syncthreads(); + if (thread_idx == 0) { + Tcompute total_sum_sq = 0.0f; + for (int i = 0; i < NUM_THREADS; ++i) { + total_sum_sq += norm_val[i]; + } + norm_val[0] = rsqrtf(total_sum_sq + 1e-6f); + } + __syncthreads(); + Tcompute r_norm_k = norm_val[0]; + + // Normalize K + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + k_local[i] *= r_norm_k; + } + __syncthreads(); + } + + // 2. Perform the recurrent update logic + Tcompute g_t = expf(static_cast(*g_ptr)); + Tcompute beta_t = static_cast(*beta_ptr); + Tcompute scale = rsqrtf(static_cast(Dk)); + + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + q_local[i] *= scale; + } + __syncthreads(); + + // Loop over Dv, each thread computes an element of the delta and output vector + for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { + Tcompute kv_mem = 0.0f; + // Calculate kv_mem = sum(h_{t-1} * k_t) + for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { + Tcompute h_prev = static_cast(initial_state_ptr[dk_idx * Dv + dv_idx]); + kv_mem += (h_prev * g_t) * k_local[dk_idx]; + } + + Tcompute v_t = static_cast(v_ptr[dv_idx]); + Tcompute delta = (v_t - kv_mem) * beta_t; + + // Calculate final state h_t = h_{t-1} * g + k_t * delta + // And write it back + for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { + Tcompute h_prev = static_cast(initial_state_ptr[dk_idx * Dv + dv_idx]); + Tcompute h_final = (h_prev * g_t) + (k_local[dk_idx] * delta); + final_state_ptr[dk_idx * Dv + dv_idx] = static_cast(h_final); + } + + // Calculate output o_t = sum(h_t * q_t) + // This requires another reduction. For simplicity, we assume one thread calculates one output element. + // A more optimized version would have all threads collaborating. + } + __syncthreads(); // Ensure final_state is fully written + + // All threads collaborate to compute the final output vector + for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { + Tcompute out_val = 0.0f; + for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { + out_val += static_cast(final_state_ptr[dk_idx * Dv + dv_idx]) * q_local[dk_idx]; + } + out_ptr[dv_idx] = static_cast(out_val); + } +} + +#endif // __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ \ No newline at end of file diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/info.h b/src/infiniop/ops/recurrent_gated_delta_rule/info.h new file mode 100644 index 000000000..bc6c523ad --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/info.h @@ -0,0 +1,74 @@ +// infiniop/ops/recurrent_gated_delta_rule/info.h + +#ifndef __RECURRENT_GATED_DELTA_RULE_INFO_H__ +#define __RECURRENT_GATED_DELTA_RULE_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op { +namespace recurrent_gated_delta_rule { + +class RecurrentGatedDeltaRuleInfo { + RecurrentGatedDeltaRuleInfo() = default; + +public: + // --- Data Types and Flags --- + infiniDtype_t dtype; + bool use_qk_l2norm; + + // --- Shape Dimensions --- + size_t B, H, T, Dk, Dv; + + // --- Strides for Memory Layout --- + // Strides can be added here if needed for more complex layouts + + static utils::Result + create(infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_desc, + bool use_qk_l2norm) { + + auto dtype = q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + + // Check for consistent data types across all tensors + if (out_desc->dtype() != dtype || final_state_desc->dtype() != dtype || k_desc->dtype() != dtype || v_desc->dtype() != dtype || g_desc->dtype() != dtype || beta_desc->dtype() != dtype || initial_state_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // Check tensor dimensions + if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || initial_state_desc->ndim() != 4 || out_desc->ndim() != 4 || final_state_desc->ndim() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + RecurrentGatedDeltaRuleInfo info; + info.dtype = dtype; + info.use_qk_l2norm = use_qk_l2norm; + + auto q_shape = q_desc->shape(); + info.B = q_shape[0]; + info.H = q_shape[1]; + info.T = q_shape[2]; + info.Dk = q_shape[3]; + + info.Dv = v_desc->shape()[3]; + + // Further validation can be added here to ensure all shapes are compatible. + // For example, check if initial_state_desc shape is [B, H, Dk, Dv]. + + return utils::Result(info); + } +}; + +} // namespace recurrent_gated_delta_rule +} // namespace op + +#endif // __RECURRENT_GATED_DELTA_RULE_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu new file mode 100644 index 000000000..6c4b4ebe5 --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu @@ -0,0 +1,132 @@ +// recurrent_gated_delta_rule_nvidia.cu + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "recurrent_gated_delta_rule_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel.cuh" +#include + +// Kernel Launcher Wrapper +template +INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( + Tdata *out, Tdata *final_state, + const Tdata *q, const Tdata *k, const Tdata *v, + const Tdata *g, const Tdata *beta, const Tdata *initial_state, + bool use_qk_l2norm) { + recurrentGatedDeltaRuleKernel( + out, final_state, q, k, v, g, beta, initial_state, use_qk_l2norm); +} + +namespace op { +namespace recurrent_gated_delta_rule { +namespace nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_desc, + bool use_qk_l2norm) { + auto info = RecurrentGatedDeltaRuleInfo::create( + out_desc, final_state_desc, q_desc, k_desc, v_desc, + g_desc, beta_desc, initial_state_desc, use_qk_l2norm); + CHECK_RESULT(info); + + // Calculate workspace size if needed, here it's 0 + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + void *out, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *initial_state, + bool use_qk_l2norm, + infiniDtype_t dtype, + size_t B, size_t H, + cudaStream_t stream) { + dim3 grid(uint32_t(B), uint32_t(H), 1); + dim3 block(NUM_THREADS); + // Shared memory for local Q, K, and one reduction value + size_t shared_mem_size = (Dk + Dk + NUM_THREADS) * sizeof(float); + + if (dtype == INFINI_DTYPE_F16) { + recurrentGatedDeltaRule + <<>>( + (half *)out, (half *)final_state, + (const half *)q, (const half *)k, (const half *)v, + (const half *)g, (const half *)beta, (const half *)initial_state, + use_qk_l2norm); + } else if (dtype == INFINI_DTYPE_BF16) { + recurrentGatedDeltaRule<__nv_bfloat16, float, Dk, Dv, NUM_THREADS> + <<>>( + (__nv_bfloat16 *)out, (__nv_bfloat16 *)final_state, + (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k, (const __nv_bfloat16 *)v, + (const __nv_bfloat16 *)g, (const __nv_bfloat16 *)beta, (const __nv_bfloat16 *)initial_state, + use_qk_l2norm); + } else if (dtype == INFINI_DTYPE_F32) { + recurrentGatedDeltaRule + <<>>( + (float *)out, (float *)final_state, + (const float *)q, (const float *)k, (const float *)v, + (const float *)g, (const float *)beta, (const float *)initial_state, + use_qk_l2norm); + } else { + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *initial_state, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + + // Specialize for common shapes and thread counts + if (_info.Dk == 128 && _info.Dv == 128) { + if (_opaque->internal->maxThreadsPerBlock() >= 128) { + return launchKernel<128, 128, 128>( + out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, + _info.dtype, _info.B, _info.H, stream); + } + } else if (_info.Dk == 64 && _info.Dv == 64) { + if (_opaque->internal->maxThreadsPerBlock() >= 64) { + return launchKernel<64, 64, 64>( + out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, + _info.dtype, _info.B, _info.H, stream); + } + } + + // Fallback or error for unsupported shapes + // You can add more specializations for other shapes here. + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +} // namespace nvidia +} // namespace recurrent_gated_delta_rule +} // namespace op \ No newline at end of file diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh new file mode 100644 index 000000000..61bfb3051 --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ +#define __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ + +#include "../recurrent_gated_delta_rule.h" + +DESCRIPTOR(nvidia) + +#endif // __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ \ No newline at end of file diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc b/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc new file mode 100644 index 000000000..4852b65eb --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc @@ -0,0 +1,106 @@ +// infiniop/ops/recurrent_gated_delta_rule/operator.cc + +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/recurrent_gated_delta_rule.h" + +#if defined(ENABLE_NVIDIA_API) +#include "nvidia/recurrent_gated_delta_rule_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_desc, + bool use_qk_l2norm) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::recurrent_gated_delta_rule::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor **>( \ + desc_ptr), \ + out_desc, final_state_desc, q_desc, k_desc, v_desc, g_desc, \ + beta_desc, initial_state_desc, use_qk_l2norm); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopRecurrentGatedDeltaRule( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *initial_state, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->calculate(workspace, workspace_size, out, final_state, q, k, v, \ + g, beta, initial_state, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} \ No newline at end of file diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h b/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h new file mode 100644 index 000000000..08ab5fa48 --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h @@ -0,0 +1,56 @@ +// infiniop/ops/recurrent_gated_delta_rule.h + +#ifndef __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ +#define __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::recurrent_gated_delta_rule::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RecurrentGatedDeltaRuleInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + RecurrentGatedDeltaRuleInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t final_state_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_desc, \ + infiniopTensorDescriptor_t beta_desc, \ + infiniopTensorDescriptor_t initial_state_desc, \ + bool use_qk_l2norm); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, void *final_state, \ + const void *q, const void *k, const void *v, \ + const void *g, const void *beta, const void *initial_state, \ + void *stream) const; \ + }; \ + } + +#endif // __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ \ No newline at end of file diff --git a/test/infiniop/chunk_gated_delta_rule.py b/test/infiniop/chunk_gated_delta_rule.py new file mode 100644 index 000000000..b1de7b2a0 --- /dev/null +++ b/test/infiniop/chunk_gated_delta_rule.py @@ -0,0 +1,344 @@ +# test_chunk_gated_delta_rule.py + +import torch +import torch.nn.functional as F +import ctypes +from ctypes import c_uint32, c_float, c_uint64, c_size_t, POINTER, addressof + +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + + +# ============================================================================== +# Reference Implementation +# ============================================================================== +# From modeling_qwen3_next.py, the production PyTorch fallback implementation +def ref_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + + # The production implementation expects (B, T, H, D) and transposes internally + # query, key, value, beta, g = [ + # x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + # ] + + query, key, value, beta, g = [ + x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + # print("before pad", query.shape, key.shape, value.shape, beta.shape, g.shape) + + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + # print("after pad", query.shape, key.shape, value.shape, beta.shape, g.shape) + + tot_seqs = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + # Reshape to chunks (in the head dimension) + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + # This part is quite intricate and involves parallel scan logic. + # We will trust the reference implementation as the ground truth. + g = g.cumsum(dim=-1) + + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = ( + torch.zeros( + batch_size, + num_heads, + k_head_dim, + v_head_dim, + device=value.device, + dtype=torch.float32, + ) + if initial_state is None + else initial_state.to(torch.float32) + ) + + core_attn_out = torch.zeros_like(value) + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(0, tot_seqs // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn_intra = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask, 0 + ) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_intra @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2 + ) + @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + + core_attn_out = core_attn_out.reshape( + core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] + ) + core_attn_out = core_attn_out[:, :, :sequence_length] # Unpad + # core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + core_attn_out = core_attn_out.contiguous().to(initial_dtype) + + if last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.contiguous().to(initial_dtype) + + return core_attn_out, last_recurrent_state + + +# ============================================================================== +# Test Configuration +# ============================================================================== +# (B, T, H, Dk, Dv, chunk_size, use_qk_l2norm) +# T (seq_len) must be > 1 for this operator +_TEST_CASES_ = [ + (2, 511, 40, 64, 64, 8, True), + # (2, 511, 40, 64, 64, 16, True), + # (4, 1024, 64, 128, 128, 64, False), + (8, 511, 32, 64, 64, 8, True), + (8, 511, 32, 128, 128, 8, True), +] + +# Data types for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +# Tolerance map +_TOLERANCE_MAP = { + InfiniDtype.F16: { + "atol": 1e-3, + "rtol": 1e-3, + }, # Higher tolerance due to complex ops + InfiniDtype.BF16: {"atol": 5e-2, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-4}, +} + +# Global flags +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +def test( + handle, + device, + B, + T, + H, + Dk, + Dv, + chunk_size, + use_qk_l2norm, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing ChunkGatedDeltaRule on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, H={H}, Dk={Dk}, Dv={Dv}, chunk_size={chunk_size}, " + f"dtype={InfiniDtypeNames[dtype]}, use_qk_l2norm={use_qk_l2norm}" + ) + + # Input tensors are in (B, H, T, D) layout as they come from the model layers + q = TestTensor((B, H, T, Dk), None, dtype, device) + k = TestTensor((B, H, T, Dk), None, dtype, device) + v = TestTensor((B, H, T, Dv), None, dtype, device) + + g_logsigmoid = torch.randn(B, H, T, dtype=torch.float32) + g = TestTensor.from_torch(F.logsigmoid(g_logsigmoid), dtype, device) + beta_sigmoid = torch.randn(B, H, T, dtype=torch.float32) + beta = TestTensor.from_torch(torch.sigmoid(beta_sigmoid), dtype, device) + + initial_state = TestTensor((B, H, Dk, Dv), None, dtype, device) + # initial_state = None + # final_state = initial_state + + initial_state_desc = ctypes.c_void_p(0) + initial_state_data = ctypes.c_void_p(0) + initial_state_torch = None + if initial_state is not None: + initial_state_desc = initial_state.descriptor + initial_state_data = initial_state.data() + initial_state_torch = initial_state.torch_tensor() + + # Output tensors + out = TestTensor((B, H, T, Dv), None, dtype, device) + # final_state shape is (B, H, Dk, Dv) + final_state = TestTensor((B, H, Dk, Dv), None, dtype, device) + + # Run reference implementation + ans_out, ans_final_state = ref_chunk_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + chunk_size=chunk_size, + initial_state=initial_state_torch, + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + + if sync: + sync() + + # Create operator descriptor + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateChunkGatedDeltaRuleDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + final_state.descriptor, + q.descriptor, + k.descriptor, + v.descriptor, + g.descriptor, + beta.descriptor, + initial_state_desc, + ctypes.c_bool(use_qk_l2norm), + ctypes.c_size_t(chunk_size), + ) + ) + + # Get workspace size + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetChunkGatedDeltaRuleWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + # Invalidate descriptors to ensure kernel does not rely on them + q.destroy_desc() + k.destroy_desc() + v.destroy_desc() + g.destroy_desc() + beta.destroy_desc() + if initial_state is not None: + initial_state.destroy_desc() + out.destroy_desc() + final_state.destroy_desc() + + # Define the library call + def lib_chunk_gated_delta_rule(): + check_error( + LIBINFINIOP.infiniopChunkGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + final_state.data(), + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + initial_state_data, + None, + ) + ) + + # Execute the custom operator + lib_chunk_gated_delta_rule() + + if sync: + sync() + + # Verify correctness + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + print("--- Verifying Output Tensor ---") + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + + if DEBUG: + print("--- Verifying Final State Tensor ---") + debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) + assert torch.allclose( + final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol + ) + + # Clean up + check_error(LIBINFINIOP.infiniopDestroyChunkGatedDeltaRuleDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index a59bbdf3f..988f3ac8e 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2782,3 +2782,88 @@ def swap_(lib): lib.infiniopDestroySwapDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def recurrent_gated_delta_rule_(lib): + lib.infiniopCreateRecurrentGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopCreateRecurrentGatedDeltaRuleDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_bool, + ] + lib.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize.restype = c_int32 + lib.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopRecurrentGatedDeltaRule.restype = c_int32 + lib.infiniopRecurrentGatedDeltaRule.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyRecurrentGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopDestroyRecurrentGatedDeltaRuleDescriptor.argtypes = [ + infiniopOperatorDescriptor_t + ] + + +@OpRegister.operator +def chunk_gated_delta_rule_(lib): + lib.infiniopCreateChunkGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopCreateChunkGatedDeltaRuleDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_bool, + c_size_t, + ] + lib.infiniopGetChunkGatedDeltaRuleWorkspaceSize.restype = c_int32 + lib.infiniopGetChunkGatedDeltaRuleWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopChunkGatedDeltaRule.restype = c_int32 + lib.infiniopChunkGatedDeltaRule.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.argtypes = [ + infiniopOperatorDescriptor_t + ] diff --git a/test/infiniop/recurrent_gated_delta_rule.py b/test/infiniop/recurrent_gated_delta_rule.py new file mode 100644 index 000000000..8cca95262 --- /dev/null +++ b/test/infiniop/recurrent_gated_delta_rule.py @@ -0,0 +1,293 @@ +# test_recurrent_gated_delta_rule.py + +import torch +import torch.nn.functional as F +import ctypes +from ctypes import c_uint32, c_float, c_uint64, c_size_t, POINTER, addressof + +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + + +# ============================================================================== +# Reference Implementation +# ============================================================================== +# 从 modeling_qwen3_next.py 提供的生产环境PyTorch备选实现 +# 我们将严格对照此函数进行测试 +def ref_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + + # 生产环境的实现期望输入已经是 (B, H, T, D) + # 我们在测试数据生成时会直接生成这种格式,以模拟真实调用场景 + query, key, value, beta, g = [ + x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros( + batch_size, + num_heads, + sequence_length, + v_head_dim, + device=value.device, + dtype=torch.float32, + ) + + # 注意:这里的 initial_state 形状是 (B, H, Dk, Dv) + last_recurrent_state = ( + torch.zeros( + batch_size, + num_heads, + k_head_dim, + v_head_dim, + device=value.device, + dtype=torch.float32, + ) + if initial_state is None + else initial_state.to(torch.float32) + ) + + for i in range(sequence_length): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze( + -1 + ) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + + core_attn_out = core_attn_out.contiguous().to(initial_dtype) + if last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.contiguous().to(initial_dtype) + + return core_attn_out, last_recurrent_state + + +# ============================================================================== +# Test Configuration +# ============================================================================== +# (B, T, H, Dk, Dv, use_qk_l2norm) +# T (seq_len) is typically 1 for decode stage +_TEST_CASES_ = [ + (7, 1, 40, 128, 128, True), + (5, 1, 64, 128, 128, False), + (1, 1, 8, 64, 64, True), + # (16, 1, 32, 80, 80, True), +] + +# Data types for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, +} + +# Global flags for controlling test behavior +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +def test( + handle, + device, + B, + T, + H, + Dk, + Dv, + use_qk_l2norm, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing RecurrentGatedDeltaRule on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, H={H}, Dk={Dk}, Dv={Dv}, dtype={InfiniDtypeNames[dtype]}, " + f"use_qk_l2norm={use_qk_l2norm}" + ) + + # Create input tensors. + # IMPORTANT: We directly create tensors in (B, H, T, D) layout to match the production environment. + q = TestTensor((B, H, T, Dk), None, dtype, device) + k = TestTensor((B, H, T, Dk), None, dtype, device) + v = TestTensor((B, H, T, Dv), None, dtype, device) + # g and beta have shape (B, H, T) + g_logsigmoid = torch.randn(B, H, T, dtype=torch.float32) + g = TestTensor.from_torch(F.logsigmoid(g_logsigmoid), dtype, device) + beta_sigmoid = torch.randn(B, H, T, dtype=torch.float32) + beta = TestTensor.from_torch(torch.sigmoid(beta_sigmoid), dtype, device) + + initial_state = TestTensor((B, H, Dk, Dv), None, dtype, device) + + # Create output tensors + out = TestTensor((B, H, T, Dv), None, dtype, device) + final_state = TestTensor((B, H, Dk, Dv), None, dtype, device) + + # Run reference implementation + ans_out, ans_final_state = ref_recurrent_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + initial_state.torch_tensor(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + + if sync: + sync() + + # Create operator descriptor + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRecurrentGatedDeltaRuleDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + final_state.descriptor, + q.descriptor, + k.descriptor, + v.descriptor, + g.descriptor, + beta.descriptor, + initial_state.descriptor, + ctypes.c_bool(use_qk_l2norm), + ) + ) + + # Get workspace size and allocate memory + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + # Invalidate descriptors to ensure kernel does not rely on them + q.destroy_desc() + k.destroy_desc() + v.destroy_desc() + g.destroy_desc() + beta.destroy_desc() + initial_state.destroy_desc() + out.destroy_desc() + final_state.destroy_desc() + + # Define the library call as a lambda for profiling + def lib_recurrent_gated_delta_rule(): + check_error( + LIBINFINIOP.infiniopRecurrentGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + final_state.data(), + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + initial_state.data(), + None, + ) + ) + + # Execute the custom operator + lib_recurrent_gated_delta_rule() + + if sync: + sync() + + # Verify correctness + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + # Verify main output + if DEBUG: + print("--- Verifying Output Tensor ---") + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + + # Verify final state + if DEBUG: + print("--- Verifying Final State Tensor ---") + debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) + assert torch.allclose( + final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol + ) + # print(final_state.actual_tensor(), ans_final_state) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: ref_recurrent_gated_delta_rule( + q.torch_tensor(), k.torch_tensor(), v.torch_tensor(), + g.torch_tensor(), beta.torch_tensor(), initial_state.torch_tensor(), + output_final_state=True, use_qk_l2norm_in_kernel=use_qk_l2norm), + device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lib_recurrent_gated_delta_rule, device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + # Clean up resources + check_error( + LIBINFINIOP.infiniopDestroyRecurrentGatedDeltaRuleDescriptor(descriptor) + ) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options from command line arguments + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From 958a0f864fe7edd2b31e88d8bc16ce68f8b2038d Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Tue, 23 Jun 2026 02:35:12 +0000 Subject: [PATCH 2/3] issue/1210 refactor delta rule ops and optimize --- include/infinicore/ops.hpp | 2 + .../infinicore/ops/chunk_gated_delta_rule.hpp | 53 ++ .../ops/recurrent_gated_delta_rule.hpp | 55 ++ include/infiniop/ops/chunk_gated_delta_rule.h | 24 +- .../infiniop/ops/recurrent_gated_delta_rule.h | 22 +- python/infinicore/nn/functional/__init__.py | 4 + .../nn/functional/chunk_gated_delta_rule.py | 66 ++ .../functional/recurrent_gated_delta_rule.py | 47 + .../chunk_gated_delta_rule.cc | 172 ++++ .../chunk_gated_delta_rule_infiniop.cc | 108 +++ .../recurrent_gated_delta_rule.cc | 166 ++++ .../recurrent_gated_delta_rule_infiniop.cc | 99 +++ src/infinicore/pybind11/ops.hpp | 4 + .../pybind11/ops/chunk_gated_delta_rule.hpp | 42 + .../ops/recurrent_gated_delta_rule.hpp | 37 + .../chunk_gated_delta_rule.h | 99 ++- .../chunk_gated_delta_rule/cuda/kernel.cuh | 822 +++++++++++++----- .../ops/chunk_gated_delta_rule/info.h | 189 +++- .../nvidia/chunk_gated_delta_rule_nvidia.cu | 350 +++++--- .../ops/chunk_gated_delta_rule/operator.cc | 25 +- .../cuda/kernel.cuh | 194 +++-- .../ops/recurrent_gated_delta_rule/info.h | 156 +++- .../recurrent_gated_delta_rule_nvidia.cu | 250 ++++-- .../recurrent_gated_delta_rule/operator.cc | 22 +- .../recurrent_gated_delta_rule.h | 96 +- test/infinicore/ops/chunk_gated_delta_rule.py | 344 ++++++++ .../ops/recurrent_gated_delta_rule.py | 291 +++++++ test/infiniop/chunk_gated_delta_rule.py | 564 +++++++----- test/infiniop/libinfiniop/op_register.py | 13 +- test/infiniop/recurrent_gated_delta_rule.py | 362 +++++--- 30 files changed, 3709 insertions(+), 969 deletions(-) create mode 100644 include/infinicore/ops/chunk_gated_delta_rule.hpp create mode 100644 include/infinicore/ops/recurrent_gated_delta_rule.hpp create mode 100644 python/infinicore/nn/functional/chunk_gated_delta_rule.py create mode 100644 python/infinicore/nn/functional/recurrent_gated_delta_rule.py create mode 100644 src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc create mode 100644 src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc create mode 100644 src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc create mode 100644 src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp create mode 100644 src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp create mode 100644 test/infinicore/ops/chunk_gated_delta_rule.py create mode 100644 test/infinicore/ops/recurrent_gated_delta_rule.py diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 061bfbfd1..17ef59d3d 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -20,6 +20,7 @@ #include "ops/blas_dot.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" +#include "ops/chunk_gated_delta_rule.hpp" #include "ops/conv2d.hpp" #include "ops/cross_entropy.hpp" #include "ops/deepseek_moe.hpp" @@ -46,6 +47,7 @@ #include "ops/random_sample.hpp" #include "ops/rearrange.hpp" #include "ops/reciprocal.hpp" +#include "ops/recurrent_gated_delta_rule.hpp" #include "ops/relu.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" diff --git a/include/infinicore/ops/chunk_gated_delta_rule.hpp b/include/infinicore/ops/chunk_gated_delta_rule.hpp new file mode 100644 index 000000000..5102836e0 --- /dev/null +++ b/include/infinicore/ops/chunk_gated_delta_rule.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "infinicore.h" + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(ChunkGatedDeltaRule, + Tensor, + Tensor, + std::optional, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + std::optional, + std::optional, + std::optional, + bool, + size_t); + +__export Tensor chunk_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + std::optional cu_seqlens = std::nullopt, + std::optional initial_state_indices = std::nullopt, + std::optional final_state_indices = std::nullopt, + bool use_qk_l2norm = false, + size_t chunk_size = 64); + +__export void chunk_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm = false, + size_t chunk_size = 64); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/recurrent_gated_delta_rule.hpp b/include/infinicore/ops/recurrent_gated_delta_rule.hpp new file mode 100644 index 000000000..7d837c67b --- /dev/null +++ b/include/infinicore/ops/recurrent_gated_delta_rule.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "infinicore.h" + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(RecurrentGatedDeltaRule, + Tensor, + Tensor, + std::optional, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + std::optional, + std::optional, + bool); + +__export Tensor recurrent_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + const Tensor &initial_state, + bool use_qk_l2norm = false); + +__export Tensor recurrent_gated_delta_rule_indexed(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + const Tensor &initial_state_indices, + const Tensor &final_state_indices, + bool use_qk_l2norm = false); + +__export void recurrent_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm = false); + +} // namespace infinicore::op diff --git a/include/infiniop/ops/chunk_gated_delta_rule.h b/include/infiniop/ops/chunk_gated_delta_rule.h index 1822e8266..a9a9c74aa 100644 --- a/include/infiniop/ops/chunk_gated_delta_rule.h +++ b/include/infiniop/ops/chunk_gated_delta_rule.h @@ -8,14 +8,17 @@ typedef struct InfiniopDescriptor *infiniopChunkGatedDeltaRuleDescriptor_t; __INFINI_C __export infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( infiniopHandle_t handle, infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t final_state_desc, - infiniopTensorDescriptor_t q_desc, - infiniopTensorDescriptor_t k_desc, - infiniopTensorDescriptor_t v_desc, - infiniopTensorDescriptor_t g_desc, - infiniopTensorDescriptor_t beta_desc, - infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t out_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv] + infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk] + infiniopTensorDescriptor_t final_state_desc, // null when final_state_indices_desc is provided + infiniopTensorDescriptor_t q_desc, // padded: [B, T, Hk, Dk]; varlen: [1, total_tokens, Hk, Dk] + infiniopTensorDescriptor_t k_desc, // same shape as q + infiniopTensorDescriptor_t v_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv] + infiniopTensorDescriptor_t g_desc, // padded: [B, T, Hv]; varlen: [1, total_tokens, Hv] + infiniopTensorDescriptor_t beta_desc, // same shape/dtype as g + infiniopTensorDescriptor_t cu_seqlens_desc, // nullable; [B + 1], int32/int64 + infiniopTensorDescriptor_t initial_state_indices_desc, // nullable; [B], int32/int64; enables indexed state-pool reads + infiniopTensorDescriptor_t final_state_indices_desc, // nullable; [B], int32/int64; writes final state in-place to initial_state pool bool use_qk_l2norm, size_t chunk_size); @@ -28,13 +31,16 @@ __INFINI_C __export infiniStatus_t infiniopChunkGatedDeltaRule( void *workspace, size_t workspace_size, void *out, + void *initial_state, void *final_state, const void *q, const void *k, const void *v, const void *g, const void *beta, - const void *initial_state, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, void *stream); __INFINI_C __export infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor( diff --git a/include/infiniop/ops/recurrent_gated_delta_rule.h b/include/infiniop/ops/recurrent_gated_delta_rule.h index 85c1f3564..2865cc1f5 100644 --- a/include/infiniop/ops/recurrent_gated_delta_rule.h +++ b/include/infiniop/ops/recurrent_gated_delta_rule.h @@ -8,14 +8,16 @@ typedef struct InfiniopDescriptor *infiniopRecurrentGatedDeltaRuleDescriptor_t; __INFINI_C __export infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( infiniopHandle_t handle, infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t final_state_desc, - infiniopTensorDescriptor_t q_desc, - infiniopTensorDescriptor_t k_desc, - infiniopTensorDescriptor_t v_desc, - infiniopTensorDescriptor_t g_desc, - infiniopTensorDescriptor_t beta_desc, - infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t out_desc, // [B, T, Hv, Dv], T must be 1; last dim contiguous + infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk] + infiniopTensorDescriptor_t final_state_desc, // legacy/indexed out-of-place final state; null when final_state_indices_desc is provided + infiniopTensorDescriptor_t q_desc, // [B, T, Hk, Dk], T must be 1; last dim contiguous + infiniopTensorDescriptor_t k_desc, // [B, T, Hk, Dk], same shape as q; last dim contiguous + infiniopTensorDescriptor_t v_desc, // [B, T, Hv, Dv], Hv must be a multiple of Hk; last dim contiguous + infiniopTensorDescriptor_t g_desc, // [B, T, Hv]; may have a different fp dtype from q/k/v/out/state + infiniopTensorDescriptor_t beta_desc, // [B, T, Hv]; same dtype as g + infiniopTensorDescriptor_t initial_state_indices_desc, // nullable; [B], int32/int64; enables indexed pool mode + infiniopTensorDescriptor_t final_state_indices_desc, // nullable; [B], int32/int64; writes final state in-place to initial_state pool bool use_qk_l2norm); __INFINI_C __export infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( @@ -27,13 +29,15 @@ __INFINI_C __export infiniStatus_t infiniopRecurrentGatedDeltaRule( void *workspace, size_t workspace_size, void *out, + void *initial_state, void *final_state, const void *q, const void *k, const void *v, const void *g, const void *beta, - const void *initial_state, + const void *initial_state_indices, + const void *final_state_indices, void *stream); __INFINI_C __export infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor( diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index f3b9a0a2b..3c11242cc 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -7,6 +7,7 @@ from .avg_pool1d import avg_pool1d from .binary_cross_entropy_with_logits import binary_cross_entropy_with_logits from .causal_softmax import causal_softmax +from .chunk_gated_delta_rule import chunk_gated_delta_rule from .embedding import embedding from .flash_attention import flash_attention from .gaussian_nll_loss import gaussian_nll_loss @@ -23,6 +24,7 @@ from .pad import pad from .prelu import prelu from .random_sample import random_sample +from .recurrent_gated_delta_rule import recurrent_gated_delta_rule from .relu6 import relu6 from .rms_norm import rms_norm from .rope import RopeAlgo, rope @@ -44,6 +46,7 @@ "conv2d", "adaptive_max_pool1d", "causal_softmax", + "chunk_gated_delta_rule", "embedding", "flash_attention", "gaussian_nll_loss", @@ -56,6 +59,7 @@ "prelu", "relu6", "rms_norm", + "recurrent_gated_delta_rule", "sigmoid", "silu", "smooth_l1_loss", diff --git a/python/infinicore/nn/functional/chunk_gated_delta_rule.py b/python/infinicore/nn/functional/chunk_gated_delta_rule.py new file mode 100644 index 000000000..dbcf15e75 --- /dev/null +++ b/python/infinicore/nn/functional/chunk_gated_delta_rule.py @@ -0,0 +1,66 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def chunk_gated_delta_rule( + q: Tensor, + k: Tensor, + v: Tensor, + g: Tensor, + beta: Tensor, + initial_state: Tensor, + *, + cu_seqlens: Tensor | None = None, + initial_state_indices: Tensor | None = None, + final_state_indices: Tensor | None = None, + use_qk_l2norm: bool = False, + chunk_size: int = 64, +) -> Tensor: + """Run chunk gated delta rule and return only ``out``. + + Padded mode shapes: + q, k: ``[B, T, Hk, Dk]`` + v, out: ``[B, T, Hv, Dv]`` + g, beta: ``[B, T, Hv]`` + initial_state: ``[B, Hv, Dk, Dv]`` + + Continuous-batch mode shapes: + Pass ``cu_seqlens`` with shape ``[B + 1]`` and dtype int32/int64. + q, k: ``[1, total_tokens, Hk, Dk]`` + v, out: ``[1, total_tokens, Hv, Dv]`` + g, beta: ``[1, total_tokens, Hv]`` + + Indexed state-pool mode: + initial_state is ``[pool_size, Hv, Dv, Dk]``. + ``initial_state_indices`` and ``final_state_indices`` are both ``[B]`` + int32/int64 tensors. The final state is written in-place into + ``initial_state[final_state_indices]`` and no final state tensor is + returned. + + Notes: + ``Hv`` must be a multiple of ``Hk``. q/k/v/out may be strided in the + first three dimensions, but the last dimension must be contiguous. + g and beta may use a different floating dtype from q/k/v/state. + """ + if (initial_state_indices is None) != (final_state_indices is None): + raise ValueError( + "initial_state_indices and final_state_indices must be provided together" + ) + + return Tensor( + _infinicore.chunk_gated_delta_rule( + q._underlying, + k._underlying, + v._underlying, + g._underlying, + beta._underlying, + initial_state._underlying, + None if cu_seqlens is None else cu_seqlens._underlying, + None + if initial_state_indices is None + else initial_state_indices._underlying, + None if final_state_indices is None else final_state_indices._underlying, + use_qk_l2norm, + chunk_size, + ) + ) diff --git a/python/infinicore/nn/functional/recurrent_gated_delta_rule.py b/python/infinicore/nn/functional/recurrent_gated_delta_rule.py new file mode 100644 index 000000000..323701e93 --- /dev/null +++ b/python/infinicore/nn/functional/recurrent_gated_delta_rule.py @@ -0,0 +1,47 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def recurrent_gated_delta_rule( + q: Tensor, + k: Tensor, + v: Tensor, + g: Tensor, + beta: Tensor, + initial_state: Tensor, + *, + initial_state_indices: Tensor | None = None, + final_state_indices: Tensor | None = None, + use_qk_l2norm: bool = False, +) -> Tensor: + if initial_state_indices is None and final_state_indices is None: + return Tensor( + _infinicore.recurrent_gated_delta_rule( + q._underlying, + k._underlying, + v._underlying, + g._underlying, + beta._underlying, + initial_state._underlying, + use_qk_l2norm, + ) + ) + + if initial_state_indices is None or final_state_indices is None: + raise ValueError( + "initial_state_indices and final_state_indices must be provided together" + ) + + return Tensor( + _infinicore.recurrent_gated_delta_rule_indexed( + q._underlying, + k._underlying, + v._underlying, + g._underlying, + beta._underlying, + initial_state._underlying, + initial_state_indices._underlying, + final_state_indices._underlying, + use_qk_l2norm, + ) + ) diff --git a/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc new file mode 100644 index 000000000..2fccef444 --- /dev/null +++ b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc @@ -0,0 +1,172 @@ +#include "infinicore/ops/chunk_gated_delta_rule.hpp" +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(ChunkGatedDeltaRule); + +ChunkGatedDeltaRule::ChunkGatedDeltaRule(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state, q, k, v, g, beta); + if (final_state.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state.value()); + } + if (cu_seqlens.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, cu_seqlens.value()); + } + if (initial_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state_indices.value()); + } + if (final_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state_indices.value()); + } + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); +} + +void ChunkGatedDeltaRule::execute(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(ChunkGatedDeltaRule, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); +} + +static void check_4d_sequence_tensor(const Tensor &x, const char *name) { + if (x->shape().size() != 4) { + throw std::runtime_error(std::string("chunk_gated_delta_rule expects ") + name + " with shape [B, T, H, D] or [1, total_tokens, H, D]"); + } +} + +static Shape chunk_final_state_shape(const Tensor &q, + const Tensor &v, + const Tensor &initial_state, + std::optional cu_seqlens, + std::optional initial_state_indices) { + const auto &q_shape = q->shape(); + const auto &v_shape = v->shape(); + size_t B = cu_seqlens.has_value() ? cu_seqlens.value()->shape()[0] - 1 : v_shape[0]; + size_t Hv = v_shape[2]; + size_t Dk = q_shape[3]; + size_t Dv = v_shape[3]; + if (initial_state_indices.has_value()) { + return {B, Hv, Dv, Dk}; + } + return {B, Hv, Dk, Dv}; +} + +Tensor chunk_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + check_4d_sequence_tensor(q, "q"); + check_4d_sequence_tensor(k, "k"); + check_4d_sequence_tensor(v, "v"); + auto out = Tensor::empty(v->shape(), v->dtype(), v->device()); + std::optional final_state = std::nullopt; + if (!final_state_indices.has_value()) { + final_state = Tensor::empty(chunk_final_state_shape(q, v, initial_state, cu_seqlens, initial_state_indices), + initial_state->dtype(), + initial_state->device()); + } + chunk_gated_delta_rule_(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); + return out; +} + +void chunk_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + check_4d_sequence_tensor(q, "q"); + check_4d_sequence_tensor(k, "k"); + check_4d_sequence_tensor(v, "v"); + ChunkGatedDeltaRule::execute(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc new file mode 100644 index 000000000..96a2b3b44 --- /dev/null +++ b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc @@ -0,0 +1,108 @@ +#include "infinicore/ops/chunk_gated_delta_rule.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::chunk_gated_delta_rule_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, ChunkGatedDeltaRule, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, initial_state, q, k, v, g, beta; + std::optional final_state; + std::optional cu_seqlens; + std::optional initial_state_indices; + std::optional final_state_indices; +}; + +void *plan(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + size_t seed = hash_combine(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, ChunkGatedDeltaRule, + seed, + out->desc(), + initial_state->desc(), + final_state.has_value() ? final_state.value()->desc() : nullptr, + q->desc(), + k->desc(), + v->desc(), + g->desc(), + beta->desc(), + cu_seqlens.has_value() ? cu_seqlens.value()->desc() : nullptr, + initial_state_indices.has_value() ? initial_state_indices.value()->desc() : nullptr, + final_state_indices.has_value() ? final_state_indices.value()->desc() : nullptr, + use_qk_l2norm, + chunk_size); + + INFINIOP_WORKSPACE_TENSOR(workspace, ChunkGatedDeltaRule, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(initial_state), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(g), + graph::GraphTensor(beta), + final_state.has_value() ? std::optional(graph::GraphTensor(final_state.value())) : std::nullopt, + cu_seqlens.has_value() ? std::optional(graph::GraphTensor(cu_seqlens.value())) : std::nullopt, + initial_state_indices.has_value() ? std::optional(graph::GraphTensor(initial_state_indices.value())) : std::nullopt, + final_state_indices.has_value() ? std::optional(graph::GraphTensor(final_state_indices.value())) : std::nullopt}; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopChunkGatedDeltaRule( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->out->data(), + planned->initial_state->data(), + planned->final_state.has_value() ? planned->final_state.value()->data() : nullptr, + planned->q->data(), + planned->k->data(), + planned->v->data(), + planned->g->data(), + planned->beta->data(), + planned->cu_seqlens.has_value() ? planned->cu_seqlens.value()->data() : nullptr, + planned->initial_state_indices.has_value() ? planned->initial_state_indices.value()->data() : nullptr, + planned->final_state_indices.has_value() ? planned->final_state_indices.value()->data() : nullptr, + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(ChunkGatedDeltaRule, &plan, &run, &cleanup); + +} // namespace infinicore::op::chunk_gated_delta_rule_impl::infiniop diff --git a/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc new file mode 100644 index 000000000..4fd7bb69a --- /dev/null +++ b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc @@ -0,0 +1,166 @@ +#include "infinicore/ops/recurrent_gated_delta_rule.hpp" +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RecurrentGatedDeltaRule); + +RecurrentGatedDeltaRule::RecurrentGatedDeltaRule(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state, q, k, v, g, beta); + if (final_state.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state.value()); + } + if (initial_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state_indices.value()); + } + if (final_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state_indices.value()); + } + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); +} + +void RecurrentGatedDeltaRule::execute(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(RecurrentGatedDeltaRule, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); +} + +static Tensor ensure_4d_sequence_tensor(const Tensor &x, const char *name) { + if (x->shape().size() == 4) { + return x; + } + if (x->shape().size() == 3) { + return x->unsqueeze(1); + } + throw std::runtime_error(std::string("recurrent_gated_delta_rule expects ") + name + " with shape [B, T, H, D] or [B, H, D]"); +} + +static Shape recurrent_output_shape(const Tensor &v) { + const auto &shape = v->shape(); + return {shape[0], shape[1], shape[2], shape[3]}; +} + +Tensor recurrent_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + const Tensor &initial_state, + bool use_qk_l2norm) { + Tensor q4 = ensure_4d_sequence_tensor(q, "q"); + Tensor k4 = ensure_4d_sequence_tensor(k, "k"); + Tensor v4 = ensure_4d_sequence_tensor(v, "v"); + auto out = Tensor::empty(recurrent_output_shape(v4), v4->dtype(), v4->device()); + Shape final_state_shape = {v4->shape()[0], v4->shape()[2], q4->shape()[3], v4->shape()[3]}; + auto final_state = Tensor::empty(final_state_shape, initial_state->dtype(), initial_state->device()); + recurrent_gated_delta_rule_(out, + initial_state, + final_state, + q4, + k4, + v4, + g, + beta, + std::nullopt, + std::nullopt, + use_qk_l2norm); + return out; +} + +Tensor recurrent_gated_delta_rule_indexed(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + const Tensor &initial_state_indices, + const Tensor &final_state_indices, + bool use_qk_l2norm) { + Tensor q4 = ensure_4d_sequence_tensor(q, "q"); + Tensor k4 = ensure_4d_sequence_tensor(k, "k"); + Tensor v4 = ensure_4d_sequence_tensor(v, "v"); + auto out = Tensor::empty(recurrent_output_shape(v4), v4->dtype(), v4->device()); + recurrent_gated_delta_rule_(out, + initial_state, + std::nullopt, + q4, + k4, + v4, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); + return out; +} + +void recurrent_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + Tensor q4 = ensure_4d_sequence_tensor(q, "q"); + Tensor k4 = ensure_4d_sequence_tensor(k, "k"); + Tensor v4 = ensure_4d_sequence_tensor(v, "v"); + RecurrentGatedDeltaRule::execute(out, + initial_state, + final_state, + q4, + k4, + v4, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc new file mode 100644 index 000000000..ba52b2c40 --- /dev/null +++ b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc @@ -0,0 +1,99 @@ +#include "infinicore/ops/recurrent_gated_delta_rule.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::recurrent_gated_delta_rule_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RecurrentGatedDeltaRule, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, initial_state, q, k, v, g, beta; + std::optional final_state; + std::optional initial_state_indices; + std::optional final_state_indices; +}; + +void *plan(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + size_t seed = hash_combine(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, RecurrentGatedDeltaRule, + seed, + out->desc(), + initial_state->desc(), + final_state.has_value() ? final_state.value()->desc() : nullptr, + q->desc(), + k->desc(), + v->desc(), + g->desc(), + beta->desc(), + initial_state_indices.has_value() ? initial_state_indices.value()->desc() : nullptr, + final_state_indices.has_value() ? final_state_indices.value()->desc() : nullptr, + use_qk_l2norm); + + INFINIOP_WORKSPACE_TENSOR(workspace, RecurrentGatedDeltaRule, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(initial_state), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(g), + graph::GraphTensor(beta), + final_state.has_value() ? std::optional(graph::GraphTensor(final_state.value())) : std::nullopt, + initial_state_indices.has_value() ? std::optional(graph::GraphTensor(initial_state_indices.value())) : std::nullopt, + final_state_indices.has_value() ? std::optional(graph::GraphTensor(final_state_indices.value())) : std::nullopt}; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopRecurrentGatedDeltaRule( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->out->data(), + planned->initial_state->data(), + planned->final_state.has_value() ? planned->final_state.value()->data() : nullptr, + planned->q->data(), + planned->k->data(), + planned->v->data(), + planned->g->data(), + planned->beta->data(), + planned->initial_state_indices.has_value() ? planned->initial_state_indices.value()->data() : nullptr, + planned->final_state_indices.has_value() ? planned->final_state_indices.value()->data() : nullptr, + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RecurrentGatedDeltaRule, &plan, &run, &cleanup); + +} // namespace infinicore::op::recurrent_gated_delta_rule_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index ccf473ba8..435c1c119 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -34,6 +34,7 @@ #include "ops/cat.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" +#include "ops/chunk_gated_delta_rule.hpp" #include "ops/conv2d.hpp" #include "ops/cross_entropy.hpp" #include "ops/diff.hpp" @@ -89,6 +90,7 @@ #include "ops/random_sample.hpp" #include "ops/rearrange.hpp" #include "ops/reciprocal.hpp" +#include "ops/recurrent_gated_delta_rule.hpp" #include "ops/relu6.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" @@ -241,7 +243,9 @@ inline void bind(py::module &m) { bind_addcmul(m); bind_cdist(m); bind_binary_cross_entropy_with_logits(m); + bind_chunk_gated_delta_rule(m); bind_reciprocal(m); + bind_recurrent_gated_delta_rule(m); bind_upsample_bilinear(m); bind_kthvalue(m); bind_ldexp(m); diff --git a/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp b/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp new file mode 100644 index 000000000..6d8bf5cf2 --- /dev/null +++ b/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include "infinicore/ops/chunk_gated_delta_rule.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_chunk_gated_delta_rule(py::module &m) { + m.def("chunk_gated_delta_rule", + &op::chunk_gated_delta_rule, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g"), + py::arg("beta"), + py::arg("initial_state"), + py::arg("cu_seqlens") = std::nullopt, + py::arg("initial_state_indices") = std::nullopt, + py::arg("final_state_indices") = std::nullopt, + py::arg("use_qk_l2norm") = false, + py::arg("chunk_size") = 64, + R"doc(Chunk gated delta rule. Returns out only. + +Padded mode: + q/k: [B, T, Hk, Dk], v/out: [B, T, Hv, Dv], g/beta: [B, T, Hv], + initial_state: [B, Hv, Dk, Dv]. + +Continuous-batch mode: + pass cu_seqlens [B + 1]; q/k: [1, total_tokens, Hk, Dk], + v/out: [1, total_tokens, Hv, Dv], g/beta: [1, total_tokens, Hv]. + +Indexed pool mode: + initial_state is [pool_size, Hv, Dv, Dk]. Provide both initial_state_indices + and final_state_indices [B]; final state is written in-place to initial_state. +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp b/src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp new file mode 100644 index 000000000..127b62042 --- /dev/null +++ b/src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "infinicore/ops/recurrent_gated_delta_rule.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_recurrent_gated_delta_rule(py::module &m) { + m.def("recurrent_gated_delta_rule", + &op::recurrent_gated_delta_rule, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g"), + py::arg("beta"), + py::arg("initial_state"), + py::arg("use_qk_l2norm") = false, + R"doc(Recurrent gated delta rule. Returns out only.)doc"); + + m.def("recurrent_gated_delta_rule_indexed", + &op::recurrent_gated_delta_rule_indexed, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g"), + py::arg("beta"), + py::arg("initial_state"), + py::arg("initial_state_indices"), + py::arg("final_state_indices"), + py::arg("use_qk_l2norm") = false, + R"doc(Recurrent gated delta rule with indexed in-place state pool. Returns out only.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h b/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h index e6850d0bc..808e0e52c 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h +++ b/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h @@ -6,52 +6,57 @@ #include "../../operator.h" #include "info.h" -#define DESCRIPTOR(NAMESPACE) \ - \ - namespace op::chunk_gated_delta_rule::NAMESPACE { \ - class Descriptor final : public InfiniopDescriptor { \ - struct Opaque; \ - Opaque *_opaque; \ - ChunkGatedDeltaRuleInfo _info; \ - size_t _workspace_size; \ - \ - Descriptor( \ - Opaque *opaque, \ - ChunkGatedDeltaRuleInfo info, \ - size_t workspace_size, \ - infiniDevice_t device_type, \ - int device_id) \ - : InfiniopDescriptor{device_type, device_id}, \ - _opaque(opaque), \ - _info(info), \ - _workspace_size(workspace_size) {} \ - \ - public: \ - ~Descriptor(); \ - \ - size_t workspaceSize() const { return _workspace_size; } \ - \ - static infiniStatus_t create( \ - infiniopHandle_t handle, \ - Descriptor **desc_ptr, \ - infiniopTensorDescriptor_t out_desc, \ - infiniopTensorDescriptor_t final_state_desc, \ - infiniopTensorDescriptor_t q_desc, \ - infiniopTensorDescriptor_t k_desc, \ - infiniopTensorDescriptor_t v_desc, \ - infiniopTensorDescriptor_t g_desc, \ - infiniopTensorDescriptor_t beta_desc, \ - const std::optional &initial_state_desc, \ - bool use_qk_l2norm, \ - size_t chunk_size); \ - \ - infiniStatus_t calculate( \ - void *workspace, size_t workspace_size, \ - void *out, void *final_state, \ - const void *q, const void *k, const void *v, \ - const void *g, const void *beta, const void *initial_state, \ - void *stream) const; \ - }; \ +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::chunk_gated_delta_rule::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + ChunkGatedDeltaRuleInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + ChunkGatedDeltaRuleInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t initial_state_desc, \ + infiniopTensorDescriptor_t final_state_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_desc, \ + infiniopTensorDescriptor_t beta_desc, \ + infiniopTensorDescriptor_t cu_seqlens_desc, \ + infiniopTensorDescriptor_t initial_state_indices_desc, \ + infiniopTensorDescriptor_t final_state_indices_desc, \ + bool use_qk_l2norm, \ + size_t chunk_size); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, void *initial_state, void *final_state, \ + const void *q, const void *k, const void *v, \ + const void *g, const void *beta, const void *cu_seqlens, \ + const void *initial_state_indices, \ + const void *final_state_indices, \ + void *stream) const; \ + }; \ } -#endif // __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ \ No newline at end of file +#endif // __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh index 2bba28795..cc167b6c7 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh +++ b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh @@ -1,283 +1,667 @@ -// op/chunk_gated_delta_rule/cuda/kernel.cuh - #ifndef __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ #define __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ -#include +#include +#include #include -#include +__device__ inline int64_t loadOptionalIndex( + const void *indices, + bool is_i64, + int idx, + int fallback) { + if (indices == nullptr) { + return static_cast(fallback); + } + return is_i64 + ? static_cast(indices)[idx] + : static_cast(static_cast(indices)[idx]); +} + +template +__device__ inline float loadAsFloat(const T *ptr, ptrdiff_t offset) { + return static_cast(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat(const half *ptr, ptrdiff_t offset) { + return __half2float(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { + return __bfloat162float(ptr[offset]); +} + +#define CGDR_FOR(idx, n) \ + for (int idx = threadIdx.x; idx < static_cast(n); idx += blockDim.x) -template +template +__device__ Tcompute blockReduceSum(Tcompute v) { + __shared__ Tcompute smem[NUM_THREADS]; + smem[threadIdx.x] = v; + __syncthreads(); + + for (int s = NUM_THREADS / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + smem[threadIdx.x] += smem[threadIdx.x + s]; + } + __syncthreads(); + } + return smem[0]; +} + +template < + typename Tdata, + typename Tgate, + typename Tcompute, + size_t Dk, + size_t Dv, + size_t NUM_THREADS> __device__ void chunkGatedDeltaRuleKernel( + Tcompute *state_workspace, Tdata *out, + Tdata *initial_state, Tdata *final_state, const Tdata *q, const Tdata *k, const Tdata *v, - const Tdata *g, - const Tdata *beta, - const Tdata *initial_state, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, bool use_qk_l2norm, - const size_t chunk_size, - const size_t T // Original sequence length, must be passed from host -) { - // Grid Strategy: Each block handles one sequence for one head. - // gridDim.x = B, gridDim.y = H - const size_t batch_idx = blockIdx.x; - const size_t head_idx = blockIdx.y; - const size_t thread_idx = threadIdx.x; - - const size_t H = gridDim.y; - - const size_t T_padded = (T + chunk_size - 1) / chunk_size * chunk_size; - const size_t num_chunks = T_padded / chunk_size; - const float scale = rsqrtf(static_cast(Dk)); - - using BlockScan = cub::BlockScan; - - // --- Shared Memory Layout --- - extern __shared__ char shared_mem_char[]; - Tcompute *shared_mem = reinterpret_cast(shared_mem_char); - - // Pointers to different sections of shared memory - Tcompute *q_s = shared_mem; - Tcompute *k_s = q_s + chunk_size * Dk; - Tcompute *v_s = k_s + chunk_size * Dk; - Tcompute *k_beta_s = v_s + chunk_size * Dv; - Tcompute *g_s = k_beta_s + chunk_size * Dk; - Tcompute *beta_s = g_s + chunk_size; - Tcompute *g_cumsum_s = beta_s + chunk_size; - Tcompute *attn_s = g_cumsum_s + chunk_size; - Tcompute *k_cumdecay_s = attn_s + chunk_size * chunk_size; - Tcompute *value_prime_s = k_cumdecay_s + chunk_size * Dk; - Tcompute *v_prime_s = value_prime_s + chunk_size * Dv; - Tcompute *attn_inter_s = v_prime_s + chunk_size * Dv; - - typename BlockScan::TempStorage *cub_temp_storage = (typename BlockScan::TempStorage *)(attn_inter_s + chunk_size * Dv); - - // --- Main loop over chunks of the sequence --- - for (size_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { - const Tdata *current_state_ptr_g = (chunk_idx == 0 && initial_state != nullptr) ? initial_state : final_state; - const ptrdiff_t state_offset = (batch_idx * H + head_idx) * (Dk * Dv); + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t chunk_size, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { - __syncthreads(); - size_t chunk_offset = chunk_idx * chunk_size; - - // --- 2.1: Collaborative Loading of chunk data --- - // (This section is unchanged) - for (size_t i = thread_idx; i < chunk_size; i += BLOCK_THREADS) { - size_t t_idx = chunk_offset + i; - if (t_idx < T) { - ptrdiff_t gb_offset = (batch_idx * H * T) + (head_idx * T) + t_idx; - g_s[i] = static_cast(g[gb_offset]); - beta_s[i] = static_cast(beta[gb_offset]); - } else { - g_s[i] = 0.0f; - beta_s[i] = 1.0f; + const int batch_idx = blockIdx.x; + const int value_head_idx = blockIdx.y; + const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); + + if (key_head_idx >= static_cast(Hk)) { + return; + } + + int64_t token_begin = 0; + int64_t token_end = static_cast(T); + + if (has_cu_seqlens) { + token_begin = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx, 0); + token_end = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx + 1, 0); + if (token_begin < 0 || token_end < token_begin || token_end > static_cast(T)) { + return; + } + } + + int64_t read_slot = batch_idx; + int64_t write_slot = batch_idx; + + if (indexed_state_pool) { + read_slot = loadOptionalIndex( + initial_state_indices, + initial_state_indices_i64, + batch_idx, + batch_idx); + + write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadOptionalIndex( + final_state_indices, + final_state_indices_i64, + batch_idx, + batch_idx); + + if (read_slot < 0 || write_slot < 0 || read_slot >= static_cast(pool_size) || write_slot >= static_cast(pool_size)) { + return; + } + } + + const ptrdiff_t initial_base = indexed_state_pool + ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 + : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; + + Tdata *final_state_target = nullptr; + ptrdiff_t final_base = 0; + + if (indexed_state_pool && final_state_indices != nullptr) { + final_state_target = initial_state; + final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; + } else { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } + + const ptrdiff_t per_block_workspace = static_cast(Dk * Dv) + static_cast(chunk_size * Dk) * 3 + static_cast(chunk_size * Dv) * 3 + static_cast(chunk_size * chunk_size) + static_cast(chunk_size) * 3; + + const ptrdiff_t workspace_block = (static_cast(batch_idx) * gridDim.y + static_cast(value_head_idx)) * per_block_workspace; + + Tcompute *state_local = state_workspace + workspace_block; + Tcompute *q_buf = state_local + Dk * Dv; + Tcompute *k_buf = q_buf + chunk_size * Dk; + Tcompute *k_cumdecay = k_buf + chunk_size * Dk; + Tcompute *v_beta = k_cumdecay + chunk_size * Dk; + Tcompute *v_mid = v_beta + chunk_size * Dv; + Tcompute *v_new = v_mid + chunk_size * Dv; + Tcompute *attn = v_new + chunk_size * Dv; + Tcompute *g_cum = attn + chunk_size * chunk_size; + Tcompute *beta_buf = g_cum + chunk_size; + Tcompute *row_buf = beta_buf + chunk_size; + + const int token_batch = has_cu_seqlens ? 0 : batch_idx; + const Tcompute scale = rsqrtf(static_cast(Dk)); + + // Load initial state. + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t read_idx = indexed_state_pool + ? initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3 + : initial_base + static_cast(dk) * initial_s2 + static_cast(dv) * initial_s3; + + state_local[i] = static_cast( + loadAsFloat(initial_state, read_idx)); + } + __syncthreads(); + + for (int64_t chunk_begin = token_begin; + chunk_begin < token_end; + chunk_begin += static_cast(chunk_size)) { + + const int64_t remaining = token_end - chunk_begin; + const int actual_len = static_cast( + remaining < static_cast(chunk_size) + ? remaining + : static_cast(chunk_size)); + + // Load beta and cumulative gate. + if (threadIdx.x == 0) { + Tcompute running_g = 0; + for (int t = 0; t < static_cast(chunk_size); ++t) { + if (t < actual_len) { + int64_t token_idx = chunk_begin + t; + ptrdiff_t gate_offset = static_cast(token_batch) * g_s0 + static_cast(token_idx) * g_s1 + static_cast(value_head_idx) * g_s2; + ptrdiff_t beta_offset = static_cast(token_batch) * beta_s0 + static_cast(token_idx) * beta_s1 + static_cast(value_head_idx) * beta_s2; + + running_g += static_cast( + loadAsFloat(g, gate_offset)); + beta_buf[t] = static_cast( + loadAsFloat(beta, beta_offset)); + g_cum[t] = running_g; + } else { + beta_buf[t] = 0; + g_cum[t] = running_g; + } } } - for (size_t i = thread_idx; i < chunk_size * Dk; i += BLOCK_THREADS) { - size_t t_local = i / Dk; - size_t d = i % Dk; - size_t t_global = chunk_offset + t_local; - if (t_global < T) { - ptrdiff_t qk_offset = (batch_idx * H * T * Dk) + (head_idx * T * Dk) + (t_global * Dk) + d; - q_s[i] = static_cast(q[qk_offset]); - k_s[i] = static_cast(k[qk_offset]); + __syncthreads(); + + // Load q/k/v_beta. + CGDR_FOR(x, chunk_size * Dk) { + int t = x / Dk; + int dk = x % Dk; + + if (t < actual_len) { + int64_t token_idx = chunk_begin + t; + ptrdiff_t q_base = static_cast(token_batch) * q_s0 + static_cast(token_idx) * q_s1 + static_cast(key_head_idx) * q_s2; + ptrdiff_t k_base = static_cast(token_batch) * k_s0 + static_cast(token_idx) * k_s1 + static_cast(key_head_idx) * k_s2; + + q_buf[x] = static_cast(loadAsFloat(q, q_base + dk)) * scale; + k_buf[x] = static_cast(loadAsFloat(k, k_base + dk)); } else { - q_s[i] = 0.0f; - k_s[i] = 0.0f; + q_buf[x] = 0; + k_buf[x] = 0; } } - for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { - size_t t_local = i / Dv; - size_t d = i % Dv; - size_t t_global = chunk_offset + t_local; - if (t_global < T) { - ptrdiff_t v_offset = (batch_idx * H * T * Dv) + (head_idx * T * Dv) + (t_global * Dv) + d; - v_s[i] = static_cast(v[v_offset]); + + CGDR_FOR(x, chunk_size * Dv) { + int t = x / Dv; + int dv = x % Dv; + + if (t < actual_len) { + int64_t token_idx = chunk_begin + t; + ptrdiff_t v_base = static_cast(token_batch) * v_s0 + static_cast(token_idx) * v_s1 + static_cast(value_head_idx) * v_s2; + + v_beta[x] = static_cast(loadAsFloat(v, v_base + dv)) * beta_buf[t]; } else { - v_s[i] = 0.0f; + v_beta[x] = 0; } } __syncthreads(); - // --- 2.2: Optional L2 Normalization --- (Unchanged) + // Optional q/k L2 norm. if (use_qk_l2norm) { - // This loop is collapsed for brevity. It is correct and unchanged. - for (size_t t = thread_idx; t < chunk_size; t += BLOCK_THREADS) { - size_t t_global = chunk_offset + t; - if (t_global < T) { - Tcompute q_norm_sq = 0.0f; - Tcompute k_norm_sq = 0.0f; - for (size_t d = 0; d < Dk; ++d) { - Tcompute q_val = q_s[t * Dk + d]; - Tcompute k_val = k_s[t * Dk + d]; - q_norm_sq += q_val * q_val; - k_norm_sq += k_val * k_val; - } - Tcompute r_q_norm = rsqrtf(q_norm_sq + 1e-6f); - Tcompute r_k_norm = rsqrtf(k_norm_sq + 1e-6f); - for (size_t d = 0; d < Dk; ++d) { - q_s[t * Dk + d] *= r_q_norm; - k_s[t * Dk + d] *= r_k_norm; - } + for (int t = 0; t < static_cast(chunk_size); ++t) { + Tcompute q_sum = 0; + Tcompute k_sum = 0; + + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + q_sum += q_buf[t * Dk + dk] * q_buf[t * Dk + dk]; + k_sum += k_buf[t * Dk + dk] * k_buf[t * Dk + dk]; } - } - __syncthreads(); - } - // --- 2.3 Intra-Chunk Calculations --- (Unchanged, all operate on shared memory) - Tcompute g_val = (thread_idx < chunk_size) ? g_s[thread_idx] : 0.0f; - Tcompute g_cumsum_val; - BlockScan(*cub_temp_storage).InclusiveSum(g_val, g_cumsum_val); - if (thread_idx < chunk_size) { - g_cumsum_s[thread_idx] = g_cumsum_val; - } - __syncthreads(); - for (size_t i = thread_idx; i < chunk_size; i += BLOCK_THREADS) { - Tcompute beta_val = beta_s[i]; - for (size_t d = 0; d < Dk; ++d) { - k_beta_s[i * Dk + d] = k_s[i * Dk + d] * beta_val; - } - for (size_t d = 0; d < Dv; ++d) { - v_s[i * Dv + d] *= beta_val; - } - for (size_t d = 0; d < Dk; ++d) { - q_s[i * Dk + d] *= scale; + q_sum = blockReduceSum(q_sum); + k_sum = blockReduceSum(k_sum); + + Tcompute q_norm = rsqrtf(q_sum / (scale * scale) + 1e-6f); + Tcompute k_norm = rsqrtf(k_sum + 1e-6f); + + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + q_buf[t * Dk + dk] *= q_norm; + k_buf[t * Dk + dk] *= k_norm; + } + __syncthreads(); } } - __syncthreads(); - for (size_t i = thread_idx; i < chunk_size * chunk_size; i += BLOCK_THREADS) { - size_t row = i / chunk_size; - size_t col = i % chunk_size; - Tcompute dot_prod = 0.0f; - if (col < row) { - for (size_t d = 0; d < Dk; ++d) { dot_prod += k_beta_s[row * Dk + d] * k_s[col * Dk + d]; } - Tcompute decay = expf(g_cumsum_s[row] - g_cumsum_s[col]); - attn_s[i] = -dot_prod * decay; + + // Build lower-triangular attn. + CGDR_FOR(x, chunk_size * chunk_size) { + int i = x / chunk_size; + int j = x % chunk_size; + + if (j < i) { + Tcompute dot = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + dot += k_buf[i * Dk + dk] * beta_buf[i] * k_buf[j * Dk + dk]; + } + attn[x] = -dot * expf(g_cum[i] - g_cum[j]); } else { - attn_s[i] = 0.0f; + attn[x] = 0; } } __syncthreads(); - for (size_t i = 1; i < chunk_size; ++i) { - for (size_t j = thread_idx; j < i; j += BLOCK_THREADS) { - Tcompute update_val = 0.0f; - for (size_t l = 0; l < i; ++l) { update_val += attn_s[i * chunk_size + l] * attn_s[l * chunk_size + j]; } - attn_s[i * chunk_size + j] += update_val; + + // Triangular solve-like correction. + // Sequential in i, parallel in j. + for (int i = 1; i < static_cast(chunk_size); ++i) { + CGDR_FOR(m, chunk_size) { + row_buf[m] = m < i ? attn[i * chunk_size + m] : 0; + } + __syncthreads(); + + for (int j = threadIdx.x; j < i; j += blockDim.x) { + Tcompute correction = 0; + for (int m = 0; m < i; ++m) { + correction += row_buf[m] * attn[m * chunk_size + j]; + } + attn[i * chunk_size + j] = row_buf[j] + correction; } __syncthreads(); } - if (thread_idx < chunk_size) { - attn_s[thread_idx * chunk_size + thread_idx] += 1.0f; + + CGDR_FOR(i, chunk_size) { + attn[i * chunk_size + i] = 1; } __syncthreads(); - for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { - size_t row = i / Dv; - size_t col_v = i % Dv; - Tcompute dot_prod = 0.0f; - for (size_t d = 0; d < chunk_size; ++d) { - dot_prod += attn_s[row * chunk_size + d] * v_s[d * Dv + col_v]; + + // v_mid = attn @ v_beta. + CGDR_FOR(x, chunk_size * Dv) { + int i = x / Dv; + int dv = x % Dv; + + Tcompute sum = 0; + for (int j = 0; j < static_cast(chunk_size); ++j) { + sum += attn[i * chunk_size + j] * v_beta[j * Dv + dv]; } - value_prime_s[i] = dot_prod; + v_mid[x] = sum; } - for (size_t i = thread_idx; i < chunk_size * Dk; i += BLOCK_THREADS) { - size_t row = i / Dk; - int col_k = i % Dk; - Tcompute dot_prod = 0.0f; - for (size_t d = 0; d < chunk_size; ++d) { - dot_prod += attn_s[row * chunk_size + d] * k_beta_s[d * Dk + col_k] * expf(g_cumsum_s[d]); + + // k_cumdecay = attn @ (k * beta * exp(g)). + CGDR_FOR(x, chunk_size * Dk) { + int i = x / Dk; + int dk = x % Dk; + + Tcompute sum = 0; + for (int j = 0; j < static_cast(chunk_size); ++j) { + sum += attn[i * chunk_size + j] * k_buf[j * Dk + dk] * beta_buf[j] * expf(g_cum[j]); } - k_cumdecay_s[i] = dot_prod; + k_cumdecay[x] = sum; } __syncthreads(); - // --- 2.4: Inter-Chunk Interaction --- - // (Correctly reads from global memory) - for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { - size_t row = i / Dv; - size_t col_v = i % Dv; - Tcompute sum = 0.0f; - for (size_t d = 0; d < Dk; ++d) { - Tcompute state_val = (initial_state == nullptr && chunk_idx == 0) ? 0.0f : static_cast(current_state_ptr_g[state_offset + d * Dv + col_v]); - sum += k_cumdecay_s[row * Dk + d] * state_val; + // v_new = v_mid - k_cumdecay @ state. + CGDR_FOR(x, chunk_size * Dv) { + int i = x / Dv; + int dv = x % Dv; + + Tcompute v_prime = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + v_prime += k_cumdecay[i * Dk + dk] * state_local[dk * Dv + dv]; } - v_prime_s[i] = sum; + v_new[x] = v_mid[x] - v_prime; } - for (size_t i = thread_idx; i < chunk_size * Dv; i += BLOCK_THREADS) { - size_t row = i / Dv; - size_t col_v = i % Dv; - Tcompute sum = 0.0f; - Tcompute g_exp = expf(g_cumsum_s[row]); - for (size_t d = 0; d < Dk; ++d) { - Tcompute state_val = (initial_state == nullptr && chunk_idx == 0) ? 0.0f : static_cast(current_state_ptr_g[state_offset + d * Dv + col_v]); - sum += (q_s[row * Dk + d] * g_exp) * state_val; + __syncthreads(); + + // Output. + CGDR_FOR(x, actual_len * Dv) { + int i = x / Dv; + int dv = x % Dv; + + int64_t token_idx = chunk_begin + i; + ptrdiff_t out_base = static_cast(token_batch) * out_s0 + static_cast(token_idx) * out_s1 + static_cast(value_head_idx) * out_s2; + + Tcompute out_val = 0; + Tcompute q_decay = expf(g_cum[i]); + + for (int dk = 0; dk < static_cast(Dk); ++dk) { + out_val += q_buf[i * Dk + dk] * q_decay * state_local[dk * Dv + dv]; + } + + for (int j = 0; j <= i; ++j) { + Tcompute qk_attn = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + qk_attn += q_buf[i * Dk + dk] * k_buf[j * Dk + dk]; + } + + qk_attn *= expf(g_cum[i] - g_cum[j]); + out_val += qk_attn * v_new[j * Dv + dv]; } - attn_inter_s[i] = sum; + + out[out_base + dv] = static_cast(out_val); } __syncthreads(); - // --- 2.5: Final Output Calculation and Writeback --- (Unchanged) - for (size_t t = thread_idx; t < chunk_size; t += BLOCK_THREADS) { - size_t global_t = chunk_offset + t; - if (global_t < T) { - ptrdiff_t out_offset = (batch_idx * H * T * Dv) + (head_idx * T * Dv) + (global_t * Dv); - for (size_t d_v = 0; d_v < Dv; ++d_v) { - Tcompute intra_sum = 0.0f; - for (size_t j = 0; j <= t; ++j) { - Tcompute dot_qk = 0.0f; - for (size_t d_k = 0; d_k < Dk; ++d_k) { - dot_qk += q_s[t * Dk + d_k] * k_s[j * Dk + d_k]; - } - Tcompute value_prime_j = value_prime_s[j * Dv + d_v]; - Tcompute v_prime_j = v_prime_s[j * Dv + d_v]; - Tcompute v_new_j = value_prime_j - v_prime_j; - Tcompute decay = expf(g_cumsum_s[t] - g_cumsum_s[j]); - intra_sum += (dot_qk * decay) * v_new_j; - } - out[out_offset + d_v] = static_cast(attn_inter_s[t * Dv + d_v] + intra_sum); - } + // Update state. + const Tcompute last_decay = expf(g_cum[chunk_size - 1]); + + CGDR_FOR(x, Dk * Dv) { + int dk = x / Dv; + int dv = x % Dv; + + Tcompute next_state = state_local[x] * last_decay; + + for (int i = 0; i < static_cast(chunk_size); ++i) { + next_state += k_buf[i * Dk + dk] * expf(g_cum[chunk_size - 1] - g_cum[i]) * v_new[i * Dv + dv]; } + + state_local[x] = next_state; + } + __syncthreads(); + } + + // Store final state. + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t write_idx; + if (indexed_state_pool) { + const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; + const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; + + write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; + } else { + write_idx = final_base + static_cast(dk) * final_s2 + static_cast(dv) * final_s3; + } + + final_state_target[write_idx] = static_cast(state_local[i]); + } +} + +template < + typename Tdata, + typename Tgate, + typename Tcompute, + size_t Dk, + size_t Dv, + size_t NUM_THREADS> +__device__ void chunkGatedDeltaRuleRecurrentKernel( + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + + const int batch_idx = blockIdx.x; + const int value_head_idx = blockIdx.y; + const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); + + if (key_head_idx >= static_cast(Hk)) { + return; + } + + int64_t token_begin = 0; + int64_t token_end = static_cast(T); + + if (has_cu_seqlens) { + token_begin = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx, 0); + token_end = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx + 1, 0); + if (token_begin < 0 || token_end < token_begin || token_end > static_cast(T)) { + return; } + } + + int64_t read_slot = batch_idx; + int64_t write_slot = batch_idx; - // --- 2.6: Update inter_chunk_state for the next iteration --- - // (Correctly reads-updates-writes to global memory) + if (indexed_state_pool) { + read_slot = loadOptionalIndex( + initial_state_indices, + initial_state_indices_i64, + batch_idx, + batch_idx); + + write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadOptionalIndex( + final_state_indices, + final_state_indices_i64, + batch_idx, + batch_idx); + + if (read_slot < 0 || write_slot < 0 || read_slot >= static_cast(pool_size) || write_slot >= static_cast(pool_size)) { + return; + } + } + + const ptrdiff_t initial_base = indexed_state_pool + ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 + : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; + + Tdata *final_state_target = nullptr; + ptrdiff_t final_base = 0; + + if (indexed_state_pool && final_state_indices != nullptr) { + final_state_target = initial_state; + final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; + } else { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } + + const ptrdiff_t workspace_block = (static_cast(batch_idx) * gridDim.y + static_cast(value_head_idx)) * static_cast(Dk * Dv); + + Tcompute *state_local = state_workspace + workspace_block; + __shared__ Tcompute q_vec[Dk]; + __shared__ Tcompute k_vec[Dk]; + __shared__ Tcompute v_new[Dv]; + __shared__ Tcompute scalar_buf[3]; + + const int token_batch = has_cu_seqlens ? 0 : batch_idx; + const Tcompute scale = rsqrtf(static_cast(Dk)); + + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t read_idx = indexed_state_pool + ? initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3 + : initial_base + static_cast(dk) * initial_s2 + static_cast(dv) * initial_s3; + + state_local[i] = static_cast( + loadAsFloat(initial_state, read_idx)); + } + __syncthreads(); + + for (int64_t token_idx = token_begin; token_idx < token_end; ++token_idx) { + ptrdiff_t q_base = static_cast(token_batch) * q_s0 + static_cast(token_idx) * q_s1 + static_cast(key_head_idx) * q_s2; + ptrdiff_t k_base = static_cast(token_batch) * k_s0 + static_cast(token_idx) * k_s1 + static_cast(key_head_idx) * k_s2; + + Tcompute q_sum = 0; + Tcompute k_sum = 0; + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + Tcompute q_raw = static_cast(loadAsFloat(q, q_base + dk)); + Tcompute k_raw = static_cast(loadAsFloat(k, k_base + dk)); + q_vec[dk] = q_raw; + k_vec[dk] = k_raw; + q_sum += q_raw * q_raw; + k_sum += k_raw * k_raw; + } + + q_sum = blockReduceSum(q_sum); + k_sum = blockReduceSum(k_sum); + + if (threadIdx.x == 0) { + ptrdiff_t gate_offset = static_cast(token_batch) * g_s0 + static_cast(token_idx) * g_s1 + static_cast(value_head_idx) * g_s2; + ptrdiff_t beta_offset = static_cast(token_batch) * beta_s0 + static_cast(token_idx) * beta_s1 + static_cast(value_head_idx) * beta_s2; + + scalar_buf[0] = expf(static_cast(loadAsFloat(g, gate_offset))); + scalar_buf[1] = static_cast(loadAsFloat(beta, beta_offset)); + scalar_buf[2] = use_qk_l2norm + ? rsqrtf(q_sum + static_cast(1e-6)) * scale + : scale; + } __syncthreads(); - Tcompute g_final_cumsum = g_cumsum_s[chunk_size - 1]; - Tcompute final_decay_factor = expf(g_final_cumsum); - Tdata *final_state_ptr = final_state + state_offset; - for (size_t i = thread_idx; i < Dk * Dv; i += BLOCK_THREADS) { - size_t dk = i / Dv; - size_t dv = i % Dv; + const Tcompute decay = scalar_buf[0]; + const Tcompute beta_t = scalar_buf[1]; + const Tcompute q_scale = scalar_buf[2]; + const Tcompute k_scale = use_qk_l2norm + ? rsqrtf(k_sum + static_cast(1e-6)) + : static_cast(1); - Tcompute old_state_val; - if (chunk_idx == 0) { - old_state_val = (initial_state != nullptr) ? static_cast(initial_state[state_offset + i]) : 0.0f; - } else { - old_state_val = static_cast(final_state_ptr[i]); + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + q_vec[dk] *= q_scale; + k_vec[dk] *= k_scale; + } + __syncthreads(); + + for (int dv = threadIdx.x; dv < static_cast(Dv); dv += blockDim.x) { + Tcompute projected = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + projected += k_vec[dk] * state_local[dk * Dv + dv]; } - Tcompute decayed_state = old_state_val * final_decay_factor; - - Tcompute chunk_contribution = 0.0f; - for (size_t t = 0; t < chunk_size; ++t) { - Tcompute decay_factor = expf(g_final_cumsum - g_cumsum_s[t]); - Tcompute value_prime_t = value_prime_s[t * Dv + dv]; - Tcompute v_prime_t = v_prime_s[t * Dv + dv]; - Tcompute v_new_t = value_prime_t - v_prime_t; - chunk_contribution += (k_s[t * Dk + dk] * decay_factor) * v_new_t; + + ptrdiff_t v_base = static_cast(token_batch) * v_s0 + static_cast(token_idx) * v_s1 + static_cast(value_head_idx) * v_s2; + Tcompute v_raw = static_cast(loadAsFloat(v, v_base + dv)); + v_new[dv] = beta_t * (v_raw - decay * projected); + } + __syncthreads(); + + CGDR_FOR(x, Dk * Dv) { + int dk = x / Dv; + int dv = x % Dv; + state_local[x] = decay * state_local[x] + k_vec[dk] * v_new[dv]; + } + __syncthreads(); + + for (int dv = threadIdx.x; dv < static_cast(Dv); dv += blockDim.x) { + Tcompute out_val = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + out_val += q_vec[dk] * state_local[dk * Dv + dv]; } - final_state_ptr[i] = static_cast(decayed_state + chunk_contribution); + ptrdiff_t out_base = static_cast(token_batch) * out_s0 + static_cast(token_idx) * out_s1 + static_cast(value_head_idx) * out_s2; + out[out_base + dv] = static_cast(out_val); + } + __syncthreads(); + } + + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t write_idx; + if (indexed_state_pool) { + const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; + const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; + + write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; + } else { + write_idx = final_base + static_cast(dk) * final_s2 + static_cast(dv) * final_s3; } - // BUG FIX: Add a block-wide memory fence to ensure global memory writes from this - // iteration are visible to all threads before the next iteration begins. - __threadfence_block(); + final_state_target[write_idx] = static_cast(state_local[i]); } } -#endif // __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ \ No newline at end of file +#endif // __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/info.h b/src/infiniop/ops/chunk_gated_delta_rule/info.h index 922153d0d..4dcb25319 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/info.h +++ b/src/infiniop/ops/chunk_gated_delta_rule/info.h @@ -5,7 +5,6 @@ #include "../../../utils.h" #include "../../tensor.h" -#include #include namespace op { @@ -15,56 +14,188 @@ class ChunkGatedDeltaRuleInfo { ChunkGatedDeltaRuleInfo() = default; public: - // --- Data Types and Flags --- - infiniDtype_t dtype; - bool use_qk_l2norm; - - // --- Shape Dimensions --- - size_t B, H, T, Dk, Dv, chunk_size; + infiniDtype_t data_dtype; + infiniDtype_t gate_dtype; + infiniDtype_t cu_seqlens_dtype; + infiniDtype_t initial_state_indices_dtype; + infiniDtype_t final_state_indices_dtype; - // --- Strides for Memory Layout --- - // Strides can be added here if needed for more complex layouts + bool use_qk_l2norm; + bool has_cu_seqlens; + bool has_initial_state_indices; + bool has_final_state_indices; + bool indexed_state_pool; + + size_t B, T, total_tokens, Hk, Hv, Dk, Dv, chunk_size, pool_size, value_heads_per_key_head; + + std::vector out_strides; + std::vector initial_state_strides; + std::vector final_state_strides; + std::vector q_strides; + std::vector k_strides; + std::vector v_strides; + std::vector g_strides; + std::vector beta_strides; static utils::Result create(infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, infiniopTensorDescriptor_t final_state_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_desc, - const std::optional &initial_state_desc, + infiniopTensorDescriptor_t cu_seqlens_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, bool use_qk_l2norm, size_t chunk_size) { - auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (out_desc == nullptr || initial_state_desc == nullptr || q_desc == nullptr || k_desc == nullptr || v_desc == nullptr || g_desc == nullptr || beta_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } - // Check for consistent data types across all tensors - if (out_desc->dtype() != dtype || final_state_desc->dtype() != dtype || k_desc->dtype() != dtype || v_desc->dtype() != dtype || g_desc->dtype() != dtype || beta_desc->dtype() != dtype) { + if (chunk_size == 0) { + return INFINI_STATUS_BAD_PARAM; + } + + auto data_dtype = q_desc->dtype(); + CHECK_DTYPE(data_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (k_desc->dtype() != data_dtype || v_desc->dtype() != data_dtype || out_desc->dtype() != data_dtype || initial_state_desc->dtype() != data_dtype || (final_state_desc != nullptr && final_state_desc->dtype() != data_dtype)) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - // Check tensor dimensions - if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || out_desc->ndim() != 4 || final_state_desc->ndim() != 4) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; + auto gate_dtype = g_desc->dtype(); + CHECK_DTYPE(gate_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (beta_desc->dtype() != gate_dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; } - ChunkGatedDeltaRuleInfo info; - info.dtype = dtype; - info.use_qk_l2norm = use_qk_l2norm; - info.chunk_size = chunk_size; + bool has_cu = cu_seqlens_desc != nullptr; + bool has_initial_indices = initial_state_indices_desc != nullptr; + bool has_final_indices = final_state_indices_desc != nullptr; + bool indexed_pool = has_initial_indices || has_final_indices; + + if (has_final_indices && final_state_desc != nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + if (!has_final_indices && final_state_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || out_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || initial_state_desc->ndim() != 4 || (!has_final_indices && final_state_desc->ndim() != 4)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } auto q_shape = q_desc->shape(); - info.B = q_shape[0]; - info.H = q_shape[1]; - info.T = q_shape[2]; - info.Dk = q_shape[3]; + auto k_shape = k_desc->shape(); + auto v_shape = v_desc->shape(); + auto out_shape = out_desc->shape(); + auto g_shape = g_desc->shape(); + auto beta_shape = beta_desc->shape(); + + size_t B = q_shape[0], T = q_shape[1], Hk = q_shape[2], Dk = q_shape[3]; + size_t Hv = v_shape[2], Dv = v_shape[3], total_tokens = T; + + if (has_cu) { + if (cu_seqlens_desc->ndim() != 1 || cu_seqlens_desc->shape()[0] < 2) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + B = cu_seqlens_desc->shape()[0] - 1; + if (q_shape[0] != 1 || k_shape[0] != 1 || v_shape[0] != 1 || out_shape[0] != 1 || g_shape[0] != 1 || beta_shape[0] != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + total_tokens = q_shape[1]; + T = total_tokens; + } - info.Dv = v_desc->shape()[3]; + if (k_shape[0] != q_shape[0] || k_shape[1] != q_shape[1] || k_shape[2] != Hk || k_shape[3] != Dk || v_shape[0] != q_shape[0] || v_shape[1] != q_shape[1] || out_shape[0] != q_shape[0] || out_shape[1] != q_shape[1] || out_shape[2] != Hv || out_shape[3] != Dv || g_shape[0] != q_shape[0] || g_shape[1] != q_shape[1] || g_shape[2] != Hv || beta_shape[0] != q_shape[0] || beta_shape[1] != q_shape[1] || beta_shape[2] != Hv || Hk == 0 || Hv == 0 || Hv % Hk != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (q_desc->strides()[3] != 1 || k_desc->strides()[3] != 1 || v_desc->strides()[3] != 1 || out_desc->strides()[3] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + auto initial_shape = initial_state_desc->shape(); + size_t pool_size = initial_shape[0]; + if (indexed_pool) { + if (initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (!has_final_indices) { + auto final_shape = final_state_desc->shape(); + if (indexed_pool) { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + } - // Further validation can be added here to ensure all shapes are compatible. - // For example, check if initial_state_desc shape is [B, H, Dk, Dv]. + infiniDtype_t cu_dtype = INFINI_DTYPE_INVALID; + infiniDtype_t initial_indices_dtype = INFINI_DTYPE_INVALID; + infiniDtype_t final_indices_dtype = INFINI_DTYPE_INVALID; + if (has_cu) { + cu_dtype = cu_seqlens_desc->dtype(); + CHECK_DTYPE(cu_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + if (has_initial_indices) { + if (initial_state_indices_desc->ndim() != 1 || initial_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + initial_indices_dtype = initial_state_indices_desc->dtype(); + CHECK_DTYPE(initial_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + if (has_final_indices) { + if (final_state_indices_desc->ndim() != 1 || final_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + final_indices_dtype = final_state_indices_desc->dtype(); + CHECK_DTYPE(final_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + + ChunkGatedDeltaRuleInfo info; + info.data_dtype = data_dtype; + info.gate_dtype = gate_dtype; + info.cu_seqlens_dtype = cu_dtype; + info.initial_state_indices_dtype = initial_indices_dtype; + info.final_state_indices_dtype = final_indices_dtype; + info.use_qk_l2norm = use_qk_l2norm; + info.has_cu_seqlens = has_cu; + info.has_initial_state_indices = has_initial_indices; + info.has_final_state_indices = has_final_indices; + info.indexed_state_pool = indexed_pool; + info.B = B; + info.T = T; + info.total_tokens = total_tokens; + info.Hk = Hk; + info.Hv = Hv; + info.Dk = Dk; + info.Dv = Dv; + info.chunk_size = chunk_size; + info.pool_size = pool_size; + info.value_heads_per_key_head = Hv / Hk; + info.out_strides = out_desc->strides(); + info.initial_state_strides = initial_state_desc->strides(); + if (final_state_desc != nullptr) { + info.final_state_strides = final_state_desc->strides(); + } + info.q_strides = q_desc->strides(); + info.k_strides = k_desc->strides(); + info.v_strides = v_desc->strides(); + info.g_strides = g_desc->strides(); + info.beta_strides = beta_desc->strides(); return utils::Result(info); } @@ -73,4 +204,4 @@ class ChunkGatedDeltaRuleInfo { } // namespace chunk_gated_delta_rule } // namespace op -#endif // __CHUNK_GATED_DELTA_RULE_INFO_H__ \ No newline at end of file +#endif // __CHUNK_GATED_DELTA_RULE_INFO_H__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu index f5f96e859..461f95ad5 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu +++ b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu @@ -1,22 +1,130 @@ -// chunk_gated_delta_rule_nvidia.cu - #include "../../../devices/nvidia/nvidia_common.cuh" -#include "chunk_gated_delta_rule_nvidia.cuh" - #include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "chunk_gated_delta_rule_nvidia.cuh" #include "../cuda/kernel.cuh" #include -// Kernel Launcher Wrapper -template +template INFINIOP_CUDA_KERNEL chunkGatedDeltaRule( - Tdata *out, Tdata *final_state, - const Tdata *q, const Tdata *k, const Tdata *v, - const Tdata *g, const Tdata *beta, const Tdata *initial_state, - bool use_qk_l2norm, size_t chunk_size, size_t T) { - chunkGatedDeltaRuleKernel( - out, final_state, q, k, v, g, beta, initial_state, use_qk_l2norm, chunk_size, T); + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t chunk_size, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + chunkGatedDeltaRuleRecurrentKernel( + state_workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, cu_seqlens_i64, + initial_state_indices_i64, final_state_indices_i64, use_qk_l2norm, + has_cu_seqlens, indexed_state_pool, T, chunk_size, pool_size, Hk, value_heads_per_key_head, + out_s0, out_s1, out_s2, initial_s0, initial_s1, initial_s2, initial_s3, + final_s0, final_s1, final_s2, final_s3, q_s0, q_s1, q_s2, k_s0, k_s1, + k_s2, v_s0, v_s1, v_s2, g_s0, g_s1, g_s2, beta_s0, beta_s1, beta_s2); +} + +template +INFINIOP_CUDA_KERNEL chunkGatedDeltaRuleChunked( + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t chunk_size, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + chunkGatedDeltaRuleKernel( + state_workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, cu_seqlens_i64, + initial_state_indices_i64, final_state_indices_i64, use_qk_l2norm, + has_cu_seqlens, indexed_state_pool, T, chunk_size, pool_size, Hk, value_heads_per_key_head, + out_s0, out_s1, out_s2, initial_s0, initial_s1, initial_s2, initial_s3, + final_s0, final_s1, final_s2, final_s3, q_s0, q_s1, q_s2, k_s0, k_s1, + k_s2, v_s0, v_s1, v_s2, g_s0, g_s1, g_s2, beta_s0, beta_s1, beta_s2); } namespace op { @@ -35,147 +143,163 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, infiniopTensorDescriptor_t final_state_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_desc, - const std::optional &initial_state_desc, + infiniopTensorDescriptor_t cu_seqlens_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, bool use_qk_l2norm, size_t chunk_size) { auto info = ChunkGatedDeltaRuleInfo::create( - out_desc, final_state_desc, q_desc, k_desc, v_desc, - g_desc, beta_desc, initial_state_desc, use_qk_l2norm, chunk_size); + out_desc, initial_state_desc, final_state_desc, q_desc, k_desc, v_desc, + g_desc, beta_desc, cu_seqlens_desc, initial_state_indices_desc, + final_state_indices_desc, use_qk_l2norm, chunk_size); CHECK_RESULT(info); - // Calculate workspace size if needed, here it's 0 - size_t workspace_size = 0; + auto info_value = info.take(); + // We always want to use fast path, slow path is kept as a ref + const bool use_chunked_fallback = false; + const size_t per_block_workspace = use_chunked_fallback + ? info_value.Dk * info_value.Dv + info_value.chunk_size * info_value.Dk * 3 + info_value.chunk_size * info_value.Dv * 3 + info_value.chunk_size * info_value.chunk_size + info_value.chunk_size * 3 + : info_value.Dk * info_value.Dv; + const size_t workspace_size = info_value.B * info_value.Hv * per_block_workspace * sizeof(float); *desc_ptr = new Descriptor( new Opaque{reinterpret_cast(handle)->internal()}, - info.take(), workspace_size, handle->device, handle->device_id); + info_value, workspace_size, handle->device, handle->device_id); - return infiniStatus_t::INFINI_STATUS_SUCCESS; + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernelWithGateDtype( + void *workspace, + void *out, + void *initial_state, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + const ChunkGatedDeltaRuleInfo &info, + cudaStream_t stream) { + +#define LAUNCH_ARGS(TYPE) \ + static_cast(workspace), static_cast(out), static_cast(initial_state), static_cast(final_state), \ + static_cast(q), static_cast(k), static_cast(v), \ + static_cast(g), static_cast(beta), cu_seqlens, initial_state_indices, \ + final_state_indices, info.cu_seqlens_dtype == INFINI_DTYPE_I64, \ + info.initial_state_indices_dtype == INFINI_DTYPE_I64, info.final_state_indices_dtype == INFINI_DTYPE_I64, \ + info.use_qk_l2norm, info.has_cu_seqlens, info.indexed_state_pool, info.T, info.chunk_size, info.pool_size, info.Hk, \ + info.value_heads_per_key_head, info.out_strides[0], info.out_strides[1], info.out_strides[2], \ + info.initial_state_strides[0], info.initial_state_strides[1], info.initial_state_strides[2], \ + info.initial_state_strides[3], info.final_state_strides.empty() ? 0 : info.final_state_strides[0], \ + info.final_state_strides.empty() ? 0 : info.final_state_strides[1], info.final_state_strides.empty() ? 0 : info.final_state_strides[2], \ + info.final_state_strides.empty() ? 0 : info.final_state_strides[3], info.q_strides[0], info.q_strides[1], \ + info.q_strides[2], info.k_strides[0], info.k_strides[1], info.k_strides[2], info.v_strides[0], \ + info.v_strides[1], info.v_strides[2], info.g_strides[0], info.g_strides[1], info.g_strides[2], \ + info.beta_strides[0], info.beta_strides[1], info.beta_strides[2] + +#define LAUNCH_GATE(TYPE) \ + do { \ + if (false) { \ + chunkGatedDeltaRuleChunked \ + <<>>(LAUNCH_ARGS(TYPE)); \ + } else { \ + chunkGatedDeltaRule \ + <<>>(LAUNCH_ARGS(TYPE)); \ + } \ + } while (0) + + if (info.gate_dtype == INFINI_DTYPE_F16) { + LAUNCH_GATE(half); + } else if (info.gate_dtype == INFINI_DTYPE_BF16) { + LAUNCH_GATE(__nv_bfloat16); + } else if (info.gate_dtype == INFINI_DTYPE_F32) { + LAUNCH_GATE(float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +#undef LAUNCH_GATE +#undef LAUNCH_ARGS + + return INFINI_STATUS_SUCCESS; } template infiniStatus_t launchKernel( - void *out, void *final_state, - const void *q, const void *k, const void *v, - const void *g, const void *beta, const void *initial_state, - bool use_qk_l2norm, - infiniDtype_t dtype, - size_t B, size_t H, size_t T, size_t chunk_size, + void *workspace, + void *out, + void *initial_state, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + const ChunkGatedDeltaRuleInfo &info, cudaStream_t stream) { - dim3 grid(uint32_t(B), uint32_t(H), 1); - dim3 block(NUM_THREADS); - // Shared memory for local Q, K, and one reduction value - // size_t shared_mem_size = (Dk + Dk + NUM_THREADS) * sizeof(float); - - using Tcompute = float; - using BlockScan = cub::BlockScan; - // using BlockReduce = cub::BlockReduce; - - // size_t shared_mem_size = ( - // chunk_size * (3 * Dk + Dv + 3) + - // chunk_size * chunk_size + - // Dk * Dv - // ) * sizeof(Tcompute) + sizeof(typename BlockScan::TempStorage) + sizeof(typename BlockReduce::TempStorage); - // size_t shared_mem_size = ( - // // q_s, k_s, k_beta_s, k_cumdecay_s - // chunk_size * 4 * Dk + - // // v_s, value_prime_s, v_prime_s, attn_inter_s - // chunk_size * 4 * Dv + - // // g_s, beta_s, g_cumsum_s - // chunk_size * 3 + - // // attn_s (removed decay_mask_s) - // chunk_size * chunk_size + - // // inter_chunk_state_s - // Dk * Dv - // ) * sizeof(Tcompute) + sizeof(typename BlockScan::TempStorage) + sizeof(typename BlockReduce::TempStorage); - - // size_t shared_mem_size = ( - // // q_s, k_s, k_beta_s, k_cumdecay_s - // chunk_size * 4 * Dk + - // // v_s, value_prime_s, v_prime_s (v_new_s is still here from prev version) - // chunk_size * 4 * Dv + - // // g_s, beta_s, g_cumsum_s - // chunk_size * 3 + - // // attn_s - // chunk_size * chunk_size + - // // inter_chunk_state_s - // Dk * Dv - // ) * sizeof(Tcompute) + sizeof(typename BlockScan::TempStorage) + sizeof(typename BlockReduce::TempStorage); - size_t shared_mem_size = ( - // q_s, k_s, k_beta_s, k_cumdecay_s - chunk_size * 4 * Dk + - // v_s, value_prime_s, v_prime_s, attn_inter_s - chunk_size * 4 * Dv + - // g_s, beta_s, g_cumsum_s - chunk_size * 3 + - // attn_s - chunk_size * chunk_size - // NOTE: Dk * Dv term for inter_chunk_state_s has been removed. - ) - * sizeof(Tcompute) - + sizeof(typename BlockScan::TempStorage); - - if (dtype == INFINI_DTYPE_F16) { - chunkGatedDeltaRule - <<>>( - (half *)out, (half *)final_state, - (const half *)q, (const half *)k, (const half *)v, - (const half *)g, (const half *)beta, (const half *)initial_state, - use_qk_l2norm, chunk_size, T); - } else if (dtype == INFINI_DTYPE_BF16) { - chunkGatedDeltaRule<__nv_bfloat16, float, Dk, Dv, NUM_THREADS> - <<>>( - (__nv_bfloat16 *)out, (__nv_bfloat16 *)final_state, - (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k, (const __nv_bfloat16 *)v, - (const __nv_bfloat16 *)g, (const __nv_bfloat16 *)beta, (const __nv_bfloat16 *)initial_state, - use_qk_l2norm, chunk_size, T); - } else if (dtype == INFINI_DTYPE_F32) { - chunkGatedDeltaRule - <<>>( - (float *)out, (float *)final_state, - (const float *)q, (const float *)k, (const float *)v, - (const float *)g, (const float *)beta, (const float *)initial_state, - use_qk_l2norm, chunk_size, T); - } else { - return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; + if (info.data_dtype == INFINI_DTYPE_F16) { + return launchKernelWithGateDtype( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, info, stream); + } + if (info.data_dtype == INFINI_DTYPE_BF16) { + return launchKernelWithGateDtype<__nv_bfloat16, Dk, Dv, NUM_THREADS>( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, info, stream); } - return infiniStatus_t::INFINI_STATUS_SUCCESS; + if (info.data_dtype == INFINI_DTYPE_F32) { + return launchKernelWithGateDtype( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, info, stream); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; } infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *out, void *final_state, + void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, - const void *g, const void *beta, const void *initial_state, + const void *g, const void *beta, const void *cu_seqlens, + const void *initial_state_indices, const void *final_state_indices, void *stream_) const { cudaStream_t stream = (cudaStream_t)stream_; + if (workspace == nullptr || workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } - // Specialize for common shapes and thread counts if (_info.Dk == 128 && _info.Dv == 128) { if (_opaque->internal->maxThreadsPerBlock() >= 128) { return launchKernel<128, 128, 128>( - out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, - _info.dtype, _info.B, _info.H, _info.T, _info.chunk_size, stream); + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, _info, stream); } } else if (_info.Dk == 64 && _info.Dv == 64) { if (_opaque->internal->maxThreadsPerBlock() >= 64) { return launchKernel<64, 64, 64>( - out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, - _info.dtype, _info.B, _info.H, _info.T, _info.chunk_size, stream); + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, _info, stream); } } - // Fallback or error for unsupported shapes - // You can add more specializations for other shapes here. - return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_SHAPE; + return INFINI_STATUS_BAD_TENSOR_SHAPE; } } // namespace nvidia } // namespace chunk_gated_delta_rule -} // namespace op \ No newline at end of file +} // namespace op diff --git a/src/infiniop/ops/chunk_gated_delta_rule/operator.cc b/src/infiniop/ops/chunk_gated_delta_rule/operator.cc index 776deb94a..631d9ca44 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/operator.cc +++ b/src/infiniop/ops/chunk_gated_delta_rule/operator.cc @@ -12,18 +12,19 @@ __INFINI_C infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( infiniopHandle_t handle, infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, infiniopTensorDescriptor_t final_state_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_desc, - infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t cu_seqlens_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, bool use_qk_l2norm, size_t chunk_size) { - std::optional initial_state_opt = (initial_state_desc == nullptr) ? std::nullopt : std::optional(initial_state_desc); - #define CREATE(CASE, NAMESPACE) \ case CASE: \ return op::chunk_gated_delta_rule::NAMESPACE::Descriptor::create( \ @@ -31,8 +32,10 @@ __INFINI_C infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( reinterpret_cast< \ op::chunk_gated_delta_rule::NAMESPACE::Descriptor **>( \ desc_ptr), \ - out_desc, final_state_desc, q_desc, k_desc, v_desc, g_desc, \ - beta_desc, initial_state_opt, use_qk_l2norm, chunk_size); + out_desc, initial_state_desc, final_state_desc, q_desc, \ + k_desc, v_desc, g_desc, beta_desc, cu_seqlens_desc, \ + initial_state_indices_desc, final_state_indices_desc, \ + use_qk_l2norm, chunk_size); switch (handle->device) { #ifdef ENABLE_NVIDIA_API @@ -68,16 +71,18 @@ __INFINI_C infiniStatus_t infiniopGetChunkGatedDeltaRuleWorkspaceSize( __INFINI_C infiniStatus_t infiniopChunkGatedDeltaRule( infiniopChunkGatedDeltaRuleDescriptor_t desc, void *workspace, size_t workspace_size, - void *out, void *final_state, + void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, - const void *g, const void *beta, const void *initial_state, + const void *g, const void *beta, const void *cu_seqlens, + const void *initial_state_indices, const void *final_state_indices, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast< \ op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ - ->calculate(workspace, workspace_size, out, final_state, q, k, v, \ - g, beta, initial_state, stream); + ->calculate(workspace, workspace_size, out, initial_state, \ + final_state, q, k, v, g, beta, cu_seqlens, \ + initial_state_indices, final_state_indices, stream); switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API @@ -107,4 +112,4 @@ __INFINI_C infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor( return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } #undef DESTROY -} \ No newline at end of file +} diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh index a8d68c61e..db9161626 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh +++ b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh @@ -4,66 +4,148 @@ #define __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ #include +#include +#include #include -// Tdata: (e.g., half) -// Tcompute: (e.g., float) -template + +__device__ inline int64_t loadStateIndex( + const void *indices, + bool is_i64, + int batch_idx, + int fallback) { + if (indices == nullptr) { + return static_cast(fallback); + } + return is_i64 + ? static_cast(indices)[batch_idx] + : static_cast(static_cast(indices)[batch_idx]); +} + +template +__device__ inline float loadAsFloat(const T *ptr, ptrdiff_t offset) { + return static_cast(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat(const half *ptr, ptrdiff_t offset) { + return __half2float(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { + return __bfloat162float(ptr[offset]); +} + +template __device__ void recurrentGatedDeltaRuleKernel( Tdata *out, + Tdata *initial_state, Tdata *final_state, const Tdata *q, const Tdata *k, const Tdata *v, - const Tdata *g, - const Tdata *beta, - const Tdata *initial_state, - bool use_qk_l2norm) { + const Tgate *g, + const Tgate *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool indexed_state_pool, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { const int batch_idx = blockIdx.x; - const int head_idx = blockIdx.y; + const int value_head_idx = blockIdx.y; + const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); const int thread_idx = threadIdx.x; - // T=1 for decode stage, so seq_idx is always 0 - const int seq_idx = 0; - - const size_t H = gridDim.y; - const size_t base_offset_qkv = (batch_idx * H + head_idx) * Dk; // T=1, Dk=Dv for simplicity now - const size_t base_offset_gb = (batch_idx * H + head_idx); // T=1 - const size_t state_offset = (batch_idx * H + head_idx) * Dk * Dv; + if (key_head_idx >= static_cast(Hk)) { + return; + } - const Tdata *q_ptr = q + base_offset_qkv; - const Tdata *k_ptr = k + base_offset_qkv; - const Tdata *v_ptr = v + base_offset_qkv; // Assuming Dv = Dk - const Tdata *g_ptr = g + base_offset_gb; - const Tdata *beta_ptr = beta + base_offset_gb; - const Tdata *initial_state_ptr = initial_state + state_offset; + constexpr int seq_idx = 0; + const ptrdiff_t q_base = static_cast(batch_idx) * q_s0 + seq_idx * q_s1 + static_cast(key_head_idx) * q_s2; + const ptrdiff_t k_base = static_cast(batch_idx) * k_s0 + seq_idx * k_s1 + static_cast(key_head_idx) * k_s2; + const ptrdiff_t v_base = static_cast(batch_idx) * v_s0 + seq_idx * v_s1 + static_cast(value_head_idx) * v_s2; + const ptrdiff_t out_base = static_cast(batch_idx) * out_s0 + seq_idx * out_s1 + static_cast(value_head_idx) * out_s2; + const ptrdiff_t gate_offset = static_cast(batch_idx) * g_s0 + seq_idx * g_s1 + static_cast(value_head_idx) * g_s2; + const ptrdiff_t beta_offset = static_cast(batch_idx) * beta_s0 + seq_idx * beta_s1 + static_cast(value_head_idx) * beta_s2; + + int64_t read_slot = static_cast(batch_idx); + int64_t write_slot = static_cast(batch_idx); + if (indexed_state_pool) { + read_slot = loadStateIndex(initial_state_indices, initial_state_indices_i64, batch_idx, batch_idx); + write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadStateIndex(final_state_indices, final_state_indices_i64, batch_idx, batch_idx); + if (read_slot < 0 || write_slot < 0) { + for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { + out[out_base + dv_idx] = static_cast(0.0f); + } + return; + } + } - Tdata *out_ptr = out + base_offset_qkv; - Tdata *final_state_ptr = final_state + state_offset; + const ptrdiff_t initial_base = indexed_state_pool + ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 + : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; + ptrdiff_t final_base = 0; + Tdata *final_state_target = nullptr; + if (indexed_state_pool && final_state_indices != nullptr) { + final_state_target = initial_state; + final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; + } else if (indexed_state_pool) { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } else { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } extern __shared__ char shared_mem_char[]; Tcompute *shared_mem = reinterpret_cast(shared_mem_char); - // shared memory layout: q_local[Dk], k_local[Dk], norm_val[1] Tcompute *q_local = shared_mem; Tcompute *k_local = q_local + Dk; - Tcompute *norm_val = k_local + Dk; // for reduction + Tcompute *norm_val = k_local + Dk; - // 1. Load Q and K into shared memory and optionally normalize - // Load for (int i = thread_idx; i < Dk; i += NUM_THREADS) { - q_local[i] = static_cast(q_ptr[i]); - k_local[i] = static_cast(k_ptr[i]); + q_local[i] = static_cast(loadAsFloat(q, q_base + i)); + k_local[i] = static_cast(loadAsFloat(k, k_base + i)); } if (use_qk_l2norm) { __syncthreads(); - // Parallel reduction to compute L2 norm for Q Tcompute sum_sq = 0.0f; for (int i = thread_idx; i < Dk; i += NUM_THREADS) { sum_sq += q_local[i] * q_local[i]; } - // Simplified reduction, for real use CUB will be better - // This part needs a proper block-wide reduction implementation norm_val[thread_idx] = sum_sq; __syncthreads(); if (thread_idx == 0) { @@ -76,12 +158,10 @@ __device__ void recurrentGatedDeltaRuleKernel( __syncthreads(); Tcompute r_norm_q = norm_val[0]; - // Normalize Q for (int i = thread_idx; i < Dk; i += NUM_THREADS) { q_local[i] *= r_norm_q; } - // Repeat for K sum_sq = 0.0f; for (int i = thread_idx; i < Dk; i += NUM_THREADS) { sum_sq += k_local[i] * k_local[i]; @@ -98,16 +178,14 @@ __device__ void recurrentGatedDeltaRuleKernel( __syncthreads(); Tcompute r_norm_k = norm_val[0]; - // Normalize K for (int i = thread_idx; i < Dk; i += NUM_THREADS) { k_local[i] *= r_norm_k; } __syncthreads(); } - // 2. Perform the recurrent update logic - Tcompute g_t = expf(static_cast(*g_ptr)); - Tcompute beta_t = static_cast(*beta_ptr); + Tcompute g_t = expf(static_cast(loadAsFloat(g, gate_offset))); + Tcompute beta_t = static_cast(loadAsFloat(beta, beta_offset)); Tcompute scale = rsqrtf(static_cast(Dk)); for (int i = thread_idx; i < Dk; i += NUM_THREADS) { @@ -115,40 +193,34 @@ __device__ void recurrentGatedDeltaRuleKernel( } __syncthreads(); - // Loop over Dv, each thread computes an element of the delta and output vector for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { Tcompute kv_mem = 0.0f; - // Calculate kv_mem = sum(h_{t-1} * k_t) for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { - Tcompute h_prev = static_cast(initial_state_ptr[dk_idx * Dv + dv_idx]); + ptrdiff_t state_idx = indexed_state_pool + ? initial_base + static_cast(dv_idx) * initial_s2 + static_cast(dk_idx) * initial_s3 + : initial_base + static_cast(dk_idx) * initial_s2 + static_cast(dv_idx) * initial_s3; + Tcompute h_prev = static_cast(loadAsFloat(initial_state, state_idx)); kv_mem += (h_prev * g_t) * k_local[dk_idx]; } - Tcompute v_t = static_cast(v_ptr[dv_idx]); + Tcompute v_t = static_cast(loadAsFloat(v, v_base + dv_idx)); Tcompute delta = (v_t - kv_mem) * beta_t; + Tcompute out_val = 0.0f; - // Calculate final state h_t = h_{t-1} * g + k_t * delta - // And write it back for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { - Tcompute h_prev = static_cast(initial_state_ptr[dk_idx * Dv + dv_idx]); + ptrdiff_t read_state_idx = indexed_state_pool + ? initial_base + static_cast(dv_idx) * initial_s2 + static_cast(dk_idx) * initial_s3 + : initial_base + static_cast(dk_idx) * initial_s2 + static_cast(dv_idx) * initial_s3; + ptrdiff_t write_state_idx = indexed_state_pool + ? final_base + static_cast(dv_idx) * (final_state_indices != nullptr ? initial_s2 : final_s2) + static_cast(dk_idx) * (final_state_indices != nullptr ? initial_s3 : final_s3) + : final_base + static_cast(dk_idx) * final_s2 + static_cast(dv_idx) * final_s3; + Tcompute h_prev = static_cast(loadAsFloat(initial_state, read_state_idx)); Tcompute h_final = (h_prev * g_t) + (k_local[dk_idx] * delta); - final_state_ptr[dk_idx * Dv + dv_idx] = static_cast(h_final); - } - - // Calculate output o_t = sum(h_t * q_t) - // This requires another reduction. For simplicity, we assume one thread calculates one output element. - // A more optimized version would have all threads collaborating. - } - __syncthreads(); // Ensure final_state is fully written - - // All threads collaborate to compute the final output vector - for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { - Tcompute out_val = 0.0f; - for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { - out_val += static_cast(final_state_ptr[dk_idx * Dv + dv_idx]) * q_local[dk_idx]; + out_val += h_final * q_local[dk_idx]; + final_state_target[write_state_idx] = static_cast(h_final); } - out_ptr[dv_idx] = static_cast(out_val); + out[out_base + dv_idx] = static_cast(out_val); } } -#endif // __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ \ No newline at end of file +#endif // __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/info.h b/src/infiniop/ops/recurrent_gated_delta_rule/info.h index bc6c523ad..51644964a 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/info.h +++ b/src/infiniop/ops/recurrent_gated_delta_rule/info.h @@ -5,7 +5,6 @@ #include "../../../utils.h" #include "../../tensor.h" -#include #include namespace op { @@ -15,54 +14,159 @@ class RecurrentGatedDeltaRuleInfo { RecurrentGatedDeltaRuleInfo() = default; public: - // --- Data Types and Flags --- - infiniDtype_t dtype; + infiniDtype_t data_dtype; + infiniDtype_t gate_dtype; + infiniDtype_t initial_state_indices_dtype; + infiniDtype_t final_state_indices_dtype; bool use_qk_l2norm; + bool has_initial_state_indices; + bool has_final_state_indices; + bool indexed_state_pool; - // --- Shape Dimensions --- - size_t B, H, T, Dk, Dv; + size_t B, Hk, Hv, T, Dk, Dv, pool_size, value_heads_per_key_head; - // --- Strides for Memory Layout --- - // Strides can be added here if needed for more complex layouts + std::vector out_strides; + std::vector initial_state_strides; + std::vector final_state_strides; + std::vector q_strides; + std::vector k_strides; + std::vector v_strides; + std::vector g_strides; + std::vector beta_strides; static utils::Result create(infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, infiniopTensorDescriptor_t final_state_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_desc, - infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, bool use_qk_l2norm) { - auto dtype = q_desc->dtype(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (out_desc == nullptr || initial_state_desc == nullptr || q_desc == nullptr || k_desc == nullptr || v_desc == nullptr || g_desc == nullptr || beta_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } - // Check for consistent data types across all tensors - if (out_desc->dtype() != dtype || final_state_desc->dtype() != dtype || k_desc->dtype() != dtype || v_desc->dtype() != dtype || g_desc->dtype() != dtype || beta_desc->dtype() != dtype || initial_state_desc->dtype() != dtype) { + auto data_dtype = q_desc->dtype(); + CHECK_DTYPE(data_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (k_desc->dtype() != data_dtype || v_desc->dtype() != data_dtype || out_desc->dtype() != data_dtype || initial_state_desc->dtype() != data_dtype || (final_state_desc != nullptr && final_state_desc->dtype() != data_dtype)) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - // Check tensor dimensions - if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || initial_state_desc->ndim() != 4 || out_desc->ndim() != 4 || final_state_desc->ndim() != 4) { + auto gate_dtype = g_desc->dtype(); + CHECK_DTYPE(gate_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (beta_desc->dtype() != gate_dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + bool has_initial_indices = initial_state_indices_desc != nullptr; + bool has_final_indices = final_state_indices_desc != nullptr; + bool indexed_pool = has_initial_indices || has_final_indices; + + if (has_final_indices && final_state_desc != nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + if (!has_final_indices && final_state_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || out_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || initial_state_desc->ndim() != 4 || (!has_final_indices && final_state_desc->ndim() != 4)) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - RecurrentGatedDeltaRuleInfo info; - info.dtype = dtype; - info.use_qk_l2norm = use_qk_l2norm; + auto q_shape = q_desc->shape(); // [B, T, Hk, Dk] + auto k_shape = k_desc->shape(); // [B, T, Hk, Dk] + auto v_shape = v_desc->shape(); // [B, T, Hv, Dv] + auto out_shape = out_desc->shape(); // [B, T, Hv, Dv] + auto g_shape = g_desc->shape(); // [B, T, Hv] + auto beta_shape = beta_desc->shape(); // [B, T, Hv] + + size_t B = q_shape[0], T = q_shape[1], Hk = q_shape[2], Dk = q_shape[3]; + size_t Hv = v_shape[2], Dv = v_shape[3]; - auto q_shape = q_desc->shape(); - info.B = q_shape[0]; - info.H = q_shape[1]; - info.T = q_shape[2]; - info.Dk = q_shape[3]; + if (T != 1 || k_shape[0] != B || k_shape[1] != T || k_shape[2] != Hk || k_shape[3] != Dk || v_shape[0] != B || v_shape[1] != T || out_shape[0] != B || out_shape[1] != T || out_shape[2] != Hv || out_shape[3] != Dv || g_shape[0] != B || g_shape[1] != T || g_shape[2] != Hv || beta_shape[0] != B || beta_shape[1] != T || beta_shape[2] != Hv || Hk == 0 || Hv == 0 || Hv % Hk != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (q_desc->strides()[3] != 1 || k_desc->strides()[3] != 1 || v_desc->strides()[3] != 1 || out_desc->strides()[3] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } - info.Dv = v_desc->shape()[3]; + auto initial_shape = initial_state_desc->shape(); + size_t pool_size = initial_shape[0]; + if (indexed_pool) { + // Indexed pool layout is [pool_size, Hv, Dv, Dk]. + if (initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + // Legacy layout is [B, Hv, Dk, Dv]. + if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (!has_final_indices) { + auto final_shape = final_state_desc->shape(); + if (indexed_pool) { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + } - // Further validation can be added here to ensure all shapes are compatible. - // For example, check if initial_state_desc shape is [B, H, Dk, Dv]. + infiniDtype_t initial_indices_dtype = INFINI_DTYPE_INVALID; + infiniDtype_t final_indices_dtype = INFINI_DTYPE_INVALID; + if (has_initial_indices) { + if (initial_state_indices_desc->ndim() != 1 || initial_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + initial_indices_dtype = initial_state_indices_desc->dtype(); + CHECK_DTYPE(initial_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + if (has_final_indices) { + if (final_state_indices_desc->ndim() != 1 || final_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + final_indices_dtype = final_state_indices_desc->dtype(); + CHECK_DTYPE(final_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + + RecurrentGatedDeltaRuleInfo info; + info.data_dtype = data_dtype; + info.gate_dtype = gate_dtype; + info.initial_state_indices_dtype = initial_indices_dtype; + info.final_state_indices_dtype = final_indices_dtype; + info.use_qk_l2norm = use_qk_l2norm; + info.has_initial_state_indices = has_initial_indices; + info.has_final_state_indices = has_final_indices; + info.indexed_state_pool = indexed_pool; + info.B = B; + info.Hk = Hk; + info.Hv = Hv; + info.T = T; + info.Dk = Dk; + info.Dv = Dv; + info.pool_size = pool_size; + info.value_heads_per_key_head = Hv / Hk; + info.out_strides = out_desc->strides(); + info.initial_state_strides = initial_state_desc->strides(); + if (final_state_desc != nullptr) { + info.final_state_strides = final_state_desc->strides(); + } + info.q_strides = q_desc->strides(); + info.k_strides = k_desc->strides(); + info.v_strides = v_desc->strides(); + info.g_strides = g_desc->strides(); + info.beta_strides = beta_desc->strides(); return utils::Result(info); } @@ -71,4 +175,4 @@ class RecurrentGatedDeltaRuleInfo { } // namespace recurrent_gated_delta_rule } // namespace op -#endif // __RECURRENT_GATED_DELTA_RULE_INFO_H__ \ No newline at end of file +#endif // __RECURRENT_GATED_DELTA_RULE_INFO_H__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu index 6c4b4ebe5..3c227266e 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu @@ -8,15 +8,59 @@ #include "../cuda/kernel.cuh" #include -// Kernel Launcher Wrapper -template +template INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( - Tdata *out, Tdata *final_state, + Tdata *out, Tdata *initial_state, Tdata *final_state, const Tdata *q, const Tdata *k, const Tdata *v, - const Tdata *g, const Tdata *beta, const Tdata *initial_state, - bool use_qk_l2norm) { - recurrentGatedDeltaRuleKernel( - out, final_state, q, k, v, g, beta, initial_state, use_qk_l2norm); + const Tgate *g, const Tgate *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool indexed_state_pool, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + recurrentGatedDeltaRuleKernel( + out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, + initial_state_indices_i64, final_state_indices_i64, + use_qk_l2norm, indexed_state_pool, + Hk, value_heads_per_key_head, + out_s0, out_s1, out_s2, + initial_s0, initial_s1, initial_s2, initial_s3, + final_s0, final_s1, final_s2, final_s3, + q_s0, q_s1, q_s2, + k_s0, k_s1, k_s2, + v_s0, v_s1, v_s2, + g_s0, g_s1, g_s2, + beta_s0, beta_s1, beta_s2); } namespace op { @@ -35,20 +79,22 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, infiniopTensorDescriptor_t final_state_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_desc, - infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, bool use_qk_l2norm) { auto info = RecurrentGatedDeltaRuleInfo::create( - out_desc, final_state_desc, q_desc, k_desc, v_desc, - g_desc, beta_desc, initial_state_desc, use_qk_l2norm); + out_desc, initial_state_desc, final_state_desc, q_desc, k_desc, v_desc, + g_desc, beta_desc, initial_state_indices_desc, final_state_indices_desc, + use_qk_l2norm); CHECK_RESULT(info); - // Calculate workspace size if needed, here it's 0 size_t workspace_size = 0; *desc_ptr = new Descriptor( @@ -58,75 +104,173 @@ infiniStatus_t Descriptor::create( return infiniStatus_t::INFINI_STATUS_SUCCESS; } -template -infiniStatus_t launchKernel( - void *out, void *final_state, +template +infiniStatus_t launchKernelTyped( + const RecurrentGatedDeltaRuleInfo &_info, + void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, - const void *g, const void *beta, const void *initial_state, - bool use_qk_l2norm, - infiniDtype_t dtype, - size_t B, size_t H, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, cudaStream_t stream) { - dim3 grid(uint32_t(B), uint32_t(H), 1); + dim3 grid(uint32_t(_info.B), uint32_t(_info.Hv), 1); dim3 block(NUM_THREADS); - // Shared memory for local Q, K, and one reduction value size_t shared_mem_size = (Dk + Dk + NUM_THREADS) * sizeof(float); - if (dtype == INFINI_DTYPE_F16) { - recurrentGatedDeltaRule - <<>>( - (half *)out, (half *)final_state, - (const half *)q, (const half *)k, (const half *)v, - (const half *)g, (const half *)beta, (const half *)initial_state, - use_qk_l2norm); - } else if (dtype == INFINI_DTYPE_BF16) { - recurrentGatedDeltaRule<__nv_bfloat16, float, Dk, Dv, NUM_THREADS> - <<>>( - (__nv_bfloat16 *)out, (__nv_bfloat16 *)final_state, - (const __nv_bfloat16 *)q, (const __nv_bfloat16 *)k, (const __nv_bfloat16 *)v, - (const __nv_bfloat16 *)g, (const __nv_bfloat16 *)beta, (const __nv_bfloat16 *)initial_state, - use_qk_l2norm); - } else if (dtype == INFINI_DTYPE_F32) { - recurrentGatedDeltaRule - <<>>( - (float *)out, (float *)final_state, - (const float *)q, (const float *)k, (const float *)v, - (const float *)g, (const float *)beta, (const float *)initial_state, - use_qk_l2norm); - } else { + auto final_s0 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[0]; + auto final_s1 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[1]; + auto final_s2 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[2]; + auto final_s3 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[3]; + + recurrentGatedDeltaRule + <<>>( + static_cast(out), + static_cast(initial_state), + static_cast(final_state), + static_cast(q), + static_cast(k), + static_cast(v), + static_cast(g), + static_cast(beta), + initial_state_indices, + final_state_indices, + initial_state_indices_i64, + final_state_indices_i64, + _info.use_qk_l2norm, + _info.indexed_state_pool, + _info.Hk, + _info.value_heads_per_key_head, + _info.out_strides[0], + _info.out_strides[1], + _info.out_strides[2], + _info.initial_state_strides[0], + _info.initial_state_strides[1], + _info.initial_state_strides[2], + _info.initial_state_strides[3], + final_s0, + final_s1, + final_s2, + final_s3, + _info.q_strides[0], + _info.q_strides[1], + _info.q_strides[2], + _info.k_strides[0], + _info.k_strides[1], + _info.k_strides[2], + _info.v_strides[0], + _info.v_strides[1], + _info.v_strides[2], + _info.g_strides[0], + _info.g_strides[1], + _info.g_strides[2], + _info.beta_strides[0], + _info.beta_strides[1], + _info.beta_strides[2]); + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernelForGate( + const RecurrentGatedDeltaRuleInfo &_info, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + cudaStream_t stream) { + switch (_info.gate_dtype) { + case INFINI_DTYPE_F16: + return launchKernelTyped( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_BF16: + return launchKernelTyped( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_F32: + return launchKernelTyped( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + default: + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +template +infiniStatus_t launchKernel( + const RecurrentGatedDeltaRuleInfo &_info, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + cudaStream_t stream) { + switch (_info.data_dtype) { + case INFINI_DTYPE_F16: + return launchKernelForGate( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_BF16: + return launchKernelForGate<__nv_bfloat16, Dk, Dv, NUM_THREADS>( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_F32: + return launchKernelForGate( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + default: return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; } - return infiniStatus_t::INFINI_STATUS_SUCCESS; } infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *out, void *final_state, + void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, - const void *g, const void *beta, const void *initial_state, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, void *stream_) const { cudaStream_t stream = (cudaStream_t)stream_; - // Specialize for common shapes and thread counts + if (_info.has_initial_state_indices && initial_state_indices == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + if (_info.has_final_state_indices && final_state_indices == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + if (!_info.has_final_state_indices && final_state == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + bool initial_indices_i64 = _info.initial_state_indices_dtype == INFINI_DTYPE_I64; + bool final_indices_i64 = _info.final_state_indices_dtype == INFINI_DTYPE_I64; + if (_info.Dk == 128 && _info.Dv == 128) { if (_opaque->internal->maxThreadsPerBlock() >= 128) { return launchKernel<128, 128, 128>( - out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, - _info.dtype, _info.B, _info.H, stream); + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, + initial_indices_i64, final_indices_i64, stream); } } else if (_info.Dk == 64 && _info.Dv == 64) { if (_opaque->internal->maxThreadsPerBlock() >= 64) { return launchKernel<64, 64, 64>( - out, final_state, q, k, v, g, beta, initial_state, _info.use_qk_l2norm, - _info.dtype, _info.B, _info.H, stream); + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, + initial_indices_i64, final_indices_i64, stream); } } - // Fallback or error for unsupported shapes - // You can add more specializations for other shapes here. return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_SHAPE; } } // namespace nvidia } // namespace recurrent_gated_delta_rule -} // namespace op \ No newline at end of file +} // namespace op diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc b/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc index 4852b65eb..465e8d6c9 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc +++ b/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc @@ -12,13 +12,15 @@ __INFINI_C infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( infiniopHandle_t handle, infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, infiniopTensorDescriptor_t final_state_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t g_desc, infiniopTensorDescriptor_t beta_desc, - infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, bool use_qk_l2norm) { #define CREATE(CASE, NAMESPACE) \ case CASE: \ @@ -27,8 +29,9 @@ __INFINI_C infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( reinterpret_cast< \ op::recurrent_gated_delta_rule::NAMESPACE::Descriptor **>( \ desc_ptr), \ - out_desc, final_state_desc, q_desc, k_desc, v_desc, g_desc, \ - beta_desc, initial_state_desc, use_qk_l2norm); + out_desc, initial_state_desc, final_state_desc, q_desc, k_desc, \ + v_desc, g_desc, beta_desc, initial_state_indices_desc, \ + final_state_indices_desc, use_qk_l2norm); switch (handle->device) { #ifdef ENABLE_NVIDIA_API @@ -64,16 +67,19 @@ __INFINI_C infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( __INFINI_C infiniStatus_t infiniopRecurrentGatedDeltaRule( infiniopRecurrentGatedDeltaRuleDescriptor_t desc, void *workspace, size_t workspace_size, - void *out, void *final_state, + void *out, void *initial_state, void *final_state, const void *q, const void *k, const void *v, - const void *g, const void *beta, const void *initial_state, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast< \ op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ - ->calculate(workspace, workspace_size, out, final_state, q, k, v, \ - g, beta, initial_state, stream); + ->calculate(workspace, workspace_size, out, initial_state, \ + final_state, q, k, v, g, beta, initial_state_indices, \ + final_state_indices, stream); switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API @@ -103,4 +109,4 @@ __INFINI_C infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor( return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } #undef DESTROY -} \ No newline at end of file +} diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h b/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h index 08ab5fa48..14846ea25 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h +++ b/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h @@ -6,51 +6,55 @@ #include "../../operator.h" #include "info.h" -#define DESCRIPTOR(NAMESPACE) \ - \ - namespace op::recurrent_gated_delta_rule::NAMESPACE { \ - class Descriptor final : public InfiniopDescriptor { \ - struct Opaque; \ - Opaque *_opaque; \ - RecurrentGatedDeltaRuleInfo _info; \ - size_t _workspace_size; \ - \ - Descriptor( \ - Opaque *opaque, \ - RecurrentGatedDeltaRuleInfo info, \ - size_t workspace_size, \ - infiniDevice_t device_type, \ - int device_id) \ - : InfiniopDescriptor{device_type, device_id}, \ - _opaque(opaque), \ - _info(info), \ - _workspace_size(workspace_size) {} \ - \ - public: \ - ~Descriptor(); \ - \ - size_t workspaceSize() const { return _workspace_size; } \ - \ - static infiniStatus_t create( \ - infiniopHandle_t handle, \ - Descriptor **desc_ptr, \ - infiniopTensorDescriptor_t out_desc, \ - infiniopTensorDescriptor_t final_state_desc, \ - infiniopTensorDescriptor_t q_desc, \ - infiniopTensorDescriptor_t k_desc, \ - infiniopTensorDescriptor_t v_desc, \ - infiniopTensorDescriptor_t g_desc, \ - infiniopTensorDescriptor_t beta_desc, \ - infiniopTensorDescriptor_t initial_state_desc, \ - bool use_qk_l2norm); \ - \ - infiniStatus_t calculate( \ - void *workspace, size_t workspace_size, \ - void *out, void *final_state, \ - const void *q, const void *k, const void *v, \ - const void *g, const void *beta, const void *initial_state, \ - void *stream) const; \ - }; \ +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::recurrent_gated_delta_rule::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RecurrentGatedDeltaRuleInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + RecurrentGatedDeltaRuleInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t initial_state_desc, \ + infiniopTensorDescriptor_t final_state_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_desc, \ + infiniopTensorDescriptor_t beta_desc, \ + infiniopTensorDescriptor_t initial_state_indices_desc, \ + infiniopTensorDescriptor_t final_state_indices_desc, \ + bool use_qk_l2norm); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, void *initial_state, void *final_state, \ + const void *q, const void *k, const void *v, \ + const void *g, const void *beta, \ + const void *initial_state_indices, \ + const void *final_state_indices, \ + void *stream) const; \ + }; \ } -#endif // __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ \ No newline at end of file +#endif // __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ diff --git a/test/infinicore/ops/chunk_gated_delta_rule.py b/test/infinicore/ops/chunk_gated_delta_rule.py new file mode 100644 index 000000000..0165868de --- /dev/null +++ b/test/infinicore/ops/chunk_gated_delta_rule.py @@ -0,0 +1,344 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import infinicore +import torch +from framework import ( + BaseOperatorTest, + TensorSpec, + TestCase, + GenericTestRunner, + TensorInitializer, +) + +# Test cases: +# (n_khead, kdim, n_vhead, vdim, seqlens, init_state_indices, final_state_indices, state_pool_size) +_VARLEN_TEST_CASES_DATA = [ + (16, 128, 48, 128, (13,), (0,), (0,), 1), + (16, 128, 48, 128, (13,), (1,), (0,), 2), + (16, 128, 48, 128, (13, 20), (1, 1), (0, 0), 4), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, +} + +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def torch_chunk_gated_delta_rule_ref( + q, + k, + v, + g, + beta, + initial_state, + cu_seqlens=None, + initial_state_indices=None, + final_state_indices=None, + use_qk_l2norm=False, + chunk_size=64, +): + + def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + def _run_one(q, k, v, g, beta, init_state): + # q/k: [1, T, Hk, Dk] + # v/out: [1, T, Hv, Dv] + # g/beta: [1, T, Hv] + # init_state: [B, Hv, Dv, Dk] + # return state: [B, Hv, Dv, Dk] + initial_dtype = q.dtype + + if use_qk_l2norm: + q = l2norm(q, dim=-1, eps=1e-6) + k = l2norm(k, dim=-1, eps=1e-6) + + B, T, Hk, Dk = q.shape + _, _, Hv, Dv = v.shape + assert B == 1 + assert Hv % Hk == 0 + assert init_state.shape == (B, Hv, Dv, Dk) + + group_size = Hv // Hk + if group_size != 1: + q = q.repeat_interleave(group_size, dim=2) + k = k.repeat_interleave(group_size, dim=2) + + q = q.transpose(1, 2).contiguous().float() + k = k.transpose(1, 2).contiguous().float() + v = v.transpose(1, 2).contiguous().float() + + beta = beta.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + + B, H, sequence_length, Dk = k.shape + Dv = v.shape[-1] + + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + + q = torch.nn.functional.pad(q, (0, 0, 0, pad_size)) + k = torch.nn.functional.pad(k, (0, 0, 0, pad_size)) + v = torch.nn.functional.pad(v, (0, 0, 0, pad_size)) + beta = torch.nn.functional.pad(beta, (0, pad_size)) + g = torch.nn.functional.pad(g, (0, pad_size)) + + total_sequence_length = sequence_length + pad_size + scale = 1 / (q.shape[-1] ** 0.5) + q = q * scale + + v_beta = v * beta.unsqueeze(-1) + k_beta = k * beta.unsqueeze(-1) + + q, k, v, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (q, k, v, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + + attn = -((k_beta @ k.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + v = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + # [B, Hv, Dv, Dk] -> [B, Hv, Dk, Dv] + last_state = init_state.transpose(-1, -2).contiguous().float().clone() + + out = torch.zeros_like(v) + + for i in range(total_sequence_length // chunk_size): + q_i = q[:, :, i] + k_i = k[:, :, i] + v_i = v[:, :, i] + + attn = q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i] + + v_prime = k_cumdecay[:, :, i] @ last_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state + out[:, :, i] = attn_inter + attn @ v_new + + last_state = ( + last_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + out = out.reshape(out.shape[0], out.shape[1], -1, out.shape[-1]) + out = out[:, :, :sequence_length] + out = out.transpose(1, 2).contiguous().to(initial_dtype) + + # [B, Hv, Dk, Dv] -> [B, Hv, Dv, Dk] + final_state = last_state.transpose(-1, -2).contiguous().to(init_state.dtype) + + return out, final_state + + if cu_seqlens is None: + out, final_state = _run_one(q, k, v, g, beta, initial_state) + + if initial_state_indices is not None: + for b, dst in enumerate(final_state_indices.cpu().tolist()): + initial_state[dst].copy_(final_state[b].to(initial_state.dtype)) + + return out + + cu = cu_seqlens.cpu().tolist() + batch = len(cu) - 1 + total_tokens = cu[-1] + + out = torch.empty_like(v[:, :total_tokens]) + indexed_state_pool = initial_state_indices is not None + + for b in range(batch): + start = cu[b] + end = cu[b + 1] + + q_b = q[:, start:end] + k_b = k[:, start:end] + v_b = v[:, start:end] + g_b = g[:, start:end] + beta_b = beta[:, start:end] + + if indexed_state_pool: + src = int(initial_state_indices[b].item()) + init_b = initial_state[src : src + 1] + else: + init_b = initial_state[b : b + 1] + + out_b, final_b = _run_one(q_b, k_b, v_b, g_b, beta_b, init_b) + out[:, start:end].copy_(out_b) + + if indexed_state_pool: + dst = int(final_state_indices[b].item()) + initial_state[dst].copy_(final_b[0].to(initial_state.dtype)) + + return out + + +def parse_varlen_test_cases(): + tests = [] + + for ( + n_khead, + kdim, + n_vhead, + vdim, + seqlens, + init_state_indices, + final_state_indices, + state_pool_size, + ) in _VARLEN_TEST_CASES_DATA: + batch = len(seqlens) + total_tokens = sum(seqlens) + cu_seqlens = [0] + for seqlen in seqlens: + cu_seqlens.append(cu_seqlens[-1] + seqlen) + + q_shape = (1, total_tokens, n_khead, kdim) + k_shape = (1, total_tokens, n_khead, kdim) + v_shape = (1, total_tokens, n_vhead, vdim) + g_shape = (1, total_tokens, n_vhead) + beta_shape = (1, total_tokens, n_vhead) + + # Indexed state-pool mode in your wrapper doc: + # initial_state: [pool_size, Hv, Dv, Dk] + state_shape = (state_pool_size, n_vhead, vdim, kdim) + + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-3}) + + q_spec = TensorSpec.from_tensor(q_shape, None, dtype, scale=0.2, bias=-0.1) + k_spec = TensorSpec.from_tensor(k_shape, None, dtype, scale=0.2, bias=-0.1) + v_spec = TensorSpec.from_tensor(v_shape, None, dtype, scale=0.2, bias=-0.1) + g_spec = TensorSpec.from_tensor( + g_shape, None, infinicore.float32, scale=0.02, bias=-0.01 + ) + beta_spec = TensorSpec.from_tensor( + beta_shape, None, infinicore.float32, scale=0.5, bias=0.0 + ) + state_spec = TensorSpec.from_tensor( + state_shape, None, dtype, init_mode=TensorInitializer.ZEROS + ) + + cu_seqlens_spec = TensorSpec.from_tensor( + (batch + 1,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(cu_seqlens, dtype=torch.int32), + ) + + init_indices_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(init_state_indices, dtype=torch.int32), + ) + final_indices_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(final_state_indices, dtype=torch.int32), + ) + + tests.append( + TestCase( + inputs=[ + q_spec, + k_spec, + v_spec, + g_spec, + beta_spec, + state_spec, + cu_seqlens_spec, + init_indices_spec, + final_indices_spec, + ], + kwargs={ + "use_qk_l2norm": True, + "chunk_size": 64, + }, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="ChunkGatedDeltaRule - VARLEN_INDEXED_STATE_POOL", + ) + ) + + return tests + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("ChunkGatedDeltaRule") + + def get_test_cases(self): + return parse_varlen_test_cases() + + def torch_operator(self, *args, **kwargs): + args = list(args) + args[5] = args[5].clone() + return torch_chunk_gated_delta_rule_ref(*args, **kwargs) + + def infinicore_operator( + self, + q, + k, + v, + g, + beta, + states, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size=64, + ): + return infinicore.nn.functional.chunk_gated_delta_rule( + q, + k, + v, + g, + beta, + states, + cu_seqlens=cu_seqlens, + initial_state_indices=initial_state_indices, + final_state_indices=final_state_indices, + use_qk_l2norm=use_qk_l2norm, + chunk_size=chunk_size, + ) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/recurrent_gated_delta_rule.py b/test/infinicore/ops/recurrent_gated_delta_rule.py new file mode 100644 index 000000000..4fbc59490 --- /dev/null +++ b/test/infinicore/ops/recurrent_gated_delta_rule.py @@ -0,0 +1,291 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import torch.nn.functional as torch_F +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorInitializer, + TensorSpec, + TestCase, +) + +import infinicore + +# Test cases: +# (B, T, Hk, Hv, Dk, Dv, use_qk_l2norm, strided_qkv) +_TEST_CASES = [ + (7, 1, 40, 40, 128, 128, True, False), + (5, 1, 64, 64, 128, 128, False, False), + (1, 1, 8, 8, 64, 64, True, False), + (2, 1, 4, 8, 64, 64, False, False), + (2, 1, 4, 8, 64, 64, True, True), +] +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, +} + + +def ref_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + use_qk_l2norm=False, +): + initial_dtype = query.dtype + if use_qk_l2norm: + query = torch_F.normalize(query, p=2, dim=-1) + key = torch_F.normalize(key, p=2, dim=-1) + + query, key, value, beta, g = [ + x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + state = initial_state.contiguous().to(torch.float32).clone() + batch_size, sequence_length, key_heads, _ = key.shape + value_heads, v_head_dim = value.shape[2], value.shape[-1] + value_heads_per_key_head = value_heads // key_heads + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + out = torch.zeros( + batch_size, + sequence_length, + value_heads, + v_head_dim, + device=value.device, + dtype=torch.float32, + ) + + for i in range(sequence_length): + for vh in range(value_heads): + kh = vh // value_heads_per_key_head + q_t = query[:, i, kh] + k_t = key[:, i, kh] + v_t = value[:, i, vh] + g_t = g[:, i, vh].exp().view(batch_size, 1, 1) + beta_t = beta[:, i, vh].view(batch_size, 1) + state_t = state[:, vh] + + state_t = state_t * g_t + kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state_t = state_t + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + state[:, vh] = state_t + out[:, i, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) + + return out.contiguous().to(initial_dtype), state.contiguous().to(initial_dtype) + + +def strided_bthd_strides(shape): + _, T, H, D = shape + return (8 * T * H * D, 4 * H * D, 2 * D, 1) + + +def tensor_spec(shape, dtype, strides=None): + return TensorSpec.from_tensor(shape, strides, dtype, scale=0.2, bias=-0.1) + + +def gate_spec(shape): + return TensorSpec.from_tensor( + shape, None, infinicore.float32, scale=0.02, bias=-0.01 + ) + + +def beta_spec(shape): + return TensorSpec.from_tensor(shape, None, infinicore.float32, scale=0.5, bias=0.0) + + +def index_spec(values): + values = torch.tensor(values, dtype=torch.int64) + return TensorSpec.from_tensor( + values.shape, + None, + infinicore.int64, + init_mode=TensorInitializer.MANUAL, + set_tensor=values, + ) + + +def parse_test_cases(): + tests = [] + + for B, T, Hk, Hv, Dk, Dv, use_qk_l2norm, strided_qkv in _TEST_CASES: + q_shape = (B, T, Hk, Dk) + k_shape = (B, T, Hk, Dk) + v_shape = (B, T, Hv, Dv) + gate_shape = (B, T, Hv) + initial_state_shape = (B, Hv, Dk, Dv) + pool_size = B * 2 + 3 + state_pool_shape = (pool_size, Hv, Dv, Dk) + q_strides = strided_bthd_strides(q_shape) if strided_qkv else None + k_strides = strided_bthd_strides(k_shape) if strided_qkv else None + v_strides = strided_bthd_strides(v_shape) if strided_qkv else None + + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + + base_inputs = [ + tensor_spec(q_shape, dtype, q_strides), + tensor_spec(k_shape, dtype, k_strides), + tensor_spec(v_shape, dtype, v_strides), + gate_spec(gate_shape), + beta_spec(gate_shape), + ] + + tests.append( + TestCase( + inputs=base_inputs + + [TensorSpec.from_tensor(initial_state_shape, None, dtype)], + kwargs={ + "mode": "legacy", + "use_qk_l2norm": use_qk_l2norm, + }, + description=( + f"legacy B={B}, T={T}, Hk={Hk}, Hv={Hv}, " + f"Dk={Dk}, Dv={Dv}, dtype={dtype}, " + f"strided_qkv={strided_qkv}, l2norm={use_qk_l2norm}" + ), + tolerance=tol, + ) + ) + + tests.append( + TestCase( + inputs=base_inputs + + [ + TensorSpec.from_tensor(state_pool_shape, None, dtype), + index_spec(range(1, B + 1)), + index_spec(range(B + 1, 2 * B + 1)), + ], + kwargs={ + "mode": "indexed_pool", + "use_qk_l2norm": use_qk_l2norm, + }, + output_count=2, + description=( + f"indexed pool B={B}, T={T}, Hk={Hk}, Hv={Hv}, " + f"Dk={Dk}, Dv={Dv}, dtype={dtype}, " + f"strided_qkv={strided_qkv}, l2norm={use_qk_l2norm}" + ), + tolerance=tol, + ) + ) + + for dtype in _TENSOR_DTYPES: + tests.append( + TestCase( + inputs=[ + tensor_spec((1, 48, 128), dtype), + tensor_spec((1, 48, 128), dtype), + tensor_spec((1, 48, 128), dtype), + gate_spec((1, 1, 48)), + beta_spec((1, 1, 48)), + TensorSpec.from_tensor((1, 48, 128, 128), None, dtype), + ], + kwargs={"mode": "user_3d", "use_qk_l2norm": False}, + description=f"user 3D repro dtype={dtype}", + tolerance=_TOLERANCE_MAP[dtype], + ) + ) + + return tests + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("recurrent_gated_delta_rule") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, q, k, v, g, beta, initial_state, *args, **kwargs): + mode = kwargs.pop("mode") + use_qk_l2norm = kwargs.pop("use_qk_l2norm", False) + + if mode == "legacy": + out, _ = ref_recurrent_gated_delta_rule( + q, k, v, g, beta, initial_state, use_qk_l2norm=use_qk_l2norm + ) + return out + + if mode == "indexed_pool": + initial_state_indices, final_state_indices = args + state_pool = initial_state.clone() + gathered_initial_state = ( + state_pool[initial_state_indices].transpose(-1, -2).contiguous() + ) + out, final_state = ref_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + gathered_initial_state, + use_qk_l2norm=use_qk_l2norm, + ) + state_pool[final_state_indices] = final_state.transpose(-1, -2).contiguous() + return out, state_pool + + if mode == "user_3d": + out, _ = ref_recurrent_gated_delta_rule( + q.unsqueeze(1), + k.unsqueeze(1), + v.unsqueeze(1), + g, + beta, + initial_state, + use_qk_l2norm=use_qk_l2norm, + ) + return out + + raise ValueError(f"Unsupported test mode: {mode}") + + def infinicore_operator(self, q, k, v, g, beta, initial_state, *args, **kwargs): + mode = kwargs.pop("mode") + use_qk_l2norm = kwargs.pop("use_qk_l2norm", False) + + if mode == "legacy" or mode == "user_3d": + return infinicore.nn.functional.recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + initial_state, + use_qk_l2norm=use_qk_l2norm, + ) + + if mode == "indexed_pool": + initial_state_indices, final_state_indices = args + out = infinicore.nn.functional.recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + initial_state, + initial_state_indices=initial_state_indices, + final_state_indices=final_state_indices, + use_qk_l2norm=use_qk_l2norm, + ) + return out, initial_state + + raise ValueError(f"Unsupported test mode: {mode}") + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infiniop/chunk_gated_delta_rule.py b/test/infiniop/chunk_gated_delta_rule.py index b1de7b2a0..674ecad8e 100644 --- a/test/infiniop/chunk_gated_delta_rule.py +++ b/test/infiniop/chunk_gated_delta_rule.py @@ -1,9 +1,8 @@ -# test_chunk_gated_delta_rule.py +import ctypes +from ctypes import c_uint64 import torch import torch.nn.functional as F -import ctypes -from ctypes import c_uint32, c_float, c_uint64, c_size_t, POINTER, addressof from libinfiniop import ( LIBINFINIOP, @@ -14,7 +13,6 @@ get_args, debug, get_tolerance, - profile_operation, InfiniDtype, InfiniDtypeNames, InfiniDeviceNames, @@ -23,19 +21,14 @@ ) -# ============================================================================== -# Reference Implementation -# ============================================================================== -# From modeling_qwen3_next.py, the production PyTorch fallback implementation def ref_chunk_gated_delta_rule( query, key, value, g, beta, - chunk_size=64, - initial_state=None, - output_final_state=False, + initial_state, + cu_seqlens=None, use_qk_l2norm_in_kernel=False, ): initial_dtype = query.dtype @@ -43,231 +36,144 @@ def ref_chunk_gated_delta_rule( query = F.normalize(query, p=2, dim=-1) key = F.normalize(key, p=2, dim=-1) - # The production implementation expects (B, T, H, D) and transposes internally - # query, key, value, beta, g = [ - # x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) - # ] - query, key, value, beta, g = [ x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] - - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] - # print("before pad", query.shape, key.shape, value.shape, beta.shape, g.shape) - - pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)) - key = F.pad(key, (0, 0, 0, pad_size)) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - # print("after pad", query.shape, key.shape, value.shape, beta.shape, g.shape) - - tot_seqs = sequence_length + pad_size - scale = 1 / (query.shape[-1] ** 0.5) + state = initial_state.contiguous().to(torch.float32).clone() + + if cu_seqlens is None: + batch_size, sequence_length, key_heads, k_head_dim = key.shape + spans = [(b, 0, sequence_length) for b in range(batch_size)] + else: + key_heads, k_head_dim = key.shape[2], key.shape[3] + batch_size = cu_seqlens.numel() - 1 + spans = [ + (b, int(cu_seqlens[b].item()), int(cu_seqlens[b + 1].item())) + for b in range(batch_size) + ] + + value_heads, v_head_dim = value.shape[2], value.shape[3] + value_heads_per_key_head = value_heads // key_heads + scale = 1 / (k_head_dim**0.5) query = query * scale + out = torch.zeros_like(value, dtype=torch.float32) + + for b, start, end in spans: + for vh in range(value_heads): + kh = vh // value_heads_per_key_head + state_t = state[b, vh] + for t in range(start, end): + token_b = 0 if cu_seqlens is not None else b + q_t = query[token_b, t, kh] + k_t = key[token_b, t, kh] + v_t = value[token_b, t, vh] + g_t = g[token_b, t, vh].exp() + beta_t = beta[token_b, t, vh] + + state_t = state_t * g_t + kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state_t = state_t + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + state[b, vh] = state_t + out[token_b, t, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) + + return out.contiguous().to(initial_dtype), state.contiguous().to(initial_dtype) + + +_PADDED_TEST_CASES_DATA = [ + # B, T, n_khead, kdim, n_vhead, vdim, chunk_size, use_qk_l2norm, strided_qkv + (2, 17, 4, 64, 4, 64, 8, True, False), + (2, 19, 4, 64, 8, 64, 8, False, False), + (2, 13, 4, 64, 8, 64, 8, True, True), +] - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - - # Reshape to chunks (in the head dimension) - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) +# Test cases: (n_khead, kdim, n_vhead, vdim, (seqlens), (init_state_indices), +# final_state_indices, state_pool_size) +_VARLEN_TEST_CASES_DATA = [ + (4, 64, 8, 64, (1, 17, 3, 9), (1, 2, 3, 4), (5, 6, 7, 8), 13), + (16, 128, 48, 128, (13,), (0,), (0,), 1), +] - mask = torch.triu( - torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), - diagonal=0, - ) +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] - # This part is quite intricate and involves parallel scan logic. - # We will trust the reference implementation as the ground truth. - g = g.cumsum(dim=-1) +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-2, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-4}, +} - decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() +DEBUG = False - attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) +def parse_test_cases(): + tests = [] + for case in _PADDED_TEST_CASES_DATA: + tests.append(("padded", case)) + for case in _VARLEN_TEST_CASES_DATA: + tests.append(("varlen_indexed_pool", case)) + return tests - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) +def bthd_strides(B, T, H, D, strided): + if not strided: + return None + return (T * H * D * 2, H * D * 2, D * 2, 1) - last_recurrent_state = ( - torch.zeros( - batch_size, - num_heads, - k_head_dim, - v_head_dim, - device=value.device, - dtype=torch.float32, - ) - if initial_state is None - else initial_state.to(torch.float32) - ) - core_attn_out = torch.zeros_like(value) - mask = torch.triu( - torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), - diagonal=1, +def make_gate(shape, device): + return TestTensor.from_torch( + F.logsigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device ) - for i in range(0, tot_seqs // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn_intra = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( - mask, 0 - ) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn_intra @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() - + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2 - ) - @ v_new - ) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape( - core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] +def make_beta(shape, device): + return TestTensor.from_torch( + torch.sigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device ) - core_attn_out = core_attn_out[:, :, :sequence_length] # Unpad - # core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) - core_attn_out = core_attn_out.contiguous().to(initial_dtype) - - if last_recurrent_state is not None: - last_recurrent_state = last_recurrent_state.contiguous().to(initial_dtype) - - return core_attn_out, last_recurrent_state - - -# ============================================================================== -# Test Configuration -# ============================================================================== -# (B, T, H, Dk, Dv, chunk_size, use_qk_l2norm) -# T (seq_len) must be > 1 for this operator -_TEST_CASES_ = [ - (2, 511, 40, 64, 64, 8, True), - # (2, 511, 40, 64, 64, 16, True), - # (4, 1024, 64, 128, 128, 64, False), - (8, 511, 32, 64, 64, 8, True), - (8, 511, 32, 128, 128, 8, True), -] -# Data types for testing -_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] -# Tolerance map -_TOLERANCE_MAP = { - InfiniDtype.F16: { - "atol": 1e-3, - "rtol": 1e-3, - }, # Higher tolerance due to complex ops - InfiniDtype.BF16: {"atol": 5e-2, "rtol": 5e-2}, - InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-4}, -} - -# Global flags -DEBUG = False -PROFILE = False -NUM_PRERUN = 10 -NUM_ITERATIONS = 100 - - -def test( +def run_op( handle, device, - B, - T, - H, - Dk, - Dv, - chunk_size, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, use_qk_l2norm, - dtype=InfiniDtype.F16, - sync=None, + chunk_size, ): - print( - f"Testing ChunkGatedDeltaRule on {InfiniDeviceNames[device]} with " - f"B={B}, T={T}, H={H}, Dk={Dk}, Dv={Dv}, chunk_size={chunk_size}, " - f"dtype={InfiniDtypeNames[dtype]}, use_qk_l2norm={use_qk_l2norm}" - ) - - # Input tensors are in (B, H, T, D) layout as they come from the model layers - q = TestTensor((B, H, T, Dk), None, dtype, device) - k = TestTensor((B, H, T, Dk), None, dtype, device) - v = TestTensor((B, H, T, Dv), None, dtype, device) - - g_logsigmoid = torch.randn(B, H, T, dtype=torch.float32) - g = TestTensor.from_torch(F.logsigmoid(g_logsigmoid), dtype, device) - beta_sigmoid = torch.randn(B, H, T, dtype=torch.float32) - beta = TestTensor.from_torch(torch.sigmoid(beta_sigmoid), dtype, device) - - initial_state = TestTensor((B, H, Dk, Dv), None, dtype, device) - # initial_state = None - # final_state = initial_state - - initial_state_desc = ctypes.c_void_p(0) - initial_state_data = ctypes.c_void_p(0) - initial_state_torch = None - if initial_state is not None: - initial_state_desc = initial_state.descriptor - initial_state_data = initial_state.data() - initial_state_torch = initial_state.torch_tensor() - - # Output tensors - out = TestTensor((B, H, T, Dv), None, dtype, device) - # final_state shape is (B, H, Dk, Dv) - final_state = TestTensor((B, H, Dk, Dv), None, dtype, device) - - # Run reference implementation - ans_out, ans_final_state = ref_chunk_gated_delta_rule( - q.torch_tensor(), - k.torch_tensor(), - v.torch_tensor(), - g.torch_tensor(), - beta.torch_tensor(), - chunk_size=chunk_size, - initial_state=initial_state_torch, - output_final_state=True, - use_qk_l2norm_in_kernel=use_qk_l2norm, - ) - - if sync: - sync() - - # Create operator descriptor descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateChunkGatedDeltaRuleDescriptor( handle, ctypes.byref(descriptor), out.descriptor, - final_state.descriptor, + initial_state.descriptor, + final_state.descriptor if final_state is not None else ctypes.c_void_p(0), q.descriptor, k.descriptor, v.descriptor, g.descriptor, beta.descriptor, - initial_state_desc, + cu_seqlens.descriptor if cu_seqlens is not None else ctypes.c_void_p(0), + initial_state_indices.descriptor + if initial_state_indices is not None + else ctypes.c_void_p(0), + final_state_indices.descriptor + if final_state_indices is not None + else ctypes.c_void_p(0), ctypes.c_bool(use_qk_l2norm), ctypes.c_size_t(chunk_size), ) ) - # Get workspace size workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetChunkGatedDeltaRuleWorkspaceSize( @@ -276,69 +182,251 @@ def test( ) workspace = TestWorkspace(workspace_size.value, q.device) - # Invalidate descriptors to ensure kernel does not rely on them - q.destroy_desc() - k.destroy_desc() - v.destroy_desc() - g.destroy_desc() - beta.destroy_desc() - if initial_state is not None: - initial_state.destroy_desc() - out.destroy_desc() - final_state.destroy_desc() - - # Define the library call - def lib_chunk_gated_delta_rule(): - check_error( - LIBINFINIOP.infiniopChunkGatedDeltaRule( - descriptor, - workspace.data(), - workspace_size.value, - out.data(), - final_state.data(), - q.data(), - k.data(), - v.data(), - g.data(), - beta.data(), - initial_state_data, - None, - ) + for tensor in [ + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + ]: + if tensor is not None: + tensor.destroy_desc() + + check_error( + LIBINFINIOP.infiniopChunkGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + initial_state.data(), + final_state.data() if final_state is not None else None, + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + cu_seqlens.data() if cu_seqlens is not None else None, + initial_state_indices.data() if initial_state_indices is not None else None, + final_state_indices.data() if final_state_indices is not None else None, + None, ) + ) + check_error(LIBINFINIOP.infiniopDestroyChunkGatedDeltaRuleDescriptor(descriptor)) - # Execute the custom operator - lib_chunk_gated_delta_rule() +def test_padded( + handle, + device, + test_case, + dtype=InfiniDtype.F16, + sync=None, +): + B, T, Hk, Dk, Hv, Dv, chunk_size, use_qk_l2norm, strided_qkv = test_case + print( + f"Testing ChunkGatedDeltaRule on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, chunk={chunk_size}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, strided_qkv={strided_qkv}, " + f"l2norm={use_qk_l2norm}" + ) + q = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + k = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + v = TestTensor( + (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), dtype, device + ) + g = make_gate((B, T, Hv), device) + beta = make_beta((B, T, Hv), device) + initial_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + final_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + out = TestTensor( + (B, T, Hv, Dv), + bthd_strides(B, T, Hv, Dv, strided_qkv), + dtype, + device, + mode="zeros", + ) + + ans_out, ans_final_state = ref_chunk_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + initial_state.torch_tensor(), + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) if sync: sync() - # Verify correctness - atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + run_op( + handle, + device, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + None, + None, + None, + use_qk_l2norm, + chunk_size, + ) + if sync: + sync() + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: - print("--- Verifying Output Tensor ---") debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose( + final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol + ) + + +def test_varlen_indexed_pool( + handle, + device, + test_case, + dtype=InfiniDtype.F16, + sync=None, +): + ( + Hk, + Dk, + Hv, + Dv, + lengths, + initial_state_indices_data, + final_state_indices_data, + pool_size, + ) = test_case + chunk_size = 8 + use_qk_l2norm = True + lengths = tuple(lengths) + initial_state_indices_data = tuple(initial_state_indices_data) + final_state_indices_data = tuple(final_state_indices_data) + B = len(lengths) + total_tokens = sum(lengths) + print( + f"Testing ChunkGatedDeltaRule varlen indexed pool on {InfiniDeviceNames[device]} with " + f"lengths={lengths}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, chunk={chunk_size}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, " + f"initial_indices={initial_state_indices_data}, final_indices={final_state_indices_data}, " + f"pool_size={pool_size}" + ) + q = TestTensor((1, total_tokens, Hk, Dk), None, dtype, device) + k = TestTensor((1, total_tokens, Hk, Dk), None, dtype, device) + v = TestTensor((1, total_tokens, Hv, Dv), None, dtype, device) + g = make_gate((1, total_tokens, Hv), device) + beta = make_beta((1, total_tokens, Hv), device) + out = TestTensor((1, total_tokens, Hv, Dv), None, dtype, device, mode="zeros") + + cu = torch.tensor( + [0] + list(torch.tensor(lengths).cumsum(0).tolist()), + dtype=torch.int64, + device=q.torch_tensor().device, + ) + cu_seqlens = TestTensor.from_torch(cu, InfiniDtype.I64, device) + initial_state_pool = TestTensor((pool_size, Hv, Dv, Dk), None, dtype, device) + initial_state_indices_torch = torch.tensor( + initial_state_indices_data, dtype=torch.int64, device=q.torch_tensor().device + ) + final_state_indices_torch = torch.tensor( + final_state_indices_data, dtype=torch.int64, device=q.torch_tensor().device + ) + initial_state_indices = TestTensor.from_torch( + initial_state_indices_torch, InfiniDtype.I64, device + ) + final_state_indices = TestTensor.from_torch( + final_state_indices_torch, InfiniDtype.I64, device + ) + + gathered_initial = ( + initial_state_pool.torch_tensor()[initial_state_indices_torch] + .transpose(-1, -2) + .contiguous() + ) + ans_out, ans_final_state = ref_chunk_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + gathered_initial, + cu_seqlens=cu, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + ans_pool = initial_state_pool.torch_tensor().clone() + ans_pool[final_state_indices_torch] = ans_final_state.transpose(-1, -2).contiguous() + if sync: + sync() + + run_op( + handle, + device, + out, + initial_state_pool, + None, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size, + ) + if sync: + sync() + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: - print("--- Verifying Final State Tensor ---") - debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug(initial_state_pool.actual_tensor(), ans_pool, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) assert torch.allclose( - final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol + initial_state_pool.actual_tensor(), ans_pool, atol=atol, rtol=rtol ) - # Clean up - check_error(LIBINFINIOP.infiniopDestroyChunkGatedDeltaRuleDescriptor(descriptor)) + +def test( + handle, + device, + mode, + test_case, + dtype=InfiniDtype.F16, + sync=None, +): + if mode == "padded": + return test_padded(handle, device, test_case, dtype=dtype, sync=sync) + if mode == "varlen_indexed_pool": + return test_varlen_indexed_pool( + handle, device, test_case, dtype=dtype, sync=sync + ) + raise ValueError(f"Unknown chunk_gated_delta_rule test mode: {mode}") if __name__ == "__main__": args = get_args() DEBUG = args.debug - PROFILE = args.profile - NUM_PRERUN = args.num_prerun - NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): - test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + test_operator(device, test, parse_test_cases(), _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 988f3ac8e..4e6dfb33b 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1266,7 +1266,6 @@ def per_tensor_quant_int8_(lib): c_void_p, c_void_p, c_void_p, - c_void_p, c_bool, c_void_p, ] @@ -2792,12 +2791,14 @@ def recurrent_gated_delta_rule_(lib): POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, + c_void_p, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, + c_void_p, + c_void_p, c_bool, ] lib.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize.restype = c_int32 @@ -2819,6 +2820,8 @@ def recurrent_gated_delta_rule_(lib): c_void_p, c_void_p, c_void_p, + c_void_p, + c_void_p, ] lib.infiniopDestroyRecurrentGatedDeltaRuleDescriptor.restype = c_int32 lib.infiniopDestroyRecurrentGatedDeltaRuleDescriptor.argtypes = [ @@ -2834,12 +2837,15 @@ def chunk_gated_delta_rule_(lib): POINTER(infiniopOperatorDescriptor_t), infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, + c_void_p, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, c_void_p, + c_void_p, + c_void_p, c_bool, c_size_t, ] @@ -2862,6 +2868,9 @@ def chunk_gated_delta_rule_(lib): c_void_p, c_void_p, c_void_p, + c_void_p, + c_void_p, + c_void_p, ] lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.restype = c_int32 lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.argtypes = [ diff --git a/test/infiniop/recurrent_gated_delta_rule.py b/test/infiniop/recurrent_gated_delta_rule.py index 8cca95262..25a40911f 100644 --- a/test/infiniop/recurrent_gated_delta_rule.py +++ b/test/infiniop/recurrent_gated_delta_rule.py @@ -1,9 +1,10 @@ # test_recurrent_gated_delta_rule.py +import ctypes +from ctypes import c_uint64 + import torch import torch.nn.functional as F -import ctypes -from ctypes import c_uint32, c_float, c_uint64, c_size_t, POINTER, addressof from libinfiniop import ( LIBINFINIOP, @@ -23,11 +24,6 @@ ) -# ============================================================================== -# Reference Implementation -# ============================================================================== -# 从 modeling_qwen3_next.py 提供的生产环境PyTorch备选实现 -# 我们将严格对照此函数进行测试 def ref_recurrent_gated_delta_rule( query, key, @@ -43,54 +39,43 @@ def ref_recurrent_gated_delta_rule( query = F.normalize(query, p=2, dim=-1) key = F.normalize(key, p=2, dim=-1) - # 生产环境的实现期望输入已经是 (B, H, T, D) - # 我们在测试数据生成时会直接生成这种格式,以模拟真实调用场景 query, key, value, beta, g = [ x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] + initial_state = initial_state.contiguous().to(torch.float32).clone() - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] + batch_size, sequence_length, key_heads, k_head_dim = key.shape + value_heads, v_head_dim = value.shape[2], value.shape[-1] + value_heads_per_key_head = value_heads // key_heads scale = 1 / (query.shape[-1] ** 0.5) query = query * scale core_attn_out = torch.zeros( batch_size, - num_heads, sequence_length, + value_heads, v_head_dim, device=value.device, dtype=torch.float32, ) - - # 注意:这里的 initial_state 形状是 (B, H, Dk, Dv) - last_recurrent_state = ( - torch.zeros( - batch_size, - num_heads, - k_head_dim, - v_head_dim, - device=value.device, - dtype=torch.float32, - ) - if initial_state is None - else initial_state.to(torch.float32) - ) + last_recurrent_state = initial_state for i in range(sequence_length): - q_t = query[:, :, i] - k_t = key[:, :, i] - v_t = value[:, :, i] - g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) - beta_t = beta[:, :, i].unsqueeze(-1) - - last_recurrent_state = last_recurrent_state * g_t - kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - delta = (v_t - kv_mem) * beta_t - last_recurrent_state = last_recurrent_state + k_t.unsqueeze( - -1 - ) * delta.unsqueeze(-2) - core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + for vh in range(value_heads): + kh = vh // value_heads_per_key_head + q_t = query[:, i, kh] + k_t = key[:, i, kh] + v_t = value[:, i, vh] + g_t = g[:, i, vh].exp().view(batch_size, 1, 1) + beta_t = beta[:, i, vh].view(batch_size, 1) + state_t = last_recurrent_state[:, vh] + + state_t = state_t * g_t + kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state_t = state_t + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + last_recurrent_state[:, vh] = state_t + core_attn_out[:, i, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) if not output_final_state: last_recurrent_state = None @@ -102,71 +87,89 @@ def ref_recurrent_gated_delta_rule( return core_attn_out, last_recurrent_state -# ============================================================================== -# Test Configuration -# ============================================================================== -# (B, T, H, Dk, Dv, use_qk_l2norm) -# T (seq_len) is typically 1 for decode stage _TEST_CASES_ = [ - (7, 1, 40, 128, 128, True), - (5, 1, 64, 128, 128, False), - (1, 1, 8, 64, 64, True), - # (16, 1, 32, 80, 80, True), + (7, 1, 40, 40, 128, 128, True, False), + (5, 1, 64, 64, 128, 128, False, False), + (1, 1, 8, 8, 64, 64, True, False), + (2, 1, 4, 8, 64, 64, False, False), + (2, 1, 4, 8, 64, 64, True, True), ] -# Data types for testing _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] -# Tolerance map for different data types _TOLERANCE_MAP = { InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, } -# Global flags for controlling test behavior DEBUG = False PROFILE = False NUM_PRERUN = 10 NUM_ITERATIONS = 100 +def bthd_strides(B, T, H, D, strided): + if not strided: + return None + return (T * H * D * 2, H * D * 2, D * 2, 1) + + +def make_gate(shape, device): + return TestTensor.from_torch( + F.logsigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device + ) + + +def make_beta(shape, device): + return TestTensor.from_torch( + torch.sigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device + ) + + def test( handle, device, B, T, - H, + Hk, + Hv, Dk, Dv, use_qk_l2norm, + strided_qkv, dtype=InfiniDtype.F16, sync=None, ): print( f"Testing RecurrentGatedDeltaRule on {InfiniDeviceNames[device]} with " - f"B={B}, T={T}, H={H}, Dk={Dk}, Dv={Dv}, dtype={InfiniDtypeNames[dtype]}, " + f"B={B}, T={T}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, strided_qkv={strided_qkv}, " f"use_qk_l2norm={use_qk_l2norm}" ) - # Create input tensors. - # IMPORTANT: We directly create tensors in (B, H, T, D) layout to match the production environment. - q = TestTensor((B, H, T, Dk), None, dtype, device) - k = TestTensor((B, H, T, Dk), None, dtype, device) - v = TestTensor((B, H, T, Dv), None, dtype, device) - # g and beta have shape (B, H, T) - g_logsigmoid = torch.randn(B, H, T, dtype=torch.float32) - g = TestTensor.from_torch(F.logsigmoid(g_logsigmoid), dtype, device) - beta_sigmoid = torch.randn(B, H, T, dtype=torch.float32) - beta = TestTensor.from_torch(torch.sigmoid(beta_sigmoid), dtype, device) - - initial_state = TestTensor((B, H, Dk, Dv), None, dtype, device) - - # Create output tensors - out = TestTensor((B, H, T, Dv), None, dtype, device) - final_state = TestTensor((B, H, Dk, Dv), None, dtype, device) + q = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + k = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + v = TestTensor( + (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), dtype, device + ) + g = make_gate((B, T, Hv), device) + beta = make_beta((B, T, Hv), device) + + initial_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + out = TestTensor( + (B, T, Hv, Dv), + bthd_strides(B, T, Hv, Dv, strided_qkv), + dtype, + device, + mode="zeros", + ) + final_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) - # Run reference implementation ans_out, ans_final_state = ref_recurrent_gated_delta_rule( q.torch_tensor(), k.torch_tensor(), @@ -181,25 +184,25 @@ def test( if sync: sync() - # Create operator descriptor descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateRecurrentGatedDeltaRuleDescriptor( handle, ctypes.byref(descriptor), out.descriptor, + initial_state.descriptor, final_state.descriptor, q.descriptor, k.descriptor, v.descriptor, g.descriptor, beta.descriptor, - initial_state.descriptor, + ctypes.c_void_p(0), + ctypes.c_void_p(0), ctypes.c_bool(use_qk_l2norm), ) ) - # Get workspace size and allocate memory workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( @@ -208,7 +211,6 @@ def test( ) workspace = TestWorkspace(workspace_size.value, q.device) - # Invalidate descriptors to ensure kernel does not rely on them q.destroy_desc() k.destroy_desc() v.destroy_desc() @@ -218,7 +220,6 @@ def test( out.destroy_desc() final_state.destroy_desc() - # Define the library call as a lambda for profiling def lib_recurrent_gated_delta_rule(): check_error( LIBINFINIOP.infiniopRecurrentGatedDeltaRule( @@ -226,53 +227,216 @@ def lib_recurrent_gated_delta_rule(): workspace.data(), workspace_size.value, out.data(), + initial_state.data(), final_state.data(), q.data(), k.data(), v.data(), g.data(), beta.data(), - initial_state.data(), + None, + None, None, ) ) - # Execute the custom operator lib_recurrent_gated_delta_rule() if sync: sync() - # Verify correctness atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - - # Verify main output if DEBUG: - print("--- Verifying Output Tensor ---") debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) - assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) - - # Verify final state - if DEBUG: - print("--- Verifying Final State Tensor ---") debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) assert torch.allclose( final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol ) - # print(final_state.actual_tensor(), ans_final_state) - # Profiling workflow if PROFILE: - # fmt: off - profile_operation("PyTorch", lambda: ref_recurrent_gated_delta_rule( - q.torch_tensor(), k.torch_tensor(), v.torch_tensor(), - g.torch_tensor(), beta.torch_tensor(), initial_state.torch_tensor(), - output_final_state=True, use_qk_l2norm_in_kernel=use_qk_l2norm), - device, NUM_PRERUN, NUM_ITERATIONS) - profile_operation(" lib", lib_recurrent_gated_delta_rule, device, NUM_PRERUN, NUM_ITERATIONS) - # fmt: on - - # Clean up resources + profile_operation( + "PyTorch", + lambda: ref_recurrent_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + initial_state.torch_tensor(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lib_recurrent_gated_delta_rule, + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error( + LIBINFINIOP.infiniopDestroyRecurrentGatedDeltaRuleDescriptor(descriptor) + ) + + +def test_indexed_pool_inplace( + handle, + device, + B, + T, + Hk, + Hv, + Dk, + Dv, + use_qk_l2norm, + strided_qkv, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing RecurrentGatedDeltaRule indexed pool inplace on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, strided_qkv={strided_qkv}, " + f"use_qk_l2norm={use_qk_l2norm}" + ) + + q = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + k = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + v = TestTensor( + (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), dtype, device + ) + g = make_gate((B, T, Hv), device) + beta = make_beta((B, T, Hv), device) + + pool_size = B * 2 + 3 + initial_state_pool = TestTensor((pool_size, Hv, Dv, Dk), None, dtype, device) + index_device = q.torch_tensor().device + initial_state_indices_torch = torch.arange( + 1, B + 1, dtype=torch.int64, device=index_device + ) + final_state_indices_torch = torch.arange( + B + 1, 2 * B + 1, dtype=torch.int64, device=index_device + ) + initial_state_indices = TestTensor.from_torch( + initial_state_indices_torch, InfiniDtype.I64, device + ) + final_state_indices = TestTensor.from_torch( + final_state_indices_torch, InfiniDtype.I64, device + ) + + out = TestTensor( + (B, T, Hv, Dv), + bthd_strides(B, T, Hv, Dv, strided_qkv), + dtype, + device, + mode="zeros", + ) + + gathered_initial_state = ( + initial_state_pool.torch_tensor()[initial_state_indices_torch] + .transpose(-1, -2) + .contiguous() + ) + ans_out, ans_final_state = ref_recurrent_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + gathered_initial_state, + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + ans_initial_state_pool = initial_state_pool.torch_tensor().clone() + ans_initial_state_pool[final_state_indices_torch] = ans_final_state.transpose( + -1, -2 + ).contiguous() + + if sync: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRecurrentGatedDeltaRuleDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + initial_state_pool.descriptor, + ctypes.c_void_p(0), + q.descriptor, + k.descriptor, + v.descriptor, + g.descriptor, + beta.descriptor, + initial_state_indices.descriptor, + final_state_indices.descriptor, + ctypes.c_bool(use_qk_l2norm), + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + q.destroy_desc() + k.destroy_desc() + v.destroy_desc() + g.destroy_desc() + beta.destroy_desc() + initial_state_pool.destroy_desc() + initial_state_indices.destroy_desc() + final_state_indices.destroy_desc() + out.destroy_desc() + + check_error( + LIBINFINIOP.infiniopRecurrentGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + initial_state_pool.data(), + None, + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + initial_state_indices.data(), + final_state_indices.data(), + None, + ) + ) + + if sync: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug( + initial_state_pool.actual_tensor(), + ans_initial_state_pool, + atol=atol, + rtol=rtol, + ) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose( + initial_state_pool.actual_tensor(), ans_initial_state_pool, atol=atol, rtol=rtol + ) + check_error( LIBINFINIOP.infiniopDestroyRecurrentGatedDeltaRuleDescriptor(descriptor) ) @@ -281,7 +445,6 @@ def lib_recurrent_gated_delta_rule(): if __name__ == "__main__": args = get_args() - # Configure testing options from command line arguments DEBUG = args.debug PROFILE = args.profile NUM_PRERUN = args.num_prerun @@ -289,5 +452,6 @@ def lib_recurrent_gated_delta_rule(): for device in get_test_devices(args): test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + test_operator(device, test_indexed_pool_inplace, _TEST_CASES_, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m") From 308d158ccd78e333e3ea19bb055982804563de9c Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 25 Jun 2026 09:08:35 +0000 Subject: [PATCH 3/3] issue/1210 fix format --- .../nvidia/chunk_gated_delta_rule_nvidia.cuh | 2 +- .../nvidia/recurrent_gated_delta_rule_nvidia.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh index b811a7185..5dde15f7d 100644 --- a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh +++ b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh @@ -5,4 +5,4 @@ DESCRIPTOR(nvidia) -#endif // __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ \ No newline at end of file +#endif // __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh index 61bfb3051..82236a31d 100644 --- a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh @@ -5,4 +5,4 @@ DESCRIPTOR(nvidia) -#endif // __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ \ No newline at end of file +#endif // __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__