diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 061bfbfd1..17ef59d3d 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -20,6 +20,7 @@ #include "ops/blas_dot.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" +#include "ops/chunk_gated_delta_rule.hpp" #include "ops/conv2d.hpp" #include "ops/cross_entropy.hpp" #include "ops/deepseek_moe.hpp" @@ -46,6 +47,7 @@ #include "ops/random_sample.hpp" #include "ops/rearrange.hpp" #include "ops/reciprocal.hpp" +#include "ops/recurrent_gated_delta_rule.hpp" #include "ops/relu.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" diff --git a/include/infinicore/ops/chunk_gated_delta_rule.hpp b/include/infinicore/ops/chunk_gated_delta_rule.hpp new file mode 100644 index 000000000..5102836e0 --- /dev/null +++ b/include/infinicore/ops/chunk_gated_delta_rule.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "infinicore.h" + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(ChunkGatedDeltaRule, + Tensor, + Tensor, + std::optional, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + std::optional, + std::optional, + std::optional, + bool, + size_t); + +__export Tensor chunk_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + std::optional cu_seqlens = std::nullopt, + std::optional initial_state_indices = std::nullopt, + std::optional final_state_indices = std::nullopt, + bool use_qk_l2norm = false, + size_t chunk_size = 64); + +__export void chunk_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm = false, + size_t chunk_size = 64); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/recurrent_gated_delta_rule.hpp b/include/infinicore/ops/recurrent_gated_delta_rule.hpp new file mode 100644 index 000000000..7d837c67b --- /dev/null +++ b/include/infinicore/ops/recurrent_gated_delta_rule.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "infinicore.h" + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(RecurrentGatedDeltaRule, + Tensor, + Tensor, + std::optional, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + const Tensor &, + std::optional, + std::optional, + bool); + +__export Tensor recurrent_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + const Tensor &initial_state, + bool use_qk_l2norm = false); + +__export Tensor recurrent_gated_delta_rule_indexed(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + const Tensor &initial_state_indices, + const Tensor &final_state_indices, + bool use_qk_l2norm = false); + +__export void recurrent_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm = false); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 5f58805b9..c1482552f 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -31,6 +31,7 @@ #include "infiniop/ops/broadcast_to.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/cdist.h" +#include "infiniop/ops/chunk_gated_delta_rule.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/cross_entropy.h" @@ -100,6 +101,7 @@ #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/reciprocal.h" +#include "infiniop/ops/recurrent_gated_delta_rule.h" #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" diff --git a/include/infiniop/ops/chunk_gated_delta_rule.h b/include/infiniop/ops/chunk_gated_delta_rule.h new file mode 100644 index 000000000..a9a9c74aa --- /dev/null +++ b/include/infiniop/ops/chunk_gated_delta_rule.h @@ -0,0 +1,49 @@ +#ifndef __INFINIOP_CHUNK_GATED_DELTA_RULE_API_H__ +#define __INFINIOP_CHUNK_GATED_DELTA_RULE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopChunkGatedDeltaRuleDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv] + infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk] + infiniopTensorDescriptor_t final_state_desc, // null when final_state_indices_desc is provided + infiniopTensorDescriptor_t q_desc, // padded: [B, T, Hk, Dk]; varlen: [1, total_tokens, Hk, Dk] + infiniopTensorDescriptor_t k_desc, // same shape as q + infiniopTensorDescriptor_t v_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv] + infiniopTensorDescriptor_t g_desc, // padded: [B, T, Hv]; varlen: [1, total_tokens, Hv] + infiniopTensorDescriptor_t beta_desc, // same shape/dtype as g + infiniopTensorDescriptor_t cu_seqlens_desc, // nullable; [B + 1], int32/int64 + infiniopTensorDescriptor_t initial_state_indices_desc, // nullable; [B], int32/int64; enables indexed state-pool reads + infiniopTensorDescriptor_t final_state_indices_desc, // nullable; [B], int32/int64; writes final state in-place to initial_state pool + bool use_qk_l2norm, + size_t chunk_size); + +__INFINI_C __export infiniStatus_t infiniopGetChunkGatedDeltaRuleWorkspaceSize( + infiniopChunkGatedDeltaRuleDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopChunkGatedDeltaRule( + infiniopChunkGatedDeltaRuleDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void *initial_state, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor( + infiniopChunkGatedDeltaRuleDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/recurrent_gated_delta_rule.h b/include/infiniop/ops/recurrent_gated_delta_rule.h new file mode 100644 index 000000000..2865cc1f5 --- /dev/null +++ b/include/infiniop/ops/recurrent_gated_delta_rule.h @@ -0,0 +1,46 @@ +#ifndef __INFINIOP_RECURRENT_GATED_DELTA_RULE_API_H__ +#define __INFINIOP_RECURRENT_GATED_DELTA_RULE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRecurrentGatedDeltaRuleDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, // [B, T, Hv, Dv], T must be 1; last dim contiguous + infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk] + infiniopTensorDescriptor_t final_state_desc, // legacy/indexed out-of-place final state; null when final_state_indices_desc is provided + infiniopTensorDescriptor_t q_desc, // [B, T, Hk, Dk], T must be 1; last dim contiguous + infiniopTensorDescriptor_t k_desc, // [B, T, Hk, Dk], same shape as q; last dim contiguous + infiniopTensorDescriptor_t v_desc, // [B, T, Hv, Dv], Hv must be a multiple of Hk; last dim contiguous + infiniopTensorDescriptor_t g_desc, // [B, T, Hv]; may have a different fp dtype from q/k/v/out/state + infiniopTensorDescriptor_t beta_desc, // [B, T, Hv]; same dtype as g + infiniopTensorDescriptor_t initial_state_indices_desc, // nullable; [B], int32/int64; enables indexed pool mode + infiniopTensorDescriptor_t final_state_indices_desc, // nullable; [B], int32/int64; writes final state in-place to initial_state pool + bool use_qk_l2norm); + +__INFINI_C __export infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopRecurrentGatedDeltaRule( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void *initial_state, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc); + +#endif diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index f3b9a0a2b..3c11242cc 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -7,6 +7,7 @@ from .avg_pool1d import avg_pool1d from .binary_cross_entropy_with_logits import binary_cross_entropy_with_logits from .causal_softmax import causal_softmax +from .chunk_gated_delta_rule import chunk_gated_delta_rule from .embedding import embedding from .flash_attention import flash_attention from .gaussian_nll_loss import gaussian_nll_loss @@ -23,6 +24,7 @@ from .pad import pad from .prelu import prelu from .random_sample import random_sample +from .recurrent_gated_delta_rule import recurrent_gated_delta_rule from .relu6 import relu6 from .rms_norm import rms_norm from .rope import RopeAlgo, rope @@ -44,6 +46,7 @@ "conv2d", "adaptive_max_pool1d", "causal_softmax", + "chunk_gated_delta_rule", "embedding", "flash_attention", "gaussian_nll_loss", @@ -56,6 +59,7 @@ "prelu", "relu6", "rms_norm", + "recurrent_gated_delta_rule", "sigmoid", "silu", "smooth_l1_loss", diff --git a/python/infinicore/nn/functional/chunk_gated_delta_rule.py b/python/infinicore/nn/functional/chunk_gated_delta_rule.py new file mode 100644 index 000000000..dbcf15e75 --- /dev/null +++ b/python/infinicore/nn/functional/chunk_gated_delta_rule.py @@ -0,0 +1,66 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def chunk_gated_delta_rule( + q: Tensor, + k: Tensor, + v: Tensor, + g: Tensor, + beta: Tensor, + initial_state: Tensor, + *, + cu_seqlens: Tensor | None = None, + initial_state_indices: Tensor | None = None, + final_state_indices: Tensor | None = None, + use_qk_l2norm: bool = False, + chunk_size: int = 64, +) -> Tensor: + """Run chunk gated delta rule and return only ``out``. + + Padded mode shapes: + q, k: ``[B, T, Hk, Dk]`` + v, out: ``[B, T, Hv, Dv]`` + g, beta: ``[B, T, Hv]`` + initial_state: ``[B, Hv, Dk, Dv]`` + + Continuous-batch mode shapes: + Pass ``cu_seqlens`` with shape ``[B + 1]`` and dtype int32/int64. + q, k: ``[1, total_tokens, Hk, Dk]`` + v, out: ``[1, total_tokens, Hv, Dv]`` + g, beta: ``[1, total_tokens, Hv]`` + + Indexed state-pool mode: + initial_state is ``[pool_size, Hv, Dv, Dk]``. + ``initial_state_indices`` and ``final_state_indices`` are both ``[B]`` + int32/int64 tensors. The final state is written in-place into + ``initial_state[final_state_indices]`` and no final state tensor is + returned. + + Notes: + ``Hv`` must be a multiple of ``Hk``. q/k/v/out may be strided in the + first three dimensions, but the last dimension must be contiguous. + g and beta may use a different floating dtype from q/k/v/state. + """ + if (initial_state_indices is None) != (final_state_indices is None): + raise ValueError( + "initial_state_indices and final_state_indices must be provided together" + ) + + return Tensor( + _infinicore.chunk_gated_delta_rule( + q._underlying, + k._underlying, + v._underlying, + g._underlying, + beta._underlying, + initial_state._underlying, + None if cu_seqlens is None else cu_seqlens._underlying, + None + if initial_state_indices is None + else initial_state_indices._underlying, + None if final_state_indices is None else final_state_indices._underlying, + use_qk_l2norm, + chunk_size, + ) + ) diff --git a/python/infinicore/nn/functional/recurrent_gated_delta_rule.py b/python/infinicore/nn/functional/recurrent_gated_delta_rule.py new file mode 100644 index 000000000..323701e93 --- /dev/null +++ b/python/infinicore/nn/functional/recurrent_gated_delta_rule.py @@ -0,0 +1,47 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def recurrent_gated_delta_rule( + q: Tensor, + k: Tensor, + v: Tensor, + g: Tensor, + beta: Tensor, + initial_state: Tensor, + *, + initial_state_indices: Tensor | None = None, + final_state_indices: Tensor | None = None, + use_qk_l2norm: bool = False, +) -> Tensor: + if initial_state_indices is None and final_state_indices is None: + return Tensor( + _infinicore.recurrent_gated_delta_rule( + q._underlying, + k._underlying, + v._underlying, + g._underlying, + beta._underlying, + initial_state._underlying, + use_qk_l2norm, + ) + ) + + if initial_state_indices is None or final_state_indices is None: + raise ValueError( + "initial_state_indices and final_state_indices must be provided together" + ) + + return Tensor( + _infinicore.recurrent_gated_delta_rule_indexed( + q._underlying, + k._underlying, + v._underlying, + g._underlying, + beta._underlying, + initial_state._underlying, + initial_state_indices._underlying, + final_state_indices._underlying, + use_qk_l2norm, + ) + ) diff --git a/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc new file mode 100644 index 000000000..2fccef444 --- /dev/null +++ b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.cc @@ -0,0 +1,172 @@ +#include "infinicore/ops/chunk_gated_delta_rule.hpp" +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(ChunkGatedDeltaRule); + +ChunkGatedDeltaRule::ChunkGatedDeltaRule(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state, q, k, v, g, beta); + if (final_state.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state.value()); + } + if (cu_seqlens.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, cu_seqlens.value()); + } + if (initial_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state_indices.value()); + } + if (final_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state_indices.value()); + } + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); +} + +void ChunkGatedDeltaRule::execute(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(ChunkGatedDeltaRule, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); +} + +static void check_4d_sequence_tensor(const Tensor &x, const char *name) { + if (x->shape().size() != 4) { + throw std::runtime_error(std::string("chunk_gated_delta_rule expects ") + name + " with shape [B, T, H, D] or [1, total_tokens, H, D]"); + } +} + +static Shape chunk_final_state_shape(const Tensor &q, + const Tensor &v, + const Tensor &initial_state, + std::optional cu_seqlens, + std::optional initial_state_indices) { + const auto &q_shape = q->shape(); + const auto &v_shape = v->shape(); + size_t B = cu_seqlens.has_value() ? cu_seqlens.value()->shape()[0] - 1 : v_shape[0]; + size_t Hv = v_shape[2]; + size_t Dk = q_shape[3]; + size_t Dv = v_shape[3]; + if (initial_state_indices.has_value()) { + return {B, Hv, Dv, Dk}; + } + return {B, Hv, Dk, Dv}; +} + +Tensor chunk_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + check_4d_sequence_tensor(q, "q"); + check_4d_sequence_tensor(k, "k"); + check_4d_sequence_tensor(v, "v"); + auto out = Tensor::empty(v->shape(), v->dtype(), v->device()); + std::optional final_state = std::nullopt; + if (!final_state_indices.has_value()) { + final_state = Tensor::empty(chunk_final_state_shape(q, v, initial_state, cu_seqlens, initial_state_indices), + initial_state->dtype(), + initial_state->device()); + } + chunk_gated_delta_rule_(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); + return out; +} + +void chunk_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + check_4d_sequence_tensor(q, "q"); + check_4d_sequence_tensor(k, "k"); + check_4d_sequence_tensor(v, "v"); + ChunkGatedDeltaRule::execute(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc new file mode 100644 index 000000000..96a2b3b44 --- /dev/null +++ b/src/infinicore/ops/chunk_gated_delta_rule/chunk_gated_delta_rule_infiniop.cc @@ -0,0 +1,108 @@ +#include "infinicore/ops/chunk_gated_delta_rule.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::chunk_gated_delta_rule_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, ChunkGatedDeltaRule, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, initial_state, q, k, v, g, beta; + std::optional final_state; + std::optional cu_seqlens; + std::optional initial_state_indices; + std::optional final_state_indices; +}; + +void *plan(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional cu_seqlens, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm, + size_t chunk_size) { + size_t seed = hash_combine(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, ChunkGatedDeltaRule, + seed, + out->desc(), + initial_state->desc(), + final_state.has_value() ? final_state.value()->desc() : nullptr, + q->desc(), + k->desc(), + v->desc(), + g->desc(), + beta->desc(), + cu_seqlens.has_value() ? cu_seqlens.value()->desc() : nullptr, + initial_state_indices.has_value() ? initial_state_indices.value()->desc() : nullptr, + final_state_indices.has_value() ? final_state_indices.value()->desc() : nullptr, + use_qk_l2norm, + chunk_size); + + INFINIOP_WORKSPACE_TENSOR(workspace, ChunkGatedDeltaRule, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(initial_state), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(g), + graph::GraphTensor(beta), + final_state.has_value() ? std::optional(graph::GraphTensor(final_state.value())) : std::nullopt, + cu_seqlens.has_value() ? std::optional(graph::GraphTensor(cu_seqlens.value())) : std::nullopt, + initial_state_indices.has_value() ? std::optional(graph::GraphTensor(initial_state_indices.value())) : std::nullopt, + final_state_indices.has_value() ? std::optional(graph::GraphTensor(final_state_indices.value())) : std::nullopt}; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopChunkGatedDeltaRule( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->out->data(), + planned->initial_state->data(), + planned->final_state.has_value() ? planned->final_state.value()->data() : nullptr, + planned->q->data(), + planned->k->data(), + planned->v->data(), + planned->g->data(), + planned->beta->data(), + planned->cu_seqlens.has_value() ? planned->cu_seqlens.value()->data() : nullptr, + planned->initial_state_indices.has_value() ? planned->initial_state_indices.value()->data() : nullptr, + planned->final_state_indices.has_value() ? planned->final_state_indices.value()->data() : nullptr, + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(ChunkGatedDeltaRule, &plan, &run, &cleanup); + +} // namespace infinicore::op::chunk_gated_delta_rule_impl::infiniop diff --git a/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc new file mode 100644 index 000000000..4fd7bb69a --- /dev/null +++ b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.cc @@ -0,0 +1,166 @@ +#include "infinicore/ops/recurrent_gated_delta_rule.hpp" +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RecurrentGatedDeltaRule); + +RecurrentGatedDeltaRule::RecurrentGatedDeltaRule(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state, q, k, v, g, beta); + if (final_state.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state.value()); + } + if (initial_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, initial_state_indices.value()); + } + if (final_state_indices.has_value()) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, final_state_indices.value()); + } + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); +} + +void RecurrentGatedDeltaRule::execute(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(RecurrentGatedDeltaRule, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); +} + +static Tensor ensure_4d_sequence_tensor(const Tensor &x, const char *name) { + if (x->shape().size() == 4) { + return x; + } + if (x->shape().size() == 3) { + return x->unsqueeze(1); + } + throw std::runtime_error(std::string("recurrent_gated_delta_rule expects ") + name + " with shape [B, T, H, D] or [B, H, D]"); +} + +static Shape recurrent_output_shape(const Tensor &v) { + const auto &shape = v->shape(); + return {shape[0], shape[1], shape[2], shape[3]}; +} + +Tensor recurrent_gated_delta_rule(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + const Tensor &initial_state, + bool use_qk_l2norm) { + Tensor q4 = ensure_4d_sequence_tensor(q, "q"); + Tensor k4 = ensure_4d_sequence_tensor(k, "k"); + Tensor v4 = ensure_4d_sequence_tensor(v, "v"); + auto out = Tensor::empty(recurrent_output_shape(v4), v4->dtype(), v4->device()); + Shape final_state_shape = {v4->shape()[0], v4->shape()[2], q4->shape()[3], v4->shape()[3]}; + auto final_state = Tensor::empty(final_state_shape, initial_state->dtype(), initial_state->device()); + recurrent_gated_delta_rule_(out, + initial_state, + final_state, + q4, + k4, + v4, + g, + beta, + std::nullopt, + std::nullopt, + use_qk_l2norm); + return out; +} + +Tensor recurrent_gated_delta_rule_indexed(const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + Tensor initial_state, + const Tensor &initial_state_indices, + const Tensor &final_state_indices, + bool use_qk_l2norm) { + Tensor q4 = ensure_4d_sequence_tensor(q, "q"); + Tensor k4 = ensure_4d_sequence_tensor(k, "k"); + Tensor v4 = ensure_4d_sequence_tensor(v, "v"); + auto out = Tensor::empty(recurrent_output_shape(v4), v4->dtype(), v4->device()); + recurrent_gated_delta_rule_(out, + initial_state, + std::nullopt, + q4, + k4, + v4, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); + return out; +} + +void recurrent_gated_delta_rule_(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + Tensor q4 = ensure_4d_sequence_tensor(q, "q"); + Tensor k4 = ensure_4d_sequence_tensor(k, "k"); + Tensor v4 = ensure_4d_sequence_tensor(v, "v"); + RecurrentGatedDeltaRule::execute(out, + initial_state, + final_state, + q4, + k4, + v4, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc new file mode 100644 index 000000000..ba52b2c40 --- /dev/null +++ b/src/infinicore/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule_infiniop.cc @@ -0,0 +1,99 @@ +#include "infinicore/ops/recurrent_gated_delta_rule.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::recurrent_gated_delta_rule_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RecurrentGatedDeltaRule, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, initial_state, q, k, v, g, beta; + std::optional final_state; + std::optional initial_state_indices; + std::optional final_state_indices; +}; + +void *plan(Tensor out, + Tensor initial_state, + std::optional final_state, + const Tensor &q, + const Tensor &k, + const Tensor &v, + const Tensor &g, + const Tensor &beta, + std::optional initial_state_indices, + std::optional final_state_indices, + bool use_qk_l2norm) { + size_t seed = hash_combine(out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + initial_state_indices, + final_state_indices, + use_qk_l2norm); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, RecurrentGatedDeltaRule, + seed, + out->desc(), + initial_state->desc(), + final_state.has_value() ? final_state.value()->desc() : nullptr, + q->desc(), + k->desc(), + v->desc(), + g->desc(), + beta->desc(), + initial_state_indices.has_value() ? initial_state_indices.value()->desc() : nullptr, + final_state_indices.has_value() ? final_state_indices.value()->desc() : nullptr, + use_qk_l2norm); + + INFINIOP_WORKSPACE_TENSOR(workspace, RecurrentGatedDeltaRule, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(initial_state), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(g), + graph::GraphTensor(beta), + final_state.has_value() ? std::optional(graph::GraphTensor(final_state.value())) : std::nullopt, + initial_state_indices.has_value() ? std::optional(graph::GraphTensor(initial_state_indices.value())) : std::nullopt, + final_state_indices.has_value() ? std::optional(graph::GraphTensor(final_state_indices.value())) : std::nullopt}; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopRecurrentGatedDeltaRule( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->out->data(), + planned->initial_state->data(), + planned->final_state.has_value() ? planned->final_state.value()->data() : nullptr, + planned->q->data(), + planned->k->data(), + planned->v->data(), + planned->g->data(), + planned->beta->data(), + planned->initial_state_indices.has_value() ? planned->initial_state_indices.value()->data() : nullptr, + planned->final_state_indices.has_value() ? planned->final_state_indices.value()->data() : nullptr, + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RecurrentGatedDeltaRule, &plan, &run, &cleanup); + +} // namespace infinicore::op::recurrent_gated_delta_rule_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index ccf473ba8..435c1c119 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -34,6 +34,7 @@ #include "ops/cat.hpp" #include "ops/causal_softmax.hpp" #include "ops/cdist.hpp" +#include "ops/chunk_gated_delta_rule.hpp" #include "ops/conv2d.hpp" #include "ops/cross_entropy.hpp" #include "ops/diff.hpp" @@ -89,6 +90,7 @@ #include "ops/random_sample.hpp" #include "ops/rearrange.hpp" #include "ops/reciprocal.hpp" +#include "ops/recurrent_gated_delta_rule.hpp" #include "ops/relu6.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" @@ -241,7 +243,9 @@ inline void bind(py::module &m) { bind_addcmul(m); bind_cdist(m); bind_binary_cross_entropy_with_logits(m); + bind_chunk_gated_delta_rule(m); bind_reciprocal(m); + bind_recurrent_gated_delta_rule(m); bind_upsample_bilinear(m); bind_kthvalue(m); bind_ldexp(m); diff --git a/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp b/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp new file mode 100644 index 000000000..6d8bf5cf2 --- /dev/null +++ b/src/infinicore/pybind11/ops/chunk_gated_delta_rule.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include "infinicore/ops/chunk_gated_delta_rule.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_chunk_gated_delta_rule(py::module &m) { + m.def("chunk_gated_delta_rule", + &op::chunk_gated_delta_rule, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g"), + py::arg("beta"), + py::arg("initial_state"), + py::arg("cu_seqlens") = std::nullopt, + py::arg("initial_state_indices") = std::nullopt, + py::arg("final_state_indices") = std::nullopt, + py::arg("use_qk_l2norm") = false, + py::arg("chunk_size") = 64, + R"doc(Chunk gated delta rule. Returns out only. + +Padded mode: + q/k: [B, T, Hk, Dk], v/out: [B, T, Hv, Dv], g/beta: [B, T, Hv], + initial_state: [B, Hv, Dk, Dv]. + +Continuous-batch mode: + pass cu_seqlens [B + 1]; q/k: [1, total_tokens, Hk, Dk], + v/out: [1, total_tokens, Hv, Dv], g/beta: [1, total_tokens, Hv]. + +Indexed pool mode: + initial_state is [pool_size, Hv, Dv, Dk]. Provide both initial_state_indices + and final_state_indices [B]; final state is written in-place to initial_state. +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp b/src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp new file mode 100644 index 000000000..127b62042 --- /dev/null +++ b/src/infinicore/pybind11/ops/recurrent_gated_delta_rule.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "infinicore/ops/recurrent_gated_delta_rule.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_recurrent_gated_delta_rule(py::module &m) { + m.def("recurrent_gated_delta_rule", + &op::recurrent_gated_delta_rule, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g"), + py::arg("beta"), + py::arg("initial_state"), + py::arg("use_qk_l2norm") = false, + R"doc(Recurrent gated delta rule. Returns out only.)doc"); + + m.def("recurrent_gated_delta_rule_indexed", + &op::recurrent_gated_delta_rule_indexed, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("g"), + py::arg("beta"), + py::arg("initial_state"), + py::arg("initial_state_indices"), + py::arg("final_state_indices"), + py::arg("use_qk_l2norm") = false, + R"doc(Recurrent gated delta rule with indexed in-place state pool. Returns out only.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h b/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h new file mode 100644 index 000000000..808e0e52c --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/chunk_gated_delta_rule.h @@ -0,0 +1,62 @@ +// infiniop/ops/chunk_gated_delta_rule.h + +#ifndef __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ +#define __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::chunk_gated_delta_rule::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + ChunkGatedDeltaRuleInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + ChunkGatedDeltaRuleInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t initial_state_desc, \ + infiniopTensorDescriptor_t final_state_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_desc, \ + infiniopTensorDescriptor_t beta_desc, \ + infiniopTensorDescriptor_t cu_seqlens_desc, \ + infiniopTensorDescriptor_t initial_state_indices_desc, \ + infiniopTensorDescriptor_t final_state_indices_desc, \ + bool use_qk_l2norm, \ + size_t chunk_size); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, void *initial_state, void *final_state, \ + const void *q, const void *k, const void *v, \ + const void *g, const void *beta, const void *cu_seqlens, \ + const void *initial_state_indices, \ + const void *final_state_indices, \ + void *stream) const; \ + }; \ + } + +#endif // __INFINIOP_CHUNK_GATED_DELTA_RULE_H__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh new file mode 100644 index 000000000..cc167b6c7 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/cuda/kernel.cuh @@ -0,0 +1,667 @@ +#ifndef __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ +#define __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ + +#include +#include +#include + +__device__ inline int64_t loadOptionalIndex( + const void *indices, + bool is_i64, + int idx, + int fallback) { + if (indices == nullptr) { + return static_cast(fallback); + } + return is_i64 + ? static_cast(indices)[idx] + : static_cast(static_cast(indices)[idx]); +} + +template +__device__ inline float loadAsFloat(const T *ptr, ptrdiff_t offset) { + return static_cast(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat(const half *ptr, ptrdiff_t offset) { + return __half2float(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { + return __bfloat162float(ptr[offset]); +} + +#define CGDR_FOR(idx, n) \ + for (int idx = threadIdx.x; idx < static_cast(n); idx += blockDim.x) + +template +__device__ Tcompute blockReduceSum(Tcompute v) { + __shared__ Tcompute smem[NUM_THREADS]; + smem[threadIdx.x] = v; + __syncthreads(); + + for (int s = NUM_THREADS / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + smem[threadIdx.x] += smem[threadIdx.x + s]; + } + __syncthreads(); + } + return smem[0]; +} + +template < + typename Tdata, + typename Tgate, + typename Tcompute, + size_t Dk, + size_t Dv, + size_t NUM_THREADS> +__device__ void chunkGatedDeltaRuleKernel( + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t chunk_size, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + + const int batch_idx = blockIdx.x; + const int value_head_idx = blockIdx.y; + const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); + + if (key_head_idx >= static_cast(Hk)) { + return; + } + + int64_t token_begin = 0; + int64_t token_end = static_cast(T); + + if (has_cu_seqlens) { + token_begin = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx, 0); + token_end = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx + 1, 0); + if (token_begin < 0 || token_end < token_begin || token_end > static_cast(T)) { + return; + } + } + + int64_t read_slot = batch_idx; + int64_t write_slot = batch_idx; + + if (indexed_state_pool) { + read_slot = loadOptionalIndex( + initial_state_indices, + initial_state_indices_i64, + batch_idx, + batch_idx); + + write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadOptionalIndex( + final_state_indices, + final_state_indices_i64, + batch_idx, + batch_idx); + + if (read_slot < 0 || write_slot < 0 || read_slot >= static_cast(pool_size) || write_slot >= static_cast(pool_size)) { + return; + } + } + + const ptrdiff_t initial_base = indexed_state_pool + ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 + : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; + + Tdata *final_state_target = nullptr; + ptrdiff_t final_base = 0; + + if (indexed_state_pool && final_state_indices != nullptr) { + final_state_target = initial_state; + final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; + } else { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } + + const ptrdiff_t per_block_workspace = static_cast(Dk * Dv) + static_cast(chunk_size * Dk) * 3 + static_cast(chunk_size * Dv) * 3 + static_cast(chunk_size * chunk_size) + static_cast(chunk_size) * 3; + + const ptrdiff_t workspace_block = (static_cast(batch_idx) * gridDim.y + static_cast(value_head_idx)) * per_block_workspace; + + Tcompute *state_local = state_workspace + workspace_block; + Tcompute *q_buf = state_local + Dk * Dv; + Tcompute *k_buf = q_buf + chunk_size * Dk; + Tcompute *k_cumdecay = k_buf + chunk_size * Dk; + Tcompute *v_beta = k_cumdecay + chunk_size * Dk; + Tcompute *v_mid = v_beta + chunk_size * Dv; + Tcompute *v_new = v_mid + chunk_size * Dv; + Tcompute *attn = v_new + chunk_size * Dv; + Tcompute *g_cum = attn + chunk_size * chunk_size; + Tcompute *beta_buf = g_cum + chunk_size; + Tcompute *row_buf = beta_buf + chunk_size; + + const int token_batch = has_cu_seqlens ? 0 : batch_idx; + const Tcompute scale = rsqrtf(static_cast(Dk)); + + // Load initial state. + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t read_idx = indexed_state_pool + ? initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3 + : initial_base + static_cast(dk) * initial_s2 + static_cast(dv) * initial_s3; + + state_local[i] = static_cast( + loadAsFloat(initial_state, read_idx)); + } + __syncthreads(); + + for (int64_t chunk_begin = token_begin; + chunk_begin < token_end; + chunk_begin += static_cast(chunk_size)) { + + const int64_t remaining = token_end - chunk_begin; + const int actual_len = static_cast( + remaining < static_cast(chunk_size) + ? remaining + : static_cast(chunk_size)); + + // Load beta and cumulative gate. + if (threadIdx.x == 0) { + Tcompute running_g = 0; + for (int t = 0; t < static_cast(chunk_size); ++t) { + if (t < actual_len) { + int64_t token_idx = chunk_begin + t; + ptrdiff_t gate_offset = static_cast(token_batch) * g_s0 + static_cast(token_idx) * g_s1 + static_cast(value_head_idx) * g_s2; + ptrdiff_t beta_offset = static_cast(token_batch) * beta_s0 + static_cast(token_idx) * beta_s1 + static_cast(value_head_idx) * beta_s2; + + running_g += static_cast( + loadAsFloat(g, gate_offset)); + beta_buf[t] = static_cast( + loadAsFloat(beta, beta_offset)); + g_cum[t] = running_g; + } else { + beta_buf[t] = 0; + g_cum[t] = running_g; + } + } + } + __syncthreads(); + + // Load q/k/v_beta. + CGDR_FOR(x, chunk_size * Dk) { + int t = x / Dk; + int dk = x % Dk; + + if (t < actual_len) { + int64_t token_idx = chunk_begin + t; + ptrdiff_t q_base = static_cast(token_batch) * q_s0 + static_cast(token_idx) * q_s1 + static_cast(key_head_idx) * q_s2; + ptrdiff_t k_base = static_cast(token_batch) * k_s0 + static_cast(token_idx) * k_s1 + static_cast(key_head_idx) * k_s2; + + q_buf[x] = static_cast(loadAsFloat(q, q_base + dk)) * scale; + k_buf[x] = static_cast(loadAsFloat(k, k_base + dk)); + } else { + q_buf[x] = 0; + k_buf[x] = 0; + } + } + + CGDR_FOR(x, chunk_size * Dv) { + int t = x / Dv; + int dv = x % Dv; + + if (t < actual_len) { + int64_t token_idx = chunk_begin + t; + ptrdiff_t v_base = static_cast(token_batch) * v_s0 + static_cast(token_idx) * v_s1 + static_cast(value_head_idx) * v_s2; + + v_beta[x] = static_cast(loadAsFloat(v, v_base + dv)) * beta_buf[t]; + } else { + v_beta[x] = 0; + } + } + __syncthreads(); + + // Optional q/k L2 norm. + if (use_qk_l2norm) { + for (int t = 0; t < static_cast(chunk_size); ++t) { + Tcompute q_sum = 0; + Tcompute k_sum = 0; + + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + q_sum += q_buf[t * Dk + dk] * q_buf[t * Dk + dk]; + k_sum += k_buf[t * Dk + dk] * k_buf[t * Dk + dk]; + } + + q_sum = blockReduceSum(q_sum); + k_sum = blockReduceSum(k_sum); + + Tcompute q_norm = rsqrtf(q_sum / (scale * scale) + 1e-6f); + Tcompute k_norm = rsqrtf(k_sum + 1e-6f); + + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + q_buf[t * Dk + dk] *= q_norm; + k_buf[t * Dk + dk] *= k_norm; + } + __syncthreads(); + } + } + + // Build lower-triangular attn. + CGDR_FOR(x, chunk_size * chunk_size) { + int i = x / chunk_size; + int j = x % chunk_size; + + if (j < i) { + Tcompute dot = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + dot += k_buf[i * Dk + dk] * beta_buf[i] * k_buf[j * Dk + dk]; + } + attn[x] = -dot * expf(g_cum[i] - g_cum[j]); + } else { + attn[x] = 0; + } + } + __syncthreads(); + + // Triangular solve-like correction. + // Sequential in i, parallel in j. + for (int i = 1; i < static_cast(chunk_size); ++i) { + CGDR_FOR(m, chunk_size) { + row_buf[m] = m < i ? attn[i * chunk_size + m] : 0; + } + __syncthreads(); + + for (int j = threadIdx.x; j < i; j += blockDim.x) { + Tcompute correction = 0; + for (int m = 0; m < i; ++m) { + correction += row_buf[m] * attn[m * chunk_size + j]; + } + attn[i * chunk_size + j] = row_buf[j] + correction; + } + __syncthreads(); + } + + CGDR_FOR(i, chunk_size) { + attn[i * chunk_size + i] = 1; + } + __syncthreads(); + + // v_mid = attn @ v_beta. + CGDR_FOR(x, chunk_size * Dv) { + int i = x / Dv; + int dv = x % Dv; + + Tcompute sum = 0; + for (int j = 0; j < static_cast(chunk_size); ++j) { + sum += attn[i * chunk_size + j] * v_beta[j * Dv + dv]; + } + v_mid[x] = sum; + } + + // k_cumdecay = attn @ (k * beta * exp(g)). + CGDR_FOR(x, chunk_size * Dk) { + int i = x / Dk; + int dk = x % Dk; + + Tcompute sum = 0; + for (int j = 0; j < static_cast(chunk_size); ++j) { + sum += attn[i * chunk_size + j] * k_buf[j * Dk + dk] * beta_buf[j] * expf(g_cum[j]); + } + k_cumdecay[x] = sum; + } + __syncthreads(); + + // v_new = v_mid - k_cumdecay @ state. + CGDR_FOR(x, chunk_size * Dv) { + int i = x / Dv; + int dv = x % Dv; + + Tcompute v_prime = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + v_prime += k_cumdecay[i * Dk + dk] * state_local[dk * Dv + dv]; + } + v_new[x] = v_mid[x] - v_prime; + } + __syncthreads(); + + // Output. + CGDR_FOR(x, actual_len * Dv) { + int i = x / Dv; + int dv = x % Dv; + + int64_t token_idx = chunk_begin + i; + ptrdiff_t out_base = static_cast(token_batch) * out_s0 + static_cast(token_idx) * out_s1 + static_cast(value_head_idx) * out_s2; + + Tcompute out_val = 0; + Tcompute q_decay = expf(g_cum[i]); + + for (int dk = 0; dk < static_cast(Dk); ++dk) { + out_val += q_buf[i * Dk + dk] * q_decay * state_local[dk * Dv + dv]; + } + + for (int j = 0; j <= i; ++j) { + Tcompute qk_attn = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + qk_attn += q_buf[i * Dk + dk] * k_buf[j * Dk + dk]; + } + + qk_attn *= expf(g_cum[i] - g_cum[j]); + out_val += qk_attn * v_new[j * Dv + dv]; + } + + out[out_base + dv] = static_cast(out_val); + } + __syncthreads(); + + // Update state. + const Tcompute last_decay = expf(g_cum[chunk_size - 1]); + + CGDR_FOR(x, Dk * Dv) { + int dk = x / Dv; + int dv = x % Dv; + + Tcompute next_state = state_local[x] * last_decay; + + for (int i = 0; i < static_cast(chunk_size); ++i) { + next_state += k_buf[i * Dk + dk] * expf(g_cum[chunk_size - 1] - g_cum[i]) * v_new[i * Dv + dv]; + } + + state_local[x] = next_state; + } + __syncthreads(); + } + + // Store final state. + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t write_idx; + if (indexed_state_pool) { + const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; + const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; + + write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; + } else { + write_idx = final_base + static_cast(dk) * final_s2 + static_cast(dv) * final_s3; + } + + final_state_target[write_idx] = static_cast(state_local[i]); + } +} + +template < + typename Tdata, + typename Tgate, + typename Tcompute, + size_t Dk, + size_t Dv, + size_t NUM_THREADS> +__device__ void chunkGatedDeltaRuleRecurrentKernel( + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + + const int batch_idx = blockIdx.x; + const int value_head_idx = blockIdx.y; + const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); + + if (key_head_idx >= static_cast(Hk)) { + return; + } + + int64_t token_begin = 0; + int64_t token_end = static_cast(T); + + if (has_cu_seqlens) { + token_begin = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx, 0); + token_end = loadOptionalIndex(cu_seqlens, cu_seqlens_i64, batch_idx + 1, 0); + if (token_begin < 0 || token_end < token_begin || token_end > static_cast(T)) { + return; + } + } + + int64_t read_slot = batch_idx; + int64_t write_slot = batch_idx; + + if (indexed_state_pool) { + read_slot = loadOptionalIndex( + initial_state_indices, + initial_state_indices_i64, + batch_idx, + batch_idx); + + write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadOptionalIndex( + final_state_indices, + final_state_indices_i64, + batch_idx, + batch_idx); + + if (read_slot < 0 || write_slot < 0 || read_slot >= static_cast(pool_size) || write_slot >= static_cast(pool_size)) { + return; + } + } + + const ptrdiff_t initial_base = indexed_state_pool + ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 + : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; + + Tdata *final_state_target = nullptr; + ptrdiff_t final_base = 0; + + if (indexed_state_pool && final_state_indices != nullptr) { + final_state_target = initial_state; + final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; + } else { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } + + const ptrdiff_t workspace_block = (static_cast(batch_idx) * gridDim.y + static_cast(value_head_idx)) * static_cast(Dk * Dv); + + Tcompute *state_local = state_workspace + workspace_block; + __shared__ Tcompute q_vec[Dk]; + __shared__ Tcompute k_vec[Dk]; + __shared__ Tcompute v_new[Dv]; + __shared__ Tcompute scalar_buf[3]; + + const int token_batch = has_cu_seqlens ? 0 : batch_idx; + const Tcompute scale = rsqrtf(static_cast(Dk)); + + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t read_idx = indexed_state_pool + ? initial_base + static_cast(dv) * initial_s2 + static_cast(dk) * initial_s3 + : initial_base + static_cast(dk) * initial_s2 + static_cast(dv) * initial_s3; + + state_local[i] = static_cast( + loadAsFloat(initial_state, read_idx)); + } + __syncthreads(); + + for (int64_t token_idx = token_begin; token_idx < token_end; ++token_idx) { + ptrdiff_t q_base = static_cast(token_batch) * q_s0 + static_cast(token_idx) * q_s1 + static_cast(key_head_idx) * q_s2; + ptrdiff_t k_base = static_cast(token_batch) * k_s0 + static_cast(token_idx) * k_s1 + static_cast(key_head_idx) * k_s2; + + Tcompute q_sum = 0; + Tcompute k_sum = 0; + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + Tcompute q_raw = static_cast(loadAsFloat(q, q_base + dk)); + Tcompute k_raw = static_cast(loadAsFloat(k, k_base + dk)); + q_vec[dk] = q_raw; + k_vec[dk] = k_raw; + q_sum += q_raw * q_raw; + k_sum += k_raw * k_raw; + } + + q_sum = blockReduceSum(q_sum); + k_sum = blockReduceSum(k_sum); + + if (threadIdx.x == 0) { + ptrdiff_t gate_offset = static_cast(token_batch) * g_s0 + static_cast(token_idx) * g_s1 + static_cast(value_head_idx) * g_s2; + ptrdiff_t beta_offset = static_cast(token_batch) * beta_s0 + static_cast(token_idx) * beta_s1 + static_cast(value_head_idx) * beta_s2; + + scalar_buf[0] = expf(static_cast(loadAsFloat(g, gate_offset))); + scalar_buf[1] = static_cast(loadAsFloat(beta, beta_offset)); + scalar_buf[2] = use_qk_l2norm + ? rsqrtf(q_sum + static_cast(1e-6)) * scale + : scale; + } + __syncthreads(); + + const Tcompute decay = scalar_buf[0]; + const Tcompute beta_t = scalar_buf[1]; + const Tcompute q_scale = scalar_buf[2]; + const Tcompute k_scale = use_qk_l2norm + ? rsqrtf(k_sum + static_cast(1e-6)) + : static_cast(1); + + for (int dk = threadIdx.x; dk < static_cast(Dk); dk += blockDim.x) { + q_vec[dk] *= q_scale; + k_vec[dk] *= k_scale; + } + __syncthreads(); + + for (int dv = threadIdx.x; dv < static_cast(Dv); dv += blockDim.x) { + Tcompute projected = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + projected += k_vec[dk] * state_local[dk * Dv + dv]; + } + + ptrdiff_t v_base = static_cast(token_batch) * v_s0 + static_cast(token_idx) * v_s1 + static_cast(value_head_idx) * v_s2; + Tcompute v_raw = static_cast(loadAsFloat(v, v_base + dv)); + v_new[dv] = beta_t * (v_raw - decay * projected); + } + __syncthreads(); + + CGDR_FOR(x, Dk * Dv) { + int dk = x / Dv; + int dv = x % Dv; + state_local[x] = decay * state_local[x] + k_vec[dk] * v_new[dv]; + } + __syncthreads(); + + for (int dv = threadIdx.x; dv < static_cast(Dv); dv += blockDim.x) { + Tcompute out_val = 0; + for (int dk = 0; dk < static_cast(Dk); ++dk) { + out_val += q_vec[dk] * state_local[dk * Dv + dv]; + } + + ptrdiff_t out_base = static_cast(token_batch) * out_s0 + static_cast(token_idx) * out_s1 + static_cast(value_head_idx) * out_s2; + out[out_base + dv] = static_cast(out_val); + } + __syncthreads(); + } + + CGDR_FOR(i, Dk * Dv) { + int dk = i / Dv; + int dv = i % Dv; + + ptrdiff_t write_idx; + if (indexed_state_pool) { + const ptrdiff_t s2 = final_state_indices != nullptr ? initial_s2 : final_s2; + const ptrdiff_t s3 = final_state_indices != nullptr ? initial_s3 : final_s3; + + write_idx = final_base + static_cast(dv) * s2 + static_cast(dk) * s3; + } else { + write_idx = final_base + static_cast(dk) * final_s2 + static_cast(dv) * final_s3; + } + + final_state_target[write_idx] = static_cast(state_local[i]); + } +} + +#endif // __CHUNK_GATED_DELTA_RULE_KERNEL_CUH__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/info.h b/src/infiniop/ops/chunk_gated_delta_rule/info.h new file mode 100644 index 000000000..4dcb25319 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/info.h @@ -0,0 +1,207 @@ +// infiniop/ops/chunk_gated_delta_rule/info.h + +#ifndef __CHUNK_GATED_DELTA_RULE_INFO_H__ +#define __CHUNK_GATED_DELTA_RULE_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op { +namespace chunk_gated_delta_rule { + +class ChunkGatedDeltaRuleInfo { + ChunkGatedDeltaRuleInfo() = default; + +public: + infiniDtype_t data_dtype; + infiniDtype_t gate_dtype; + infiniDtype_t cu_seqlens_dtype; + infiniDtype_t initial_state_indices_dtype; + infiniDtype_t final_state_indices_dtype; + + bool use_qk_l2norm; + bool has_cu_seqlens; + bool has_initial_state_indices; + bool has_final_state_indices; + bool indexed_state_pool; + + size_t B, T, total_tokens, Hk, Hv, Dk, Dv, chunk_size, pool_size, value_heads_per_key_head; + + std::vector out_strides; + std::vector initial_state_strides; + std::vector final_state_strides; + std::vector q_strides; + std::vector k_strides; + std::vector v_strides; + std::vector g_strides; + std::vector beta_strides; + + static utils::Result + create(infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t cu_seqlens_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, + bool use_qk_l2norm, + size_t chunk_size) { + + if (out_desc == nullptr || initial_state_desc == nullptr || q_desc == nullptr || k_desc == nullptr || v_desc == nullptr || g_desc == nullptr || beta_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + if (chunk_size == 0) { + return INFINI_STATUS_BAD_PARAM; + } + + auto data_dtype = q_desc->dtype(); + CHECK_DTYPE(data_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (k_desc->dtype() != data_dtype || v_desc->dtype() != data_dtype || out_desc->dtype() != data_dtype || initial_state_desc->dtype() != data_dtype || (final_state_desc != nullptr && final_state_desc->dtype() != data_dtype)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto gate_dtype = g_desc->dtype(); + CHECK_DTYPE(gate_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (beta_desc->dtype() != gate_dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + bool has_cu = cu_seqlens_desc != nullptr; + bool has_initial_indices = initial_state_indices_desc != nullptr; + bool has_final_indices = final_state_indices_desc != nullptr; + bool indexed_pool = has_initial_indices || has_final_indices; + + if (has_final_indices && final_state_desc != nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + if (!has_final_indices && final_state_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || out_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || initial_state_desc->ndim() != 4 || (!has_final_indices && final_state_desc->ndim() != 4)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + auto q_shape = q_desc->shape(); + auto k_shape = k_desc->shape(); + auto v_shape = v_desc->shape(); + auto out_shape = out_desc->shape(); + auto g_shape = g_desc->shape(); + auto beta_shape = beta_desc->shape(); + + size_t B = q_shape[0], T = q_shape[1], Hk = q_shape[2], Dk = q_shape[3]; + size_t Hv = v_shape[2], Dv = v_shape[3], total_tokens = T; + + if (has_cu) { + if (cu_seqlens_desc->ndim() != 1 || cu_seqlens_desc->shape()[0] < 2) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + B = cu_seqlens_desc->shape()[0] - 1; + if (q_shape[0] != 1 || k_shape[0] != 1 || v_shape[0] != 1 || out_shape[0] != 1 || g_shape[0] != 1 || beta_shape[0] != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + total_tokens = q_shape[1]; + T = total_tokens; + } + + if (k_shape[0] != q_shape[0] || k_shape[1] != q_shape[1] || k_shape[2] != Hk || k_shape[3] != Dk || v_shape[0] != q_shape[0] || v_shape[1] != q_shape[1] || out_shape[0] != q_shape[0] || out_shape[1] != q_shape[1] || out_shape[2] != Hv || out_shape[3] != Dv || g_shape[0] != q_shape[0] || g_shape[1] != q_shape[1] || g_shape[2] != Hv || beta_shape[0] != q_shape[0] || beta_shape[1] != q_shape[1] || beta_shape[2] != Hv || Hk == 0 || Hv == 0 || Hv % Hk != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (q_desc->strides()[3] != 1 || k_desc->strides()[3] != 1 || v_desc->strides()[3] != 1 || out_desc->strides()[3] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + auto initial_shape = initial_state_desc->shape(); + size_t pool_size = initial_shape[0]; + if (indexed_pool) { + if (initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (!has_final_indices) { + auto final_shape = final_state_desc->shape(); + if (indexed_pool) { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + } + + infiniDtype_t cu_dtype = INFINI_DTYPE_INVALID; + infiniDtype_t initial_indices_dtype = INFINI_DTYPE_INVALID; + infiniDtype_t final_indices_dtype = INFINI_DTYPE_INVALID; + if (has_cu) { + cu_dtype = cu_seqlens_desc->dtype(); + CHECK_DTYPE(cu_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + if (has_initial_indices) { + if (initial_state_indices_desc->ndim() != 1 || initial_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + initial_indices_dtype = initial_state_indices_desc->dtype(); + CHECK_DTYPE(initial_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + if (has_final_indices) { + if (final_state_indices_desc->ndim() != 1 || final_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + final_indices_dtype = final_state_indices_desc->dtype(); + CHECK_DTYPE(final_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + + ChunkGatedDeltaRuleInfo info; + info.data_dtype = data_dtype; + info.gate_dtype = gate_dtype; + info.cu_seqlens_dtype = cu_dtype; + info.initial_state_indices_dtype = initial_indices_dtype; + info.final_state_indices_dtype = final_indices_dtype; + info.use_qk_l2norm = use_qk_l2norm; + info.has_cu_seqlens = has_cu; + info.has_initial_state_indices = has_initial_indices; + info.has_final_state_indices = has_final_indices; + info.indexed_state_pool = indexed_pool; + info.B = B; + info.T = T; + info.total_tokens = total_tokens; + info.Hk = Hk; + info.Hv = Hv; + info.Dk = Dk; + info.Dv = Dv; + info.chunk_size = chunk_size; + info.pool_size = pool_size; + info.value_heads_per_key_head = Hv / Hk; + info.out_strides = out_desc->strides(); + info.initial_state_strides = initial_state_desc->strides(); + if (final_state_desc != nullptr) { + info.final_state_strides = final_state_desc->strides(); + } + info.q_strides = q_desc->strides(); + info.k_strides = k_desc->strides(); + info.v_strides = v_desc->strides(); + info.g_strides = g_desc->strides(); + info.beta_strides = beta_desc->strides(); + + return utils::Result(info); + } +}; + +} // namespace chunk_gated_delta_rule +} // namespace op + +#endif // __CHUNK_GATED_DELTA_RULE_INFO_H__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu new file mode 100644 index 000000000..461f95ad5 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cu @@ -0,0 +1,305 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "chunk_gated_delta_rule_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include + +template +INFINIOP_CUDA_KERNEL chunkGatedDeltaRule( + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t chunk_size, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + chunkGatedDeltaRuleRecurrentKernel( + state_workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, cu_seqlens_i64, + initial_state_indices_i64, final_state_indices_i64, use_qk_l2norm, + has_cu_seqlens, indexed_state_pool, T, chunk_size, pool_size, Hk, value_heads_per_key_head, + out_s0, out_s1, out_s2, initial_s0, initial_s1, initial_s2, initial_s3, + final_s0, final_s1, final_s2, final_s3, q_s0, q_s1, q_s2, k_s0, k_s1, + k_s2, v_s0, v_s1, v_s2, g_s0, g_s1, g_s2, beta_s0, beta_s1, beta_s2); +} + +template +INFINIOP_CUDA_KERNEL chunkGatedDeltaRuleChunked( + Tcompute *state_workspace, + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + bool cu_seqlens_i64, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool has_cu_seqlens, + bool indexed_state_pool, + size_t T, + size_t chunk_size, + size_t pool_size, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + chunkGatedDeltaRuleKernel( + state_workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, cu_seqlens_i64, + initial_state_indices_i64, final_state_indices_i64, use_qk_l2norm, + has_cu_seqlens, indexed_state_pool, T, chunk_size, pool_size, Hk, value_heads_per_key_head, + out_s0, out_s1, out_s2, initial_s0, initial_s1, initial_s2, initial_s3, + final_s0, final_s1, final_s2, final_s3, q_s0, q_s1, q_s2, k_s0, k_s1, + k_s2, v_s0, v_s1, v_s2, g_s0, g_s1, g_s2, beta_s0, beta_s1, beta_s2); +} + +namespace op { +namespace chunk_gated_delta_rule { +namespace nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t cu_seqlens_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, + bool use_qk_l2norm, + size_t chunk_size) { + auto info = ChunkGatedDeltaRuleInfo::create( + out_desc, initial_state_desc, final_state_desc, q_desc, k_desc, v_desc, + g_desc, beta_desc, cu_seqlens_desc, initial_state_indices_desc, + final_state_indices_desc, use_qk_l2norm, chunk_size); + CHECK_RESULT(info); + + auto info_value = info.take(); + // We always want to use fast path, slow path is kept as a ref + const bool use_chunked_fallback = false; + const size_t per_block_workspace = use_chunked_fallback + ? info_value.Dk * info_value.Dv + info_value.chunk_size * info_value.Dk * 3 + info_value.chunk_size * info_value.Dv * 3 + info_value.chunk_size * info_value.chunk_size + info_value.chunk_size * 3 + : info_value.Dk * info_value.Dv; + const size_t workspace_size = info_value.B * info_value.Hv * per_block_workspace * sizeof(float); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info_value, workspace_size, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernelWithGateDtype( + void *workspace, + void *out, + void *initial_state, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + const ChunkGatedDeltaRuleInfo &info, + cudaStream_t stream) { + +#define LAUNCH_ARGS(TYPE) \ + static_cast(workspace), static_cast(out), static_cast(initial_state), static_cast(final_state), \ + static_cast(q), static_cast(k), static_cast(v), \ + static_cast(g), static_cast(beta), cu_seqlens, initial_state_indices, \ + final_state_indices, info.cu_seqlens_dtype == INFINI_DTYPE_I64, \ + info.initial_state_indices_dtype == INFINI_DTYPE_I64, info.final_state_indices_dtype == INFINI_DTYPE_I64, \ + info.use_qk_l2norm, info.has_cu_seqlens, info.indexed_state_pool, info.T, info.chunk_size, info.pool_size, info.Hk, \ + info.value_heads_per_key_head, info.out_strides[0], info.out_strides[1], info.out_strides[2], \ + info.initial_state_strides[0], info.initial_state_strides[1], info.initial_state_strides[2], \ + info.initial_state_strides[3], info.final_state_strides.empty() ? 0 : info.final_state_strides[0], \ + info.final_state_strides.empty() ? 0 : info.final_state_strides[1], info.final_state_strides.empty() ? 0 : info.final_state_strides[2], \ + info.final_state_strides.empty() ? 0 : info.final_state_strides[3], info.q_strides[0], info.q_strides[1], \ + info.q_strides[2], info.k_strides[0], info.k_strides[1], info.k_strides[2], info.v_strides[0], \ + info.v_strides[1], info.v_strides[2], info.g_strides[0], info.g_strides[1], info.g_strides[2], \ + info.beta_strides[0], info.beta_strides[1], info.beta_strides[2] + +#define LAUNCH_GATE(TYPE) \ + do { \ + if (false) { \ + chunkGatedDeltaRuleChunked \ + <<>>(LAUNCH_ARGS(TYPE)); \ + } else { \ + chunkGatedDeltaRule \ + <<>>(LAUNCH_ARGS(TYPE)); \ + } \ + } while (0) + + if (info.gate_dtype == INFINI_DTYPE_F16) { + LAUNCH_GATE(half); + } else if (info.gate_dtype == INFINI_DTYPE_BF16) { + LAUNCH_GATE(__nv_bfloat16); + } else if (info.gate_dtype == INFINI_DTYPE_F32) { + LAUNCH_GATE(float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +#undef LAUNCH_GATE +#undef LAUNCH_ARGS + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + void *workspace, + void *out, + void *initial_state, + void *final_state, + const void *q, + const void *k, + const void *v, + const void *g, + const void *beta, + const void *cu_seqlens, + const void *initial_state_indices, + const void *final_state_indices, + const ChunkGatedDeltaRuleInfo &info, + cudaStream_t stream) { + if (info.data_dtype == INFINI_DTYPE_F16) { + return launchKernelWithGateDtype( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, info, stream); + } + if (info.data_dtype == INFINI_DTYPE_BF16) { + return launchKernelWithGateDtype<__nv_bfloat16, Dk, Dv, NUM_THREADS>( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, info, stream); + } + if (info.data_dtype == INFINI_DTYPE_F32) { + return launchKernelWithGateDtype( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, info, stream); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *cu_seqlens, + const void *initial_state_indices, const void *final_state_indices, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + if (workspace == nullptr || workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + if (_info.Dk == 128 && _info.Dv == 128) { + if (_opaque->internal->maxThreadsPerBlock() >= 128) { + return launchKernel<128, 128, 128>( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, _info, stream); + } + } else if (_info.Dk == 64 && _info.Dv == 64) { + if (_opaque->internal->maxThreadsPerBlock() >= 64) { + return launchKernel<64, 64, 64>( + workspace, out, initial_state, final_state, q, k, v, g, beta, cu_seqlens, + initial_state_indices, final_state_indices, _info, stream); + } + } + + return INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +} // namespace nvidia +} // namespace chunk_gated_delta_rule +} // namespace op diff --git a/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh new file mode 100644 index 000000000..5dde15f7d --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/nvidia/chunk_gated_delta_rule_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ +#define __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ + +#include "../chunk_gated_delta_rule.h" + +DESCRIPTOR(nvidia) + +#endif // __CHUNK_GATED_DELTA_RULE_NVIDIA_H__ diff --git a/src/infiniop/ops/chunk_gated_delta_rule/operator.cc b/src/infiniop/ops/chunk_gated_delta_rule/operator.cc new file mode 100644 index 000000000..631d9ca44 --- /dev/null +++ b/src/infiniop/ops/chunk_gated_delta_rule/operator.cc @@ -0,0 +1,115 @@ +// infiniop/ops/chunk_gated_delta_rule/operator.cc + +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/chunk_gated_delta_rule.h" + +#if defined(ENABLE_NVIDIA_API) +#include "nvidia/chunk_gated_delta_rule_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t cu_seqlens_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, + bool use_qk_l2norm, + size_t chunk_size) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::chunk_gated_delta_rule::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor **>( \ + desc_ptr), \ + out_desc, initial_state_desc, final_state_desc, q_desc, \ + k_desc, v_desc, g_desc, beta_desc, cu_seqlens_desc, \ + initial_state_indices_desc, final_state_indices_desc, \ + use_qk_l2norm, chunk_size); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetChunkGatedDeltaRuleWorkspaceSize( + infiniopChunkGatedDeltaRuleDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopChunkGatedDeltaRule( + infiniopChunkGatedDeltaRuleDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, const void *cu_seqlens, + const void *initial_state_indices, const void *final_state_indices, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->calculate(workspace, workspace_size, out, initial_state, \ + final_state, q, k, v, g, beta, cu_seqlens, \ + initial_state_indices, final_state_indices, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor( + infiniopChunkGatedDeltaRuleDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast< \ + op::chunk_gated_delta_rule::NAMESPACE::Descriptor *>(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh new file mode 100644 index 000000000..db9161626 --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/cuda/kernel.cuh @@ -0,0 +1,226 @@ +// kernel.cuh (in op/recurrent_gated_delta_rule/cuda/) + +#ifndef __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ +#define __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ + +#include +#include +#include +#include + +__device__ inline int64_t loadStateIndex( + const void *indices, + bool is_i64, + int batch_idx, + int fallback) { + if (indices == nullptr) { + return static_cast(fallback); + } + return is_i64 + ? static_cast(indices)[batch_idx] + : static_cast(static_cast(indices)[batch_idx]); +} + +template +__device__ inline float loadAsFloat(const T *ptr, ptrdiff_t offset) { + return static_cast(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat(const half *ptr, ptrdiff_t offset) { + return __half2float(ptr[offset]); +} + +template <> +__device__ inline float loadAsFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr, ptrdiff_t offset) { + return __bfloat162float(ptr[offset]); +} + +template +__device__ void recurrentGatedDeltaRuleKernel( + Tdata *out, + Tdata *initial_state, + Tdata *final_state, + const Tdata *q, + const Tdata *k, + const Tdata *v, + const Tgate *g, + const Tgate *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool indexed_state_pool, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + const int batch_idx = blockIdx.x; + const int value_head_idx = blockIdx.y; + const int key_head_idx = value_head_idx / static_cast(value_heads_per_key_head); + const int thread_idx = threadIdx.x; + + if (key_head_idx >= static_cast(Hk)) { + return; + } + + constexpr int seq_idx = 0; + const ptrdiff_t q_base = static_cast(batch_idx) * q_s0 + seq_idx * q_s1 + static_cast(key_head_idx) * q_s2; + const ptrdiff_t k_base = static_cast(batch_idx) * k_s0 + seq_idx * k_s1 + static_cast(key_head_idx) * k_s2; + const ptrdiff_t v_base = static_cast(batch_idx) * v_s0 + seq_idx * v_s1 + static_cast(value_head_idx) * v_s2; + const ptrdiff_t out_base = static_cast(batch_idx) * out_s0 + seq_idx * out_s1 + static_cast(value_head_idx) * out_s2; + const ptrdiff_t gate_offset = static_cast(batch_idx) * g_s0 + seq_idx * g_s1 + static_cast(value_head_idx) * g_s2; + const ptrdiff_t beta_offset = static_cast(batch_idx) * beta_s0 + seq_idx * beta_s1 + static_cast(value_head_idx) * beta_s2; + + int64_t read_slot = static_cast(batch_idx); + int64_t write_slot = static_cast(batch_idx); + if (indexed_state_pool) { + read_slot = loadStateIndex(initial_state_indices, initial_state_indices_i64, batch_idx, batch_idx); + write_slot = final_state_indices == nullptr + ? static_cast(batch_idx) + : loadStateIndex(final_state_indices, final_state_indices_i64, batch_idx, batch_idx); + if (read_slot < 0 || write_slot < 0) { + for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { + out[out_base + dv_idx] = static_cast(0.0f); + } + return; + } + } + + const ptrdiff_t initial_base = indexed_state_pool + ? static_cast(read_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1 + : static_cast(batch_idx) * initial_s0 + static_cast(value_head_idx) * initial_s1; + ptrdiff_t final_base = 0; + Tdata *final_state_target = nullptr; + if (indexed_state_pool && final_state_indices != nullptr) { + final_state_target = initial_state; + final_base = static_cast(write_slot) * initial_s0 + static_cast(value_head_idx) * initial_s1; + } else if (indexed_state_pool) { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } else { + final_state_target = final_state; + final_base = static_cast(batch_idx) * final_s0 + static_cast(value_head_idx) * final_s1; + } + + extern __shared__ char shared_mem_char[]; + Tcompute *shared_mem = reinterpret_cast(shared_mem_char); + + Tcompute *q_local = shared_mem; + Tcompute *k_local = q_local + Dk; + Tcompute *norm_val = k_local + Dk; + + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + q_local[i] = static_cast(loadAsFloat(q, q_base + i)); + k_local[i] = static_cast(loadAsFloat(k, k_base + i)); + } + + if (use_qk_l2norm) { + __syncthreads(); + Tcompute sum_sq = 0.0f; + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + sum_sq += q_local[i] * q_local[i]; + } + norm_val[thread_idx] = sum_sq; + __syncthreads(); + if (thread_idx == 0) { + Tcompute total_sum_sq = 0.0f; + for (int i = 0; i < NUM_THREADS; ++i) { + total_sum_sq += norm_val[i]; + } + norm_val[0] = rsqrtf(total_sum_sq + 1e-6f); + } + __syncthreads(); + Tcompute r_norm_q = norm_val[0]; + + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + q_local[i] *= r_norm_q; + } + + sum_sq = 0.0f; + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + sum_sq += k_local[i] * k_local[i]; + } + norm_val[thread_idx] = sum_sq; + __syncthreads(); + if (thread_idx == 0) { + Tcompute total_sum_sq = 0.0f; + for (int i = 0; i < NUM_THREADS; ++i) { + total_sum_sq += norm_val[i]; + } + norm_val[0] = rsqrtf(total_sum_sq + 1e-6f); + } + __syncthreads(); + Tcompute r_norm_k = norm_val[0]; + + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + k_local[i] *= r_norm_k; + } + __syncthreads(); + } + + Tcompute g_t = expf(static_cast(loadAsFloat(g, gate_offset))); + Tcompute beta_t = static_cast(loadAsFloat(beta, beta_offset)); + Tcompute scale = rsqrtf(static_cast(Dk)); + + for (int i = thread_idx; i < Dk; i += NUM_THREADS) { + q_local[i] *= scale; + } + __syncthreads(); + + for (int dv_idx = thread_idx; dv_idx < Dv; dv_idx += NUM_THREADS) { + Tcompute kv_mem = 0.0f; + for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { + ptrdiff_t state_idx = indexed_state_pool + ? initial_base + static_cast(dv_idx) * initial_s2 + static_cast(dk_idx) * initial_s3 + : initial_base + static_cast(dk_idx) * initial_s2 + static_cast(dv_idx) * initial_s3; + Tcompute h_prev = static_cast(loadAsFloat(initial_state, state_idx)); + kv_mem += (h_prev * g_t) * k_local[dk_idx]; + } + + Tcompute v_t = static_cast(loadAsFloat(v, v_base + dv_idx)); + Tcompute delta = (v_t - kv_mem) * beta_t; + Tcompute out_val = 0.0f; + + for (int dk_idx = 0; dk_idx < Dk; ++dk_idx) { + ptrdiff_t read_state_idx = indexed_state_pool + ? initial_base + static_cast(dv_idx) * initial_s2 + static_cast(dk_idx) * initial_s3 + : initial_base + static_cast(dk_idx) * initial_s2 + static_cast(dv_idx) * initial_s3; + ptrdiff_t write_state_idx = indexed_state_pool + ? final_base + static_cast(dv_idx) * (final_state_indices != nullptr ? initial_s2 : final_s2) + static_cast(dk_idx) * (final_state_indices != nullptr ? initial_s3 : final_s3) + : final_base + static_cast(dk_idx) * final_s2 + static_cast(dv_idx) * final_s3; + Tcompute h_prev = static_cast(loadAsFloat(initial_state, read_state_idx)); + Tcompute h_final = (h_prev * g_t) + (k_local[dk_idx] * delta); + out_val += h_final * q_local[dk_idx]; + final_state_target[write_state_idx] = static_cast(h_final); + } + out[out_base + dv_idx] = static_cast(out_val); + } +} + +#endif // __RECURRENT_GATED_DELTA_RULE_KERNEL_CUH__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/info.h b/src/infiniop/ops/recurrent_gated_delta_rule/info.h new file mode 100644 index 000000000..51644964a --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/info.h @@ -0,0 +1,178 @@ +// infiniop/ops/recurrent_gated_delta_rule/info.h + +#ifndef __RECURRENT_GATED_DELTA_RULE_INFO_H__ +#define __RECURRENT_GATED_DELTA_RULE_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op { +namespace recurrent_gated_delta_rule { + +class RecurrentGatedDeltaRuleInfo { + RecurrentGatedDeltaRuleInfo() = default; + +public: + infiniDtype_t data_dtype; + infiniDtype_t gate_dtype; + infiniDtype_t initial_state_indices_dtype; + infiniDtype_t final_state_indices_dtype; + bool use_qk_l2norm; + bool has_initial_state_indices; + bool has_final_state_indices; + bool indexed_state_pool; + + size_t B, Hk, Hv, T, Dk, Dv, pool_size, value_heads_per_key_head; + + std::vector out_strides; + std::vector initial_state_strides; + std::vector final_state_strides; + std::vector q_strides; + std::vector k_strides; + std::vector v_strides; + std::vector g_strides; + std::vector beta_strides; + + static utils::Result + create(infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, + bool use_qk_l2norm) { + + if (out_desc == nullptr || initial_state_desc == nullptr || q_desc == nullptr || k_desc == nullptr || v_desc == nullptr || g_desc == nullptr || beta_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + auto data_dtype = q_desc->dtype(); + CHECK_DTYPE(data_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (k_desc->dtype() != data_dtype || v_desc->dtype() != data_dtype || out_desc->dtype() != data_dtype || initial_state_desc->dtype() != data_dtype || (final_state_desc != nullptr && final_state_desc->dtype() != data_dtype)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto gate_dtype = g_desc->dtype(); + CHECK_DTYPE(gate_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (beta_desc->dtype() != gate_dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + bool has_initial_indices = initial_state_indices_desc != nullptr; + bool has_final_indices = final_state_indices_desc != nullptr; + bool indexed_pool = has_initial_indices || has_final_indices; + + if (has_final_indices && final_state_desc != nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + if (!has_final_indices && final_state_desc == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + if (q_desc->ndim() != 4 || k_desc->ndim() != 4 || v_desc->ndim() != 4 || out_desc->ndim() != 4 || g_desc->ndim() != 3 || beta_desc->ndim() != 3 || initial_state_desc->ndim() != 4 || (!has_final_indices && final_state_desc->ndim() != 4)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + auto q_shape = q_desc->shape(); // [B, T, Hk, Dk] + auto k_shape = k_desc->shape(); // [B, T, Hk, Dk] + auto v_shape = v_desc->shape(); // [B, T, Hv, Dv] + auto out_shape = out_desc->shape(); // [B, T, Hv, Dv] + auto g_shape = g_desc->shape(); // [B, T, Hv] + auto beta_shape = beta_desc->shape(); // [B, T, Hv] + + size_t B = q_shape[0], T = q_shape[1], Hk = q_shape[2], Dk = q_shape[3]; + size_t Hv = v_shape[2], Dv = v_shape[3]; + + if (T != 1 || k_shape[0] != B || k_shape[1] != T || k_shape[2] != Hk || k_shape[3] != Dk || v_shape[0] != B || v_shape[1] != T || out_shape[0] != B || out_shape[1] != T || out_shape[2] != Hv || out_shape[3] != Dv || g_shape[0] != B || g_shape[1] != T || g_shape[2] != Hv || beta_shape[0] != B || beta_shape[1] != T || beta_shape[2] != Hv || Hk == 0 || Hv == 0 || Hv % Hk != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (q_desc->strides()[3] != 1 || k_desc->strides()[3] != 1 || v_desc->strides()[3] != 1 || out_desc->strides()[3] != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + auto initial_shape = initial_state_desc->shape(); + size_t pool_size = initial_shape[0]; + if (indexed_pool) { + // Indexed pool layout is [pool_size, Hv, Dv, Dk]. + if (initial_shape[1] != Hv || initial_shape[2] != Dv || initial_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + // Legacy layout is [B, Hv, Dk, Dv]. + if (initial_shape[0] != B || initial_shape[1] != Hv || initial_shape[2] != Dk || initial_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (!has_final_indices) { + auto final_shape = final_state_desc->shape(); + if (indexed_pool) { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dv || final_shape[3] != Dk) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (final_shape[0] != B || final_shape[1] != Hv || final_shape[2] != Dk || final_shape[3] != Dv) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + } + + infiniDtype_t initial_indices_dtype = INFINI_DTYPE_INVALID; + infiniDtype_t final_indices_dtype = INFINI_DTYPE_INVALID; + if (has_initial_indices) { + if (initial_state_indices_desc->ndim() != 1 || initial_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + initial_indices_dtype = initial_state_indices_desc->dtype(); + CHECK_DTYPE(initial_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + if (has_final_indices) { + if (final_state_indices_desc->ndim() != 1 || final_state_indices_desc->shape()[0] != B) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + final_indices_dtype = final_state_indices_desc->dtype(); + CHECK_DTYPE(final_indices_dtype, INFINI_DTYPE_I32, INFINI_DTYPE_I64); + } + + RecurrentGatedDeltaRuleInfo info; + info.data_dtype = data_dtype; + info.gate_dtype = gate_dtype; + info.initial_state_indices_dtype = initial_indices_dtype; + info.final_state_indices_dtype = final_indices_dtype; + info.use_qk_l2norm = use_qk_l2norm; + info.has_initial_state_indices = has_initial_indices; + info.has_final_state_indices = has_final_indices; + info.indexed_state_pool = indexed_pool; + info.B = B; + info.Hk = Hk; + info.Hv = Hv; + info.T = T; + info.Dk = Dk; + info.Dv = Dv; + info.pool_size = pool_size; + info.value_heads_per_key_head = Hv / Hk; + info.out_strides = out_desc->strides(); + info.initial_state_strides = initial_state_desc->strides(); + if (final_state_desc != nullptr) { + info.final_state_strides = final_state_desc->strides(); + } + info.q_strides = q_desc->strides(); + info.k_strides = k_desc->strides(); + info.v_strides = v_desc->strides(); + info.g_strides = g_desc->strides(); + info.beta_strides = beta_desc->strides(); + + return utils::Result(info); + } +}; + +} // namespace recurrent_gated_delta_rule +} // namespace op + +#endif // __RECURRENT_GATED_DELTA_RULE_INFO_H__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu new file mode 100644 index 000000000..3c227266e --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cu @@ -0,0 +1,276 @@ +// recurrent_gated_delta_rule_nvidia.cu + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "recurrent_gated_delta_rule_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel.cuh" +#include + +template +INFINIOP_CUDA_KERNEL recurrentGatedDeltaRule( + Tdata *out, Tdata *initial_state, Tdata *final_state, + const Tdata *q, const Tdata *k, const Tdata *v, + const Tgate *g, const Tgate *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + bool use_qk_l2norm, + bool indexed_state_pool, + size_t Hk, + size_t value_heads_per_key_head, + ptrdiff_t out_s0, + ptrdiff_t out_s1, + ptrdiff_t out_s2, + ptrdiff_t initial_s0, + ptrdiff_t initial_s1, + ptrdiff_t initial_s2, + ptrdiff_t initial_s3, + ptrdiff_t final_s0, + ptrdiff_t final_s1, + ptrdiff_t final_s2, + ptrdiff_t final_s3, + ptrdiff_t q_s0, + ptrdiff_t q_s1, + ptrdiff_t q_s2, + ptrdiff_t k_s0, + ptrdiff_t k_s1, + ptrdiff_t k_s2, + ptrdiff_t v_s0, + ptrdiff_t v_s1, + ptrdiff_t v_s2, + ptrdiff_t g_s0, + ptrdiff_t g_s1, + ptrdiff_t g_s2, + ptrdiff_t beta_s0, + ptrdiff_t beta_s1, + ptrdiff_t beta_s2) { + recurrentGatedDeltaRuleKernel( + out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, + initial_state_indices_i64, final_state_indices_i64, + use_qk_l2norm, indexed_state_pool, + Hk, value_heads_per_key_head, + out_s0, out_s1, out_s2, + initial_s0, initial_s1, initial_s2, initial_s3, + final_s0, final_s1, final_s2, final_s3, + q_s0, q_s1, q_s2, + k_s0, k_s1, k_s2, + v_s0, v_s1, v_s2, + g_s0, g_s1, g_s2, + beta_s0, beta_s1, beta_s2); +} + +namespace op { +namespace recurrent_gated_delta_rule { +namespace nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, + bool use_qk_l2norm) { + auto info = RecurrentGatedDeltaRuleInfo::create( + out_desc, initial_state_desc, final_state_desc, q_desc, k_desc, v_desc, + g_desc, beta_desc, initial_state_indices_desc, final_state_indices_desc, + use_qk_l2norm); + CHECK_RESULT(info); + + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernelTyped( + const RecurrentGatedDeltaRuleInfo &_info, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + cudaStream_t stream) { + dim3 grid(uint32_t(_info.B), uint32_t(_info.Hv), 1); + dim3 block(NUM_THREADS); + size_t shared_mem_size = (Dk + Dk + NUM_THREADS) * sizeof(float); + + auto final_s0 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[0]; + auto final_s1 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[1]; + auto final_s2 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[2]; + auto final_s3 = _info.final_state_strides.empty() ? 0 : _info.final_state_strides[3]; + + recurrentGatedDeltaRule + <<>>( + static_cast(out), + static_cast(initial_state), + static_cast(final_state), + static_cast(q), + static_cast(k), + static_cast(v), + static_cast(g), + static_cast(beta), + initial_state_indices, + final_state_indices, + initial_state_indices_i64, + final_state_indices_i64, + _info.use_qk_l2norm, + _info.indexed_state_pool, + _info.Hk, + _info.value_heads_per_key_head, + _info.out_strides[0], + _info.out_strides[1], + _info.out_strides[2], + _info.initial_state_strides[0], + _info.initial_state_strides[1], + _info.initial_state_strides[2], + _info.initial_state_strides[3], + final_s0, + final_s1, + final_s2, + final_s3, + _info.q_strides[0], + _info.q_strides[1], + _info.q_strides[2], + _info.k_strides[0], + _info.k_strides[1], + _info.k_strides[2], + _info.v_strides[0], + _info.v_strides[1], + _info.v_strides[2], + _info.g_strides[0], + _info.g_strides[1], + _info.g_strides[2], + _info.beta_strides[0], + _info.beta_strides[1], + _info.beta_strides[2]); + return infiniStatus_t::INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernelForGate( + const RecurrentGatedDeltaRuleInfo &_info, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + cudaStream_t stream) { + switch (_info.gate_dtype) { + case INFINI_DTYPE_F16: + return launchKernelTyped( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_BF16: + return launchKernelTyped( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_F32: + return launchKernelTyped( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + default: + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +template +infiniStatus_t launchKernel( + const RecurrentGatedDeltaRuleInfo &_info, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + bool initial_state_indices_i64, + bool final_state_indices_i64, + cudaStream_t stream) { + switch (_info.data_dtype) { + case INFINI_DTYPE_F16: + return launchKernelForGate( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_BF16: + return launchKernelForGate<__nv_bfloat16, Dk, Dv, NUM_THREADS>( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + case INFINI_DTYPE_F32: + return launchKernelForGate( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, initial_state_indices_i64, final_state_indices_i64, stream); + default: + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + + if (_info.has_initial_state_indices && initial_state_indices == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + if (_info.has_final_state_indices && final_state_indices == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + if (!_info.has_final_state_indices && final_state == nullptr) { + return INFINI_STATUS_NULL_POINTER; + } + + bool initial_indices_i64 = _info.initial_state_indices_dtype == INFINI_DTYPE_I64; + bool final_indices_i64 = _info.final_state_indices_dtype == INFINI_DTYPE_I64; + + if (_info.Dk == 128 && _info.Dv == 128) { + if (_opaque->internal->maxThreadsPerBlock() >= 128) { + return launchKernel<128, 128, 128>( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, + initial_indices_i64, final_indices_i64, stream); + } + } else if (_info.Dk == 64 && _info.Dv == 64) { + if (_opaque->internal->maxThreadsPerBlock() >= 64) { + return launchKernel<64, 64, 64>( + _info, out, initial_state, final_state, q, k, v, g, beta, + initial_state_indices, final_state_indices, + initial_indices_i64, final_indices_i64, stream); + } + } + + return infiniStatus_t::INFINI_STATUS_BAD_TENSOR_SHAPE; +} + +} // namespace nvidia +} // namespace recurrent_gated_delta_rule +} // namespace op diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh new file mode 100644 index 000000000..82236a31d --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/nvidia/recurrent_gated_delta_rule_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ +#define __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ + +#include "../recurrent_gated_delta_rule.h" + +DESCRIPTOR(nvidia) + +#endif // __RECURRENT_GATED_DELTA_RULE_NVIDIA_H__ diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc b/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc new file mode 100644 index 000000000..465e8d6c9 --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/operator.cc @@ -0,0 +1,112 @@ +// infiniop/ops/recurrent_gated_delta_rule/operator.cc + +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/recurrent_gated_delta_rule.h" + +#if defined(ENABLE_NVIDIA_API) +#include "nvidia/recurrent_gated_delta_rule_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor( + infiniopHandle_t handle, + infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t initial_state_desc, + infiniopTensorDescriptor_t final_state_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t g_desc, + infiniopTensorDescriptor_t beta_desc, + infiniopTensorDescriptor_t initial_state_indices_desc, + infiniopTensorDescriptor_t final_state_indices_desc, + bool use_qk_l2norm) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::recurrent_gated_delta_rule::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor **>( \ + desc_ptr), \ + out_desc, initial_state_desc, final_state_desc, q_desc, k_desc, \ + v_desc, g_desc, beta_desc, initial_state_indices_desc, \ + final_state_indices_desc, use_qk_l2norm); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopRecurrentGatedDeltaRule( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, void *initial_state, void *final_state, + const void *q, const void *k, const void *v, + const void *g, const void *beta, + const void *initial_state_indices, + const void *final_state_indices, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc) \ + ->calculate(workspace, workspace_size, out, initial_state, \ + final_state, q, k, v, g, beta, initial_state_indices, \ + final_state_indices, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__INFINI_C infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor( + infiniopRecurrentGatedDeltaRuleDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast< \ + op::recurrent_gated_delta_rule::NAMESPACE::Descriptor *>(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h b/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h new file mode 100644 index 000000000..14846ea25 --- /dev/null +++ b/src/infiniop/ops/recurrent_gated_delta_rule/recurrent_gated_delta_rule.h @@ -0,0 +1,60 @@ +// infiniop/ops/recurrent_gated_delta_rule.h + +#ifndef __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ +#define __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::recurrent_gated_delta_rule::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + RecurrentGatedDeltaRuleInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + RecurrentGatedDeltaRuleInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t initial_state_desc, \ + infiniopTensorDescriptor_t final_state_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t g_desc, \ + infiniopTensorDescriptor_t beta_desc, \ + infiniopTensorDescriptor_t initial_state_indices_desc, \ + infiniopTensorDescriptor_t final_state_indices_desc, \ + bool use_qk_l2norm); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, void *initial_state, void *final_state, \ + const void *q, const void *k, const void *v, \ + const void *g, const void *beta, \ + const void *initial_state_indices, \ + const void *final_state_indices, \ + void *stream) const; \ + }; \ + } + +#endif // __INFINIOP_RECURRENT_GATED_DELTA_RULE_H__ diff --git a/test/infinicore/ops/chunk_gated_delta_rule.py b/test/infinicore/ops/chunk_gated_delta_rule.py new file mode 100644 index 000000000..0165868de --- /dev/null +++ b/test/infinicore/ops/chunk_gated_delta_rule.py @@ -0,0 +1,344 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import infinicore +import torch +from framework import ( + BaseOperatorTest, + TensorSpec, + TestCase, + GenericTestRunner, + TensorInitializer, +) + +# Test cases: +# (n_khead, kdim, n_vhead, vdim, seqlens, init_state_indices, final_state_indices, state_pool_size) +_VARLEN_TEST_CASES_DATA = [ + (16, 128, 48, 128, (13,), (0,), (0,), 1), + (16, 128, 48, 128, (13,), (1,), (0,), 2), + (16, 128, 48, 128, (13, 20), (1, 1), (0, 0), 4), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, +} + +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def torch_chunk_gated_delta_rule_ref( + q, + k, + v, + g, + beta, + initial_state, + cu_seqlens=None, + initial_state_indices=None, + final_state_indices=None, + use_qk_l2norm=False, + chunk_size=64, +): + + def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + def _run_one(q, k, v, g, beta, init_state): + # q/k: [1, T, Hk, Dk] + # v/out: [1, T, Hv, Dv] + # g/beta: [1, T, Hv] + # init_state: [B, Hv, Dv, Dk] + # return state: [B, Hv, Dv, Dk] + initial_dtype = q.dtype + + if use_qk_l2norm: + q = l2norm(q, dim=-1, eps=1e-6) + k = l2norm(k, dim=-1, eps=1e-6) + + B, T, Hk, Dk = q.shape + _, _, Hv, Dv = v.shape + assert B == 1 + assert Hv % Hk == 0 + assert init_state.shape == (B, Hv, Dv, Dk) + + group_size = Hv // Hk + if group_size != 1: + q = q.repeat_interleave(group_size, dim=2) + k = k.repeat_interleave(group_size, dim=2) + + q = q.transpose(1, 2).contiguous().float() + k = k.transpose(1, 2).contiguous().float() + v = v.transpose(1, 2).contiguous().float() + + beta = beta.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + + B, H, sequence_length, Dk = k.shape + Dv = v.shape[-1] + + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + + q = torch.nn.functional.pad(q, (0, 0, 0, pad_size)) + k = torch.nn.functional.pad(k, (0, 0, 0, pad_size)) + v = torch.nn.functional.pad(v, (0, 0, 0, pad_size)) + beta = torch.nn.functional.pad(beta, (0, pad_size)) + g = torch.nn.functional.pad(g, (0, pad_size)) + + total_sequence_length = sequence_length + pad_size + scale = 1 / (q.shape[-1] ** 0.5) + q = q * scale + + v_beta = v * beta.unsqueeze(-1) + k_beta = k * beta.unsqueeze(-1) + + q, k, v, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (q, k, v, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + + attn = -((k_beta @ k.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + v = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + # [B, Hv, Dv, Dk] -> [B, Hv, Dk, Dv] + last_state = init_state.transpose(-1, -2).contiguous().float().clone() + + out = torch.zeros_like(v) + + for i in range(total_sequence_length // chunk_size): + q_i = q[:, :, i] + k_i = k[:, :, i] + v_i = v[:, :, i] + + attn = q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i] + + v_prime = k_cumdecay[:, :, i] @ last_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state + out[:, :, i] = attn_inter + attn @ v_new + + last_state = ( + last_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + out = out.reshape(out.shape[0], out.shape[1], -1, out.shape[-1]) + out = out[:, :, :sequence_length] + out = out.transpose(1, 2).contiguous().to(initial_dtype) + + # [B, Hv, Dk, Dv] -> [B, Hv, Dv, Dk] + final_state = last_state.transpose(-1, -2).contiguous().to(init_state.dtype) + + return out, final_state + + if cu_seqlens is None: + out, final_state = _run_one(q, k, v, g, beta, initial_state) + + if initial_state_indices is not None: + for b, dst in enumerate(final_state_indices.cpu().tolist()): + initial_state[dst].copy_(final_state[b].to(initial_state.dtype)) + + return out + + cu = cu_seqlens.cpu().tolist() + batch = len(cu) - 1 + total_tokens = cu[-1] + + out = torch.empty_like(v[:, :total_tokens]) + indexed_state_pool = initial_state_indices is not None + + for b in range(batch): + start = cu[b] + end = cu[b + 1] + + q_b = q[:, start:end] + k_b = k[:, start:end] + v_b = v[:, start:end] + g_b = g[:, start:end] + beta_b = beta[:, start:end] + + if indexed_state_pool: + src = int(initial_state_indices[b].item()) + init_b = initial_state[src : src + 1] + else: + init_b = initial_state[b : b + 1] + + out_b, final_b = _run_one(q_b, k_b, v_b, g_b, beta_b, init_b) + out[:, start:end].copy_(out_b) + + if indexed_state_pool: + dst = int(final_state_indices[b].item()) + initial_state[dst].copy_(final_b[0].to(initial_state.dtype)) + + return out + + +def parse_varlen_test_cases(): + tests = [] + + for ( + n_khead, + kdim, + n_vhead, + vdim, + seqlens, + init_state_indices, + final_state_indices, + state_pool_size, + ) in _VARLEN_TEST_CASES_DATA: + batch = len(seqlens) + total_tokens = sum(seqlens) + cu_seqlens = [0] + for seqlen in seqlens: + cu_seqlens.append(cu_seqlens[-1] + seqlen) + + q_shape = (1, total_tokens, n_khead, kdim) + k_shape = (1, total_tokens, n_khead, kdim) + v_shape = (1, total_tokens, n_vhead, vdim) + g_shape = (1, total_tokens, n_vhead) + beta_shape = (1, total_tokens, n_vhead) + + # Indexed state-pool mode in your wrapper doc: + # initial_state: [pool_size, Hv, Dv, Dk] + state_shape = (state_pool_size, n_vhead, vdim, kdim) + + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-3}) + + q_spec = TensorSpec.from_tensor(q_shape, None, dtype, scale=0.2, bias=-0.1) + k_spec = TensorSpec.from_tensor(k_shape, None, dtype, scale=0.2, bias=-0.1) + v_spec = TensorSpec.from_tensor(v_shape, None, dtype, scale=0.2, bias=-0.1) + g_spec = TensorSpec.from_tensor( + g_shape, None, infinicore.float32, scale=0.02, bias=-0.01 + ) + beta_spec = TensorSpec.from_tensor( + beta_shape, None, infinicore.float32, scale=0.5, bias=0.0 + ) + state_spec = TensorSpec.from_tensor( + state_shape, None, dtype, init_mode=TensorInitializer.ZEROS + ) + + cu_seqlens_spec = TensorSpec.from_tensor( + (batch + 1,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(cu_seqlens, dtype=torch.int32), + ) + + init_indices_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(init_state_indices, dtype=torch.int32), + ) + final_indices_spec = TensorSpec.from_tensor( + (batch,), + None, + infinicore.int32, + init_mode=TensorInitializer.MANUAL, + set_tensor=torch.tensor(final_state_indices, dtype=torch.int32), + ) + + tests.append( + TestCase( + inputs=[ + q_spec, + k_spec, + v_spec, + g_spec, + beta_spec, + state_spec, + cu_seqlens_spec, + init_indices_spec, + final_indices_spec, + ], + kwargs={ + "use_qk_l2norm": True, + "chunk_size": 64, + }, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="ChunkGatedDeltaRule - VARLEN_INDEXED_STATE_POOL", + ) + ) + + return tests + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("ChunkGatedDeltaRule") + + def get_test_cases(self): + return parse_varlen_test_cases() + + def torch_operator(self, *args, **kwargs): + args = list(args) + args[5] = args[5].clone() + return torch_chunk_gated_delta_rule_ref(*args, **kwargs) + + def infinicore_operator( + self, + q, + k, + v, + g, + beta, + states, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size=64, + ): + return infinicore.nn.functional.chunk_gated_delta_rule( + q, + k, + v, + g, + beta, + states, + cu_seqlens=cu_seqlens, + initial_state_indices=initial_state_indices, + final_state_indices=final_state_indices, + use_qk_l2norm=use_qk_l2norm, + chunk_size=chunk_size, + ) + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/recurrent_gated_delta_rule.py b/test/infinicore/ops/recurrent_gated_delta_rule.py new file mode 100644 index 000000000..4fbc59490 --- /dev/null +++ b/test/infinicore/ops/recurrent_gated_delta_rule.py @@ -0,0 +1,291 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import torch.nn.functional as torch_F +from framework import ( + BaseOperatorTest, + GenericTestRunner, + TensorInitializer, + TensorSpec, + TestCase, +) + +import infinicore + +# Test cases: +# (B, T, Hk, Hv, Dk, Dv, use_qk_l2norm, strided_qkv) +_TEST_CASES = [ + (7, 1, 40, 40, 128, 128, True, False), + (5, 1, 64, 64, 128, 128, False, False), + (1, 1, 8, 8, 64, 64, True, False), + (2, 1, 4, 8, 64, 64, False, False), + (2, 1, 4, 8, 64, 64, True, True), +] +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 5e-3, "rtol": 5e-2}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, +} + + +def ref_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + use_qk_l2norm=False, +): + initial_dtype = query.dtype + if use_qk_l2norm: + query = torch_F.normalize(query, p=2, dim=-1) + key = torch_F.normalize(key, p=2, dim=-1) + + query, key, value, beta, g = [ + x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + state = initial_state.contiguous().to(torch.float32).clone() + batch_size, sequence_length, key_heads, _ = key.shape + value_heads, v_head_dim = value.shape[2], value.shape[-1] + value_heads_per_key_head = value_heads // key_heads + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + out = torch.zeros( + batch_size, + sequence_length, + value_heads, + v_head_dim, + device=value.device, + dtype=torch.float32, + ) + + for i in range(sequence_length): + for vh in range(value_heads): + kh = vh // value_heads_per_key_head + q_t = query[:, i, kh] + k_t = key[:, i, kh] + v_t = value[:, i, vh] + g_t = g[:, i, vh].exp().view(batch_size, 1, 1) + beta_t = beta[:, i, vh].view(batch_size, 1) + state_t = state[:, vh] + + state_t = state_t * g_t + kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state_t = state_t + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + state[:, vh] = state_t + out[:, i, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) + + return out.contiguous().to(initial_dtype), state.contiguous().to(initial_dtype) + + +def strided_bthd_strides(shape): + _, T, H, D = shape + return (8 * T * H * D, 4 * H * D, 2 * D, 1) + + +def tensor_spec(shape, dtype, strides=None): + return TensorSpec.from_tensor(shape, strides, dtype, scale=0.2, bias=-0.1) + + +def gate_spec(shape): + return TensorSpec.from_tensor( + shape, None, infinicore.float32, scale=0.02, bias=-0.01 + ) + + +def beta_spec(shape): + return TensorSpec.from_tensor(shape, None, infinicore.float32, scale=0.5, bias=0.0) + + +def index_spec(values): + values = torch.tensor(values, dtype=torch.int64) + return TensorSpec.from_tensor( + values.shape, + None, + infinicore.int64, + init_mode=TensorInitializer.MANUAL, + set_tensor=values, + ) + + +def parse_test_cases(): + tests = [] + + for B, T, Hk, Hv, Dk, Dv, use_qk_l2norm, strided_qkv in _TEST_CASES: + q_shape = (B, T, Hk, Dk) + k_shape = (B, T, Hk, Dk) + v_shape = (B, T, Hv, Dv) + gate_shape = (B, T, Hv) + initial_state_shape = (B, Hv, Dk, Dv) + pool_size = B * 2 + 3 + state_pool_shape = (pool_size, Hv, Dv, Dk) + q_strides = strided_bthd_strides(q_shape) if strided_qkv else None + k_strides = strided_bthd_strides(k_shape) if strided_qkv else None + v_strides = strided_bthd_strides(v_shape) if strided_qkv else None + + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + + base_inputs = [ + tensor_spec(q_shape, dtype, q_strides), + tensor_spec(k_shape, dtype, k_strides), + tensor_spec(v_shape, dtype, v_strides), + gate_spec(gate_shape), + beta_spec(gate_shape), + ] + + tests.append( + TestCase( + inputs=base_inputs + + [TensorSpec.from_tensor(initial_state_shape, None, dtype)], + kwargs={ + "mode": "legacy", + "use_qk_l2norm": use_qk_l2norm, + }, + description=( + f"legacy B={B}, T={T}, Hk={Hk}, Hv={Hv}, " + f"Dk={Dk}, Dv={Dv}, dtype={dtype}, " + f"strided_qkv={strided_qkv}, l2norm={use_qk_l2norm}" + ), + tolerance=tol, + ) + ) + + tests.append( + TestCase( + inputs=base_inputs + + [ + TensorSpec.from_tensor(state_pool_shape, None, dtype), + index_spec(range(1, B + 1)), + index_spec(range(B + 1, 2 * B + 1)), + ], + kwargs={ + "mode": "indexed_pool", + "use_qk_l2norm": use_qk_l2norm, + }, + output_count=2, + description=( + f"indexed pool B={B}, T={T}, Hk={Hk}, Hv={Hv}, " + f"Dk={Dk}, Dv={Dv}, dtype={dtype}, " + f"strided_qkv={strided_qkv}, l2norm={use_qk_l2norm}" + ), + tolerance=tol, + ) + ) + + for dtype in _TENSOR_DTYPES: + tests.append( + TestCase( + inputs=[ + tensor_spec((1, 48, 128), dtype), + tensor_spec((1, 48, 128), dtype), + tensor_spec((1, 48, 128), dtype), + gate_spec((1, 1, 48)), + beta_spec((1, 1, 48)), + TensorSpec.from_tensor((1, 48, 128, 128), None, dtype), + ], + kwargs={"mode": "user_3d", "use_qk_l2norm": False}, + description=f"user 3D repro dtype={dtype}", + tolerance=_TOLERANCE_MAP[dtype], + ) + ) + + return tests + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("recurrent_gated_delta_rule") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, q, k, v, g, beta, initial_state, *args, **kwargs): + mode = kwargs.pop("mode") + use_qk_l2norm = kwargs.pop("use_qk_l2norm", False) + + if mode == "legacy": + out, _ = ref_recurrent_gated_delta_rule( + q, k, v, g, beta, initial_state, use_qk_l2norm=use_qk_l2norm + ) + return out + + if mode == "indexed_pool": + initial_state_indices, final_state_indices = args + state_pool = initial_state.clone() + gathered_initial_state = ( + state_pool[initial_state_indices].transpose(-1, -2).contiguous() + ) + out, final_state = ref_recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + gathered_initial_state, + use_qk_l2norm=use_qk_l2norm, + ) + state_pool[final_state_indices] = final_state.transpose(-1, -2).contiguous() + return out, state_pool + + if mode == "user_3d": + out, _ = ref_recurrent_gated_delta_rule( + q.unsqueeze(1), + k.unsqueeze(1), + v.unsqueeze(1), + g, + beta, + initial_state, + use_qk_l2norm=use_qk_l2norm, + ) + return out + + raise ValueError(f"Unsupported test mode: {mode}") + + def infinicore_operator(self, q, k, v, g, beta, initial_state, *args, **kwargs): + mode = kwargs.pop("mode") + use_qk_l2norm = kwargs.pop("use_qk_l2norm", False) + + if mode == "legacy" or mode == "user_3d": + return infinicore.nn.functional.recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + initial_state, + use_qk_l2norm=use_qk_l2norm, + ) + + if mode == "indexed_pool": + initial_state_indices, final_state_indices = args + out = infinicore.nn.functional.recurrent_gated_delta_rule( + q, + k, + v, + g, + beta, + initial_state, + initial_state_indices=initial_state_indices, + final_state_indices=final_state_indices, + use_qk_l2norm=use_qk_l2norm, + ) + return out, initial_state + + raise ValueError(f"Unsupported test mode: {mode}") + + +def main(): + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infiniop/chunk_gated_delta_rule.py b/test/infiniop/chunk_gated_delta_rule.py new file mode 100644 index 000000000..674ecad8e --- /dev/null +++ b/test/infiniop/chunk_gated_delta_rule.py @@ -0,0 +1,432 @@ +import ctypes +from ctypes import c_uint64 + +import torch +import torch.nn.functional as F + +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + + +def ref_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + cu_seqlens=None, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + + query, key, value, beta, g = [ + x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + state = initial_state.contiguous().to(torch.float32).clone() + + if cu_seqlens is None: + batch_size, sequence_length, key_heads, k_head_dim = key.shape + spans = [(b, 0, sequence_length) for b in range(batch_size)] + else: + key_heads, k_head_dim = key.shape[2], key.shape[3] + batch_size = cu_seqlens.numel() - 1 + spans = [ + (b, int(cu_seqlens[b].item()), int(cu_seqlens[b + 1].item())) + for b in range(batch_size) + ] + + value_heads, v_head_dim = value.shape[2], value.shape[3] + value_heads_per_key_head = value_heads // key_heads + scale = 1 / (k_head_dim**0.5) + query = query * scale + out = torch.zeros_like(value, dtype=torch.float32) + + for b, start, end in spans: + for vh in range(value_heads): + kh = vh // value_heads_per_key_head + state_t = state[b, vh] + for t in range(start, end): + token_b = 0 if cu_seqlens is not None else b + q_t = query[token_b, t, kh] + k_t = key[token_b, t, kh] + v_t = value[token_b, t, vh] + g_t = g[token_b, t, vh].exp() + beta_t = beta[token_b, t, vh] + + state_t = state_t * g_t + kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state_t = state_t + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + state[b, vh] = state_t + out[token_b, t, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) + + return out.contiguous().to(initial_dtype), state.contiguous().to(initial_dtype) + + +_PADDED_TEST_CASES_DATA = [ + # B, T, n_khead, kdim, n_vhead, vdim, chunk_size, use_qk_l2norm, strided_qkv + (2, 17, 4, 64, 4, 64, 8, True, False), + (2, 19, 4, 64, 8, 64, 8, False, False), + (2, 13, 4, 64, 8, 64, 8, True, True), +] + +# Test cases: (n_khead, kdim, n_vhead, vdim, (seqlens), (init_state_indices), +# final_state_indices, state_pool_size) +_VARLEN_TEST_CASES_DATA = [ + (4, 64, 8, 64, (1, 17, 3, 9), (1, 2, 3, 4), (5, 6, 7, 8), 13), + (16, 128, 48, 128, (13,), (0,), (0,), 1), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-2, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-2, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-4, "rtol": 1e-4}, +} + +DEBUG = False + + +def parse_test_cases(): + tests = [] + for case in _PADDED_TEST_CASES_DATA: + tests.append(("padded", case)) + for case in _VARLEN_TEST_CASES_DATA: + tests.append(("varlen_indexed_pool", case)) + return tests + + +def bthd_strides(B, T, H, D, strided): + if not strided: + return None + return (T * H * D * 2, H * D * 2, D * 2, 1) + + +def make_gate(shape, device): + return TestTensor.from_torch( + F.logsigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device + ) + + +def make_beta(shape, device): + return TestTensor.from_torch( + torch.sigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device + ) + + +def run_op( + handle, + device, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size, +): + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateChunkGatedDeltaRuleDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + initial_state.descriptor, + final_state.descriptor if final_state is not None else ctypes.c_void_p(0), + q.descriptor, + k.descriptor, + v.descriptor, + g.descriptor, + beta.descriptor, + cu_seqlens.descriptor if cu_seqlens is not None else ctypes.c_void_p(0), + initial_state_indices.descriptor + if initial_state_indices is not None + else ctypes.c_void_p(0), + final_state_indices.descriptor + if final_state_indices is not None + else ctypes.c_void_p(0), + ctypes.c_bool(use_qk_l2norm), + ctypes.c_size_t(chunk_size), + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetChunkGatedDeltaRuleWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + for tensor in [ + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + ]: + if tensor is not None: + tensor.destroy_desc() + + check_error( + LIBINFINIOP.infiniopChunkGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + initial_state.data(), + final_state.data() if final_state is not None else None, + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + cu_seqlens.data() if cu_seqlens is not None else None, + initial_state_indices.data() if initial_state_indices is not None else None, + final_state_indices.data() if final_state_indices is not None else None, + None, + ) + ) + check_error(LIBINFINIOP.infiniopDestroyChunkGatedDeltaRuleDescriptor(descriptor)) + + +def test_padded( + handle, + device, + test_case, + dtype=InfiniDtype.F16, + sync=None, +): + B, T, Hk, Dk, Hv, Dv, chunk_size, use_qk_l2norm, strided_qkv = test_case + print( + f"Testing ChunkGatedDeltaRule on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, chunk={chunk_size}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, strided_qkv={strided_qkv}, " + f"l2norm={use_qk_l2norm}" + ) + q = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + k = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + v = TestTensor( + (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), dtype, device + ) + g = make_gate((B, T, Hv), device) + beta = make_beta((B, T, Hv), device) + initial_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + final_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + out = TestTensor( + (B, T, Hv, Dv), + bthd_strides(B, T, Hv, Dv, strided_qkv), + dtype, + device, + mode="zeros", + ) + + ans_out, ans_final_state = ref_chunk_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + initial_state.torch_tensor(), + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + if sync: + sync() + + run_op( + handle, + device, + out, + initial_state, + final_state, + q, + k, + v, + g, + beta, + None, + None, + None, + use_qk_l2norm, + chunk_size, + ) + + if sync: + sync() + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose( + final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol + ) + + +def test_varlen_indexed_pool( + handle, + device, + test_case, + dtype=InfiniDtype.F16, + sync=None, +): + ( + Hk, + Dk, + Hv, + Dv, + lengths, + initial_state_indices_data, + final_state_indices_data, + pool_size, + ) = test_case + chunk_size = 8 + use_qk_l2norm = True + lengths = tuple(lengths) + initial_state_indices_data = tuple(initial_state_indices_data) + final_state_indices_data = tuple(final_state_indices_data) + B = len(lengths) + total_tokens = sum(lengths) + print( + f"Testing ChunkGatedDeltaRule varlen indexed pool on {InfiniDeviceNames[device]} with " + f"lengths={lengths}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, chunk={chunk_size}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, " + f"initial_indices={initial_state_indices_data}, final_indices={final_state_indices_data}, " + f"pool_size={pool_size}" + ) + q = TestTensor((1, total_tokens, Hk, Dk), None, dtype, device) + k = TestTensor((1, total_tokens, Hk, Dk), None, dtype, device) + v = TestTensor((1, total_tokens, Hv, Dv), None, dtype, device) + g = make_gate((1, total_tokens, Hv), device) + beta = make_beta((1, total_tokens, Hv), device) + out = TestTensor((1, total_tokens, Hv, Dv), None, dtype, device, mode="zeros") + + cu = torch.tensor( + [0] + list(torch.tensor(lengths).cumsum(0).tolist()), + dtype=torch.int64, + device=q.torch_tensor().device, + ) + cu_seqlens = TestTensor.from_torch(cu, InfiniDtype.I64, device) + initial_state_pool = TestTensor((pool_size, Hv, Dv, Dk), None, dtype, device) + initial_state_indices_torch = torch.tensor( + initial_state_indices_data, dtype=torch.int64, device=q.torch_tensor().device + ) + final_state_indices_torch = torch.tensor( + final_state_indices_data, dtype=torch.int64, device=q.torch_tensor().device + ) + initial_state_indices = TestTensor.from_torch( + initial_state_indices_torch, InfiniDtype.I64, device + ) + final_state_indices = TestTensor.from_torch( + final_state_indices_torch, InfiniDtype.I64, device + ) + + gathered_initial = ( + initial_state_pool.torch_tensor()[initial_state_indices_torch] + .transpose(-1, -2) + .contiguous() + ) + ans_out, ans_final_state = ref_chunk_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + gathered_initial, + cu_seqlens=cu, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + ans_pool = initial_state_pool.torch_tensor().clone() + ans_pool[final_state_indices_torch] = ans_final_state.transpose(-1, -2).contiguous() + if sync: + sync() + + run_op( + handle, + device, + out, + initial_state_pool, + None, + q, + k, + v, + g, + beta, + cu_seqlens, + initial_state_indices, + final_state_indices, + use_qk_l2norm, + chunk_size, + ) + + if sync: + sync() + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug(initial_state_pool.actual_tensor(), ans_pool, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose( + initial_state_pool.actual_tensor(), ans_pool, atol=atol, rtol=rtol + ) + + +def test( + handle, + device, + mode, + test_case, + dtype=InfiniDtype.F16, + sync=None, +): + if mode == "padded": + return test_padded(handle, device, test_case, dtype=dtype, sync=sync) + if mode == "varlen_indexed_pool": + return test_varlen_indexed_pool( + handle, device, test_case, dtype=dtype, sync=sync + ) + raise ValueError(f"Unknown chunk_gated_delta_rule test mode: {mode}") + + +if __name__ == "__main__": + args = get_args() + DEBUG = args.debug + + for device in get_test_devices(args): + test_operator(device, test, parse_test_cases(), _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index a59bbdf3f..4e6dfb33b 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1266,7 +1266,6 @@ def per_tensor_quant_int8_(lib): c_void_p, c_void_p, c_void_p, - c_void_p, c_bool, c_void_p, ] @@ -2782,3 +2781,98 @@ def swap_(lib): lib.infiniopDestroySwapDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def recurrent_gated_delta_rule_(lib): + lib.infiniopCreateRecurrentGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopCreateRecurrentGatedDeltaRuleDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_void_p, + c_bool, + ] + lib.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize.restype = c_int32 + lib.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopRecurrentGatedDeltaRule.restype = c_int32 + lib.infiniopRecurrentGatedDeltaRule.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyRecurrentGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopDestroyRecurrentGatedDeltaRuleDescriptor.argtypes = [ + infiniopOperatorDescriptor_t + ] + + +@OpRegister.operator +def chunk_gated_delta_rule_(lib): + lib.infiniopCreateChunkGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopCreateChunkGatedDeltaRuleDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_bool, + c_size_t, + ] + lib.infiniopGetChunkGatedDeltaRuleWorkspaceSize.restype = c_int32 + lib.infiniopGetChunkGatedDeltaRuleWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopChunkGatedDeltaRule.restype = c_int32 + lib.infiniopChunkGatedDeltaRule.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.restype = c_int32 + lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.argtypes = [ + infiniopOperatorDescriptor_t + ] diff --git a/test/infiniop/recurrent_gated_delta_rule.py b/test/infiniop/recurrent_gated_delta_rule.py new file mode 100644 index 000000000..25a40911f --- /dev/null +++ b/test/infiniop/recurrent_gated_delta_rule.py @@ -0,0 +1,457 @@ +# test_recurrent_gated_delta_rule.py + +import ctypes +from ctypes import c_uint64 + +import torch +import torch.nn.functional as F + +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + + +def ref_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel=False, +): + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + + query, key, value, beta, g = [ + x.contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + initial_state = initial_state.contiguous().to(torch.float32).clone() + + batch_size, sequence_length, key_heads, k_head_dim = key.shape + value_heads, v_head_dim = value.shape[2], value.shape[-1] + value_heads_per_key_head = value_heads // key_heads + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros( + batch_size, + sequence_length, + value_heads, + v_head_dim, + device=value.device, + dtype=torch.float32, + ) + last_recurrent_state = initial_state + + for i in range(sequence_length): + for vh in range(value_heads): + kh = vh // value_heads_per_key_head + q_t = query[:, i, kh] + k_t = key[:, i, kh] + v_t = value[:, i, vh] + g_t = g[:, i, vh].exp().view(batch_size, 1, 1) + beta_t = beta[:, i, vh].view(batch_size, 1) + state_t = last_recurrent_state[:, vh] + + state_t = state_t * g_t + kv_mem = (state_t * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + state_t = state_t + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + last_recurrent_state[:, vh] = state_t + core_attn_out[:, i, vh] = (state_t * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + + core_attn_out = core_attn_out.contiguous().to(initial_dtype) + if last_recurrent_state is not None: + last_recurrent_state = last_recurrent_state.contiguous().to(initial_dtype) + + return core_attn_out, last_recurrent_state + + +_TEST_CASES_ = [ + (7, 1, 40, 40, 128, 128, True, False), + (5, 1, 64, 64, 128, 128, False, False), + (1, 1, 8, 8, 64, 64, True, False), + (2, 1, 4, 8, 64, 64, False, False), + (2, 1, 4, 8, 64, 64, True, True), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +def bthd_strides(B, T, H, D, strided): + if not strided: + return None + return (T * H * D * 2, H * D * 2, D * 2, 1) + + +def make_gate(shape, device): + return TestTensor.from_torch( + F.logsigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device + ) + + +def make_beta(shape, device): + return TestTensor.from_torch( + torch.sigmoid(torch.randn(*shape, dtype=torch.float32)), InfiniDtype.F32, device + ) + + +def test( + handle, + device, + B, + T, + Hk, + Hv, + Dk, + Dv, + use_qk_l2norm, + strided_qkv, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing RecurrentGatedDeltaRule on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, strided_qkv={strided_qkv}, " + f"use_qk_l2norm={use_qk_l2norm}" + ) + + q = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + k = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + v = TestTensor( + (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), dtype, device + ) + g = make_gate((B, T, Hv), device) + beta = make_beta((B, T, Hv), device) + + initial_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + out = TestTensor( + (B, T, Hv, Dv), + bthd_strides(B, T, Hv, Dv, strided_qkv), + dtype, + device, + mode="zeros", + ) + final_state = TestTensor((B, Hv, Dk, Dv), None, dtype, device) + + ans_out, ans_final_state = ref_recurrent_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + initial_state.torch_tensor(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + + if sync: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRecurrentGatedDeltaRuleDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + initial_state.descriptor, + final_state.descriptor, + q.descriptor, + k.descriptor, + v.descriptor, + g.descriptor, + beta.descriptor, + ctypes.c_void_p(0), + ctypes.c_void_p(0), + ctypes.c_bool(use_qk_l2norm), + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + q.destroy_desc() + k.destroy_desc() + v.destroy_desc() + g.destroy_desc() + beta.destroy_desc() + initial_state.destroy_desc() + out.destroy_desc() + final_state.destroy_desc() + + def lib_recurrent_gated_delta_rule(): + check_error( + LIBINFINIOP.infiniopRecurrentGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + initial_state.data(), + final_state.data(), + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + None, + None, + None, + ) + ) + + lib_recurrent_gated_delta_rule() + + if sync: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug(final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose( + final_state.actual_tensor(), ans_final_state, atol=atol, rtol=rtol + ) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: ref_recurrent_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + initial_state.torch_tensor(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lib_recurrent_gated_delta_rule, + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error( + LIBINFINIOP.infiniopDestroyRecurrentGatedDeltaRuleDescriptor(descriptor) + ) + + +def test_indexed_pool_inplace( + handle, + device, + B, + T, + Hk, + Hv, + Dk, + Dv, + use_qk_l2norm, + strided_qkv, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing RecurrentGatedDeltaRule indexed pool inplace on {InfiniDeviceNames[device]} with " + f"B={B}, T={T}, Hk={Hk}, Hv={Hv}, Dk={Dk}, Dv={Dv}, " + f"dtype={InfiniDtypeNames[dtype]}, gate_dtype=F32, strided_qkv={strided_qkv}, " + f"use_qk_l2norm={use_qk_l2norm}" + ) + + q = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + k = TestTensor( + (B, T, Hk, Dk), bthd_strides(B, T, Hk, Dk, strided_qkv), dtype, device + ) + v = TestTensor( + (B, T, Hv, Dv), bthd_strides(B, T, Hv, Dv, strided_qkv), dtype, device + ) + g = make_gate((B, T, Hv), device) + beta = make_beta((B, T, Hv), device) + + pool_size = B * 2 + 3 + initial_state_pool = TestTensor((pool_size, Hv, Dv, Dk), None, dtype, device) + index_device = q.torch_tensor().device + initial_state_indices_torch = torch.arange( + 1, B + 1, dtype=torch.int64, device=index_device + ) + final_state_indices_torch = torch.arange( + B + 1, 2 * B + 1, dtype=torch.int64, device=index_device + ) + initial_state_indices = TestTensor.from_torch( + initial_state_indices_torch, InfiniDtype.I64, device + ) + final_state_indices = TestTensor.from_torch( + final_state_indices_torch, InfiniDtype.I64, device + ) + + out = TestTensor( + (B, T, Hv, Dv), + bthd_strides(B, T, Hv, Dv, strided_qkv), + dtype, + device, + mode="zeros", + ) + + gathered_initial_state = ( + initial_state_pool.torch_tensor()[initial_state_indices_torch] + .transpose(-1, -2) + .contiguous() + ) + ans_out, ans_final_state = ref_recurrent_gated_delta_rule( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + g.torch_tensor(), + beta.torch_tensor(), + gathered_initial_state, + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm, + ) + ans_initial_state_pool = initial_state_pool.torch_tensor().clone() + ans_initial_state_pool[final_state_indices_torch] = ans_final_state.transpose( + -1, -2 + ).contiguous() + + if sync: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateRecurrentGatedDeltaRuleDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + initial_state_pool.descriptor, + ctypes.c_void_p(0), + q.descriptor, + k.descriptor, + v.descriptor, + g.descriptor, + beta.descriptor, + initial_state_indices.descriptor, + final_state_indices.descriptor, + ctypes.c_bool(use_qk_l2norm), + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetRecurrentGatedDeltaRuleWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + q.destroy_desc() + k.destroy_desc() + v.destroy_desc() + g.destroy_desc() + beta.destroy_desc() + initial_state_pool.destroy_desc() + initial_state_indices.destroy_desc() + final_state_indices.destroy_desc() + out.destroy_desc() + + check_error( + LIBINFINIOP.infiniopRecurrentGatedDeltaRule( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + initial_state_pool.data(), + None, + q.data(), + k.data(), + v.data(), + g.data(), + beta.data(), + initial_state_indices.data(), + final_state_indices.data(), + None, + ) + ) + + if sync: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + debug( + initial_state_pool.actual_tensor(), + ans_initial_state_pool, + atol=atol, + rtol=rtol, + ) + assert torch.allclose(out.actual_tensor(), ans_out, atol=atol, rtol=rtol) + assert torch.allclose( + initial_state_pool.actual_tensor(), ans_initial_state_pool, atol=atol, rtol=rtol + ) + + check_error( + LIBINFINIOP.infiniopDestroyRecurrentGatedDeltaRuleDescriptor(descriptor) + ) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + test_operator(device, test_indexed_pool_inplace, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m")