Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ops/blas_dot.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/cdist.hpp"
#include "ops/chunk_gated_delta_rule.hpp"
#include "ops/conv2d.hpp"
#include "ops/cross_entropy.hpp"
#include "ops/deepseek_moe.hpp"
Expand All @@ -46,6 +47,7 @@
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/reciprocal.hpp"
#include "ops/recurrent_gated_delta_rule.hpp"
#include "ops/relu.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
Expand Down
53 changes: 53 additions & 0 deletions include/infinicore/ops/chunk_gated_delta_rule.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include "infinicore.h"

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(ChunkGatedDeltaRule,
Tensor,
Tensor,
std::optional<Tensor>,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
std::optional<Tensor>,
std::optional<Tensor>,
std::optional<Tensor>,
bool,
size_t);

__export Tensor chunk_gated_delta_rule(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &g,
const Tensor &beta,
Tensor initial_state,
std::optional<Tensor> cu_seqlens = std::nullopt,
std::optional<Tensor> initial_state_indices = std::nullopt,
std::optional<Tensor> final_state_indices = std::nullopt,
bool use_qk_l2norm = false,
size_t chunk_size = 64);

__export void chunk_gated_delta_rule_(Tensor out,
Tensor initial_state,
std::optional<Tensor> final_state,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &g,
const Tensor &beta,
std::optional<Tensor> cu_seqlens,
std::optional<Tensor> initial_state_indices,
std::optional<Tensor> final_state_indices,
bool use_qk_l2norm = false,
size_t chunk_size = 64);

} // namespace infinicore::op
55 changes: 55 additions & 0 deletions include/infinicore/ops/recurrent_gated_delta_rule.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once

#include "infinicore.h"

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(RecurrentGatedDeltaRule,
Tensor,
Tensor,
std::optional<Tensor>,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
std::optional<Tensor>,
std::optional<Tensor>,
bool);

__export Tensor recurrent_gated_delta_rule(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &g,
const Tensor &beta,
const Tensor &initial_state,
bool use_qk_l2norm = false);

__export Tensor recurrent_gated_delta_rule_indexed(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &g,
const Tensor &beta,
Tensor initial_state,
const Tensor &initial_state_indices,
const Tensor &final_state_indices,
bool use_qk_l2norm = false);

__export void recurrent_gated_delta_rule_(Tensor out,
Tensor initial_state,
std::optional<Tensor> final_state,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &g,
const Tensor &beta,
std::optional<Tensor> initial_state_indices,
std::optional<Tensor> final_state_indices,
bool use_qk_l2norm = false);

} // namespace infinicore::op
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "infiniop/ops/broadcast_to.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/cdist.h"
#include "infiniop/ops/chunk_gated_delta_rule.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/cross_entropy.h"
Expand Down Expand Up @@ -100,6 +101,7 @@
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/reciprocal.h"
#include "infiniop/ops/recurrent_gated_delta_rule.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rope.h"
Expand Down
49 changes: 49 additions & 0 deletions include/infiniop/ops/chunk_gated_delta_rule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef __INFINIOP_CHUNK_GATED_DELTA_RULE_API_H__
#define __INFINIOP_CHUNK_GATED_DELTA_RULE_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopChunkGatedDeltaRuleDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateChunkGatedDeltaRuleDescriptor(
infiniopHandle_t handle,
infiniopChunkGatedDeltaRuleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv]
infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk]
infiniopTensorDescriptor_t final_state_desc, // null when final_state_indices_desc is provided
infiniopTensorDescriptor_t q_desc, // padded: [B, T, Hk, Dk]; varlen: [1, total_tokens, Hk, Dk]
infiniopTensorDescriptor_t k_desc, // same shape as q
infiniopTensorDescriptor_t v_desc, // padded: [B, T, Hv, Dv]; varlen: [1, total_tokens, Hv, Dv]
infiniopTensorDescriptor_t g_desc, // padded: [B, T, Hv]; varlen: [1, total_tokens, Hv]
infiniopTensorDescriptor_t beta_desc, // same shape/dtype as g
infiniopTensorDescriptor_t cu_seqlens_desc, // nullable; [B + 1], int32/int64
infiniopTensorDescriptor_t initial_state_indices_desc, // nullable; [B], int32/int64; enables indexed state-pool reads
infiniopTensorDescriptor_t final_state_indices_desc, // nullable; [B], int32/int64; writes final state in-place to initial_state pool
bool use_qk_l2norm,
size_t chunk_size);

