Skip to content

Commit cd4e4f5

Browse files
authored
Merge branch 'main' into python_3_13
2 parents bc9d2a7 + ddb3094 commit cd4e4f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2438
-1200
lines changed

ads/aqua/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
get_async_httpx_client,
1515
get_httpx_client,
1616
)
17-
from ads.aqua.common.utils import fetch_service_compartment
1817
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1918

2019
ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL"
@@ -39,7 +38,3 @@ def set_log_level(log_level: str):
3938

4039
if OCI_RESOURCE_PRINCIPAL_VERSION:
4140
set_auth("resource_principal")
42-
43-
ODSC_MODEL_COMPARTMENT_OCID = (
44-
os.environ.get("ODSC_MODEL_COMPARTMENT_OCID") or fetch_service_compartment()
45-
)

ads/aqua/app.py

Lines changed: 133 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@
77
import traceback
88
from dataclasses import fields
99
from datetime import datetime, timedelta
10-
from typing import Any, Dict, Optional, Union
10+
from itertools import chain
11+
from typing import Any, Dict, List, Optional, Union
1112

1213
import oci
1314
from cachetools import TTLCache, cached
14-
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
15+
from oci.data_science.models import (
16+
ContainerSummary,
17+
UpdateModelDetails,
18+
UpdateModelProvenanceDetails,
19+
)
1520

1621
from ads import set_auth
1722
from ads.aqua import logger
@@ -24,6 +29,11 @@
2429
is_valid_ocid,
2530
load_config,
2631
)
32+
from ads.aqua.config.container_config import (
33+
AquaContainerConfig,
34+
AquaContainerConfigItem,
35+
)
36+
from ads.aqua.constants import SERVICE_MANAGED_CONTAINER_URI_SCHEME
2737
from ads.common import oci_client as oc
2838
from ads.common.auth import default_signer
2939
from ads.common.utils import UNKNOWN, extract_region, is_path_exists
@@ -240,7 +250,9 @@ def create_model_catalog(
240250
.with_custom_metadata_list(model_custom_metadata)
241251
.with_defined_metadata_list(model_taxonomy_metadata)
242252
.with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN))
243-
.with_defined_tags(**(defined_tags or {})) # Create defined tags when a model is created.
253+
.with_defined_tags(
254+
**(defined_tags or {})
255+
) # Create defined tags when a model is created.
244256
.create(
245257
**kwargs,
246258
)
@@ -271,6 +283,43 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
271283
logger.info(f"Artifact not found in model {model_id}.")
272284
return False
273285

