Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/infinicore/ops/paged_attention/paged_attention_infiniops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,27 @@

#include "base/paged_attention_infinilm.h"

#include <cstddef>
#include <optional>

namespace infinicore::op::paged_attention_impl::infiniops {
namespace {
using TensorMeta = ::infinicore::op::infiniops::TensorMeta;

constexpr std::size_t kMaxPagedAttentionSplits = 8;

std::size_t WorkspaceSizeInBytes(const Tensor &q) {
return kMaxPagedAttentionSplits
* static_cast<std::size_t>(q->size(0))
* static_cast<std::size_t>(q->size(1))
* static_cast<std::size_t>(q->size(2) + 2)
* sizeof(float);
}

struct PlannedMeta {
TensorMeta out, q, k_cache, v_cache, block_tables, cache_lens;
std::optional<TensorMeta> alibi_slopes;
graph::GraphTensor out_tensor, q_tensor, k_cache_tensor, v_cache_tensor, block_tables_tensor, cache_lens_tensor;
graph::GraphTensor workspace, out_tensor, q_tensor, k_cache_tensor, v_cache_tensor, block_tables_tensor, cache_lens_tensor;
std::optional<graph::GraphTensor> alibi_slopes_tensor;
float scale;
};
Expand All @@ -35,6 +47,7 @@ void *plan(Tensor out,
return new PlannedMeta{
TensorMeta(out), TensorMeta(q), TensorMeta(k_cache), TensorMeta(v_cache), TensorMeta(block_tables), TensorMeta(cache_lens),
alibi_slopes ? std::optional<TensorMeta>{TensorMeta(*alibi_slopes)} : std::nullopt,
graph::GraphTensor(Tensor::empty({WorkspaceSizeInBytes(q)}, DataType::U8, out->device())),
graph::GraphTensor(out), graph::GraphTensor(q), graph::GraphTensor(k_cache), graph::GraphTensor(v_cache), graph::GraphTensor(block_tables), graph::GraphTensor(cache_lens),
alibi_slopes ? std::optional<graph::GraphTensor>{graph::GraphTensor(*alibi_slopes)} : std::nullopt,
scale};
Expand All @@ -44,6 +57,8 @@ void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
infini::ops::Handle handle;
handle.set_stream(context::getStream());
handle.set_workspace(planned->workspace->data());
handle.set_workspace_size_in_bytes(planned->workspace->numel());
infini::ops::Config config;
infini::ops::PagedAttentionInfinilm::Call(
handle,
Expand Down
23 changes: 23 additions & 0 deletions xmake.lua
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,25 @@ local function get_standalone_infinirt_root()
return nil
end

local function get_infiniops_cuda_architectures()
local arch_opt = get_config("cuda_arch")
if not arch_opt or arch_opt == "" then
return nil
end

local cmake_archs = {}
for _, arch in ipairs(arch_opt:gsub(";", ","):split(",")) do
local cmake_arch = arch:trim():match("^sm_(%d+a?)$")
if cmake_arch then
table.insert(cmake_archs, cmake_arch)
end
end
if #cmake_archs == 0 then
return nil
end
return table.concat(cmake_archs, ";")
end

local infiniops_external_built = false

local function build_infiniops_external(xmake_os)
Expand All @@ -318,6 +337,10 @@ local function build_infiniops_external(xmake_os)
if infiniops_ops and #infiniops_ops > 0 then
table.insert(cmake_config_args, "-DINFINI_OPS_OPS=" .. infiniops_ops)
end
local cmake_cuda_architectures = get_infiniops_cuda_architectures()
if cmake_cuda_architectures and cmake_cuda_architectures ~= "" then
table.insert(cmake_config_args, "-DCMAKE_CUDA_ARCHITECTURES=" .. cmake_cuda_architectures)
end
if infinirt_root and infinirt_root ~= "" then
table.insert(cmake_config_args, "-DINFINI_RT_ROOT=" .. infinirt_root)
end
Expand Down
Loading