Skip to content

Commit 2a9d843

Browse files
committed
fixed unit tests
1 parent 5088b85 commit 2a9d843

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,23 @@ def which_shapes(
9999
try:
100100
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
101101

102-
data, model_name = self._get_model_config_and_name(
103-
model_id=request.model_id,
104-
)
105-
106102
if request.deployment_config:
103+
if is_valid_ocid(request.model_id):
104+
ds_model = self._get_data_science_model(request.model_id)
105+
model_name = ds_model.display_name
106+
else:
107+
model_name = request.model_id
108+
107109
shape_recommendation_report = (
108110
ShapeRecommendationReport.from_deployment_config(
109111
request.deployment_config, model_name, shapes
110112
)
111113
)
112114

113115
else:
116+
data, model_name = self._get_model_config_and_name(
117+
model_id=request.model_id,
118+
)
114119
llm_config = LLMConfig.from_raw_config(data)
115120

116121
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
@@ -394,7 +399,7 @@ def _get_model_config(model: DataScienceModel):
394399
"""
395400

396401
model_task = model.freeform_tags.get("task", "").lower()
397-
model_task = re.sub(r"-", "_", model_task)
402+
model_task = re.sub(r"-", "_", model_task)
398403
model_format = model.freeform_tags.get("model_format", "").lower()
399404

400405
logger.info(f"Current model task type: {model_task}")

tests/unitary/with_extras/aqua/test_recommend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,11 @@ def test_which_shapes_valid_from_file(
448448
deployment_config=config,
449449
)
450450
else:
451-
monkeypatch.setattr(app, "_get_model_config_and_name", lambda _: (mock_ds_model_name, mock_raw_config))
451+
monkeypatch.setattr(
452+
app,
453+
"_get_model_config_and_name",
454+
lambda model_id: (mock_raw_config, mock_ds_model_name),
455+
)
452456

453457
request = RequestRecommend(
454458
model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False

0 commit comments

Comments
 (0)