|
1 | 1 | /* |
2 | 2 | * The MIT License (MIT) |
3 | 3 | * |
4 | | - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. |
| 4 | + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. |
5 | 5 | * |
6 | 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
7 | 7 | * of this software and associated documentation files (the "Software"), to deal |
@@ -4733,4 +4733,88 @@ TEST_CASE(find_concat_different_broadcast_axes) |
4733 | 4733 | EXPECT(m1.sort() == m2.sort()); |
4734 | 4734 | } |
4735 | 4735 |
|
| 4736 | +TEST_CASE(conv_broadcast_input) |
| 4737 | +{ |
| 4738 | + migraphx::shape xs{migraphx::shape::float_type, {64}}; |
| 4739 | + migraphx::shape ws{migraphx::shape::float_type, {64, 64, 3, 3}}; |
| 4740 | + migraphx::module m1; |
| 4741 | + { |
| 4742 | + auto x = m1.add_parameter("x", xs); |
| 4743 | + auto bcast = m1.add_instruction( |
| 4744 | + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 4, 4}}}), x); |
| 4745 | + auto w = m1.add_literal(migraphx::generate_literal(ws, 1)); |
| 4746 | + auto conv = m1.add_instruction(migraphx::make_op("convolution"), bcast, w); |
| 4747 | + m1.add_instruction(pass_op{}, conv); |
| 4748 | + } |
| 4749 | + run_pass(m1); |
| 4750 | + |
| 4751 | + migraphx::module m2; |
| 4752 | + { |
| 4753 | + auto x = m2.add_parameter("x", xs); |
| 4754 | + auto w = m2.add_literal(migraphx::generate_literal(ws, 1)); |
| 4755 | + auto wr = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), w); |
| 4756 | + auto w2d = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 64}}}), wr); |
| 4757 | + auto wt = |
| 4758 | + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), w2d); |
| 4759 | + auto x2d = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), x); |
| 4760 | + auto dr = m2.add_instruction(migraphx::make_op("dot"), x2d, wt); |
| 4761 | + auto d1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), dr); |
| 4762 | + auto r = m2.add_instruction( |
| 4763 | + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 2, 2}}}), d1); |
| 4764 | + m2.add_instruction(pass_op{}, r); |
| 4765 | + } |
| 4766 | + EXPECT(m1.sort() == m2.sort()); |
| 4767 | +} |
| 4768 | + |
| 4769 | +TEST_CASE(conv_multibroadcast_input) |
| 4770 | +{ |
| 4771 | + migraphx::shape xs{migraphx::shape::float_type, {1, 64, 1, 1}}; |
| 4772 | + migraphx::shape ws{migraphx::shape::float_type, {64, 64, 3, 3}}; |
| 4773 | + migraphx::module m1; |
| 4774 | + { |
| 4775 | + auto x = m1.add_parameter("x", xs); |
| 4776 | + auto bcast = m1.add_instruction( |
| 4777 | + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 4, 4}}}), x); |
| 4778 | + auto w = m1.add_literal(migraphx::generate_literal(ws, 1)); |
| 4779 | + auto conv = m1.add_instruction(migraphx::make_op("convolution"), bcast, w); |
| 4780 | + m1.add_instruction(pass_op{}, conv); |
| 4781 | + } |
| 4782 | + run_pass(m1); |
| 4783 | + |
| 4784 | + migraphx::module m2; |
| 4785 | + { |
| 4786 | + auto x = m2.add_parameter("x", xs); |
| 4787 | + auto w = m2.add_literal(migraphx::generate_literal(ws, 1)); |
| 4788 | + auto wr = m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), w); |
| 4789 | + auto w2d = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 64}}}), wr); |
| 4790 | + auto wt = |
| 4791 | + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), w2d); |
| 4792 | + auto x2d = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 64}}}), x); |
| 4793 | + auto dr = m2.add_instruction(migraphx::make_op("dot"), x2d, wt); |
| 4794 | + auto d1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), dr); |
| 4795 | + auto r = m2.add_instruction( |
| 4796 | + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 2, 2}}}), d1); |
| 4797 | + m2.add_instruction(pass_op{}, r); |
| 4798 | + } |
| 4799 | + EXPECT(m1.sort() == m2.sort()); |
| 4800 | +} |
| 4801 | + |
| 4802 | +TEST_CASE(conv_broadcast_input_group) |
| 4803 | +{ |
| 4804 | + migraphx::shape xs{migraphx::shape::float_type, {64}}; |
| 4805 | + migraphx::shape ws{migraphx::shape::float_type, {64, 32, 3, 3}}; |
| 4806 | + migraphx::module m1; |
| 4807 | + { |
| 4808 | + auto x = m1.add_parameter("x", xs); |
| 4809 | + auto bcast = m1.add_instruction( |
| 4810 | + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 64, 4, 4}}}), x); |
| 4811 | + auto w = m1.add_literal(migraphx::generate_literal(ws, 1)); |
| 4812 | + auto conv = m1.add_instruction(migraphx::make_op("convolution", {{"group", 2}}), bcast, w); |
| 4813 | + m1.add_instruction(pass_op{}, conv); |
| 4814 | + } |
| 4815 | + migraphx::module m2 = m1; |
| 4816 | + run_pass(m1); |
| 4817 | + EXPECT(m1.sort() == m2.sort()); |
| 4818 | +} |
| 4819 | + |
4736 | 4820 | int main(int argc, const char* argv[]) { test::run(argc, argv); } |
0 commit comments