48
48
OCIDataScienceModelDeployment ,
49
49
)
50
50
51
+
51
52
class HuggingFaceModelFetcher :
52
53
"""
53
54
Utility class to fetch model configurations from HuggingFace.
@@ -57,7 +58,7 @@ class HuggingFaceModelFetcher:
57
58
def is_huggingface_model_id (cls , model_id : str ) -> bool :
58
59
if is_valid_ocid (model_id ):
59
60
return False
60
- hf_pattern = r' ^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$'
61
+ hf_pattern = r" ^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$"
61
62
return bool (re .match (hf_pattern , model_id ))
62
63
63
64
@classmethod
@@ -80,12 +81,19 @@ def fetch_config_only(cls, model_id: str) -> Dict[str, Any]:
80
81
elif response .status_code == 404 :
81
82
raise AquaValueError (f"Model '{ model_id } ' not found on HuggingFace." )
82
83
elif response .status_code != 200 :
83
- raise AquaValueError (f"Failed to fetch config for '{ model_id } '. Status: { response .status_code } " )
84
+ raise AquaValueError (
85
+ f"Failed to fetch config for '{ model_id } '. Status: { response .status_code } "
86
+ )
84
87
return response .json ()
85
88
except requests .RequestException as e :
86
- raise AquaValueError (f"Network error fetching config for { model_id } : { e } " ) from e
89
+ raise AquaValueError (
90
+ f"Network error fetching config for { model_id } : { e } "
91
+ ) from e
87
92
except json .JSONDecodeError as e :
88
- raise AquaValueError (f"Invalid config format for model '{ model_id } '." ) from e
93
+ raise AquaValueError (
94
+ f"Invalid config format for model '{ model_id } '."
95
+ ) from e
96
+
89
97
90
98
class AquaShapeRecommend :
91
99
"""
@@ -135,7 +143,9 @@ def which_shapes(
135
143
"""
136
144
try :
137
145
shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
138
- data , model_name = self ._get_model_config_and_name (request .model_id , request .compartment_id )
146
+ data , model_name = self ._get_model_config_and_name (
147
+ request .model_id , request .compartment_id
148
+ )
139
149
llm_config = LLMConfig .from_raw_config (data )
140
150
shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
141
151
llm_config , shapes , model_name
@@ -165,40 +175,55 @@ def which_shapes(
165
175
166
176
return shape_recommendation_report
167
177
168
- def _get_model_config_and_name (self , model_id : str , compartment_id : str ) -> (dict , str ):
178
+ def _get_model_config_and_name (
179
+ self , model_id : str , compartment_id : str
180
+ ) -> (dict , str ):
169
181
"""
170
182
Loads model configuration, handling OCID and Hugging Face model IDs.
171
183
"""
172
184
if HuggingFaceModelFetcher .is_huggingface_model_id (model_id ):
173
185
logger .info (f"'{ model_id } ' identified as a Hugging Face model ID." )
174
186
ds_model = self ._search_model_in_catalog (model_id , compartment_id )
175
187
if ds_model and ds_model .artifact :
176
- logger .info ("Loading configuration from existing model catalog artifact." )
188
+ logger .info (
189
+ "Loading configuration from existing model catalog artifact."
190
+ )
177
191
try :
178
- return load_config (ds_model .artifact , "config.json" ), ds_model .display_name
192
+ return (
193
+ load_config (ds_model .artifact , "config.json" ),
194
+ ds_model .display_name ,
195
+ )
179
196
except AquaFileNotFoundError :
180
- logger .warning ("config.json not found in artifact, fetching from Hugging Face Hub." )
197
+ logger .warning (
198
+ "config.json not found in artifact, fetching from Hugging Face Hub."
199
+ )
181
200
return HuggingFaceModelFetcher .fetch_config_only (model_id ), model_id
182
201
else :
183
202
logger .info (f"'{ model_id } ' identified as a model OCID." )
184
203
ds_model = self ._validate_model_ocid (model_id )
185
204
return self ._get_model_config (ds_model ), ds_model .display_name
186
205
187
- def _search_model_in_catalog (self , model_id : str , compartment_id : str ) -> Optional [DataScienceModel ]:
206
+ def _search_model_in_catalog (
207
+ self , model_id : str , compartment_id : str
208
+ ) -> Optional [DataScienceModel ]:
188
209
"""
189
210
Searches for a Hugging Face model in the Data Science model catalog by display name.
190
211
"""
191
212
try :
192
213
# This should work since the SDK's list method can filter by display_name.
193
- models = DataScienceModel .list (compartment_id = compartment_id , display_name = model_id )
214
+ models = DataScienceModel .list (
215
+ compartment_id = compartment_id , display_name = model_id
216
+ )
194
217
if models :
195
218
logger .info (f"Found model '{ model_id } ' in the Data Science catalog." )
196
219
return models [0 ]
197
220
except Exception as e :
198
221
logger .warning (f"Could not search for model '{ model_id } ' in catalog: { e } " )
199
222
return None
200
223
201
- def valid_compute_shapes (self , compartment_id : Optional [str ] = None ) -> List ["ComputeShapeSummary" ]:
224
+ def valid_compute_shapes (
225
+ self , compartment_id : Optional [str ] = None
226
+ ) -> List ["ComputeShapeSummary" ]:
202
227
"""
203
228
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
204
229
@@ -219,7 +244,9 @@ def valid_compute_shapes(self, compartment_id: Optional[str] = None) -> List["Co
219
244
environment variables.
220
245
"""
221
246
if not compartment_id :
222
- compartment_id = os .environ .get ("NB_SESSION_COMPARTMENT_OCID" ) or os .environ .get ("PROJECT_COMPARTMENT_OCID" )
247
+ compartment_id = os .environ .get (
248
+ "NB_SESSION_COMPARTMENT_OCID"
249
+ ) or os .environ .get ("PROJECT_COMPARTMENT_OCID" )
223
250
if compartment_id :
224
251
logger .info (f"Using compartment_id from environment: { compartment_id } " )
225
252
0 commit comments