Skip to content

Commit 6549c75

Browse files
Completes OPEN-3481 Introduce support for baseline models on the Python API
1 parent 76cb18f commit 6549c75

File tree

6 files changed

+245
-33
lines changed

6 files changed

+245
-33
lines changed

docs/source/reference/upload.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Add models and datasets
3232
OpenlayerClient.add_model
3333
OpenlayerClient.add_dataset
3434
OpenlayerClient.add_dataframe
35+
OpenlayerClient.add_baseline_model
3536

3637
Version control flow
3738
--------------------

openlayer/__init__.py

Lines changed: 109 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .version import __version__ # noqa: F401
1717

1818
OPENLAYER_DIR = os.path.join(os.path.expanduser("~"), ".openlayer")
19+
VALID_RESOURCE_NAMES = {"baseline-model", "model", "training", "validation"}
1920

2021

2122
class OpenlayerClient(object):
@@ -441,6 +442,85 @@ def add_model(
441442
force=force,
442443
)
443444

445+
def add_baseline_model(
446+
self,
447+
project_id: int,
448+
task_type: TaskType,
449+
model_config_file_path: Optional[str] = None,
450+
force: bool = False,
451+
):
452+
"""
453+
**Coming soon...**
454+
455+
Add a baseline model to the project.
456+
457+
Baseline models should be added together with training and validation
458+
sets. A model will then be trained on the platform using AutoML, using
459+
the parameters provided in the model config file.
460+
461+
.. important::
462+
This feature is experimental and currently under development. Only
463+
tabular classification tasks are supported for now.
464+
465+
Parameters
466+
----------
467+
model_config_file_path : str, optional
468+
Path to the model configuration YAML file. If not provided, the default
469+
model config will be used.
470+
471+
.. admonition:: What's on the model config file?
472+
473+
For baseline models, the content of the YAML file should contain:
474+
475+
- ``ensembleSize`` : int, default 10
476+
Number of models ensembled.
477+
- ``randomSeed`` : int, default 42
478+
Random seed to be used for model training.
479+
- ``timeout`` : int, default 60
480+
Maximum time (in seconds) to train all the models.
481+
- ``perRunLimit`` : int, optional
482+
Maximum time (in seconds) to train each model.
483+
- ``metadata`` : Dict[str, any], default {}
484+
Dictionary containing metadata about the model. This is the
485+
metadata that will be displayed on the Openlayer platform.
486+
force : bool, optional
487+
Whether to force the addition of the baseline model to the project.
488+
If set to True, any existing staged baseline model will be overwritten.
489+
"""
490+
if task_type is not TaskType.TabularClassification:
491+
raise exceptions.OpenlayerException(
492+
"Only tabular classification is supported for model baseline for now."
493+
)
494+
495+
# Validate the baseline model
496+
baseline_model_validator = validators.BaselineModelValidator(
497+
model_config_file_path=model_config_file_path,
498+
)
499+
failed_validations = baseline_model_validator.validate()
500+
501+
if failed_validations:
502+
raise exceptions.OpenlayerValidationError(
503+
"There are issues with the baseline model. \n"
504+
"Make sure to fix all of the issues listed above before the upload.",
505+
) from None
506+
507+
# Load model config and augment with defaults
508+
model_config = {}
509+
if model_config_file_path is not None:
510+
model_config = utils.read_yaml(model_config_file_path)
511+
model_data = schemas.BaselineModelSchema().load(model_config)
512+
513+
# Copy relevant resources to temp directory
514+
with tempfile.TemporaryDirectory() as temp_dir:
515+
utils.write_yaml(model_data, f"{temp_dir}/model_config.yaml")
516+
517+
self._stage_resource(
518+
resource_name="baseline-model",
519+
resource_dir=temp_dir,
520+
project_id=project_id,
521+
force=force,
522+
)
523+
444524
def add_dataset(
445525
self,
446526
file_path: str,
@@ -1034,7 +1114,6 @@ def status(self, project_id: int):
10341114
:obj:`commit` method).
10351115
"""
10361116
project_dir = f"{OPENLAYER_DIR}/{project_id}/staging"
1037-
valid_resource_names = ["model", "training", "validation"]
10381117

10391118
if not os.listdir(project_dir):
10401119
print(
@@ -1046,7 +1125,7 @@ def status(self, project_id: int):
10461125
if not os.path.exists(f"{project_dir}/commit.yaml"):
10471126
print("The following resources are staged, waiting to be committed:")
10481127
for file in os.listdir(project_dir):
1049-
if file in valid_resource_names:
1128+
if file in VALID_RESOURCE_NAMES:
10501129
print(f"\t - {file}")
10511130
print("Use the `commit` method to add a commit message to your changes.")
10521131
return
@@ -1055,7 +1134,7 @@ def status(self, project_id: int):
10551134
commit = yaml.safe_load(commit_file)
10561135
print("The following resources are committed, waiting to be pushed:")
10571136
for file in os.listdir(project_dir):
1058-
if file != "commit.yaml":
1137+
if file in VALID_RESOURCE_NAMES:
10591138
print(f"\t - {file}")
10601139
print(f"Commit message from {commit['date']}:")
10611140
print(f"\t {commit['message']}")
@@ -1128,31 +1207,43 @@ def _stage_resource(
11281207
force : bool
11291208
Whether to overwrite the resource if it already exists in the staging area.
11301209
"""
1131-
if resource_name not in ["model", "training", "validation"]:
1210+
if resource_name not in VALID_RESOURCE_NAMES:
11321211
raise ValueError(
1133-
f"Resource name must be one of 'model', 'training', or 'validation',"
1134-
f" but got {resource_name}."
1212+
f"Resource name must be one of 'baseline-model', 'model', 'training', or 'validation',"
1213+
f" but got '{resource_name}'."
11351214
)
11361215

