Skip to content

Commit 669c21c

Browse files
authored
Merge branch 'main' into fix/ODSC-77609
2 parents 653b82d + 9e141cb commit 669c21c

File tree

3 files changed

+138
-18
lines changed

3 files changed

+138
-18
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
# Copyright (c) 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import json
6+
import os
7+
import re
58
import shutil
6-
from typing import List, Union
9+
from typing import Dict, List, Optional, Tuple, Union
710

11+
from huggingface_hub import hf_hub_download
12+
from huggingface_hub.utils import HfHubHTTPError
813
from pydantic import ValidationError
914
from rich.table import Table
1015

@@ -17,7 +22,9 @@
1722
)
1823
from ads.aqua.common.utils import (
1924
build_pydantic_error_message,
25+
format_hf_custom_error_message,
2026
get_resource_type,
27+
is_valid_ocid,
2128
load_config,
2229
load_gpu_shapes_index,
2330
)
@@ -37,6 +44,7 @@
3744
ShapeRecommendationReport,
3845
ShapeReport,
3946
)
47+
from ads.config import COMPARTMENT_OCID
4048
from ads.model.datascience_model import DataScienceModel
4149
from ads.model.service.oci_datascience_model_deployment import (
4250
OCIDataScienceModelDeployment,
@@ -91,20 +99,23 @@ def which_shapes(
9199
try:
92100
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
93101

94-
ds_model = self._get_data_science_model(request.model_id)
95-
96-
model_name = ds_model.display_name if ds_model.display_name else ""
97-
98102
if request.deployment_config:
103+
if is_valid_ocid(request.model_id):
104+
ds_model = self._get_data_science_model(request.model_id)
105+
model_name = ds_model.display_name
106+
else:
107+
model_name = request.model_id
108+
99109
shape_recommendation_report = (
100110
ShapeRecommendationReport.from_deployment_config(
101111
request.deployment_config, model_name, shapes
102112
)
103113
)
104114

105115
else:
106-
data = self._get_model_config(ds_model)
107-
116+
data, model_name = self._get_model_config_and_name(
117+
model_id=request.model_id,
118+
)
108119
llm_config = LLMConfig.from_raw_config(data)
109120

110121
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
@@ -135,7 +146,57 @@ def which_shapes(
135146

136147
return shape_recommendation_report
137148

138-
def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]:
149+
def _get_model_config_and_name(
150+
self,
151+
model_id: str,
152+
) -> Tuple[Dict, str]:
153+
"""
154+
Loads model configuration by trying OCID logic first, then falling back
155+
to treating the model_id as a Hugging Face Hub ID.
156+
157+
Parameters
158+
----------
159+
model_id : str
160+
The model OCID or Hugging Face model ID.
161+
# compartment_id : Optional[str]
162+
# The compartment OCID, used for searching the model catalog.
163+
164+
Returns
165+
-------
166+
Tuple[Dict, str]
167+
A tuple containing:
168+
- The model configuration dictionary.
169+
- The display name for the model.
170+
"""
171+
if is_valid_ocid(model_id):
172+
logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.")
173+
ds_model = self._get_data_science_model(model_id)
174+
config = self._get_model_config(ds_model)
175+
model_name = ds_model.display_name
176+
else:
177+
logger.info(
178+
f"Assuming Hugging Face model ID: Fetching config for '{model_id}'."
179+
)
180+
config = self._fetch_hf_config(model_id)
181+
model_name = model_id
182+
183+
return config, model_name
184+
185+
def _fetch_hf_config(self, model_id: str) -> Dict:
186+
"""
187+
Downloads a model's config.json from Hugging Face Hub using the
188+
huggingface_hub library.
189+
"""
190+
try:
191+
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
192+
with open(config_path, "r", encoding="utf-8") as f:
193+
return json.load(f)
194+
except HfHubHTTPError as e:
195+
format_hf_custom_error_message(e)
196+
197+
def valid_compute_shapes(
198+
self, compartment_id: Optional[str] = None
199+
) -> List["ComputeShapeSummary"]:
139200
"""
140201
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
141202
@@ -151,9 +212,23 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary
151212
152213
Raises
153214
------
154-
ValueError
155-
If the file cannot be opened, parsed, or the 'shapes' key is missing.
215+
AquaValueError
216+
If a compartment_id is not provided and cannot be found in the
217+
environment variables.
156218
"""
219+
if not compartment_id:
220+
compartment_id = COMPARTMENT_OCID
221+
if compartment_id:
222+
logger.info(f"Using compartment_id from environment: {compartment_id}")
223+
224+
if not compartment_id:
225+
raise AquaValueError(
226+
"A compartment OCID is required to list available shapes. "
227+
"Please specify it using the --compartment_id parameter.\n\n"
228+
"Example:\n"
229+
'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
230+
)
231+
157232
oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id)
158233
set_user_shapes = {shape.name: shape for shape in oci_shapes}
159234

@@ -324,6 +399,7 @@ def _get_model_config(model: DataScienceModel):
324399
"""
325400

326401
model_task = model.freeform_tags.get("task", "").lower()
402+
model_task = re.sub(r"-", "_", model_task)
327403
model_format = model.freeform_tags.get("model_format", "").lower()
328404

329405
logger.info(f"Current model task type: {model_task}")

ads/aqua/shaperecommend/shape_report.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class RequestRecommend(BaseModel):
2929
"""
3030

3131
model_id: str = Field(
32-
..., description="The OCID of the model to recommend feasible compute shapes."
32+
...,
33+
description="The OCID or Hugging Face ID of the model to recommend feasible compute shapes.",
3334
)
3435
generate_table: Optional[bool] = (
3536
Field(

tests/unitary/with_extras/aqua/test_recommend.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import json
88
import os
99
import re
10-
from unittest.mock import MagicMock
11-
10+
from unittest.mock import MagicMock, mock_open
11+
from unittest.mock import patch
1212
import pytest
1313

1414
from ads.aqua.common.entities import ComputeShapeSummary
15-
from ads.aqua.common.errors import AquaRecommendationError
15+
16+
from ads.aqua.common.errors import AquaRecommendationError, AquaValueError
1617
from ads.aqua.modeldeployment.config_loader import AquaDeploymentConfig
1718
from ads.aqua.shaperecommend.estimator import (
1819
LlamaMemoryEstimator,
@@ -21,7 +22,9 @@
2122
get_estimator,
2223
)
2324
from ads.aqua.shaperecommend.llm_config import LLMConfig
24-
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
25+
from ads.aqua.shaperecommend.recommend import (
26+
AquaShapeRecommend,
27+
)
2528
from ads.aqua.shaperecommend.shape_report import (
2629
DeploymentParams,
2730
ModelConfig,
@@ -275,6 +278,41 @@ def create(config_file=""):
275278

276279

277280
class TestAquaShapeRecommend:
281+
282+
@patch("ads.aqua.shaperecommend.recommend.hf_hub_download")
283+
@patch("builtins.open", new_callable=mock_open)
284+
def test_fetch_hf_config_success(self, mock_file, mock_download):
285+
"""Test successful config fetch from Hugging Face"""
286+
app = AquaShapeRecommend()
287+
model_id = "test/model"
288+
config_path = "/fake/path/config.json"
289+
expected_config = {"model_type": "llama", "hidden_size": 4096}
290+
291+
mock_download.return_value = config_path
292+
mock_file.return_value.read.return_value = json.dumps(expected_config)
293+
294+
result = app._fetch_hf_config(model_id)
295+
296+
assert result == expected_config
297+
mock_download.assert_called_once_with(repo_id=model_id, filename="config.json")
298+
299+
@patch("ads.aqua.shaperecommend.recommend.hf_hub_download")
300+
@patch("ads.aqua.shaperecommend.recommend.format_hf_custom_error_message")
301+
def test_fetch_hf_config_http_error(self, mock_format_error, mock_download):
302+
"""Test error handling when Hugging Face request fails"""
303+
from huggingface_hub.utils import HfHubHTTPError
304+
305+
app = AquaShapeRecommend()
306+
model_id = "nonexistent/model"
307+
http_error = HfHubHTTPError("Model not found")
308+
mock_download.side_effect = http_error
309+
310+
# The method doesn't re-raise, so it returns None
311+
result = app._fetch_hf_config(model_id)
312+
313+
assert result is None
314+
mock_format_error.assert_called_once_with(http_error)
315+
278316
@pytest.mark.parametrize(
279317
"config, expected_recs, expected_troubleshoot",
280318
[
@@ -398,18 +436,23 @@ def test_which_shapes_valid_from_file(
398436
)[1],
399437
)
400438

401-
raw = load_config(config_file)
439+
mock_raw_config = load_config(config_file)
440+
mock_ds_model_name = mock_model.display_name
402441

403442
if service_managed_model:
404-
config = AquaDeploymentConfig(**raw)
443+
config = AquaDeploymentConfig(**mock_raw_config)
405444

406445
request = RequestRecommend(
407446
model_id="ocid1.datasciencemodel.oc1.TEST",
408447
generate_table=False,
409448
deployment_config=config,
410449
)
411450
else:
412-
monkeypatch.setattr(app, "_get_model_config", lambda _: raw)
451+
monkeypatch.setattr(
452+
app,
453+
"_get_model_config_and_name",
454+
lambda model_id: (mock_raw_config, mock_ds_model_name),
455+
)
413456

414457
request = RequestRecommend(
415458
model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False

0 commit comments

Comments
 (0)