Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@
#include "ops/kv_caching.hpp"
#include "ops/layer_norm.hpp"
#include "ops/linear.hpp"
#include "ops/mamba_selective_scan.hpp"
#include "ops/matmul.hpp"
#include "ops/moe_align.hpp"
#include "ops/moe_fused_dense.hpp"
#include "ops/moe_fused_gate.hpp"
#include "ops/moe_sum.hpp"
#include "ops/moe_topk_sigmoid.hpp"
#include "ops/moe_topk_softmax.hpp"
#include "ops/prepare_moe_input.hpp"
#include "ops/nrm2.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/per_tensor_dequant_i8.hpp"
#include "ops/per_tensor_quant_i8.hpp"
#include "ops/prepare_moe_input.hpp"
#include "ops/quickgelu.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
Expand Down
43 changes: 43 additions & 0 deletions include/infinicore/ops/mamba_selective_scan.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(
MambaSelectiveScan,
Tensor,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
Tensor);

Tensor mamba_selective_scan(const Tensor &x,
const Tensor &dt,
const Tensor &b,
const Tensor &c,
const Tensor &a_log,
const Tensor &d,
const Tensor &gate,
const Tensor &dt_bias,
Tensor state);

void mamba_selective_scan_(Tensor out,
const Tensor &x,
const Tensor &dt,
const Tensor &b,
const Tensor &c,
const Tensor &a_log,
const Tensor &d,
const Tensor &gate,
const Tensor &dt_bias,
Tensor state);

} // namespace infinicore::op
7 changes: 4 additions & 3 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
#include "infiniop/ops/blas_dot.h"
#include "infiniop/ops/block_diag.h"
#include "infiniop/ops/broadcast_to.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/causal_conv1d.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/cdist.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
Expand Down Expand Up @@ -85,6 +85,7 @@
#include "infiniop/ops/logcumsumexp.h"
#include "infiniop/ops/logdet.h"
#include "infiniop/ops/lp_norm.h"
#include "infiniop/ops/mamba_selective_scan.h"
#include "infiniop/ops/masked_select.h"
#include "infiniop/ops/matrix_power.h"
#include "infiniop/ops/moe_align.h"
Expand All @@ -93,7 +94,6 @@
#include "infiniop/ops/moe_sum.h"
#include "infiniop/ops/moe_topk_sigmoid.h"
#include "infiniop/ops/moe_topk_softmax.h"
#include "infiniop/ops/prepare_moe_input.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/multi_margin_loss.h"
#include "infiniop/ops/nrm2.h"
Expand All @@ -103,6 +103,7 @@
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/pixel_shuffle.h"
#include "infiniop/ops/prepare_moe_input.h"
#include "infiniop/ops/quant/per_channel_quant_int8.h"
#include "infiniop/ops/quant/per_tensor_quant_int8.h"
#include "infiniop/ops/quickgelu.h"
Expand All @@ -116,8 +117,8 @@
#include "infiniop/ops/rotg.h"
#include "infiniop/ops/rotm.h"
#include "infiniop/ops/rotmg.h"
#include "infiniop/ops/scal.h"
#include "infiniop/ops/rwkv5_wkv.h"
#include "infiniop/ops/scal.h"
#include "infiniop/ops/scatter.h"
#include "infiniop/ops/selu.h"
#include "infiniop/ops/sigmoid.h"
Expand Down
50 changes: 50 additions & 0 deletions include/infiniop/ops/mamba_selective_scan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef __INFINIOP_MAMBA_SELECTIVE_SCAN_API_H__
#define __INFINIOP_MAMBA_SELECTIVE_SCAN_API_H__

#include "../operator_descriptor.h"
#ifdef __cplusplus
#include <cstddef>
#else
#include <stddef.h>
#endif

typedef struct InfiniopDescriptor *infiniopMambaSelectiveScanDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateMambaSelectiveScanDescriptor(
infiniopHandle_t handle,
infiniopMambaSelectiveScanDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t dt_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_log_desc,
infiniopTensorDescriptor_t d_desc,
infiniopTensorDescriptor_t gate_desc,
infiniopTensorDescriptor_t dt_bias_desc,
infiniopTensorDescriptor_t state_desc);

