From e0848a56d4810b5878c3d7eeded3dede85d5de59 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 26 Jun 2026 09:04:42 +0000 Subject: [PATCH 1/4] issue/446 support qwen3.5 llm --- csrc/cache/mamba_cache.cpp | 59 ++++++ csrc/cache/mamba_cache.hpp | 39 ++++ csrc/engine/infer_engine.cpp | 7 + csrc/engine/rank_worker.cpp | 5 + csrc/engine/rank_worker.hpp | 6 + csrc/global_state/forward_context.hpp | 12 ++ csrc/models/infinilm_model.hpp | 4 + csrc/models/qwen3_5/qwen3_5_attention.cpp | 140 ++++++++++++++ csrc/models/qwen3_5/qwen3_5_attention.hpp | 48 +++++ csrc/models/qwen3_5/qwen3_5_decoderLayer.cpp | 66 +++++++ csrc/models/qwen3_5/qwen3_5_decoderLayer.hpp | 37 ++++ csrc/models/qwen3_5/qwen3_5_for_causal_lm.cpp | 95 ++++++++++ csrc/models/qwen3_5/qwen3_5_for_causal_lm.hpp | 26 +++ .../qwen3_5/qwen3_5_fused_qkv_linear.cpp | 73 ++++++++ .../qwen3_5/qwen3_5_fused_qkv_linear.hpp | 50 +++++ csrc/models/qwen3_5/qwen3_5_model.cpp | 44 +++++ csrc/models/qwen3_5/qwen3_5_model.hpp | 29 +++ csrc/models/qwen3_5/qwen3_5_vision.cpp | 133 ++++++++++++++ csrc/models/qwen3_5/qwen3_5_vision.hpp | 120 ++++++++++++ .../qwen3_next_allocate_kv_cache_tensors.cpp | 129 +++++++++---- .../qwen3_next_allocate_kv_cache_tensors.hpp | 26 +++ .../qwen3_next/qwen3_next_decoderLayer.cpp | 4 +- .../qwen3_next/qwen3_next_for_causal_lm.cpp | 15 +- .../qwen3_next/qwen3_next_for_causal_lm.hpp | 5 - .../qwen3_next/qwen3_next_gated_deltanet.cpp | 138 +++++++++++--- .../qwen3_next/qwen3_next_gated_deltanet.hpp | 37 ++-- csrc/pybind11/engine/engine.hpp | 8 + examples/test_infer.py | 6 + python/infinilm/infer_engine.py | 18 ++ python/infinilm/llm/llm.py | 144 ++------------- python/infinilm/modeling_utils.py | 19 ++ python/infinilm/processors/__init__.py | 11 +- .../processors/basic_llm_processor.py | 1 + .../infinilm/processors/qwen3_5_processor.py | 171 ++++++++++++++++++ 34 files changed, 1489 insertions(+), 236 deletions(-) create mode 100644 csrc/cache/mamba_cache.cpp create mode 100644 csrc/cache/mamba_cache.hpp create mode 100644 csrc/models/qwen3_5/qwen3_5_attention.cpp create mode 100644 csrc/models/qwen3_5/qwen3_5_attention.hpp create mode 100644 csrc/models/qwen3_5/qwen3_5_decoderLayer.cpp create mode 100644 csrc/models/qwen3_5/qwen3_5_decoderLayer.hpp create mode 100644 csrc/models/qwen3_5/qwen3_5_for_causal_lm.cpp create mode 100644 csrc/models/qwen3_5/qwen3_5_for_causal_lm.hpp create mode 100644 csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.cpp create mode 100644 csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.hpp create mode 100644 csrc/models/qwen3_5/qwen3_5_model.cpp create mode 100644 csrc/models/qwen3_5/qwen3_5_model.hpp create mode 100644 csrc/models/qwen3_5/qwen3_5_vision.cpp create mode 100644 csrc/models/qwen3_5/qwen3_5_vision.hpp create mode 100644 csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.hpp create mode 100644 python/infinilm/processors/qwen3_5_processor.py diff --git a/csrc/cache/mamba_cache.cpp b/csrc/cache/mamba_cache.cpp new file mode 100644 index 000000000..0f65f1879 --- /dev/null +++ b/csrc/cache/mamba_cache.cpp @@ -0,0 +1,59 @@ +#include "mamba_cache.hpp" + +#include "../global_state/global_state.hpp" +#include "infinicore/context/context.hpp" + +namespace infinilm::cache { + +infinicore::Tensor MambaCache::create_layer_conv_state( + infinicore::Size k_dim, + infinicore::Size v_dim, + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + infinicore::Size conv_kernel_dim, + infinicore::DataType dtype, + size_t pool_size) { + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + const auto [num_rank_k_heads, num_rank_v_heads] = get_rank_head_counts(num_k_heads, num_v_heads, rank_info.tp_size); + const size_t conv_state_len = conv_kernel_dim > 0 ? conv_kernel_dim - 1 : 0; + const size_t conv_dim = 2 * num_rank_k_heads * k_dim + num_rank_v_heads * v_dim; + + auto conv_state = infinicore::Tensor::zeros( + {pool_size, conv_dim, conv_state_len}, + dtype, + rank_info.device); + infinicore::context::syncStream(); + return conv_state; +} + +infinicore::Tensor MambaCache::create_layer_ssm_state( + infinicore::Size k_dim, + infinicore::Size v_dim, + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + infinicore::DataType dtype, + size_t pool_size) { + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + const auto rank_head_counts = get_rank_head_counts(num_k_heads, num_v_heads, rank_info.tp_size); + const size_t num_rank_v_heads = rank_head_counts.second; + + auto ssm_state = infinicore::Tensor::zeros( + {pool_size, num_rank_v_heads, v_dim, k_dim}, + dtype, + rank_info.device); + + infinicore::context::syncStream(); + return ssm_state; +} + +std::pair MambaCache::get_rank_head_counts( + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + size_t tp_size) { + bool is_kv_replica = (num_k_heads < tp_size && num_v_heads < tp_size && num_k_heads == num_v_heads && tp_size % num_k_heads == 0); + size_t num_rank_k_heads = is_kv_replica ? 1 : (num_k_heads / tp_size); + size_t num_rank_v_heads = is_kv_replica ? 1 : (num_v_heads / tp_size); + return {num_rank_k_heads, num_rank_v_heads}; +} + +} // namespace infinilm::cache diff --git a/csrc/cache/mamba_cache.hpp b/csrc/cache/mamba_cache.hpp new file mode 100644 index 000000000..f8d25415c --- /dev/null +++ b/csrc/cache/mamba_cache.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "base_cache.hpp" + +#include "infinicore/device.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinilm::cache { + +class MambaCache { +public: + static infinicore::Tensor create_layer_conv_state( + infinicore::Size k_dim, + infinicore::Size v_dim, + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + infinicore::Size conv_kernel_dim, + infinicore::DataType dtype, + size_t pool_size); + + static infinicore::Tensor create_layer_ssm_state( + infinicore::Size k_dim, + infinicore::Size v_dim, + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + infinicore::DataType dtype, + size_t pool_size); + +private: + static std::pair get_rank_head_counts( + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + size_t tp_size); +}; + +} // namespace infinilm::cache diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 7c8429236..f77fe0250 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -153,6 +153,8 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { to_device(cu_seqlens), to_device(block_tables), to_device(slot_mapping), + to_device(mamba_init_state_indices), + to_device(mamba_final_state_indices), to_device_vec(pixel_values), to_device_vec(image_bound), to_device_vec(tgt_sizes), @@ -166,6 +168,11 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { input.block_tables, input.slot_mapping}; + infinilm::global_state::get_forward_context().mamba_metadata = { + input.input_offsets, + input.mamba_init_state_indices, + input.mamba_final_state_indices}; + global_state::get_forward_context().mm_metadata = { image_req_ids}; diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 52bc237d7..b295d754f 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -35,6 +35,10 @@ RankWorker::RankWorker( cv_.wait(lk, [&] { return init_done_; }); } +RankWorker::~RankWorker() { + close(); +} + std::string RankWorker::info() const { std::stringstream ss; @@ -513,6 +517,7 @@ void RankWorker::thread_loop() { // Top-level exception: ensure any waiters are woken and the thread exits cleanly. { std::lock_guard lk(mutex_); + init_done_ = true; should_exit_ = true; job_done_ = true; } diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index a26a3cd69..56ecfef5b 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -52,6 +52,10 @@ class RankWorker { std::optional block_tables; /// Slot ids for each token `[seq]`. Used for paged cache. std::optional slot_mapping; + /// Mamba state cache indices read at the start of each request forward. + std::optional mamba_init_state_indices; + /// Mamba state cache indices written with the final state of each request forward. + std::optional mamba_final_state_indices; /// Image pixel values for multi-modal models. std::optional> pixel_values; /// Image placeholder bounds for MiniCPM-V style replacement. @@ -81,6 +85,8 @@ class RankWorker { bool enable_graph_compiling, backends::AttentionBackend attention_backend); + ~RankWorker(); + // Submit a parameter load job and wait until the load completes on the worker thread. void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/global_state/forward_context.hpp b/csrc/global_state/forward_context.hpp index 2568fc7ee..18a841394 100644 --- a/csrc/global_state/forward_context.hpp +++ b/csrc/global_state/forward_context.hpp @@ -44,10 +44,22 @@ struct MultiModalMetadata { std::optional> image_req_ids; }; +struct MambaMetadata { + /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. + std::optional input_offsets; + /// State cache indices read at the start of each request forward. + std::optional init_state_indices; + /// State cache indices written with the final state of each request forward. + std::optional final_state_indices; +}; + struct ForwardContext { AttentionMetadata attn_metadata; + MambaMetadata mamba_metadata; MultiModalMetadata mm_metadata; std::vector kv_cache_vec; + std::vector conv_state_vec; + std::vector ssm_state_vec; }; void initialize_forward_context(ForwardContext &forward_context); diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index c76a29f08..1e9f1d3c9 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -34,6 +34,10 @@ class InfinilmModel : public infinicore::nn::Module { std::optional block_tables; /// Slot ids for each token `[seq]`. Used for paged cache. std::optional slot_mapping; + /// Mamba state cache indices read at the start of each request forward, of shape `[num_requests]`. + std::optional mamba_init_state_indices; + /// Mamba state cache indices written with the final state of each request forward, of shape `[num_requests]`. + std::optional mamba_final_state_indices; /// Image pixel values for multi-modal models. /// Vector of tensors. Shape is model-specific (e.g. LLaVA: [batch, 3, H, W], MiniCPM-V: [n_patch, 3, filter_H, H * W / filter_H]). std::optional> pixel_values; diff --git a/csrc/models/qwen3_5/qwen3_5_attention.cpp b/csrc/models/qwen3_5/qwen3_5_attention.cpp new file mode 100644 index 000000000..de9093f95 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_attention.cpp @@ -0,0 +1,140 @@ +#include "qwen3_5_attention.hpp" +#include "../../global_state/global_state.hpp" +#include "../../layers/attention/attention.hpp" +#include "../../utils.hpp" +#include +#include + +namespace infinilm::models::qwen3_5 { + +Qwen35Attention::Qwen35Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + hidden_size_ = model_config->get("hidden_size"); + head_dim_ = model_config->get("head_dim"); + rotary_dim_ = model_config->get_rotary_dim(); + + const auto &dtype{model_config->get_dtype()}; + size_t total_num_heads = model_config->get("num_attention_heads"); + size_t total_num_kv_heads = model_config->get("num_key_value_heads"); + bool use_bias = model_config->get_or("attention_bias", true); + bool use_output_bias = model_config->get_or("attention_output_bias", false); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + if ((total_num_kv_heads < tp_size) || (0 != (total_num_kv_heads % tp_size))) { + throw std::runtime_error("infinilm::models::qwen3_5::Qwen35Attention: num_key_value_heads must be divisible by tp_size"); + } + + num_attention_heads_ = total_num_heads / tp_size; + num_key_value_heads_ = total_num_kv_heads / tp_size; + + auto quantization_method = model_config->get_quantization_method(); + auto register_fn = [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }; + qkv_proj_ = std::make_shared( + hidden_size_, head_dim_, total_num_heads, total_num_kv_heads, + "q_proj", "k_proj", "v_proj", register_fn, + quantization_method, use_bias, dtype, device, rank_info); + o_proj_ = this->register_module( + "o_proj", total_num_heads * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device); + + float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); + attn_ = std::make_shared(num_attention_heads_, head_dim_, scaling, num_key_value_heads_, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + + infinilm::layers::attention::init_kv_cache_quant_params(register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); +} + +infinicore::Tensor Qwen35Attention::forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const { + if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { + return forward_static_(positions, hidden_states); + } + return forward_paged_(positions, hidden_states); +} + +infinicore::Tensor Qwen35Attention::forward_static_(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + auto [q, gate, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); + k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); + + auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids; + } else { + throw std::runtime_error("infinilm::models::qwen3_5::Qwen35Attention: Unexpected position_ids shape"); + } + + auto q_rotary = q_reshaped->narrow({{3, 0, rotary_dim_}}); + auto k_rotary = k_reshaped->narrow({{3, 0, rotary_dim_}}); + rotary_emb_->forward(q_rotary, pos_ids_for_rope, true); + rotary_emb_->forward(k_rotary, pos_ids_for_rope, true); + + auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); + attn_output = infinicore::op::mul(attn_output, infinicore::op::sigmoid(gate)); + return o_proj_->forward(attn_output); +} + +infinicore::Tensor Qwen35Attention::forward_paged_(const infinicore::Tensor &position_ids, + const infinicore::Tensor &hidden_states) const { + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + ASSERT_EQ(batch_size, 1); + + auto [q, gate, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); + q_reshaped = q_norm_->forward(q_reshaped); + k_reshaped = k_norm_->forward(k_reshaped); + + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids; + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + auto q_rotary = q_reshaped->narrow({{2, 0, rotary_dim_}}); + auto k_rotary = k_reshaped->narrow({{2, 0, rotary_dim_}}); + rotary_emb_->forward(q_rotary, pos_ids_for_rope, true); + rotary_emb_->forward(k_rotary, pos_ids_for_rope, true); + + auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); + attn_output = infinicore::op::mul(attn_output, infinicore::op::sigmoid(gate)->view(attn_output->shape())); + return o_proj_->forward(attn_output); +} +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_attention.hpp b/csrc/models/qwen3_5/qwen3_5_attention.hpp new file mode 100644 index 000000000..68d0586f0 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_attention.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include "qwen3_5_fused_qkv_linear.hpp" + +namespace infinilm::models::qwen3_5 { +class Qwen35Attention : public infinicore::nn::Module { +public: + Qwen35Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + + size_t layer_idx() const { return layer_idx_; } + size_t num_heads() const { return num_attention_heads_; } + size_t num_kv_heads() const { return num_key_value_heads_; } + size_t head_dim() const { return head_dim_; } + size_t hidden_size() const { return hidden_size_; } + +private: + infinicore::Tensor forward_static_(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + + infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + +protected: + std::shared_ptr qkv_proj_; + std::shared_ptr o_proj_; + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); + std::shared_ptr rotary_emb_; + + std::shared_ptr attn_; + ::infinilm::backends::AttentionBackend attention_backend_; + size_t layer_idx_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t hidden_size_; + size_t head_dim_; + size_t rotary_dim_; + + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); +}; +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_decoderLayer.cpp b/csrc/models/qwen3_5/qwen3_5_decoderLayer.cpp new file mode 100644 index 000000000..70964bb69 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_decoderLayer.cpp @@ -0,0 +1,66 @@ +#include "qwen3_5_decoderLayer.hpp" +#include "infinicore/ops.hpp" +#include +#include +#include + +namespace infinilm::models::qwen3_5 { + +Qwen35DecoderLayer::Qwen35DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : layer_idx_(layer_idx) { + + const auto &dtype{model_config->get_dtype()}; + size_t hidden_size = model_config->get("hidden_size"); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, model_config, device); + + const std::vector layer_types = model_config->get>("layer_types"); + layer_type_ = layer_types[layer_idx]; + if ("linear_attention" == layer_type_) { + INFINICORE_NN_MODULE_INIT(linear_attn, model_config, layer_idx, device); + } else if ("full_attention" == layer_type_) { + INFINICORE_NN_MODULE_INIT(self_attn, model_config, layer_idx, device); + } else { + throw std::runtime_error("infinilm::models::qwen3_5::Qwen35DecoderLayer: unsupported layer_type '" + layer_type_ + "' for layer " + std::to_string(layer_idx)); + } +} + +std::tuple Qwen35DecoderLayer::forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) { + input_layernorm_->forward_inplace(hidden_states, residual); + if ("linear_attention" == layer_type_) { + hidden_states = linear_attn_->forward(hidden_states); + } else if ("full_attention" == layer_type_) { + hidden_states = self_attn_->forward(positions, hidden_states); + } + + post_attention_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = mlp_->forward(hidden_states); + return std::make_tuple(hidden_states, residual); +} + +infinicore::Tensor Qwen35DecoderLayer::forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states) { + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + if ("linear_attention" == layer_type_) { + hidden_states = linear_attn_->forward(hidden_states); + } else if ("full_attention" == layer_type_) { + hidden_states = self_attn_->forward(positions, hidden_states); + } + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + return hidden_states; +} + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_decoderLayer.hpp b/csrc/models/qwen3_5/qwen3_5_decoderLayer.hpp new file mode 100644 index 000000000..751586bbe --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_decoderLayer.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "../qwen3_next/qwen3_next_gated_deltanet.hpp" +#include "qwen3_5_attention.hpp" +#include +#include + +namespace infinilm::models::qwen3_5 { + +class Qwen35DecoderLayer : public infinicore::nn::Module { +public: + Qwen35DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + std::tuple forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states); + + size_t layer_idx() const { return layer_idx_; } + +protected: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(Qwen35Attention, self_attn); + INFINICORE_NN_MODULE(qwen3_next::Qwen3NextGatedDeltaNet, linear_attn); + INFINICORE_NN_MODULE(infinilm::layers::MLP, mlp); + +private: + size_t layer_idx_; + std::string layer_type_; +}; + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_for_causal_lm.cpp b/csrc/models/qwen3_5/qwen3_5_for_causal_lm.cpp new file mode 100644 index 000000000..6752082fe --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_for_causal_lm.cpp @@ -0,0 +1,95 @@ +#include "qwen3_5_for_causal_lm.hpp" + +#include "../qwen3_next/qwen3_next_for_causal_lm.hpp" +#include "../../global_state/global_state.hpp" +#include "../models_registry.hpp" +#include +#include +#include + +namespace infinilm::models::qwen3_5 { + +Qwen35ForCausalLM::Qwen35ForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + size_t hidden_size = model_config->get("hidden_size"); + size_t vocab_size = model_config->get("vocab_size"); + const auto &dtype{model_config->get_dtype()}; + + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); +} + +infinilm::InfinilmModel::Output Qwen35ForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; +} + +void Qwen35ForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { + if (cache_config == nullptr) { + cache_config_.reset(); + } else { + cache_config_ = cache_config->unique_copy(); + } + model_->reset_cache(cache_config); +} + +std::shared_ptr create_qwen3_5_model_config(std::shared_ptr model_config) { + const std::string model_type = model_config->get("model_type"); + if ("qwen3_5" != model_type) { + throw std::runtime_error("infinilm::models::qwen3_5::create_qwen3_next_model_config: model_type is not qwen3_5"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + if (config_json.contains("text_config") && config_json["text_config"].is_object()) { + const nlohmann::json &text_config_json = config_json["text_config"]; + for (auto it = text_config_json.begin(); it != text_config_json.end(); ++it) { + if (!config_json.contains(it.key())) { + config_json[it.key()] = it.value(); + } + } + if (!config_json.contains("dtype") && config_json.contains("torch_dtype")) { + config_json["dtype"] = config_json["torch_dtype"]; + } + } + if (!config_json.contains("rope_theta") && + config_json.contains("rope_parameters") && + config_json["rope_parameters"].is_object() && + config_json["rope_parameters"].contains("rope_theta")) { + // TODO: This is only a temporary loader shim. Qwen3.6 uses mRoPE, + // which needs proper support in InfiniCore instead of treating it as + // plain RoPE through a top-level rope_theta. + config_json["rope_theta"] = config_json["rope_parameters"]["rope_theta"]; + } + if (!config_json.contains("partial_rotary_factor") && + config_json.contains("rope_parameters") && + config_json["rope_parameters"].is_object() && + config_json["rope_parameters"].contains("partial_rotary_factor")) { + config_json["partial_rotary_factor"] = config_json["rope_parameters"]["partial_rotary_factor"]; + } + if (!config_json.contains("layer_types")) { + size_t full_attention_interval = model_config->get("full_attention_interval"); + size_t num_hidden_layers = model_config->get("num_hidden_layers"); + std::vector layer_types; + layer_types.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; i++) { + layer_types.push_back(bool((i + 1) % full_attention_interval) ? "linear_attention" : "full_attention"); + } + config_json["layer_types"] = layer_types; + } + + if (!config_json.contains("attention_bias")) { + config_json["attention_bias"] = false; + } + return model_config; +} + +} // namespace infinilm::models::qwen3_5 + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + qwen3_5, + infinilm::models::qwen3_5::Qwen35ForCausalLM, + infinilm::models::qwen3_5::create_qwen3_5_model_config); +} // namespace diff --git a/csrc/models/qwen3_5/qwen3_5_for_causal_lm.hpp b/csrc/models/qwen3_5/qwen3_5_for_causal_lm.hpp new file mode 100644 index 000000000..e2baddae8 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_for_causal_lm.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "qwen3_5_model.hpp" +#include +#include + +namespace infinilm::models::qwen3_5 { + + +class Qwen35ForCausalLM : public InfinilmModel { +public: + Qwen35ForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + + void reset_cache(const cache::CacheConfig *cache_config) override; + +protected: + INFINICORE_NN_MODULE(Qwen35Model, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_qwen3_5_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.cpp b/csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.cpp new file mode 100644 index 000000000..aca1b5937 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.cpp @@ -0,0 +1,73 @@ +#include "qwen3_5_fused_qkv_linear.hpp" + +namespace infinilm::models::qwen3_5 { + +Qwen35FusedQKVLinear::Qwen35FusedQKVLinear(size_t hidden_size, + size_t head_dim, + size_t num_q_head, + size_t num_kv_head, + const std::string &q_name, + const std::string &k_name, + const std::string &v_name, + infinilm::layers::linear::RegisterParamFn register_fn, + std::shared_ptr quantization, + bool bias, + const infinicore::DataType &dtype, + const infinicore::Device &device, + engine::distributed::RankInfo rank_info) + : infinilm::nn::ColumnParallelLinear( + hidden_size, + num_q_head * head_dim * 2 + num_kv_head * head_dim * calculate_kv_replicas(num_kv_head, rank_info.tp_size) * 2, + quantization == nullptr ? std::make_shared() : quantization, + bias, + dtype, + device, + rank_info.tp_rank, + rank_info.tp_size), + head_dim_(head_dim), + local_num_q_heads_(num_q_head / tp_size_), + q_proj_out_size_(num_q_head * head_dim * 2 / tp_size_), + q_out_size_(num_q_head * head_dim / tp_size_), + k_out_size_(calculate_kv_replicas(num_kv_head, rank_info.tp_size) * num_kv_head * head_dim / tp_size_), + v_out_size_(calculate_kv_replicas(num_kv_head, rank_info.tp_size) * num_kv_head * head_dim / tp_size_), + num_kv_head_(num_kv_head), + register_fn_(register_fn) { + split_infos_ = { + {q_name, 0, q_proj_out_size_, 0}, + {k_name, q_proj_out_size_, k_out_size_, num_kv_head_}, + {v_name, q_proj_out_size_ + k_out_size_, v_out_size_, num_kv_head_}, + }; + auto params = this->split_params(split_infos_, tp_rank_, tp_size_, num_kv_head_); + for (auto &sp : params) { + register_fn_(sp.full_name, std::move(sp.param)); + } +} + +std::tuple +Qwen35FusedQKVLinear::forward_split(infinicore::Tensor &input) { + auto output = this->forward(input); + auto shape = output->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + auto q_proj_out = output->narrow({{2, 0, q_proj_out_size_}}); + auto q_proj_heads = q_proj_out->view({batch_size, seq_len, local_num_q_heads_, head_dim_ * 2}); + auto q_out = q_proj_heads->narrow({{3, 0, head_dim_}}); + auto gate_out = q_proj_heads->narrow({{3, head_dim_, head_dim_}}); + auto k_out = output->narrow({{2, q_proj_out_size_, k_out_size_}}); + auto v_out = output->narrow({{2, q_proj_out_size_ + k_out_size_, v_out_size_}}); + + return std::make_tuple(q_out, gate_out, k_out, v_out); +} + +void Qwen35FusedQKVLinear::process_weights_after_loading() { + BaseLinear::process_weights_after_loading(); + if (register_fn_ && !split_infos_.empty()) { + auto params = this->split_params(split_infos_, tp_rank_, tp_size_, num_kv_head_); + for (auto &sp : params) { + register_fn_(sp.full_name, std::move(sp.param)); + } + } +} + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.hpp b/csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.hpp new file mode 100644 index 000000000..c155030a2 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_fused_qkv_linear.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "../../layers/linear/linear.hpp" + +namespace infinilm::models::qwen3_5 { + +class Qwen35FusedQKVLinear : public infinilm::nn::ColumnParallelLinear { +public: + Qwen35FusedQKVLinear(size_t hidden_size, + size_t head_dim, + size_t num_q_head, + size_t num_kv_head, + const std::string &q_name, + const std::string &k_name, + const std::string &v_name, + infinilm::layers::linear::RegisterParamFn register_fn, + std::shared_ptr quantization = nullptr, + bool bias = false, + const infinicore::DataType &dtype = infinicore::DataType::F32, + const infinicore::Device &device = infinicore::Device(), + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + + void process_weights_after_loading() override; + + std::tuple + forward_split(infinicore::Tensor &input); + +private: + static size_t calculate_kv_replicas(size_t num_kv_head, size_t tp_size) { + if (num_kv_head % tp_size == 0) { + return 1; + } + if (tp_size % num_kv_head == 0) { + return (tp_size + num_kv_head - 1) / num_kv_head; + } + throw std::runtime_error("Invalid KV head configuration"); + } + + size_t head_dim_; + size_t local_num_q_heads_; + size_t q_proj_out_size_; + size_t q_out_size_; + size_t k_out_size_; + size_t v_out_size_; + size_t num_kv_head_; + infinilm::layers::linear::RegisterParamFn register_fn_; + std::vector split_infos_; +}; + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_model.cpp b/csrc/models/qwen3_5/qwen3_5_model.cpp new file mode 100644 index 000000000..cf4ab0044 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_model.cpp @@ -0,0 +1,44 @@ +#include "qwen3_5_model.hpp" + +#include "../../global_state/global_state.hpp" +#include "../qwen3_next/qwen3_next_allocate_kv_cache_tensors.hpp" + +#include + +namespace infinilm::models::qwen3_5 { + +Qwen35Model::Qwen35Model(std::shared_ptr model_config, + const infinicore::Device &device) + : model_config_(model_config) { + const auto &dtype{model_config->get_dtype()}; + nlohmann::json &config_json = model_config->get_config_json(); + + if (config_json.contains("vision_config") && !config_json["vision_config"].is_null()) { + INFINICORE_NN_MODULE_INIT(visual, config_json["vision_config"], dtype, device); + } + INFINICORE_NN_MODULE_INIT(language_model, model_config, device); +} + +infinicore::Tensor Qwen35Model::forward(const InfinilmModel::Input &input) const { + return language_model_->forward(input); +} + +void Qwen35Model::reset_cache(const cache::CacheConfig *cache_config) { + if (nullptr == cache_config) { + return; + } + + auto &forward_context = infinilm::global_state::get_forward_context(); + forward_context.kv_cache_vec.clear(); + forward_context.conv_state_vec.clear(); + forward_context.ssm_state_vec.clear(); + + const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; + + auto cache_vectors = infinilm::models::qwen3_next::qwen3_next_allocate_cache_tensors(cache_config, model_config_, attention_backend); + forward_context.kv_cache_vec = std::move(cache_vectors.kv_cache_tensors); + forward_context.conv_state_vec = std::move(cache_vectors.conv_state_tensors); + forward_context.ssm_state_vec = std::move(cache_vectors.ssm_state_tensors); +} + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_model.hpp b/csrc/models/qwen3_5/qwen3_5_model.hpp new file mode 100644 index 000000000..01b713807 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_model.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include "../infinilm_model.hpp" +#include "qwen3_5_decoderLayer.hpp" +#include "qwen3_5_vision.hpp" + +namespace infinilm::models::qwen3_5 { + +using Qwen35LanguageModel = infinilm::layers::causal_lm_templates::TextModel; + +class Qwen35Model : public infinicore::nn::Module { +public: + Qwen35Model(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const InfinilmModel::Input &input) const; + + void reset_cache(const cache::CacheConfig *cache_config); + +protected: + INFINICORE_NN_MODULE(Qwen35VisionModel, visual); + INFINICORE_NN_MODULE(Qwen35LanguageModel, language_model); + +private: + std::shared_ptr model_config_; +}; + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_vision.cpp b/csrc/models/qwen3_5/qwen3_5_vision.cpp new file mode 100644 index 000000000..759a14817 --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_vision.cpp @@ -0,0 +1,133 @@ +#include "qwen3_5_vision.hpp" + +#include +#include +#include + +namespace infinilm::models::qwen3_5 { +namespace { + +size_t get_size_or_first(const nlohmann::json &config, const char *key, size_t default_value) { + if (!config.contains(key) || config.at(key).is_null()) { + return default_value; + } + const auto &value = config.at(key); + if (value.is_array()) { + return value.empty() ? default_value : value.at(0).get(); + } + return value.get(); +} + +} // namespace + +Qwen35VisionPatchProj::Qwen35VisionPatchProj(size_t in_channels, + size_t hidden_size, + size_t temporal_patch_size, + size_t patch_size, + const infinicore::DataType &dtype, + const infinicore::Device &device) + : in_channels_(in_channels), + hidden_size_(hidden_size), + temporal_patch_size_(temporal_patch_size), + patch_size_(patch_size) { + INFINICORE_NN_PARAMETER_INIT(weight, ({hidden_size_, in_channels_, temporal_patch_size_, patch_size_, patch_size_}, dtype, device)); + INFINICORE_NN_PARAMETER_INIT(bias, ({hidden_size_}, dtype, device)); +} + +infinicore::Tensor Qwen35VisionPatchProj::forward(const infinicore::Tensor &hidden_states) const { + throw std::runtime_error("Qwen35VisionPatchProj::forward is not implemented yet"); +} + +Qwen35VisionPatchEmbed::Qwen35VisionPatchEmbed(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + const size_t in_channels = config.value("in_channels", 3); + const size_t hidden_size = config.value("hidden_size", 1152); + const size_t temporal_patch_size = get_size_or_first(config, "temporal_patch_size", 2); + const size_t patch_size = get_size_or_first(config, "patch_size", 16); + INFINICORE_NN_MODULE_INIT(proj, in_channels, hidden_size, temporal_patch_size, patch_size, dtype, device); +} + +infinicore::Tensor Qwen35VisionPatchEmbed::forward(const infinicore::Tensor &hidden_states) const { + return proj_->forward(hidden_states); +} + +Qwen35VisionAttention::Qwen35VisionAttention(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device) + : hidden_size_(config.value("hidden_size", 1152)), + num_heads_(config.value("num_heads", 16)) { + INFINICORE_NN_MODULE_INIT(qkv, hidden_size_, hidden_size_ * 3, true, dtype, device); + INFINICORE_NN_MODULE_INIT(proj, hidden_size_, hidden_size_, true, dtype, device); +} + +infinicore::Tensor Qwen35VisionAttention::forward(const infinicore::Tensor &hidden_states) const { + throw std::runtime_error("Qwen35VisionAttention::forward is not implemented yet"); +} + +Qwen35VisionMLP::Qwen35VisionMLP(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + const size_t hidden_size = config.value("hidden_size", 1152); + const size_t intermediate_size = config.value("intermediate_size", 4304); + INFINICORE_NN_MODULE_INIT(linear_fc1, hidden_size, intermediate_size, true, dtype, device); + INFINICORE_NN_MODULE_INIT(linear_fc2, intermediate_size, hidden_size, true, dtype, device); +} + +infinicore::Tensor Qwen35VisionMLP::forward(const infinicore::Tensor &hidden_states) const { + throw std::runtime_error("Qwen35VisionMLP::forward is not implemented yet"); +} + +Qwen35VisionBlock::Qwen35VisionBlock(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + const size_t hidden_size = config.value("hidden_size", 1152); + const double norm_eps = config.value("layer_norm_eps", config.value("rms_norm_eps", 1e-6)); + INFINICORE_NN_MODULE_INIT(norm1, hidden_size, norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(norm2, hidden_size, norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(attn, config, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, config, dtype, device); +} + +infinicore::Tensor Qwen35VisionBlock::forward(const infinicore::Tensor &hidden_states) const { + throw std::runtime_error("Qwen35VisionBlock::forward is not implemented yet"); +} + +Qwen35VisionPatchMerger::Qwen35VisionPatchMerger(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + const size_t hidden_size = config.value("hidden_size", 1152); + const size_t out_hidden_size = config.value("out_hidden_size", hidden_size); + const size_t spatial_merge_size = config.value("spatial_merge_size", 2); + const size_t merged_size = hidden_size * spatial_merge_size * spatial_merge_size; + const double norm_eps = config.value("layer_norm_eps", config.value("rms_norm_eps", 1e-6)); + INFINICORE_NN_MODULE_INIT(norm, hidden_size, norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(linear_fc1, merged_size, merged_size, true, dtype, device); + INFINICORE_NN_MODULE_INIT(linear_fc2, merged_size, out_hidden_size, true, dtype, device); +} + +infinicore::Tensor Qwen35VisionPatchMerger::forward(const infinicore::Tensor &hidden_states) const { + throw std::runtime_error("Qwen35VisionPatchMerger::forward is not implemented yet"); +} + +Qwen35VisionModel::Qwen35VisionModel(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device) { + const size_t hidden_size = config.value("hidden_size", 1152); + const size_t num_position_embeddings = config.value("num_position_embeddings", 2304); + const size_t depth = config.value("depth", config.value("num_hidden_layers", 27)); + + INFINICORE_NN_MODULE_INIT(patch_embed, config, dtype, device); + INFINICORE_NN_MODULE_INIT(pos_embed, num_position_embeddings, hidden_size, std::nullopt, dtype, device); + blocks_.reserve(depth); + for (size_t i = 0; i < depth; ++i) { + blocks_.push_back(this->register_module("blocks." + std::to_string(i), config, dtype, device)); + } + INFINICORE_NN_MODULE_INIT(merger, config, dtype, device); +} + +infinicore::Tensor Qwen35VisionModel::forward(const infinicore::Tensor &pixel_values) const { + throw std::runtime_error("Qwen35VisionModel::forward is not implemented yet"); +} + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_5/qwen3_5_vision.hpp b/csrc/models/qwen3_5/qwen3_5_vision.hpp new file mode 100644 index 000000000..012d3833f --- /dev/null +++ b/csrc/models/qwen3_5/qwen3_5_vision.hpp @@ -0,0 +1,120 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../../layers/common_modules.hpp" + +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/layer_norm.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" +#include + +namespace infinilm::models::qwen3_5 { + +class Qwen35VisionPatchProj : public infinicore::nn::Module { +public: + Qwen35VisionPatchProj(size_t in_channels, + size_t hidden_size, + size_t temporal_patch_size, + size_t patch_size, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + size_t in_channels_; + size_t hidden_size_; + size_t temporal_patch_size_; + size_t patch_size_; + + INFINICORE_NN_PARAMETER(weight); + INFINICORE_NN_PARAMETER(bias); +}; + +class Qwen35VisionPatchEmbed : public infinicore::nn::Module { +public: + Qwen35VisionPatchEmbed(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(Qwen35VisionPatchProj, proj); +}; + +class Qwen35VisionAttention : public infinicore::nn::Module { +public: + Qwen35VisionAttention(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + size_t hidden_size_; + size_t num_heads_; + + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, qkv); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, proj); +}; + +class Qwen35VisionMLP : public infinicore::nn::Module { +public: + Qwen35VisionMLP(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, linear_fc1); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, linear_fc2); +}; + +class Qwen35VisionBlock : public infinicore::nn::Module { +public: + Qwen35VisionBlock(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, norm1); + INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, norm2); + INFINICORE_NN_MODULE(Qwen35VisionAttention, attn); + INFINICORE_NN_MODULE(Qwen35VisionMLP, mlp); +}; + +class Qwen35VisionPatchMerger : public infinicore::nn::Module { +public: + Qwen35VisionPatchMerger(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, norm); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, linear_fc1); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, linear_fc2); +}; + +class Qwen35VisionModel : public infinicore::nn::Module { +public: + Qwen35VisionModel(const nlohmann::json &config, + const infinicore::DataType &dtype, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &pixel_values) const; + +private: + INFINICORE_NN_MODULE(Qwen35VisionPatchEmbed, patch_embed); + INFINICORE_NN_MODULE(infinicore::nn::Embedding, pos_embed); + INFINICORE_NN_MODULE_VEC(Qwen35VisionBlock, blocks); + INFINICORE_NN_MODULE(Qwen35VisionPatchMerger, merger); +}; + +} // namespace infinilm::models::qwen3_5 diff --git a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp index a95abab82..b077732c8 100644 --- a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp +++ b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp @@ -1,13 +1,16 @@ -#include "../../backends/attention_backends.hpp" -#include "../../cache/kv_cache.hpp" -#include "qwen3_next_for_causal_lm.hpp" +#include "qwen3_next_allocate_kv_cache_tensors.hpp" + +#include "../../global_state/global_state.hpp" +#include "../../utils.hpp" +#include "infinicore/context/context.hpp" +#include #include -#include +#include #include namespace infinilm::models::qwen3_next { -std::vector qwen3_next_allocate_kv_cache_tensors( +AllocatedHybridCache qwen3_next_allocate_cache_tensors( const cache::CacheConfig *cache_config, const std::shared_ptr &text_config, const backends::AttentionBackend &attention_backend) { @@ -18,36 +21,92 @@ std::vector qwen3_next_allocate_kv_cache_tensors( throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: text_config is null"); } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + const size_t head_dim = text_config->get("head_dim"); + const size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const size_t max_position_embeddings = text_config->get("max_position_embeddings"); + + const size_t linear_conv_kernel_dim = text_config->get("linear_conv_kernel_dim"); + const size_t linear_key_head_dim = text_config->get("linear_key_head_dim"); + const size_t linear_num_key_heads = text_config->get("linear_num_key_heads"); + const size_t linear_num_value_heads = text_config->get("linear_num_value_heads"); + const size_t linear_value_head_dim = text_config->get("linear_value_head_dim"); + + const auto &dtype{text_config->get_dtype()}; + const auto &kv_cache_dtype{text_config->get_kv_cache_dtype()}; + const std::vector layer_types = text_config->get>("layer_types"); + std::vector kv_cache_vec; + std::vector conv_state_vec; + std::vector ssm_state_vec; + kv_cache_vec.reserve(num_hidden_layers); + conv_state_vec.reserve(num_hidden_layers); + ssm_state_vec.reserve(num_hidden_layers); + + auto allocate_linear_attention_cache = [&](size_t layer_idx, size_t pool_size) { + auto conv_state = cache::MambaCache::create_layer_conv_state( + linear_key_head_dim, + linear_value_head_dim, + linear_num_key_heads, + linear_num_value_heads, + linear_conv_kernel_dim, + dtype, + pool_size); + auto ssm_state = cache::MambaCache::create_layer_ssm_state( + linear_key_head_dim, + linear_value_head_dim, + linear_num_key_heads, + linear_num_value_heads, + dtype, + pool_size); + + kv_cache_vec.emplace_back(); + conv_state_vec.push_back(std::move(conv_state)); + ssm_state_vec.push_back(std::move(ssm_state)); + }; + + auto allocate_static_full_attention_cache = [&](size_t layer_idx, const cache::StaticKVCacheConfig &config) { + auto kv_cache = cache::StaticKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + max_position_embeddings, + kv_cache_dtype, + config); + + kv_cache_vec.push_back(std::move(kv_cache)); + conv_state_vec.emplace_back(); + ssm_state_vec.emplace_back(); + }; + + auto allocate_paged_full_attention_cache = [&](size_t layer_idx, const cache::PagedKVCacheConfig &config) { + auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + kv_cache_dtype, + config); + + kv_cache_vec.push_back(std::move(kv_cache)); + conv_state_vec.emplace_back(); + ssm_state_vec.emplace_back(); + }; + switch (attention_backend) { case backends::AttentionBackend::STATIC_ATTN: { auto static_kv_cache_config = dynamic_cast(cache_config); if (nullptr == static_kv_cache_config) { throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: invalid static kv cache config type"); } - const size_t num_hidden_layers = text_config->get("num_hidden_layers"); - kv_cache_vec.reserve(num_hidden_layers); - - const size_t head_dim = text_config->get("head_dim"); - const size_t num_key_value_heads = text_config->get("num_key_value_heads"); - const size_t max_position_embeddings = text_config->get("max_position_embeddings"); - const auto &dtype{text_config->get_dtype()}; - const std::vector layer_types = text_config->get>("layer_types"); for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { const std::string &layer_type = layer_types[layer_idx]; if ("linear_attention" == layer_type) { - kv_cache_vec.emplace_back(); + allocate_linear_attention_cache(layer_idx, static_kv_cache_config->max_batch_size()); } else if ("full_attention" == layer_type) { - auto kv_cache = cache::StaticKVCache::create_layer_kv_cache( - head_dim, - head_dim, - num_key_value_heads, - num_key_value_heads, - max_position_embeddings, - dtype, - *static_kv_cache_config); - kv_cache_vec.push_back(kv_cache); + allocate_static_full_attention_cache(layer_idx, *static_kv_cache_config); } else { throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: unsupported layer_type '" + layer_type + "' for layer " + std::to_string(layer_idx)); } @@ -62,27 +121,14 @@ std::vector qwen3_next_allocate_kv_cache_tensors( if (nullptr == paged_kv_cache_config) { throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: invalid paged kv cache config type"); } - const size_t num_hidden_layers = text_config->get("num_hidden_layers"); - kv_cache_vec.reserve(num_hidden_layers); - - const size_t head_dim = text_config->get("head_dim"); - const size_t num_key_value_heads = text_config->get("num_key_value_heads"); - const auto &dtype{text_config->get_dtype()}; - const std::vector layer_types = text_config->get>("layer_types"); + const size_t mamba_pool_size = std::max(1, paged_kv_cache_config->num_blocks() / 4); for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { const std::string &layer_type = layer_types[layer_idx]; if ("linear_attention" == layer_type) { - kv_cache_vec.emplace_back(); + allocate_linear_attention_cache(layer_idx, mamba_pool_size); } else if ("full_attention" == layer_type) { - auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( - head_dim, - head_dim, - num_key_value_heads, - num_key_value_heads, - dtype, - *paged_kv_cache_config); - kv_cache_vec.push_back(kv_cache); + allocate_paged_full_attention_cache(layer_idx, *paged_kv_cache_config); } else { throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: unsupported layer_type '" + layer_type + "' for layer " + std::to_string(layer_idx)); } @@ -92,7 +138,10 @@ std::vector qwen3_next_allocate_kv_cache_tensors( default: throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: Unsupported attention backend: " + std::to_string(static_cast(attention_backend))); } - return kv_cache_vec; + return AllocatedHybridCache{ + std::move(kv_cache_vec), + std::move(conv_state_vec), + std::move(ssm_state_vec)}; } } // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.hpp b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.hpp new file mode 100644 index 000000000..a4b5190f1 --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "../../backends/attention_backends.hpp" +#include "../../cache/kv_cache.hpp" +#include "../../cache/mamba_cache.hpp" +#include "../../config/model_config.hpp" + +#include +#include +#include +#include + +namespace infinilm::models::qwen3_next { + +struct AllocatedHybridCache { + std::vector kv_cache_tensors; + std::vector conv_state_tensors; + std::vector ssm_state_tensors; +}; + +AllocatedHybridCache qwen3_next_allocate_cache_tensors( + const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend); + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp b/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp index 7b396cb2d..4c61832c8 100644 --- a/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp +++ b/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp @@ -35,7 +35,7 @@ std::tuple Qwen3NextDecoderLayer::forwar infinicore::Tensor &residual) { input_layernorm_->forward_inplace(hidden_states, residual); if ("linear_attention" == layer_type_) { - hidden_states = linear_attn_->forward(positions, hidden_states); + hidden_states = linear_attn_->forward(hidden_states); } else if ("full_attention" == layer_type_) { hidden_states = self_attn_->forward(positions, hidden_states); } @@ -50,7 +50,7 @@ infinicore::Tensor Qwen3NextDecoderLayer::forward(const infinicore::Tensor &posi auto residual = hidden_states; hidden_states = input_layernorm_->forward(hidden_states); if ("linear_attention" == layer_type_) { - hidden_states = linear_attn_->forward(positions, hidden_states); + hidden_states = linear_attn_->forward(hidden_states); } else if ("full_attention" == layer_type_) { hidden_states = self_attn_->forward(positions, hidden_states); } diff --git a/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp index 7d2f8a6e4..1b6354540 100644 --- a/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp +++ b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp @@ -1,8 +1,10 @@ #include "qwen3_next_for_causal_lm.hpp" #include "../../global_state/global_state.hpp" #include "../models_registry.hpp" +#include "qwen3_next_allocate_kv_cache_tensors.hpp" #include #include +#include #include namespace infinilm::models::qwen3_next { @@ -31,12 +33,17 @@ void Qwen3NextForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { } cache_config_ = cache_config->unique_copy(); - auto &kv_cache_vec = infinilm::global_state::get_forward_context().kv_cache_vec; - kv_cache_vec.clear(); + auto &forward_context = infinilm::global_state::get_forward_context(); + forward_context.kv_cache_vec.clear(); + forward_context.conv_state_vec.clear(); + forward_context.ssm_state_vec.clear(); + const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; - auto new_kv_cache_vec = qwen3_next_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend); - kv_cache_vec = std::move(new_kv_cache_vec); + auto cache_vectors = qwen3_next_allocate_cache_tensors(cache_config, model_config_, attention_backend); + forward_context.kv_cache_vec = std::move(cache_vectors.kv_cache_tensors); + forward_context.conv_state_vec = std::move(cache_vectors.conv_state_tensors); + forward_context.ssm_state_vec = std::move(cache_vectors.ssm_state_tensors); } std::shared_ptr create_qwen3_next_model_config(std::shared_ptr model_config) { diff --git a/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp index b8d1a6a94..0cbe45320 100644 --- a/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp +++ b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp @@ -24,9 +24,4 @@ class Qwen3NextForCausalLM : public InfinilmModel { std::shared_ptr create_qwen3_next_model_config(std::shared_ptr model_config); -/** Implemented in `qwen3_next_allocate_kv_cache_tensors.cpp`. */ -std::vector qwen3_next_allocate_kv_cache_tensors( - const cache::CacheConfig *cache_config, - const std::shared_ptr &text_config, - const backends::AttentionBackend &attention_backend); } // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp index d39917870..8d12da356 100644 --- a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp +++ b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp @@ -1,21 +1,20 @@ #include "qwen3_next_gated_deltanet.hpp" -#include -namespace infinilm::models::qwen3_next { +#include "../../global_state/global_state.hpp" -FakeConv1d::FakeConv1d(size_t in_channels, - size_t out_channels, - size_t kernel_size, - size_t stride, - size_t padding, - size_t dilation, - size_t groups, - bool bias, - const infinicore::DataType dtype, - const infinicore::Device device) { - - INFINICORE_NN_PARAMETER_INIT(weight, ({out_channels, 1, kernel_size}, dtype, device)); -} +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace infinilm::models::qwen3_next { Qwen3NextGatedDeltaNet::Qwen3NextGatedDeltaNet(std::shared_ptr model_config, size_t layer_idx, @@ -28,33 +27,116 @@ Qwen3NextGatedDeltaNet::Qwen3NextGatedDeltaNet(std::shared_ptrget("linear_key_head_dim"); size_t linear_value_head_dim = model_config->get("linear_value_head_dim"); - size_t key_dim = linear_key_head_dim * linear_num_key_heads; - size_t value_dim = linear_value_head_dim * linear_num_value_heads; + linear_num_value_heads_ = linear_num_value_heads; + linear_num_key_heads_ = linear_num_key_heads; + linear_key_head_dim_ = linear_key_head_dim; + linear_value_head_dim_ = linear_value_head_dim; + key_dim_ = linear_key_head_dim_ * linear_num_key_heads_; + value_dim_ = linear_value_head_dim_ * linear_num_value_heads_; size_t linear_conv_kernel_dim = model_config->get("linear_conv_kernel_dim"); double rms_norm_eps = model_config->get("rms_norm_eps"); - size_t conv_dim = key_dim * 2 + value_dim; - INFINICORE_NN_MODULE_INIT(conv1d, conv_dim, conv_dim, linear_conv_kernel_dim, 1, linear_conv_kernel_dim - 1, 1, 1, false, dtype, device); + size_t conv_dim = key_dim_ * 2 + value_dim_; + conv_dim_ = conv_dim; + conv_state_len_ = linear_conv_kernel_dim > 0 ? linear_conv_kernel_dim - 1 : 0; + conv1d_weight_ = infinicore::nn::Parameter({conv_dim, 1, linear_conv_kernel_dim}, dtype, device); + this->register_parameter("conv1d.weight", conv1d_weight_); - size_t projection_size_qkvz = key_dim * 2 + value_dim * 2; - size_t projection_size_ba = linear_num_value_heads * 2; + size_t projection_size_qkv = key_dim_ * 2 + value_dim_; - in_proj_qkvz_ = this->register_module("in_proj_qkvz", hidden_size, projection_size_qkvz, false, dtype, device); - in_proj_ba_ = this->register_module("in_proj_ba", hidden_size, projection_size_ba, false, dtype, device); + in_proj_qkv_ = this->register_module("in_proj_qkv", hidden_size, projection_size_qkv, false, dtype, device); + in_proj_z_ = this->register_module("in_proj_z", hidden_size, value_dim_, false, dtype, device); + in_proj_a_ = this->register_module("in_proj_a", hidden_size, linear_num_value_heads, false, dtype, device); + in_proj_b_ = this->register_module("in_proj_b", hidden_size, linear_num_value_heads, false, dtype, device); INFINICORE_NN_PARAMETER_INIT(dt_bias, ({linear_num_value_heads}, dtype, device)); INFINICORE_NN_PARAMETER_INIT(A_log, ({linear_num_value_heads}, dtype, device)); INFINICORE_NN_MODULE_INIT(norm, linear_value_head_dim, rms_norm_eps, dtype, device); - out_proj_ = this->register_module("out_proj", value_dim, hidden_size, false, dtype, device); + out_proj_ = this->register_module("out_proj", value_dim_, hidden_size, false, dtype, device); } -infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &positions, - const infinicore::Tensor &hidden_states) const { - spdlog::error("Qwen3NextGatedDeltaNet: forward not implemented"); - return hidden_states; +infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &hidden_states) const { + + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + auto qkv = in_proj_qkv_->forward(hidden_states_mutable); + auto z = in_proj_z_->forward(hidden_states_mutable); + auto a = in_proj_a_->forward(hidden_states_mutable); + auto b = in_proj_b_->forward(hidden_states_mutable); + + auto &forward_context = infinilm::global_state::get_forward_context(); + auto &mamba_metadata = forward_context.mamba_metadata; + + auto conv_out = infinicore::op::causal_conv1d( + qkv, + forward_context.conv_state_vec[layer_idx_], + conv1d_weight_, + std::nullopt, + mamba_metadata.input_offsets.value(), + mamba_metadata.init_state_indices.value(), + mamba_metadata.final_state_indices.value()); + auto conv_qkv = infinicore::op::silu(conv_out); + + auto q = conv_qkv->narrow({{2, 0, key_dim_}}); + auto k = conv_qkv->narrow({{2, key_dim_, key_dim_}}); + auto v = conv_qkv->narrow({{2, key_dim_ * 2, value_dim_}}); + bool is_decode = mamba_metadata.input_offsets.value()->shape()[0] - 1 == seq_len; + infinicore::Tensor delta_out; + if (is_decode) { + auto ssm_state = forward_context.ssm_state_vec[layer_idx_]; + auto q_delta = q->view({seq_len, 1, linear_num_key_heads_, linear_key_head_dim_}); + auto k_delta = k->view({seq_len, 1, linear_num_key_heads_, linear_key_head_dim_}); + auto v_delta = v->view({seq_len, 1, linear_num_value_heads_, linear_value_head_dim_}); + + auto a_heads = a->view({seq_len, 1, linear_num_value_heads_}); + auto b_heads = b->view({seq_len, 1, linear_num_value_heads_}); + auto [g, beta] = infinicore::op::fused_gated_delta_net_gating(A_log_, a_heads, b_heads, dt_bias_); + + delta_out = infinicore::op::recurrent_gated_delta_rule_indexed( + q_delta, + k_delta, + v_delta, + g, + beta, + ssm_state, + mamba_metadata.init_state_indices.value(), + mamba_metadata.final_state_indices.value(), + true); + delta_out = delta_out->view({seq_len, linear_num_value_heads_, linear_value_head_dim_}); + } else { + auto ssm_state = forward_context.ssm_state_vec[layer_idx_]; + auto q_delta = q->view({1, seq_len, linear_num_key_heads_, linear_key_head_dim_}); + auto k_delta = k->view({1, seq_len, linear_num_key_heads_, linear_key_head_dim_}); + auto v_delta = v->view({1, seq_len, linear_num_value_heads_, linear_value_head_dim_}); + + auto a_heads = a->view({1, seq_len, linear_num_value_heads_}); + auto b_heads = b->view({1, seq_len, linear_num_value_heads_}); + auto [g, beta] = infinicore::op::fused_gated_delta_net_gating(A_log_, a_heads, b_heads, dt_bias_); + + delta_out = infinicore::op::chunk_gated_delta_rule( + q_delta, + k_delta, + v_delta, + g, + beta, + ssm_state, + mamba_metadata.input_offsets.value(), + mamba_metadata.init_state_indices.value(), + mamba_metadata.final_state_indices.value(), + true); + delta_out = delta_out->view({seq_len, linear_num_value_heads_, linear_value_head_dim_}); + } + + auto v_norm = norm_->forward(delta_out->view({batch_size * seq_len * linear_num_value_heads_, linear_value_head_dim_})) + ->view({batch_size, seq_len, value_dim_}); + auto gated = infinicore::op::mul(v_norm, infinicore::op::silu(z)); + return out_proj_->forward(gated); } } // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp index 8aba34ded..e91c3c24c 100644 --- a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp +++ b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp @@ -5,43 +5,34 @@ namespace infinilm::models::qwen3_next { using Qwen3Next_Fake_RMSNormGated = infinicore::nn::RMSNorm; -class FakeConv1d : public infinicore::nn::Module { -public: - FakeConv1d(size_t in_channels, - size_t out_channels, - size_t kernel_size, - size_t stride, - size_t padding, - size_t dilation, - size_t groups, - bool bias, - const infinicore::DataType dtype, - const infinicore::Device device); - -private: - size_t layer_idx_; - INFINICORE_NN_PARAMETER(weight); -}; - class Qwen3NextGatedDeltaNet : public infinicore::nn::Module { public: Qwen3NextGatedDeltaNet(std::shared_ptr model_config, size_t layer_idx, const infinicore::Device &device); - infinicore::Tensor forward(const infinicore::Tensor &positions, - const infinicore::Tensor &hidden_states) const; + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; private: - std::shared_ptr in_proj_qkvz_; - std::shared_ptr in_proj_ba_; - INFINICORE_NN_MODULE(FakeConv1d, conv1d); + std::shared_ptr in_proj_qkv_; + std::shared_ptr in_proj_z_; + std::shared_ptr in_proj_a_; + std::shared_ptr in_proj_b_; + INFINICORE_NN_PARAMETER(conv1d_weight); INFINICORE_NN_PARAMETER(dt_bias); INFINICORE_NN_PARAMETER(A_log); INFINICORE_NN_MODULE(Qwen3Next_Fake_RMSNormGated, norm); std::shared_ptr out_proj_; size_t layer_idx_; + size_t linear_num_value_heads_; + size_t linear_num_key_heads_; + size_t linear_key_head_dim_; + size_t linear_value_head_dim_; + size_t key_dim_; + size_t value_dim_; + size_t conv_dim_; + size_t conv_state_len_; }; } // namespace infinilm::models::qwen3_next diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 4c5058ac5..cf2364dc2 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -136,6 +136,8 @@ inline void bind_infer_engine(py::module &m) { std::optional cu_seqlens, std::optional block_tables, std::optional slot_mapping, + std::optional mamba_init_state_indices, + std::optional mamba_final_state_indices, std::optional> pixel_values, std::optional> image_bound, std::optional> tgt_sizes, @@ -150,6 +152,8 @@ inline void bind_infer_engine(py::module &m) { std::move(cu_seqlens), std::move(block_tables), std::move(slot_mapping), + std::move(mamba_init_state_indices), + std::move(mamba_final_state_indices), std::move(pixel_values), std::move(image_bound), std::move(tgt_sizes), @@ -195,6 +199,8 @@ inline void bind_infer_engine(py::module &m) { py::arg("cu_seqlens") = std::nullopt, py::arg("block_tables") = std::nullopt, py::arg("slot_mapping") = std::nullopt, + py::arg("mamba_init_state_indices") = std::nullopt, + py::arg("mamba_final_state_indices") = std::nullopt, py::arg("pixel_values") = std::nullopt, py::arg("image_bound") = std::nullopt, py::arg("tgt_sizes") = std::nullopt, @@ -207,6 +213,8 @@ inline void bind_infer_engine(py::module &m) { .def_readwrite("cu_seqlens", &InferEngine::Input::cu_seqlens) .def_readwrite("block_tables", &InferEngine::Input::block_tables) .def_readwrite("slot_mapping", &InferEngine::Input::slot_mapping) + .def_readwrite("mamba_init_state_indices", &InferEngine::Input::mamba_init_state_indices) + .def_readwrite("mamba_final_state_indices", &InferEngine::Input::mamba_final_state_indices) .def_readwrite("pixel_values", &InferEngine::Input::pixel_values) .def_readwrite("image_bound", &InferEngine::Input::image_bound) .def_readwrite("tgt_sizes", &InferEngine::Input::tgt_sizes) diff --git a/examples/test_infer.py b/examples/test_infer.py index 14db1b6ef..fe42f8e2f 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -12,6 +12,8 @@ def test( tp=1, enable_paged_attn=False, enable_graph=False, + num_blocks=512, + block_size=256, top_k=1, top_p=1.0, temperature=1.0, @@ -35,6 +37,8 @@ def test( cache_type="paged" if enable_paged_attn else "static", max_batch_size=len(prompts), max_tokens=max_new_tokens, + num_blocks=num_blocks, + block_size=block_size, temperature=temperature, top_k=top_k, top_p=top_p, @@ -101,6 +105,8 @@ def test( tp=tp, enable_paged_attn=enable_paged_attn, enable_graph=enable_graph, + num_blocks=cfg.num_blocks, + block_size=cfg.block_size, top_k=cfg.top_k, top_p=cfg.top_p, temperature=cfg.temperature, diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 10cf58be2..c6d651134 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -101,6 +101,10 @@ def dtype(self): torch_dtype = self.hf_config.get("torch_dtype") if torch_dtype is None: torch_dtype = self.hf_config.get("dtype") + if torch_dtype is None: + text_config = self.hf_config.get("text_config") + if isinstance(text_config, dict): + torch_dtype = text_config.get("torch_dtype") or text_config.get("dtype") return parse_dtype(torch_dtype) @property @@ -144,6 +148,8 @@ def forward( cu_seqlens=None, block_tables=None, slot_mapping=None, + mamba_init_state_indices=None, + mamba_final_state_indices=None, pixel_values=None, image_bound=None, tgt_sizes=None, @@ -174,6 +180,16 @@ def forward( slot_mapping = ( slot_mapping._underlying if slot_mapping is not None else None ) + mamba_init_state_indices = ( + mamba_init_state_indices._underlying + if mamba_init_state_indices is not None + else None + ) + mamba_final_state_indices = ( + mamba_final_state_indices._underlying + if mamba_final_state_indices is not None + else None + ) def convert_tensor_list(tensor_list_): if tensor_list_ is None: @@ -200,6 +216,8 @@ def convert_tensor_list(tensor_list_): cu_seqlens=cu_seqlens, block_tables=block_tables, slot_mapping=slot_mapping, + mamba_init_state_indices=mamba_init_state_indices, + mamba_final_state_indices=mamba_final_state_indices, pixel_values=pixel_values, image_bound=image_bound, tgt_sizes=tgt_sizes, diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 243bb5764..81695c79b 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -67,10 +67,14 @@ def __init__(self, config: EngineConfig): f"KV Connector created: {config.kv_transfer_config.kv_connector} " f"(role={config.kv_transfer_config.kv_role})" ) + llm_config = self.model_runner.model_engine.hf_config + if "text_config" in llm_config: + llm_config = llm_config["text_config"] - max_position_embeddings = self.model_runner.model_engine.hf_config.get( + max_position_embeddings = llm_config.get( "max_position_embeddings", config.max_cache_len ) + max_num_batched_tokens = int( os.getenv("INFINILM_MAX_NUM_BATCHED_TOKENS", max_position_embeddings) ) @@ -721,13 +725,13 @@ def add_request( elif prompt is not None: prompt_token_ids = self.engine.tokenize(prompt) else: - assert ( - messages is not None - ), "Either messages or prompt/prompt_token_ids must be provided" + assert messages is not None, ( + "Either messages or prompt/prompt_token_ids must be provided" + ) - assert ( - apply_chat_template - ), "apply_chat_template needs to be true for multi-role conversation" + assert apply_chat_template, ( + "apply_chat_template needs to be true for multi-role conversation" + ) prompt = self.engine.apply_chat_template( messages, add_generation_prompt=add_generation_prompt @@ -752,128 +756,4 @@ def add_request( ) if sampling_params is None: - sampling_params = SamplingParams(max_tokens=self.config.max_tokens) - elif sampling_params.max_tokens is None: - sampling_params = sampling_params.clone() - sampling_params.max_tokens = self.config.max_tokens - - request = InferenceRequest( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - processed_inputs=processed_inputs, - mm_token_index_mappings=mm_index_mappings, - sampling_params=sampling_params, - eos_token_ids=self.engine.eos_token_ids, - request_data=request_data, - ) - - if request_data and "kv_transfer_params" in request_data: - kv_params = request_data["kv_transfer_params"] - request.kv_transfer_params = kv_params - - # Initialize output queue for streaming - _ = request.output_queue - - self.engine.add_request(request) - return request - - def add_chat_request( - self, - messages: List[dict], - sampling_params: Optional[SamplingParams] = None, - request_id: Optional[str] = None, - request_data: Optional[dict] = None, - add_generation_prompt: bool = True, - **kwargs, - ) -> InferenceRequest: - """Add a chat request to the engine. - - Args: - messages: List of message dicts (chat conversation). - sampling_params: Sampling parameters. - request_id: Optional request ID. - request_data: Optional request data dict. - - Returns: - The created InferenceRequest object. - """ - - return self.add_request( - messages=messages, - apply_chat_template=True, - add_generation_prompt=add_generation_prompt, - sampling_params=sampling_params, - request_id=request_id, - request_data=request_data, - ) - - async def stream_request( - self, - request: InferenceRequest, - timeout: float = 100.0, - request_timeout: Optional[float] = None, - ) -> AsyncIterator[TokenOutput]: - """Stream tokens from a request. - - Args: - request: The inference request to stream from. - timeout: Timeout for waiting on each token. - - Yields: - TokenOutput objects for each generated token. - """ - import asyncio - - start = time.time() - try: - while True: - try: - if request_timeout and time.time() - start > float(request_timeout): - logger.warning( - f"Request {request.request_id} exceeded request timeout of {request_timeout} seconds" - ) - self.add_aborted_req(request, FinishReason.TIMEOUT) - - token_output = await asyncio.wait_for( - request.output_queue.async_q.get(), timeout=timeout - ) - - request.output_queue.async_q.task_done() - - yield token_output - - if token_output.finished: - break - except asyncio.TimeoutError: - logger.warning( - f"Timeout while waiting for token from request {request.request_id}" - ) - if request.is_aborted(): - while not request.output_queue.async_q.empty(): - try: - token_output = request.output_queue.async_q.get_nowait() - request.output_queue.async_q.task_done() - yield token_output - except asyncio.QueueEmpty: - break - - yield TokenOutput( - request_id=request.request_id, - token_id=-1, - token_text="", - finished=True, - finish_reason=request.finish_reason, - generated_text=request.generated_text, - ) - break - continue - except Exception as e: - logger.error( - f"Error while streaming request {request.request_id}: {e}" - ) - break - finally: - # Unified cleanup point: runs whether the loop exits normally, - # via exception, or via aclose() (GeneratorExit from Starlette). - await request.close() + sampling_params = SamplingParams(max_ \ No newline at end of file diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 94f4016a9..ec05f8620 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -661,10 +661,29 @@ def _remap_mamba(state_dict, config=None): # Model type → remap function mapping +def _remap_qwen3_5(state_dict, config=None): + """Apply Qwen3.5-specific load-time weight fixes.""" + state_dict = drop_keys(state_dict, ["mtp."]) + + norm_weight_suffixes = ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "self_attn.q_norm.weight", + "self_attn.k_norm.weight", + ) + + for key, tensor in state_dict.items(): + if key == "model.norm.weight" or key.endswith(norm_weight_suffixes): + state_dict[key] = tensor + torch.ones_like(tensor) + + return state_dict + + _WEIGHT_REMAPPER = { "glm4": _remap_glm4, "chatglm": _remap_chatglm, "baichuan": _remap_baichuan, "gpt2": _remap_gpt2, "mamba": _remap_mamba, + "qwen3_5": _remap_qwen3_5, } diff --git a/python/infinilm/processors/__init__.py b/python/infinilm/processors/__init__.py index 67f97acac..e804d7fc5 100644 --- a/python/infinilm/processors/__init__.py +++ b/python/infinilm/processors/__init__.py @@ -1,4 +1,5 @@ import importlib +import json import pkgutil from pathlib import Path from transformers import AutoConfig @@ -30,8 +31,14 @@ def from_pretrained(cls, model_dir_path: str, **kwargs) -> InfinilmProcessor: registered Processor. Falls back to the registered default processor for unregistered or standard architectures. """ - config = AutoConfig.from_pretrained(model_dir_path, trust_remote_code=True) - model_type = config.model_type.lower() + try: + config = AutoConfig.from_pretrained(model_dir_path, trust_remote_code=True) + model_type = config.model_type.lower() + except ValueError: + config_path = Path(model_dir_path) / "config.json" + with open(config_path, "r") as f: + config = json.load(f) + model_type = config["model_type"].lower() processor_cls = get_processor_class(model_type) return processor_cls(model_dir_path) diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index 6948aa41c..177685777 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -73,6 +73,7 @@ def build_model_inputs( temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1, + **kwargs, ) -> dict: """Process a batch of data and return a dictionary of model inputs.""" if isinstance(scheduler_output, StaticSchedulerOutput): diff --git a/python/infinilm/processors/qwen3_5_processor.py b/python/infinilm/processors/qwen3_5_processor.py new file mode 100644 index 000000000..d59d59141 --- /dev/null +++ b/python/infinilm/processors/qwen3_5_processor.py @@ -0,0 +1,171 @@ +from typing_extensions import override + +from transformers import AutoProcessor, AutoTokenizer + +from ..llm.scheduler import SchedulerOutput +from ..llm.static_scheduler import StaticSchedulerOutput +from .basic_llm_processor import BasicLLMProcessor +from .processor import register_processor + + +@register_processor("qwen3_5") +class Qwen35Processor(BasicLLMProcessor): + def __init__(self, model_dir_path: str): + try: + self.processor = AutoProcessor.from_pretrained( + model_dir_path, trust_remote_code=True + ) + self.tokenizer = self.processor.tokenizer + except Exception: + self.processor = None + self.tokenizer = AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + + + + @override + def __call__( + self, + prompt, + images=None, + videos=None, + audios=None, + return_tensors: str = None, + **kwargs, + ) -> dict: + if not images and not videos and not audios: + return self.tokenizer( + prompt, return_tensors=return_tensors, add_special_tokens=False + ) + + if self.processor is None: + raise RuntimeError("Qwen3.5 multimodal processor is not available") + + return self.processor( + text=prompt, + images=images, + videos=videos, + return_tensors=return_tensors or "pt", + **kwargs, + ) + + @override + def apply_chat_template( + self, + conversation, + add_generation_prompt: bool = False, + tokenize: bool = True, + **kwargs, + ): + normalized_conversation = [] + for message in conversation: + content = message["content"] + if not isinstance(content, list): + normalized_conversation.append(message) + continue + + normalized_content = [] + for item in content: + item_type = item.get("type") + if item_type == "text": + normalized_content.append({"type": "text", "text": item.get("text", "")}) + elif item_type == "image_url": + normalized_content.append({"type": "image"}) + elif item_type == "video_url": + normalized_content.append({"type": "video"}) + else: + raise NotImplementedError(f"Unsupported Qwen3.5 content type: {item_type}") + + normalized_conversation.append( + {"role": message.get("role", "user"), "content": normalized_content} + ) + + template_owner = self.processor if self.processor is not None else self.tokenizer + return template_owner.apply_chat_template( + conversation=normalized_conversation, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + **kwargs, + ) + + @override + def build_model_inputs( + self, + scheduler_output: SchedulerOutput | StaticSchedulerOutput, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: int = 1, + **kwargs, + ) -> dict: + if isinstance(scheduler_output, StaticSchedulerOutput): + model_inputs = self._build_model_input_from_static_scheduler_output( + scheduler_output, temperature, top_p, top_k + ) + elif isinstance(scheduler_output, SchedulerOutput): + model_inputs = self._build_model_input_from_batch_scheduler_output( + scheduler_output, temperature, top_p, top_k + ) + self._append_qwen35_mm_inputs(model_inputs, scheduler_output) + else: + raise ValueError( + "scheduler_output must be an instance of SchedulerOutput or StaticSchedulerOutput" + ) + + # TODO(qwen3_5): The scheduler should own stable mamba cache ids. For now + # use a per-forward arange so the C++ model input and mamba metadata path + # can be exercised without encoding cache policy in the processor. + num_requests = len(scheduler_output.scheduled_requests) + init_indices = list(range(num_requests)) + final_indices = list(range(num_requests)) + + import infinicore + + model_inputs["mamba_init_state_indices"] = infinicore.from_list( + init_indices, dtype=infinicore.int32 + ) + model_inputs["mamba_final_state_indices"] = infinicore.from_list( + final_indices, dtype=infinicore.int32 + ) + return model_inputs + + def _append_qwen35_mm_inputs( + self, model_inputs: dict, scheduler_output: SchedulerOutput + ) -> None: + import infinicore + import torch + + pixel_values = [] + image_req_ids = [] + for req_id, req in enumerate(scheduler_output.scheduled_requests): + processed_inputs = req.processed_inputs + if ( + not scheduler_output.is_prefill + or processed_inputs is None + or "pixel_values" not in processed_inputs + ): + continue + + pixel_value = processed_inputs["pixel_values"] + if isinstance(pixel_value, list): + pixel_values.extend(pixel_value) + else: + pixel_values.append(pixel_value) + image_req_ids.append(req_id) + + if pixel_values: + pixel_values = [ + infinicore.from_torch(t if isinstance(t, torch.Tensor) else torch.as_tensor(t)) + for t in pixel_values + ] + model_inputs["pixel_values"] = pixel_values + model_inputs["image_req_ids"] = image_req_ids + + @override + def get_mm_token_index_list( + self, prompt_token_ids, image_ids=None, video_ids=None, audio_ids=None, **kwargs + ): + mm_token_index_list = [] + + + return mm_token_index_list From b2f70e6f2f69833d7c9b11d428f6048d18f9a1e9 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 26 Jun 2026 09:04:42 +0000 Subject: [PATCH 2/4] issue/446 support TP in gated deltanet layer --- csrc/layers/linear/fused_linear.cpp | 15 +- csrc/layers/linear/fused_linear.hpp | 11 ++ .../qwen3_next/qwen3_next_gated_deltanet.cpp | 175 +++++++++++++----- .../qwen3_next/qwen3_next_gated_deltanet.hpp | 53 ++++-- python/infinilm/modeling_utils.py | 21 ++- 5 files changed, 208 insertions(+), 67 deletions(-) diff --git a/csrc/layers/linear/fused_linear.cpp b/csrc/layers/linear/fused_linear.cpp index 2c0520901..1bcb96c94 100644 --- a/csrc/layers/linear/fused_linear.cpp +++ b/csrc/layers/linear/fused_linear.cpp @@ -84,7 +84,20 @@ QKVParallelLinear::QKVParallelLinear(size_t hidden_size, const infinicore::DataType &dtype, const infinicore::Device &device, engine::distributed::RankInfo rank_info) - : QKVParallelLinear(hidden_size, head_dim, num_q_head, num_kv_head, quantization, bias, dtype, device, rank_info) { + : QKVParallelLinear(hidden_size, head_dim, head_dim, head_dim, num_q_head, num_kv_head, num_kv_head, bias, bias, bias, q_name, k_name, v_name, register_fn, quantization, dtype, device, rank_info) { +} + +QKVParallelLinear::QKVParallelLinear(size_t hidden_size, + size_t q_dim, size_t k_dim, size_t v_dim, + size_t num_q_head, size_t num_k_head, size_t num_v_head, + bool q_bias, bool k_bias, bool v_bias, + const std::string &q_name, const std::string &k_name, const std::string &v_name, + RegisterParamFn register_fn, + std::shared_ptr quantization, + const infinicore::DataType &dtype, + const infinicore::Device &device, + engine::distributed::RankInfo rank_info) + : QKVParallelLinear(hidden_size, q_dim, k_dim, v_dim, num_q_head, num_k_head, num_v_head, q_bias, k_bias, v_bias, quantization, dtype, device, rank_info) { register_fn_ = register_fn; split_infos_ = { {q_name, 0, q_out_size_, 0}, diff --git a/csrc/layers/linear/fused_linear.hpp b/csrc/layers/linear/fused_linear.hpp index 6e4a34856..8773a081c 100644 --- a/csrc/layers/linear/fused_linear.hpp +++ b/csrc/layers/linear/fused_linear.hpp @@ -27,6 +27,17 @@ class QKVParallelLinear : public infinilm::nn::ColumnParallelLinear { const infinicore::Device &device = infinicore::Device(), engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + QKVParallelLinear(size_t hidden_size, + size_t q_dim, size_t k_dim, size_t v_dim, + size_t num_q_head, size_t num_k_head, size_t num_v_head, + bool q_bias, bool k_bias, bool v_bias, + const std::string &q_name, const std::string &k_name, const std::string &v_name, + RegisterParamFn register_fn, + std::shared_ptr quantization = nullptr, + const infinicore::DataType &dtype = infinicore::DataType::F32, + const infinicore::Device &device = infinicore::Device(), + engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); + QKVParallelLinear(size_t hidden_size, size_t head_dim, size_t num_q_head, size_t num_kv_head, diff --git a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp index 8d12da356..0bd8bf509 100644 --- a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp +++ b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp @@ -16,6 +16,87 @@ namespace infinilm::models::qwen3_next { +Qwen3NextCausalConv1D::Qwen3NextCausalConv1D(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + const auto &dtype{model_config->get_dtype()}; + size_t linear_num_value_heads = model_config->get("linear_num_value_heads"); + size_t linear_num_key_heads = model_config->get("linear_num_key_heads"); + size_t linear_key_head_dim = model_config->get("linear_key_head_dim"); + size_t linear_value_head_dim = model_config->get("linear_value_head_dim"); + size_t linear_conv_kernel_dim = model_config->get("linear_conv_kernel_dim"); + + size_t key_dim = linear_key_head_dim * linear_num_key_heads; + size_t value_dim = linear_value_head_dim * linear_num_value_heads; + size_t conv_dim = key_dim * 2 + value_dim; + + size_t conv_state_len = linear_conv_kernel_dim > 0 ? linear_conv_kernel_dim - 1 : 0; + weight_ = infinicore::nn::Parameter({conv_dim, 1, linear_conv_kernel_dim}, dtype, device); + this->register_parameter("weight", weight_); + + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + tp_size_ = rank_info.tp_size; + tp_rank_ = rank_info.tp_rank; + conv_kernel_dim_ = linear_conv_kernel_dim; + auto tp_size = tp_size_; + full_qk_dim_ = linear_num_key_heads * linear_key_head_dim; + full_v_dim_ = linear_num_value_heads * linear_value_head_dim; + local_qk_dim_ = (linear_num_key_heads >= tp_size ? linear_num_key_heads / tp_size : 1) * linear_key_head_dim; + local_v_dim_ = (linear_num_value_heads >= tp_size ? linear_num_value_heads / tp_size : 1) * linear_value_head_dim; + local_conv_dim_ = local_qk_dim_ * 2 + local_v_dim_; +} + +void Qwen3NextCausalConv1D::process_weights_after_loading() { + if (tp_size_ <= 1 || weight_->size(0) == local_conv_dim_) { + return; + } + + const size_t expected_full_conv_dim = full_qk_dim_ * 2 + full_v_dim_; + if (weight_->size(0) != expected_full_conv_dim) { + throw std::runtime_error("Qwen3NextCausalConv1D: unexpected conv1d weight shape for TP slicing"); + } + + auto local_weight = infinicore::Tensor::empty( + {local_conv_dim_, 1, conv_kernel_dim_}, + weight_->dtype(), + weight_->device()); + + const size_t src_qk0_offset = tp_rank_ * local_qk_dim_; + const size_t src_qk1_offset = full_qk_dim_ + tp_rank_ * local_qk_dim_; + const size_t src_v_offset = 2 * full_qk_dim_ + tp_rank_ * local_v_dim_; + + const size_t dst_qk0_offset = 0; + const size_t dst_qk1_offset = local_qk_dim_; + const size_t dst_v_offset = 2 * local_qk_dim_; + + local_weight->narrow({{0, dst_qk0_offset, local_qk_dim_}}) + ->copy_from(weight_->narrow({{0, src_qk0_offset, local_qk_dim_}})); + local_weight->narrow({{0, dst_qk1_offset, local_qk_dim_}}) + ->copy_from(weight_->narrow({{0, src_qk1_offset, local_qk_dim_}})); + local_weight->narrow({{0, dst_v_offset, local_v_dim_}}) + ->copy_from(weight_->narrow({{0, src_v_offset, local_v_dim_}})); + + weight_ = infinicore::nn::Parameter(local_weight); + parameters_["weight"] = weight_; +} + +infinicore::Tensor Qwen3NextCausalConv1D::forward(const infinicore::Tensor &qkv) const { + auto &forward_context = infinilm::global_state::get_forward_context(); + auto &mamba_metadata = forward_context.mamba_metadata; + + auto conv_out = infinicore::op::causal_conv1d( + qkv, + forward_context.conv_state_vec[layer_idx_], + weight_->narrow({{0, 0, local_conv_dim_}}), // narrow in case load is skipped + std::nullopt, + mamba_metadata.input_offsets.value(), + mamba_metadata.init_state_indices.value(), + mamba_metadata.final_state_indices.value()); + auto conv_qkv = infinicore::op::silu(conv_out); + return conv_qkv; +} + Qwen3NextGatedDeltaNet::Qwen3NextGatedDeltaNet(std::shared_ptr model_config, size_t layer_idx, const infinicore::Device &device) { @@ -26,36 +107,40 @@ Qwen3NextGatedDeltaNet::Qwen3NextGatedDeltaNet(std::shared_ptrget("linear_num_key_heads"); size_t linear_key_head_dim = model_config->get("linear_key_head_dim"); size_t linear_value_head_dim = model_config->get("linear_value_head_dim"); - - linear_num_value_heads_ = linear_num_value_heads; - linear_num_key_heads_ = linear_num_key_heads; - linear_key_head_dim_ = linear_key_head_dim; - linear_value_head_dim_ = linear_value_head_dim; - key_dim_ = linear_key_head_dim_ * linear_num_key_heads_; - value_dim_ = linear_value_head_dim_ * linear_num_value_heads_; - - size_t linear_conv_kernel_dim = model_config->get("linear_conv_kernel_dim"); + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + auto tp_size = rank_info.tp_size; + auto tp_rank = rank_info.tp_rank; + local_num_value_heads_ = linear_num_value_heads / tp_size; + local_num_key_heads_ = linear_num_key_heads / tp_size; + key_head_dim_ = linear_key_head_dim; + value_head_dim_ = linear_value_head_dim; + size_t value_dim = linear_value_head_dim * linear_num_value_heads; + local_key_dim_ = key_head_dim_ * local_num_key_heads_; + local_value_dim_ = value_head_dim_ * local_num_value_heads_; double rms_norm_eps = model_config->get("rms_norm_eps"); - size_t conv_dim = key_dim_ * 2 + value_dim_; - conv_dim_ = conv_dim; - conv_state_len_ = linear_conv_kernel_dim > 0 ? linear_conv_kernel_dim - 1 : 0; - conv1d_weight_ = infinicore::nn::Parameter({conv_dim, 1, linear_conv_kernel_dim}, dtype, device); - this->register_parameter("conv1d.weight", conv1d_weight_); - - size_t projection_size_qkv = key_dim_ * 2 + value_dim_; + conv1d_ = this->register_module("conv1d", model_config, layer_idx, device); - in_proj_qkv_ = this->register_module("in_proj_qkv", hidden_size, projection_size_qkv, false, dtype, device); - in_proj_z_ = this->register_module("in_proj_z", hidden_size, value_dim_, false, dtype, device); - in_proj_a_ = this->register_module("in_proj_a", hidden_size, linear_num_value_heads, false, dtype, device); - in_proj_b_ = this->register_module("in_proj_b", hidden_size, linear_num_value_heads, false, dtype, device); + size_t projection_size_qkv = local_key_dim_ * 2 + local_value_dim_; + auto quantization_method = model_config->get_quantization_method(); + auto register_fn = [this](const std::string &n, infinicore::nn::Parameter p) { this->register_parameter(n, std::move(p)); }; + in_proj_qkv_ = std::make_shared( + hidden_size, linear_key_head_dim, linear_key_head_dim, linear_value_head_dim, linear_num_key_heads, linear_num_key_heads, linear_num_value_heads, + false, false, false, + "in_proj_q", "in_proj_k", "in_proj_v", register_fn, + quantization_method, dtype, device, rank_info); + in_proj_z_ = this->register_module("in_proj_z", hidden_size, value_dim, false, dtype, device, tp_rank, tp_size); + in_proj_a_ = this->register_module("in_proj_a", hidden_size, linear_num_value_heads, false, dtype, device, tp_rank, tp_size); + in_proj_b_ = this->register_module("in_proj_b", hidden_size, linear_num_value_heads, false, dtype, device, tp_rank, tp_size); - INFINICORE_NN_PARAMETER_INIT(dt_bias, ({linear_num_value_heads}, dtype, device)); - INFINICORE_NN_PARAMETER_INIT(A_log, ({linear_num_value_heads}, dtype, device)); + INFINICORE_NN_PARAMETER_INIT(dt_bias, ({linear_num_value_heads}, dtype, device, 0, tp_rank, tp_size)); + INFINICORE_NN_PARAMETER_INIT(A_log, ({linear_num_value_heads}, dtype, device, 0, tp_rank, tp_size)); INFINICORE_NN_MODULE_INIT(norm, linear_value_head_dim, rms_norm_eps, dtype, device); - out_proj_ = this->register_module("out_proj", value_dim_, hidden_size, false, dtype, device); + out_proj_ = this->register_module( + "out_proj", value_dim, hidden_size, quantization_method, + false, dtype, device, rank_info.tp_rank, rank_info.tp_size, rank_info.comm); } infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &hidden_states) const { @@ -73,29 +158,21 @@ infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &hid auto &forward_context = infinilm::global_state::get_forward_context(); auto &mamba_metadata = forward_context.mamba_metadata; - auto conv_out = infinicore::op::causal_conv1d( - qkv, - forward_context.conv_state_vec[layer_idx_], - conv1d_weight_, - std::nullopt, - mamba_metadata.input_offsets.value(), - mamba_metadata.init_state_indices.value(), - mamba_metadata.final_state_indices.value()); - auto conv_qkv = infinicore::op::silu(conv_out); + auto conv_qkv = this->conv1d_->forward(qkv); - auto q = conv_qkv->narrow({{2, 0, key_dim_}}); - auto k = conv_qkv->narrow({{2, key_dim_, key_dim_}}); - auto v = conv_qkv->narrow({{2, key_dim_ * 2, value_dim_}}); + auto q = conv_qkv->narrow({{2, 0, local_key_dim_}}); + auto k = conv_qkv->narrow({{2, local_key_dim_, local_key_dim_}}); + auto v = conv_qkv->narrow({{2, local_key_dim_ * 2, local_value_dim_}}); bool is_decode = mamba_metadata.input_offsets.value()->shape()[0] - 1 == seq_len; infinicore::Tensor delta_out; if (is_decode) { auto ssm_state = forward_context.ssm_state_vec[layer_idx_]; - auto q_delta = q->view({seq_len, 1, linear_num_key_heads_, linear_key_head_dim_}); - auto k_delta = k->view({seq_len, 1, linear_num_key_heads_, linear_key_head_dim_}); - auto v_delta = v->view({seq_len, 1, linear_num_value_heads_, linear_value_head_dim_}); + auto q_delta = q->view({seq_len, 1, local_num_key_heads_, key_head_dim_}); + auto k_delta = k->view({seq_len, 1, local_num_key_heads_, key_head_dim_}); + auto v_delta = v->view({seq_len, 1, local_num_value_heads_, value_head_dim_}); - auto a_heads = a->view({seq_len, 1, linear_num_value_heads_}); - auto b_heads = b->view({seq_len, 1, linear_num_value_heads_}); + auto a_heads = a->view({seq_len, 1, local_num_value_heads_}); + auto b_heads = b->view({seq_len, 1, local_num_value_heads_}); auto [g, beta] = infinicore::op::fused_gated_delta_net_gating(A_log_, a_heads, b_heads, dt_bias_); delta_out = infinicore::op::recurrent_gated_delta_rule_indexed( @@ -108,15 +185,15 @@ infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &hid mamba_metadata.init_state_indices.value(), mamba_metadata.final_state_indices.value(), true); - delta_out = delta_out->view({seq_len, linear_num_value_heads_, linear_value_head_dim_}); + delta_out = delta_out->view({seq_len, local_num_value_heads_, value_head_dim_}); } else { auto ssm_state = forward_context.ssm_state_vec[layer_idx_]; - auto q_delta = q->view({1, seq_len, linear_num_key_heads_, linear_key_head_dim_}); - auto k_delta = k->view({1, seq_len, linear_num_key_heads_, linear_key_head_dim_}); - auto v_delta = v->view({1, seq_len, linear_num_value_heads_, linear_value_head_dim_}); + auto q_delta = q->view({1, seq_len, local_num_key_heads_, key_head_dim_}); + auto k_delta = k->view({1, seq_len, local_num_key_heads_, key_head_dim_}); + auto v_delta = v->view({1, seq_len, local_num_value_heads_, value_head_dim_}); - auto a_heads = a->view({1, seq_len, linear_num_value_heads_}); - auto b_heads = b->view({1, seq_len, linear_num_value_heads_}); + auto a_heads = a->view({1, seq_len, local_num_value_heads_}); + auto b_heads = b->view({1, seq_len, local_num_value_heads_}); auto [g, beta] = infinicore::op::fused_gated_delta_net_gating(A_log_, a_heads, b_heads, dt_bias_); delta_out = infinicore::op::chunk_gated_delta_rule( @@ -130,11 +207,11 @@ infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &hid mamba_metadata.init_state_indices.value(), mamba_metadata.final_state_indices.value(), true); - delta_out = delta_out->view({seq_len, linear_num_value_heads_, linear_value_head_dim_}); + delta_out = delta_out->view({seq_len, local_num_value_heads_, value_head_dim_}); } - auto v_norm = norm_->forward(delta_out->view({batch_size * seq_len * linear_num_value_heads_, linear_value_head_dim_})) - ->view({batch_size, seq_len, value_dim_}); + auto v_norm = norm_->forward(delta_out->view({batch_size * seq_len * local_num_value_heads_, value_head_dim_})) + ->view({batch_size, seq_len, local_value_dim_}); auto gated = infinicore::op::mul(v_norm, infinicore::op::silu(z)); return out_proj_->forward(gated); } diff --git a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp index e91c3c24c..e7d64e7a9 100644 --- a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp +++ b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp @@ -3,7 +3,29 @@ #include "../../layers/common_modules.hpp" namespace infinilm::models::qwen3_next { -using Qwen3Next_Fake_RMSNormGated = infinicore::nn::RMSNorm; +class Qwen3NextCausalConv1D : public infinicore::nn::Module { +public: + Qwen3NextCausalConv1D( + std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &qkv) const; + void process_weights_after_loading() override; + +private: + infinicore::nn::Parameter weight_; + + size_t layer_idx_; + size_t local_conv_dim_; + size_t full_qk_dim_; + size_t full_v_dim_; + size_t local_qk_dim_; + size_t local_v_dim_; + size_t conv_kernel_dim_; + size_t tp_rank_; + size_t tp_size_; +}; class Qwen3NextGatedDeltaNet : public infinicore::nn::Module { public: @@ -14,25 +36,24 @@ class Qwen3NextGatedDeltaNet : public infinicore::nn::Module { infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; private: - std::shared_ptr in_proj_qkv_; - std::shared_ptr in_proj_z_; - std::shared_ptr in_proj_a_; - std::shared_ptr in_proj_b_; - INFINICORE_NN_PARAMETER(conv1d_weight); + std::shared_ptr in_proj_qkv_; + std::shared_ptr in_proj_z_; + std::shared_ptr in_proj_a_; + std::shared_ptr in_proj_b_; + std::shared_ptr conv1d_; + INFINICORE_NN_PARAMETER(dt_bias); INFINICORE_NN_PARAMETER(A_log); - INFINICORE_NN_MODULE(Qwen3Next_Fake_RMSNormGated, norm); - std::shared_ptr out_proj_; + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + std::shared_ptr out_proj_; size_t layer_idx_; - size_t linear_num_value_heads_; - size_t linear_num_key_heads_; - size_t linear_key_head_dim_; - size_t linear_value_head_dim_; - size_t key_dim_; - size_t value_dim_; - size_t conv_dim_; - size_t conv_state_len_; + size_t local_num_value_heads_; + size_t local_num_key_heads_; + size_t key_head_dim_; + size_t value_head_dim_; + size_t local_key_dim_; + size_t local_value_dim_; }; } // namespace infinilm::models::qwen3_next diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index ec05f8620..bd6673295 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -661,9 +661,11 @@ def _remap_mamba(state_dict, config=None): # Model type → remap function mapping -def _remap_qwen3_5(state_dict, config=None): +def _remap_qwen3_5(state_dict, config): """Apply Qwen3.5-specific load-time weight fixes.""" state_dict = drop_keys(state_dict, ["mtp."]) + llm_config = config["text_config"] + key_dim = llm_config["linear_key_head_dim"] * llm_config["linear_num_key_heads"] norm_weight_suffixes = ( "input_layernorm.weight", @@ -672,9 +674,26 @@ def _remap_qwen3_5(state_dict, config=None): "self_attn.k_norm.weight", ) + to_drop = [] + to_add = {} for key, tensor in state_dict.items(): if key == "model.norm.weight" or key.endswith(norm_weight_suffixes): state_dict[key] = tensor + torch.ones_like(tensor) + elif key.endswith("linear_attn.in_proj_qkv.weight"): + prefix = key[: -len("in_proj_qkv.weight")] + to_add[prefix + "in_proj_q.weight"] = state_dict[key][ + :key_dim, : + ].contiguous() + to_add[prefix + "in_proj_k.weight"] = state_dict[key][ + key_dim : key_dim * 2, : + ].contiguous() + to_add[prefix + "in_proj_v.weight"] = state_dict[key][ + key_dim * 2 :, : + ].contiguous() + to_drop.append(key) + + state_dict = drop_keys(state_dict, to_drop) + state_dict.update(to_add) return state_dict From 379b3fab41967cc2c45f12b37801394d443b504c Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 26 Jun 2026 09:04:42 +0000 Subject: [PATCH 3/4] issue/446 fix bench.py --- examples/bench.py | 3 +++ python/infinilm/infer_engine.py | 35 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/examples/bench.py b/examples/bench.py index 2e3d4624d..4d3e685b3 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -51,6 +51,9 @@ def _normalize_config(config, model_type): """ normalized = dict(config) + if "text_config" in normalized: + normalized = normalized["text_config"] + key_map = _CONFIG_KEY_MAP.get(model_type) if not key_map: diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index c6d651134..80ba7b60a 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -95,6 +95,13 @@ def __init__( self.use_cache = False self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig) + llm_config = self.hf_config.get("text_config", self.hf_config) + layer_types = llm_config.get("layer_types") or [] + self.has_mamba_cache = "linear_attention" in layer_types or ( + "linear_conv_kernel_dim" in llm_config + and "linear_num_key_heads" in llm_config + and "linear_num_value_heads" in llm_config + ) @property def dtype(self): @@ -261,6 +268,20 @@ def generate( block_tables = None max_blocks_per_batch = 0 + mamba_state_indices = None + if self.has_mamba_cache: + if not self.enable_paged_attn: + raise RuntimeError( + "Low-level generate for mamba-cache models currently requires paged attention" + ) + mamba_pool_size = max(2, self.get_cache_config().num_blocks() // 4) + if batch_size > mamba_pool_size - 1: + raise RuntimeError( + f"Batch size {batch_size} exceeds available mamba cache rows " + f"{mamba_pool_size - 1}" + ) + mamba_state_indices = list(range(1, batch_size + 1)) + if self.enable_paged_attn: paged_block_size = self.get_cache_config().block_size() max_blocks_per_batch = ( @@ -340,6 +361,18 @@ def generate( [seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32 ) + mamba_init_state_indices = None + mamba_final_state_indices = None + if mamba_state_indices is not None: + mamba_init_state_indices = infinicore.from_list( + [0] * batch_size if iter == 0 else mamba_state_indices, + dtype=infinicore.int32, + ) + mamba_final_state_indices = infinicore.from_list( + mamba_state_indices, + dtype=infinicore.int32, + ) + output_id = self( input_ids=input_ids, pixel_values=pixel_values if iter == 0 else None, @@ -350,6 +383,8 @@ def generate( cu_seqlens=cu_seqlens, block_tables=block_tables, slot_mapping=slot_mapping, + mamba_init_state_indices=mamba_init_state_indices, + mamba_final_state_indices=mamba_final_state_indices, image_bound=image_bound if iter == 0 else None, tgt_sizes=tgt_sizes if iter == 0 else None, temperature=generation_config.temperature, From fb86c4e5e69ff35fabff2c462180739043a82d54 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 26 Jun 2026 09:05:34 +0000 Subject: [PATCH 4/4] issue/446 support serving with naive mamba cache manager --- .../qwen3_next_allocate_kv_cache_tensors.cpp | 2 +- python/infinilm/llm/cache_manager.py | 38 +++++++++ python/infinilm/llm/llm.py | 14 ++++ python/infinilm/llm/request.py | 7 +- python/infinilm/llm/scheduler.py | 81 +++++++++++++++---- .../infinilm/processors/qwen3_5_processor.py | 37 ++++++--- 6 files changed, 145 insertions(+), 34 deletions(-) diff --git a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp index b077732c8..3aaf898cb 100644 --- a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp +++ b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp @@ -121,7 +121,7 @@ AllocatedHybridCache qwen3_next_allocate_cache_tensors( if (nullptr == paged_kv_cache_config) { throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: invalid paged kv cache config type"); } - const size_t mamba_pool_size = std::max(1, paged_kv_cache_config->num_blocks() / 4); + const size_t mamba_pool_size = std::max(2, paged_kv_cache_config->num_blocks() / 4); for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { const std::string &layer_type = layer_types[layer_idx]; diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py index 8e8007cdd..3e3ba9fdd 100644 --- a/python/infinilm/llm/cache_manager.py +++ b/python/infinilm/llm/cache_manager.py @@ -35,6 +35,44 @@ def free(self) -> None: self.token_ids = [] +class MambaCacheManager: + """Manages request ownership of mamba state cache rows. + + Row 0 is reserved as the permanent zero state. Request-owned rows are + allocated from [1, num_blocks). + """ + + ZERO_STATE_INDEX = 0 + + def __init__(self, num_blocks: int): + if num_blocks < 2: + raise ValueError("mamba cache pool size must be at least 2") + self.num_blocks = num_blocks + self.free_block_ids: deque = deque(range(1, num_blocks)) + self.used_block_ids: Set[int] = set() + + def can_allocate(self) -> bool: + return bool(self.free_block_ids) + + def allocate(self) -> int | None: + if not self.free_block_ids: + return None + block_id = self.free_block_ids.popleft() + self.used_block_ids.add(block_id) + return block_id + + def free(self, block_id: int | None) -> None: + if block_id is None or block_id == self.ZERO_STATE_INDEX: + return + if block_id not in self.used_block_ids: + return + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def get_num_free_blocks(self) -> int: + return len(self.free_block_ids) + + class BlockManager: """Manages Paged KV Cache allocation with prefix caching support. diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index 81695c79b..5b078eea9 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -74,6 +74,13 @@ def __init__(self, config: EngineConfig): max_position_embeddings = llm_config.get( "max_position_embeddings", config.max_cache_len ) + layer_types = llm_config.get("layer_types") or [] + has_mamba_cache = "linear_attention" in layer_types or ( + "linear_conv_kernel_dim" in llm_config + and "linear_num_key_heads" in llm_config + and "linear_num_value_heads" in llm_config + ) + num_mamba_cache_blocks = max(2, config.num_blocks // 4) max_num_batched_tokens = int( os.getenv("INFINILM_MAX_NUM_BATCHED_TOKENS", max_position_embeddings) @@ -86,8 +93,15 @@ def __init__(self, config: EngineConfig): block_size=config.block_size, max_num_batched_tokens=max_num_batched_tokens, connector=connector, + has_mamba_cache=has_mamba_cache, + num_mamba_cache_blocks=num_mamba_cache_blocks, ) logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}") + if has_mamba_cache: + logger.info( + "Using Mamba cache with num_blocks=%s, zero_state_index=0", + num_mamba_cache_blocks, + ) else: raise ValueError(f"Unsupported cache_type: {config.cache_type}") diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 16f12efe5..4576cd8cc 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -148,9 +148,7 @@ def __init__( # Generation state self.generated_token_ids: List[int] = [] - self.generated_text: str = ( - "" # generated_text == tokenizer.decode(generated_token_ids[:_token_decode_offset]) - ) + self.generated_text: str = "" # generated_text == tokenizer.decode(generated_token_ids[:_token_decode_offset]) self.status: RequestStatus = RequestStatus.WAITING self.finish_reason: Optional[FinishReason] = None @@ -163,6 +161,9 @@ def __init__( self.num_computed_tokens: int = 0 # Total tokens computed (local + remote) self.num_blocks: int = 0 + # Mamba cache management. None means no mamba cache row is currently owned. + self.mamba_cache_index: Optional[int] = None + # PD disaggregation support self.kv_transfer_params: Optional[dict] = ( None # KV transfer parameters from the router diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py index c99c54f59..e24dd00ba 100644 --- a/python/infinilm/llm/scheduler.py +++ b/python/infinilm/llm/scheduler.py @@ -7,7 +7,7 @@ import logging from typing import List, Optional from infinilm.llm.request import RequestStatus, InferenceRequest -from infinilm.llm.cache_manager import BlockManager +from infinilm.llm.cache_manager import BlockManager, MambaCacheManager logger = logging.getLogger(__name__) @@ -42,6 +42,8 @@ def __init__( block_size: int = 256, max_num_batched_tokens: int = 1024, connector=None, + has_mamba_cache: bool = False, + num_mamba_cache_blocks: int | None = None, ): self.waiting_queue = janus.Queue() self.running_queue = janus.Queue() @@ -54,6 +56,12 @@ def __init__( self.remote_kv_requests: dict[str, InferenceRequest] = {} self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) + self.has_mamba_cache = has_mamba_cache + self.mamba_cache_manager = ( + MambaCacheManager(num_mamba_cache_blocks or max(2, num_blocks // 4)) + if has_mamba_cache + else None + ) self.block_size = block_size self.max_num_batched_tokens = max_num_batched_tokens self.connector = connector @@ -106,23 +114,30 @@ def schedule(self) -> Optional[SchedulerOutput]: req_tokens = req.get_input_tokens() if req.num_computed_tokens == 0: - ( - cached_block_table, - num_local_computed_tokens, - blocks_blueprint, - ) = self.cache_manager.get_computed_blocks( - req_tokens, req.get_mm_token_index_mappings() - ) - if self.connector is not None: - ext_tokens, load_kv_async = ( - self.connector.get_num_new_matched_tokens( - req, num_local_computed_tokens - ) - ) - num_external_computed_tokens = ext_tokens - else: + if self.has_mamba_cache: + cached_block_table = [] + num_local_computed_tokens = 0 + blocks_blueprint = [] load_kv_async = False num_external_computed_tokens = 0 + else: + ( + cached_block_table, + num_local_computed_tokens, + blocks_blueprint, + ) = self.cache_manager.get_computed_blocks( + req_tokens, req.get_mm_token_index_mappings() + ) + if self.connector is not None: + ext_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + req, num_local_computed_tokens + ) + ) + num_external_computed_tokens = ext_tokens + else: + load_kv_async = False + num_external_computed_tokens = 0 num_computed_tokens = ( num_local_computed_tokens + num_external_computed_tokens @@ -181,6 +196,17 @@ def schedule(self) -> Optional[SchedulerOutput]: deferred_requests.append(req) break + if self.has_mamba_cache and req.mamba_cache_index is None: + req.mamba_cache_index = self.mamba_cache_manager.allocate() + if req.mamba_cache_index is None: + self.cache_manager.free_blocks(req_blocks) + logger.warning( + "Insufficient mamba cache rows for request %s, deferring.", + req.request_id, + ) + deferred_requests.append(req) + break + req.block_table = req_blocks req.slot_mapping = slot_mapping req.num_blocks = len(req_blocks) @@ -373,6 +399,9 @@ def complete_requests(self, requests: List[InferenceRequest]): self.cache_manager.free_blocks(req.block_table) elif req.block_table and delay_free_blocks: self.pending_free_blocks[req.request_id] = list(req.block_table) + if self.mamba_cache_manager is not None: + self.mamba_cache_manager.free(req.mamba_cache_index) + req.mamba_cache_index = None if req.status == RequestStatus.CANCELED: logger.info( @@ -396,6 +425,13 @@ def can_accept_request( num_local_computed_tokens: int, current_prefill_extra_blocks: int = 0, ) -> bool: + if ( + self.mamba_cache_manager is not None + and request.mamba_cache_index is None + and not self.mamba_cache_manager.can_allocate() + ): + return False + total_required_blocks = 0 # Calculate blocks needed for running requests @@ -484,7 +520,7 @@ def update_from_output(self, model_output): def get_cache_stats(self) -> dict: """Get cache statistics.""" - return { + stats = { "num_blocks": self.cache_manager.num_blocks, "block_size": self.cache_manager.block_size, "num_free_blocks": self.cache_manager.get_num_free_blocks(), @@ -492,3 +528,14 @@ def get_cache_stats(self) -> dict: "num_pending_blocks": len(self.cache_manager.pending_block_ids), "num_used_blocks": len(self.cache_manager.used_block_ids), } + if self.mamba_cache_manager is not None: + stats.update( + { + "num_mamba_cache_blocks": self.mamba_cache_manager.num_blocks, + "num_free_mamba_cache_blocks": self.mamba_cache_manager.get_num_free_blocks(), + "num_used_mamba_cache_blocks": len( + self.mamba_cache_manager.used_block_ids + ), + } + ) + return stats diff --git a/python/infinilm/processors/qwen3_5_processor.py b/python/infinilm/processors/qwen3_5_processor.py index d59d59141..168020e9f 100644 --- a/python/infinilm/processors/qwen3_5_processor.py +++ b/python/infinilm/processors/qwen3_5_processor.py @@ -22,8 +22,6 @@ def __init__(self, model_dir_path: str): model_dir_path, trust_remote_code=True ) - - @override def __call__( self, @@ -69,19 +67,25 @@ def apply_chat_template( for item in content: item_type = item.get("type") if item_type == "text": - normalized_content.append({"type": "text", "text": item.get("text", "")}) + normalized_content.append( + {"type": "text", "text": item.get("text", "")} + ) elif item_type == "image_url": normalized_content.append({"type": "image"}) elif item_type == "video_url": normalized_content.append({"type": "video"}) else: - raise NotImplementedError(f"Unsupported Qwen3.5 content type: {item_type}") + raise NotImplementedError( + f"Unsupported Qwen3.5 content type: {item_type}" + ) normalized_conversation.append( {"role": message.get("role", "user"), "content": normalized_content} ) - template_owner = self.processor if self.processor is not None else self.tokenizer + template_owner = ( + self.processor if self.processor is not None else self.tokenizer + ) return template_owner.apply_chat_template( conversation=normalized_conversation, add_generation_prompt=add_generation_prompt, @@ -112,12 +116,18 @@ def build_model_inputs( "scheduler_output must be an instance of SchedulerOutput or StaticSchedulerOutput" ) - # TODO(qwen3_5): The scheduler should own stable mamba cache ids. For now - # use a per-forward arange so the C++ model input and mamba metadata path - # can be exercised without encoding cache policy in the processor. - num_requests = len(scheduler_output.scheduled_requests) - init_indices = list(range(num_requests)) - final_indices = list(range(num_requests)) + init_indices = [] + final_indices = [] + for req in scheduler_output.scheduled_requests: + if req.mamba_cache_index is None: + raise RuntimeError( + f"Request {req.request_id} has no assigned mamba cache index" + ) + if scheduler_output.is_prefill: + init_indices.append(0) + else: + init_indices.append(req.mamba_cache_index) + final_indices.append(req.mamba_cache_index) import infinicore @@ -155,7 +165,9 @@ def _append_qwen35_mm_inputs( if pixel_values: pixel_values = [ - infinicore.from_torch(t if isinstance(t, torch.Tensor) else torch.as_tensor(t)) + infinicore.from_torch( + t if isinstance(t, torch.Tensor) else torch.as_tensor(t) + ) for t in pixel_values ] model_inputs["pixel_values"] = pixel_values @@ -166,6 +178,5 @@ def get_mm_token_index_list( self, prompt_token_ids, image_ids=None, video_ids=None, audio_ids=None, **kwargs ): mm_token_index_list = [] - return mm_token_index_list