__INFINI_C __export infiniStatus_t infiniopGetChunkGatedDeltaRuleWorkspaceSize(
infiniopChunkGatedDeltaRuleDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t infiniopChunkGatedDeltaRule(
infiniopChunkGatedDeltaRuleDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
void *initial_state,
void *final_state,
const void *q,
const void *k,
const void *v,
const void *g,
const void *beta,
const void *cu_seqlens,
const void *initial_state_indices,
const void *final_state_indices,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyChunkGatedDeltaRuleDescriptor(
infiniopChunkGatedDeltaRuleDescriptor_t desc);

#endif
46 changes: 46 additions & 0 deletions include/infiniop/ops/recurrent_gated_delta_rule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#ifndef __INFINIOP_RECURRENT_GATED_DELTA_RULE_API_H__
#define __INFINIOP_RECURRENT_GATED_DELTA_RULE_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopRecurrentGatedDeltaRuleDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateRecurrentGatedDeltaRuleDescriptor(
infiniopHandle_t handle,
infiniopRecurrentGatedDeltaRuleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc, // [B, T, Hv, Dv], T must be 1; last dim contiguous
infiniopTensorDescriptor_t initial_state_desc, // legacy: [B, Hv, Dk, Dv]; indexed pool: [pool_size, Hv, Dv, Dk]
infiniopTensorDescriptor_t final_state_desc, // legacy/indexed out-of-place final state; null when final_state_indices_desc is provided
infiniopTensorDescriptor_t q_desc, // [B, T, Hk, Dk], T must be 1; last dim contiguous
infiniopTensorDescriptor_t k_desc, // [B, T, Hk, Dk], same shape as q; last dim contiguous
infiniopTensorDescriptor_t v_desc, // [B, T, Hv, Dv], Hv must be a multiple of Hk; last dim contiguous
infiniopTensorDescriptor_t g_desc, // [B, T, Hv]; may have a different fp dtype from q/k/v/out/state
infiniopTensorDescriptor_t beta_desc, // [B, T, Hv]; same dtype as g
infiniopTensorDescriptor_t initial_state_indices_desc, // nullable; [B], int32/int64; enables indexed pool mode
infiniopTensorDescriptor_t final_state_indices_desc, // nullable; [B], int32/int64; writes final state in-place to initial_state pool
bool use_qk_l2norm);

__INFINI_C __export infiniStatus_t infiniopGetRecurrentGatedDeltaRuleWorkspaceSize(
infiniopRecurrentGatedDeltaRuleDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t infiniopRecurrentGatedDeltaRule(
infiniopRecurrentGatedDeltaRuleDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
void *initial_state,
void *final_state,
const void *q,
const void *k,
const void *v,
const void *g,
const void *beta,
const void *initial_state_indices,
const void *final_state_indices,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyRecurrentGatedDeltaRuleDescriptor(
infiniopRecurrentGatedDeltaRuleDescriptor_t desc);

#endif
4 changes: 4 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .avg_pool1d import avg_pool1d
from .binary_cross_entropy_with_logits import binary_cross_entropy_with_logits
from .causal_softmax import causal_softmax
from .chunk_gated_delta_rule import chunk_gated_delta_rule
from .embedding import embedding
from .flash_attention import flash_attention
from .gaussian_nll_loss import gaussian_nll_loss
Expand All @@ -23,6 +24,7 @@
from .pad import pad
from .prelu import prelu
from .random_sample import random_sample
from .recurrent_gated_delta_rule import recurrent_gated_delta_rule
from .relu6 import relu6
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
Expand All @@ -44,6 +46,7 @@
"conv2d",
"adaptive_max_pool1d",
"causal_softmax",
"chunk_gated_delta_rule",
"embedding",
"flash_attention",
"gaussian_nll_loss",
Expand All @@ -56,6 +59,7 @@
"prelu",
"relu6",
"rms_norm",
"recurrent_gated_delta_rule",
"sigmoid",
"silu",
"smooth_l1_loss",
Expand Down
66 changes: 66 additions & 0 deletions python/infinicore/nn/functional/chunk_gated_delta_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def chunk_gated_delta_rule(
q: Tensor,
k: Tensor,
v: Tensor,
g: Tensor,
beta: Tensor,
initial_state: Tensor,
*,
cu_seqlens: Tensor | None = None,
initial_state_indices: Tensor | None = None,
final_state_indices: Tensor | None = None,
use_qk_l2norm: bool = False,
chunk_size: int = 64,
) -> Tensor:
"""Run chunk gated delta rule and return only ``out``.

Padded mode shapes:
q, k: ``[B, T, Hk, Dk]``
v, out: ``[B, T, Hv, Dv]``
g, beta: ``[B, T, Hv]``
initial_state: ``[B, Hv, Dk, Dv]``

Continuous-batch mode shapes:
Pass ``cu_seqlens`` with shape ``[B + 1]`` and dtype int32/int64.
q, k: ``[1, total_tokens, Hk, Dk]``
v, out: ``[1, total_tokens, Hv, Dv]``
g, beta: ``[1, total_tokens, Hv]``

Indexed state-pool mode:
initial_state is ``[pool_size, Hv, Dv, Dk]``.
``initial_state_indices`` and ``final_state_indices`` are both ``[B]``
int32/int64 tensors. The final state is written in-place into
``initial_state[final_state_indices]`` and no final state tensor is
returned.

Notes:
``Hv`` must be a multiple of ``Hk``. q/k/v/out may be strided in the
first three dimensions, but the last dimension must be contiguous.
g and beta may use a different floating dtype from q/k/v/state.
"""
if (initial_state_indices is None) != (final_state_indices is None):
raise ValueError(
"initial_state_indices and final_state_indices must be provided together"
)

return Tensor(
_infinicore.chunk_gated_delta_rule(
q._underlying,
k._underlying,
v._underlying,
g._underlying,
beta._underlying,
initial_state._underlying,
None if cu_seqlens is None else cu_seqlens._underlying,
None
if initial_state_indices is None
else initial_state_indices._underlying,
None if final_state_indices is None else final_state_indices._underlying,
use_qk_l2norm,
chunk_size,
)
)
47 changes: 47 additions & 0 deletions python/infinicore/nn/functional/recurrent_gated_delta_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def recurrent_gated_delta_rule(
q: Tensor,
k: Tensor,
v: Tensor,
g: Tensor,
beta: Tensor,
initial_state: Tensor,
*,
initial_state_indices: Tensor | None = None,
final_state_indices: Tensor | None = None,
use_qk_l2norm: bool = False,
) -> Tensor:
if initial_state_indices is None and final_state_indices is None:
return Tensor(
_infinicore.recurrent_gated_delta_rule(
q._underlying,
k._underlying,
v._underlying,
g._underlying,
beta._underlying,
initial_state._underlying,
use_qk_l2norm,
)
)

if initial_state_indices is None or final_state_indices is None:
raise ValueError(
"initial_state_indices and final_state_indices must be provided together"
)

return Tensor(
_infinicore.recurrent_gated_delta_rule_indexed(
q._underlying,
k._underlying,
v._underlying,
g._underlying,
beta._underlying,
initial_state._underlying,
initial_state_indices._underlying,
final_state_indices._underlying,
use_qk_l2norm,
)
)
Loading
Loading