__INFINI_C __export infiniStatus_t infiniopGetMambaSelectiveScanWorkspaceSize(
infiniopMambaSelectiveScanDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t infiniopMambaSelectiveScan(
infiniopMambaSelectiveScanDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *x,
const void *dt,
const void *b,
const void *c,
const void *a_log,
const void *d,
const void *gate,
const void *dt_bias,
void *state,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyMambaSelectiveScanDescriptor(
infiniopMambaSelectiveScanDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .linear import linear
from .linear_w8a8i8 import linear_w8a8i8
from .log_softmax import log_softmax
from .mamba_selective_scan import mamba_selective_scan
from .multi_margin_loss import multi_margin_loss
from .pad import pad
from .prelu import prelu
Expand Down Expand Up @@ -70,6 +71,7 @@
"upsample_bilinear",
"interpolate",
"log_softmax",
"mamba_selective_scan",
"upsample_nearest",
"triplet_margin_with_distance_loss",
"embedding",
Expand Down
29 changes: 29 additions & 0 deletions python/infinicore/nn/functional/mamba_selective_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def mamba_selective_scan(
x: Tensor,
dt: Tensor,
b: Tensor,
c: Tensor,
a_log: Tensor,
d: Tensor,
gate: Tensor,
dt_bias: Tensor,
state: Tensor,
) -> Tensor:
"""Run Mamba selective scan and update ``state`` in-place."""
return Tensor(
_infinicore.mamba_selective_scan(
x._underlying,
dt._underlying,
b._underlying,
c._underlying,
a_log._underlying,
d._underlying,
gate._underlying,
dt_bias._underlying,
state._underlying,
)
)
28 changes: 28 additions & 0 deletions src/infinicore/ops/mamba_selective_scan/mamba_selective_scan.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "infinicore/ops/mamba_selective_scan.hpp"
#include "../../utils.hpp"

namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MambaSelectiveScan);

MambaSelectiveScan::MambaSelectiveScan(Tensor out, const Tensor &x, const Tensor &dt,
const Tensor &b, const Tensor &c, const Tensor &a_log,
const Tensor &d, const Tensor &gate, const Tensor &dt_bias,
Tensor state) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, x, dt, b, c, a_log, d, gate, dt_bias, state);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, x, dt, b, c, a_log, d, gate, dt_bias, state);
}
void MambaSelectiveScan::execute(Tensor out, const Tensor &x, const Tensor &dt,
const Tensor &b, const Tensor &c, const Tensor &a_log,
const Tensor &d, const Tensor &gate, const Tensor &dt_bias,
Tensor state) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(MambaSelectiveScan, out, x, dt, b, c, a_log, d, gate, dt_bias, state);
}
Tensor mamba_selective_scan(const Tensor &x, const Tensor &dt, const Tensor &b, const Tensor &c, const Tensor &a_log, const Tensor &d, const Tensor &gate, const Tensor &dt_bias, Tensor state) {
auto output = Tensor::empty(x->shape(), x->dtype(), x->device());
mamba_selective_scan_(output, x, dt, b, c, a_log, d, gate, dt_bias, state);
return output;
}
void mamba_selective_scan_(Tensor out, const Tensor &x, const Tensor &dt, const Tensor &b, const Tensor &c, const Tensor &a_log, const Tensor &d, const Tensor &gate, const Tensor &dt_bias, Tensor state) {
MambaSelectiveScan::execute(out, x, dt, b, c, a_log, d, gate, dt_bias, state);
}
} // namespace infinicore::op
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "../infiniop_impl.hpp"
#include "infinicore/ops/mamba_selective_scan.hpp"

