2
2
# Copyright (c) 2025 Oracle and/or its affiliates.
3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
5
- #!/usr/bin/env python
6
- # Copyright (c) 2025 Oracle and/or its affiliates.
7
- # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
8
-
9
- import shutil
10
- import os
11
5
import json
12
- from typing import List , Union , Optional , Dict , Any , Tuple
6
+ import os
7
+ import shutil
8
+ from typing import Dict , List , Optional , Tuple , Union
13
9
14
10
from pydantic import ValidationError
15
11
from rich .table import Table
16
12
17
13
from huggingface_hub import hf_hub_download
18
14
from huggingface_hub .utils import HfHubHTTPError
15
+
19
16
from ads .aqua .app import logger
20
17
from ads .aqua .common .entities import ComputeShapeSummary
21
18
from ads .aqua .common .errors import (
25
22
)
26
23
from ads .aqua .common .utils import (
27
24
build_pydantic_error_message ,
25
+ get_resource_type ,
26
+ is_valid_ocid ,
28
27
load_config ,
29
28
load_gpu_shapes_index ,
30
- is_valid_ocid ,
31
- get_resource_type ,
29
+ format_hf_custom_error_message ,
32
30
)
33
31
from ads .aqua .shaperecommend .constants import (
34
- BITS_AND_BYTES_4BIT ,
35
32
BITSANDBYTES ,
33
+ BITS_AND_BYTES_4BIT ,
36
34
SAFETENSORS ,
37
35
SHAPE_MAP ,
38
36
TEXT_GENERATION ,
46
44
ShapeRecommendationReport ,
47
45
ShapeReport ,
48
46
)
47
+ from ads .config import COMPARTMENT_OCID
49
48
from ads .model .datascience_model import DataScienceModel
50
49
from ads .model .service .oci_datascience_model_deployment import (
51
50
OCIDataScienceModelDeployment ,
@@ -100,9 +99,6 @@ def which_shapes(
100
99
"""
101
100
try :
102
101
shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
103
- # data, model_name = self._get_model_config_and_name(
104
- # model_id=request.model_id, compartment_id=request.compartment_id
105
- # )
106
102
data , model_name = self ._get_model_config_and_name (
107
103
model_id = request .model_id ,
108
104
)
@@ -158,41 +154,18 @@ def _get_model_config_and_name(
158
154
- The display name for the model.
159
155
"""
160
156
if is_valid_ocid (model_id ):
161
- logger .info (f"' { model_id } ' identified as a model OCID ." )
157
+ logger .info (f"Detected OCID: Fetching OCI model config for ' { model_id } ' ." )
162
158
ds_model = self ._validate_model_ocid (model_id )
163
- return self ._get_model_config (ds_model ), ds_model .display_name
159
+ config = self ._fetch_hf_config (model_id )
160
+ model_name = ds_model .display_name
161
+ else :
162
+ logger .info (
163
+ f"Assuming Hugging Face model ID: Fetching config for '{ model_id } '."
164
+ )
165
+ config = self ._fetch_hf_config (model_id )
166
+ model_name = model_id
164
167
165
- logger .info (
166
- f"'{ model_id } ' is not an OCID, treating as a Hugging Face model ID."
167
- )
168
- # if not compartment_id:
169
- # compartment_id = os.environ.get(
170
- # "NB_SESSION_COMPARTMENT_OCID"
171
- # ) or os.environ.get("PROJECT_COMPARTMENT_OCID")
172
- # if compartment_id:
173
- # logger.info(f"Using compartment_id from environment: {compartment_id}")
174
- # if not compartment_id:
175
- # raise AquaValueError(
176
- # "A compartment OCID is required to list available shapes. "
177
- # "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
178
- # "or 'PROJECT_COMPARTMENT_OCID' environment variable."
179
- # "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
180
- # )
181
-
182
- # ds_model = self._search_model_in_catalog(model_id, compartment_id)
183
- # if ds_model:
184
- # logger.info("Loading configuration from existing model catalog artifact.")
185
- # try:
186
- # return (
187
- # self._get_model_config(ds_model),
188
- # ds_model.display_name,
189
- # )
190
- # except AquaFileNotFoundError:
191
- # logger.warning(
192
- # "config.json not found in artifact, fetching from Hugging Face Hub."
193
- # )
194
-
195
- return self ._fetch_hf_config (model_id ), model_id
168
+ return config , model_name
196
169
197
170
def _fetch_hf_config (self , model_id : str ) -> Dict :
198
171
"""
@@ -204,38 +177,7 @@ def _fetch_hf_config(self, model_id: str) -> Dict:
204
177
with open (config_path , "r" , encoding = "utf-8" ) as f :
205
178
return json .load (f )
206
179
except HfHubHTTPError as e :
207
- if "401" in str (e ):
208
- raise AquaValueError (
209
- f"Model '{ model_id } ' requires authentication. Please set your HuggingFace access token as an environment variable (HF_TOKEN). cli command: export HF_TOKEN=<HF_TOKEN>"
210
- )
211
- elif "404" in str (e ) or "not found" in str (e ).lower ():
212
- raise AquaValueError (
213
- f"Model '{ model_id } ' not found on HuggingFace. Please check the name for typos."
214
- )
215
- raise AquaValueError (
216
- f"Failed to download config for '{ model_id } ': { e } "
217
- ) from e
218
-
219
- # def _search_model_in_catalog(
220
- # self, model_id: str, compartment_id: str
221
- # ) -> Optional[DataScienceModel]:
222
- # """
223
- # Searches for a model in the Data Science catalog by its display name.
224
- # """
225
- # try:
226
- # models = DataScienceModel.list(
227
- # compartment_id=compartment_id, display_name=model_id
228
- # )
229
- # if len(models) > 1:
230
- # logger.warning(
231
- # f"Found multiple models with the name '{model_id}'. Using the first one found."
232
- # )
233
- # if models:
234
- # logger.info(f"Found model '{model_id}' in the Data Science catalog.")
235
- # return models[0]
236
- # except Exception as e:
237
- # logger.warning(f"Could not search for model '{model_id}' in catalog: {e}")
238
- # return None
180
+ format_hf_custom_error_message (e )
239
181
240
182
def valid_compute_shapes (
241
183
self , compartment_id : Optional [str ] = None
@@ -260,18 +202,16 @@ def valid_compute_shapes(
260
202
environment variables.
261
203
"""
262
204
if not compartment_id :
263
- compartment_id = os .environ .get (
264
- "NB_SESSION_COMPARTMENT_OCID"
265
- ) or os .environ .get ("PROJECT_COMPARTMENT_OCID" )
205
+ compartment_id = COMPARTMENT_OCID
266
206
if compartment_id :
267
207
logger .info (f"Using compartment_id from environment: { compartment_id } " )
268
208
269
209
if not compartment_id :
270
210
raise AquaValueError (
271
211
"A compartment OCID is required to list available shapes. "
272
- "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
273
- "or 'PROJECT_COMPARTMENT_OCID' environment variable. "
274
- "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
212
+ "Please specify it using the --compartment_id parameter. \n \n "
213
+ "Example: \n "
214
+ 'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
275
215
)
276
216
277
217
oci_shapes = OCIDataScienceModelDeployment .shapes (compartment_id = compartment_id )
0 commit comments