286+
def get_config_from_metadata(
287+
self, model_id: str, metadata_key: str
288+
) -> ModelConfigResult:
289+
"""Gets the config for the given Aqua model from model catalog metadata content.
290+
291+
Parameters
292+
----------
293+
model_id: str
294+
The OCID of the Aqua model.
295+
metadata_key: str
296+
The metadata key name where artifact content is stored
297+
Returns
298+
-------
299+
ModelConfigResult
300+
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301+
"""
302+
config = {}
303+
oci_model = self.ds_client.get_model(model_id).data
304+
try:
305+
config = self.ds_client.get_model_defined_metadatum_artifact_content(
306+
model_id, metadata_key
307+
).data.content.decode("utf-8")
308+
return ModelConfigResult(config=json.loads(config), model_details=oci_model)
309+
except UnicodeDecodeError as ex:
310+
logger.error(
311+
f"Failed to decode content for '{metadata_key}' in defined metadata for model '{model_id}' : {ex}"
312+
)
313+
except json.JSONDecodeError as ex:
314+
logger.error(
315+
f"Invalid JSON format for '{metadata_key}' in defined metadata for model '{model_id}' : {ex}"
316+
)
317+
except Exception as ex:
318+
logger.error(
319+
f"Failed to retrieve defined metadata key '{metadata_key}' for model '{model_id}': {ex}"
320+
)
321+
return ModelConfigResult(config=config, model_details=oci_model)
322+
274323
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
275324
def get_config(
276325
self,
@@ -310,22 +359,7 @@ def get_config(
310359
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
311360

312361
config: Dict[str, Any] = {}
313-
314-
# if the current model has a service model tag, then
315-
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
316-
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
317-
logger.info(
318-
f"Base model found for the model: {oci_model.id}. "
319-
f"Loading {config_file_name} for base model {base_model_ocid}."
320-
)
321-
if config_folder == ConfigFolder.ARTIFACT:
322-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
323-
else:
324-
base_model = self.ds_client.get_model(base_model_ocid).data
325-
artifact_path = get_artifact_path(base_model.custom_metadata_list)
326-
else:
327-
logger.info(f"Loading {config_file_name} for model {oci_model.id}...")
328-
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
362+
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
329363
if not artifact_path:
330364
logger.debug(
331365
f"Failed to get artifact path from custom metadata for the model: {model_id}"
@@ -340,7 +374,7 @@ def get_config(
340374
config_file_path = os.path.join(config_path, config_file_name)
341375
if is_path_exists(config_file_path):
342376
try:
343-
logger.debug(
377+
logger.info(
344378
f"Loading config: `{config_file_name}` from `{config_path}`"
345379
)
346380
config = load_config(
@@ -361,6 +395,85 @@ def get_config(
361395

362396
return ModelConfigResult(config=config, model_details=oci_model)
363397

398+
def get_container_image(self, container_type: str = None) -> str:
399+
"""
400+
Gets the latest smc container complete image name from the given container type.
401+
402+
Parameters
403+
----------
404+
container_type: str
405+
type of container, can be either odsc-vllm-serving, odsc-llm-fine-tuning, odsc-llm-evaluate
406+
407+
Returns
408+
-------
409+
str:
410+
A complete container name along with version. ex: dsmc://odsc-vllm-serving:0.7.4.1
411+
"""
412+
413+
containers = self.list_service_containers()
414+
container = next(
415+
(c for c in containers if c.is_latest and c.family_name == container_type),
416+
None,
417+
)
418+
if not container:
419+
raise AquaValueError(f"Invalid container type : {container_type}")
420+
container_image = (
421+
SERVICE_MANAGED_CONTAINER_URI_SCHEME
422+
+ container.container_name
423+
+ ":"
424+
+ container.tag
425+
)
426+
return container_image
427+
428+
@cached(cache=TTLCache(maxsize=20, ttl=timedelta(minutes=30), timer=datetime.now))
429+
def list_service_containers(self) -> List[ContainerSummary]:
430+
"""
431+
List containers from containers.conf in OCI Datascience control plane
432+
"""
433+
containers = self.ds_client.list_containers().data
434+
return containers
435+
436+
def get_container_config(self) -> AquaContainerConfig:
437+
"""
438+
Fetches latest containers from containers.conf in OCI Datascience control plane
439+
440+
Returns
441+
-------
442+
AquaContainerConfig
443+
An Object that contains latest container info for the given container family
444+
445+
"""
446+
return AquaContainerConfig.from_service_config(
447+
service_containers=self.list_service_containers()
448+
)
449+
450+
def get_container_config_item(
451+
self, container_family: str
452+
) -> AquaContainerConfigItem:
453+
"""
454+
Fetches latest container for given container_family_name from containers.conf in OCI Datascience control plane
455+
456+
Returns
457+
-------
458+
AquaContainerConfigItem
459+
An Object that contains latest container info for the given container family
460+
461+
"""
462+
463+
aqua_container_config = self.get_container_config()
464+
inference_config = aqua_container_config.inference.values()
465+
ft_config = aqua_container_config.finetune.values()
466+
eval_config = aqua_container_config.evaluate.values()
467+
container = next(
468+
(
469+
container
470+
for container in chain(inference_config, ft_config, eval_config)
471+
if container.family.lower() == container_family.lower()
472+
),
473+
None,
474+
)
475+
return container
476+
364477
@property
365478
def telemetry(self):
366479
if not self._telemetry:

ads/aqua/cli.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65
import os
76

87
from ads.aqua import (
98
ENV_VAR_LOG_LEVEL,
10-
ODSC_MODEL_COMPARTMENT_OCID,
119
logger,
1210
set_log_level,
1311
)
14-
from ads.aqua.common.errors import AquaCLIError, AquaConfigError
12+
from ads.aqua.common.errors import AquaCLIError
1513
from ads.aqua.evaluation import AquaEvaluationApp
1614
from ads.aqua.finetuning import AquaFineTuningApp
1715
from ads.aqua.model import AquaModelApp
1816
from ads.aqua.modeldeployment import AquaDeploymentApp
1917
from ads.common.utils import LOG_LEVELS
20-
from ads.config import NB_SESSION_OCID
2118

2219

2320
class AquaCommand:
@@ -82,16 +79,6 @@ def __init__(
8279

8380
set_log_level(aqua_log_level)
8481

85-
if not ODSC_MODEL_COMPARTMENT_OCID:
86-
if NB_SESSION_OCID:
87-
raise AquaConfigError(
88-
f"Aqua is not available for the notebook session {NB_SESSION_OCID}. For more information, "
89-
f"please refer to the documentation."
90-
)
91-
raise AquaConfigError(
92-
"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
93-
)
94-
9582
@staticmethod
9683
def _validate_value(flag, value):
9784
"""Check if the given value for bool flag is valid.

ads/aqua/common/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class AquaMultiModelRef(Serializable):
151151
The name of the model.
152152
gpu_count : Optional[int]
153153
Number of GPUs required for deployment.
154+
model_task : Optional[str]
155+
The task that model operates on. Supported tasks are in MultiModelSupportedTaskType
154156
env_var : Optional[Dict[str, Any]]
155157
Optional environment variables to override during deployment.
156158
artifact_location : Optional[str]
@@ -162,6 +164,7 @@ class AquaMultiModelRef(Serializable):
162164
gpu_count: Optional[int] = Field(
163165
None, description="The gpu count allocation for the model."
164166
)
167+
model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
165168
env_var: Optional[dict] = Field(
166169
default_factory=dict, description="The environment variables of the model."
167170
)

0 commit comments

Comments
 (0)