Skip to content

Commit 190e2d8

Browse files
committed
Updated pr.
1 parent 0797da8 commit 190e2d8

File tree

5 files changed

+34
-35
lines changed

5 files changed

+34
-35
lines changed

ads/aqua/model/model.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class AquaModelApp(AquaApp):
141141
@telemetry(entry_point="plugin=model&action=create", name="aqua")
142142
def create(
143143
self,
144-
model_id: Union[str, AquaMultiModelRef],
144+
model: Union[str, AquaMultiModelRef],
145145
project_id: Optional[str] = None,
146146
compartment_id: Optional[str] = None,
147147
freeform_tags: Optional[Dict] = None,
@@ -153,7 +153,7 @@ def create(
153153
154154
Parameters
155155
----------
156-
model_id : Union[str, AquaMultiModelRef]
156+
model : Union[str, AquaMultiModelRef]
157157
The model ID as a string or a AquaMultiModelRef instance to be deployed.
158158
project_id : Optional[str]
159159
The project ID for the custom model.
@@ -171,11 +171,11 @@ def create(
171171
The instance of DataScienceModel or DataScienceModelGroup.
172172
"""
173173
fine_tune_weights = []
174-
if isinstance(model_id, AquaMultiModelRef):
175-
fine_tune_weights = model_id.fine_tune_weights
176-
model_id = model_id.model_id
174+
if isinstance(model, AquaMultiModelRef):
175+
fine_tune_weights = model.fine_tune_weights
176+
model = model.model_id
177177

178-
service_model = DataScienceModel.from_id(model_id)
178+
service_model = DataScienceModel.from_id(model)
179179
target_project = project_id or PROJECT_OCID
180180
target_compartment = compartment_id or COMPARTMENT_OCID
181181

@@ -192,7 +192,7 @@ def create(
192192
custom_model = None
193193
if fine_tune_weights:
194194
custom_model = self._create_model_group(
195-
model_id=model_id,
195+
model_id=model,
196196
compartment_id=target_compartment,
197197
project_id=target_project,
198198
freeform_tags=combined_freeform_tags,
@@ -202,18 +202,16 @@ def create(
202202
)
203203

204204
logger.info(
205-
f"Aqua Model Group {custom_model.id} created with the service model {model_id}."
205+
f"Aqua Model Group {custom_model.id} created with the service model {model}."
206206
)
207207
else:
208208
# Skip model copying if it is registered model or fine-tuned model
209209
if (
210-
service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
211-
is not None
212-
or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG)
213-
is not None
210+
Tags.BASE_MODEL_CUSTOM in service_model.freeform_tags
211+
or Tags.AQUA_FINE_TUNED_MODEL_TAG in service_model.freeform_tags
214212
):
215213
logger.info(
216-
f"Aqua Model {model_id} already exists in the user's compartment."
214+
f"Aqua Model {model} already exists in the user's compartment."
217215
"Skipped copying."
218216
)
219217
return service_model
@@ -227,7 +225,7 @@ def create(
227225
**kwargs,
228226
)
229227
logger.info(
230-
f"Aqua Model {custom_model.id} created with the service model {model_id}."
228+
f"Aqua Model {custom_model.id} created with the service model {model}."
231229
)
232230

233231
# Track unique models that were created in the user's compartment

ads/aqua/modeldeployment/constants.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
This module contains constants used in Aqua Model Deployment.
1010
"""
1111

12+
from ads.common.extended_enum import ExtendedEnum
13+
1214
DEFAULT_WAIT_TIME = 12000
1315
DEFAULT_POLL_INTERVAL = 10
14-
DEFAULT_DEPLOYMENT_TYPE = "MODEL_STACK"
16+
17+
18+
class DeploymentType(ExtendedEnum):
19+
STACKED = "STACKED"
20+
MULTI = "MULTI"

ads/aqua/modeldeployment/deployment.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464
MultiModelDeploymentConfigLoader,
6565
)
6666
from ads.aqua.modeldeployment.constants import (
67-
DEFAULT_DEPLOYMENT_TYPE,
6867
DEFAULT_POLL_INTERVAL,
6968
DEFAULT_WAIT_TIME,
69+
DeploymentType,
7070
)
7171
from ads.aqua.modeldeployment.entities import (
7272
AquaDeployment,
@@ -216,30 +216,24 @@ def create(
216216
model_app = AquaModelApp()
217217
if (
218218
create_deployment_details.model_id
219-
or create_deployment_details.deployment_type
219+
or create_deployment_details.deployment_type == DeploymentType.STACKED
220220
):
221-
model_id = create_deployment_details.model_id
222-
if not model_id:
221+
model = create_deployment_details.model_id
222+
if not model:
223223
if len(create_deployment_details.models) != 1:
224224
raise AquaValueError(
225225
"Invalid 'models' provided. Only one base model is required for model stack deployment."
226226
)
227-
if create_deployment_details.deployment_type != DEFAULT_DEPLOYMENT_TYPE:
228-
raise AquaValueError(
229-
f"Invalid 'deployment_type' provided. Only {DEFAULT_DEPLOYMENT_TYPE} is supported for model stack deployment."
230-
)
231-
model_id = create_deployment_details.models[0]
227+
model = create_deployment_details.models[0]
232228

233-
service_model_id = (
234-
model_id if isinstance(model_id, str) else model_id.model_id
235-
)
229+
service_model_id = model if isinstance(model, str) else model.model_id
236230
logger.debug(
237231
f"Single model ({service_model_id}) provided. "
238232
"Delegating to single model creation method."
239233
)
240234

241235
aqua_model = model_app.create(
242-
model_id=model_id,
236+
model=model,
243237
compartment_id=compartment_id,
244238
project_id=project_id,
245239
freeform_tags=freeform_tags,
@@ -250,6 +244,7 @@ def create(
250244
create_deployment_details=create_deployment_details,
251245
container_config=container_config,
252246
)
247+
# TODO: add multi model validation from deployment_type
253248
else:
254249
# Collect all unique model IDs (including fine-tuned models)
255250
source_model_ids = list(

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,7 @@ def test_create_deployment_for_foundation_model(
15151515
)
15161516

15171517
mock_create.assert_called_with(
1518-
model_id=TestDataset.MODEL_ID,
1518+
model=TestDataset.MODEL_ID,
15191519
compartment_id=TestDataset.USER_COMPARTMENT_ID,
15201520
project_id=TestDataset.USER_PROJECT_ID,
15211521
freeform_tags=freeform_tags,
@@ -1611,7 +1611,7 @@ def test_create_deployment_for_fine_tuned_model(
16111611
)
16121612

16131613
mock_create.assert_called_with(
1614-
model_id=TestDataset.MODEL_ID,
1614+
model=TestDataset.MODEL_ID,
16151615
compartment_id=TestDataset.USER_COMPARTMENT_ID,
16161616
project_id=TestDataset.USER_PROJECT_ID,
16171617
freeform_tags=None,
@@ -1707,7 +1707,7 @@ def test_create_deployment_for_gguf_model(
17071707
)
17081708

17091709
mock_create.assert_called_with(
1710-
model_id=TestDataset.MODEL_ID,
1710+
model=TestDataset.MODEL_ID,
17111711
compartment_id=TestDataset.USER_COMPARTMENT_ID,
17121712
project_id=TestDataset.USER_PROJECT_ID,
17131713
freeform_tags=None,
@@ -1810,7 +1810,7 @@ def test_create_deployment_for_tei_byoc_embedding_model(
18101810
)
18111811

18121812
mock_create.assert_called_with(
1813-
model_id=TestDataset.MODEL_ID,
1813+
model=TestDataset.MODEL_ID,
18141814
compartment_id=TestDataset.USER_COMPARTMENT_ID,
18151815
project_id=TestDataset.USER_PROJECT_ID,
18161816
freeform_tags=None,
@@ -1928,7 +1928,7 @@ def test_create_deployment_for_stack_model(
19281928
predict_log_id="ocid1.log.oc1.<region>.<OCID>",
19291929
freeform_tags=freeform_tags,
19301930
defined_tags=defined_tags,
1931-
deployment_type="MODEL_STACK",
1931+
deployment_type="STACKED",
19321932
)
19331933

19341934
mock_create.assert_called()

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create):
418418

419419
# will not copy service model
420420
self.app.create(
421-
model_id="test_model_id",
421+
model="test_model_id",
422422
project_id="test_project_id",
423423
compartment_id="test_compartment_id",
424424
)
@@ -433,7 +433,7 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create):
433433
mock_model.freeform_tags.pop(Tags.BASE_MODEL_CUSTOM)
434434
# will copy service model
435435
model = self.app.create(
436-
model_id="test_model_id",
436+
model="test_model_id",
437437
project_id="test_project_id",
438438
compartment_id="test_compartment_id",
439439
)

0 commit comments

Comments
 (0)