Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 336 additions & 0 deletions csrc/models/rwkv/rwkv5_for_causal_lm.cpp

Large diffs are not rendered by default.

145 changes: 145 additions & 0 deletions csrc/models/rwkv/rwkv5_for_causal_lm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#pragma once

#include "../../cache/kv_cache.hpp"
#include "../../config/model_config.hpp"
#include "../../layers/linear/linear.hpp"
#include "../infinilm_model.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/layer_norm.hpp"
#include "infinicore/nn/parameter.hpp"
#include "infinicore/tensor.hpp"

#include <memory>
#include <optional>
#include <vector>

namespace infinilm::models::rwkv {

class Rwkv5SelfAttention : public infinicore::nn::Module {
public:
Rwkv5SelfAttention(std::shared_ptr<infinilm::config::ModelConfig> config,
size_t layer_idx,
const infinicore::Device &device);

infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
infinicore::Tensor &attn_x_state,
infinicore::Tensor &wkv_state) const;

private:
infinicore::Tensor shifted_hidden_(const infinicore::Tensor &hidden_states,
infinicore::Tensor &state) const;
infinicore::Tensor group_norm_(const infinicore::Tensor &x) const;

size_t layer_idx_;
size_t hidden_size_;
size_t attention_hidden_size_;
size_t head_size_;
size_t num_heads_;
size_t head_size_divisor_;

INFINICORE_NN_PARAMETER(time_decay);
INFINICORE_NN_PARAMETER(time_faaaa);
INFINICORE_NN_PARAMETER(time_mix_gate);
INFINICORE_NN_PARAMETER(time_mix_key);
INFINICORE_NN_PARAMETER(time_mix_value);
INFINICORE_NN_PARAMETER(time_mix_receptance);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, key);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, value);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, receptance);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, gate);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, output);
INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln_x);
};

class Rwkv5FeedForward : public infinicore::nn::Module {
public:
Rwkv5FeedForward(std::shared_ptr<infinilm::config::ModelConfig> config,
size_t layer_idx,
const infinicore::Device &device);

infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
infinicore::Tensor &ffn_x_state) const;

private:
infinicore::Tensor shifted_hidden_(const infinicore::Tensor &hidden_states,
infinicore::Tensor &state) const;

size_t layer_idx_;
size_t hidden_size_;
size_t intermediate_size_;

INFINICORE_NN_PARAMETER(time_mix_key);
INFINICORE_NN_PARAMETER(time_mix_receptance);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, key);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, receptance);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, value);
};

class Rwkv5Block : public infinicore::nn::Module {
public:
Rwkv5Block(std::shared_ptr<infinilm::config::ModelConfig> config,
size_t layer_idx,
const infinicore::Device &device);

infinicore::Tensor forward(const infinicore::Tensor &hidden_states,
infinicore::Tensor &attn_x_state,
infinicore::Tensor &wkv_state,
infinicore::Tensor &ffn_x_state) const;

private:
size_t layer_idx_;
size_t rescale_every_;
std::shared_ptr<infinicore::nn::LayerNorm> pre_ln_;
INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln1);
INFINICORE_NN_MODULE(Rwkv5SelfAttention, attention);
INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln2);
INFINICORE_NN_MODULE(Rwkv5FeedForward, feed_forward);
};

class Rwkv5Model : public infinicore::nn::Module {
public:
Rwkv5Model(std::shared_ptr<infinilm::config::ModelConfig> config,
const infinicore::Device &device);

infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input,
infinicore::Tensor &attn_x_state,
infinicore::Tensor &wkv_state,
infinicore::Tensor &ffn_x_state) const;

private:
size_t rescale_every_;

INFINICORE_NN_MODULE(infinicore::nn::Embedding, embeddings);
INFINICORE_NN_MODULE_VEC(Rwkv5Block, blocks);
INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln_out);
};

class Rwkv5ForCausalLM : public infinilm::InfinilmModel {
public:
Rwkv5ForCausalLM(std::shared_ptr<infinilm::config::ModelConfig> config,
const infinicore::Device &device);

Output forward(const Input &input) const override;
void reset_cache(const cache::CacheConfig *cache_config) override;

private:
void ensure_state_(size_t batch_size) const;

size_t num_hidden_layers_;
size_t hidden_size_;
size_t num_heads_;
size_t head_size_;
infinicore::Device device_;
infinicore::DataType dtype_;
mutable size_t state_batch_size_ = 0;
mutable infinicore::Tensor attn_x_state_;
mutable infinicore::Tensor wkv_state_;
mutable infinicore::Tensor ffn_x_state_;

INFINICORE_NN_MODULE(Rwkv5Model, rwkv);
INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, head);
};

std::shared_ptr<infinilm::config::ModelConfig> create_rwkv5_model_config(std::shared_ptr<infinilm::config::ModelConfig> model_config);

} // namespace infinilm::models::rwkv
54 changes: 46 additions & 8 deletions examples/bench.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import infinicore
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.infer_engine import (
GenerationConfig,
InferEngine,
read_hf_generation_config,
)
from infinilm.base_config import BaseConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.processors import AutoInfinilmProcessor
Expand Down Expand Up @@ -38,6 +42,10 @@
"num_key_value_heads": "num_attention_heads",
"head_dim": lambda cfg: cfg["hidden_size"] // cfg["num_attention_heads"],
},
"rwkv5": {
"num_key_value_heads": "num_attention_heads",
"head_dim": "head_size",
},
}


