Skip to content

Commit f7d654e

Browse files
v-Golubevmaxnick
authored andcommitted
[TESTS] MoECompressedWeightsSubgraphTest class
1 parent db4eb6d commit f7d654e

File tree

6 files changed

+349
-188
lines changed

6 files changed

+349
-188
lines changed

src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/batch_gather_matmul_compressed.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,6 @@ void BatchGatherMatmulCompressed::validate_and_infer_types() {
6666
NODE_VALIDATION_CHECK(this,
6767
ov::is_type<ov::op::v0::Constant>(weight_zero_points),
6868
"Input weight_zero_points must be a Constant node.");
69-
70-
// check wight_scales and weight_zero_points are either per channel or per tensor
71-
const auto& weight_scales_shape = weight_scales->get_output_partial_shape(0);
72-
const auto& weight_zero_points_shape = weight_zero_points->get_output_partial_shape(0);
73-
auto weight_shape = get_input_partial_shape(1);
74-
75-
using ov::op::AutoBroadcastType;
76-
NODE_VALIDATION_CHECK(
77-
this,
78-
PartialShape::broadcast_merge_into(weight_shape, weight_scales_shape, AutoBroadcastType::NUMPY) &&
79-
PartialShape::broadcast_merge_into(weight_shape, weight_zero_points_shape, AutoBroadcastType::NUMPY),
80-
"Input weight_scales and weight_zero_points shapes are not compatible with weight shape.");
81-
8269
BatchGatherMatmul::validate_and_infer_types();
8370
}
8471
} // namespace ov::intel_cpu

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/moe.cpp

Lines changed: 132 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -31,36 +31,29 @@ inline std::ostream& operator<<(std::ostream& os, const MoEType& type) {
3131
}
3232
}
3333

