Skip to content

Commit 308e492

Browse files
authored
Merge branch 'main' into python_3_13
2 parents cd4e4f5 + 45f7c84 commit 308e492

File tree

116 files changed

+6309
-2136
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+6309
-2136
lines changed

ads/aqua/app.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import os
77
import traceback
8+
from concurrent.futures import ThreadPoolExecutor
89
from dataclasses import fields
910
from datetime import datetime, timedelta
1011
from itertools import chain
@@ -22,7 +23,7 @@
2223
from ads.aqua import logger
2324
from ads.aqua.common.entities import ModelConfigResult
2425
from ads.aqua.common.enums import ConfigFolder, Tags
25-
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
26+
from ads.aqua.common.errors import AquaValueError
2627
from ads.aqua.common.utils import (
2728
_is_valid_mvs,
2829
get_artifact_path,
@@ -58,12 +59,15 @@
5859
class AquaApp:
5960
"""Base Aqua App to contain common components."""
6061

62+
MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63+
6164
@telemetry(name="aqua")
6265
def __init__(self) -> None:
6366
if OCI_RESOURCE_PRINCIPAL_VERSION:
6467
set_auth("resource_principal")
6568
self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT})
6669
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
70+
self.compute_client = oc.OCIClientFactory(**default_signer()).compute
6771
self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management
6872
self.identity_client = oc.OCIClientFactory(**default_signer()).identity
6973
self.region = extract_region(self._auth)
@@ -127,20 +131,69 @@ def update_model_provenance(
127131
update_model_provenance_details=update_model_provenance_details,
128132
)
129133

130-
# TODO: refactor model evaluation implementation to use it.
131134
@staticmethod
132135
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
133-
if is_valid_ocid(source_id):
134-
if "datasciencemodeldeployment" in source_id:
135-
return ModelDeployment.from_id(source_id)
136-
elif "datasciencemodel" in source_id:
137-
return DataScienceModel.from_id(source_id)
136+
"""
137+
Fetches a model or model deployment based on the provided OCID.
138+
139+
Parameters
140+
----------
141+
source_id : str
142+
OCID of the Data Science model or model deployment.
143+
144+
Returns
145+
-------
146+
Union[ModelDeployment, DataScienceModel]
147+
The corresponding resource object.
138148
149+
Raises
150+
------
151+
AquaValueError
152+
If the OCID is invalid or unsupported.
153+
"""
154+
logger.debug(f"Resolving source for ID: {source_id}")
155+
if not is_valid_ocid(source_id):
156+
logger.error(f"Invalid OCID format: {source_id}")
157+
raise AquaValueError(
158+
f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
159+
)
160+
161+
if "datasciencemodeldeployment" in source_id:
162+
logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
163+
return ModelDeployment.from_id(source_id)
164+
165+
if "datasciencemodel" in source_id:
166+
logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
167+
return DataScienceModel.from_id(source_id)
168+
169+
logger.error(f"Unrecognized OCID type: {source_id}")
139170
raise AquaValueError(
140-
f"Invalid source {source_id}. "
141-
"Specify either a model or model deployment id."
171+
f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
142172
)
143173

