Skip to content

Commit 7b05b8a

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Refactor production data publishing and remove debugging statements
1 parent a55bc0b commit 7b05b8a

File tree

2 files changed

+120
-112
lines changed

2 files changed

+120
-112
lines changed

openlayer/__init__.py

Lines changed: 96 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
project.status()
2222
project.push()
2323
"""
24+
import copy
2425
import os
2526
import shutil
2627
import tarfile
@@ -29,7 +30,7 @@
2930
import urllib.parse
3031
import uuid
3132
import warnings
32-
from typing import Dict, Optional, Tuple
33+
from typing import Dict, List, Optional, Tuple, Union
3334

3435
import pandas as pd
3536
import yaml
@@ -1073,74 +1074,50 @@ def upload_reference_dataframe(
10731074
dataset_config_file_path=dataset_config_file_path,
10741075
task_type=task_type,
10751076
)
1076-
1077-
def send_stream_data(
1077+
1078+
def stream_data(
10781079
self,
10791080
inference_pipeline_id: str,
10801081
task_type: TaskType,
1081-
stream_df: pd.DataFrame,
1082+
stream_data: Union[Dict[str, any], List[Dict[str, any]]],
10821083
stream_config: Optional[Dict[str, any]] = None,
10831084
stream_config_file_path: Optional[str] = None,
1084-
verbose: bool = True,
10851085
) -> None:
1086-
"""Publishes a batch of production data to the Openlayer platform."""
1087-
if stream_config is None and stream_config_file_path is None:
1086+
"""Streams production data to the Openlayer platform."""
1087+
if not isinstance(stream_data, (dict, list)):
10881088
raise ValueError(
1089-
"Either `batch_config` or `batch_config_file_path` must be" " provided."
1089+
"stream_data must be a dictionary or a list of dictionaries."
10901090
)
1091-
if stream_config_file_path is not None and not os.path.exists(
1092-
stream_config_file_path
1093-
):
1094-
raise exceptions.OpenlayerValidationError(
1095-
f"Stream config file path {stream_config_file_path} does not exist."
1096-
) from None
1097-
elif stream_config_file_path is not None:
1098-
stream_config = utils.read_yaml(stream_config_file_path)
1099-
1100-
stream_config_to_validate = dict(stream_config)
1101-
stream_config_to_validate["label"] = "production"
1091+
if isinstance(stream_data, dict):
1092+
stream_data = [stream_data]
11021093

1103-
# Validate stream of data
1104-
stream_validator = dataset_validators.get_validator(
1094+
stream_df = pd.DataFrame(stream_data)
1095+
stream_config = self._validate_production_data_and_load_config(
11051096
task_type=task_type,
1106-
dataset_config=stream_config_to_validate,
1107-
dataset_config_file_path=stream_config_file_path,
1108-
dataset_df=stream_df,
1097+
config=stream_config,
1098+
config_file_path=stream_config_file_path,
1099+
df=stream_df,
11091100
)
1110-
failed_validations = stream_validator.validate()
1111-
1112-
if failed_validations:
1113-
raise exceptions.OpenlayerValidationError(
1114-
"There are issues with the stream of data and its config. \n"
1115-
"Make sure to fix all of the issues listed above before the upload.",
1116-
) from None
1117-
1118-
# Load dataset config and augment with defaults
1119-
stream_data = dict(stream_config)
1120-
1121-
# Add default columns if not present
1122-
columns_to_add = {"timestampColumnName", "inferenceIdColumnName"}
1123-
for column in columns_to_add:
1124-
if stream_data.get(column) is None:
1125-
stream_data, stream_df = self._add_default_column(
1126-
config=stream_data, df=stream_df, column_name=column
1127-
)
1128-
1129-
1101+
stream_config, stream_df = self._add_default_columns(
1102+
config=stream_config, df=stream_df
1103+
)
1104+
stream_config = self._strip_read_only_fields(stream_config)
11301105
body = {
1131-
"config": stream_data,
1106+
"config": stream_config,
11321107
"rows": stream_df.to_dict(orient="records"),
11331108
}
1134-
1135-
print("This is the body!")
1136-
print(body)
11371109
self.api.post_request(
11381110
endpoint=f"inference-pipelines/{inference_pipeline_id}/data-stream",
11391111
body=body,
11401112
)
1113+
print("Stream published!")
11411114

1142-
if verbose:
1143-
print("Stream published!")
1115+
def _strip_read_only_fields(self, config: Dict[str, any]) -> Dict[str, any]:
1116+
"""Strips read-only fields from the config."""
1117+
stripped_config = copy.deepcopy(config)
1118+
for field in {"columnNames", "label"}:
1119+
stripped_config.pop(field, None)
1120+
return stripped_config
11441121

11451122
def publish_batch_data(
11461123
self,
@@ -1151,54 +1128,29 @@ def publish_batch_data(
11511128
batch_config_file_path: Optional[str] = None,
11521129
) -> None:
11531130
"""Publishes a batch of production data to the Openlayer platform."""
1154-
if batch_config is None and batch_config_file_path is None:
1155-
raise ValueError(
1156-
"Either `batch_config` or `batch_config_file_path` must be" " provided."
1157-
)
1158-
if batch_config_file_path is not None and not os.path.exists(
1159-
batch_config_file_path
1160-
):
1161-
raise exceptions.OpenlayerValidationError(
1162-
f"Batch config file path {batch_config_file_path} does not exist."
1163-
) from None
1164-
elif batch_config_file_path is not None:
1165-
batch_config = utils.read_yaml(batch_config_file_path)
1166-
1167-
batch_config["label"] = "production"
1168-
1169-
# Validate batch of data
1170-
batch_validator = dataset_validators.get_validator(
1131+
batch_config = self._validate_production_data_and_load_config(
11711132
task_type=task_type,
1172-
dataset_config=batch_config,
1173-
dataset_config_file_path=batch_config_file_path,
1174-
dataset_df=batch_df,
1133+
config=batch_config,
1134+
config_file_path=batch_config_file_path,
1135+
df=batch_df,
1136+
)
1137+
batch_config, batch_df = self._add_default_columns(
1138+
config=batch_config, df=batch_df
11751139
)
1176-
failed_validations = batch_validator.validate()
11771140

1178-
if failed_validations:
1179-
raise exceptions.OpenlayerValidationError(
1180-
"There are issues with the batch of data and its config. \n"
1181-
"Make sure to fix all of the issues listed above before the upload.",
1182-
) from None
1141+
# Add column names if missing
1142+
if batch_config.get("columnNames") is None:
1143+
batch_config["columnNames"] = list(batch_df.columns)
11831144

1184-
# Add default columns if not present
1185-
if batch_data.get("columnNames") is None:
1186-
batch_data["columnNames"] = list(batch_df.columns)
1187-
columns_to_add = {"timestampColumnName", "inferenceIdColumnName"}
1188-
for column in columns_to_add:
1189-
if batch_data.get(column) is None:
1190-
batch_data, batch_df = self._add_default_column(
1191-
config=batch_data, df=batch_df, column_name=column
1192-
)
11931145
# Get min and max timestamps
1194-
earliest_timestamp = batch_df[batch_data["timestampColumnName"]].min()
1195-
latest_timestamp = batch_df[batch_data["timestampColumnName"]].max()
1146+
earliest_timestamp = batch_df[batch_config["timestampColumnName"]].min()
1147+
latest_timestamp = batch_df[batch_config["timestampColumnName"]].max()
11961148
batch_row_count = len(batch_df)
11971149

11981150
with tempfile.TemporaryDirectory() as tmp_dir:
11991151
# Copy save files to tmp dir
12001152
batch_df.to_csv(f"{tmp_dir}/dataset.csv", index=False)
1201-
utils.write_yaml(batch_data, f"{tmp_dir}/dataset_config.yaml")
1153+
utils.write_yaml(batch_config, f"{tmp_dir}/dataset_config.yaml")
12021154

12031155
tar_file_path = os.path.join(tmp_dir, "tarfile")
12041156
with tarfile.open(tar_file_path, mode="w:gz") as tar:
@@ -1234,9 +1186,64 @@ def publish_batch_data(
12341186
),
12351187
presigned_url_query_params=presigned_url_query_params,
12361188
)
1237-
12381189
print("Data published!")
12391190

1191+
def _validate_production_data_and_load_config(
1192+
self,
1193+
task_type: tasks.TaskType,
1194+
config: Dict[str, any],
1195+
config_file_path: str,
1196+
df: pd.DataFrame,
1197+
) -> Dict[str, any]:
1198+
"""Validates the production data and its config and returns a valid config
1199+
populated with the default values."""
1200+
if config is None and config_file_path is None:
1201+
raise ValueError(
1202+
"Either the config or the config file path must be provided."
1203+
)
1204+
if config_file_path is not None and not os.path.exists(config_file_path):
1205+
raise exceptions.OpenlayerValidationError(
1206+
f"The file specified by the config file path {config_file_path} does"
1207+
" not exist."
1208+
) from None
1209+
elif config_file_path is not None:
1210+
config = utils.read_yaml(config_file_path)
1211+
1212+
# Force label to be production
1213+
config["label"] = "production"
1214+
1215+
# Validate batch of data
1216+
validator = dataset_validators.get_validator(
1217+
task_type=task_type,
1218+
dataset_config=config,
1219+
dataset_config_file_path=config_file_path,
1220+
dataset_df=df,
1221+
)
1222+
failed_validations = validator.validate()
1223+
1224+
if failed_validations:
1225+
raise exceptions.OpenlayerValidationError(
1226+
"There are issues with the data and its config. \n"
1227+
"Make sure to fix all of the issues listed above before the upload.",
1228+
) from None
1229+
1230+
config = DatasetSchema().load({"task_type": task_type.value, **config})
1231+
1232+
return config
1233+
1234+
def _add_default_columns(
1235+
self, config: Dict[str, any], df: pd.DataFrame
1236+
) -> Tuple[Dict[str, any], pd.DataFrame]:
1237+
"""Adds the default columns if not present and returns the updated config and
1238+
dataframe."""
1239+
columns_to_add = {"timestampColumnName", "inferenceIdColumnName"}
1240+
for column in columns_to_add:
1241+
if config.get(column) is None:
1242+
config, df = self._add_default_column(
1243+
config=config, df=df, column_name=column
1244+
)
1245+
return config, df
1246+
12401247
def _add_default_column(
12411248
self, config: Dict[str, any], df: pd.DataFrame, column_name: str
12421249
) -> Tuple[Dict[str, any], pd.DataFrame]:

openlayer/inference_pipelines.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -238,44 +238,45 @@ def upload_reference_dataframe(
238238
task_type=self.taskType,
239239
**kwargs,
240240
)
241-
242-
def send_stream_data(self, *args, **kwargs):
243-
"""Publishes a stream of production data to the Openlayer platform.
241+
242+
def stream_data(self, *args, **kwargs):
243+
"""Streams production data to the Openlayer platform.
244244
245245
Parameters
246246
----------
247-
stream_df : pd.DataFrame
248-
Dataframe containing the batch of production data.
247+
stream_data: Union[Dict[str, any], List[Dict[str, any]]]
248+
Dictionary or list of dictionaries containing the production data. E.g.,
249+
``{'CreditScore': 618, 'Geography': 'France', 'Balance': 321.92}``.
249250
stream_config : Dict[str, any], optional
250-
Dictionary containing the batch configuration. This is not needed if
251-
``batch_config_file_path`` is provided.
251+
Dictionary containing the stream configuration. This is not needed if
252+
``stream_config_file_path`` is provided.
252253
253254
.. admonition:: What's in the config?
254255
255-
The configuration for a batch of data depends on the :obj:`TaskType`.
256+
The configuration for a stream of data depends on the :obj:`TaskType`.
256257
Refer to the `How to write dataset configs guides <https://docs.openlayer.com/docs/tabular-classification-dataset-config>`_
257258
for details. These configurations are
258-
the same for development and batches of production data.
259+
the same for development and production data.
259260
260261
stream_config_file_path : str
261262
Path to the configuration YAML file. This is not needed if
262-
``batch_config`` is provided.
263+
``stream_config`` is provided.
263264
264265
.. admonition:: What's in the config file?
265266
266-
The configuration for a batch of data depends on the :obj:`TaskType`.
267+
The configuration for a stream of data depends on the :obj:`TaskType`.
267268
Refer to the `How to write dataset configs guides <https://docs.openlayer.com/docs/tabular-classification-dataset-config>`_
268269
for details. These configurations are
269-
the same for development and batches of production data.
270+
the same for development and production data.
270271
271272
Notes
272273
-----
273-
Production data usually has a column with the inference timestamps. This
274-
column is specified in the ``timestampsColumnName`` of the batch config file,
274+
Production data usually contains the inference timestamps. This
275+
column is specified in the ``timestampsColumnName`` of the stream config file,
275276
and it should contain timestamps in the **UNIX format in seconds**.
276277
277-
Production data also usually has a column with the prediction IDs. This
278-
column is specified in the ``inferenceIdColumnName`` of the batch config file.
278+
Production data also usually contains the prediction IDs. This
279+
column is specified in the ``inferenceIdColumnName`` of the stream config file.
279280
This column is particularly important when the ground truths are not available
280281
during inference time, and they are updated later.
281282
@@ -298,16 +299,16 @@ def send_stream_data(self, *args, **kwargs):
298299
... name="XGBoost model inference pipeline",
299300
... )
300301
301-
With the ``InferencePipeline`` object retrieved, you can publish a batch
302-
of production data -- in this example, stored in a pandas dataframe
303-
called ``df`` -- with:
302+
With the ``InferencePipeline`` object retrieved, you can stream
303+
production data -- in this example, stored in a dictionary called
304+
``stream_data`` -- with:
304305
305-
>>> inference_pipeline.send_stream_data(
306-
... batch_df=df,
307-
... batch_config=config,
306+
>>> inference_pipeline.stream_data(
307+
... stream_data=stream_data,
308+
... stream_config=config,
308309
... )
309310
"""
310-
return self.client.send_stream_data(
311+
return self.client.stream_data(
311312
*args,
312313
inference_pipeline_id=self.id,
313314
task_type=self.taskType,

0 commit comments

Comments
 (0)