34-
typedef std::tuple<MoePatternParams,
35-
MoEType, // MoE builder type
36-
ov::test::ElementType, // weights precision
37-
ov::test::ElementType, // decompression precision
38-
ov::test::ElementType, // scale precision
39-
bool, // use weight decompression
40-
DecompressionType, // decompression multiply type
41-
DecompressionType, // decompression subtract type
42-
bool, // reshape on decompression constants
43-
int, // decompression_group_size
44-
ov::AnyMap> // additional config
45-
MoeTestParams;
46-
47-
class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
34+
using MoeTestParams = std::tuple<MoePatternParams,
35+
MoEType, // MoE builder type
36+
ov::AnyMap>; // additional config
37+
38+
using MoeCompressedWeightsTestParams = std::tuple<MoePatternParams,
39+
MoEType, // MoE builder type
40+
ov::test::ElementType, // weights precision
41+
ov::test::ElementType, // decompression precision
42+
ov::test::ElementType, // scale precision
43+
DecompressionType, // decompression multiply type
44+
DecompressionType, // decompression subtract type
45+
bool, // reshape on decompression constants
46+
int, // decompression_group_size
47+
ov::AnyMap, // additional config
48+
bool>; // use_matmul_decompression_impl
49+
50+
class MoESubgraphTest : public testing::WithParamInterface<MoeTestParams>,
4851
virtual public SubgraphBaseTest,
4952
public CpuTestWithFusing {
5053
public:
51-
static std::string getTestCaseName(const testing::TestParamInfo<MoeTestParams>& obj) {
52-
const auto& [moe_params,
53-
moe_type,
54-
weights_precision,
55-
decompression_precision,
56-
scale_precision,
57-
use_weight_decompression,
58-
decompression_multiply_type,
59-
decompression_subtract_type,
60-
reshape_on_decompression,
61-
decompression_group_size,
62-
additional_config] = obj.param;
63-
54+
static std::string generateBaseMoeTestName(const MoePatternParams& moe_params,
55+
const MoEType& moe_type,
56+
const ov::AnyMap& additional_config) {
6457
std::ostringstream result;
6558
result << "IS=" << ov::test::utils::partialShape2str({moe_params.data_shape.first}) << "_";
6659
result << "TS=";
@@ -72,16 +65,6 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
7265
result << "intermediate_size=" << moe_params.intermediate_size << "_";
7366
result << "moe_type=" << moe_type << "_";
7467

75-
if (use_weight_decompression) {
76-
result << "WP=" << weights_precision << "_";
77-
result << "DP=" << decompression_precision << "_";
78-
result << "SP=" << scale_precision << "_";
79-
result << "DM=" << decompression_multiply_type << "_";
80-
result << "DS=" << decompression_subtract_type << "_";
81-
result << "RD=" << reshape_on_decompression << "_";
82-
result << "GS=" << decompression_group_size << "_";
83-
}
84-
8568
result << "config=(";
8669
for (const auto& configEntry : additional_config) {
8770
result << configEntry.first << "=" << configEntry.second.as<std::string>() << "_";
@@ -90,6 +73,58 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
9073

9174
return result.str();
9275
}
76+
static std::string getTestCaseName(const testing::TestParamInfo<MoeTestParams>& obj) {
77+
const auto& [moe_params, moe_type, additional_config] = obj.param;
78+
return generateBaseMoeTestName(moe_params, moe_type, additional_config);
79+
}
80+
81+
protected:
82+
void SetUp() override {
83+
targetDevice = ov::test::utils::DEVICE_CPU;
84+
const auto& [shape_params, moe_type, additional_config] = GetParam();
85+
86+
configuration.insert(additional_config.begin(), additional_config.end());
87+
init_input_shapes({shape_params.data_shape});
88+
inType = outType = ov::element::f32;
89+
90+
if (moe_type == MoEType::MoE2GeMM) {
91+
function = initMoE2GeMMSubgraph(shape_params, ov::element::f32, ov::element::f32);
92+
} else if (moe_type == MoEType::MoE3GeMM) {
93+
function = initMoE3GeMMSubgraph(shape_params, ov::element::f32, ov::element::f32);
94+
} else {
95+
OPENVINO_THROW("Unsupported MoEType");
96+
}
97+
}
98+
};
99+
100+
class MoECompressedWeightsSubgraphTest : public testing::WithParamInterface<MoeCompressedWeightsTestParams>,
101+
virtual public SubgraphBaseTest,
102+
public CpuTestWithFusing {
103+
public:
104+
static std::string getTestCaseName(const testing::TestParamInfo<MoeCompressedWeightsTestParams>& obj) {
105+
const auto& [moe_params,
106+
moe_type,
107+
weights_precision,
108+
decompression_precision,
109+
scale_precision,
110+
decompression_multiply_type,
111+
decompression_subtract_type,
112+
reshape_on_decompression,
113+
decompression_group_size,
114+
additional_config,
115+
use_matmul_decompression_impl] = obj.param;
116+
std::ostringstream result;
117+
result << MoESubgraphTest::generateBaseMoeTestName(moe_params, moe_type, additional_config) << "_";
118+
result << "_WP=" << weights_precision << "_";
119+
result << "DP=" << decompression_precision << "_";
120+
result << "SP=" << scale_precision << "_";
121+
result << "DM=" << decompression_multiply_type << "_";
122+
result << "DS=" << decompression_subtract_type << "_";
123+
result << "RD=" << reshape_on_decompression << "_";
124+
result << "GS=" << decompression_group_size << "_";
125+
result << "use_matmul_decompression_impl=" << use_matmul_decompression_impl << "_";
126+
return result.str();
127+
}
93128

94129
protected:
95130
void SetUp() override {
@@ -100,12 +135,12 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
100135
weights_precision,
101136
decompression_precision,
102137
scale_precision,
103-
use_weight_decompression,
104138
decompression_multiply_type,
105139
decompression_subtract_type,
106140
reshape_on_decompression,
107141
decompression_group_size,
108-
additional_config] = GetParam();
142+
additional_config,
143+
use_matmul_decompression_impl] = GetParam();
109144

110145
configuration.insert(additional_config.begin(), additional_config.end());
111146
init_input_shapes({shape_params.data_shape});
@@ -115,9 +150,9 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
115150
function = initMoE2GeMMSubgraph(shape_params,
116151
ov::element::f32,
117152
weights_precision,
153+
true,
118154
decompression_precision,
119-
ov::element::f32,
120-
use_weight_decompression,
155+
scale_precision,
121156
decompression_multiply_type,
122157
decompression_subtract_type,
123158
reshape_on_decompression,
@@ -126,9 +161,9 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
126161
function = initMoE3GeMMSubgraph(shape_params,
127162
ov::element::f32,
128163
weights_precision,
164+
true,
129165
decompression_precision,
130-
ov::element::f32,
131-
use_weight_decompression,
166+
scale_precision,
132167
decompression_multiply_type,
133168
decompression_subtract_type,
134169
reshape_on_decompression,
@@ -137,14 +172,42 @@ class MoeSubgraphTest : public testing::WithParamInterface<MoeTestParams>,
137172
OPENVINO_THROW("Unsupported MoEType");
138173
}
139174
}
175+
176+
void check_results() {
177+
const auto& test_param = GetParam();
178+
const ov::element::Type compressed_weights_precision = std::get<2>(test_param);
179+
const bool use_matmul_decompression_impl = std::get<10>(test_param);
180+
181+
const auto runtime_model = compiledModel.get_runtime_model();
182+
const auto result = runtime_model->get_result();
183+
auto batch_gather_mm = result->get_input_node_shared_ptr(0);
184+
185+
auto type = batch_gather_mm->get_rt_info().at(ov::exec_model_info::LAYER_TYPE).as<std::string>();
186+
if (type == "Reorder" || type == "Convert" || type == "Subgraph")
187+
batch_gather_mm = batch_gather_mm->get_input_node_shared_ptr(0);
188+
189+
type = batch_gather_mm->get_rt_info().at(ov::exec_model_info::LAYER_TYPE).as<std::string>();
190+
EXPECT_EQ(type, "BatchGatherMatmul");
191+
192+
const auto& expected_weights_precision =
193+
use_matmul_decompression_impl ? compressed_weights_precision : batch_gather_mm->get_input_element_type(0);
194+
EXPECT_EQ(batch_gather_mm->get_input_element_type(1), expected_weights_precision);
195+
}
140196
};
141197

198+
TEST_P(MoESubgraphTest, CompareWithRefs) {
199+
SKIP_IF_CURRENT_TEST_IS_DISABLED()
200+
run();
201+
}
202+
203+
TEST_P(MoECompressedWeightsSubgraphTest, CompareWithRefs) {
204+
SKIP_IF_CURRENT_TEST_IS_DISABLED()
205+
run();
206+
check_results();
207+
}
208+
142209
namespace {
143-
// Test parameter generation
144-
const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32};
145-
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8, ov::element::u4, ov::element::i4};
146210
const std::vector<MoEType> moe_types = {MoEType::MoE2GeMM, MoEType::MoE3GeMM};
147-
148211
const std::vector<MoePatternParams> moe_params = {
149212
{
150213
{{-1, -1, 2048}, {{2, 15, 2048}, {2, 1, 2048}, {3, 8, 2048}}}, // data_shape,
@@ -165,41 +228,36 @@ const ov::AnyMap additional_config_bf16 = {{ov::hint::inference_precision.name()
165228

166229
} // namespace
167230

168-
// Basic FP32 tests
169-
INSTANTIATE_TEST_SUITE_P(smoke_MoeSubgraph_basic,
170-
MoeSubgraphTest,
231+
INSTANTIATE_TEST_SUITE_P(smoke_MoESubgraph_basic,
232+
MoESubgraphTest,
171233
::testing::Combine(::testing::ValuesIn(moe_params),
172234
::testing::ValuesIn(moe_types),
173-
::testing::Values(ov::element::f32),
174-
::testing::Values(ov::element::f32),
175-
::testing::Values(ov::element::f32),
176-
::testing::Values(false),
177-
::testing::Values(DecompressionType::full),
178-
::testing::Values(DecompressionType::full),
179-
::testing::Values(false),
180-
::testing::Values(0),
181235
::testing::Values(additional_config_basic)),
182-
MoeSubgraphTest::getTestCaseName);
236+
MoESubgraphTest::getTestCaseName);
183237

184-
// BF16 inference precision tests
185-
INSTANTIATE_TEST_SUITE_P(smoke_MoeSubgraph_bf16,
186-
MoeSubgraphTest,
238+
INSTANTIATE_TEST_SUITE_P(smoke_MoESubgraph_bf16,
239+
MoESubgraphTest,
187240
::testing::Combine(::testing::ValuesIn(moe_params),
188241
::testing::ValuesIn(moe_types),
242+
::testing::Values(additional_config_bf16)),
243+
MoESubgraphTest::getTestCaseName);
244+
245+
const std::vector<ov::test::ElementType> decompression_precisions = {ov::element::f32};
246+
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8, ov::element::u4, ov::element::i4};
247+
248+
INSTANTIATE_TEST_SUITE_P(smoke_MoeCompressedWeights,
249+
MoECompressedWeightsSubgraphTest,
250+
::testing::Combine(::testing::ValuesIn(moe_params),
251+
::testing::ValuesIn(moe_types),
252+
::testing::ValuesIn(weights_precisions),
253+
::testing::ValuesIn(decompression_precisions),
189254
::testing::Values(ov::element::f32),
190-
::testing::Values(ov::element::f32),
191-
::testing::Values(ov::element::f32),
192-
::testing::Values(false),
193255
::testing::Values(DecompressionType::full),
194256
::testing::Values(DecompressionType::full),
195-
::testing::Values(false),
196-
::testing::Values(0),
197-
::testing::Values(additional_config_bf16)),
198-
MoeSubgraphTest::getTestCaseName);
199-
200-
TEST_P(MoeSubgraphTest, CompareWithRefs) {
201-
SKIP_IF_CURRENT_TEST_IS_DISABLED()
202-
run();
203-
}
257+
::testing::Values(false), // reshape on decompression
258+
::testing::Values(16), // decompression group size
259+
::testing::Values(additional_config_basic),
260+
::testing::Values(true)), // use_matmul_decompression_impl
261+
MoECompressedWeightsSubgraphTest::getTestCaseName);
204262

205263
} // namespace ov::test

