Skip to content

Commit 190ed2e

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-3563 Add model property indicating if baseline / shell / non-shell and Closes OPEN-3616 Replace 'baseline-model' resource name in favor of common 'model' resource
1 parent 69d2d17 commit 190ed2e

File tree

3 files changed

+64
-59
lines changed

3 files changed

+64
-59
lines changed

openlayer/__init__.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .version import __version__ # noqa: F401
1717

1818
OPENLAYER_DIR = os.path.join(os.path.expanduser("~"), ".openlayer")
19-
VALID_RESOURCE_NAMES = {"baseline-model", "model", "training", "validation"}
19+
VALID_RESOURCE_NAMES = {"model", "training", "validation"}
2020

2121

2222
class OpenlayerClient(object):
@@ -411,6 +411,9 @@ def add_model(
411411
if model_package_dir:
412412
shutil.copytree(model_package_dir, temp_dir, dirs_exist_ok=True)
413413
utils.write_python_version(temp_dir)
414+
model_data["modelType"] = "full"
415+
else:
416+
model_data["modelType"] = "shell"
414417

415418
utils.write_yaml(model_data, f"{temp_dir}/model_config.yaml")
416419

@@ -479,14 +482,15 @@ def add_baseline_model(
479482
model_config = {}
480483
if model_config_file_path is not None:
481484
model_config = utils.read_yaml(model_config_file_path)
485+
model_config["modelType"] = "baseline"
482486
model_data = BaselineModelSchema().load(model_config)
483487

484488
# Copy relevant resources to temp directory
485489
with tempfile.TemporaryDirectory() as temp_dir:
486490
utils.write_yaml(model_data, f"{temp_dir}/model_config.yaml")
487491

488492
self._stage_resource(
489-
resource_name="baseline-model",
493+
resource_name="model",
490494
resource_dir=temp_dir,
491495
project_id=project_id,
492496
force=force,
@@ -1182,30 +1186,14 @@ def _stage_resource(
11821186
"""
11831187
if resource_name not in VALID_RESOURCE_NAMES:
11841188
raise ValueError(
1185-
"Resource name must be one of 'baseline-model', 'model', 'training', or"
1189+
"Resource name must be one of 'model', 'training', or"
11861190
f" 'validation', but got '{resource_name}'."
11871191
)
11881192

11891193
project_dir = f"{OPENLAYER_DIR}/{project_id}/staging"
11901194

11911195
resources_staged = utils.list_resources_in_bundle(project_dir)
11921196

1193-
if resource_name == "model" and "baseline-model" in resources_staged:
1194-
raise exceptions.OpenlayerException(
1195-
"Trying to stage a `model` when there is a `baseline-model` already staged."
1196-
+ " You can either add a `model` or a `baseline-model`, but not both at the"
1197-
+ " same time. Please remove one of them from the staging area using the"
1198-
+ " `restore` method."
1199-
) from None
1200-
1201-
if resource_name == "baseline-model" and "model" in resources_staged:
1202-
raise exceptions.OpenlayerException(
1203-
"Trying to stage a `baseline-model` when there is a `model` already staged."
1204-
+ " You can either add a `model` or a `baseline-model`, but not both at the"
1205-
+ " same time. Please remove one of them from the staging area using the"
1206-
+ " `restore` method."
1207-
) from None
1208-
12091197
if resource_name in resources_staged:
12101198
print(f"Found an existing `{resource_name}` resource staged.")
12111199

openlayer/schemas.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class BaselineModelSchema(ma.Schema):
2323
"""Schema for baseline models."""
2424

2525
metadata = ma.fields.Dict(allow_none=True, load_default={})
26+
modelType = ma.fields.Str()
2627

2728

2829
class CommitSchema(ma.Schema):
@@ -119,6 +120,7 @@ class ModelSchema(ma.Schema):
119120
allow_none=True,
120121
load_default={},
121122
)
123+
modelType = ma.fields.Str()
122124
architectureType = ma.fields.Str(
123125
validate=ma.validate.OneOf(
124126
[model_framework.value for model_framework in ModelType],

openlayer/validators.py

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ def _validate_bundle_state(self):
147147
"""Checks whether the bundle is in a valid state.
148148
149149
This includes:
150-
- When a "model" is included, you always need to provide predictions for both
151-
"validation" and "training" (regardless of artifact or no artifact).
150+
- When a "model" (shell or full) is included, you always need to provide predictions for both
151+
"validation" and "training".
152152
- When a "baseline-model" is included, you always need to provide a "training"
153153
and "validation" set without predictions.
154154
- When a "model" nor a "baseline-model" are included, you always need to NOT
@@ -186,33 +186,35 @@ def _validate_bundle_state(self):
186186
)
187187

188188
if "model" in self._bundle_resources:
189+
model_config = self._load_model_config_from_bundle()
190+
model_type = model_config.get("modelType")
189191
if (
190192
training_predictions_column_name is None
191193
or validation_predictions_column_name is None
192-
):
194+
) and model_type != "baseline":
193195
bundle_state_failed_validations.append(
194196
"To push a model to the platform, you must provide "
195197
"training and a validation sets with predictions in the column "
196198
"`predictions_column_name`."
197199
)
198-
elif "baseline-model" in self._bundle_resources:
199-
if (
200-
"training" not in self._bundle_resources
201-
or "validation" not in self._bundle_resources
202-
):
203-
bundle_state_failed_validations.append(
204-
"To push a baseline model to the platform, you must provide "
205-
"training and validation sets."
206-
)
207-
elif (
208-
training_predictions_column_name is not None
209-
and validation_predictions_column_name is not None
210-
):
211-
bundle_state_failed_validations.append(
212-
"To push a baseline model to the platform, you must not provide "
213-
"training and a validation sets without predictions in the column "
214-
"`predictions_column_name`."
215-
)
200+
if model_type == "baseline":
201+
if (
202+
"training" not in self._bundle_resources
203+
or "validation" not in self._bundle_resources
204+
):
205+
bundle_state_failed_validations.append(
206+
"To push a baseline model to the platform, you must provide "
207+
"training and validation sets."
208+
)
209+
elif (
210+
training_predictions_column_name is not None
211+
and validation_predictions_column_name is not None
212+
):
213+
bundle_state_failed_validations.append(
214+
"To push a baseline model to the platform, you must provide "
215+
"training and validation sets without predictions in the column "
216+
"`predictions_column_name`."
217+
)
216218
else:
217219
if (
218220
"training" in self._bundle_resources
@@ -260,26 +262,15 @@ def _validate_bundle_resources(self):
260262
validation_set_validator.validate()
261263
)
262264

263-
if (
264-
"baseline-model" in self._bundle_resources
265-
and not self._skip_model_validation
266-
):
267-
baseline_model_validator = BaselineModelValidator(
268-
model_config_file_path=f"{self.bundle_path}/baseline-model/model_config.yaml"
269-
)
270-
bundle_resources_failed_validations.extend(
271-
baseline_model_validator.validate()
272-
)
273-
274265
if "model" in self._bundle_resources and not self._skip_model_validation:
275-
model_files = os.listdir(f"{self.bundle_path}/model")
276-
# Shell model
277-
if len(model_files) == 1:
266+
model_config_file_path = f"{self.bundle_path}/model/model_config.yaml"
267+
model_config = self._load_model_config_from_bundle()
268+
269+
if model_config["modelType"] == "shell":
278270
model_validator = ModelValidator(
279-
model_config_file_path=f"{self.bundle_path}/model/model_config.yaml"
271+
model_config_file_path=model_config_file_path
280272
)
281-
# Model package
282-
else:
273+
elif model_config["modelType"] == "full":
283274
# Use data from the validation as test data
284275
validation_dataset_df = self._load_dataset_from_bundle("validation")
285276
validation_dataset_config = self._load_dataset_config_from_bundle(
@@ -298,12 +289,21 @@ def _validate_bundle_resources(self):
298289
].head()
299290

300291
model_validator = ModelValidator(
301-
model_config_file_path=f"{self.bundle_path}/model/model_config.yaml",
292+
model_config_file_path=model_config_file_path,
302293
model_package_dir=f"{self.bundle_path}/model",
303294
sample_data=sample_data,
304295
use_runner=self._use_runner,
305296
)
306-
bundle_resources_failed_validations.extend(model_validator.validate())
297+
elif model_config["modelType"] == "baseline":
298+
model_validator = BaselineModelValidator(
299+
model_config_file_path=model_config_file_path
300+
)
301+
else:
302+
raise ValueError(
303+
f"Invalid model type: {model_config['modelType']}. "
304+
"The model type must be one of 'shell', 'full' or 'baseline'."
305+
)
306+
bundle_resources_failed_validations.extend(model_validator.validate())
307307

308308
# Add the bundle resources failed validations to the list of all failed validations
309309
self.failed_validations.extend(bundle_resources_failed_validations)
@@ -347,6 +347,21 @@ def _load_dataset_config_from_bundle(self, label: str) -> Dict[str, Any]:
347347

348348
return dataset_config
349349

350+
def _load_model_config_from_bundle(self) -> Dict[str, Any]:
351+
"""Loads a model config from a commit bundle.
352+
353+
Returns
354+
-------
355+
Dict[str, Any]
356+
The model config.
357+
"""
358+
model_config_file_path = f"{self.bundle_path}/model/model_config.yaml"
359+
360+
with open(model_config_file_path, "r", encoding="UTF-8") as stream:
361+
model_config = yaml.safe_load(stream)
362+
363+
return model_config
364+
350365
def validate(self) -> List[str]:
351366
"""Validates the commit bundle.
352367

0 commit comments

Comments
 (0)