Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions csrc/cache/mamba_cache.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "mamba_cache.hpp"

#include "../global_state/global_state.hpp"
#include "infinicore/context/context.hpp"

namespace infinilm::cache {

infinicore::Tensor MambaCache::create_layer_conv_state(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size conv_kernel_dim,
infinicore::DataType dtype,
size_t pool_size) {
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
const auto [num_rank_k_heads, num_rank_v_heads] = get_rank_head_counts(num_k_heads, num_v_heads, rank_info.tp_size);
const size_t conv_state_len = conv_kernel_dim > 0 ? conv_kernel_dim - 1 : 0;
const size_t conv_dim = 2 * num_rank_k_heads * k_dim + num_rank_v_heads * v_dim;

auto conv_state = infinicore::Tensor::zeros(
{pool_size, conv_dim, conv_state_len},
dtype,
rank_info.device);
infinicore::context::syncStream();
return conv_state;
}

infinicore::Tensor MambaCache::create_layer_ssm_state(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::DataType dtype,
size_t pool_size) {
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
const auto rank_head_counts = get_rank_head_counts(num_k_heads, num_v_heads, rank_info.tp_size);
const size_t num_rank_v_heads = rank_head_counts.second;

auto ssm_state = infinicore::Tensor::zeros(
{pool_size, num_rank_v_heads, v_dim, k_dim},
dtype,
rank_info.device);

infinicore::context::syncStream();
return ssm_state;
}

std::pair<size_t, size_t> MambaCache::get_rank_head_counts(
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
size_t tp_size) {
bool is_kv_replica = (num_k_heads < tp_size && num_v_heads < tp_size && num_k_heads == num_v_heads && tp_size % num_k_heads == 0);
size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / tp_size);
size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / tp_size);
return {num_rank_k_heads, num_rank_v_heads};
}

} // namespace infinilm::cache
39 changes: 39 additions & 0 deletions csrc/cache/mamba_cache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include "base_cache.hpp"

#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include <cstddef>
#include <infinicore/dtype.hpp>
#include <utility>

namespace infinilm::cache {

class MambaCache {
public:
static infinicore::Tensor create_layer_conv_state(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size conv_kernel_dim,
infinicore::DataType dtype,
size_t pool_size);

static infinicore::Tensor create_layer_ssm_state(
infinicore::Size k_dim,
infinicore::Size v_dim,
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::DataType dtype,
size_t pool_size);

private:
static std::pair<size_t, size_t> get_rank_head_counts(
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
size_t tp_size);
};

} // namespace infinilm::cache
7 changes: 7 additions & 0 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
to_device(cu_seqlens),
to_device(block_tables),
to_device(slot_mapping),
to_device(mamba_init_state_indices),
to_device(mamba_final_state_indices),
to_device_vec(pixel_values),
to_device_vec(image_bound),
to_device_vec(tgt_sizes),
Expand All @@ -166,6 +168,11 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
input.block_tables,
input.slot_mapping};

infinilm::global_state::get_forward_context().mamba_metadata = {
input.input_offsets,
input.mamba_init_state_indices,
input.mamba_final_state_indices};

global_state::get_forward_context().mm_metadata = {
image_req_ids};

Expand Down
5 changes: 5 additions & 0 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ RankWorker::RankWorker(
cv_.wait(lk, [&] { return init_done_; });
}

RankWorker::~RankWorker() {
close();
}

