@@ -99,18 +99,23 @@ def which_shapes(
99
99
try :
100
100
shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
101
101
102
- data , model_name = self ._get_model_config_and_name (
103
- model_id = request .model_id ,
104
- )
105
-
106
102
if request .deployment_config :
103
+ if is_valid_ocid (request .model_id ):
104
+ ds_model = self ._get_data_science_model (request .model_id )
105
+ model_name = ds_model .display_name
106
+ else :
107
+ model_name = request .model_id
108
+
107
109
shape_recommendation_report = (
108
110
ShapeRecommendationReport .from_deployment_config (
109
111
request .deployment_config , model_name , shapes
110
112
)
111
113
)
112
114
113
115
else :
116
+ data , model_name = self ._get_model_config_and_name (
117
+ model_id = request .model_id ,
118
+ )
114
119
llm_config = LLMConfig .from_raw_config (data )
115
120
116
121
shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
@@ -394,7 +399,7 @@ def _get_model_config(model: DataScienceModel):
394
399
"""
395
400
396
401
model_task = model .freeform_tags .get ("task" , "" ).lower ()
397
- model_task = re .sub (r"-" , "_" , model_task )
402
+ model_task = re .sub (r"-" , "_" , model_task )
398
403
model_format = model .freeform_tags .get ("model_format" , "" ).lower ()
399
404
400
405
logger .info (f"Current model task type: { model_task } " )
0 commit comments