Skip to content

Commit 42eaac8

Browse files
committed
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into ins_debug_symbols
2 parents e66c848 + f481878 commit 42eaac8

File tree

8 files changed

+178
-14
lines changed

8 files changed

+178
-14
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Full documentation for MIGraphX is available at
2626

2727
### Optimized
2828

29+
* Added a new pass to replace convolution with constant broadcast input with a reduced GEMM which improves model compilation time (#4621).
30+
2931
### Removed
3032

3133
## MIGraphX 2.15 for ROCm 7.2.0

docs/sphinx/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
rocm-docs-core==1.31.3
1+
rocm-docs-core==1.32.0
22
sphinx-collapse

docs/sphinx/requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ click==8.1.7
6161
# sphinx-external-toc
6262
comm==0.2.2
6363
# via ipykernel
64-
cryptography==44.0.1
64+
cryptography==46.0.5
6565
# via pyjwt
6666
debugpy==1.8.12
6767
# via ipykernel
@@ -212,7 +212,7 @@ requests==2.32.4
212212
# via
213213
# pygithub
214214
# sphinx
215-
rocm-docs-core==1.31.3
215+
rocm-docs-core==1.32.0
216216
# via -r requirements.in
217217
rpds-py==0.22.3
218218
# via
@@ -285,8 +285,9 @@ traitlets==5.14.3
285285
# matplotlib-inline
286286
# nbclient
287287
# nbformat
288-
typing-extensions==4.12.2
288+
typing-extensions==4.15.0
289289
# via
290+
# cryptography
290291
# ipython
291292
# myst-nb
292293
# pydata-sphinx-theme

src/propagate_constant.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ static bool is_const_ins(instruction_ref ins, const std::unordered_set<std::stri
5959
skip_ops.find(ins->name()) == skip_ops.end();
6060
}
6161

62-
static argument as_packed(const argument& c)
62+
static literal as_packed(const argument& c)
6363
{
6464
if(c.get_shape().packed())
65-
return c;
65+
return {c.get_shape(), c.data()};
6666
auto s = c.get_shape().with_lens(c.get_shape().lens());
67-
argument result;
68-
c.visit([&](auto x) { result = literal{s, x.begin(), x.end()}.get_argument(); });
67+
literal result;
68+
c.visit([&](auto x) { result = literal{s, x.begin(), x.end()}; });
6969
return result;
7070
}
7171

@@ -98,11 +98,16 @@ void propagate_constant::apply(module& m) const
9898

9999
// Compute literals in parallel
100100
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
101-
std::vector<argument> literals(const_instrs_vec.size());
101+
std::vector<literal> literals(const_instrs_vec.size());
102102
std::size_t grainsize = 1;
103+
#ifdef _WIN32
104+
grainsize = std::max<std::size_t>(
105+
const_instrs_vec.size() / (std::thread::hardware_concurrency() / 2), 1);
106+
#else
103107
#if !MIGRAPHX_HAS_EXECUTORS
104108
std::size_t n = std::max<std::size_t>(2048 / std::thread::hardware_concurrency(), 1);
105109
grainsize = const_instrs_vec.size() / n;
110+
#endif
106111
#endif
107112
simple_par_for(const_instrs_vec.size(), grainsize, [&](const auto i) {
108113
literals[i] = as_packed(const_instrs_vec[i]->eval());
@@ -128,7 +133,7 @@ void propagate_constant::apply(module& m) const
128133
}
129134
assert(literals[i].get_shape().lens() == const_instrs_vec[i]->get_shape().lens());
130135
assert(literals[i].get_shape().bytes() <= const_instrs_vec[i]->get_shape().bytes());
131-
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
136+
auto l = m.add_literal(literals[i]);
132137
m.replace_instruction(const_instrs_vec[i], l);
133138
}
134139
}

src/simplify_algebra.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,11 +2160,83 @@ struct find_split_transpose
21602160
}
21612161
};
21622162

