Skip to content

Commit aca26e7

Browse files
authored
[AQUA][STMD] Added support to edit stacked model deployment (#1265)
2 parents 129cd8f + c476ecb commit aca26e7

File tree

6 files changed

+702
-254
lines changed

6 files changed

+702
-254
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,15 @@ def post(self, *args, **kwargs): # noqa: ARG002
119119
if not input_data:
120120
raise HTTPError(400, Errors.NO_INPUT_DATA)
121121

122-
self.finish(AquaDeploymentApp().create(**input_data))
122+
model_deployment_id = input_data.pop("model_deployment_id", None)
123+
if model_deployment_id:
124+
self.finish(
125+
AquaDeploymentApp().update(
126+
model_deployment_id=model_deployment_id, **input_data
127+
)
128+
)
129+
else:
130+
self.finish(AquaDeploymentApp().create(**input_data))
123131

124132
def read(self, id):
125133
"""Read the information of an Aqua model deployment."""

ads/aqua/modeldeployment/deployment.py

Lines changed: 222 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282
AquaDeploymentDetail,
8383
ConfigValidationError,
8484
CreateModelDeploymentDetails,
85+
ModelDeploymentDetails,
86+
UpdateModelDeploymentDetails,
8587
)
8688
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
8789
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
@@ -110,6 +112,9 @@
110112
ModelDeploymentInfrastructure,
111113
ModelDeploymentMode,
112114
)
115+
from ads.model.deployment.model_deployment import (
116+
ModelDeploymentUpdateType,
117+
)
113118
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
114119
from ads.telemetry import telemetry
115120

