Skip to content

Commit 5088b85

Browse files
committed
modified unit tests
1 parent 287b406 commit 5088b85

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import json
66
import os
7+
import re
78
import shutil
89
from typing import Dict, List, Optional, Tuple, Union
910

10-
from pydantic import ValidationError
11-
from rich.table import Table
12-
1311
from huggingface_hub import hf_hub_download
1412
from huggingface_hub.utils import HfHubHTTPError
13+
from pydantic import ValidationError
14+
from rich.table import Table
1515

1616
from ads.aqua.app import logger
1717
from ads.aqua.common.entities import ComputeShapeSummary
@@ -22,15 +22,15 @@
2222
)
2323
from ads.aqua.common.utils import (
2424
build_pydantic_error_message,
25+
format_hf_custom_error_message,
2526
get_resource_type,
2627
is_valid_ocid,
2728
load_config,
2829
load_gpu_shapes_index,
29-
format_hf_custom_error_message,
3030
)
3131
from ads.aqua.shaperecommend.constants import (
32-
BITSANDBYTES,
3332
BITS_AND_BYTES_4BIT,
33+
BITSANDBYTES,
3434
SAFETENSORS,
3535
SHAPE_MAP,
3636
TEXT_GENERATION,
@@ -98,14 +98,10 @@ def which_shapes(
9898
"""
9999
try:
100100
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
101-
101+
102102
data, model_name = self._get_model_config_and_name(
103103
model_id=request.model_id,
104104
)
105-
llm_config = LLMConfig.from_raw_config(data)
106-
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
107-
llm_config, shapes, model_name
108-
)
109105

110106
if request.deployment_config:
111107
shape_recommendation_report = (
@@ -115,16 +111,11 @@ def which_shapes(
115111
)
116112

117113
else:
118-
ds_model = self._get_data_science_model(request.model_id)
119-
120-
data = self._get_model_config(ds_model)
121-
122114
llm_config = LLMConfig.from_raw_config(data)
123115

124116
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
125117
llm_config, shapes, model_name
126118
)
127-
128119

129120
if request.generate_table and shape_recommendation_report.recommendations:
130121
shape_recommendation_report = self._rich_diff_table(
@@ -174,8 +165,8 @@ def _get_model_config_and_name(
174165
"""
175166
if is_valid_ocid(model_id):
176167
logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.")
177-
ds_model = self._validate_model_ocid(model_id)
178-
config = self._fetch_hf_config(model_id)
168+
ds_model = self._get_data_science_model(model_id)
169+
config = self._get_model_config(ds_model)
179170
model_name = ds_model.display_name
180171
else:
181172
logger.info(
@@ -403,6 +394,7 @@ def _get_model_config(model: DataScienceModel):
403394
"""
404395

405396
model_task = model.freeform_tags.get("task", "").lower()
397+
model_task = re.sub(r"-", "_", model_task)
406398
model_format = model.freeform_tags.get("model_format", "").lower()
407399

408400
logger.info(f"Current model task type: {model_task}")

tests/unitary/with_extras/aqua/test_recommend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,18 +436,19 @@ def test_which_shapes_valid_from_file(
436436
)[1],
437437
)
438438

439-
raw = load_config(config_file)
439+
mock_raw_config = load_config(config_file)
440+
mock_ds_model_name = mock_model.display_name
440441

441442
if service_managed_model:
442-
config = AquaDeploymentConfig(**raw)
443+
config = AquaDeploymentConfig(**mock_raw_config)
443444

444445
request = RequestRecommend(
445446
model_id="ocid1.datasciencemodel.oc1.TEST",
446447
generate_table=False,
447448
deployment_config=config,
448449
)
449450
else:
450-
monkeypatch.setattr(app, "_get_model_config", lambda _: raw)
451+
monkeypatch.setattr(app, "_get_model_config_and_name", lambda _: (mock_ds_model_name, mock_raw_config))
451452

452453
request = RequestRecommend(
453454
model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False

0 commit comments

Comments
 (0)