Skip to content

Commit 0797da8

Browse files
committed
Updated pr.
1 parent 41f543e commit 0797da8

File tree

3 files changed

+88
-44
lines changed

3 files changed

+88
-44
lines changed

ads/aqua/model/model.py

Lines changed: 82 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,11 @@ def create(
170170
Union[DataScienceModel, DataScienceModelGroup]
171171
The instance of DataScienceModel or DataScienceModelGroup.
172172
"""
173-
fine_tune_weights = (
174-
model_id.fine_tune_weights
175-
if isinstance(model_id, AquaMultiModelRef)
176-
else []
177-
)
178-
model_id = (
179-
model_id.model_id if isinstance(model_id, AquaMultiModelRef) else model_id
180-
)
173+
fine_tune_weights = []
174+
if isinstance(model_id, AquaMultiModelRef):
175+
fine_tune_weights = model_id.fine_tune_weights
176+
model_id = model_id.model_id
177+
181178
service_model = DataScienceModel.from_id(model_id)
182179
target_project = project_id or PROJECT_OCID
183180
target_compartment = compartment_id or COMPARTMENT_OCID
@@ -194,26 +191,14 @@ def create(
194191

195192
custom_model = None
196193
if fine_tune_weights:
197-
custom_model = (
198-
DataScienceModelGroup()
199-
.with_compartment_id(target_compartment)
200-
.with_project_id(target_project)
201-
.with_display_name(service_model.display_name)
202-
.with_description(service_model.description)
203-
.with_freeform_tags(**combined_freeform_tags)
204-
.with_defined_tags(**combined_defined_tags)
205-
.with_custom_metadata_list(service_model.custom_metadata_list)
206-
.with_base_model_id(model_id)
207-
.with_member_models(
208-
[
209-
{
210-
"inference_key": fine_tune_weight.model_name,
211-
"model_id": fine_tune_weight.model_id,
212-
}
213-
for fine_tune_weight in fine_tune_weights
214-
]
215-
)
216-
.create()
194+
custom_model = self._create_model_group(
195+
model_id=model_id,
196+
compartment_id=target_compartment,
197+
project_id=target_project,
198+
freeform_tags=combined_freeform_tags,
199+
defined_tags=combined_defined_tags,
200+
fine_tune_weights=fine_tune_weights,
201+
service_model=service_model,
217202
)
218203

219204
logger.info(
@@ -233,21 +218,13 @@ def create(
233218
)
234219
return service_model
235220

236-
custom_model = (
237-
DataScienceModel()
238-
.with_compartment_id(target_compartment)
239-
.with_project_id(target_project)
240-
.with_model_file_description(
241-
json_dict=service_model.model_file_description
242-
)
243-
.with_display_name(service_model.display_name)
244-
.with_description(service_model.description)
245-
.with_freeform_tags(**combined_freeform_tags)
246-
.with_defined_tags(**combined_defined_tags)
247-
.with_custom_metadata_list(service_model.custom_metadata_list)
248-
.with_defined_metadata_list(service_model.defined_metadata_list)
249-
.with_provenance_metadata(service_model.provenance_metadata)
250-
.create(model_by_reference=True, **kwargs)
221+
custom_model = self._create_model(
222+
compartment_id=target_compartment,
223+
project_id=target_project,
224+
freeform_tags=combined_freeform_tags,
225+
defined_tags=combined_defined_tags,
226+
service_model=service_model,
227+
**kwargs,
251228
)
252229
logger.info(
253230
f"Aqua Model {custom_model.id} created with the service model {model_id}."
@@ -262,6 +239,68 @@ def create(
262239

263240
return custom_model
264241

242+
def _create_model(
243+
self,
244+
compartment_id: str,
245+
project_id: str,
246+
freeform_tags: Dict,
247+
defined_tags: Dict,
248+
service_model: DataScienceModel,
249+
**kwargs,
250+
):
251+
"""Creates a data science model by reference."""
252+
custom_model = (
253+
DataScienceModel()
254+
.with_compartment_id(compartment_id)
255+
.with_project_id(project_id)
256+
.with_model_file_description(json_dict=service_model.model_file_description)
257+
.with_display_name(service_model.display_name)
258+
.with_description(service_model.description)
259+
.with_freeform_tags(**freeform_tags)
260+
.with_defined_tags(**defined_tags)
261+
.with_custom_metadata_list(service_model.custom_metadata_list)
262+
.with_defined_metadata_list(service_model.defined_metadata_list)
263+
.with_provenance_metadata(service_model.provenance_metadata)
264+
.create(model_by_reference=True, **kwargs)
265+
)
266+
267+
return custom_model
268+
269+
def _create_model_group(
270+
self,
271+
model_id: str,
272+
compartment_id: str,
273+
project_id: str,
274+
freeform_tags: Dict,
275+
defined_tags: Dict,
276+
fine_tune_weights: List,
277+
service_model: DataScienceModel,
278+
):
279+
"""Creates a data science model group."""
280+
custom_model = (
281+
DataScienceModelGroup()
282+
.with_compartment_id(compartment_id)
283+
.with_project_id(project_id)
284+
.with_display_name(service_model.display_name)
285+
.with_description(service_model.description)
286+
.with_freeform_tags(**freeform_tags)
287+
.with_defined_tags(**defined_tags)
288+
.with_custom_metadata_list(service_model.custom_metadata_list)
289+
.with_base_model_id(model_id)
290+
.with_member_models(
291+
[
292+
{
293+
"inference_key": fine_tune_weight.model_name,
294+
"model_id": fine_tune_weight.model_id,
295+
}
296+
for fine_tune_weight in fine_tune_weights
297+
]
298+
)
299+
.create()
300+
)
301+
302+
return custom_model
303+
265304
@telemetry(entry_point="plugin=model&action=create", name="aqua")
266305
def create_multi(
267306
self,

ads/aqua/modeldeployment/deployment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,18 @@ def create(
216216
model_app = AquaModelApp()
217217
if (
218218
create_deployment_details.model_id
219-
or create_deployment_details.deployment_type == DEFAULT_DEPLOYMENT_TYPE
219+
or create_deployment_details.deployment_type
220220
):
221221
model_id = create_deployment_details.model_id
222222
if not model_id:
223223
if len(create_deployment_details.models) != 1:
224224
raise AquaValueError(
225225
"Invalid 'models' provided. Only one base model is required for model stack deployment."
226226
)
227+
if create_deployment_details.deployment_type != DEFAULT_DEPLOYMENT_TYPE:
228+
raise AquaValueError(
229+
f"Invalid 'deployment_type' provided. Only {DEFAULT_DEPLOYMENT_TYPE} is supported for model stack deployment."
230+
)
227231
model_id = create_deployment_details.models[0]
228232

229233
service_model_id = (

ads/model/datascience_model_group.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def _build_model_group_details(self) -> dict:
541541
)
542542

543543
build_model_group_details = copy.deepcopy(self._spec)
544+
# pop out the unrequired specs for building `CreateModelGroupDetails` or `UpdateModelGroupDetails`.
544545
build_model_group_details.pop(self.CONST_CUSTOM_METADATA_LIST, None)
545546
build_model_group_details.pop(self.CONST_MEMBER_MODELS, None)
546547
build_model_group_details.pop(self.CONST_BASE_MODEL_ID, None)

0 commit comments

Comments
 (0)