@@ -397,14 +402,14 @@ def create(
397402

398403
def _validate_input_models(
399404
self,
400-
create_deployment_details: CreateModelDeploymentDetails,
405+
deployment_details: ModelDeploymentDetails,
401406
):
402-
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
407+
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details for stacked or multi model deployment."""
403408
# Collect all unique model IDs (including fine-tuned models)
404409
source_model_ids = list(
405410
{
406411
model_id
407-
for model in create_deployment_details.models
412+
for model in deployment_details.models
408413
for model_id in model.all_model_ids()
409414
}
410415
)
@@ -415,7 +420,7 @@ def _validate_input_models(
415420
source_models = self.get_multi_source(source_model_ids) or {}
416421

417422
try:
418-
create_deployment_details.validate_input_models(model_details=source_models)
423+
deployment_details.validate_input_models(model_details=source_models)
419424
except ConfigValidationError as err:
420425
raise AquaValueError(f"{err}") from err
421426

@@ -1249,6 +1254,219 @@ def _get_container_type_key(
12491254

12501255
return container_type_key
12511256

1257+
@telemetry(entry_point="plugin=deployment&action=update", name="aqua")
1258+
def update(
1259+
self,
1260+
model_deployment_id: str,
1261+
update_model_deployment_details: Optional[UpdateModelDeploymentDetails] = None,
1262+
**kwargs,
1263+
) -> AquaDeployment:
1264+
"""Updates a AQUA model group deployment.
1265+
1266+
Args:
1267+
update_model_deployment_details : UpdateModelDeploymentDetails, optional
1268+
An instance of UpdateModelDeploymentDetails containing all optional
1269+
fields for updating a model deployment via Aqua.
1270+
kwargs:
1271+
display_name (str): The name of the model deployment.
1272+
description (Optional[str]): The description of the deployment.
1273+
models (Optional[List[AquaMultiModelRef]]): List of models for deployment.
1274+
instance_count (int): Number of instances used for deployment.
1275+
log_group_id (Optional[str]): OCI logging group ID for logs.
1276+
access_log_id (Optional[str]): OCID for access logs.
1277+
predict_log_id (Optional[str]): OCID for prediction logs.
1278+
bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps.
1279+
web_concurrency (Optional[int]): Number of worker processes/threads for handling requests.
1280+
memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape.
1281+
ocpus (Optional[float]): OCPU count for the selected shape.
1282+
freeform_tags (Optional[Dict]): Freeform tags for model deployment.
1283+
defined_tags (Optional[Dict]): Defined tags for model deployment.
1284+
1285+
Returns
1286+
-------
1287+
AquaDeployment
1288+
An Aqua deployment instance.
1289+
"""
1290+
if not update_model_deployment_details:
1291+
try:
1292+
update_model_deployment_details = UpdateModelDeploymentDetails(**kwargs)
1293+
except ValidationError as ex:
1294+
custom_errors = build_pydantic_error_message(ex)
1295+
raise AquaValueError(
1296+
f"Invalid parameters for updating a model group deployment. Error details: {custom_errors}."
1297+
) from ex
1298+
1299+
model_deployment = ModelDeployment.from_id(model_deployment_id)
1300+
1301+
infrastructure = model_deployment.infrastructure
1302+
runtime = model_deployment.runtime
1303+
1304+
if not runtime.model_group_id:
1305+
raise AquaValueError(
1306+
"Invalid 'model_deployment_id'. Only model group deployment is supported to update."
1307+
)
1308+
1309+
# updates model group if fine tuned weights changed.
1310+
model = self._update_model_group(
1311+
runtime.model_group_id, update_model_deployment_details
1312+
)
1313+
1314+
# updates model group deployment infrastructure
1315+
(
1316+
infrastructure.with_bandwidth_mbps(
1317+
update_model_deployment_details.bandwidth_mbps
1318+
or infrastructure.bandwidth_mbps
1319+
)
1320+
.with_replica(
1321+
update_model_deployment_details.instance_count or infrastructure.replica
1322+
)
1323+
.with_web_concurrency(
1324+
update_model_deployment_details.web_concurrency
1325+
or infrastructure.web_concurrency
1326+
)
1327+
)
1328+
1329+
if (
1330+
update_model_deployment_details.log_group_id
1331+
and update_model_deployment_details.access_log_id
1332+
):
1333+
infrastructure.with_access_log(
1334+
log_group_id=update_model_deployment_details.log_group_id,
1335+
log_id=update_model_deployment_details.access_log_id,
1336+
)
1337+
1338+
if (
1339+
update_model_deployment_details.log_group_id
1340+
and update_model_deployment_details.predict_log_id
1341+
):
1342+
infrastructure.with_predict_log(
1343+
log_group_id=update_model_deployment_details.log_group_id,
1344+
log_id=update_model_deployment_details.predict_log_id,
1345+
)
1346+
1347+
if (
1348+
update_model_deployment_details.memory_in_gbs
1349+
and update_model_deployment_details.ocpus
1350+
and infrastructure.shape_name.endswith("Flex")
1351+
):
1352+
infrastructure.with_shape_config_details(
1353+
ocpus=update_model_deployment_details.ocpus,
1354+
memory_in_gbs=update_model_deployment_details.memory_in_gbs,
1355+
)
1356+
1357+
# applies ZDT as default type to update parameters if model group id hasn't been changed
1358+
update_type = ModelDeploymentUpdateType.ZDT
1359+
# applies LIVE update if model group id has been changed
1360+
if runtime.model_group_id != model.id:
1361+
runtime.with_model_group_id(model.id)
1362+
update_type = ModelDeploymentUpdateType.LIVE
1363+
1364+
freeform_tags = (
1365+
update_model_deployment_details.freeform_tags
1366+
or model_deployment.freeform_tags
1367+
)
1368+
defined_tags = (
1369+
update_model_deployment_details.defined_tags
1370+
or model_deployment.defined_tags
1371+
)
1372+
1373+
# updates model group deployment
1374+
(
1375+
model_deployment.with_display_name(
1376+
update_model_deployment_details.display_name
1377+
or model_deployment.display_name
1378+
)
1379+
.with_description(
1380+
update_model_deployment_details.description
1381+
or model_deployment.description
1382+
)
1383+
.with_freeform_tags(**(freeform_tags or {}))
1384+
.with_defined_tags(**(defined_tags or {}))
1385+
.with_infrastructure(infrastructure)
1386+
.with_runtime(runtime)
1387+
)
1388+
1389+
model_deployment.update(wait_for_completion=False, update_type=update_type)
1390+
1391+
logger.info(f"Updating Aqua Model Deployment {model_deployment.id}.")
1392+
1393+
return AquaDeployment.from_oci_model_deployment(
1394+
model_deployment.dsc_model_deployment, self.region
1395+
)
1396+
1397+
def _update_model_group(
1398+
self,
1399+
model_group_id: str,
1400+
update_model_deployment_details: UpdateModelDeploymentDetails,
1401+
) -> DataScienceModelGroup:
1402+
"""Creates a new model group if fine tuned weights changed.
1403+
1404+
Parameters
1405+
----------
1406+
model_group_id: str
1407+
The model group id.
1408+
update_model_deployment_details: UpdateModelDeploymentDetails
1409+
An instance of UpdateModelDeploymentDetails containing all optional
1410+
fields for updating a model deployment via Aqua.
1411+
1412+
Returns
1413+
-------
1414+
DataScienceModelGroup
1415+
The instance of DataScienceModelGroup.
1416+
"""
1417+
model_group = DataScienceModelGroup.from_id(model_group_id)
1418+
# create a new model group if fine tune weights changed as member models in ds model group is inmutable
1419+
if update_model_deployment_details.models:
1420+
if len(update_model_deployment_details.models) != 1:
1421+
raise AquaValueError(
1422+
"Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1423+
)
1424+
# validates input base and fine tune models
1425+
self._validate_input_models(update_model_deployment_details)
1426+
target_stacked_model = update_model_deployment_details.models[0]
1427+
target_base_model_id = target_stacked_model.model_id
1428+
if model_group.base_model_id != target_base_model_id:
1429+
raise AquaValueError(
1430+
"Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1431+
)
1432+
1433+
# add member models
1434+
member_models = [
1435+
{
1436+
"inference_key": fine_tune_weight.model_name,
1437+
"model_id": fine_tune_weight.model_id,
1438+
}
1439+
for fine_tune_weight in target_stacked_model.fine_tune_weights
1440+
]
1441+
# add base model
1442+
member_models.append(
1443+
{
1444+
"inference_key": target_stacked_model.model_name,
1445+
"model_id": target_base_model_id,
1446+
}
1447+
)
1448+
1449+
# creates a model group with the same configurations from original model group except member models
1450+
model_group = (
1451+
DataScienceModelGroup()
1452+
.with_compartment_id(model_group.compartment_id)
1453+
.with_project_id(model_group.project_id)
1454+
.with_display_name(model_group.display_name)
1455+
.with_description(model_group.description)
1456+
.with_freeform_tags(**(model_group.freeform_tags or {}))
1457+
.with_defined_tags(**(model_group.defined_tags or {}))
1458+
.with_custom_metadata_list(model_group.custom_metadata_list)
1459+
.with_base_model_id(target_base_model_id)
1460+
.with_member_models(member_models)
1461+
.create()
1462+
)
1463+
1464+
logger.info(
1465+
f"Model group of base model {target_base_model_id} has been updated: {model_group.id}."
1466+
)
1467+
1468+
return model_group
1469+
12521470
@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
12531471
def list(self, **kwargs) -> List["AquaDeployment"]:
12541472
"""List Aqua model deployments in a given compartment and under certain project.

0 commit comments

Comments
 (0)