Skip to content

Commit 49e0008

Browse files
committed
Closing UNB-2603 Improve imports
1 parent 3227493 commit 49e0008

File tree

1 file changed

+51
-54
lines changed

1 file changed

+51
-54
lines changed

unboxapi/__init__.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,16 @@
77
import uuid
88
from typing import Callable, List, Optional
99

10+
import marshmallow as ma
1011
import pandas as pd
11-
from bentoml.saved_bundle.bundler import _write_bento_content_to_dir
12-
from bentoml.utils.tempdir import TempDirectory
13-
from marshmallow import ValidationError
12+
from bentoml.saved_bundle import bundler
13+
from bentoml.utils import tempdir
1414

15-
from .api import Api
15+
from . import api, exceptions, schemas, utils
1616
from .datasets import Dataset
17-
from .exceptions import (
18-
UnboxDatasetInconsistencyError,
19-
UnboxDuplicateTask,
20-
UnboxResourceError,
21-
UnboxSubscriptionPlanException,
22-
UnboxValidationError,
23-
)
2417
from .models import Model, ModelType, create_template_model
2518
from .projects import Project
26-
from .schemas import DatasetSchema, ModelSchema, ProjectSchema
2719
from .tasks import TaskType
28-
from .utils import HidePrints
2920
from .version import __version__ # noqa: F401
3021

3122

@@ -46,7 +37,7 @@ class UnboxClient(object):
4637
"""
4738

4839
def __init__(self, api_key: str = None):
49-
self.api = Api(api_key)
40+
self.api = api.Api(api_key)
5041
self.subscription_plan = self.api.get_request("me/subscription-plan")
5142

5243
def create_project(
@@ -94,11 +85,13 @@ def create_project(
9485
obj:`add_dataframe` for detailed examples.
9586
"""
9687
# ----------------------------- Schema validation ---------------------------- #
97-
project_schema = ProjectSchema()
88+
project_schema = schemas.ProjectSchema()
9889
try:
9990
project_schema.load({"name": name, "description": description})
100-
except ValidationError as err:
101-
raise UnboxValidationError(self._format_error_message(err)) from None
91+
except ma.ValidationError as err:
92+
raise exceptions.UnboxValidationError(
93+
self._format_error_message(err)
94+
) from None
10295

