diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 7b5d379de..f028d0381 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -77,6 +77,7 @@ AquaDeploymentDetail, ConfigValidationError, CreateModelDeploymentDetails, + UpdateModelGroupDeploymentDetails, ) from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig from ads.common.object_storage_details import ObjectStorageDetails @@ -100,6 +101,9 @@ ModelDeploymentInfrastructure, ModelDeploymentMode, ) +from ads.model.deployment.model_deployment import ( + ModelDeploymentUpdateType, +) from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem from ads.telemetry import telemetry @@ -1214,6 +1218,219 @@ def _get_container_type_key( return container_type_key + @telemetry(entry_point="plugin=deployment&action=update", name="aqua") + def update( + self, + model_deployment_id: str, + update_model_deployment_details: Optional[ + UpdateModelGroupDeploymentDetails + ] = None, + **kwargs, + ) -> AquaDeployment: + """Updates a AQUA model group deployment. + + Args: + update_model_deployment_details : UpdateModelGroupDeploymentDetails, optional + An instance of UpdateModelGroupDeploymentDetails containing all optional + fields for updating a model deployment via Aqua. + kwargs: + display_name (str): The name of the model deployment. + description (Optional[str]): The description of the deployment. + models (Optional[List[AquaMultiModelRef]]): List of models for deployment. + instance_count (int): Number of instances used for deployment. + log_group_id (Optional[str]): OCI logging group ID for logs. + access_log_id (Optional[str]): OCID for access logs. + predict_log_id (Optional[str]): OCID for prediction logs. + bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps. + web_concurrency (Optional[int]): Number of worker processes/threads for handling requests. + memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape. + ocpus (Optional[float]): OCPU count for the selected shape. + private_endpoint_id (Optional[str]): Private endpoint ID for model deployment. + freeform_tags (Optional[Dict]): Freeform tags for model deployment. + defined_tags (Optional[Dict]): Defined tags for model deployment. + + Returns + ------- + AquaDeployment + An Aqua deployment instance. + """ + if not update_model_deployment_details: + try: + update_model_deployment_details = UpdateModelGroupDeploymentDetails( + **kwargs + ) + except ValidationError as ex: + custom_errors = build_pydantic_error_message(ex) + raise AquaValueError( + f"Invalid parameters for updating a model group deployment. Error details: {custom_errors}." + ) from ex + + model_deployment = ModelDeployment.from_id(model_deployment_id) + + infrastructure = model_deployment.infrastructure + runtime = model_deployment.runtime + + if not runtime.model_group_id: + raise AquaValueError( + "Invalid 'model_deployment_id'. Only model group deployment is supported to update." + ) + + model = self._update_model_group( + runtime.model_group_id, update_model_deployment_details + ) + + ( + infrastructure.with_bandwidth_mbps( + update_model_deployment_details.bandwidth_mbps + or infrastructure.bandwidth_mbps + ) + .with_replica( + update_model_deployment_details.instance_count or infrastructure.replica + ) + .with_web_concurrency( + update_model_deployment_details.web_concurrency + or infrastructure.web_concurrency + ) + .with_private_endpoint_id( + update_model_deployment_details.private_endpoint_id + or infrastructure.private_endpoint_id + ) + ) + + if ( + update_model_deployment_details.log_group_id + and update_model_deployment_details.access_log_id + ): + infrastructure.with_access_log( + log_group_id=update_model_deployment_details.log_group_id, + log_id=update_model_deployment_details.access_log_id, + ) + + if ( + update_model_deployment_details.log_group_id + and update_model_deployment_details.predict_log_id + ): + infrastructure.with_predict_log( + log_group_id=update_model_deployment_details.log_group_id, + log_id=update_model_deployment_details.predict_log_id, + ) + + if ( + update_model_deployment_details.memory_in_gbs + and update_model_deployment_details.ocpus + and infrastructure.shape_name.endswith("Flex") + ): + infrastructure.with_shape_config_details( + ocpus=update_model_deployment_details.ocpus, + memory_in_gbs=update_model_deployment_details.memory_in_gbs, + ) + + update_type = ModelDeploymentUpdateType.ZDT + # applies LIVE update if model group id has been changed + if runtime.model_group_id != model.id: + runtime.with_model_group_id(model.id) + update_type = ModelDeploymentUpdateType.LIVE + + freeform_tags = ( + update_model_deployment_details.freeform_tags + or model_deployment.freeform_tags + ) + defined_tags = ( + update_model_deployment_details.defined_tags + or model_deployment.defined_tags + ) + + # configure model deployment and deploy model on container runtime + ( + model_deployment.with_display_name( + update_model_deployment_details.display_name + or model_deployment.display_name + ) + .with_description( + update_model_deployment_details.description + or model_deployment.description + ) + .with_freeform_tags(**(freeform_tags or {})) + .with_defined_tags(**(defined_tags or {})) + .with_infrastructure(infrastructure) + .with_runtime(runtime) + ) + + model_deployment.update(wait_for_completion=False, update_type=update_type) + + logger.info(f"Updating Aqua Model Deployment {model_deployment.id}.") + + return AquaDeployment.from_oci_model_deployment( + model_deployment.dsc_model_deployment, self.region + ) + + def _update_model_group( + self, + model_group_id: str, + update_model_deployment_details: UpdateModelGroupDeploymentDetails, + ) -> DataScienceModelGroup: + """Creates a new model group if fine tuned weights changed. + + Parameters + ---------- + model_group_id: str + The model group id. + update_model_deployment_details: UpdateModelGroupDeploymentDetails + An instance of UpdateModelGroupDeploymentDetails containing all optional + fields for updating a model deployment via Aqua. + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + model_group = DataScienceModelGroup.from_id(model_group_id) + # create a new model group if fine tune weights changed as member models in ds model group is inmutable + if update_model_deployment_details.models: + if len(update_model_deployment_details.models) != 1: + raise AquaValueError( + "Invalid 'models' provided. Only one base model is required for updating model stack deployment." + ) + target_stacked_model = update_model_deployment_details.models[0] + target_base_model_id = target_stacked_model.model_id + if model_group.base_model_id != target_base_model_id: + raise AquaValueError( + "Invalid parameter 'models'. Base model id can't be changed for stacked model deployment." + ) + + # add member models + member_models = [ + { + "inference_key": fine_tune_weight.model_name, + "model_id": fine_tune_weight.model_id, + } + for fine_tune_weight in target_stacked_model.fine_tune_weights + ] + # add base model + member_models.append( + { + "inference_key": target_stacked_model.model_name, + "model_id": target_base_model_id, + } + ) + + # creates a model group with the same configurations from original model group except member models + model_group = ( + DataScienceModelGroup() + .with_compartment_id(model_group.compartment_id) + .with_project_id(model_group.project_id) + .with_display_name(model_group.display_name) + .with_description(model_group.description) + .with_freeform_tags(**(model_group.freeform_tags or {})) + .with_defined_tags(**(model_group.defined_tags or {})) + .with_custom_metadata_list(model_group.custom_metadata_list) + .with_base_model_id(target_base_model_id) + .with_member_models(member_models) + .create() + ) + + return model_group + @telemetry(entry_point="plugin=deployment&action=list", name="aqua") def list(self, **kwargs) -> List["AquaDeployment"]: """List Aqua model deployments in a given compartment and under certain project. diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index cbcb499ad..f409ed283 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -716,3 +716,58 @@ def validate_ft_model_v2( class Config: extra = "allow" protected_namespaces = () + + +class UpdateModelGroupDeploymentDetails(BaseModel): + """Class for updating Aqua model deployments.""" + + display_name: Optional[str] = Field( + None, description="The name of the model deployment." + ) + description: Optional[str] = Field( + None, description="The description of the deployment." + ) + models: Optional[List[AquaMultiModelRef]] = Field( + None, description="List of models for multimodel deployment." + ) + instance_count: Optional[int] = Field( + None, description="Number of instances used for deployment." + ) + log_group_id: Optional[str] = Field( + None, description="OCI logging group ID for logs." + ) + access_log_id: Optional[str] = Field( + None, + description="OCID for access logs. " + "https://docs.oracle.com/en-us/iaas/data-science/using/model_dep_using_logging.htm", + ) + predict_log_id: Optional[str] = Field( + None, + description="OCID for prediction logs." + "https://docs.oracle.com/en-us/iaas/data-science/using/model_dep_using_logging.htm", + ) + bandwidth_mbps: Optional[int] = Field( + None, description="Bandwidth limit on the load balancer in Mbps." + ) + web_concurrency: Optional[int] = Field( + None, description="Number of worker processes/threads for handling requests." + ) + memory_in_gbs: Optional[float] = Field( + None, description="Memory (in GB) for the selected shape." + ) + ocpus: Optional[float] = Field( + None, description="OCPU count for the selected shape." + ) + private_endpoint_id: Optional[str] = Field( + None, description="Private endpoint ID for model deployment." + ) + freeform_tags: Optional[Dict] = Field( + None, description="Freeform tags for model deployment." + ) + defined_tags: Optional[Dict] = Field( + None, description="Defined tags for model deployment." + ) + + class Config: + extra = "allow" + protected_namespaces = () diff --git a/ads/model/deployment/model_deployment.py b/ads/model/deployment/model_deployment.py index b8452963c..684a3bc9b 100644 --- a/ads/model/deployment/model_deployment.py +++ b/ads/model/deployment/model_deployment.py @@ -86,6 +86,11 @@ class ModelDeploymentType: MODEL_GROUP = "MODEL_GROUP" +class ModelDeploymentUpdateType: + ZDT = "ZDT" + LIVE = "LIVE" + + class LogNotConfiguredError(Exception): # pragma: no cover pass @@ -237,6 +242,7 @@ class ModelDeployment(Builder): CONST_LIFECYCLE_STATE = "lifecycleState" CONST_LIFECYCLE_DETAILS = "lifecycleDetails" CONST_TIME_CREATED = "timeCreated" + CONST_UPDATE_TYPE = "updateType" attribute_map = { CONST_ID: "id", @@ -677,6 +683,7 @@ def update( wait_for_completion: bool = True, max_wait_time: int = DEFAULT_WAIT_TIME, poll_interval: int = DEFAULT_POLL_INTERVAL, + update_type: str = ModelDeploymentUpdateType.ZDT, **kwargs, ): """Updates a model deployment @@ -699,6 +706,9 @@ def update( Negative implies infinite wait time. poll_interval: int Poll interval in seconds (Defaults to 10). + update_type: str + Update type for the deployment. Allowed values: ['ZDT', 'LIVE']. + 'LIVE' update can only be used if model group id has been changed. kwargs: dict @@ -722,7 +732,9 @@ def update( update_model_deployment_details = ( updated_properties.to_update_deployment() if properties or updated_properties.oci_model_deployment or kwargs - else self._update_model_deployment_details(**kwargs) + else self._update_model_deployment_details( + update_type=update_type, **kwargs + ) ) response = self.dsc_model_deployment.update( @@ -1488,16 +1500,22 @@ def _extract_from_oci_model( The ModelDeploymentInfrastructure or ModelDeploymentRuntime instance. """ for infra_attr, dsc_attr in dsc_instance.payload_attribute_map.items(): - value = get_value(oci_model_instance, dsc_attr) - if value: - if infra_attr not in sub_level: - dsc_instance._spec[infra_attr] = value - else: - dsc_instance._spec[infra_attr] = {} - for sub_infra_attr, sub_dsc_attr in sub_level[infra_attr].items(): - sub_value = get_value(value, sub_dsc_attr) - if sub_value: - dsc_instance._spec[infra_attr][sub_infra_attr] = sub_value + dsc_attr = dsc_attr if isinstance(dsc_attr, list) else [dsc_attr] + for dsc_value in dsc_attr: + value = get_value(oci_model_instance, dsc_value) + if value: + if infra_attr not in sub_level: + dsc_instance._spec[infra_attr] = value + else: + dsc_instance._spec[infra_attr] = {} + for sub_infra_attr, sub_dsc_attr in sub_level[ + infra_attr + ].items(): + sub_value = get_value(value, sub_dsc_attr) + if sub_value: + dsc_instance._spec[infra_attr][sub_infra_attr] = ( + sub_value + ) return dsc_instance def _build_model_deployment_details(self) -> CreateModelDeploymentDetails: @@ -1533,10 +1551,16 @@ def _build_model_deployment_details(self) -> CreateModelDeploymentDetails: ).to_oci_model(CreateModelDeploymentDetails) def _update_model_deployment_details( - self, **kwargs + self, update_type: str, **kwargs ) -> UpdateModelDeploymentDetails: """Builds UpdateModelDeploymentDetails from model deployment instance. + Parameters + ---------- + update_type: str + Update type for the deployment. Allowed values: ['ZDT', 'LIVE']. + 'LIVE' update can only be used if model group id has been changed. + Returns ------- UpdateModelDeploymentDetails @@ -1555,9 +1579,13 @@ def _update_model_deployment_details( self.infrastructure.CONST_MODEL_DEPLOYMENT_CONFIG_DETAILS: self._build_model_deployment_configuration_details(), self.infrastructure.CONST_CATEGORY_LOG_DETAILS: self._build_category_log_details(), } - return OCIDataScienceModelDeployment( - **update_model_deployment_details - ).to_oci_model(UpdateModelDeploymentDetails) + + update_model_deployment_details[ + self.infrastructure.CONST_MODEL_DEPLOYMENT_CONFIG_DETAILS + ][self.CONST_UPDATE_TYPE] = update_type + return UpdateModelDeploymentDetails( + **ads_utils.batch_convert_case(update_model_deployment_details, "snake") + ) def _update_spec(self, **kwargs) -> "ModelDeployment": """Updates model deployment specs from kwargs. diff --git a/ads/model/deployment/model_deployment_infrastructure.py b/ads/model/deployment/model_deployment_infrastructure.py index fb81b8920..0f29df1ea 100644 --- a/ads/model/deployment/model_deployment_infrastructure.py +++ b/ads/model/deployment/model_deployment_infrastructure.py @@ -1,7 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2023, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/from typing import Dict @@ -187,19 +186,44 @@ class ModelDeploymentInfrastructure(Builder): "model_deployment_configuration_details.model_configuration_details" ) + MODEL_INFRA_CONFIG_DETAILS_PATH = ( + "model_deployment_configuration_details.infrastructure_configuration_details" + ) + payload_attribute_map = { CONST_PROJECT_ID: "project_id", CONST_COMPARTMENT_ID: "compartment_id", - CONST_SHAPE_NAME: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.instance_shape_name", - CONST_SHAPE_CONFIG_DETAILS: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.model_deployment_instance_shape_config_details", - CONST_SUBNET_ID: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.subnet_id", - CONST_PRIVATE_ENDPOINT_ID: f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.private_endpoint_id", - CONST_REPLICA: f"{MODEL_CONFIG_DETAILS_PATH}.scaling_policy.instance_count", - CONST_BANDWIDTH_MBPS: f"{MODEL_CONFIG_DETAILS_PATH}.bandwidth_mbps", + CONST_SHAPE_NAME: [ + f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.instance_shape_name", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.instance_configuration.instance_shape_name", + ], + CONST_SHAPE_CONFIG_DETAILS: [ + f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.model_deployment_instance_shape_config_details", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.instance_configuration.model_deployment_instance_shape_config_details", + ], + CONST_SUBNET_ID: [ + f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.subnet_id", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.instance_configuration.subnet_id", + ], + CONST_PRIVATE_ENDPOINT_ID: [ + f"{MODEL_CONFIG_DETAILS_PATH}.instance_configuration.private_endpoint_id", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.instance_configuration.private_endpoint_id", + ], + CONST_REPLICA: [ + f"{MODEL_CONFIG_DETAILS_PATH}.scaling_policy.instance_count", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.scaling_policy.instance_count", + ], + CONST_BANDWIDTH_MBPS: [ + f"{MODEL_CONFIG_DETAILS_PATH}.bandwidth_mbps", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.bandwidth_mbps", + ], CONST_ACCESS_LOG: "category_log_details.access", CONST_PREDICT_LOG: "category_log_details.predict", CONST_DEPLOYMENT_TYPE: "model_deployment_configuration_details.deployment_type", - CONST_POLICY_TYPE: f"{MODEL_CONFIG_DETAILS_PATH}.scaling_policy.policy_type", + CONST_POLICY_TYPE: [ + f"{MODEL_CONFIG_DETAILS_PATH}.scaling_policy.policy_type", + f"{MODEL_INFRA_CONFIG_DETAILS_PATH}.scaling_policy.policy_type", + ], } sub_level_attribute_maps = { @@ -621,7 +645,9 @@ def subnet_id(self) -> str: """ return self.get_spec(self.CONST_SUBNET_ID, None) - def with_private_endpoint_id(self, private_endpoint_id: str) -> "ModelDeploymentInfrastructure": + def with_private_endpoint_id( + self, private_endpoint_id: str + ) -> "ModelDeploymentInfrastructure": """Sets the private endpoint id of model deployment. Parameters