namespace infinicore::op::mamba_selective_scan_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, MambaSelectiveScan, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, x, dt, b, c, a_log, d, gate, dt_bias, state;
};
void *plan(Tensor out, const Tensor &x, const Tensor &dt, const Tensor &b, const Tensor &c, const Tensor &a_log, const Tensor &d, const Tensor &gate, const Tensor &dt_bias, Tensor state) {
size_t seed = hash_combine(out, x, dt, b, c, a_log, d, gate, dt_bias, state);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(Descriptor, descriptor, MambaSelectiveScan, seed, out->desc(), x->desc(), dt->desc(), b->desc(), c->desc(), a_log->desc(), d->desc(), gate->desc(), dt_bias->desc(), state->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, MambaSelectiveScan, descriptor);
return new PlannedMeta{descriptor, graph::GraphTensor(workspace), graph::GraphTensor(out), graph::GraphTensor(x), graph::GraphTensor(dt), graph::GraphTensor(b), graph::GraphTensor(c), graph::GraphTensor(a_log), graph::GraphTensor(d), graph::GraphTensor(gate), graph::GraphTensor(dt_bias), graph::GraphTensor(state)};
}
void run(void *planned_meta) {
auto p = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopMambaSelectiveScan(p->descriptor->desc, p->workspace->data(), p->workspace->numel(), p->out->data(), p->x->data(), p->dt->data(), p->b->data(), p->c->data(), p->a_log->data(), p->d->data(), p->gate->data(), p->dt_bias->data(), p->state->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MambaSelectiveScan, &plan, &run, &cleanup);
} // namespace infinicore::op::mamba_selective_scan_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include "ops/logdet.hpp"
#include "ops/logical_and.hpp"
#include "ops/logical_not.hpp"
#include "ops/mamba_selective_scan.hpp"
#include "ops/masked_select.hpp"
#include "ops/matmul.hpp"
#include "ops/mha.hpp"
Expand Down Expand Up @@ -178,6 +179,7 @@ inline void bind(py::module &m) {
bind_linear(m);
bind_logdet(m);
bind_matmul(m);
bind_mamba_selective_scan(m);
bind_kron(m);
bind_mul(m);
bind_nrm2(m);
Expand Down
34 changes: 34 additions & 0 deletions src/infinicore/pybind11/ops/mamba_selective_scan.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/mamba_selective_scan.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_mamba_selective_scan(py::module &m) {
m.def("mamba_selective_scan",
&op::mamba_selective_scan,
py::arg("x"),
py::arg("dt"),
py::arg("b"),
py::arg("c"),
py::arg("a_log"),
py::arg("d"),
py::arg("gate"),
py::arg("dt_bias"),
py::arg("state"),
R"doc(Mamba selective scan. Returns out and updates state in-place.

Shapes:
x, dt, gate, out: [batch, seq_len, intermediate]
b, c: [batch, seq_len, state_size]
a_log: [intermediate, state_size]
d, dt_bias: [intermediate]
state: [batch, intermediate, state_size], float32
)doc");
}

} // namespace infinicore::ops
71 changes: 71 additions & 0 deletions src/infiniop/ops/mamba_selective_scan/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef __MAMBA_SELECTIVE_SCAN_INFO_H__
#define __MAMBA_SELECTIVE_SCAN_INFO_H__

#include "../../../utils.h"
#include "../../tensor.h"

namespace op::mamba_selective_scan {

class MambaSelectiveScanInfo {
MambaSelectiveScanInfo() = default;

public:
infiniDtype_t dtype;
size_t batch;
size_t seq_len;
size_t intermediate;
size_t state_size;

static utils::Result<MambaSelectiveScanInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t dt_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_log_desc,
infiniopTensorDescriptor_t d_desc,
infiniopTensorDescriptor_t gate_desc,
infiniopTensorDescriptor_t dt_bias_desc,
infiniopTensorDescriptor_t state_desc) {
auto dtype = x_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_BF16 && dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (out_desc->dtype() != dtype || dt_desc->dtype() != dtype || b_desc->dtype() != dtype || c_desc->dtype() != dtype || gate_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (a_log_desc->dtype() != dtype || d_desc->dtype() != dtype || dt_bias_desc->dtype() != dtype || state_desc->dtype() != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (out_desc->ndim() != 3 || x_desc->ndim() != 3 || dt_desc->ndim() != 3 || b_desc->ndim() != 3 || c_desc->ndim() != 3 || a_log_desc->ndim() != 2 || d_desc->ndim() != 1 || gate_desc->ndim() != 3 || dt_bias_desc->ndim() != 1 || state_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto xs = x_desc->shape();
if (out_desc->shape() != xs || dt_desc->shape() != xs || gate_desc->shape() != xs) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = xs[0], seq_len = xs[1], intermediate = xs[2];
auto bs = b_desc->shape();
if (c_desc->shape() != bs || bs[0] != batch || bs[1] != seq_len) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t state_size = bs[2];
if (a_log_desc->shape()[0] != intermediate || a_log_desc->shape()[1] != state_size) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (d_desc->shape()[0] != intermediate || dt_bias_desc->shape()[0] != intermediate) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto ss = state_desc->shape();
if (ss[0] < batch || ss[1] != intermediate || ss[2] != state_size) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (out_desc->strides()[2] != 1 || x_desc->strides()[2] != 1 || dt_desc->strides()[2] != 1 || b_desc->strides()[2] != 1 || c_desc->strides()[2] != 1 || a_log_desc->strides()[1] != 1 || state_desc->strides()[2] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<MambaSelectiveScanInfo>(MambaSelectiveScanInfo{dtype, batch, seq_len, intermediate, state_size});
}
};

} // namespace op::mamba_selective_scan
#endif
Loading
Loading