@@ -147,8 +147,8 @@ def _validate_bundle_state(self):
147
147
"""Checks whether the bundle is in a valid state.
148
148
149
149
This includes:
150
- - When a "model" is included, you always need to provide predictions for both
151
- "validation" and "training" (regardless of artifact or no artifact) .
150
+ - When a "model" (shell or full) is included, you always need to provide predictions for both
151
+ "validation" and "training".
152
152
- When a "baseline-model" is included, you always need to provide a "training"
153
153
and "validation" set without predictions.
154
154
- When a "model" nor a "baseline-model" are included, you always need to NOT
@@ -186,33 +186,35 @@ def _validate_bundle_state(self):
186
186
)
187
187
188
188
if "model" in self ._bundle_resources :
189
+ model_config = self ._load_model_config_from_bundle ()
190
+ model_type = model_config .get ("modelType" )
189
191
if (
190
192
training_predictions_column_name is None
191
193
or validation_predictions_column_name is None
192
- ):
194
+ ) and model_type != "baseline" :
193
195
bundle_state_failed_validations .append (
194
196
"To push a model to the platform, you must provide "
195
197
"training and a validation sets with predictions in the column "
196
198
"`predictions_column_name`."
197
199
)
198
- elif "baseline-model" in self . _bundle_resources :
199
- if (
200
- "training" not in self ._bundle_resources
201
- or "validation" not in self ._bundle_resources
202
- ):
203
- bundle_state_failed_validations .append (
204
- "To push a baseline model to the platform, you must provide "
205
- "training and validation sets."
206
- )
207
- elif (
208
- training_predictions_column_name is not None
209
- and validation_predictions_column_name is not None
210
- ):
211
- bundle_state_failed_validations .append (
212
- "To push a baseline model to the platform, you must not provide "
213
- "training and a validation sets without predictions in the column "
214
- "`predictions_column_name`."
215
- )
200
+ if model_type == "baseline" :
201
+ if (
202
+ "training" not in self ._bundle_resources
203
+ or "validation" not in self ._bundle_resources
204
+ ):
205
+ bundle_state_failed_validations .append (
206
+ "To push a baseline model to the platform, you must provide "
207
+ "training and validation sets."
208
+ )
209
+ elif (
210
+ training_predictions_column_name is not None
211
+ and validation_predictions_column_name is not None
212
+ ):
213
+ bundle_state_failed_validations .append (
214
+ "To push a baseline model to the platform, you must provide "
215
+ "training and validation sets without predictions in the column "
216
+ "`predictions_column_name`."
217
+ )
216
218
else :
217
219
if (
218
220
"training" in self ._bundle_resources
@@ -260,26 +262,15 @@ def _validate_bundle_resources(self):
260
262
validation_set_validator .validate ()
261
263
)
262
264
263
- if (
264
- "baseline-model" in self ._bundle_resources
265
- and not self ._skip_model_validation
266
- ):
267
- baseline_model_validator = BaselineModelValidator (
268
- model_config_file_path = f"{ self .bundle_path } /baseline-model/model_config.yaml"
269
- )
270
- bundle_resources_failed_validations .extend (
271
- baseline_model_validator .validate ()
272
- )
273
-
274
265
if "model" in self ._bundle_resources and not self ._skip_model_validation :
275
- model_files = os .listdir (f"{ self .bundle_path } /model" )
276
- # Shell model
277
- if len (model_files ) == 1 :
266
+ model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
267
+ model_config = self ._load_model_config_from_bundle ()
268
+
269
+ if model_config ["modelType" ] == "shell" :
278
270
model_validator = ModelValidator (
279
- model_config_file_path = f" { self . bundle_path } /model/model_config.yaml"
271
+ model_config_file_path = model_config_file_path
280
272
)
281
- # Model package
282
- else :
273
+ elif model_config ["modelType" ] == "full" :
283
274
# Use data from the validation as test data
284
275
validation_dataset_df = self ._load_dataset_from_bundle ("validation" )
285
276
validation_dataset_config = self ._load_dataset_config_from_bundle (
@@ -298,12 +289,21 @@ def _validate_bundle_resources(self):
298
289
].head ()
299
290
300
291
model_validator = ModelValidator (
301
- model_config_file_path = f" { self . bundle_path } /model/model_config.yaml" ,
292
+ model_config_file_path = model_config_file_path ,
302
293
model_package_dir = f"{ self .bundle_path } /model" ,
303
294
sample_data = sample_data ,
304
295
use_runner = self ._use_runner ,
305
296
)
306
- bundle_resources_failed_validations .extend (model_validator .validate ())
297
+ elif model_config ["modelType" ] == "baseline" :
298
+ model_validator = BaselineModelValidator (
299
+ model_config_file_path = model_config_file_path
300
+ )
301
+ else :
302
+ raise ValueError (
303
+ f"Invalid model type: { model_config ['modelType' ]} . "
304
+ "The model type must be one of 'shell', 'full' or 'baseline'."
305
+ )
306
+ bundle_resources_failed_validations .extend (model_validator .validate ())
307
307
308
308
# Add the bundle resources failed validations to the list of all failed validations
309
309
self .failed_validations .extend (bundle_resources_failed_validations )
@@ -347,6 +347,21 @@ def _load_dataset_config_from_bundle(self, label: str) -> Dict[str, Any]:
347
347
348
348
return dataset_config
349
349
350
+ def _load_model_config_from_bundle (self ) -> Dict [str , Any ]:
351
+ """Loads a model config from a commit bundle.
352
+
353
+ Returns
354
+ -------
355
+ Dict[str, Any]
356
+ The model config.
357
+ """
358
+ model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
359
+
360
+ with open (model_config_file_path , "r" , encoding = "UTF-8" ) as stream :
361
+ model_config = yaml .safe_load (stream )
362
+
363
+ return model_config
364
+
350
365
def validate (self ) -> List [str ]:
351
366
"""Validates the commit bundle.
352
367
0 commit comments