1137-
staging_dir = f"{OPENLAYER_DIR}/{project_id}/staging/{resource_name}"
1216+
project_dir = f"{OPENLAYER_DIR}/{project_id}/staging"
1217+
1218+
resources_staged = utils.list_resources_in_bundle(project_dir)
11381219

1139-
# Append 'dataset' to the end of the resource name for the prints
1140-
if resource_name in ["training", "validation"]:
1141-
resource_name += " dataset"
1220+
if resource_name == "model" and "baseline-model" in resources_staged:
1221+
raise exceptions.OpenlayerException(
1222+
"Trying to stage a `model` when there is a `baseline-model` already staged."
1223+
+ " You can either add a `model` or a `baseline-model`, but not both at the same time."
1224+
+ " Please remove one of them from the staging area using the `restore` method."
1225+
) from None
11421226

1143-
if os.path.exists(staging_dir):
1144-
print(f"Found an existing {resource_name} staged.")
1145-
overwrite = "n"
1227+
if resource_name == "baseline-model" and "model" in resources_staged:
1228+
raise exceptions.OpenlayerException(
1229+
"Trying to stage a `baseline-model` when there is a `model` already staged."
1230+
+ " You can either add a `model` or a `baseline-model`, but not both at the same time."
1231+
+ " Please remove one of them from the staging area using the `restore` method."
1232+
) from None
1233+
1234+
if resource_name in resources_staged:
1235+
print(f"Found an existing `{resource_name}` resource staged.")
11461236

1237+
overwrite = "n"
11471238
if not force:
11481239
overwrite = input("Do you want to overwrite it? [y/n] ")
11491240
if overwrite.lower() == "y" or force:
1150-
print(f"Overwriting previously staged {resource_name}...")
1151-
shutil.rmtree(staging_dir)
1241+
print(f"Overwriting previously staged `{resource_name}` resource...")
1242+
shutil.rmtree(project_dir + "/" + resource_name)
11521243
else:
1153-
print(f"Keeping the existing {resource_name} staged.")
1244+
print(f"Keeping the existing `{resource_name}` resource staged.")
11541245
return
11551246

