From daa138874ca8222f99067fab9d8d9b0fddf05bfc Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 17 Jun 2026 17:28:59 +0000 Subject: [PATCH 1/8] Add find_slice_squeeze matcher to simplify_reshapes Port the find_slice_squeeze matcher from the MLP_prediction_towers branch. This matcher rewrites slice->squeeze->pointwise/reduce into slice->pointwise/reduce->squeeze (unsqueezing the other inputs), which lets the squeeze propagate downstream and parallel slice branches merge back together. Includes the associated unit tests. --- src/simplify_reshapes.cpp | 70 ++++++++++++ test/simplify_reshapes_test.cpp | 191 +++++++++++++++++++++++++++++++- 2 files changed, 258 insertions(+), 3 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 619f52e7291..92366f7b542 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1818,6 +1818,75 @@ struct find_flatten flatten->inputs()); } }; + +// Match slice->squeeze->pw/reduce where the squeeze and slice share the same +// single axis, then rewrite to slice->pw/reduce->squeeze (unsqueezing the +// other inputs). find_op_shape_transform_op propagates the squeeze through +// any downstream op chain, and find_splits in simplify_algebra merges parallel +// branches back together. +struct find_slice_squeeze +{ + auto matcher() const + { + auto match_op = match::any_of(match::pointwise(), match::reduce()); + auto squeeze_slice = match::name("squeeze")( + match::arg(0)(match::name("slice").bind("slice"))) + .bind("squeeze"); + return match_op(match::any_of[match::inputs()](squeeze_slice)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto op_ins = r.result; + auto squeeze = r.instructions["squeeze"]; + auto slice_ins = r.instructions["slice"]; + + auto sq_axes = squeeze->get_operator().to_value()["axes"].to_vector(); + auto sl_axes = slice_ins->get_operator().to_value()["axes"].to_vector(); + if(sq_axes.size() != 1 or sl_axes.size() != 1 or sq_axes.front() != sl_axes.front()) + return; + + auto axis = sq_axes.front(); + + auto inputs = op_ins->inputs(); + for(auto& input : inputs) + { + if(input == squeeze) + input = slice_ins; + else + input = m.insert_instruction( + op_ins, make_op("unsqueeze", {{"axes", {axis}}}), input); + } + + auto op = op_ins->get_operator(); + if(not op.attributes().contains("pointwise")) + { + auto v = op.to_value(); + if(v.contains("axes")) + { + auto op_axes = v["axes"].to_vector(); + + std::transform(op_axes.begin(), + op_axes.end(), + op_axes.begin(), + [&](auto i) { return (i >= axis) ? i + 1 : i; }); + + v["axes"] = op_axes; + op = make_op(op_ins->name(), v); + } + else if(v.contains("axis")) + { + auto a = v["axis"].to(); + v["axis"] = a >= axis ? a + 1 : a; + op = make_op(op_ins->name(), v); + } + } + + auto new_op = m.insert_instruction(op_ins, op, inputs, op_ins->module_inputs()); + auto new_sq = m.insert_instruction(op_ins, squeeze->get_operator(), new_op); + m.replace_instruction(op_ins, new_sq); + } +}; } // namespace void simplify_reshapes::apply(module& m) const @@ -1841,6 +1910,7 @@ void simplify_reshapes::apply(module& m) const find_nested_concat{}, find_transpose_slice{}, find_slice_transpose{}, + find_slice_squeeze{}, find_unary_shape_transforms{}, find_reshape_dot{}, find_mul_add_shape_op_dot{}, diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index ae032ce4f35..69e0499f8dc 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -4279,9 +4279,9 @@ TEST_CASE(transpose_slice_non_packed_axis) migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), unsqueeze); auto slice = m2.add_instruction( migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose); - auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); - auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze); - m2.add_return({sqrt}); + auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), slice); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sqrt); + m2.add_return({squeeze}); } EXPECT(m1 == m2); } @@ -5054,6 +5054,191 @@ TEST_CASE(gather_strided_view_elements_mismatch) EXPECT(m1.get_output_shapes() == m2.get_output_shapes()); } +TEST_CASE(slice_squeeze_reduce_sum) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4, 8}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sl); + auto red = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), sq); + m1.add_return({red}); + } + run_pass(m1); + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto sl = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), sl); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), red); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_reduce_sum_all_axes) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4, 8}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sl); + auto red = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {0, 1}}}), sq); + m1.add_return({red}); + } + run_pass(m1); + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto sl = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto red = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), sl); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), red); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_argmin) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4, 8}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sl); + auto argmin = m1.add_instruction(migraphx::make_op("argmin", {{"axis", 1}}), sq); + m1.add_return({argmin}); + } + run_pass(m1); + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto sl = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto argmin = m2.add_instruction(migraphx::make_op("argmin", {{"axis", 2}}), sl); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), argmin); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_axis_mismatch) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 2, 4}}; + migraphx::shape bs{migraphx::shape::float_type, {1, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto bias = m1.add_parameter("bias", bs); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sl); + auto add = m1.add_instruction(migraphx::make_op("add"), sq, bias); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto bias = m2.add_parameter("bias", bs); + auto sl = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), bias); + auto add = m2.add_instruction(migraphx::make_op("add"), sl, unsq); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), add); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_multi_axis) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3, 1}}; + migraphx::shape bs{migraphx::shape::float_type, {3}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto bias = m1.add_parameter("bias", bs); + auto sl = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0, 2}}}), sl); + auto add = m1.add_instruction(migraphx::make_op("add"), sq, bias); + m1.add_return({add}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_binary_two_inputs) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto s0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto s1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input); + auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + auto add = m1.add_instruction(migraphx::make_op("add"), sq0, sq1); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto s0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input); + auto add = m2.add_instruction(migraphx::make_op("add"), s0, s1); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), add); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_binary_different_inputs) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4}}; + migraphx::module m1; + { + auto a = m1.add_parameter("a", s); + auto b = m1.add_parameter("b", s); + auto s0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), a); + auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto s1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), b); + auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + auto add = m1.add_instruction(migraphx::make_op("add"), sq0, sq1); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto a = m2.add_parameter("a", s); + auto b = m2.add_parameter("b", s); + auto s0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), a); + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), b); + auto add = m2.add_instruction(migraphx::make_op("add"), s0, s1); + auto sq = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), add); + m2.add_return({sq}); + } + EXPECT(m1.sort() == m2.sort()); +} + TEST_CASE(slice_reshape_multibroadcast_rebase_axis) { migraphx::module m1; From ea45bcec072e363b320e915019b5c9abcfb1e1a0 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 17 Jun 2026 17:33:11 +0000 Subject: [PATCH 2/8] Reuse find_op_shape_transform_op::is_reduce in find_slice_squeeze Replace the brittle "not pointwise" check with the shared is_reduce helper from find_op_shape_transform_op so reduce/argmin/argmax detection is precise and consistent with the rest of the pass. --- src/simplify_reshapes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 92366f7b542..81e20578ab6 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1859,7 +1859,7 @@ struct find_slice_squeeze } auto op = op_ins->get_operator(); - if(not op.attributes().contains("pointwise")) + if(find_op_shape_transform_op::is_reduce(op_ins)) { auto v = op.to_value(); if(v.contains("axes")) From 53e3c11872d64afb8438f24263a5130ee6a58761 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 17 Jun 2026 17:33:51 +0000 Subject: [PATCH 3/8] Reuse find_op_shape_transform_op::insert for axis remapping Delegate the reduce/argmin axis remapping in find_slice_squeeze to the shared insert() helper by building a source->common axes map for the unsqueeze. This removes the hand-rolled axis-shifting logic, keeps behavior consistent with find_op_shape_transform_op, and additionally handles layout permutations for free. --- src/simplify_reshapes.cpp | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 81e20578ab6..9225faf0ebd 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1858,31 +1858,16 @@ struct find_slice_squeeze op_ins, make_op("unsqueeze", {{"axes", {axis}}}), input); } - auto op = op_ins->get_operator(); - if(find_op_shape_transform_op::is_reduce(op_ins)) - { - auto v = op.to_value(); - if(v.contains("axes")) - { - auto op_axes = v["axes"].to_vector(); - - std::transform(op_axes.begin(), - op_axes.end(), - op_axes.begin(), - [&](auto i) { return (i >= axis) ? i + 1 : i; }); - - v["axes"] = op_axes; - op = make_op(op_ins->name(), v); - } - else if(v.contains("axis")) - { - auto a = v["axis"].to(); - v["axis"] = a >= axis ? a + 1 : a; - op = make_op(op_ins->name(), v); - } - } - - auto new_op = m.insert_instruction(op_ins, op, inputs, op_ins->module_inputs()); + // Unsqueezing the inputs shifts every axis at or after `axis` up by one. + // Build the source->common axes map and let find_op_shape_transform_op + // handle reduce/argmin/layout axis remapping (pointwise ops are inserted + // unchanged), instead of duplicating that logic here. + auto axis_sz = static_cast(axis); + std::vector> axes_map(squeeze->get_shape().ndim()); + for(std::size_t i = 0; i < axes_map.size(); ++i) + axes_map[i] = {i >= axis_sz ? i + 1 : i}; + + auto new_op = find_op_shape_transform_op::insert(m, op_ins, inputs, axes_map); auto new_sq = m.insert_instruction(op_ins, squeeze->get_operator(), new_op); m.replace_instruction(op_ins, new_sq); } From c95630cab5f715cc53df4faa72b2b9570ab180cf Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 22 Jun 2026 22:26:33 +0000 Subject: [PATCH 4/8] fix cppcheck --- src/simplify_reshapes.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 9225faf0ebd..4e2b7f63e0e 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1862,10 +1862,9 @@ struct find_slice_squeeze // Build the source->common axes map and let find_op_shape_transform_op // handle reduce/argmin/layout axis remapping (pointwise ops are inserted // unchanged), instead of duplicating that logic here. - auto axis_sz = static_cast(axis); std::vector> axes_map(squeeze->get_shape().ndim()); for(std::size_t i = 0; i < axes_map.size(); ++i) - axes_map[i] = {i >= axis_sz ? i + 1 : i}; + axes_map[i] = {i >= axis ? i + 1 : i}; auto new_op = find_op_shape_transform_op::insert(m, op_ins, inputs, axes_map); auto new_sq = m.insert_instruction(op_ins, squeeze->get_operator(), new_op); From 9f776f6b6c9e8057eef307d671e5e3c2e715039e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 22 Jun 2026 22:27:05 +0000 Subject: [PATCH 5/8] Fix format --- src/simplify_reshapes.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 4e2b7f63e0e..77540df7f8c 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -1829,9 +1829,9 @@ struct find_slice_squeeze auto matcher() const { auto match_op = match::any_of(match::pointwise(), match::reduce()); - auto squeeze_slice = match::name("squeeze")( - match::arg(0)(match::name("slice").bind("slice"))) - .bind("squeeze"); + auto squeeze_slice = + match::name("squeeze")(match::arg(0)(match::name("slice").bind("slice"))) + .bind("squeeze"); return match_op(match::any_of[match::inputs()](squeeze_slice)); } @@ -1854,8 +1854,8 @@ struct find_slice_squeeze if(input == squeeze) input = slice_ins; else - input = m.insert_instruction( - op_ins, make_op("unsqueeze", {{"axes", {axis}}}), input); + input = + m.insert_instruction(op_ins, make_op("unsqueeze", {{"axes", {axis}}}), input); } // Unsqueezing the inputs shifts every axis at or after `axis` up by one. From 26c7a6e3df031e6decb51f01ffc447dfe96a2c56 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 23 Jun 2026 03:58:33 +0000 Subject: [PATCH 6/8] Add optimize module tests --- test/optimize_module_test.cpp | 159 ++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/test/optimize_module_test.cpp b/test/optimize_module_test.cpp index fb59a36725b..374baf84b27 100644 --- a/test/optimize_module_test.cpp +++ b/test/optimize_module_test.cpp @@ -185,4 +185,163 @@ TEST_CASE(mul_add_transpose_dot) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(slice_squeeze_pw_unary) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto s0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto rel0 = m1.add_instruction(migraphx::make_op("relu"), sq0); + auto s1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input); + auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + auto rel1 = m1.add_instruction(migraphx::make_op("relu"), sq1); + m1.add_return({rel0, rel1}); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto relu = m2.add_instruction(migraphx::make_op("relu"), input); + auto s0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), relu); + auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), relu); + auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + m2.add_return({sq0, sq1}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_pw_unary_3d) +{ + migraphx::shape s{migraphx::shape::float_type, {3, 2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto s0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto rel0 = m1.add_instruction(migraphx::make_op("relu"), sq0); + auto s1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input); + auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + auto rel1 = m1.add_instruction(migraphx::make_op("relu"), sq1); + auto s2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), input); + auto sq2 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s2); + auto rel2 = m1.add_instruction(migraphx::make_op("relu"), sq2); + m1.add_return({rel0, rel1, rel2}); + } + run_pass(m1); + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto relu = m2.add_instruction(migraphx::make_op("relu"), input); + auto s0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), relu); + auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), relu); + auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + auto s2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), relu); + auto sq2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s2); + m2.add_return({sq0, sq1, sq2}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(slice_squeeze_pw_binary_const) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 4}}; + migraphx::shape bs{migraphx::shape::float_type, {4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto b0 = m1.add_literal(migraphx::generate_literal(bs, 0)); + auto b1 = m1.add_literal(migraphx::generate_literal(bs, 1)); + + auto s0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto add0 = m1.add_instruction(migraphx::make_op("add"), sq0, b0); + + auto s1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), input); + auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + auto add1 = m1.add_instruction(migraphx::make_op("add"), sq1, b1); + + m1.add_return({add0, add1}); + } + run_pass(m1); + + // propagate_constant folds unsqueeze+concat of literals into one literal + migraphx::literal stacked_lit; + { + migraphx::module tmp; + auto b0 = tmp.add_literal(migraphx::generate_literal(bs, 0)); + auto b1 = tmp.add_literal(migraphx::generate_literal(bs, 1)); + auto bu0 = tmp.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), b0); + auto bu1 = tmp.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), b1); + auto cat = tmp.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), bu0, bu1); + auto ev = cat->eval(); + stacked_lit = migraphx::literal(ev.get_shape(), ev.data()); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto stacked = m2.add_literal(stacked_lit); + auto add = m2.add_instruction(migraphx::make_op("add"), input, stacked); + + auto s0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add); + auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s0); + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add); + auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), s1); + + m2.add_return({sq0, sq1}); + } + EXPECT(m1.sort() == m2.sort()); +} +TEST_CASE(slice_squeeze_non_zero_axis) +{ + migraphx::shape s{migraphx::shape::float_type, {3, 2, 4}}; + migraphx::module m1; + { + auto input = m1.add_parameter("input", s); + auto s0 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); + auto sq0 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s0); + auto rel0 = m1.add_instruction(migraphx::make_op("relu"), sq0); + auto s1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), input); + auto sq1 = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s1); + auto rel1 = m1.add_instruction(migraphx::make_op("relu"), sq1); + m1.add_return({rel0, rel1}); + } + run_pass(m1); + migraphx::module m2; + { + auto input = m2.add_parameter("input", s); + auto relu = m2.add_instruction(migraphx::make_op("relu"), input); + auto s0 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu); + auto sq0 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s0); + auto s1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu); + auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), s1); + m2.add_return({sq0, sq1}); + } + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 8d99e7d10e034bd86fe5d1e091bf5d0ed7663196 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 23 Jun 2026 04:13:51 +0000 Subject: [PATCH 7/8] fix license --- test/optimize_module_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize_module_test.cpp b/test/optimize_module_test.cpp index 374baf84b27..298d81d4027 100644 --- a/test/optimize_module_test.cpp +++ b/test/optimize_module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 960f180b76ec828c2b9c734728d0056e3ac1e3b9 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Tue, 23 Jun 2026 04:15:35 +0000 Subject: [PATCH 8/8] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47fec6d3b38..82ac58b6c9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ Full documentation for MIGraphX is available at * Added driver warnings when inputs dimensions and/or values are not set (#4850). * Added documentation for using debug symbols (#4945). * Added `--log-stdout` flag to migraphx-driver to log to stdout instead of stderr (#4959). +* Added slice squeeze matcher to propogate squeeze downstream and allow for parallel branches to merge together (#5004) ### Changed