Expand Down Expand Up @@ -88,6 +96,23 @@ def read_json_file(file_path):
return json.load(file)


def resolve_generation_defaults(model_path, top_k, top_p, temperature):
generation_config = read_hf_generation_config(model_path)

def resolve(value, name, fallback, cast):
if value is None:
value = generation_config.get(name)
if value is None:
value = fallback
return cast(value)

return (
resolve(top_k, "top_k", 1, int),
resolve(top_p, "top_p", 1.0, float),
resolve(temperature, "temperature", 1.0, float),
)


def get_test_cases(
model_path: str,
batch_size_list: list[int],
Expand Down Expand Up @@ -146,13 +171,18 @@ def get_test_cases(
return case_dict


prompt_path = (
"examples/bench_prompt.md"
if os.path.isfile("examples/bench_prompt.md")
else "InfiniLM/examples/bench_prompt.md"
)
with open(prompt_path, "r") as f:
prompt = f.read()
def read_bench_prompt():
prompt_path = (
"examples/bench_prompt.md"
if os.path.isfile("examples/bench_prompt.md")
else "InfiniLM/examples/bench_prompt.md"
)
with open(prompt_path, "r") as f:
return f.read()


def has_prompt_override():
return any(arg == "--prompt" or arg.startswith("--prompt=") for arg in sys.argv[1:])


def repeat_prompt(input_ids: list[int], target_length: int):
Expand All @@ -174,8 +204,10 @@ def __init__(
cache_config=None,
enable_graph=False,
attn_backend="default",
prompt=None,
) -> None:
model_path = os.path.expanduser(model_path)
prompt = read_bench_prompt() if prompt is None else prompt
# ---------------------------------------------------------------------------- #
# 创建模型,
# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -286,6 +318,9 @@ def run(
enable_paged_attn = cfg.enable_paged_attn
enable_graph = cfg.enable_graph
attn_backend = cfg.attn
cfg.top_k, cfg.top_p, cfg.temperature = resolve_generation_defaults(
model_path, cfg.top_k, cfg.top_p, cfg.temperature
)

if isinstance(batch_size, int):
batch_size = [batch_size]
Expand Down Expand Up @@ -319,6 +354,8 @@ def run(
if enable_paged_attn and attn_backend == "default":
attn_backend = "paged-attn"

prompt = cfg.prompt if has_prompt_override() else read_bench_prompt()

test = TestModel(
model_path,
infini_device=infini_device,
Expand All @@ -327,6 +364,7 @@ def run(
cache_config=cache_config,
enable_graph=enable_graph,
attn_backend=attn_backend,
prompt=prompt,
)

# ---------------------------------------------------------------------------- #
Expand Down
6 changes: 3 additions & 3 deletions examples/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def test(
tp=1,
enable_paged_attn=False,
enable_graph=False,
top_k=1,
top_p=1.0,
temperature=1.0,
top_k=None,
top_p=None,
temperature=None,
attn_backend="default",
use_mla=False,
image_path=None,
Expand Down
6 changes: 3 additions & 3 deletions python/infinilm/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def _add_common_args(self):
self.parser.add_argument(
"--prompt", type=str, default="How are you", help="default prompt text"
)
self.parser.add_argument("--top-k", type=int, default=1)
self.parser.add_argument("--top-p", type=float, default=1.0)
self.parser.add_argument("--temperature", type=float, default=1.0)
self.parser.add_argument("--top-k", type=int, default=None)
self.parser.add_argument("--top-p", type=float, default=None)
self.parser.add_argument("--temperature", type=float, default=None)

# --- debug ---
self.parser.add_argument("--warmup", action="store_true")
Expand Down
6 changes: 3 additions & 3 deletions python/infinilm/config/engine_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class EngineConfig:
num_blocks: int = 512
block_size: int = 256
max_cache_len: int = 4096
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
enable_graph: bool = False
attn_backend: str = "default"
use_mla: bool = False
Expand Down
12 changes: 11 additions & 1 deletion python/infinilm/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def read_hf_config(model_path):
and config_dict.get("dtype") is None
):
config_dict["torch_dtype"] = "float32"
if config_dict.get("model_type") == "rwkv5":
if config_dict.get("torch_dtype") is None and config_dict.get("dtype") is None:
config_dict["torch_dtype"] = "bfloat16"
config_dict.setdefault(
"max_position_embeddings", config_dict.get("context_length", 4096)
)
if "model_type" not in config_dict:
raise ValueError(
f"`model_type` is not specified in the config file `{config_path}`."
Expand Down Expand Up @@ -239,6 +245,10 @@ def generate(
if _measure_and_log_time:
time_measurements = []

is_rwkv = self.model_type == "rwkv5"
if is_rwkv:
self.reset_cache(self.get_cache_config())

block_tables = None
max_blocks_per_batch = 0
if self.enable_paged_attn:
Expand All @@ -263,7 +273,7 @@ def generate(

batch_size, seq_len = input_ids.shape[:2]

if self.enable_paged_attn:
if self.enable_paged_attn and not is_rwkv:
input_ids = input_ids.view([1, batch_size * seq_len])
position_ids = infinicore.from_list(
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
Expand Down
Loading