Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Full documentation for MIGraphX is available at
* Added driver warnings when inputs dimensions and/or values are not set (#4850).
* Added documentation for using debug symbols (#4945).
* Added `--log-stdout` flag to migraphx-driver to log to stdout instead of stderr (#4959).
* Added slice squeeze matcher to propogate squeeze downstream and allow for parallel branches to merge together (#5004)

### Changed

Expand Down
54 changes: 54 additions & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,59 @@ struct find_flatten
flatten->inputs());
}
};

// Match slice->squeeze->pw/reduce where the squeeze and slice share the same
// single axis, then rewrite to slice->pw/reduce->squeeze (unsqueezing the
// other inputs). find_op_shape_transform_op propagates the squeeze through
// any downstream op chain, and find_splits in simplify_algebra merges parallel
// branches back together.
struct find_slice_squeeze
{
auto matcher() const
{
auto match_op = match::any_of(match::pointwise(), match::reduce());
auto squeeze_slice =
match::name("squeeze")(match::arg(0)(match::name("slice").bind("slice")))
.bind("squeeze");
return match_op(match::any_of[match::inputs()](squeeze_slice));
}

void apply(module& m, const match::matcher_result& r) const
{
auto op_ins = r.result;
auto squeeze = r.instructions["squeeze"];
auto slice_ins = r.instructions["slice"];

auto sq_axes = squeeze->get_operator().to_value()["axes"].to_vector<int64_t>();
auto sl_axes = slice_ins->get_operator().to_value()["axes"].to_vector<int64_t>();
if(sq_axes.size() != 1 or sl_axes.size() != 1 or sq_axes.front() != sl_axes.front())
return;

auto axis = sq_axes.front();

auto inputs = op_ins->inputs();
for(auto& input : inputs)
{
if(input == squeeze)
input = slice_ins;
else
input =
m.insert_instruction(op_ins, make_op("unsqueeze", {{"axes", {axis}}}), input);
}

// Unsqueezing the inputs shifts every axis at or after `axis` up by one.
// Build the source->common axes map and let find_op_shape_transform_op
// handle reduce/argmin/layout axis remapping (pointwise ops are inserted
// unchanged), instead of duplicating that logic here.
std::vector<std::vector<std::size_t>> axes_map(squeeze->get_shape().ndim());
for(std::size_t i = 0; i < axes_map.size(); ++i)
axes_map[i] = {i >= axis ? i + 1 : i};

auto new_op = find_op_shape_transform_op::insert(m, op_ins, inputs, axes_map);
auto new_sq = m.insert_instruction(op_ins, squeeze->get_operator(), new_op);
m.replace_instruction(op_ins, new_sq);
}
};
} // namespace

void simplify_reshapes::apply(module& m) const
Expand All @@ -1841,6 +1894,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{},
find_slice_squeeze{},
find_unary_shape_transforms{},
find_reshape_dot{},
find_mul_add_shape_op_dot{},
Expand Down
161 changes: 160 additions & 1 deletion test/optimize_module_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -185,4 +185,163 @@ TEST_CASE(mul_add_transpose_dot)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(slice_squeeze_pw_unary)
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto s0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input);
auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0);
auto rel0 = m1.add_instruction(migraphx::make_op("relu"), sq0);
auto s1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input);
auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
auto rel1 = m1.add_instruction(migraphx::make_op("relu"), sq1);
m1.add_return({rel0, rel1});
}
run_pass(m1);

migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto relu = m2.add_instruction(migraphx::make_op("relu"), input);
auto s0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0);
auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
m2.add_return({sq0, sq1});
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(slice_squeeze_pw_unary_3d)
{
migraphx::shape s{migraphx::shape::float_type, {3, 2, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto s0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input);
auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0);
auto rel0 = m1.add_instruction(migraphx::make_op("relu"), sq0);
auto s1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input);
auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
auto rel1 = m1.add_instruction(migraphx::make_op("relu"), sq1);
auto s2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), input);
auto sq2 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s2);
auto rel2 = m1.add_instruction(migraphx::make_op("relu"), sq2);
m1.add_return({rel0, rel1, rel2});
}
run_pass(m1);

migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto relu = m2.add_instruction(migraphx::make_op("relu"), input);
auto s0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0);
auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
auto s2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), relu);
auto sq2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s2);
m2.add_return({sq0, sq1, sq2});
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(slice_squeeze_pw_binary_const)
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
migraphx::shape bs{migraphx::shape::float_type, {4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto b0 = m1.add_literal(migraphx::generate_literal(bs, 0));
auto b1 = m1.add_literal(migraphx::generate_literal(bs, 1));

auto s0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input);
auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0);
auto add0 = m1.add_instruction(migraphx::make_op("add"), sq0, b0);

auto s1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input);
auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), sq1, b1);

m1.add_return({add0, add1});
}
run_pass(m1);

// propagate_constant folds unsqueeze+concat of literals into one literal
migraphx::literal stacked_lit;
{
migraphx::module tmp;
auto b0 = tmp.add_literal(migraphx::generate_literal(bs, 0));
auto b1 = tmp.add_literal(migraphx::generate_literal(bs, 1));
auto bu0 = tmp.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), b0);
auto bu1 = tmp.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), b1);
auto cat = tmp.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), bu0, bu1);
auto ev = cat->eval();
stacked_lit = migraphx::literal(ev.get_shape(), ev.data());
}

migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto stacked = m2.add_literal(stacked_lit);
auto add = m2.add_instruction(migraphx::make_op("add"), input, stacked);

auto s0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add);
auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0);
auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);

m2.add_return({sq0, sq1});
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(slice_squeeze_non_zero_axis)
{
migraphx::shape s{migraphx::shape::float_type, {3, 2, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto s0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s0);
auto rel0 = m1.add_instruction(migraphx::make_op("relu"), sq0);
auto s1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input);
auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s1);
auto rel1 = m1.add_instruction(migraphx::make_op("relu"), sq1);
m1.add_return({rel0, rel1});
}
run_pass(m1);
migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto relu = m2.add_instruction(migraphx::make_op("relu"), input);
auto s0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s0);
auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s1);
m2.add_return({sq0, sq1});
}
EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading
Loading