82
82
AquaDeploymentDetail ,
83
83
ConfigValidationError ,
84
84
CreateModelDeploymentDetails ,
85
+ ModelDeploymentDetails ,
86
+ UpdateModelDeploymentDetails ,
85
87
)
86
88
from ads .aqua .modeldeployment .model_group_config import ModelGroupConfig
87
89
from ads .aqua .shaperecommend .recommend import AquaShapeRecommend
110
112
ModelDeploymentInfrastructure ,
111
113
ModelDeploymentMode ,
112
114
)
115
+ from ads .model .deployment .model_deployment import (
116
+ ModelDeploymentUpdateType ,
117
+ )
113
118
from ads .model .model_metadata import ModelCustomMetadata , ModelCustomMetadataItem
114
119
from ads .telemetry import telemetry
115
120
@@ -397,14 +402,14 @@ def create(
397
402
398
403
def _validate_input_models (
399
404
self ,
400
- create_deployment_details : CreateModelDeploymentDetails ,
405
+ deployment_details : ModelDeploymentDetails ,
401
406
):
402
- """Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
407
+ """Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details for stacked or multi model deployment."""
403
408
# Collect all unique model IDs (including fine-tuned models)
404
409
source_model_ids = list (
405
410
{
406
411
model_id
407
- for model in create_deployment_details .models
412
+ for model in deployment_details .models
408
413
for model_id in model .all_model_ids ()
409
414
}
410
415
)
@@ -415,7 +420,7 @@ def _validate_input_models(
415
420
source_models = self .get_multi_source (source_model_ids ) or {}
416
421
417
422
try :
418
- create_deployment_details .validate_input_models (model_details = source_models )
423
+ deployment_details .validate_input_models (model_details = source_models )
419
424
except ConfigValidationError as err :
420
425
raise AquaValueError (f"{ err } " ) from err
421
426
@@ -1249,6 +1254,219 @@ def _get_container_type_key(
1249
1254
1250
1255
return container_type_key
1251
1256
1257
+ @telemetry (entry_point = "plugin=deployment&action=update" , name = "aqua" )
1258
+ def update (
1259
+ self ,
1260
+ model_deployment_id : str ,
1261
+ update_model_deployment_details : Optional [UpdateModelDeploymentDetails ] = None ,
1262
+ ** kwargs ,
1263
+ ) -> AquaDeployment :
1264
+ """Updates a AQUA model group deployment.
1265
+
1266
+ Args:
1267
+ update_model_deployment_details : UpdateModelDeploymentDetails, optional
1268
+ An instance of UpdateModelDeploymentDetails containing all optional
1269
+ fields for updating a model deployment via Aqua.
1270
+ kwargs:
1271
+ display_name (str): The name of the model deployment.
1272
+ description (Optional[str]): The description of the deployment.
1273
+ models (Optional[List[AquaMultiModelRef]]): List of models for deployment.
1274
+ instance_count (int): Number of instances used for deployment.
1275
+ log_group_id (Optional[str]): OCI logging group ID for logs.
1276
+ access_log_id (Optional[str]): OCID for access logs.
1277
+ predict_log_id (Optional[str]): OCID for prediction logs.
1278
+ bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps.
1279
+ web_concurrency (Optional[int]): Number of worker processes/threads for handling requests.
1280
+ memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape.
1281
+ ocpus (Optional[float]): OCPU count for the selected shape.
1282
+ freeform_tags (Optional[Dict]): Freeform tags for model deployment.
1283
+ defined_tags (Optional[Dict]): Defined tags for model deployment.
1284
+
1285
+ Returns
1286
+ -------
1287
+ AquaDeployment
1288
+ An Aqua deployment instance.
1289
+ """
1290
+ if not update_model_deployment_details :
1291
+ try :
1292
+ update_model_deployment_details = UpdateModelDeploymentDetails (** kwargs )
1293
+ except ValidationError as ex :
1294
+ custom_errors = build_pydantic_error_message (ex )
1295
+ raise AquaValueError (
1296
+ f"Invalid parameters for updating a model group deployment. Error details: { custom_errors } ."
1297
+ ) from ex
1298
+
1299
+ model_deployment = ModelDeployment .from_id (model_deployment_id )
1300
+
1301
+ infrastructure = model_deployment .infrastructure
1302
+ runtime = model_deployment .runtime
1303
+
1304
+ if not runtime .model_group_id :
1305
+ raise AquaValueError (
1306
+ "Invalid 'model_deployment_id'. Only model group deployment is supported to update."
1307
+ )
1308
+
1309
+ # updates model group if fine tuned weights changed.
1310
+ model = self ._update_model_group (
1311
+ runtime .model_group_id , update_model_deployment_details
1312
+ )
1313
+
1314
+ # updates model group deployment infrastructure
1315
+ (
1316
+ infrastructure .with_bandwidth_mbps (
1317
+ update_model_deployment_details .bandwidth_mbps
1318
+ or infrastructure .bandwidth_mbps
1319
+ )
1320
+ .with_replica (
1321
+ update_model_deployment_details .instance_count or infrastructure .replica
1322
+ )
1323
+ .with_web_concurrency (
1324
+ update_model_deployment_details .web_concurrency
1325
+ or infrastructure .web_concurrency
1326
+ )
1327
+ )
1328
+
1329
+ if (
1330
+ update_model_deployment_details .log_group_id
1331
+ and update_model_deployment_details .access_log_id
1332
+ ):
1333
+ infrastructure .with_access_log (
1334
+ log_group_id = update_model_deployment_details .log_group_id ,
1335
+ log_id = update_model_deployment_details .access_log_id ,
1336
+ )
1337
+
1338
+ if (
1339
+ update_model_deployment_details .log_group_id
1340
+ and update_model_deployment_details .predict_log_id
1341
+ ):
1342
+ infrastructure .with_predict_log (
1343
+ log_group_id = update_model_deployment_details .log_group_id ,
1344
+ log_id = update_model_deployment_details .predict_log_id ,
1345
+ )
1346
+
1347
+ if (
1348
+ update_model_deployment_details .memory_in_gbs
1349
+ and update_model_deployment_details .ocpus
1350
+ and infrastructure .shape_name .endswith ("Flex" )
1351
+ ):
1352
+ infrastructure .with_shape_config_details (
1353
+ ocpus = update_model_deployment_details .ocpus ,
1354
+ memory_in_gbs = update_model_deployment_details .memory_in_gbs ,
1355
+ )
1356
+
1357
+ # applies ZDT as default type to update parameters if model group id hasn't been changed
1358
+ update_type = ModelDeploymentUpdateType .ZDT
1359
+ # applies LIVE update if model group id has been changed
1360
+ if runtime .model_group_id != model .id :
1361
+ runtime .with_model_group_id (model .id )
1362
+ update_type = ModelDeploymentUpdateType .LIVE
1363
+
1364
+ freeform_tags = (
1365
+ update_model_deployment_details .freeform_tags
1366
+ or model_deployment .freeform_tags
1367
+ )
1368
+ defined_tags = (
1369
+ update_model_deployment_details .defined_tags
1370
+ or model_deployment .defined_tags
1371
+ )
1372
+
1373
+ # updates model group deployment
1374
+ (
1375
+ model_deployment .with_display_name (
1376
+ update_model_deployment_details .display_name
1377
+ or model_deployment .display_name
1378
+ )
1379
+ .with_description (
1380
+ update_model_deployment_details .description
1381
+ or model_deployment .description
1382
+ )
1383
+ .with_freeform_tags (** (freeform_tags or {}))
1384
+ .with_defined_tags (** (defined_tags or {}))
1385
+ .with_infrastructure (infrastructure )
1386
+ .with_runtime (runtime )
1387
+ )
1388
+
1389
+ model_deployment .update (wait_for_completion = False , update_type = update_type )
1390
+
1391
+ logger .info (f"Updating Aqua Model Deployment { model_deployment .id } ." )
1392
+
1393
+ return AquaDeployment .from_oci_model_deployment (
1394
+ model_deployment .dsc_model_deployment , self .region
1395
+ )
1396
+
1397
+ def _update_model_group (
1398
+ self ,
1399
+ model_group_id : str ,
1400
+ update_model_deployment_details : UpdateModelDeploymentDetails ,
1401
+ ) -> DataScienceModelGroup :
1402
+ """Creates a new model group if fine tuned weights changed.
1403
+
1404
+ Parameters
1405
+ ----------
1406
+ model_group_id: str
1407
+ The model group id.
1408
+ update_model_deployment_details: UpdateModelDeploymentDetails
1409
+ An instance of UpdateModelDeploymentDetails containing all optional
1410
+ fields for updating a model deployment via Aqua.
1411
+
1412
+ Returns
1413
+ -------
1414
+ DataScienceModelGroup
1415
+ The instance of DataScienceModelGroup.
1416
+ """
1417
+ model_group = DataScienceModelGroup .from_id (model_group_id )
1418
+ # create a new model group if fine tune weights changed as member models in ds model group is inmutable
1419
+ if update_model_deployment_details .models :
1420
+ if len (update_model_deployment_details .models ) != 1 :
1421
+ raise AquaValueError (
1422
+ "Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1423
+ )
1424
+ # validates input base and fine tune models
1425
+ self ._validate_input_models (update_model_deployment_details )
1426
+ target_stacked_model = update_model_deployment_details .models [0 ]
1427
+ target_base_model_id = target_stacked_model .model_id
1428
+ if model_group .base_model_id != target_base_model_id :
1429
+ raise AquaValueError (
1430
+ "Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1431
+ )
1432
+
1433
+ # add member models
1434
+ member_models = [
1435
+ {
1436
+ "inference_key" : fine_tune_weight .model_name ,
1437
+ "model_id" : fine_tune_weight .model_id ,
1438
+ }
1439
+ for fine_tune_weight in target_stacked_model .fine_tune_weights
1440
+ ]
1441
+ # add base model
1442
+ member_models .append (
1443
+ {
1444
+ "inference_key" : target_stacked_model .model_name ,
1445
+ "model_id" : target_base_model_id ,
1446
+ }
1447
+ )
1448
+
1449
+ # creates a model group with the same configurations from original model group except member models
1450
+ model_group = (
1451
+ DataScienceModelGroup ()
1452
+ .with_compartment_id (model_group .compartment_id )
1453
+ .with_project_id (model_group .project_id )
1454
+ .with_display_name (model_group .display_name )
1455
+ .with_description (model_group .description )
1456
+ .with_freeform_tags (** (model_group .freeform_tags or {}))
1457
+ .with_defined_tags (** (model_group .defined_tags or {}))
1458
+ .with_custom_metadata_list (model_group .custom_metadata_list )
1459
+ .with_base_model_id (target_base_model_id )
1460
+ .with_member_models (member_models )
1461
+ .create ()
1462
+ )
1463
+
1464
+ logger .info (
1465
+ f"Model group of base model { target_base_model_id } has been updated: { model_group .id } ."
1466
+ )
1467
+
1468
+ return model_group
1469
+
1252
1470
@telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" )
1253
1471
def list (self , ** kwargs ) -> List ["AquaDeployment" ]:
1254
1472
"""List Aqua model deployments in a given compartment and under certain project.
0 commit comments