2
2
# Copyright (c) 2025 Oracle and/or its affiliates.
3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
5
+ import json
6
+ import os
7
+ import re
5
8
import shutil
6
- from typing import List , Union
9
+ from typing import Dict , List , Optional , Tuple , Union
7
10
11
+ from huggingface_hub import hf_hub_download
12
+ from huggingface_hub .utils import HfHubHTTPError
8
13
from pydantic import ValidationError
9
14
from rich .table import Table
10
15
17
22
)
18
23
from ads .aqua .common .utils import (
19
24
build_pydantic_error_message ,
25
+ format_hf_custom_error_message ,
20
26
get_resource_type ,
27
+ is_valid_ocid ,
21
28
load_config ,
22
29
load_gpu_shapes_index ,
23
30
)
37
44
ShapeRecommendationReport ,
38
45
ShapeReport ,
39
46
)
47
+ from ads .config import COMPARTMENT_OCID
40
48
from ads .model .datascience_model import DataScienceModel
41
49
from ads .model .service .oci_datascience_model_deployment import (
42
50
OCIDataScienceModelDeployment ,
@@ -91,20 +99,23 @@ def which_shapes(
91
99
try :
92
100
shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
93
101
94
- ds_model = self ._get_data_science_model (request .model_id )
95
-
96
- model_name = ds_model .display_name if ds_model .display_name else ""
97
-
98
102
if request .deployment_config :
103
+ if is_valid_ocid (request .model_id ):
104
+ ds_model = self ._get_data_science_model (request .model_id )
105
+ model_name = ds_model .display_name
106
+ else :
107
+ model_name = request .model_id
108
+
99
109
shape_recommendation_report = (
100
110
ShapeRecommendationReport .from_deployment_config (
101
111
request .deployment_config , model_name , shapes
102
112
)
103
113
)
104
114
105
115
else :
106
- data = self ._get_model_config (ds_model )
107
-
116
+ data , model_name = self ._get_model_config_and_name (
117
+ model_id = request .model_id ,
118
+ )
108
119
llm_config = LLMConfig .from_raw_config (data )
109
120
110
121
shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
@@ -135,7 +146,57 @@ def which_shapes(
135
146
136
147
return shape_recommendation_report
137
148
138
- def valid_compute_shapes (self , compartment_id : str ) -> List ["ComputeShapeSummary" ]:
149
+ def _get_model_config_and_name (
150
+ self ,
151
+ model_id : str ,
152
+ ) -> Tuple [Dict , str ]:
153
+ """
154
+ Loads model configuration by trying OCID logic first, then falling back
155
+ to treating the model_id as a Hugging Face Hub ID.
156
+
157
+ Parameters
158
+ ----------
159
+ model_id : str
160
+ The model OCID or Hugging Face model ID.
161
+ # compartment_id : Optional[str]
162
+ # The compartment OCID, used for searching the model catalog.
163
+
164
+ Returns
165
+ -------
166
+ Tuple[Dict, str]
167
+ A tuple containing:
168
+ - The model configuration dictionary.
169
+ - The display name for the model.
170
+ """
171
+ if is_valid_ocid (model_id ):
172
+ logger .info (f"Detected OCID: Fetching OCI model config for '{ model_id } '." )
173
+ ds_model = self ._get_data_science_model (model_id )
174
+ config = self ._get_model_config (ds_model )
175
+ model_name = ds_model .display_name
176
+ else :
177
+ logger .info (
178
+ f"Assuming Hugging Face model ID: Fetching config for '{ model_id } '."
179
+ )
180
+ config = self ._fetch_hf_config (model_id )
181
+ model_name = model_id
182
+
183
+ return config , model_name
184
+
185
+ def _fetch_hf_config (self , model_id : str ) -> Dict :
186
+ """
187
+ Downloads a model's config.json from Hugging Face Hub using the
188
+ huggingface_hub library.
189
+ """
190
+ try :
191
+ config_path = hf_hub_download (repo_id = model_id , filename = "config.json" )
192
+ with open (config_path , "r" , encoding = "utf-8" ) as f :
193
+ return json .load (f )
194
+ except HfHubHTTPError as e :
195
+ format_hf_custom_error_message (e )
196
+
197
+ def valid_compute_shapes (
198
+ self , compartment_id : Optional [str ] = None
199
+ ) -> List ["ComputeShapeSummary" ]:
139
200
"""
140
201
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
141
202
@@ -151,9 +212,23 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary
151
212
152
213
Raises
153
214
------
154
- ValueError
155
- If the file cannot be opened, parsed, or the 'shapes' key is missing.
215
+ AquaValueError
216
+ If a compartment_id is not provided and cannot be found in the
217
+ environment variables.
156
218
"""
219
+ if not compartment_id :
220
+ compartment_id = COMPARTMENT_OCID
221
+ if compartment_id :
222
+ logger .info (f"Using compartment_id from environment: { compartment_id } " )
223
+
224
+ if not compartment_id :
225
+ raise AquaValueError (
226
+ "A compartment OCID is required to list available shapes. "
227
+ "Please specify it using the --compartment_id parameter.\n \n "
228
+ "Example:\n "
229
+ 'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
230
+ )
231
+
157
232
oci_shapes = OCIDataScienceModelDeployment .shapes (compartment_id = compartment_id )
158
233
set_user_shapes = {shape .name : shape for shape in oci_shapes }
159
234
@@ -324,6 +399,7 @@ def _get_model_config(model: DataScienceModel):
324
399
"""
325
400
326
401
model_task = model .freeform_tags .get ("task" , "" ).lower ()
402
+ model_task = re .sub (r"-" , "_" , model_task )
327
403
model_format = model .freeform_tags .get ("model_format" , "" ).lower ()
328
404
329
405
logger .info (f"Current model task type: { model_task } " )
0 commit comments