4
4
5
5
import json
6
6
import os
7
+ import re
7
8
import shutil
8
9
from typing import Dict , List , Optional , Tuple , Union
9
10
10
- from pydantic import ValidationError
11
- from rich .table import Table
12
-
13
11
from huggingface_hub import hf_hub_download
14
12
from huggingface_hub .utils import HfHubHTTPError
13
+ from pydantic import ValidationError
14
+ from rich .table import Table
15
15
16
16
from ads .aqua .app import logger
17
17
from ads .aqua .common .entities import ComputeShapeSummary
22
22
)
23
23
from ads .aqua .common .utils import (
24
24
build_pydantic_error_message ,
25
+ format_hf_custom_error_message ,
25
26
get_resource_type ,
26
27
is_valid_ocid ,
27
28
load_config ,
28
29
load_gpu_shapes_index ,
29
- format_hf_custom_error_message ,
30
30
)
31
31
from ads .aqua .shaperecommend .constants import (
32
- BITSANDBYTES ,
33
32
BITS_AND_BYTES_4BIT ,
33
+ BITSANDBYTES ,
34
34
SAFETENSORS ,
35
35
SHAPE_MAP ,
36
36
TEXT_GENERATION ,
@@ -98,14 +98,10 @@ def which_shapes(
98
98
"""
99
99
try :
100
100
shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
101
-
101
+
102
102
data , model_name = self ._get_model_config_and_name (
103
103
model_id = request .model_id ,
104
104
)
105
- llm_config = LLMConfig .from_raw_config (data )
106
- shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
107
- llm_config , shapes , model_name
108
- )
109
105
110
106
if request .deployment_config :
111
107
shape_recommendation_report = (
@@ -115,16 +111,11 @@ def which_shapes(
115
111
)
116
112
117
113
else :
118
- ds_model = self ._get_data_science_model (request .model_id )
119
-
120
- data = self ._get_model_config (ds_model )
121
-
122
114
llm_config = LLMConfig .from_raw_config (data )
123
115
124
116
shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
125
117
llm_config , shapes , model_name
126
118
)
127
-
128
119
129
120
if request .generate_table and shape_recommendation_report .recommendations :
130
121
shape_recommendation_report = self ._rich_diff_table (
@@ -174,8 +165,8 @@ def _get_model_config_and_name(
174
165
"""
175
166
if is_valid_ocid (model_id ):
176
167
logger .info (f"Detected OCID: Fetching OCI model config for '{ model_id } '." )
177
- ds_model = self ._validate_model_ocid (model_id )
178
- config = self ._fetch_hf_config ( model_id )
168
+ ds_model = self ._get_data_science_model (model_id )
169
+ config = self ._get_model_config ( ds_model )
179
170
model_name = ds_model .display_name
180
171
else :
181
172
logger .info (
@@ -403,6 +394,7 @@ def _get_model_config(model: DataScienceModel):
403
394
"""
404
395
405
396
model_task = model .freeform_tags .get ("task" , "" ).lower ()
397
+ model_task = re .sub (r"-" , "_" , model_task )
406
398
model_format = model .freeform_tags .get ("model_format" , "" ).lower ()
407
399
408
400
logger .info (f"Current model task type: { model_task } " )
0 commit comments