diff --git a/src/sasctl/_services/model_management.py b/src/sasctl/_services/model_management.py index e91b1aa6..e78d2549 100644 --- a/src/sasctl/_services/model_management.py +++ b/src/sasctl/_services/model_management.py @@ -28,7 +28,13 @@ class ModelManagement(Service): # TODO: set ds2MultiType @classmethod def publish_model( - cls, model, destination, name=None, force=False, reload_model_table=False + cls, + model, + destination, + model_version="latest", + name=None, + force=False, + reload_model_table=False, ): """ @@ -38,6 +44,8 @@ def publish_model( The name or id of the model, or a dictionary representation of the model. destination : str Name of destination to publish the model to. + model_version_id : str or dict, optional + Provide the id, name, or dictionary representation of the version to publish. Defaults to 'latest'. name : str, optional Provide a custom name for the published model. Defaults to None. force : bool, optional @@ -68,6 +76,23 @@ def publish_model( # TODO: Verify allowed formats by destination type. # As of 19w04 MAS throws HTTP 500 if name is in invalid format. + if model_version != "latest": + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version_name = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError("Model version is not recognized.") + elif isinstance(model_version, str) and cls.is_uuid(model_version): + model_version_name = mr.get_model_or_version(model, model_version)[ + "modelVersionName" + ] + else: + model_version_name = model_version + else: + model_version_name = "" + model_name = name or "{}_{}".format( model_obj["name"].replace(" ", ""), model_obj["id"] ).replace("-", "") @@ -79,6 +104,7 @@ def publish_model( { "modelName": mp._publish_name(model_name), "sourceUri": model_uri.get("uri"), + "modelVersionID": model_version_name, "publishLevel": "model", } ], @@ -104,6 +130,7 @@ def create_performance_definition( table_prefix, project=None, models=None, + modelVersions=None, library_name="Public", name=None, description=None, @@ -136,6 +163,9 @@ def create_performance_definition( The name or id of the model(s), or a dictionary representation of the model(s). For multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all models in the project specified will be used. Defaults to None. + modelVersions: str, list, optional + The name of the model version(s) for models used in the performance definition. If no model versions + are specified, all models will use the latest version. Defaults to None. library_name : str The library containing the input data, default is 'Public'. name : str, optional @@ -238,11 +268,44 @@ def create_performance_definition( "Project %s must have the 'predictionVariable' " "property set." % project.name ) + print("sup") + if not modelVersions: + updated_models = [model.id for model in models] + else: + updated_models = [] + if not isinstance(modelVersions, list): + modelVersions = [modelVersions] + + if len(models) < len(modelVersions): + raise ValueError( + "There are too many versions for the amount of models specified." + ) + + modelVersions = modelVersions + [""] * (len(models) - len(modelVersions)) + for model, modelVersionName in zip(models, modelVersions): + + if ( + isinstance(modelVersionName, dict) + and "modelVersionName" in modelVersionName + ): + + modelVersionName = modelVersionName["modelVersionName"] + elif ( + isinstance(modelVersionName, dict) + and "modelVersionName" not in modelVersionName + ): + + raise ValueError("Model version is not recognized.") + + if modelVersionName != "": + updated_models.append(model.id + ":" + modelVersionName) + else: + updated_models.append(model.id) request = { "projectId": project.id, "name": name or project.name + " Performance", - "modelIds": [model.id for model in models], + "modelIds": updated_models, "championMonitored": monitor_champion, "challengerMonitored": monitor_challenger, "maxBins": max_bins, @@ -279,7 +342,6 @@ def create_performance_definition( for v in project.get("variables", []) if v.get("role") == "output" ] - return cls.post( "/performanceTasks", json=request, diff --git a/src/sasctl/_services/model_publish.py b/src/sasctl/_services/model_publish.py index c3fa225f..90f665ad 100644 --- a/src/sasctl/_services/model_publish.py +++ b/src/sasctl/_services/model_publish.py @@ -10,6 +10,7 @@ from .model_repository import ModelRepository from .service import Service +from ..utils.decorators import deprecated class ModelPublish(Service): @@ -90,7 +91,7 @@ def delete_destination(cls, item): return cls.delete("/destinations/{name}".format(name=item)) - @classmethod + @deprecated("Use publish_model in model_management.py instead.", "1.11.5") def publish_model(cls, model, destination, name=None, code=None, notes=None): """Publish a model to an existing publishing destination. diff --git a/src/sasctl/_services/score_definitions.py b/src/sasctl/_services/score_definitions.py index 2c05611f..f37cfb6b 100644 --- a/src/sasctl/_services/score_definitions.py +++ b/src/sasctl/_services/score_definitions.py @@ -69,7 +69,7 @@ def create_score_definition( library_name: str, optional The library within the CAS server the table exists in. Defaults to "Public". model_version: str, optional - The user-chosen version of the model with the specified model_id. Defaults to "latest". + The user-chosen version of the model with the specified model version name. Defaults to latest version. Returns ------- @@ -116,7 +116,7 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table and not table_file: raise HTTPError( - f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist." + "This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist." ) elif not table and table_file: cls._cas_management.upload_file( @@ -125,16 +125,40 @@ def create_score_definition( table = cls._cas_management.get_table(table_name, library_name, server_name) if not table: raise HTTPError( - f"The file failed to upload properly or another error occurred." + "The file failed to upload properly or another error occurred." ) # Checks if the inputted table exists, and if not, uploads a file to create a new table + if model_version != "latest": + + if isinstance(model_version, dict) and "modelVersionName" in model_version: + model_version = model_version["modelVersionName"] + elif ( + isinstance(model_version, dict) + and "modelVersionName" not in model_version + ): + raise ValueError( + "Model version cannot be found. Please check the inputted model version." + ) + elif isinstance(model_version, str) and cls.is_uuid(model_version): + print("hello") + model_version = cls._model_repository.get_model_or_version( + model_id, model_version + )["modelVersionName"] + else: + model_version = model_version + + object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}" + + else: + object_uri = f"/modelManagement/models/{model_id}" + save_score_def = { "name": model_name, # used to be score_def_name "description": description, "objectDescriptor": { - "uri": f"/modelManagement/models/{model_id}", - "name": f"{model_name}({model_version})", + "uri": object_uri, + "name": f"{model_name} ({model_version})", "type": f"{object_descriptor_type}", }, "inputData": { @@ -149,7 +173,7 @@ def create_score_definition( "projectUri": f"/modelRepository/projects/{model_project_id}", "projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}", "publishDestination": "", - "versionedModel": f"{model_name}({model_version})", + "versionedModel": f"{model_name} ({model_version})", }, "mappings": inputMapping, } diff --git a/tests/unit/test_model_management.py b/tests/unit/test_model_management.py index fbd4fc36..834b0ecc 100644 --- a/tests/unit/test_model_management.py +++ b/tests/unit/test_model_management.py @@ -23,6 +23,8 @@ def test_create_performance_definition(): RestObj({"name": "Test Model 2", "id": "67890", "projectId": PROJECT["id"]}), ] USER = "username" + VERSION_MOCK = {"modelVersionName": "1.0"} + VERSION_MOCK_NONAME = {} with mock.patch("sasctl.core.Session._get_authorization_token"): current_session("example.com", USER, "password") @@ -111,6 +113,32 @@ def test_create_performance_definition(): table_prefix="TestData", ) + with pytest.raises(ValueError): + # Model verions exceeds models + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=["1.0", "2.0", "3.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + with pytest.raises(ValueError): + # Model version dictionary missing modelVersionName + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=VERSION_MOCK_NONAME, + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + get_project.return_value = copy.deepcopy(PROJECT) get_project.return_value["targetVariable"] = "target" get_project.return_value["targetLevel"] = "interval" @@ -125,21 +153,68 @@ def test_create_performance_definition(): monitor_challenger=True, monitor_champion=True, ) + url, data = post_models.call_args + assert post_models.call_count == 1 + assert PROJECT["id"] == data["json"]["projectId"] + assert MODELS[0]["id"] in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + assert "TestLibrary" == data["json"]["dataLibrary"] + assert "TestData" == data["json"]["dataPrefix"] + assert "cas-shared-default" == data["json"]["casServerId"] + assert data["json"]["name"] + assert data["json"]["description"] + assert data["json"]["maxBins"] == 3 + assert data["json"]["championMonitored"] is True + assert data["json"]["challengerMonitored"] is True - assert post_models.call_count == 1 - url, data = post_models.call_args - - assert PROJECT["id"] == data["json"]["projectId"] - assert MODELS[0]["id"] in data["json"]["modelIds"] - assert MODELS[1]["id"] in data["json"]["modelIds"] - assert "TestLibrary" == data["json"]["dataLibrary"] - assert "TestData" == data["json"]["dataPrefix"] - assert "cas-shared-default" == data["json"]["casServerId"] - assert data["json"]["name"] - assert data["json"]["description"] - assert data["json"]["maxBins"] == 3 - assert data["json"]["championMonitored"] is True - assert data["json"]["challengerMonitored"] is True + get_model.side_effect = copy.deepcopy(MODELS) + _ = mm.create_performance_definition( + # One model version as a string name + models=["model1", "model2"], + modelVersions="1.0", + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + assert post_models.call_count == 2 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + + get_model.side_effect = copy.deepcopy(MODELS) + # List of string type model versions + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=["1.0", "2.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + assert post_models.call_count == 3 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"] + + get_model.side_effect = copy.deepcopy(MODELS) + # List of dictionary type and string type model versions + _ = mm.create_performance_definition( + models=["model1", "model2"], + modelVersions=[VERSION_MOCK, "2.0"], + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + assert post_models.call_count == 4 + url, data = post_models.call_args + assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"] + assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"] with mock.patch( "sasctl._services.model_management.ModelManagement" ".post" @@ -160,20 +235,39 @@ def test_create_performance_definition(): monitor_champion=True, ) - assert post_project.call_count == 1 - url, data = post_project.call_args - - assert PROJECT["id"] == data["json"]["projectId"] - assert MODELS[0]["id"] in data["json"]["modelIds"] - assert MODELS[1]["id"] in data["json"]["modelIds"] - assert "TestLibrary" == data["json"]["dataLibrary"] - assert "TestData" == data["json"]["dataPrefix"] - assert "cas-shared-default" == data["json"]["casServerId"] - assert data["json"]["name"] - assert data["json"]["description"] - assert data["json"]["maxBins"] == 3 - assert data["json"]["championMonitored"] is True - assert data["json"]["challengerMonitored"] is True + # one extra test for project with version id + + assert post_project.call_count == 1 + url, data = post_project.call_args + + assert PROJECT["id"] == data["json"]["projectId"] + assert MODELS[0]["id"] in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] + assert "TestLibrary" == data["json"]["dataLibrary"] + assert "TestData" == data["json"]["dataPrefix"] + assert "cas-shared-default" == data["json"]["casServerId"] + assert data["json"]["name"] + assert data["json"]["description"] + assert data["json"]["maxBins"] == 3 + assert data["json"]["championMonitored"] is True + assert data["json"]["challengerMonitored"] is True + + get_model.side_effect = copy.deepcopy(MODELS) + # Project with model version + _ = mm.create_performance_definition( + project="project", + modelVersions="2.0", + library_name="TestLibrary", + table_prefix="TestData", + max_bins=3, + monitor_challenger=True, + monitor_champion=True, + ) + + assert post_project.call_count == 2 + url, data = post_project.call_args + assert f"{MODELS[0]['id']}:2.0" in data["json"]["modelIds"] + assert MODELS[1]["id"] in data["json"]["modelIds"] def test_table_prefix_format(): with pytest.raises(ValueError): diff --git a/tests/unit/test_score_definitions.py b/tests/unit/test_score_definitions.py index d1210866..075b316c 100644 --- a/tests/unit/test_score_definitions.py +++ b/tests/unit/test_score_definitions.py @@ -63,89 +63,166 @@ def test_create_score_definition(): "sasctl._services.cas_management.CASManagement.upload_file" ) as upload_file: with mock.patch( - "sasctl._services.score_definitions.ScoreDefinitions.post" - ) as post: - # Invalid model id test case - get_model.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - # Valid model id but invalid table name with no table_file argument test case - get_model_mock = { - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - } - get_model.return_value = get_model_mock - get_table.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - - # Invalid table name with a table_file argument that doesn't work test case - get_table.return_value = None - upload_file.return_value = None - get_table.return_value = None - with pytest.raises(HTTPError): - sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - - # Valid table_file argument that successfully creates a table test case - get_table.return_value = None - upload_file.return_value = RestObj - get_table_mock = {"tableName": "test_table"} - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - assert response - - # Valid table_name argument test case - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - table_file="test_path", - ) - assert response - - # Checking response with inputVariables in model elements - get_model_mock = { - "id": "12345", - "projectId": "54321", - "projectVersionId": "67890", - "name": "test_model", - "inputVariables": [ - {"name": "first"}, - {"name": "second"}, - {"name": "third"}, - ], - } - get_model.return_value = get_model_mock - get_table.return_value = get_table_mock - response = sd.create_score_definition( - score_def_name="test_create_sd", - model="12345", - table_name="test_table", - ) - assert response - assert post.call_count == 3 - - data = post.call_args - json_data = json.loads(data.kwargs["data"]) - assert json_data["mappings"] != [] + "sasctl._services.model_repository.ModelRepository.get_model_or_version" + ) as get_model_or_version: + with mock.patch( + "sasctl._services.score_definitions.ScoreDefinitions.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.score_definitions.ScoreDefinitions.post" + ) as post: + + # Invalid model id test case + get_model.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + # Valid model id but invalid table name with no table_file argument test case + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + } + get_model.return_value = get_model_mock + get_table.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + + # Invalid table name with a table_file argument that doesn't work test case + get_table.return_value = None + upload_file.return_value = None + get_table.return_value = None + with pytest.raises(HTTPError): + sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + + # Valid table_file argument that successfully creates a table test case + get_table.return_value = None + upload_file.return_value = RestObj + get_table_mock = {"tableName": "test_table"} + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + assert response + + # Valid table_name argument test case + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + table_file="test_path", + ) + assert response + + # Checking response with inputVariables in model elements + get_model_mock = { + "id": "12345", + "projectId": "54321", + "projectVersionId": "67890", + "name": "test_model", + "inputVariables": [ + {"name": "first"}, + {"name": "second"}, + {"name": "third"}, + ], + } + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + ) + assert response + assert post.call_count == 3 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert json_data["mappings"] != [] + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (latest)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (latest)" + ) + + # Model version dictionary with no model version name + with pytest.raises(ValueError): + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version={}, + ) + + # Model version as a model version name string, not UUID + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = False + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version="1.0", + ) + assert response + assert post.call_count == 4 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + ) + + # Model version as a dictionary with model version name key + get_version_mock = { + "id": "3456", + "modelVersionName": "1.0", + } + get_model.return_value = get_model_mock + get_table.return_value = get_table_mock + is_uuid.return_value = True + get_model_or_version.return_value = get_version_mock + response = sd.create_score_definition( + score_def_name="test_create_sd", + model="12345", + table_name="test_table", + model_version="3456", + ) + assert response + assert post.call_count == 5 + + data = post.call_args + json_data = json.loads(data.kwargs["data"]) + assert ( + json_data["objectDescriptor"]["name"] + == "test_model (1.0)" + ) + assert ( + json_data["properties"]["versionedModel"] + == "test_model (1.0)" + )