src/tests/functional/plugin/shared/include/shared_test_classes/subgraph/moe_builders.hpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <iostream>
88
#include <memory>
9+
#include <optional>
910
#include "openvino/core/model.hpp"
1011
#include "openvino/core/type/element_type.hpp"
1112
#include "shared_test_classes/base/ov_subgraph.hpp"
@@ -24,24 +25,24 @@ struct MoePatternParams {
2425
std::shared_ptr<ov::Model> initMoE2GeMMSubgraph(const MoePatternParams& moe_params,
2526
const ov::element::Type data_precision,
2627
const ov::element::Type weights_precision,
27-
const ov::element::Type decompression_precision,
28-
const ov::element::Type scale_precision,
29-
const bool use_weight_decompression,
30-
const DecompressionType decompression_multiply_type,
31-
const DecompressionType decompression_subtract_type,
32-
const bool reshape_on_decompression,
33-
const int decompression_group_size);
28+
const bool use_weight_decompression = false,
29+
const std::optional<ov::element::Type> decompression_precision = std::nullopt,
30+
const std::optional<ov::element::Type> scale_precision = std::nullopt,
31+
const std::optional<DecompressionType> decompression_multiply_type = std::nullopt,
32+
const std::optional<DecompressionType> decompression_subtract_type = std::nullopt,
33+
const std::optional<bool> reshape_on_decompression = std::nullopt,
34+
const std::optional<int> decompression_group_size = std::nullopt);
3435

3536
std::shared_ptr<ov::Model> initMoE3GeMMSubgraph(const MoePatternParams& moe_params,
3637
const ov::element::Type data_precision,
3738
const ov::element::Type weights_precision,
38-
const ov::element::Type decompression_precision,
39-
const ov::element::Type scale_precision,
40-
const bool use_weight_decompression,
41-
const DecompressionType decompression_multiply_type,
42-
const DecompressionType decompression_subtract_type,
43-
const bool reshape_on_decompression,
44-
const int decompression_group_size);
39+
const bool use_weight_decompression = false,
40+
const std::optional<ov::element::Type> decompression_precision = std::nullopt,
41+
const std::optional<ov::element::Type> scale_precision = std::nullopt,
42+
const std::optional<DecompressionType> decompression_multiply_type = std::nullopt,
43+
const std::optional<DecompressionType> decompression_subtract_type = std::nullopt,
44+
const std::optional<bool> reshape_on_decompression = std::nullopt,
45+
const std::optional<int> decompression_group_size = std::nullopt);
4546

4647
} // namespace test
4748
} // namespace ov

0 commit comments

Comments
 (0)