174+
def get_multi_source(
175+
self,
176+
ids: List[str],
177+
) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
178+
"""
179+
Retrieves multiple DataScience resources concurrently.
180+
181+
Parameters
182+
----------
183+
ids : List[str]
184+
A list of DataScience OCIDs.
185+
186+
Returns
187+
-------
188+
Dict[str, Union[ModelDeployment, DataScienceModel]]
189+
A mapping from OCID to the corresponding resolved resource object.
190+
"""
191+
logger.debug(f"Fetching {ids} sources in parallel.")
192+
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
193+
results = list(executor.map(self.get_source, ids))
194+
195+
return dict(zip(ids, results))
196+
144197
# TODO: refactor model evaluation implementation to use it.
145198
@staticmethod
146199
def create_model_version_set(
@@ -283,8 +336,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
283336
logger.info(f"Artifact not found in model {model_id}.")
284337
return False
285338

339+
@cached(cache=TTLCache(maxsize=5, ttl=timedelta(minutes=1), timer=datetime.now))
286340
def get_config_from_metadata(
287-
self, model_id: str, metadata_key: str
341+
self,
342+
model_id: str,
343+
metadata_key: str,
288344
) -> ModelConfigResult:
289345
"""Gets the config for the given Aqua model from model catalog metadata content.
290346
@@ -299,8 +355,9 @@ def get_config_from_metadata(
299355
ModelConfigResult
300356
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301357
"""
302-
config = {}
358+
config: Dict[str, Any] = {}
303359
oci_model = self.ds_client.get_model(model_id).data
360+
304361
try:
305362
config = self.ds_client.get_model_defined_metadatum_artifact_content(
306363
model_id, metadata_key
@@ -320,7 +377,7 @@ def get_config_from_metadata(
320377
)
321378
return ModelConfigResult(config=config, model_details=oci_model)
322379

323-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
380+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
324381
def get_config(
325382
self,
326383
model_id: str,
@@ -345,8 +402,10 @@ def get_config(
345402
ModelConfigResult
346403
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
347404
"""
348-
config_folder = config_folder or ConfigFolder.CONFIG
405+
config: Dict[str, Any] = {}
349406
oci_model = self.ds_client.get_model(model_id).data
407+
408+
config_folder = config_folder or ConfigFolder.CONFIG
350409
oci_aqua = (
351410
(
352411
Tags.AQUA_TAG in oci_model.freeform_tags
@@ -356,9 +415,9 @@ def get_config(
356415
else False
357416
)
358417
if not oci_aqua:
359-
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
418+
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
419+
return ModelConfigResult(config=config, model_details=oci_model)
360420

361-
config: Dict[str, Any] = {}
362421
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
363422
if not artifact_path:
364423
logger.debug(

ads/aqua/cli.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ads.aqua.finetuning import AquaFineTuningApp
1515
from ads.aqua.model import AquaModelApp
1616
from ads.aqua.modeldeployment import AquaDeploymentApp
17+
from ads.aqua.verify_policies import AquaVerifyPoliciesApp
1718
from ads.common.utils import LOG_LEVELS
1819

1920

@@ -29,6 +30,7 @@ class AquaCommand:
2930
fine_tuning = AquaFineTuningApp
3031
deployment = AquaDeploymentApp
3132
evaluation = AquaEvaluationApp
33+
verify_policies = AquaVerifyPoliciesApp
3234

3335
def __init__(
3436
self,
@@ -94,3 +96,18 @@ def _validate_value(flag, value):
9496
"If you intend to chain a function call to the result, please separate the "
9597
"flag and the subsequent function call with separator `-`."
9698
)
99+
100+
@staticmethod
101+
def install():
102+
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
103+
104+
Return
105+
------
106+
int:
107+
Installatation status.
108+
"""
109+
import subprocess
110+
111+
wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
112+
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
113+
return status.check_returncode

ads/aqua/client/client.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
6161

6262
def __init__(self, signer: Optional[oci.signer.Signer] = None):
6363
"""
64-
Initialize the HttpxOCIAuth instance.
64+
Initializes the authentication handler with the given or default OCI signer.
6565
66-
Args:
67-
signer (oci.signer.Signer): The OCI signer to use for signing requests.
66+
Parameters
67+
----------
68+
signer : oci.signer.Signer, optional
69+
The OCI signer instance to use. If None, a default signer will be retrieved.
6870
"""
69-
70-
self.signer = signer or authutil.default_signer().get("signer")
71+
try:
72+
self.signer = signer or authutil.default_signer().get("signer")
73+
if not self.signer:
74+
raise ValueError("OCI signer could not be initialized.")
75+
except Exception as e:
76+
logger.error("Failed to initialize OCI signer: %s", e)
77+
raise
7178

7279
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
7380
"""
@@ -80,21 +87,31 @@ def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
8087
httpx.Request: The signed HTTPX request.
8188
"""
8289
# Create a requests.Request object from the HTTPX request
83-
req = requests.Request(
84-
method=request.method,
85-
url=str(request.url),
86-
headers=dict(request.headers),
87-
data=request.content,
88-
)
89-
prepared_request = req.prepare()
90+
try:
91+
req = requests.Request(
92+
method=request.method,
93+
url=str(request.url),
94+
headers=dict(request.headers),
95+
data=request.content,
96+
)
97+
prepared_request = req.prepare()
98+
self.signer.do_request_sign(prepared_request)
99+
100+
# Replace headers on the original HTTPX request with signed headers
101+
request.headers.update(prepared_request.headers)
102+
logger.debug("Successfully signed request to %s", request.url)
90103

91-
# Sign the request using the OCI Signer
92-
self.signer.do_request_sign(prepared_request)
104+
# Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
105+
if (
106+
request.method in ["GET", "DELETE"]
107+
and "content-length" not in request.headers
108+
):
109+
request.headers["content-length"] = "0"
93110

94-
# Update the original HTTPX request with the signed headers
95-
request.headers.update(prepared_request.headers)
111+
except Exception as e:
112+
logger.error("Failed to sign request to %s: %s", request.url, e)
113+
raise
96114

97-
# Proceed with the request
98115
yield request
99116

100117

@@ -330,8 +347,8 @@ def _prepare_headers(
330347
"Content-Type": "application/json",
331348
"Accept": "text/event-stream" if stream else "application/json",
332349
}
333-
if stream:
334-
default_headers["enable-streaming"] = "true"
350+
# if stream:
351+
# default_headers["enable-streaming"] = "true"
335352
if headers:
336353
default_headers.update(headers)
337354

@@ -495,7 +512,7 @@ def generate(
495512
prompt: str,
496513
payload: Optional[Dict[str, Any]] = None,
497514
headers: Optional[Dict[str, str]] = None,
498-
stream: bool = True,
515+
stream: bool = False,
499516
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
500517
"""
501518
Generate text completion for the given prompt.
@@ -521,7 +538,7 @@ def chat(
521538
messages: List[Dict[str, Any]],
522539
payload: Optional[Dict[str, Any]] = None,
523540
headers: Optional[Dict[str, str]] = None,
524-
stream: bool = True,
541+
stream: bool = False,
525542
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
526543
"""
527544
Perform a chat interaction with the model.

0 commit comments

Comments
 (0)