diff --git a/3rdparty/tvm b/3rdparty/tvm index 0be33607c2..8435b89211 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0be33607c200d65964af75641e7420cfd6d1c6c4 +Subproject commit 8435b8921105dbf8d53c1e0e66b0564bdb7b2253 diff --git a/src/backend/rocm/codegen/codegen_hip.cc b/src/backend/rocm/codegen/codegen_hip.cc index 6d288a4db5..ecb12248e6 100644 --- a/src/backend/rocm/codegen/codegen_hip.cc +++ b/src/backend/rocm/codegen/codegen_hip.cc @@ -56,9 +56,12 @@ std::optional GetAccessPtrElementType(const PrimExpr &expr) { } int GetTileLangCPAsyncTransferBytes(const CallNode *op) { - ICHECK(op->args.size() == 3 || op->args.size() == 4) - << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " - "src_access_ptr, num_elems, [predicate])"; + // Accepts ptx_cp_async / ptx_cp_async_lds (3 or 4 args: dst, src, num_elems, + // [predicate]) and ptx_cp_async_lds_rsrc (5 args: dst, src, num_elems, + // rsrc_var, base_var) -- only args[0..2] are read here. + ICHECK(op->args.size() == 3 || op->args.size() == 4 || op->args.size() == 5) + << "tl::ptx_cp_async family expects 3-5 arguments (dst_access_ptr, " + "src_access_ptr, num_elems, ...)"; const auto *num_elems_imm = op->args[2].as(); ICHECK(num_elems_imm) << "tl::ptx_cp_async num_elems must be IntImm, but got " << op->args[2]; @@ -1169,9 +1172,31 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } this->stream << ");\n"; }; - if (op->op.same_as(builtin::ptx_cp_async())) { - // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, - // args[3] = predicate (optional) + if (op->op.same_as(tl::ptx_make_buffer_resource())) { + // Expression form: emits make_wave_buffer_resource((const void*)(ptr)). + // The enclosing LetStmt visitor recognises this Call and emits `auto x =`. + ICHECK(op->args.size() == 1) + << "ptx_make_buffer_resource expects 1 argument (global_ptr)"; + std::string ptr = this->PrintExpr(op->args[0]); + os << "make_wave_buffer_resource((const void*)(" << ptr << "))"; + } else if (op->op.same_as(tl::ptx_cp_async_lds_rsrc())) { + // args = [dst, src, num_elems, rsrc_var, base_var]. arg 2 is the logical + // element count inherited from the ptx_cp_async_lds call that + // HoistBufferResource rewrote into this rsrc form -- the helper does the + // src/dst width-equality and {4,8,16} validation that the plain + // ptx_cp_async path also relies on. + ICHECK(op->args.size() == 5) << "ptx_cp_async_lds_rsrc expects 5 arguments"; + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + int total_bytes = GetTileLangCPAsyncTransferBytes(op); + std::string size = std::to_string(total_bytes); + std::string rsrc = this->PrintExpr(op->args[3]); + std::string base = this->PrintExpr(op->args[4]); + this->PrintIndent(); + this->stream << "tl::cp_async_gs_lds_with_rsrc<" << size << ">(" << dst + << ", " << src << ", " << rsrc << ", " << base << ");\n"; + } else if (op->op.same_as(builtin::ptx_cp_async())) { + // builtin::ptx_cp_async stores byte width directly in arg 2. ICHECK(op->args.size() == 3 || op->args.size() == 4) << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " "src_access_ptr, bytes, [predicate])"; @@ -1189,7 +1214,18 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << ", " << src << ", " << condition << ");\n"; } - } else if (op->op.same_as(tl::ptx_cp_async())) { + } else if (op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())) { + // Both store logical element count in arg 2; convert to bytes via + // GetTileLangCPAsyncTransferBytes. + // + // tl::ptx_cp_async_lds is normally rewritten to ptx_cp_async_lds_rsrc + // by the HoistBufferResource pass. If a call survives the rewrite + // (e.g. an access_ptr shape _extract_buffer_var can't pattern-match, + // or the pass found nothing to hoist), fall back to the synchronous + // tl::cp_async_gs path here -- correctness is preserved at + // the cost of giving up the buffer_load_dwordx4...lds fast path for + // that particular call. Treat both ops identically in codegen. int total_bytes = GetTileLangCPAsyncTransferBytes(op); std::string dst = this->PrintExpr(op->args[0]); std::string src = this->PrintExpr(op->args[1]); @@ -1207,6 +1243,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { print_extern_call_stmt("tl::cp_async_commit"); } else if (op->op.same_as(builtin::ptx_wait_group())) { int n = Downcast(op->args[0])->value; + // AMDGPU s_waitcnt vmcnt field is 6-bit (max 63); clamp to keep the + // "n"(cnt) immediate constraint in tl::cp_async_wait valid. + if (n > 63) + n = 63; std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; print_extern_call_stmt(func_name, 1); } else if (op->op.same_as(builtin::create_barriers())) { @@ -1693,7 +1733,33 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { } void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { - if (op->attr_key == tl::attr::kLexicalAllocScope) { + if (op->attr_key == "buffer_resource_var") { + // Hoisted resource descriptor from the HoistBufferResource Python pass. + // Emits: auto {rsrc_var} = make_wave_buffer_resource((const + // void*)({buf_var})); + auto rsrc_var = Downcast(op->node); + std::string rsrc_vid = AllocVarID(rsrc_var.get()); + std::string buf_ptr = PrintExpr(op->value); + this->PrintIndent(); + this->stream << "auto " << rsrc_vid + << " = make_wave_buffer_resource((const void*)(" << buf_ptr + << "));\n"; + this->VisitStmt(op->body); + return; + } else if (op->attr_key == "buffer_base_var") { + // Hoisted readfirstlane base address from the HoistBufferResource pass. + // Emits: uint32_t {base_var} = __builtin_amdgcn_readfirstlane( + // (uint32_t)(uintptr_t)({buf_var})); + auto base_var = Downcast(op->node); + std::string base_vid = AllocVarID(base_var.get()); + std::string buf_ptr = PrintExpr(op->value); + this->PrintIndent(); + this->stream << "uint32_t " << base_vid + << " = __builtin_amdgcn_readfirstlane(" + << "(uint32_t)(uintptr_t)(" << buf_ptr << "));\n"; + this->VisitStmt(op->body); + return; + } else if (op->attr_key == tl::attr::kLexicalAllocScope) { PrintIndent(); stream << "{\n"; int scope = BeginScope(); @@ -1728,6 +1794,24 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) { CodeGenC::VisitStmt_(op); } +void CodeGenTileLangHIP::VisitStmt_(const BindNode *op) { + // For Bind(var = ptx_make_buffer_resource(buf)), emit `auto x = ...;` + // instead of the C-typed declaration the base class would produce. The + // return type is int32x4_t and naming it explicitly is brittle across + // backends, so `auto` keeps the template lookup in make_wave_buffer_resource + // responsible for the type. The body that follows the bind is handled by + // the enclosing SeqStmt visitor. + if (auto *call = op->value.as()) { + if (call->op.same_as(tl::ptx_make_buffer_resource())) { + std::string value = PrintExpr(op->value); + PrintIndent(); + stream << "auto " << AllocVarID(op->var.get()) << " = " << value << ";\n"; + return; + } + } + CodeGenC::VisitStmt_(op); +} + void CodeGenTileLangHIP::VisitStmt_(const AllocBufferNode *op) { std::string vid = AllocVarID(op->buffer->data.get()); diff --git a/src/backend/rocm/codegen/codegen_hip.h b/src/backend/rocm/codegen/codegen_hip.h index c5c06cf78e..af3a34fea1 100644 --- a/src/backend/rocm/codegen/codegen_hip.h +++ b/src/backend/rocm/codegen/codegen_hip.h @@ -53,6 +53,7 @@ class CodeGenTileLangHIP final : public CodeGenC { void VisitExpr_(const ShuffleNode *op, std::ostream &os) final; // NOLINT(*) void VisitStmt_(const AllocBufferNode *op) final; void VisitStmt_(const AttrStmtNode *op) final; + void VisitStmt_(const BindNode *op) final; void VisitStmt_(const BufferStoreNode *op) final; // Override this as a work around for __grid_constant__ parameter diff --git a/src/backend/rocm/op/copy.cc b/src/backend/rocm/op/copy.cc index 4ccf537c0b..65f9815133 100644 --- a/src/backend/rocm/op/copy.cc +++ b/src/backend/rocm/op/copy.cc @@ -133,7 +133,9 @@ struct Copy { auto inject_result = InjectPTXAsyncCopy(lowered_loop, /*enable_auto_async_copy=*/true, /*async_without_async_commit_wait=*/ - no_implicit_commit_wait || GetIsAsyncCopy(op)); + no_implicit_commit_wait || GetIsAsyncCopy(op), + /*enable_buffer_load_lds=*/ + TargetIsGfx950(T.target)); Stmt cp_async_loop = inject_result.stmt; if (!inject_result.injected_ptx_async_copy) { DLOG(WARNING) << "cp.async rewrite miss for copy src=" << op.src->name diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 5ff74cf749..fcaca6b069 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -457,7 +457,10 @@ static Layout MakeQuarterBankSwizzleLayout2D(int stride, int continuous, PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor2x2(c, FloorDiv(s, 4)); PrimExpr index = vec + (c_swizzle + s * 2) * vector_size; - return Layout(Array{stride, continuous}, {tc, ts, index}); + PrimExpr swizzle_delta = (c_swizzle - c) * vector_size; + Layout result(Array{stride, continuous}, {tc, ts, index}); + const_cast(result.get())->SetSwizzleDelta(swizzle_delta); + return result; } Layout makeQuarterBankSwizzleLayout(const Buffer &buffer) { @@ -486,7 +489,10 @@ static Layout MakeHalfBankSwizzleLayout2D(int stride, int continuous, PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor4x4(c, FloorDiv(s, 2)); PrimExpr index = vec + (c_swizzle + s * 4) * vector_size; - return Layout(Array{stride, continuous}, {tc, ts, index}); + PrimExpr swizzle_delta = (c_swizzle - c) * vector_size; + Layout result(Array{stride, continuous}, {tc, ts, index}); + const_cast(result.get())->SetSwizzleDelta(swizzle_delta); + return result; } Layout makeHalfBankSwizzleLayout(const Buffer &buffer) { @@ -515,7 +521,10 @@ static Layout MakeFullBankSwizzleLayout2D(int stride, int continuous, PrimExpr vec = FloorMod(j, vector_size); PrimExpr c_swizzle = xor8x8(c, s); PrimExpr index = vec + (c_swizzle + s * 8) * vector_size; - return Layout(Array{stride, continuous}, {tc, ts, index}); + PrimExpr swizzle_delta = (c_swizzle - c) * vector_size; + Layout result(Array{stride, continuous}, {tc, ts, index}); + const_cast(result.get())->SetSwizzleDelta(swizzle_delta); + return result; } Layout makeFullBankSwizzleLayout(const Buffer &buffer) { diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 4e9c555f23..5329d00c0f 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -504,7 +504,33 @@ Layout LayoutNode::Expand(const Array &leading_shape) const { new_forward_index.push_back(Substitute(e, vmap)); } - return Layout(new_input_size, new_forward_index); + Layout result(new_input_size, new_forward_index); + // Propagate swizzle_delta_ through Expand: substitute placeholder + // indices so the delta keeps referring to the same physical input + // dimension after the leading-shape prefix is added. + if (swizzle_delta_.defined()) { + const_cast(result.get()) + ->SetSwizzleDelta(Substitute(swizzle_delta_.value(), vmap)); + } + return result; +} + +PrimExpr LayoutNode::SwizzleDelta(const Array &input_indices) const { + if (!swizzle_delta_.defined()) { + return IntImm(DataType::Int(32), 0); + } + // Substitute the last InputDim() entries of input_indices into + // swizzle_delta_, matching the convention Forward() uses. + ICHECK_GE(input_indices.size(), InputDim()) + << "SwizzleDelta requires at least " << InputDim() << " indices, but got " + << input_indices.size(); + PrimExpr delta = swizzle_delta_.value(); + size_t offset = input_indices.size() - InputDim(); + for (size_t i = 0; i < InputDim(); ++i) { + delta = + Substitute(delta, {{InputPlaceholder(i), input_indices[offset + i]}}); + } + return delta; } Fragment FragmentNode::Repeat(const Array &repeats, diff --git a/src/layout/layout.h b/src/layout/layout.h index a6890afb10..9eeee73e83 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -102,6 +102,24 @@ class LayoutNode : public Object { virtual bool IsEqual(const LayoutNode *other, bool skip_index = false) const; + /*! + * \brief Get the XOR swizzle column delta on the last input dimension. + * + * For swizzled layouts (quarter/half/full bank) returns the column + * delta caused by the XOR: delta = (c_swizzled - c) * vector_size, + * substituted against the supplied indices. For non-swizzle layouts + * returns 0. Used by the swizzle-swap optimisation in lower_tile_op.cc + * to move the XOR off the LDS-store side and onto the global-load + * side when the target supports buffer_load ... lds direct DMA. + */ + virtual PrimExpr SwizzleDelta(const Array &input_indices) const; + + /*! \brief Whether this layout carries a non-trivial swizzle delta. */ + bool HasSwizzle() const { return swizzle_delta_.defined(); } + + /*! \brief Set the swizzle delta expression (called by layout factories). */ + void SetSwizzleDelta(PrimExpr delta) { swizzle_delta_ = delta; } + static void RegisterReflection(); TVM_FFI_DECLARE_OBJECT_INFO("tl.Layout", LayoutNode, Object); static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = @@ -112,6 +130,11 @@ class LayoutNode : public Object { void UpdateAnalyzer(arith::Analyzer *analyzer) const; Array forward_index_; Array input_size_; + /*! + * \brief Optional XOR swizzle delta in terms of InputPlaceholders, set + * by swizzle layout factories and propagated through Expand/Reshape. + */ + Optional swizzle_delta_; }; /*! diff --git a/src/op/builtin.cc b/src/op/builtin.cc index df43fb2434..52a2af9e2b 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -306,6 +306,21 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_cp_async_lds) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_make_buffer_resource) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_cp_async_lds_rsrc) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(fence_proxy_async) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 6876ee4d8f..30e7ee481b 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -516,6 +516,58 @@ TVM_DLL const Op &ptx_cp_async_barrier_noinc(); */ TVM_DLL const Op &ptx_cp_async(); +/*! + * \brief Marker for an eligible async G2S copy on gfx950+. + * + * Emitted by LowerPTXAsyncCopy in place of ptx_cp_async for 16-byte + * non-predicated shared-memory writes whose LDS index is (post + * swizzle-swap, see lower_tile_op.cc) lane-contiguous. The + * HoistBufferResource pass then rewrites each call to + * ptx_cp_async_lds_rsrc with a pre-computed buffer resource + * descriptor + base address; that rsrc form is what codegen emits as + * the buffer_load_dwordx4 ... lds fast path. + * + * If a call survives the rewrite (e.g. an access_ptr the hoister + * can't pattern-match), codegen falls back to the synchronous + * tl::cp_async_gs path -- correct, but no buffer_load_lds win. + * + * ptx_cp_async_lds(dst_access_ptr, src_access_ptr, num_elems) + * + * num_elems is the logical element count (NOT byte width). Lowering + * derives the {4, 8, 16} byte transfer width from the access-ptr dtype. + * Passing this as elements keeps vec-loop folding in vectorize_loop.cc + * (which multiplies the count when it widens a loop) consistent with + * the plain ptx_cp_async path. + */ +TVM_DLL const Op &ptx_cp_async_lds(); + +/*! + * \brief Create a buffer resource descriptor for async G2S LDS copy (gfx950+). + * + * ptx_make_buffer_resource(global_ptr) + * + * Returns an int32x4_t buffer resource descriptor via + * make_wave_buffer_resource (defined in src/tl_templates/hip/copy.h). + */ +TVM_DLL const Op &ptx_make_buffer_resource(); + +/*! + * \brief Truly async G2S copy with pre-computed buffer resource (gfx950+). + * + * Same as ptx_cp_async_lds but takes a pre-hoisted buffer resource + * descriptor + base address to avoid redundant readfirstlane / + * make_wave_buffer_resource calls inside unrolled loops. The + * HoistBufferResource Python pass rewrites ptx_cp_async_lds calls to this + * form once per kernel. + * + * ptx_cp_async_lds_rsrc(dst_access_ptr, src_access_ptr, num_elems, rsrc_var, + * base_var) + * + * num_elems uses the same convention as ptx_cp_async_lds -- logical + * element count, not bytes; lowering converts via the access-ptr dtype. + */ +TVM_DLL const Op &ptx_cp_async_lds_rsrc(); + /*! * \brief Pack two b16 value into a b32 value * diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 87f7522cd8..2c1cf7cdd0 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -401,6 +401,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // 3. Non-replicated read buffer // 4. Fully replicated write buffer (backup, may cause issues) // 5. Free inference mode (no source buffer) + // Early chunk-block-aware override: if a written shared buffer has a + // FullBank-style swizzled layout with tc > 1, the default flatten policy + // produces a binding whose wavefront lanes straddle the chunk-block + // boundary, which breaks lane-contiguous LDS WRITEs (buffer_load ... lds). + // Override here BEFORE source_buffer dispatch so the CBA fragment wins. + // Fire at ALL levels so the first level that sees the layout map populated + // wins; subsequent levels short-circuit via loop_layout_inferred_. + // + // gfx950-only: this hook exists to make buffer_load_dwordx4...lds usable, + // and the alternate binding it picks can conflict with MMA fragment + // bindings on NVIDIA / older AMD. Skip on every other target so the + // existing CUDA / pre-CDNA4 layout inference is untouched. + if (!loop_layout_.defined() && TargetIsGfx950(T.target)) { + // Reuse the same vec_size calculation as ComputePlanCandidate. + auto maybe_remapped_root = + IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); + int vector_size = + GetVectorizeSize(maybe_remapped_root, T.analyzer, T.layout_map); + PrimExpr loop_total_size = 1; + for (Stmt l = root_; l.as().has_value(); l = l.as().value()->body) + loop_total_size = loop_total_size * l.as().value()->extent; + while ( + !analyzer_.CanProve(floormod(loop_total_size, T.thread_bounds->extent * + vector_size) == 0) && + vector_size > 1) + vector_size /= 2; + if (auto cba = ComputeChunkBlockAwarePlanCandidate(T, vector_size); + cba.defined()) { + // Only adopt the CBA layout if it doesn't conflict with a fragment + // (e.g. an MMA accumulator like acc_o_l / C_local) that already has + // a layout in T.layout_map. Otherwise the unconditional override + // would force the loop onto a binding incompatible with the + // fragment and ValidateCandidateAgainstFragments would fail later. + if (ValidateCandidateAgainstFragments(cba, T, /*throw_on_error=*/false, + /*check_forward_index=*/false, + /*source_buffer=*/Buffer{})) { + loop_layout_ = cba; + } + } + } if (!loop_layout_.defined() && annotated_layout_unbound_.defined()) { loop_layout_ = annotated_layout_unbound_.value()->BindThreadRange(T.thread_bounds); @@ -705,12 +745,176 @@ Fragment ParallelOpNode::ComputePlanCandidate(const LayoutInferArgs &T) const { DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ << " ############# vector_size = " << vector_size << ", thread_bounds = " << T.thread_bounds << '\n'; + // Chunk-block-aware binding is taken by the early hook in + // ParallelOp::InferLayout (before source_buffer dispatch). By the time + // we reach here loop_layout_ is already set when CBA applies, so no + // need to re-try it. auto plan = PlanLoopPartition(root_, vector_size, T.thread_bounds); DLOG(INFO) << "[PlanLoopPartition] candidate = " << plan->DebugOutput() << '\n'; return plan; } +Fragment +ParallelOpNode::ComputeChunkBlockAwarePlanCandidate(const LayoutInferArgs &T, + int vector_size) const { + // 1. Find a written shared buffer with a swizzle layout whose continuous + // (innermost) dim exceeds one CDNA LDS bank cycle (128 bytes). When + // that happens FullBank splits the dim into `tc` planes; the default + // flatten policy puts too many lanes on the continuous dim and one + // wavefront ends up straddling the tc boundary. We compute tc from the + // buffer's last-dim extent + element size to avoid having to interpret + // the layout's output-dim structure (which can vary with pipelining). + Buffer target; + // Continuous (innermost) buffer dim extent. When the parallel loop is + // fused, this comes from buffer->shape.back(); when unfused, it equals + // the extent of the matching loop var. + int64_t cont_ext = 0; + int64_t inner_extent = 1; + // Index into loop_vars_ for the loop var that drives the continuous dim, + // or -1 if the access is `fused_var % cont_ext` (1D fused case). + int split_axis = -1; + constexpr int kBankCycleBytes = 128; + PostOrderVisit(root_, [&](const ObjectRef &obj) { + if (target.defined()) + return; + const auto *store = obj.as(); + if (!store) + return; + const Buffer &buffer = store->buffer; + if (!IsSharedBuffer(buffer)) + return; + if (!T.layout_map.count(buffer)) + return; + Layout layout = T.layout_map[buffer]; + if (layout.as()) + return; + if (store->indices.empty()) + return; + if (buffer->shape.empty()) + return; + auto *last_dim_imm = as_const_int(buffer->shape.back()); + if (!last_dim_imm) + return; + int64_t cont = *last_dim_imm; + int element_bytes = buffer->dtype.bytes(); + if (element_bytes <= 0) + return; + int64_t bank_cycle_elems = kBankCycleBytes / element_bytes; + if (bank_cycle_elems <= 0) + return; + if (cont * element_bytes <= kBankCycleBytes) + return; + if (cont % bank_cycle_elems != 0) + return; + if ((cont / bank_cycle_elems) <= 1) + return; + + // Identify the loop var(s) driving the continuous dim. + PrimExpr last_idx = analyzer_.Simplify(store->indices.back()); + int chosen_axis = -1; + if (auto var_opt = last_idx.as()) { + // Unfused N-D case: bare loop var on the cont dim. + for (int i = 0; i < static_cast(loop_vars_.size()); i++) { + if (loop_vars_[i]->var.same_as(var_opt.value())) { + chosen_axis = i; + break; + } + } + if (chosen_axis < 0) + return; + auto *axis_ext_imm = as_const_int(loop_vars_[chosen_axis]->dom->extent); + if (!axis_ext_imm || *axis_ext_imm != cont) + return; + } else { + // Fused 1D case: index is `fused_var % cont_ext` (after pipelining + // multi-dim accesses get flattened). Require exactly one loop var of + // extent that is a multiple of cont. + if (loop_vars_.size() != 1) + return; + auto *total_ext_imm = as_const_int(loop_vars_[0]->dom->extent); + if (!total_ext_imm || *total_ext_imm % cont != 0) + return; + // Match `fused_var % cont` (with cont equal to the buffer's last dim). + const auto *mod = last_idx.as(); + if (!mod) + return; + auto *mod_imm = as_const_int(mod->b); + if (!mod_imm || *mod_imm != cont) + return; + if (!mod->a.same_as(loop_vars_[0]->var)) + return; + } + + target = buffer; + split_axis = chosen_axis; + cont_ext = cont; + inner_extent = bank_cycle_elems; + }); + if (!target.defined()) + return Fragment(); + + // 2. Build flatten expressed purely in the existing loop vars so the + // resulting Fragment matches root_'s loop_vars and downstream + // PartitionLoop / LowerParallelLoop are unaffected. + ICHECK(!loop_vars_.empty()); + DataType dtype = loop_vars_[0]->var.dtype(); + PrimExpr inner_pe = IntImm(dtype, inner_extent); + PrimExpr flat; + if (split_axis >= 0) { + // Unfused N-D: split the chosen loop var into outer/inner and reorder + // to [outer, ..., inner] before row-major flatten. + PrimExpr split_var = loop_vars_[split_axis]->var; + PrimExpr outer_part = FloorDiv(split_var, inner_pe); + PrimExpr inner_part = FloorMod(split_var, inner_pe); + PrimExpr modified_total = IntImm(dtype, 1); + PrimExpr modified_flat = make_zero(dtype); + for (int i = 0; i < static_cast(loop_vars_.size()); i++) { + PrimExpr ext = (i == split_axis) ? inner_pe : loop_vars_[i]->dom->extent; + PrimExpr v = (i == split_axis) + ? inner_part + : static_cast(loop_vars_[i]->var); + modified_total = modified_total * ext; + modified_flat = modified_flat * ext + v; + } + flat = outer_part * modified_total + modified_flat; + } else { + // Fused 1D: decompose fused_var into (rest, cont_inner_part, c_inner) + // where cont_inner_part = (fused_var % cont)/inner. New flat puts + // n_outer (= cont_inner_part) outermost. + PrimExpr fused = loop_vars_[0]->var; + auto *total_ext_imm = as_const_int(loop_vars_[0]->dom->extent); + ICHECK(total_ext_imm); + int64_t total = *total_ext_imm; + int64_t rest = total / cont_ext; + PrimExpr cont_pe = IntImm(dtype, cont_ext); + PrimExpr c = FloorMod(fused, cont_pe); + PrimExpr rest_part = FloorDiv(fused, cont_pe); + PrimExpr n_outer = FloorDiv(c, inner_pe); + PrimExpr c_inner = FloorMod(c, inner_pe); + PrimExpr rest_pe = IntImm(dtype, rest); + flat = n_outer * (rest_pe * inner_pe) + rest_part * inner_pe + c_inner; + } + + // 3. Apply the same coalesce policy as LoopPartitioner::Partition: + // access_idx = flat / vec_size, thd = access_idx % num_thread, + // idx = (access_idx / num_thread) * vec_size + flat % vec_size. + auto *num_thread_imm = as_const_int(T.thread_bounds->extent); + if (!num_thread_imm) + return Fragment(); // Symbolic thread bounds: fall back to default plan. + PrimExpr vec_pe = IntImm(dtype, vector_size); + PrimExpr num_thread_pe = IntImm(dtype, *num_thread_imm); + PrimExpr access_idx = FloorDiv(flat, vec_pe); + PrimExpr thd = FloorMod(access_idx, num_thread_pe); + PrimExpr idx = + FloorDiv(access_idx, num_thread_pe) * vec_pe + FloorMod(flat, vec_pe); + + Fragment fragment = Fragment(loop_vars_, /*forward_index=*/{idx}, + /*forward_thread=*/thd, + /*thread_replicate=*/IterVar()); + return fragment->BindThreadRange(T.thread_bounds); +} + void ParallelOpNode::BuildReplicationGuardsIfNeeded( const LayoutInferArgs &T, const std::vector &store_shared_global_buffers, diff --git a/src/op/parallel.h b/src/op/parallel.h index 35760d95bc..43301fbc73 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -151,6 +151,17 @@ class ParallelOpNode : public TileOperatorNode { // Compute plan-based loop layout candidate using vectorization and thread // bounds. Fragment ComputePlanCandidate(const LayoutInferArgs &T) const; + // Compute a "chunk-block aware" plan candidate. When a written shared buffer + // has a swizzled layout whose outer (tc) dim has extent > 1 (FullBank with + // continuous-bytes > one LDS bank cycle), the default flatten-and-partition + // policy makes one wavefront's lanes straddle the chunk-block boundary, + // which breaks the lane-contiguous LDS WRITE constraint and prevents any + // downstream swizzle-swap. This candidate splits the continuous loop var + // (n = n_outer * inner + n_inner) and reorders to [n_outer, ..., n_inner] + // before flattening, so consecutive lanes stay inside one tc plane. + // Returns an undefined Fragment if no eligible buffer is found. + Fragment ComputeChunkBlockAwarePlanCandidate(const LayoutInferArgs &T, + int vector_size) const; // Add replication guard predicates when needed for cross-thread stores. void BuildReplicationGuardsIfNeeded( const LayoutInferArgs &T, diff --git a/src/tl_templates/hip/copy.h b/src/tl_templates/hip/copy.h index 6142e3fdfb..1c018272b4 100644 --- a/src/tl_templates/hip/copy.h +++ b/src/tl_templates/hip/copy.h @@ -138,4 +138,31 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr, } } +// Variant with pre-hoisted buffer resource descriptor and base address. +// rsrc and rsrc_base_lo are computed once at kernel entry (see the +// HoistBufferResource Python pass) so per-call readfirstlane overhead is +// amortised across the many cp_async_gs_lds_with_rsrc calls in an unrolled +// loop. rsrc_base_lo must equal readfirstlane((uint32_t)(uintptr_t)A) for +// the same A passed to make_wave_buffer_resource that produced rsrc. +template +TL_DEVICE void +cp_async_gs_lds_with_rsrc(void *lds_base_ptr, void const *global_base_ptr, + int32x4_t rsrc, uint32_t rsrc_base_lo) { + if constexpr (N == 16) { + uint32_t my_lo = + static_cast(reinterpret_cast(global_base_ptr)); + uint32_t voffset = my_lo - rsrc_base_lo; + uint32_t lds_cur = __builtin_amdgcn_readfirstlane( + static_cast(reinterpret_cast(lds_base_ptr))); + // TODO(benenzhu): here use inline asm is a little bit tricky. + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dwordx4 %1, %2, 0 offen lds;\n\t" + : + : "s"(lds_cur), "v"(voffset), "s"(rsrc) + : "memory"); + } else { + cp_async_gs(lds_base_ptr, global_base_ptr); + } +} + } // namespace tl diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 94cb29ccfd..e422fae557 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -414,7 +414,8 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { } bool IsCPAsyncOp(const Op &op) { - return op == builtin::ptx_cp_async() || op == tl::ptx_cp_async(); + return op == builtin::ptx_cp_async() || op == tl::ptx_cp_async() || + op == tl::ptx_cp_async_lds() || op == tl::ptx_cp_async_lds_rsrc(); } static constexpr int kCPAsyncDstPtrArg = 0; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index dbe0fcb147..4fea7f8066 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -449,7 +449,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { if (node->op.same_as(builtin::ptx_cp_async())) { return count * 8; } - ICHECK(node->op.same_as(tl::ptx_cp_async())); + ICHECK(node->op.same_as(tl::ptx_cp_async()) || + node->op.same_as(tl::ptx_cp_async_lds())); auto dst_elem_bits = GetAccessPtrElementBits(node->args[0]); auto src_elem_bits = GetAccessPtrElementBits(node->args[1]); if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { @@ -524,7 +525,8 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { buffer_vector_infos_.push_back({Buffer(), vectorize_length, false, {}}); return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else if (node->op.same_as(builtin::ptx_cp_async()) || - node->op.same_as(tl::ptx_cp_async())) { + node->op.same_as(tl::ptx_cp_async()) || + node->op.same_as(tl::ptx_cp_async_lds())) { // builtin::ptx_cp_async stores bytes, while tl::ptx_cp_async stores // logical element counts. In both cases we pick the largest vector width // whose eventual PTX payload is one of {4, 8, 16} bytes. diff --git a/src/transform/lower_ptx_async_copy.cc b/src/transform/lower_ptx_async_copy.cc index 12b5404b42..fe47a8f865 100644 --- a/src/transform/lower_ptx_async_copy.cc +++ b/src/transform/lower_ptx_async_copy.cc @@ -35,9 +35,11 @@ using namespace ffi; class PTXAsyncCopyInjector : public StmtMutator { public: explicit PTXAsyncCopyInjector(bool enable_auto_async_copy, - bool async_without_async_commit_wait) + bool async_without_async_commit_wait, + bool enable_buffer_load_lds = false) : enable_auto_async_copy_(enable_auto_async_copy), - async_without_async_commit_wait_(async_without_async_commit_wait) {} + async_without_async_commit_wait_(async_without_async_commit_wait), + enable_buffer_load_lds_(enable_buffer_load_lds) {} bool InjectedPTXAsyncCopy() const { return injected_ptx_async_copy_; } @@ -124,8 +126,8 @@ class PTXAsyncCopyInjector : public StmtMutator { store, /*dst_base_load=*/BufferLoad(store->buffer, store->indices), /*src_base_load=*/BufferLoad(load->buffer, load->indices), - /*num_elems=*/index_info->per_access_num_elems, predicated, - predicate_value); + /*num_elems=*/index_info->per_access_num_elems, + /*total_bytes=*/index_info->total_bytes, predicated, predicate_value); } Optional> src_base_indices = @@ -145,8 +147,8 @@ class PTXAsyncCopyInjector : public StmtMutator { store, /*dst_base_load=*/BufferLoad(store->buffer, dst_base_indices.value()), /*src_base_load=*/BufferLoad(load->buffer, src_base_indices.value()), - /*num_elems=*/index_info->per_access_num_elems, predicated, - predicate_value); + /*num_elems=*/index_info->per_access_num_elems, + /*total_bytes=*/index_info->total_bytes, predicated, predicate_value); } Stmt VisitStmt_(const SeqStmtNode *op) final { @@ -301,6 +303,10 @@ class PTXAsyncCopyInjector : public StmtMutator { PrimExpr dst_index; int index_lanes{1}; int per_access_num_elems{0}; + // Byte width of one vectorized transfer at the final PTX emission point. + // Already factors in `current_vectorized_lanes_`. Used by the gfx950 + // buffer_load...lds routing to confirm the 16-byte gate. + int total_bytes{0}; }; // Synchronization state for injected cp.async runs carried across statements. @@ -420,6 +426,7 @@ class PTXAsyncCopyInjector : public StmtMutator { info.dst_index = dst_index; info.index_lanes = index_lanes; info.per_access_num_elems = effective_lanes; + info.total_bytes = total_bytes; return info; } @@ -488,16 +495,39 @@ class PTXAsyncCopyInjector : public StmtMutator { IntImm(DataType::Int(32), rw_mask)}); } - static Optional - MakeCPAsyncStmtFromLoads(const BufferStoreNode *store, - const BufferLoad &dst_base_load, - const BufferLoad &src_base_load, int num_elems, - bool predicated, const PrimExpr &predicate_value) { + Optional MakeCPAsyncStmtFromLoads(const BufferStoreNode *store, + const BufferLoad &dst_base_load, + const BufferLoad &src_base_load, + int num_elems, int total_bytes, + bool predicated, + const PrimExpr &predicate_value) { PrimExpr dst_access_ptr = MakeAccessPtrFromLoad(dst_base_load, num_elems, /*rw_mask=*/2); PrimExpr src_access_ptr = MakeAccessPtrFromLoad(src_base_load, num_elems, /*rw_mask=*/1); + // gfx950 routing: emit tl::ptx_cp_async_lds when the destination is a + // 16-byte non-predicated shared-memory write. Arg 2 carries the logical + // element count (same convention tl::ptx_cp_async uses) so the existing + // vec-loop folding in vectorize_loop.cc widens it correctly when the + // call sits inside a T.vectorized(k) loop. The codegen handler converts + // the logical count back to bytes via GetTileLangCPAsyncTransferBytes. + // If the LDS index carries an XOR swizzle, the swizzle-swap visitor in + // LowerTileOp rewrites the call (subtract SwizzleDelta on LDS, add on + // global) so the destination becomes lane-contiguous; if the swap + // can't produce an affine destination it downgrades back to + // tl::ptx_cp_async, so both paths produce correct code. + if (enable_buffer_load_lds_ && !predicated && total_bytes == 16) { + const std::string dst_scope = store->buffer.scope(); + const bool is_shared = dst_scope == "shared" || dst_scope == "shared.dyn"; + if (is_shared) { + Array lds_args = {dst_access_ptr, src_access_ptr, + PrimExpr(num_elems)}; + return Evaluate( + Call(store->buffer->dtype, tvm::tl::ptx_cp_async_lds(), lds_args)); + } + } + Array cp_async_args; if (predicated) { cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(num_elems), @@ -616,7 +646,9 @@ class PTXAsyncCopyInjector : public StmtMutator { return out; } if (call->op.same_as(builtin::ptx_cp_async()) || - call->op.same_as(tl::ptx_cp_async())) { + call->op.same_as(tl::ptx_cp_async()) || + call->op.same_as(tl::ptx_cp_async_lds()) || + call->op.same_as(tl::ptx_cp_async_lds_rsrc())) { return out; } if (call->op.same_as(builtin::ptx_commit_group())) { @@ -691,6 +723,7 @@ class PTXAsyncCopyInjector : public StmtMutator { bool enable_auto_async_copy_{true}; bool async_without_async_commit_wait_{false}; + bool enable_buffer_load_lds_{false}; int explicit_async_scope_depth_{0}; int current_vectorized_lanes_{1}; std::vector active_vectorized_loops_; @@ -704,9 +737,11 @@ using namespace tirx::transform; PTXAsyncCopyInjectResult InjectPTXAsyncCopy(const Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait) { + bool async_without_async_commit_wait, + bool enable_buffer_load_lds) { PTXAsyncCopyInjector injector(enable_auto_async_copy, - async_without_async_commit_wait); + async_without_async_commit_wait, + enable_buffer_load_lds); Stmt injected = injector(body); return {injector.Finalize(injected), injector.InjectedPTXAsyncCopy()}; } diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index ab54cbbc75..2fe2dad06e 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -1001,6 +1001,160 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return call; } + // gfx950 swizzle-swap for buffer_load_dwordx4 ... lds. On this branch + // the BufferStores that feed A_shared/B_shared are consumed by + // InjectPTXAsyncCopy inside rocm::Copy::Lower before they reach the + // BufferStore visitor, so the swap has to happen on the resulting + // tl::ptx_cp_async_lds Call. We rewrite the dst access_ptr to use + // (Forward(idx) - delta) on the last dim against the remapped shared + // buffer (lane-contiguous LDS writes) and shift the global src + // access_ptr by +delta on the last dim (XOR is self-inverse, so net + // data movement is unchanged). Returning the rewritten Call directly + // prevents the default visitor from re-applying the swizzled layout. + if (op->op.same_as(tl::ptx_cp_async_lds()) && TargetIsRocm(target_) && + op->args.size() == 3) { + auto resolve_load = [&](const PrimExpr &arg) -> const BufferLoadNode * { + const auto *call = arg.as(); + if (!call || !call->op.same_as(tl::access_ptr())) + return nullptr; + const auto *direct = call->args[0].as(); + if (direct) + return direct; + if (const auto *var = call->args[0].as()) { + auto it = let_bindings_.find(Downcast(call->args[0])); + if (it != let_bindings_.end()) { + return it->second.as(); + } + (void)var; + } + return nullptr; + }; + const auto *dst_ap = op->args[0].as(); + const auto *src_ap = op->args[1].as(); + const BufferLoadNode *dst_load = resolve_load(op->args[0]); + const BufferLoadNode *src_load = resolve_load(op->args[1]); + // Candidate gate: do we have everything we need to even attempt + // the swap? If not, fall through to "downgrade" below — we must + // NOT leave the call as ptx_cp_async_lds, because codegen would + // then emit the LDS template against the unmodified (swizzled) + // dst index and produce wrong addresses. + bool m9_candidate = dst_ap && src_ap && dst_load && src_load && + IsSharedBuffer(dst_load->buffer) && + IsGlobalBuffer(src_load->buffer) && + buffer_remap_.count(dst_load->buffer) && + layout_map_.count(dst_load->buffer) && + layout_map_[dst_load->buffer]->HasSwizzle() && + dst_load->indices.size() > 0 && + src_load->indices.size() > 0; + if (!m9_candidate) { + // Can't do the swap. Downgrade so codegen uses the safe + // synchronous cp_async_gs path instead of buffer_load_lds. + Call downgraded(op->dtype, tl::ptx_cp_async(), op->args); + return IRMutatorWithAnalyzer::VisitExpr(downgraded); + } + { + Buffer new_dst_buf = buffer_remap_[dst_load->buffer]; + layout_remap_.Set(new_dst_buf, layout_map_[dst_load->buffer]); + auto layout = layout_map_[dst_load->buffer]; + auto swizzled = layout->Forward(dst_load->indices); + PrimExpr delta = + analyzer_->Simplify(layout->SwizzleDelta(dst_load->indices)); + + Array new_dst_indices(swizzled.begin(), swizzled.end()); + int last_dst = static_cast(new_dst_indices.size()) - 1; + new_dst_indices.Set( + last_dst, analyzer_->Simplify(new_dst_indices[last_dst] - delta)); + + Array new_src_indices(src_load->indices.begin(), + src_load->indices.end()); + int last_src = static_cast(new_src_indices.size()) - 1; + new_src_indices.Set( + last_src, analyzer_->Simplify(new_src_indices[last_src] + delta)); + + // Post-swap linearity guard: the single-dim subtract-delta swap + // only cancels the XOR when the layout's swizzle is confined to + // the last output dim. For layouts (e.g. B's matmul layout in + // some shapes) where the swizzle spreads across multiple output + // dims, the post-swap LDS index is NOT lane-contiguous and + // buffer_load_dwordx4...lds would write to wrong addresses. + // Detect by sampling each new_dst_indices[d] against the actual + // threadIdx.x binding tracked in thread_var_ (set from the + // tirx::attr::thread_extent AttrStmt). Substituting concrete + // lane values and requiring the dependence to be a single + // constant stride catches the case where the LDS index has + // bit-extract terms like `(tx & m) >> s` that buffer_load_lds + // would scatter. If any dim is non-affine in the lane var, bail + // out so the call falls through to the safe ptx_cp_async path. + // + // Use the real binding rather than a name-based heuristic so a + // future rename of the lane var (or any kernel whose lane var + // doesn't happen to be called "tx"/"tid"/"thread*") doesn't + // silently misclassify the call as affine-OK and emit a + // wrong-banks tile. + Var lane_var; + if (thread_var_.defined() && thread_var_->var.defined() && + thread_block_size_ > 1) { + lane_var = thread_var_->var; + } + auto is_affine_in_thread_var = [&](const PrimExpr &e) -> bool { + if (!lane_var.defined()) { + // Serial / no-thread kernel: expression is trivially + // constant w.r.t. the lane and any LDS layout works. + return true; + } + arith::Analyzer post_analyzer; + PrimExpr f0 = post_analyzer.Simplify(tirx::Substitute( + e, Map{{lane_var, IntImm(lane_var->dtype, 0)}})); + PrimExpr f1 = post_analyzer.Simplify(tirx::Substitute( + e, Map{{lane_var, IntImm(lane_var->dtype, 1)}})); + PrimExpr stride = post_analyzer.Simplify(f1 - f0); + const auto *stride_imm = stride.as(); + if (!stride_imm) + return false; + for (int pt = 2; pt < 64; ++pt) { + PrimExpr fk = post_analyzer.Simplify(tirx::Substitute( + e, + Map{{lane_var, IntImm(lane_var->dtype, pt)}})); + PrimExpr actual = post_analyzer.Simplify(fk - f0); + PrimExpr expected = + IntImm(DataType::Int(64), stride_imm->value * pt); + if (!post_analyzer.CanProveEqual(actual, expected)) { + return false; + } + } + return true; + }; + bool all_dims_affine = true; + for (const auto &idx : new_dst_indices) { + if (!is_affine_in_thread_var(idx)) { + all_dims_affine = false; + break; + } + } + if (!all_dims_affine) { + // The swap can't produce a lane-contiguous LDS dst for this + // layout. Downgrade the op from tl::ptx_cp_async_lds to + // tl::ptx_cp_async (same arg shape) so codegen emits the safe + // synchronous cp_async_gs path rather than buffer_load_lds + // with a non-contiguous LDS index. Let the default visitor + // recurse from there so the access_ptr children still get the + // ordinary swizzled-layout treatment. + Call downgraded(op->dtype, tl::ptx_cp_async(), op->args); + return IRMutatorWithAnalyzer::VisitExpr(downgraded); + } else { + BufferLoad new_dst_load(new_dst_buf, new_dst_indices); + BufferLoad new_src_load(src_load->buffer, new_src_indices); + PrimExpr new_dst_ap = + Call(dst_ap->dtype, tl::access_ptr(), + {new_dst_load, dst_ap->args[1], dst_ap->args[2]}); + PrimExpr new_src_ap = + Call(src_ap->dtype, tl::access_ptr(), + {new_src_load, src_ap->args[1], src_ap->args[2]}); + return Call(op->dtype, op->op, {new_dst_ap, new_src_ap, op->args[2]}); + } + } + } + // Default: visit normally auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); return call; @@ -1031,9 +1185,58 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); auto buffer = store->buffer; if (buffer_remap_.count(buffer)) { - auto new_indices = layout_map_[buffer]->Forward(store->indices); auto new_buffer = buffer_remap_[store->buffer]; layout_remap_.Set(new_buffer, layout_map_[store->buffer]); + + // gfx950 buffer_load_dwordx4 ... lds requires LDS destinations be + // lane-contiguous (no XOR). When the shared-buffer layout carries + // an XOR swizzle, move the XOR off the LDS store side and onto + // the global load side (XOR is self-inverse, so the net data + // movement is unchanged). Gated on: ROCm target, non-PTX path, + // shared destination, layout has a swizzle delta, and the store + // value is a direct global BufferLoad (the g2s shape the cp.async + // injector recognises). + if (TargetIsRocm(target_) && !is_ptx_ && IsSharedBuffer(buffer) && + layout_map_[buffer]->HasSwizzle()) { + const BufferLoadNode *load_node = nullptr; + if (auto *load = store->value.as()) { + if (IsGlobalBuffer(load->buffer)) { + load_node = load; + } + } + if (load_node && is_one(layout_map_[buffer]->OutputShape()[0]) && + load_node->indices.size() == store->indices.size()) { + auto swizzled_store = layout_map_[buffer]->Forward(store->indices); + PrimExpr delta = analyzer_->Simplify( + layout_map_[buffer]->SwizzleDelta(store->indices)); + + Array sequential_store(swizzled_store.begin(), + swizzled_store.end()); + int last_out = static_cast(sequential_store.size()) - 1; + sequential_store.Set( + last_out, + analyzer_->Simplify(sequential_store[last_out] - delta)); + + Array reflected(store->indices.begin(), + store->indices.end()); + int last_in = static_cast(reflected.size()) - 1; + reflected.Set(last_in, + analyzer_->Simplify(reflected[last_in] + delta)); + + Array new_load_indices; + for (size_t k = 0; k < load_node->indices.size(); ++k) { + PrimExpr base = + analyzer_->Simplify(load_node->indices[k] - store->indices[k]); + new_load_indices.push_back( + analyzer_->Simplify(base + reflected[k])); + } + + BufferLoad rewritten_load(load_node->buffer, new_load_indices); + return BufferStore(new_buffer, rewritten_load, sequential_store); + } + } + + auto new_indices = layout_map_[buffer]->Forward(store->indices); return BufferStore(new_buffer, store->value, new_indices); } else if (var_remap_.count(buffer->data)) { auto new_buffer = Buffer( diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index d3d64a6c46..b542a81435 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -737,7 +737,8 @@ class SharedMemoryRewriter : public StmtExprMutator { {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())) { + op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds())) { ICHECK(op->args.size() == 3U || op->args.size() == 4U) << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " "src_access_ptr, count[, predicate])"; diff --git a/src/transform/ptx_async_copy_injector.h b/src/transform/ptx_async_copy_injector.h index 03a8497daa..b85055430c 100644 --- a/src/transform/ptx_async_copy_injector.h +++ b/src/transform/ptx_async_copy_injector.h @@ -15,10 +15,16 @@ struct PTXAsyncCopyInjectResult { * This is the statement-level entrypoint used by other transforms to apply the * same rewrite as the `tl.LowerPTXAsyncCopy` pass, but scoped to a region * (e.g., a lowered parallel loop) rather than the whole PrimFunc. + * + * `enable_buffer_load_lds` enables the gfx950-specific routing that emits + * tl::ptx_cp_async_lds for eligible 16-byte non-predicated shared-memory- + * destined copies whose LDS index is lane-contiguous (no XOR swizzle). The + * ROCm copy lowering pass passes this flag only when the target is gfx950+. */ PTXAsyncCopyInjectResult InjectPTXAsyncCopy(const tvm::tirx::Stmt &body, bool enable_auto_async_copy, - bool async_without_async_commit_wait = false); + bool async_without_async_commit_wait = false, + bool enable_buffer_load_lds = false); } // namespace tl } // namespace tvm diff --git a/src/transform/thread_storage_sync.cc b/src/transform/thread_storage_sync.cc index fed149e889..22ba2b65c9 100644 --- a/src/transform/thread_storage_sync.cc +++ b/src/transform/thread_storage_sync.cc @@ -1015,7 +1015,9 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor { if (auto opt = op->op.as()) { const Op &call_op = opt.value(); return call_op.same_as(builtin::ptx_cp_async()) || - call_op.same_as(tl::ptx_cp_async()); + call_op.same_as(tl::ptx_cp_async()) || + call_op.same_as(tl::ptx_cp_async_lds()) || + call_op.same_as(tl::ptx_cp_async_lds_rsrc()); } return false; }(); diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index 3c811ebe88..6c556b6dee 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -680,7 +680,9 @@ class TLVectorizer : public StmtMutator, if (op->op.same_as(builtin::ptx_cp_async())) { return scalar_count * 8; } - ICHECK(op->op.same_as(tl::ptx_cp_async())); + ICHECK(op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds()) || + op->op.same_as(tl::ptx_cp_async_lds_rsrc())); auto dst_elem_bits = GetAccessPtrElementBits(op->args[0]); auto src_elem_bits = GetAccessPtrElementBits(op->args[1]); if (!dst_elem_bits.has_value() || !src_elem_bits.has_value()) { @@ -701,8 +703,12 @@ class TLVectorizer : public StmtMutator, // the final codegen validate the derived PTX byte width. PrimExpr MutatePTXCPAsyncExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())); - if (op->args.size() != 3 && op->args.size() != 4) { + op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds()) || + op->op.same_as(tl::ptx_cp_async_lds_rsrc())); + // 3 or 4 args: dst, src, count, [predicate] (plain cp_async family). + // 5 args: dst, src, count, rsrc_var, base_var (hoisted-resource form). + if (op->args.size() != 3 && op->args.size() != 4 && op->args.size() != 5) { return GetRef(op); } @@ -718,13 +724,27 @@ class TLVectorizer : public StmtMutator, } predicate = pred; } + // For the rsrc form, args[3..4] are the hoisted (rsrc_var, base_var) and + // must be preserved through the rewrite so codegen still sees them. + Array trailing_rsrc_args; + if (op->args.size() == 5) { + trailing_rsrc_args.push_back(VisitExpr(op->args[3])); + trailing_rsrc_args.push_back(VisitExpr(op->args[4])); + } + + auto append_trailing = [&](Array &args) { + if (predicate.defined()) { + args.push_back(predicate.value()); + } + for (const auto &a : trailing_rsrc_args) { + args.push_back(a); + } + }; auto lanes_ptr = as_const_int(var_lanes_); if (!lanes_ptr || *lanes_ptr <= 1) { Array new_args{dst, src, count}; - if (predicate.defined()) { - new_args.push_back(predicate.value()); - } + append_trailing(new_args); if (new_args.same_as(op->args)) { return GetRef(op); } @@ -752,9 +772,7 @@ class TLVectorizer : public StmtMutator, int total_count = static_cast(Downcast(count)->value) * vector_size; Array new_args{dst, src, IntImm(count.dtype(), total_count)}; - if (predicate.defined()) { - new_args.push_back(predicate.value()); - } + append_trailing(new_args); if (new_args.same_as(op->args)) { return GetRef(op); } @@ -792,7 +810,9 @@ class TLVectorizer : public StmtMutator, } else if (op->op.same_as(builtin::tvm_access_ptr())) { return MutateAccessPtrCall_(op); } else if (op->op.same_as(builtin::ptx_cp_async()) || - op->op.same_as(tl::ptx_cp_async())) { + op->op.same_as(tl::ptx_cp_async()) || + op->op.same_as(tl::ptx_cp_async_lds()) || + op->op.same_as(tl::ptx_cp_async_lds_rsrc())) { return MutatePTXCPAsyncExpr_(op); } auto optional_op = op->op.as(); diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index de7a9b7033..89a1f47a4f 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -298,6 +298,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.InjectTcgen05Fence()(mod) mod = tilelang.transform.MergeIfStmt()(mod) # NOTE: LowerPTXAsyncCopy is applied earlier (before PipelinePlanning). + # Hoist buffer resource descriptors for the gfx950 buffer_load...lds path. + # No-op on non-gfx950 targets (pass guards on target_is_gfx950). + mod = tilelang.transform.HoistBufferResource()(mod) if allow_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) mod = tilelang.transform.MakePackedAPI()(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 198ccb65f8..8616ebccb3 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -8,6 +8,7 @@ from tvm.ir.transform import PassContext # noqa: F401 from .add_bufstore_wrapper import AddWrapperForSingleBufStore # noqa: F401 from .hoist_broadcast_values import HoistBroadcastValues # noqa: F401 +from .hoist_buffer_resource import HoistBufferResource # noqa: F401 from .decouple_type_cast import DecoupleTypeCast # noqa: F401 diff --git a/tilelang/transform/hoist_buffer_resource.py b/tilelang/transform/hoist_buffer_resource.py new file mode 100644 index 0000000000..1c98758141 --- /dev/null +++ b/tilelang/transform/hoist_buffer_resource.py @@ -0,0 +1,240 @@ +"""Hoist make_wave_buffer_resource descriptors + scale AMD async wait counts. + +On gfx950, the `cp_async_gs_lds_with_rsrc<16>` device template takes a +pre-computed buffer resource descriptor and a pre-computed wave-uniform +base address. Computing those per call would emit 4x readfirstlane plus +the resource bit-cast on every call site. In an unrolled tile-copy loop +the same global buffer is touched many times, so we lift the descriptor +to the kernel prologue once per source buffer and rewrite the calls to +the variant that takes the pre-hoisted pair. + +Second half: AMD vmcnt tracks individual `buffer_load` issues, not the +NVIDIA-style commit groups. `tl::cp_async_wait` lowers to +`s_waitcnt vmcnt(N)`, so a wait-for-N-groups must become +wait-for-(N * loads_per_group) on AMD. NVIDIA's cp.async commits group +every async load issued since the last commit; on AMD we have to scale +the wait count manually. We do that by finding the for-loop that +contains the `ptx_commit_group` call, counting async loads in one +iteration of that loop (multiplied by loop extents for nested unrolls), +and rewriting every positive `ptx_wait_group(n)` to `ptx_wait_group(n * +loads_per_group)`. `ptx_wait_group(0)` (wait-all) stays as `vmcnt(0)`, +which is already correct. + +Pipeline order: this pass runs in the OptimizeForTarget phase after +ThreadSync/MergeIfStmt and before MakePackedAPI, which means the +tl::access_ptr calls have already been lowered by `LowerAccessPtr` to +`tir.tvm_access_ptr(ptype, data, offset, extent, rw_mask)`, so the +buffer Var is at args[1] of each access_ptr term. + +This pass is gfx950-only: on every other target it returns the PrimFunc +unchanged. +""" + +from tvm import tirx as tir +from tvm.tirx import AttrStmt, Call, Evaluate, Var, PrimFunc, stmt_functor +from tvm.tirx.transform import prim_func_pass + +from tilelang.utils.target import target_is_gfx950 + +_op_ptx_cp_async_lds = tir.op.Op.get("tl.ptx_cp_async_lds") +_op_ptx_cp_async_lds_rsrc = tir.op.Op.get("tl.ptx_cp_async_lds_rsrc") +_op_tvm_access_ptr = tir.op.Op.get("tirx.tvm_access_ptr") +_op_ptx_commit_group = tir.op.Op.get("tirx.ptx_commit_group") +_op_ptx_wait_group = tir.op.Op.get("tirx.ptx_wait_group") + + +def _extract_buffer_var(access_ptr_expr): + """Pull the buffer-data Var out of a lowered tvm_access_ptr call. + + After tl.LowerAccessPtr the access pointer is encoded as + ``tvm_access_ptr(ptype, data, offset, extent, rw_mask)`` so args[1] + is the Var of interest. Anything else (e.g. an unlowered tl.access_ptr + or a plain pointer expression) returns None and the call is skipped. + """ + if not isinstance(access_ptr_expr, Call): + return None + if access_ptr_expr.op != _op_tvm_access_ptr: + return None + if len(access_ptr_expr.args) < 2: + return None + data_arg = access_ptr_expr.args[1] + if isinstance(data_arg, Var): + return data_arg + return None + + +def _is_async_load_call(stmt): + if not isinstance(stmt, Evaluate) or not isinstance(stmt.value, Call): + return False + op = stmt.value.op + return op in (_op_ptx_cp_async_lds, _op_ptx_cp_async_lds_rsrc) + + +def _is_commit_call(stmt): + if not isinstance(stmt, Evaluate) or not isinstance(stmt.value, Call): + return False + return stmt.value.op == _op_ptx_commit_group + + +def _contains_commit_call(stmt): + found = [False] + + def _v(s): + if _is_commit_call(s): + found[0] = True + + stmt_functor.post_order_visit(stmt, _v) + return found[0] + + +def _find_for_with_commit(stmt): + """Find the innermost For loop whose body contains a commit call.""" + if isinstance(stmt, tir.For): + inner = _find_for_with_commit(stmt.body) + if inner is not None: + return inner + if _contains_commit_call(stmt.body): + return stmt + elif isinstance(stmt, tir.SeqStmt): + for s in stmt.seq: + r = _find_for_with_commit(s) + if r is not None: + return r + elif hasattr(stmt, "body"): + return _find_for_with_commit(stmt.body) + return None + + +def _count_async_loads(stmt, multiplier=1): + if _is_async_load_call(stmt): + return multiplier + if isinstance(stmt, tir.For): + ext = multiplier + if isinstance(stmt.extent, tir.IntImm): + ext = multiplier * stmt.extent.value + return _count_async_loads(stmt.body, ext) + if isinstance(stmt, tir.SeqStmt): + return sum(_count_async_loads(s, multiplier) for s in stmt.seq) + if isinstance(stmt, tir.AttrStmt): + return _count_async_loads(stmt.body, multiplier) + if isinstance(stmt, tir.IfThenElse): + c = _count_async_loads(stmt.then_case, multiplier) + if stmt.else_case is not None: + c = max(c, _count_async_loads(stmt.else_case, multiplier)) + return c + return 0 + + +def _get_loads_per_group(body): + for_node = _find_for_with_commit(body) + if for_node is not None: + return _count_async_loads(for_node.body) + return 0 + + +def _fix_amd_wait_counts(body, loads_per_group): + """Multiply positive ptx_wait_group(n) arguments by loads_per_group. + + Each `tl::cp_async_wait` on AMD lowers to `s_waitcnt vmcnt(N)`, + which counts individual buffer_loads rather than NVIDIA-style commit + groups. wait_group(0) (wait-all) stays unchanged because vmcnt(0) + is already the correct "wait for everything" sentinel. + """ + + def _postorder(op): + if not isinstance(op, Evaluate): + return None + if not isinstance(op.value, Call): + return None + if op.value.op != _op_ptx_wait_group: + return None + if len(op.value.args) != 1: + return None + n_arg = op.value.args[0] + if not isinstance(n_arg, tir.IntImm): + return None + if n_arg.value <= 0: + return None + new_call = Call( + op.value.dtype, + _op_ptx_wait_group, + [tir.IntImm(n_arg.dtype, n_arg.value * loads_per_group)], + ) + return Evaluate(new_call) + + return stmt_functor.ir_transform(body, None, _postorder, ["tirx.Evaluate"]) + + +def _collect_buffer_vars(body): + """Discover unique source buffer Vars referenced by ptx_cp_async_lds calls. + + Returns an ordered dict {buf_var: (rsrc_var, base_var)} so the prologue + AttrStmts emit in a stable order. + """ + buffer_vars = {} + + def _visit(stmt): + if isinstance(stmt, Evaluate) and isinstance(stmt.value, Call) and stmt.value.op == _op_ptx_cp_async_lds: + # ptx_cp_async_lds args: (dst_access_ptr, src_access_ptr, bytes) + buf_var = _extract_buffer_var(stmt.value.args[1]) + if buf_var is not None and buf_var not in buffer_vars: + rsrc_var = Var("__rsrc_" + buf_var.name, dtype="handle") + base_var = Var("__base_" + buf_var.name, dtype="uint32") + buffer_vars[buf_var] = (rsrc_var, base_var) + + stmt_functor.post_order_visit(body, _visit) + return buffer_vars + + +def _rewrite_calls(body, buffer_vars): + """Rewrite ptx_cp_async_lds -> ptx_cp_async_lds_rsrc with hoisted vars.""" + + def _postorder(op): + if isinstance(op, Evaluate) and isinstance(op.value, Call) and op.value.op == _op_ptx_cp_async_lds: + buf_var = _extract_buffer_var(op.value.args[1]) + if buf_var is not None and buf_var in buffer_vars: + rsrc_var, base_var = buffer_vars[buf_var] + new_call = Call( + op.value.dtype, + _op_ptx_cp_async_lds_rsrc, + [ + op.value.args[0], + op.value.args[1], + op.value.args[2], + rsrc_var, + base_var, + ], + ) + return Evaluate(new_call) + return None + + return stmt_functor.ir_transform(body, None, _postorder, ["tirx.Evaluate"]) + + +def HoistBufferResource(): + """gfx950: hoist buffer resource descriptors + scale AMD vmcnt waits.""" + + def pass_fn(func: PrimFunc, _mod, _ctx): + target = func.attrs.get("target", None) + if target is None or not target_is_gfx950(target): + return func + + buffer_vars = _collect_buffer_vars(func.body) + if not buffer_vars: + return func + + new_body = _rewrite_calls(func.body, buffer_vars) + + for buf_var, (rsrc_var, base_var) in reversed(list(buffer_vars.items())): + new_body = AttrStmt(base_var, "buffer_base_var", buf_var, new_body) + new_body = AttrStmt(rsrc_var, "buffer_resource_var", buf_var, new_body) + + # AMD wait-count scaling. Only meaningful when there's at least one + # commit group; otherwise loads_per_group is 0 and we skip. + loads_per_group = _get_loads_per_group(new_body) + if loads_per_group > 1: + new_body = _fix_amd_wait_counts(new_body, loads_per_group) + + return func.with_body(new_body) + + return prim_func_pass(pass_fn, opt_level=0)