2163+
// When a convolution's input is a spatially-broadcast constant (e.g. a bias
2164+
// vector broadcast to [N, IC, H, W] with stride-0 spatial dims), the full
2165+
// spatial convolution is redundant. Replace it with:
2166+
// W_reduced[oc,ic] = sum_{kh,kw} W[oc,ic,kh,kw] (reduce_sum)
2167+
// result = dot(input_2d, W_reduced^T) (tiny GEMM)
2168+
// multibroadcast result to the original output shape
2169+
struct find_conv_broadcast_input
2170+
{
2171+
auto matcher() const
2172+
{
2173+
return match::name("convolution")(match::args(
2174+
match::name("broadcast", "multibroadcast")(match::args(match::any().bind("x")))
2175+
.bind("bcast"),
2176+
match::is_constant().bind("w")));
2177+
}
2178+
2179+
void apply(module& m, const match::matcher_result& r) const
2180+
{
2181+
auto ins = r.result;
2182+
auto x_ins = r.instructions["x"];
2183+
auto w_ins = r.instructions["w"];
2184+
2185+
if(ins->get_operator().to_value()["group"].to<int>() != 1)
2186+
return;
2187+
2188+
const auto& x_shape = x_ins->get_shape();
2189+
const auto& w_shape = w_ins->get_shape();
2190+
2191+
const auto& x_lens = x_shape.lens();
2192+
if(x_lens.size() > 2 and
2193+
std::any_of(x_lens.begin() + 2, x_lens.end(), [](auto l) { return l != 1; }))
2194+
return;
2195+
2196+
auto oc = w_shape.lens()[0];
2197+
auto ic = w_shape.lens()[1];
2198+
2199+
auto out_lens = ins->get_shape().lens();
2200+
auto n = out_lens[0];
2201+
2202+
if(x_shape.elements() != n * ic)
2203+
return;
2204+
2205+
auto ndim = w_shape.ndim();
2206+
std::vector<int64_t> spatial_axes(ndim - 2);
2207+
std::iota(spatial_axes.begin(), spatial_axes.end(), 2);
2208+
2209+
auto w_reduced =
2210+
m.insert_instruction(ins, make_op("reduce_sum", {{"axes", spatial_axes}}), w_ins);
2211+
auto w_2d = m.insert_instruction(
2212+
ins, make_op("reshape", {{"dims", std::vector<std::size_t>{oc, ic}}}), w_reduced);
2213+
auto w_t = m.insert_instruction(
2214+
ins, make_op("transpose", {{"permutation", std::vector<int64_t>{1, 0}}}), w_2d);
2215+
2216+
instruction_ref x_2d;
2217+
if(x_shape.ndim() == 1 and n == 1)
2218+
x_2d = m.insert_instruction(
2219+
ins, make_op("unsqueeze", {{"axes", std::vector<int64_t>{0}}}), x_ins);
2220+
else
2221+
x_2d = m.insert_instruction(
2222+
ins, make_op("reshape", {{"dims", std::vector<std::size_t>{n, ic}}}), x_ins);
2223+
2224+
auto dot_result = m.insert_instruction(ins, make_op("dot"), x_2d, w_t);
2225+
2226+
auto dot_1d = m.insert_instruction(
2227+
ins, make_op("squeeze", {{"axes", std::vector<int64_t>{0}}}), dot_result);
2228+
2229+
m.replace_instruction(
2230+
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", out_lens}}), dot_1d);
2231+
}
2232+
};
2233+
21632234
void simplify_algebra::apply(module& m) const
21642235
{
21652236
// Run simplifications multiple times
21662237
m.repeat_while_changes(8, [&] {
21672238
match::find_matches(m,
2239+
find_conv_broadcast_input{},
21682240
find_inner_broadcast{},
21692241
find_dot_broadcast{},
21702242
find_double_add_lit_broadcast{},

test/py/requirements-onnx.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#####################################################################################
2424
onnx==1.18.0;python_version>="3.11"
2525
onnx==1.14.1;python_version<"3.11"
26-
protobuf==4.25.8
26+
protobuf==5.29.6
2727
numpy==1.26.4;python_version>="3.11"
2828
numpy==1.21.6;python_version<"3.11"
2929
packaging==23.0

test/simplify_algebra_test.cpp

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* The MIT License (MIT)
33
*
4-
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
4+
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -4733,4 +4733,88 @@ TEST_CASE(find_concat_different_broadcast_axes)
47334733
EXPECT(m1.sort() == m2.sort());
47344734
}
47354735

4736+
TEST_CASE(conv_broadcast_input)
4737+
{
4738+
migraphx::shape xs{migraphx::shape::float_type, {64}};
4739+
migraphx::shape ws{migraphx::shape::float_type, {64, 64, 3, 3}};
4740+
migraphx::module m1;
4741+
{
4742+
auto x = m1.add_parameter("x", xs);
4743+
auto bcast = m1.add_instruction(
4744+
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 4, 4}}}), x);
4745+
auto w = m1.add_literal(migraphx::generate_literal(ws, 1));
4746+
auto conv = m1.add_instruction(migraphx::make_op("convolution"), bcast, w);
4747+
m1.add_instruction(pass_op{}, conv);
4748+
}
4749+
run_pass(m1);
4750+
4751+
migraphx::module m2;
4752+
{
4753+
auto x = m2.add_parameter("x", xs);
4754+
auto w = m2.add_literal(migraphx::generate_literal(ws, 1));
4755+
auto wr = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), w);
4756+
auto w2d = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 64}}}), wr);
4757+
auto wt =
4758+
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), w2d);
4759+
auto x2d = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), x);
4760+
auto dr = m2.add_instruction(migraphx::make_op("dot"), x2d, wt);
4761+
auto d1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), dr);
4762+
auto r = m2.add_instruction(
4763+
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 2, 2}}}), d1);
4764+
m2.add_instruction(pass_op{}, r);
4765+
}
4766+
EXPECT(m1.sort() == m2.sort());
4767+
}
4768+
4769+
TEST_CASE(conv_multibroadcast_input)
4770+
{
4771+
migraphx::shape xs{migraphx::shape::float_type, {1, 64, 1, 1}};
4772+
migraphx::shape ws{migraphx::shape::float_type, {64, 64, 3, 3}};
4773+
migraphx::module m1;
4774+
{
4775+
auto x = m1.add_parameter("x", xs);
4776+
auto bcast = m1.add_instruction(
4777+
migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 4, 4}}}), x);
4778+
auto w = m1.add_literal(migraphx::generate_literal(ws, 1));
4779+
auto conv = m1.add_instruction(migraphx::make_op("convolution"), bcast, w);
4780+
m1.add_instruction(pass_op{}, conv);
4781+
}
4782+
run_pass(m1);
4783+
4784+
migraphx::module m2;
4785+
{
4786+
auto x = m2.add_parameter("x", xs);
4787+
auto w = m2.add_literal(migraphx::generate_literal(ws, 1));
4788+
auto wr = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), w);
4789+
auto w2d = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 64}}}), wr);
4790+
auto wt =
4791+
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), w2d);
4792+
auto x2d = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 64}}}), x);
4793+
auto dr = m2.add_instruction(migraphx::make_op("dot"), x2d, wt);
4794+
auto d1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), dr);
4795+
auto r = m2.add_instruction(
4796+
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 2, 2}}}), d1);
4797+
m2.add_instruction(pass_op{}, r);
4798+
}
4799+
EXPECT(m1.sort() == m2.sort());
4800+
}
4801+
4802+
TEST_CASE(conv_broadcast_input_group)
4803+
{
4804+
migraphx::shape xs{migraphx::shape::float_type, {64}};
4805+
migraphx::shape ws{migraphx::shape::float_type, {64, 32, 3, 3}};
4806+
migraphx::module m1;
4807+
{
4808+
auto x = m1.add_parameter("x", xs);
4809+
auto bcast = m1.add_instruction(
4810+
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 4, 4}}}), x);
4811+
auto w = m1.add_literal(migraphx::generate_literal(ws, 1));
4812+
auto conv = m1.add_instruction(migraphx::make_op("convolution", {{"group", 2}}), bcast, w);
4813+
m1.add_instruction(pass_op{}, conv);
4814+
}
4815+
migraphx::module m2 = m1;
4816+
run_pass(m1);
4817+
EXPECT(m1.sort() == m2.sort());
4818+
}
4819+
47364820
int main(int argc, const char* argv[]) { test::run(argc, argv); }

tools/requirements-py.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#####################################################################################
22
# The MIT License (MIT)
33
#
4-
# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
4+
# Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
55
#
66
# Permission is hereby granted, free of charge, to any person obtaining a copy
77
# of this software and associated documentation files (the "Software"), to deal
@@ -29,4 +29,4 @@ typing==3.7.4
2929
pytest==6.0.1
3030
packaging==23.0
3131
# pin version of protobuf in Python for onnx runtime unit tests between dist versions
32-
protobuf==4.25.8
32+
protobuf==6.33.5

0 commit comments

Comments
 (0)