10396
endpoint = "projects"
10497
payload = dict(name=name, description=description, taskType=task_type.value)
@@ -195,7 +188,7 @@ def create_or_load_project(
195188
return self.create_project(
196189
name=name, task_type=task_type, description=description
197190
)
198-
except UnboxDuplicateTask:
191+
except exceptions.UnboxDuplicateTask:
199192
return self.load_project(name)
200193

201194
def add_model(
@@ -502,17 +495,17 @@ def add_model(
502495
TaskType.TabularClassification,
503496
TaskType.TextClassification,
504497
]:
505-
raise UnboxValidationError(
498+
raise exceptions.UnboxValidationError(
506499
"`task_type` must be either TaskType.TabularClassification or "
507500
"TaskType.TextClassification. \n"
508501
) from None
509502
if model_type not in [model_framework for model_framework in ModelType]:
510-
raise UnboxValidationError(
503+
raise exceptions.UnboxValidationError(
511504
"`model_type` must be one of the supported ModelTypes. Check out "
512505
"our API reference for a full list "
513506
"https://reference.unbox.ai/reference/api/unboxapi.ModelType.html. \n"
514507
) from None
515-
model_schema = ModelSchema()
508+
model_schema = schemas.ModelSchema()
516509
try:
517510
model_schema.load(
518511
{
@@ -530,27 +523,29 @@ def add_model(
530523
"dependent_dir": dependent_dir,
531524
}
532525
)
533-
except ValidationError as err:
534-
raise UnboxValidationError(self._format_error_message(err)) from None
526+
except ma.ValidationError as err:
527+
raise exceptions.UnboxValidationError(
528+
self._format_error_message(err)
529+
) from None
535530

536531
# --------------------------- Resource validations --------------------------- #
537532
# Requirements check
538533
if requirements_txt_file and not os.path.isfile(
539534
os.path.expanduser(requirements_txt_file)
540535
):
541-
raise UnboxResourceError(
536+
raise exceptions.UnboxResourceError(
542537
f"File `{requirements_txt_file}` does not exist. \n"
543538
) from None
544539

545540
# Setup script
546541
if setup_script and not os.path.isfile(os.path.expanduser(setup_script)):
547-
raise UnboxResourceError(
542+
raise exceptions.UnboxResourceError(
548543
f"File `{setup_script}` does not exist. \n"
549544
) from None
550545

551546
# Dependent dir
552547
if dependent_dir and dependent_dir == os.getcwd():
553-
raise UnboxResourceError(
548+
raise exceptions.UnboxResourceError(
554549
"`dependent_dir` cannot be the working directory. \n",
555550
mitigation="Make sure that the specified `dependent_dir` is different "
556551
f"from `{os.getcwd()}`.",
@@ -559,13 +554,13 @@ def add_model(
559554
# Training set
560555
if task_type in [TaskType.TabularClassification, TaskType.TabularRegression]:
561556
if len(train_sample_df.index) < 100:
562-
raise UnboxResourceError(
557+
raise exceptions.UnboxResourceError(
563558
context="There's an issue with the specified `train_sample_df`. \n",
564559
message=f"Only {len(train_sample_df.index)} rows were found. \n",
565560
mitigation="Make sure to upload a training sample with 100+ rows.",
566561
) from None
567562
if train_sample_df.isnull().values.any():
568-
raise UnboxResourceError(
563+
raise exceptions.UnboxResourceError(
569564
context="There's an issue with the specified `train_sample_df`. \n",
570565
message=f"The `train_sample_df` contains null values, which is "
571566
"currently not supported. \n",
@@ -579,14 +574,14 @@ def add_model(
579574

580575
# predict_proba
581576
if not isinstance(function, Callable):
582-
raise UnboxValidationError(
577+
raise exceptions.UnboxValidationError(
583578
f"- `{function}` specified as `function` is not callable. \n"
584579
) from None
585580

586581
user_args = function.__code__.co_varnames[: function.__code__.co_argcount][2:]
587582
kwarg_keys = tuple(kwargs)
588583
if user_args != kwarg_keys:
589-
raise UnboxResourceError(
584+
raise exceptions.UnboxResourceError(
590585
context="There's an issue with the speficied `function`. \n",
591586
message=f"Your function's additional args {user_args} do not match the "
592587
f"kwargs you specifed {kwarg_keys}. \n",
@@ -601,20 +596,20 @@ def add_model(
601596
TaskType.TabularRegression,
602597
]:
603598
test_input = train_sample_df[:3][feature_names].to_numpy()
604-
with HidePrints():
599+
with utils.HidePrints():
605600
function(model, test_input, **kwargs)
606601
else:
607602
test_input = [
608603
"Unbox is great!",
609604
"Let's see if this function is ready for some error analysis",
610605
]
611-
with HidePrints():
606+
with utils.HidePrints():
612607
function(model, test_input, **kwargs)
613608
except Exception as e:
614609
exception_stack = "".join(
615610
traceback.format_exception(type(e), e, e.__traceback__)
616611
)
617-
raise UnboxResourceError(
612+
raise exceptions.UnboxResourceError(
618613
context="There's an issue with the specified `function`. \n",
619614
message=f"It is failing with the following error: \n"
620615
f"{exception_stack}",
@@ -626,7 +621,7 @@ def add_model(
626621
# Transformers resources
627622
if model_type is ModelType.transformers:
628623
if "tokenizer" not in kwargs:
629-
raise UnboxResourceError(
624+
raise exceptions.UnboxResourceError(
630625
context="There's a missing kwarg for the specified model type. \n",
631626
message="`tokenizer` must be specified in kwargs when using a "
632627
"transformers model. \n",
@@ -648,7 +643,7 @@ def add_model(
648643
for feature in feature_names + [train_sample_label_column_name]
649644
if feature not in headers
650645
]
651-
raise UnboxDatasetInconsistencyError(
646+
raise exceptions.UnboxDatasetInconsistencyError(
652647
f"Features {features_not_in_dataset} specified in `feature_names` "
653648
"are not on the training sample. \n"
654649
) from None
@@ -660,13 +655,13 @@ def add_model(
660655
]
661656
for value, field in required_fields:
662657
if value is None:
663-
raise UnboxDatasetInconsistencyError(
658+
raise exceptions.UnboxDatasetInconsistencyError(
664659
message=f"TabularClassification task missing `{field}`.\n",
665660
mitigation=f"Make sure to specify `{field}` for tabular "
666661
"classification tasks.",
667662
) from None
668663

669-
with TempDirectory() as dir:
664+
with tempdir.TempDirectory() as dir:
670665
bento_service = create_template_model(
671666
model_type,
672667
task_type,
@@ -686,9 +681,9 @@ def add_model(
686681
bento_service.pack("function", function)
687682
bento_service.pack("kwargs", kwargs)
688683

689-
with TempDirectory() as temp_dir:
684+
with tempdir.TempDirectory() as temp_dir:
690685
print("Bundling model and artifacts...")
691-
_write_bento_content_to_dir(bento_service, temp_dir)
686+
bundler._write_bento_content_to_dir(bento_service, temp_dir)
692687

693688
if model_type is ModelType.rasa:
694689
dependent_dir = model.model_metadata.model_dir
@@ -715,7 +710,7 @@ def add_model(
715710
)
716711

717712
# Tar the model bundle with its artifacts and upload
718-
with TempDirectory() as tarfile_dir:
713+
with tempdir.TempDirectory() as tarfile_dir:
719714
tarfile_path = f"{tarfile_dir}/model"
720715

721716
with tarfile.open(tarfile_path, mode="w:gz") as tar:
@@ -899,11 +894,11 @@ def add_dataset(
899894
TaskType.TabularClassification,
900895
TaskType.TextClassification,
901896
]:
902-
raise UnboxValidationError(
897+
raise exceptions.UnboxValidationError(
903898
"`task_type` must be either TaskType.TabularClassification or "
904899
"TaskType.TextClassification. \n"
905900
) from None
906-
dataset_schema = DatasetSchema()
901+
dataset_schema = schemas.DatasetSchema()
907902
try:
908903
dataset_schema.load(
909904
{
@@ -920,14 +915,16 @@ def add_dataset(
920915
"categorical_feature_names": categorical_feature_names,
921916
}
922917
)
923-
except ValidationError as err:
924-
raise UnboxValidationError(self._format_error_message(err)) from None
918+
except ma.ValidationError as err:
919+
raise exceptions.UnboxValidationError(
920+
self._format_error_message(err)
921+
) from None
925922

926923
# --------------------------- Resource validations --------------------------- #
927924
exp_file_path = os.path.expanduser(file_path)
928925
object_name = "original.csv"
929926
if not os.path.isfile(exp_file_path):
930-
raise UnboxResourceError(
927+
raise exceptions.UnboxResourceError(
931928
f"File at path `{file_path}` does not contain the dataset. \n"
932929
) from None
933930

@@ -939,7 +936,7 @@ def add_dataset(
939936
df = pd.read_csv(file_path, sep=sep)
940937

941938
if df.isnull().values.any():
942-
raise UnboxResourceError(
939+
raise exceptions.UnboxResourceError(
943940
context="There's an issue with the specified dataset. \n",
944941
message="The dataset contains null values, which is currently "
945942
"not supported. \n",
@@ -951,14 +948,14 @@ def add_dataset(
951948
try:
952949
headers.index(label_column_name)
953950
except ValueError:
954-
raise UnboxDatasetInconsistencyError(
951+
raise exceptions.UnboxDatasetInconsistencyError(
955952
f"`{label_column_name}` specified as `label_column_name` is not "
956953
"in the dataset. \n"
957954
) from None
958955

959956
dataset_classes = list(df[label_column_name].unique())
960957
if len(dataset_classes) > len(class_names):
961-
raise UnboxDatasetInconsistencyError(
958+
raise exceptions.UnboxDatasetInconsistencyError(
962959
f"There are {len(dataset_classes)} classes represented in the dataset, "
963960
f"but only {len(class_names)} items in your `class_names`. \n",
964961
mitigation=f"Make sure that there are at most {len(class_names)} "
@@ -973,15 +970,15 @@ def add_dataset(
973970
headers.index(feature_name)
974971
except ValueError:
975972
if text_column_name:
976-
raise UnboxDatasetInconsistencyError(
973+
raise exceptions.UnboxDatasetInconsistencyError(
977974
f"`{text_column_name}` specified as `text_column_name` is not in "
978975
"the dataset. \n"
979976
) from None
980977
else:
981978
features_not_in_dataset = [
982979
feature for feature in feature_names if feature not in headers
983980
]
984-
raise UnboxDatasetInconsistencyError(
981+
raise exceptions.UnboxDatasetInconsistencyError(
985982
f"Features {features_not_in_dataset} specified in `feature_names` "
986983
"are not in the dataset. \n"
987984
) from None
@@ -991,22 +988,22 @@ def add_dataset(
991988
if tag_column_name:
992989
headers.index(tag_column_name)
993990
except ValueError:
994-
raise UnboxDatasetInconsistencyError(
991+
raise exceptions.UnboxDatasetInconsistencyError(
995992
f"`{tag_column_name}` specified as `tag_column_name` is not in "
996993
"the dataset. \n"
997994
) from None
998995

999996
# ----------------------- Subscription plan validations ---------------------- #
1000997
if row_count > self.subscription_plan["datasetSize"]:
1001-
raise UnboxSubscriptionPlanException(
998+
raise exceptions.UnboxSubscriptionPlanException(
1002999
f"The dataset your are trying to upload contains {row_count} rows, "
10031000
"which exceeds your plan's limit of "
10041001
f"{self.subscription_plan['datasetSize']}. \n"
10051002
) from None
10061003
if task_type == TaskType.TextClassification:
10071004
max_text_size = df[text_column_name].str.len().max()
10081005
if max_text_size > 100000:
1009-
raise UnboxSubscriptionPlanException(
1006+
raise exceptions.UnboxSubscriptionPlanException(
10101007
"The dataset you are trying to upload contains rows with "
10111008
f"{max_text_size} characters, which exceeds the 100,000 character "
10121009
"limit."
@@ -1182,7 +1179,7 @@ def add_dataframe(
11821179
"""
11831180
# --------------------------- Resource validations --------------------------- #
11841181
if not isinstance(df, pd.DataFrame):
1185-
raise UnboxValidationError(
1182+
raise exceptions.UnboxValidationError(
11861183
f"- `df` is a `{type(df)}`, but it must be of type `pd.DataFrame`. \n"
11871184
) from None
11881185
with tempfile.TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)