1156-
shutil.copytree(resource_dir, staging_dir)
1247+
shutil.copytree(resource_dir, project_dir + "/" + resource_name)
11571248

1158-
print(f"Staged the {resource_name}!")
1249+
print(f"Staged the `{resource_name}` resource!")

openlayer/projects.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def add_model(
4343
*args, project_id=self.id, task_type=tasks.TaskType(self.taskType), **kwargs
4444
)
4545

46+
def add_baseline_model(
47+
self,
48+
*args,
49+
**kwargs,
50+
):
51+
return self.client.add_baseline(
52+
*args, project_id=self.id, task_type=tasks.TaskType(self.taskType), **kwargs
53+
)
54+
4655
def add_dataset(
4756
self,
4857
*args,

openlayer/schemas.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818
)
1919

2020
# ---------------------------------- Schemas --------------------------------- #
21+
class BaselineModelSchema(ma.Schema):
22+
"""Schema for baseline models."""
23+
24+
ensembleSize = ma.fields.Int(load_default=10)
25+
metadata = ma.fields.Dict(allow_none=True, load_default={})
26+
perRunLimit = ma.fields.Int(load_default=None, allow_none=True)
27+
randomSeed = ma.fields.Int(load_default=42)
28+
timeout = ma.fields.Int(load_default=60)
29+
30+
2131
class CommitSchema(ma.Schema):
2232
"""Schema for commits."""
2333

