5
5
import json
6
6
import os
7
7
import traceback
8
+ from concurrent .futures import ThreadPoolExecutor
8
9
from dataclasses import fields
9
10
from datetime import datetime , timedelta
10
11
from itertools import chain
22
23
from ads .aqua import logger
23
24
from ads .aqua .common .entities import ModelConfigResult
24
25
from ads .aqua .common .enums import ConfigFolder , Tags
25
- from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
26
+ from ads .aqua .common .errors import AquaValueError
26
27
from ads .aqua .common .utils import (
27
28
_is_valid_mvs ,
28
29
get_artifact_path ,
58
59
class AquaApp :
59
60
"""Base Aqua App to contain common components."""
60
61
62
+ MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63
+
61
64
@telemetry (name = "aqua" )
62
65
def __init__ (self ) -> None :
63
66
if OCI_RESOURCE_PRINCIPAL_VERSION :
64
67
set_auth ("resource_principal" )
65
68
self ._auth = default_signer ({"service_endpoint" : OCI_ODSC_SERVICE_ENDPOINT })
66
69
self .ds_client = oc .OCIClientFactory (** self ._auth ).data_science
70
+ self .compute_client = oc .OCIClientFactory (** default_signer ()).compute
67
71
self .logging_client = oc .OCIClientFactory (** default_signer ()).logging_management
68
72
self .identity_client = oc .OCIClientFactory (** default_signer ()).identity
69
73
self .region = extract_region (self ._auth )
@@ -127,20 +131,69 @@ def update_model_provenance(
127
131
update_model_provenance_details = update_model_provenance_details ,
128
132
)
129
133
130
- # TODO: refactor model evaluation implementation to use it.
131
134
@staticmethod
132
135
def get_source (source_id : str ) -> Union [ModelDeployment , DataScienceModel ]:
133
- if is_valid_ocid (source_id ):
134
- if "datasciencemodeldeployment" in source_id :
135
- return ModelDeployment .from_id (source_id )
136
- elif "datasciencemodel" in source_id :
137
- return DataScienceModel .from_id (source_id )
136
+ """
137
+ Fetches a model or model deployment based on the provided OCID.
138
+
139
+ Parameters
140
+ ----------
141
+ source_id : str
142
+ OCID of the Data Science model or model deployment.
143
+
144
+ Returns
145
+ -------
146
+ Union[ModelDeployment, DataScienceModel]
147
+ The corresponding resource object.
138
148
149
+ Raises
150
+ ------
151
+ AquaValueError
152
+ If the OCID is invalid or unsupported.
153
+ """
154
+ logger .debug (f"Resolving source for ID: { source_id } " )
155
+ if not is_valid_ocid (source_id ):
156
+ logger .error (f"Invalid OCID format: { source_id } " )
157
+ raise AquaValueError (
158
+ f"Invalid source ID: { source_id } . Please provide a valid model or model deployment OCID."
159
+ )
160
+
161
+ if "datasciencemodeldeployment" in source_id :
162
+ logger .debug (f"Identified as ModelDeployment OCID: { source_id } " )
163
+ return ModelDeployment .from_id (source_id )
164
+
165
+ if "datasciencemodel" in source_id :
166
+ logger .debug (f"Identified as DataScienceModel OCID: { source_id } " )
167
+ return DataScienceModel .from_id (source_id )
168
+
169
+ logger .error (f"Unrecognized OCID type: { source_id } " )
139
170
raise AquaValueError (
140
- f"Invalid source { source_id } . "
141
- "Specify either a model or model deployment id."
171
+ f"Unsupported source ID type: { source_id } . Must be a model or model deployment OCID."
142
172
)
143
173
174
+ def get_multi_source (
175
+ self ,
176
+ ids : List [str ],
177
+ ) -> Dict [str , Union [ModelDeployment , DataScienceModel ]]:
178
+ """
179
+ Retrieves multiple DataScience resources concurrently.
180
+
181
+ Parameters
182
+ ----------
183
+ ids : List[str]
184
+ A list of DataScience OCIDs.
185
+
186
+ Returns
187
+ -------
188
+ Dict[str, Union[ModelDeployment, DataScienceModel]]
189
+ A mapping from OCID to the corresponding resolved resource object.
190
+ """
191
+ logger .debug (f"Fetching { ids } sources in parallel." )
192
+ with ThreadPoolExecutor (max_workers = self .MAX_WORKERS ) as executor :
193
+ results = list (executor .map (self .get_source , ids ))
194
+
195
+ return dict (zip (ids , results ))
196
+
144
197
# TODO: refactor model evaluation implementation to use it.
145
198
@staticmethod
146
199
def create_model_version_set (
@@ -283,8 +336,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
283
336
logger .info (f"Artifact not found in model { model_id } ." )
284
337
return False
285
338
339
+ @cached (cache = TTLCache (maxsize = 5 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
286
340
def get_config_from_metadata (
287
- self , model_id : str , metadata_key : str
341
+ self ,
342
+ model_id : str ,
343
+ metadata_key : str ,
288
344
) -> ModelConfigResult :
289
345
"""Gets the config for the given Aqua model from model catalog metadata content.
290
346
@@ -299,8 +355,9 @@ def get_config_from_metadata(
299
355
ModelConfigResult
300
356
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301
357
"""
302
- config = {}
358
+ config : Dict [ str , Any ] = {}
303
359
oci_model = self .ds_client .get_model (model_id ).data
360
+
304
361
try :
305
362
config = self .ds_client .get_model_defined_metadatum_artifact_content (
306
363
model_id , metadata_key
@@ -320,7 +377,7 @@ def get_config_from_metadata(
320
377
)
321
378
return ModelConfigResult (config = config , model_details = oci_model )
322
379
323
- @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
380
+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
324
381
def get_config (
325
382
self ,
326
383
model_id : str ,
@@ -345,8 +402,10 @@ def get_config(
345
402
ModelConfigResult
346
403
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
347
404
"""
348
- config_folder = config_folder or ConfigFolder . CONFIG
405
+ config : Dict [ str , Any ] = {}
349
406
oci_model = self .ds_client .get_model (model_id ).data
407
+
408
+ config_folder = config_folder or ConfigFolder .CONFIG
350
409
oci_aqua = (
351
410
(
352
411
Tags .AQUA_TAG in oci_model .freeform_tags
@@ -356,9 +415,9 @@ def get_config(
356
415
else False
357
416
)
358
417
if not oci_aqua :
359
- raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
418
+ logger .debug (f"Target model { oci_model .id } is not an Aqua model." )
419
+ return ModelConfigResult (config = config , model_details = oci_model )
360
420
361
- config : Dict [str , Any ] = {}
362
421
artifact_path = get_artifact_path (oci_model .custom_metadata_list )
363
422
if not artifact_path :
364
423
logger .debug (
0 commit comments