std::string RankWorker::info() const {
std::stringstream ss;

Expand Down Expand Up @@ -513,6 +517,7 @@ void RankWorker::thread_loop() {
// Top-level exception: ensure any waiters are woken and the thread exits cleanly.
{
std::lock_guard<std::mutex> lk(mutex_);
init_done_ = true;
should_exit_ = true;
job_done_ = true;
}
Expand Down
6 changes: 6 additions & 0 deletions csrc/engine/rank_worker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class RankWorker {
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
std::optional<infinicore::Tensor> slot_mapping;
/// Mamba state cache indices read at the start of each request forward.
std::optional<infinicore::Tensor> mamba_init_state_indices;
/// Mamba state cache indices written with the final state of each request forward.
std::optional<infinicore::Tensor> mamba_final_state_indices;
/// Image pixel values for multi-modal models.
std::optional<std::vector<infinicore::Tensor>> pixel_values;
/// Image placeholder bounds for MiniCPM-V style replacement.
Expand Down Expand Up @@ -81,6 +85,8 @@ class RankWorker {
bool enable_graph_compiling,
backends::AttentionBackend attention_backend);

~RankWorker();

// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
const infinicore::Tensor &param);
Expand Down
12 changes: 12 additions & 0 deletions csrc/global_state/forward_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,22 @@ struct MultiModalMetadata {
std::optional<std::vector<size_t>> image_req_ids;
};

struct MambaMetadata {
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std::optional<infinicore::Tensor> input_offsets;
/// State cache indices read at the start of each request forward.
std::optional<infinicore::Tensor> init_state_indices;
/// State cache indices written with the final state of each request forward.
std::optional<infinicore::Tensor> final_state_indices;
};

struct ForwardContext {
AttentionMetadata attn_metadata;
MambaMetadata mamba_metadata;
MultiModalMetadata mm_metadata;
std::vector<infinicore::Tensor> kv_cache_vec;
std::vector<infinicore::Tensor> conv_state_vec;
std::vector<infinicore::Tensor> ssm_state_vec;
};

void initialize_forward_context(ForwardContext &forward_context);
Expand Down
15 changes: 14 additions & 1 deletion csrc/layers/linear/fused_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,20 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: QKVParallelLinear(hidden_size, head_dim, num_q_head, num_kv_head, quantization, bias, dtype, device, rank_info) {
: QKVParallelLinear(hidden_size, head_dim, head_dim, head_dim, num_q_head, num_kv_head, num_kv_head, bias, bias, bias, q_name, k_name, v_name, register_fn, quantization, dtype, device, rank_info) {
}

QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
const std::string &q_name, const std::string &k_name, const std::string &v_name,
RegisterParamFn register_fn,
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization,
const infinicore::DataType &dtype,
const infinicore::Device &device,
engine::distributed::RankInfo rank_info)
: QKVParallelLinear(hidden_size, q_dim, k_dim, v_dim, num_q_head, num_k_head, num_v_head, q_bias, k_bias, v_bias, quantization, dtype, device, rank_info) {
register_fn_ = register_fn;
split_infos_ = {
{q_name, 0, q_out_size_, 0},
Expand Down
11 changes: 11 additions & 0 deletions csrc/layers/linear/fused_linear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear {
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

QKVParallelLinear(size_t hidden_size,
size_t q_dim, size_t k_dim, size_t v_dim,
size_t num_q_head, size_t num_k_head, size_t num_v_head,
bool q_bias, bool k_bias, bool v_bias,
const std::string &q_name, const std::string &k_name, const std::string &v_name,
RegisterParamFn register_fn,
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization = nullptr,
const infinicore::DataType &dtype = infinicore::DataType::F32,
const infinicore::Device &device = infinicore::Device(),
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

QKVParallelLinear(size_t hidden_size,
size_t head_dim,
size_t num_q_head, size_t num_kv_head,
Expand Down
4 changes: 4 additions & 0 deletions csrc/models/infinilm_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class InfinilmModel : public infinicore::nn::Module {
std::optional<infinicore::Tensor> block_tables;
/// Slot ids for each token `[seq]`. Used for paged cache.
std::optional<infinicore::Tensor> slot_mapping;
/// Mamba state cache indices read at the start of each request forward, of shape `[num_requests]`.
std::optional<infinicore::Tensor> mamba_init_state_indices;
/// Mamba state cache indices written with the final state of each request forward, of shape `[num_requests]`.
std::optional<infinicore::Tensor> mamba_final_state_indices;
/// Image pixel values for multi-modal models.
/// Vector of tensors. Shape is model-specific (e.g. LLaVA: [batch, 3, H, W], MiniCPM-V: [n_patch, 3, filter_H, H * W / filter_H]).
std::optional<std::vector<infinicore::Tensor>> pixel_values;
Expand Down
140 changes: 140 additions & 0 deletions csrc/models/qwen3_5/qwen3_5_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#include "qwen3_5_attention.hpp"
#include "../../global_state/global_state.hpp"
#include "../../layers/attention/attention.hpp"
#include "../../utils.hpp"
#include <infinicore/ops/mul.hpp>
#include <infinicore/ops/sigmoid.hpp>

namespace infinilm::models::qwen3_5 {

Qwen35Attention::Qwen35Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config,
size_t layer_idx,
const infinicore::Device &device) {
layer_idx_ = layer_idx;
hidden_size_ = model_config->get<size_t>("hidden_size");
head_dim_ = model_config->get<size_t>("head_dim");
rotary_dim_ = model_config->get_rotary_dim();

const auto &dtype{model_config->get_dtype()};
size_t total_num_heads = model_config->get<size_t>("num_attention_heads");
size_t total_num_kv_heads = model_config->get<size_t>("num_key_value_heads");
bool use_bias = model_config->get_or<bool>("attention_bias", true);
bool use_output_bias = model_config->get_or<bool>("attention_output_bias", false);
double rms_norm_eps = model_config->get<double>("rms_norm_eps");

attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend;
const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info();
int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank();
int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size();
if ((total_num_kv_heads < tp_size) || (0 != (total_num_kv_heads % tp_size))) {
throw std::runtime_error("infinilm::models::qwen3_5::Qwen35Attention: num_key_value_heads must be divisible by tp_size");
}

num_attention_heads_ = total_num_heads / tp_size;
num_key_value_heads_ = total_num_kv_heads / tp_size;

auto quantization_method = model_config->get_quantization_method();
auto register_fn = [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); };
qkv_proj_ = std::make_shared<Qwen35FusedQKVLinear>(
hidden_size_, head_dim_, total_num_heads, total_num_kv_heads,
"q_proj", "k_proj", "v_proj", register_fn,
quantization_method, use_bias, dtype, device, rank_info);
o_proj_ = this->register_module<layers::linear::RowParallelLinear>(
"o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method,
use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm);

rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device);

float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
attn_ = std::make_shared<infinilm::layers::attention::AttentionLayer>(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_,
kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_);

INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device);
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device);

infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_);
}

infinicore::Tensor Qwen35Attention::forward(const infinicore::Tensor &positions,
const infinicore::Tensor &hidden_states) const {
if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) {
return forward_static_(positions, hidden_states);
}
return forward_paged_(positions, hidden_states);
}

infinicore::Tensor Qwen35Attention::forward_static_(const infinicore::Tensor &position_ids,
const infinicore::Tensor &hidden_states) const {
auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape();
size_t batch_size = shape[0];
size_t seq_len = shape[1];

auto [q, gate, k, v] = qkv_proj_->forward_split(hidden_states_mutable);

q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_}));
k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_}));

auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_});

auto pos_shape = position_ids->shape();
infinicore::Tensor pos_ids_for_rope = position_ids;
if (pos_shape.size() == 2) {
auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
pos_ids_for_rope = pos_narrowed->view({pos_shape[1]});
} else if (pos_shape.size() == 1) {
pos_ids_for_rope = position_ids;
} else {
throw std::runtime_error("infinilm::models::qwen3_5::Qwen35Attention: Unexpected position_ids shape");
}

auto q_rotary = q_reshaped->narrow({{3, 0, rotary_dim_}});
auto k_rotary = k_reshaped->narrow({{3, 0, rotary_dim_}});
rotary_emb_->forward(q_rotary, pos_ids_for_rope, true);
rotary_emb_->forward(k_rotary, pos_ids_for_rope, true);

auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped);
attn_output = infinicore::op::mul(attn_output, infinicore::op::sigmoid(gate));
return o_proj_->forward(attn_output);
}

infinicore::Tensor Qwen35Attention::forward_paged_(const infinicore::Tensor &position_ids,
const infinicore::Tensor &hidden_states) const {
auto hidden_states_mutable = hidden_states;
auto shape = hidden_states->shape();
size_t batch_size = shape[0];
size_t seq_len = shape[1];

ASSERT_EQ(batch_size, 1);

auto [q, gate, k, v] = qkv_proj_->forward_split(hidden_states_mutable);

auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_});
auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_});
auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_});
q_reshaped = q_norm_->forward(q_reshaped);
k_reshaped = k_norm_->forward(k_reshaped);

auto pos_shape = position_ids->shape();
infinicore::Tensor pos_ids_for_rope = position_ids;
if (pos_shape.size() == 2) {
auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
pos_ids_for_rope = pos_narrowed->view({pos_shape[1]});
} else if (pos_shape.size() == 1) {
pos_ids_for_rope = position_ids;
} else {
throw std::runtime_error("Unexpected position_ids shape");
}

auto q_rotary = q_reshaped->narrow({{2, 0, rotary_dim_}});
auto k_rotary = k_reshaped->narrow({{2, 0, rotary_dim_}});
rotary_emb_->forward(q_rotary, pos_ids_for_rope, true);
rotary_emb_->forward(k_rotary, pos_ids_for_rope, true);

auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped);
attn_output = infinicore::op::mul(attn_output, infinicore::op::sigmoid(gate)->view(attn_output->shape()));
return o_proj_->forward(attn_output);
}
} // namespace infinilm::models::qwen3_5
Loading