Skip to content

Commit 88e68ef

Browse files
committed
divide configs for model types
1 parent 90aaad9 commit 88e68ef

File tree

5 files changed

+66
-78
lines changed

5 files changed

+66
-78
lines changed

src/plugins/intel_npu/tests/functional/behavior/npuw/test_engine/models/model_builder.cpp

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ void ModelBuilder::clear() {
13761376
m_name_idx = 0;
13771377
}
13781378

1379-
ov::Output<ov::Node> ModelBuilder::setup_position_ids(ModelConfig& config, const ov::Output<ov::Node>& seq_source) {
1379+
ov::Output<ov::Node> ModelBuilder::setup_position_ids(LLMConfig& config, const ov::Output<ov::Node>& seq_source) {
13801380
OPENVINO_ASSERT(!(config.internal_position_ids && config.position_ids.get_node()),
13811381
"internal_position_ids and position_ids are mutually exclusive");
13821382
ov::Output<ov::Node> position_ids_output;
@@ -1418,36 +1418,11 @@ std::shared_ptr<ov::Model> ModelBuilder::make_model(const ov::Output<ov::Node>&
14181418
return std::make_shared<ov::Model>(ov::OutputVector{res->output(0)}, m_sinks, model_name);
14191419
}
14201420

1421-
std::shared_ptr<ov::Model> ModelBuilder::build_model(const ModelConfig& config_in) {
1422-
OPENVINO_ASSERT(
1423-
static_cast<int>(config_in.use_conv_features) + static_cast<int>(config_in.use_cross_attention) + static_cast<int>(config_in.use_token_type_embedding) <= 1,
1424-
"At most one structural dispatch flag may be set");
1425-
1426-
// Fill in norm/ffn defaults from actual config sizes when the caller left them empty.
1427-
ModelConfig config = config_in;
1428-
if (!config.norm) {
1429-
config.norm = LayerNorm(config.hidden_size, config.precision);
1430-
}
1431-
if (!config.ffn) {
1432-
config.ffn = SwiGLU(config.hidden_size, config.intermediate_size, config.precision, config.weight);
1433-
}
1434-
1435-
if (config.use_conv_features) {
1436-
return build_whisper_encoder(config);
1437-
}
1438-
if (config.use_cross_attention) {
1439-
return build_whisper_decoder(config);
1440-
}
1441-
if (config.use_token_type_embedding) {
1442-
return build_embedding_encoder(config);
1443-
}
1444-
return build_llm(config);
1445-
}
1446-
1447-
std::shared_ptr<ov::Model> ModelBuilder::build_llm(const ModelConfig& config_in) {
1421+
std::shared_ptr<ov::Model> ModelBuilder::build_llm(const LLMConfig& config_in) {
14481422
clear();
14491423

1450-
ModelConfig config = config_in;
1424+
LLMConfig config = config_in;
1425+
config.finalize_defaults();
14511426
const auto prec = config.precision;
14521427

14531428
auto attention_mask = parameter(ov::element::i64, ov::PartialShape{-1, -1}, "attention_mask");
@@ -1565,8 +1540,10 @@ std::shared_ptr<ov::Model> ModelBuilder::build_llm(const ModelConfig& config_in)
15651540
return make_model(final_norm, "last_hidden_state", model_name);
15661541
}
15671542

1568-
std::shared_ptr<ov::Model> ModelBuilder::build_whisper_encoder(const ModelConfig& config) {
1543+
std::shared_ptr<ov::Model> ModelBuilder::build_whisper_encoder(const WhisperEncoderConfig& config_in) {
15691544
clear();
1545+
WhisperEncoderConfig config = config_in;
1546+
config.finalize_defaults();
15701547
const auto prec = config.precision;
15711548
const auto d = config.hidden_size;
15721549

@@ -1813,8 +1790,10 @@ static ov::Output<ov::Node> make_whisper_positional_embedding(const ov::Output<o
18131790
return hidden_states->output(0);
18141791
}
18151792

1816-
std::shared_ptr<ov::Model> ModelBuilder::build_whisper_decoder(const ModelConfig& config) {
1793+
std::shared_ptr<ov::Model> ModelBuilder::build_whisper_decoder(const WhisperDecoderConfig& config_in) {
18171794
clear();
1795+
WhisperDecoderConfig config = config_in;
1796+
config.finalize_defaults();
18181797
const auto prec = config.precision;
18191798
const auto d = config.hidden_size;
18201799
const auto heads = config.num_heads;
@@ -1948,8 +1927,10 @@ std::shared_ptr<ov::Model> ModelBuilder::build_whisper_decoder(const ModelConfig
19481927
return make_model(logits_out, "logits", "synthetic_whisper_decoder");
19491928
}
19501929

1951-
std::shared_ptr<ov::Model> ModelBuilder::build_embedding_encoder(const ModelConfig& config) {
1930+
std::shared_ptr<ov::Model> ModelBuilder::build_embedding_encoder(const BertConfig& config_in) {
19521931
clear();
1932+
BertConfig config = config_in;
1933+
config.finalize_defaults();
19531934

19541935
const auto prec = config.precision;
19551936
const auto hs = config.hidden_size;

src/plugins/intel_npu/tests/functional/behavior/npuw/test_engine/models/model_builder.hpp

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,8 @@ ov::Output<ov::Node> make_post_norm_layer(const ov::Output<ov::Node>& input,
367367
return normed2;
368368
}
369369

370-
/// Unified config for all model types. build_model() dispatches on structural flags.
371-
/// NOTE: weight MUST be declared before lm_head_weight/norm/ffn (C++ member init order).
372-
struct ModelConfig {
370+
struct BaseModelConfig {
371+
// Common parameters
373372
size_t hidden_size = 64;
374373
size_t num_heads = 4;
375374
size_t head_dim = 16;
@@ -378,16 +377,6 @@ struct ModelConfig {
378377
size_t vocab_size = 1000;
379378
size_t num_layers = 10;
380379

381-
bool use_kv_cache = true;
382-
bool use_inputs_embeds = false;
383-
bool internal_position_ids = false;
384-
385-
// Structural flags — build_model() dispatches on these
386-
bool use_conv_features = false;
387-
bool use_cross_attention = false;
388-
bool use_token_type_embedding = false;
389-
bool pre_norm = true;
390-
391380
ov::element::Type precision = ov::element::f32;
392381

393382
WeightFn weight = FP32Weight{};
@@ -400,32 +389,56 @@ struct ModelConfig {
400389
ov::Output<ov::Node> position_ids; ///< Empty = auto-creates 2D Parameter + HalfRotationRoPE
401390
NormFn qk_norm;
402391

403-
// Whisper-specific
404-
size_t encoder_layers = 0; ///< 0 = use num_layers
405-
size_t decoder_layers = 0; ///< 0 = use num_layers
406-
size_t num_mel_bins = 80;
407-
size_t max_source_positions = 1500;
408-
size_t max_target_positions = 448;
392+
BaseModelConfig() : lm_head_weight(weight) {}
409393

410-
// BERT/Encoder-specific
411-
size_t max_position_embeddings = 512;
412-
size_t type_vocab_size = 2;
394+
virtual ~BaseModelConfig() = default;
413395

414-
ModelConfig() : lm_head_weight(weight) {}
396+
/// Fill in norm/ffn defaults from actual config sizes when the caller left them empty.
397+
void finalize_defaults() {
398+
if (!norm) {
399+
norm = LayerNorm(hidden_size, precision);
400+
}
401+
if (!ffn) {
402+
ffn = SwiGLU(hidden_size, intermediate_size, precision, weight);
403+
}
404+
}
415405

416406
size_t get_kv_heads() const {
417407
return num_kv_heads == 0 ? num_heads : num_kv_heads;
418408
}
409+
};
410+
411+
struct LLMConfig : public BaseModelConfig {
412+
bool use_kv_cache = true;
413+
bool use_inputs_embeds = false;
414+
bool internal_position_ids = false; ///< embedding model
415+
bool pre_norm = true;
416+
};
417+
418+
struct WhisperEncoderConfig : public BaseModelConfig {
419+
size_t encoder_layers = 0;
420+
size_t num_mel_bins = 80;
421+
size_t max_source_positions = 1500;
419422

420423
size_t get_encoder_layers() const {
421424
return encoder_layers == 0 ? num_layers : encoder_layers;
422425
}
426+
};
427+
428+
struct WhisperDecoderConfig : public BaseModelConfig {
429+
size_t decoder_layers = 0;
430+
size_t max_target_positions = 448;
423431

424432
size_t get_decoder_layers() const {
425433
return decoder_layers == 0 ? num_layers : decoder_layers;
426434
}
427435
};
428436

437+
struct BertConfig : public BaseModelConfig {
438+
size_t max_position_embeddings = 512;
439+
size_t type_vocab_size = 2;
440+
};
441+
429442
class ModelBuilder {
430443
public:
431444
ModelBuilder() = default;
@@ -450,19 +463,16 @@ class ModelBuilder {
450463
const ov::PartialShape& shape,
451464
const std::string& name);
452465

453-
/// Unified entry point. Dispatches on config structural flags.
454-
std::shared_ptr<ov::Model> build_model(const ModelConfig& config);
466+
std::shared_ptr<ov::Model> build_llm(const LLMConfig& config);
467+
std::shared_ptr<ov::Model> build_whisper_encoder(const WhisperEncoderConfig& config);
468+
std::shared_ptr<ov::Model> build_whisper_decoder(const WhisperDecoderConfig& config);
469+
std::shared_ptr<ov::Model> build_embedding_encoder(const BertConfig& config);
455470

456471
void clear();
457472

458473
private:
459-
std::shared_ptr<ov::Model> build_llm(const ModelConfig& config);
460-
std::shared_ptr<ov::Model> build_whisper_encoder(const ModelConfig& config);
461-
std::shared_ptr<ov::Model> build_whisper_decoder(const ModelConfig& config);
462-
std::shared_ptr<ov::Model> build_embedding_encoder(const ModelConfig& config);
463-
464474
/// May auto-create HalfRotationRoPE on config.rope (hence non-const ref).
465-
ov::Output<ov::Node> setup_position_ids(ModelConfig& config, const ov::Output<ov::Node>& seq_source);
475+
ov::Output<ov::Node> setup_position_ids(LLMConfig& config, const ov::Output<ov::Node>& seq_source);
466476

467477
std::shared_ptr<ov::Model> make_model(const ov::Output<ov::Node>& output,
468478
const std::string& result_name,

src/plugins/intel_npu/tests/unit/npuw/llm_test_helpers.hpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919

2020
namespace ov::test::npuw {
2121

22-
inline ModelConfig make_llm_test_model_config() {
23-
ModelConfig cfg;
22+
template <typename Config = LLMConfig>
23+
inline Config make_test_model_config() {
24+
Config cfg;
2425
cfg.num_layers = 2;
2526
cfg.hidden_size = 64;
2627
cfg.num_heads = 4;
@@ -32,21 +33,17 @@ inline ModelConfig make_llm_test_model_config() {
3233

3334
inline std::shared_ptr<ov::Model> build_llm_test_model() {
3435
ModelBuilder mb;
35-
return mb.build_model(make_llm_test_model_config());
36+
return mb.build_llm(make_test_model_config());
3637
}
3738

3839
inline std::shared_ptr<ov::Model> build_whisper_decoder_test_model() {
39-
auto cfg = make_llm_test_model_config();
40-
cfg.use_cross_attention = true;
4140
ModelBuilder mb;
42-
return mb.build_model(cfg);
41+
return mb.build_whisper_decoder(make_test_model_config<WhisperDecoderConfig>());
4342
}
4443

4544
inline std::shared_ptr<ov::Model> build_embedding_test_model() {
46-
auto cfg = make_llm_test_model_config();
47-
cfg.use_token_type_embedding = true;
4845
ModelBuilder mb;
49-
return mb.build_model(cfg);
46+
return mb.build_embedding_encoder(make_test_model_config<BertConfig>());
5047
}
5148

5249
class NullPlugin : public ov::IPlugin {

src/plugins/intel_npu/tests/unit/npuw/online_partitioning.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "partitioning/online/snapshot.hpp"
1919

2020
using ov::test::npuw::ModelBuilder;
21-
using ov::test::npuw::ModelConfig;
21+
using ov::test::npuw::LLMConfig;
2222

2323
namespace {
2424

@@ -675,14 +675,14 @@ INSTANTIATE_TEST_SUITE_P(OnlinePartitioningTest,
675675
// always exposes its boundary Add in getInputs(), detects the mask mismatch, and
676676
// correctly sets irregular_io=true regardless of hash order.
677677
TEST(OnlinePartitioningTest, IsRegularParameterCase_PrefillModel_InputsEmbeds) {
678-
ModelConfig config;
678+
LLMConfig config;
679679
config.num_layers = 4;
680680
config.hidden_size = 64;
681681
config.use_inputs_embeds = true; // layer 0's residual Add reads inputs_embeds (ov::Parameter)
682682
config.use_kv_cache = true;
683683

684684
ModelBuilder mb;
685-
auto model = mb.build_model(config);
685+
auto model = mb.build_llm(config);
686686

687687
// Partitioning requires a static-shape stateless model, matching the real
688688
// production path in LLMCompiledModel before getPartitioning() is called.

src/plugins/intel_npu/tests/unit/npuw/partitioning_options_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "pyramid_attention.hpp"
2424

2525
using ov::test::npuw::ModelBuilder;
26-
using ov::test::npuw::ModelConfig;
26+
using ov::test::npuw::LLMConfig;
2727

2828
namespace {
2929

@@ -68,7 +68,7 @@ std::shared_ptr<ov::Model> build_unary_chain_model() {
6868
}
6969

7070
std::shared_ptr<ov::Model> build_static_llm_model(const int64_t query_len, const int64_t past_len) {
71-
ModelConfig config;
71+
LLMConfig config;
7272
config.num_layers = 4;
7373
config.hidden_size = 64;
7474
config.num_heads = 4;
@@ -77,7 +77,7 @@ std::shared_ptr<ov::Model> build_static_llm_model(const int64_t query_len, const
7777
config.vocab_size = 256;
7878

7979
ModelBuilder mb;
80-
auto model = mb.build_model(config);
80+
auto model = mb.build_llm(config);
8181

8282
ov::pass::StatefulToStateless().run_on_model(model);
8383
model = model->clone();

0 commit comments

Comments
 (0)