Skip to content

Added support to edit stacked model deployment #1250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: feature/model_group
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
AquaDeploymentDetail,
ConfigValidationError,
CreateModelDeploymentDetails,
UpdateModelGroupDeploymentDetails,
)
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
from ads.common.object_storage_details import ObjectStorageDetails
Expand All @@ -100,6 +101,9 @@
ModelDeploymentInfrastructure,
ModelDeploymentMode,
)
from ads.model.deployment.model_deployment import (
ModelDeploymentUpdateType,
)
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
from ads.telemetry import telemetry

Expand Down Expand Up @@ -1214,6 +1218,219 @@ def _get_container_type_key(

return container_type_key

@telemetry(entry_point="plugin=deployment&action=update", name="aqua")
def update(
self,
model_deployment_id: str,
update_model_deployment_details: Optional[
UpdateModelGroupDeploymentDetails
] = None,
**kwargs,
) -> AquaDeployment:
"""Updates a AQUA model group deployment.

Args:
update_model_deployment_details : UpdateModelGroupDeploymentDetails, optional
An instance of UpdateModelGroupDeploymentDetails containing all optional
fields for updating a model deployment via Aqua.
kwargs:
display_name (str): The name of the model deployment.
description (Optional[str]): The description of the deployment.
models (Optional[List[AquaMultiModelRef]]): List of models for deployment.
instance_count (int): Number of instances used for deployment.
log_group_id (Optional[str]): OCI logging group ID for logs.
access_log_id (Optional[str]): OCID for access logs.
predict_log_id (Optional[str]): OCID for prediction logs.
bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps.
web_concurrency (Optional[int]): Number of worker processes/threads for handling requests.
memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape.
ocpus (Optional[float]): OCPU count for the selected shape.
private_endpoint_id (Optional[str]): Private endpoint ID for model deployment.
freeform_tags (Optional[Dict]): Freeform tags for model deployment.
defined_tags (Optional[Dict]): Defined tags for model deployment.

Returns
-------
AquaDeployment
An Aqua deployment instance.
"""
if not update_model_deployment_details:
try:
update_model_deployment_details = UpdateModelGroupDeploymentDetails(
**kwargs
)
except ValidationError as ex:
custom_errors = build_pydantic_error_message(ex)
raise AquaValueError(
f"Invalid parameters for updating a model group deployment. Error details: {custom_errors}."
) from ex

model_deployment = ModelDeployment.from_id(model_deployment_id)

infrastructure = model_deployment.infrastructure
runtime = model_deployment.runtime

if not runtime.model_group_id:
raise AquaValueError(
"Invalid 'model_deployment_id'. Only model group deployment is supported to update."
)

model = self._update_model_group(
runtime.model_group_id, update_model_deployment_details
)

(
infrastructure.with_bandwidth_mbps(
update_model_deployment_details.bandwidth_mbps
or infrastructure.bandwidth_mbps
)
.with_replica(
update_model_deployment_details.instance_count or infrastructure.replica
)
.with_web_concurrency(
update_model_deployment_details.web_concurrency
or infrastructure.web_concurrency
)
.with_private_endpoint_id(
update_model_deployment_details.private_endpoint_id
or infrastructure.private_endpoint_id
)
)

if (
update_model_deployment_details.log_group_id
and update_model_deployment_details.access_log_id
):
infrastructure.with_access_log(
log_group_id=update_model_deployment_details.log_group_id,
log_id=update_model_deployment_details.access_log_id,
)

if (
update_model_deployment_details.log_group_id
and update_model_deployment_details.predict_log_id
):
infrastructure.with_predict_log(
log_group_id=update_model_deployment_details.log_group_id,
log_id=update_model_deployment_details.predict_log_id,
)

if (
update_model_deployment_details.memory_in_gbs
and update_model_deployment_details.ocpus
and infrastructure.shape_name.endswith("Flex")
):
infrastructure.with_shape_config_details(
ocpus=update_model_deployment_details.ocpus,
memory_in_gbs=update_model_deployment_details.memory_in_gbs,
)

update_type = ModelDeploymentUpdateType.ZDT
# applies LIVE update if model group id has been changed
if runtime.model_group_id != model.id:
runtime.with_model_group_id(model.id)
update_type = ModelDeploymentUpdateType.LIVE

freeform_tags = (
update_model_deployment_details.freeform_tags
or model_deployment.freeform_tags
)
defined_tags = (
update_model_deployment_details.defined_tags
or model_deployment.defined_tags
)

# configure model deployment and deploy model on container runtime
(
model_deployment.with_display_name(
update_model_deployment_details.display_name
or model_deployment.display_name
)
.with_description(
update_model_deployment_details.description
or model_deployment.description
)
.with_freeform_tags(**(freeform_tags or {}))
.with_defined_tags(**(defined_tags or {}))
.with_infrastructure(infrastructure)
.with_runtime(runtime)
)

model_deployment.update(wait_for_completion=False, update_type=update_type)

logger.info(f"Updating Aqua Model Deployment {model_deployment.id}.")

return AquaDeployment.from_oci_model_deployment(
model_deployment.dsc_model_deployment, self.region
)

def _update_model_group(
self,
model_group_id: str,
update_model_deployment_details: UpdateModelGroupDeploymentDetails,
) -> DataScienceModelGroup:
"""Creates a new model group if fine tuned weights changed.

Parameters
----------
model_group_id: str
The model group id.
update_model_deployment_details: UpdateModelGroupDeploymentDetails
An instance of UpdateModelGroupDeploymentDetails containing all optional
fields for updating a model deployment via Aqua.

Returns
-------
DataScienceModelGroup
The instance of DataScienceModelGroup.
"""
model_group = DataScienceModelGroup.from_id(model_group_id)
# create a new model group if fine tune weights changed as member models in ds model group is inmutable
if update_model_deployment_details.models:
if len(update_model_deployment_details.models) != 1:
raise AquaValueError(
"Invalid 'models' provided. Only one base model is required for updating model stack deployment."
)
target_stacked_model = update_model_deployment_details.models[0]
target_base_model_id = target_stacked_model.model_id
if model_group.base_model_id != target_base_model_id:
raise AquaValueError(
"Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
)

# add member models
member_models = [
{
"inference_key": fine_tune_weight.model_name,
"model_id": fine_tune_weight.model_id,
}
for fine_tune_weight in target_stacked_model.fine_tune_weights
]
# add base model
member_models.append(
{
"inference_key": target_stacked_model.model_name,
"model_id": target_base_model_id,
}
)

# creates a model group with the same configurations from original model group except member models
model_group = (
DataScienceModelGroup()
.with_compartment_id(model_group.compartment_id)
.with_project_id(model_group.project_id)
.with_display_name(model_group.display_name)
.with_description(model_group.description)
.with_freeform_tags(**(model_group.freeform_tags or {}))
.with_defined_tags(**(model_group.defined_tags or {}))
.with_custom_metadata_list(model_group.custom_metadata_list)
.with_base_model_id(target_base_model_id)
.with_member_models(member_models)
.create()
)

return model_group

@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
def list(self, **kwargs) -> List["AquaDeployment"]:
"""List Aqua model deployments in a given compartment and under certain project.
Expand Down
55 changes: 55 additions & 0 deletions ads/aqua/modeldeployment/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,58 @@ def validate_ft_model_v2(
class Config:
extra = "allow"
protected_namespaces = ()


class UpdateModelGroupDeploymentDetails(BaseModel):
"""Class for updating Aqua model deployments."""

display_name: Optional[str] = Field(
None, description="The name of the model deployment."
)
description: Optional[str] = Field(
None, description="The description of the deployment."
)
models: Optional[List[AquaMultiModelRef]] = Field(
None, description="List of models for multimodel deployment."
)
instance_count: Optional[int] = Field(
None, description="Number of instances used for deployment."
)
log_group_id: Optional[str] = Field(
None, description="OCI logging group ID for logs."
)
access_log_id: Optional[str] = Field(
None,
description="OCID for access logs. "
"https://docs.oracle.com/en-us/iaas/data-science/using/model_dep_using_logging.htm",
)
predict_log_id: Optional[str] = Field(
None,
description="OCID for prediction logs."
"https://docs.oracle.com/en-us/iaas/data-science/using/model_dep_using_logging.htm",
)
bandwidth_mbps: Optional[int] = Field(
None, description="Bandwidth limit on the load balancer in Mbps."
)
web_concurrency: Optional[int] = Field(
None, description="Number of worker processes/threads for handling requests."
)
memory_in_gbs: Optional[float] = Field(
None, description="Memory (in GB) for the selected shape."
)
ocpus: Optional[float] = Field(
None, description="OCPU count for the selected shape."
)
private_endpoint_id: Optional[str] = Field(
None, description="Private endpoint ID for model deployment."
)
freeform_tags: Optional[Dict] = Field(
None, description="Freeform tags for model deployment."
)
defined_tags: Optional[Dict] = Field(
None, description="Defined tags for model deployment."
)

class Config:
extra = "allow"
protected_namespaces = ()
Loading