openlayer/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,23 @@ def get_exception_stacktrace(err: Exception):
8989
str: the stacktrace of the most recent exception.
9090
"""
9191
return "".join(traceback.format_exception(type(err), err, err.__traceback__))
92+
93+
94+
def list_resources_in_bundle(bundle_path: str) -> list:
95+
"""Lists the resources in the bundle.
96+
97+
Args:
98+
bundle_path (str): the path to the bundle.
99+
100+
Returns:
101+
list: the list of resources in the bundle.
102+
"""
103+
# TODO: factor out list of valid resources
104+
VALID_RESOURCES = {"baseline-model", "model", "training", "validation"}
105+
106+
resources = []
107+
108+
for resource in os.listdir(bundle_path):
109+
if resource in VALID_RESOURCES:
110+
resources.append(resource)
111+
return resources

openlayer/validators.py

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,68 @@
2424
from . import schemas, utils
2525

2626

27+
class BaselineModelValidator:
28+
"""Validates the baseline model.
29+
30+
Parameters
31+
----------
32+
model_config_file_path : Optional[str], optional
33+
The path to the model config file, by default None
34+
"""
35+
36+
def __init__(self, model_config_file_path: Optional[str] = None):
37+
self.model_config_file_path = model_config_file_path
38+
39+
def _validate_model_config(self):
40+
"""Validates the model config file."""
41+
model_config_failed_validations = []
42+
43+
# File existence check
44+
if self.model_config_file_path:
45+
if not os.path.isfile(os.path.expanduser(self.model_config_file_path)):
46+
model_config_failed_validations.append(
47+
f"File `{self.model_config_file_path}` does not exist."
48+
)
49+
else:
50+
with open(self.model_config_file_path, "r") as stream:
51+
self.model_config = yaml.safe_load(stream)
52+
53+
if self.model_config:
54+
baseline_model_schema = schemas.BaselineModelSchema()
55+
try:
56+
baseline_model_schema.load(self.model_config)
57+
except ma.ValidationError as err:
58+
model_config_failed_validations.extend(
59+
_format_marshmallow_error_message(err)
60+
)
61+
62+
# Print results of the validation
63+
if model_config_failed_validations:
64+
print("Baseline model config failed validations: \n")
65+
_list_failed_validation_messages(model_config_failed_validations)
66+
67+
# Add the `model_config.yaml` failed validations to the list of all failed validations
68+
self.failed_validations.extend(model_config_failed_validations)
69+
70+
def validate(self) -> List[str]:
71+
"""Validates the baseline model.
72+
73+
Returns
74+
-------
75+
List[str]
76+
The list of failed validations.
77+
"""
78+
self.failed_validations = []
79+
80+
if self.model_config_file_path:
81+
self._validate_model_config()
82+
83+
if not self.failed_validations:
84+
print("All baseline model validations passed!")
85+
86+
return self.failed_validations
87+
88+
2789
class CommitBundleValidator:
2890
"""Validates the commit bundle prior to push.
2991
@@ -44,7 +106,7 @@ def __init__(
44106
skip_dataset_validation: bool = False,
45107
):
46108
self.bundle_path = bundle_path
47-
self._bundle_resources = self._list_resources_in_bundle()
109+
self._bundle_resources = utils.list_resources_in_bundle(bundle_path)
48110
self._skip_model_validation = skip_model_validation
49111
self._skip_dataset_validation = skip_dataset_validation
50112
self.failed_validations = []
@@ -55,8 +117,10 @@ def _validate_bundle_state(self):
55117
This includes:
56118
- When a "model" is included, you always need to provide predictions for both
57119
"validation" and "training" (regardless of artifact or no artifact).
58-
- When a "model" is not included, you always need to NOT upload predictions with
59-
one exception:
120+
- When a "baseline-model" is included, you always need to provide a "training"
121+
and "validation" set without predictions.
122+
- When a "model" nor a "baseline-model" are included, you always need to NOT upload predictions
123+
with one exception:
60124
- "validation" set only in bundle, which means the predictions are for the
61125
previous model version.
62126
"""
@@ -95,6 +159,24 @@ def _validate_bundle_state(self):
95159
"training and a validation sets with predictions in the column "
96160
"`predictions_column_name`."
97161
)
162+
elif "baseline-model" in self._bundle_resources:
163+
if (
164+
"training" not in self._bundle_resources
165+
or "validation" not in self._bundle_resources
166+
):
167+
bundle_state_failed_validations.append(
168+
"To push a baseline model to the platform, you must provide "
169+
"training and validation sets."
170+
)
171+
elif (
172+
training_predictions_column_name is not None
173+
and validation_predictions_column_name is not None
174+
):
175+
bundle_state_failed_validations.append(
176+
"To push a baseline model to the platform, you must not provide "
177+
"training and a validation sets without predictions in the column "
178+
"`predictions_column_name`."
179+
)
98180
else:
99181
if (
100182
"training" in self._bundle_resources
@@ -142,6 +224,17 @@ def _validate_bundle_resources(self):
142224
validation_set_validator.validate()
143225
)
144226

227+
if (
228+
"baseline-model" in self._bundle_resources
229+
and not self._skip_model_validation
230+
):
231+
baseline_model_validator = BaselineModelValidator(
232+
model_config_file_path=f"{self.bundle_path}/baseline-model/model_config.yaml"
233+
)
234+
bundle_resources_failed_validations.extend(
235+
baseline_model_validator.validate()
236+
)
237+
145238
if "model" in self._bundle_resources and not self._skip_model_validation:
146239
model_files = os.listdir(f"{self.bundle_path}/model")
147240
# Shell model
@@ -183,18 +276,6 @@ def _validate_bundle_resources(self):
183276
# Add the bundle resources failed validations to the list of all failed validations
184277
self.failed_validations.extend(bundle_resources_failed_validations)
185278

186-
def _list_resources_in_bundle(self) -> List[str]:
187-
"""Lists the resources in a commit bundle."""
188-
# TODO: factor out list of valid resources
189-
VALID_RESOURCES = ["model", "training", "validation"]
190-
191-
resources = []
192-
193-
for resource in os.listdir(self.bundle_path):
194-
if resource in VALID_RESOURCES:
195-
resources.append(resource)
196-
return resources
197-
198279
def _load_dataset_from_bundle(self, label: str) -> pd.DataFrame:
199280
"""Loads a dataset from a commit bundle.
200281

0 commit comments

Comments
 (0)