@@ -31,36 +31,29 @@ inline std::ostream& operator<<(std::ostream& os, const MoEType& type) {
3131 }
3232}
3333
34- typedef std::tuple<MoePatternParams,
35- MoEType, // MoE builder type
36- ov::test::ElementType, // weights precision
37- ov::test::ElementType, // decompression precision
38- ov::test::ElementType, // scale precision
39- bool , // use weight decompression
40- DecompressionType, // decompression multiply type
41- DecompressionType, // decompression subtract type
42- bool , // reshape on decompression constants
43- int , // decompression_group_size
44- ov::AnyMap> // additional config
45- MoeTestParams;
46-
47- class MoeSubgraphTest : public testing ::WithParamInterface<MoeTestParams>,
34+ using MoeTestParams = std::tuple<MoePatternParams,
35+ MoEType, // MoE builder type
36+ ov::AnyMap>; // additional config
37+
38+ using MoeCompressedWeightsTestParams = std::tuple<MoePatternParams,
39+ MoEType, // MoE builder type
40+ ov::test::ElementType, // weights precision
41+ ov::test::ElementType, // decompression precision
42+ ov::test::ElementType, // scale precision
43+ DecompressionType, // decompression multiply type
44+ DecompressionType, // decompression subtract type
45+ bool , // reshape on decompression constants
46+ int , // decompression_group_size
47+ ov::AnyMap, // additional config
48+ bool >; // use_matmul_decompression_impl
49+
50+ class MoESubgraphTest : public testing ::WithParamInterface<MoeTestParams>,
4851 virtual public SubgraphBaseTest,
4952 public CpuTestWithFusing {
5053public:
51- static std::string getTestCaseName (const testing::TestParamInfo<MoeTestParams>& obj) {
52- const auto & [moe_params,
53- moe_type,
54- weights_precision,
55- decompression_precision,
56- scale_precision,
57- use_weight_decompression,
58- decompression_multiply_type,
59- decompression_subtract_type,
60- reshape_on_decompression,
61- decompression_group_size,
62- additional_config] = obj.param ;
63-
54+ static std::string generateBaseMoeTestName (const MoePatternParams& moe_params,
55+ const MoEType& moe_type,
56+ const ov::AnyMap& additional_config) {
6457 std::ostringstream result;
6558 result << " IS=" << ov::test::utils::partialShape2str ({moe_params.data_shape .first }) << " _" ;
6659 result << " TS=" ;
@@ -72,16 +65,6 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
7265 result << " intermediate_size=" << moe_params.intermediate_size << " _" ;
7366 result << " moe_type=" << moe_type << " _" ;
7467
75- if (use_weight_decompression) {
76- result << " WP=" << weights_precision << " _" ;
77- result << " DP=" << decompression_precision << " _" ;
78- result << " SP=" << scale_precision << " _" ;
79- result << " DM=" << decompression_multiply_type << " _" ;
80- result << " DS=" << decompression_subtract_type << " _" ;
81- result << " RD=" << reshape_on_decompression << " _" ;
82- result << " GS=" << decompression_group_size << " _" ;
83- }
84-
8568 result << " config=(" ;
8669 for (const auto & configEntry : additional_config) {
8770 result << configEntry.first << " =" << configEntry.second .as <std::string>() << " _" ;
@@ -90,6 +73,58 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
9073
9174 return result.str ();
9275 }
76+ static std::string getTestCaseName (const testing::TestParamInfo<MoeTestParams>& obj) {
77+ const auto & [moe_params, moe_type, additional_config] = obj.param ;
78+ return generateBaseMoeTestName (moe_params, moe_type, additional_config);
79+ }
80+
81+ protected:
82+ void SetUp () override {
83+ targetDevice = ov::test::utils::DEVICE_CPU;
84+ const auto & [shape_params, moe_type, additional_config] = GetParam ();
85+
86+ configuration.insert (additional_config.begin (), additional_config.end ());
87+ init_input_shapes ({shape_params.data_shape });
88+ inType = outType = ov::element::f32 ;
89+
90+ if (moe_type == MoEType::MoE2GeMM) {
91+ function = initMoE2GeMMSubgraph (shape_params, ov::element::f32 , ov::element::f32 );
92+ } else if (moe_type == MoEType::MoE3GeMM) {
93+ function = initMoE3GeMMSubgraph (shape_params, ov::element::f32 , ov::element::f32 );
94+ } else {
95+ OPENVINO_THROW (" Unsupported MoEType" );
96+ }
97+ }
98+ };
99+
100+ class MoECompressedWeightsSubgraphTest : public testing ::WithParamInterface<MoeCompressedWeightsTestParams>,
101+ virtual public SubgraphBaseTest,
102+ public CpuTestWithFusing {
103+ public:
104+ static std::string getTestCaseName (const testing::TestParamInfo<MoeCompressedWeightsTestParams>& obj) {
105+ const auto & [moe_params,
106+ moe_type,
107+ weights_precision,
108+ decompression_precision,
109+ scale_precision,
110+ decompression_multiply_type,
111+ decompression_subtract_type,
112+ reshape_on_decompression,
113+ decompression_group_size,
114+ additional_config,
115+ use_matmul_decompression_impl] = obj.param ;
116+ std::ostringstream result;
117+ result << MoESubgraphTest::generateBaseMoeTestName (moe_params, moe_type, additional_config) << " _" ;
118+ result << " _WP=" << weights_precision << " _" ;
119+ result << " DP=" << decompression_precision << " _" ;
120+ result << " SP=" << scale_precision << " _" ;
121+ result << " DM=" << decompression_multiply_type << " _" ;
122+ result << " DS=" << decompression_subtract_type << " _" ;
123+ result << " RD=" << reshape_on_decompression << " _" ;
124+ result << " GS=" << decompression_group_size << " _" ;
125+ result << " use_matmul_decompression_impl=" << use_matmul_decompression_impl << " _" ;
126+ return result.str ();
127+ }
93128
94129protected:
95130 void SetUp () override {
@@ -100,12 +135,12 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
100135 weights_precision,
101136 decompression_precision,
102137 scale_precision,
103- use_weight_decompression,
104138 decompression_multiply_type,
105139 decompression_subtract_type,
106140 reshape_on_decompression,
107141 decompression_group_size,
108- additional_config] = GetParam ();
142+ additional_config,
143+ use_matmul_decompression_impl] = GetParam ();
109144
110145 configuration.insert (additional_config.begin (), additional_config.end ());
111146 init_input_shapes ({shape_params.data_shape });
@@ -115,9 +150,9 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
115150 function = initMoE2GeMMSubgraph (shape_params,
116151 ov::element::f32 ,
117152 weights_precision,
153+ true ,
118154 decompression_precision,
119- ov::element::f32 ,
120- use_weight_decompression,
155+ scale_precision,
121156 decompression_multiply_type,
122157 decompression_subtract_type,
123158 reshape_on_decompression,
@@ -126,9 +161,9 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
126161 function = initMoE3GeMMSubgraph (shape_params,
127162 ov::element::f32 ,
128163 weights_precision,
164+ true ,
129165 decompression_precision,
130- ov::element::f32 ,
131- use_weight_decompression,
166+ scale_precision,
132167 decompression_multiply_type,
133168 decompression_subtract_type,
134169 reshape_on_decompression,
@@ -137,14 +172,42 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
137172 OPENVINO_THROW (" Unsupported MoEType" );
138173 }
139174 }
175+
176+ void check_results () {
177+ const auto & test_param = GetParam ();
178+ const ov::element::Type compressed_weights_precision = std::get<2 >(test_param);
179+ const bool use_matmul_decompression_impl = std::get<10 >(test_param);
180+
181+ const auto runtime_model = compiledModel.get_runtime_model ();
182+ const auto result = runtime_model->get_result ();
183+ auto batch_gather_mm = result->get_input_node_shared_ptr (0 );
184+
185+ auto type = batch_gather_mm->get_rt_info ().at (ov::exec_model_info::LAYER_TYPE).as <std::string>();
186+ if (type == " Reorder" || type == " Convert" || type == " Subgraph" )
187+ batch_gather_mm = batch_gather_mm->get_input_node_shared_ptr (0 );
188+
189+ type = batch_gather_mm->get_rt_info ().at (ov::exec_model_info::LAYER_TYPE).as <std::string>();
190+ EXPECT_EQ (type, " BatchGatherMatmul" );
191+
192+ const auto & expected_weights_precision =
193+ use_matmul_decompression_impl ? compressed_weights_precision : batch_gather_mm->get_input_element_type (0 );
194+ EXPECT_EQ (batch_gather_mm->get_input_element_type (1 ), expected_weights_precision);
195+ }
140196};
141197
198+ TEST_P (MoESubgraphTest, CompareWithRefs) {
199+ SKIP_IF_CURRENT_TEST_IS_DISABLED ()
200+ run ();
201+ }
202+
203+ TEST_P (MoECompressedWeightsSubgraphTest, CompareWithRefs) {
204+ SKIP_IF_CURRENT_TEST_IS_DISABLED ()
205+ run ();
206+ check_results ();
207+ }
208+
142209namespace {
143- // Test parameter generation
144- const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32 };
145- const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8 , ov::element::u4, ov::element::i4};
146210const std::vector<MoEType> moe_types = {MoEType::MoE2GeMM, MoEType::MoE3GeMM};
147-
148211const std::vector<MoePatternParams> moe_params = {
149212 {
150213 {{-1 , -1 , 2048 }, {{2 , 15 , 2048 }, {2 , 1 , 2048 }, {3 , 8 , 2048 }}}, // data_shape,
@@ -165,41 +228,36 @@ const ov::AnyMap additional_config_bf16 = {{ov::hint::inference_precision.name()
165228
166229} // namespace
167230
168- // Basic FP32 tests
169- INSTANTIATE_TEST_SUITE_P (smoke_MoeSubgraph_basic,
170- MoeSubgraphTest,
231+ INSTANTIATE_TEST_SUITE_P (smoke_MoESubgraph_basic,
232+ MoESubgraphTest,
171233 ::testing::Combine (::testing::ValuesIn(moe_params),
172234 ::testing::ValuesIn(moe_types),
173- ::testing::Values(ov::element::f32 ),
174- ::testing::Values(ov::element::f32 ),
175- ::testing::Values(ov::element::f32 ),
176- ::testing::Values(false ),
177- ::testing::Values(DecompressionType::full),
178- ::testing::Values(DecompressionType::full),
179- ::testing::Values(false ),
180- ::testing::Values(0 ),
181235 ::testing::Values(additional_config_basic)),
182- MoeSubgraphTest ::getTestCaseName);
236+ MoESubgraphTest ::getTestCaseName);
183237
184- // BF16 inference precision tests
185- INSTANTIATE_TEST_SUITE_P (smoke_MoeSubgraph_bf16,
186- MoeSubgraphTest,
238+ INSTANTIATE_TEST_SUITE_P (smoke_MoESubgraph_bf16,
239+ MoESubgraphTest,
187240 ::testing::Combine (::testing::ValuesIn(moe_params),
188241 ::testing::ValuesIn(moe_types),
242+ ::testing::Values(additional_config_bf16)),
243+ MoESubgraphTest::getTestCaseName);
244+
245+ const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32 };
246+ const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8 , ov::element::u4, ov::element::i4};
247+
248+ INSTANTIATE_TEST_SUITE_P (smoke_MoeCompressedWeights,
249+ MoECompressedWeightsSubgraphTest,
250+ ::testing::Combine (::testing::ValuesIn(moe_params),
251+ ::testing::ValuesIn(moe_types),
252+ ::testing::ValuesIn(weights_precisions),
253+ ::testing::ValuesIn(decompression_precisions),
189254 ::testing::Values(ov::element::f32 ),
190- ::testing::Values(ov::element::f32 ),
191- ::testing::Values(ov::element::f32 ),
192- ::testing::Values(false ),
193255 ::testing::Values(DecompressionType::full),
194256 ::testing::Values(DecompressionType::full),
195- ::testing::Values(false ),
196- ::testing::Values(0 ),
197- ::testing::Values(additional_config_bf16)),
198- MoeSubgraphTest::getTestCaseName);
199-
200- TEST_P (MoeSubgraphTest, CompareWithRefs) {
201- SKIP_IF_CURRENT_TEST_IS_DISABLED ()
202- run ();
203- }
257+ ::testing::Values(false ), // reshape on decompression
258+ ::testing::Values(16 ), // decompression group size
259+ ::testing::Values(additional_config_basic),
260+ ::testing::Values(true )), // use_matmul_decompression_impl
261+ MoECompressedWeightsSubgraphTest::getTestCaseName);
204262
205263} // namespace ov::test
0 commit comments