Skip to content

Commit 29bd22c

Browse files
committed
added black formatting
1 parent 367ce70 commit 29bd22c

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
OCIDataScienceModelDeployment,
4949
)
5050

51+
5152
class HuggingFaceModelFetcher:
5253
"""
5354
Utility class to fetch model configurations from HuggingFace.
@@ -57,7 +58,7 @@ class HuggingFaceModelFetcher:
5758
def is_huggingface_model_id(cls, model_id: str) -> bool:
5859
if is_valid_ocid(model_id):
5960
return False
60-
hf_pattern = r'^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$'
61+
hf_pattern = r"^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$"
6162
return bool(re.match(hf_pattern, model_id))
6263

6364
@classmethod
@@ -80,12 +81,19 @@ def fetch_config_only(cls, model_id: str) -> Dict[str, Any]:
8081
elif response.status_code == 404:
8182
raise AquaValueError(f"Model '{model_id}' not found on HuggingFace.")
8283
elif response.status_code != 200:
83-
raise AquaValueError(f"Failed to fetch config for '{model_id}'. Status: {response.status_code}")
84+
raise AquaValueError(
85+
f"Failed to fetch config for '{model_id}'. Status: {response.status_code}"
86+
)
8487
return response.json()
8588
except requests.RequestException as e:
86-
raise AquaValueError(f"Network error fetching config for {model_id}: {e}") from e
89+
raise AquaValueError(
90+
f"Network error fetching config for {model_id}: {e}"
91+
) from e
8792
except json.JSONDecodeError as e:
88-
raise AquaValueError(f"Invalid config format for model '{model_id}'.") from e
93+
raise AquaValueError(
94+
f"Invalid config format for model '{model_id}'."
95+
) from e
96+
8997

9098
class AquaShapeRecommend:
9199
"""
@@ -135,7 +143,9 @@ def which_shapes(
135143
"""
136144
try:
137145
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
138-
data, model_name = self._get_model_config_and_name(request.model_id, request.compartment_id)
146+
data, model_name = self._get_model_config_and_name(
147+
request.model_id, request.compartment_id
148+
)
139149
llm_config = LLMConfig.from_raw_config(data)
140150
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
141151
llm_config, shapes, model_name
@@ -165,40 +175,55 @@ def which_shapes(
165175

166176
return shape_recommendation_report
167177

168-
def _get_model_config_and_name(self, model_id: str, compartment_id: str) -> (dict, str):
178+
def _get_model_config_and_name(
179+
self, model_id: str, compartment_id: str
180+
) -> (dict, str):
169181
"""
170182
Loads model configuration, handling OCID and Hugging Face model IDs.
171183
"""
172184
if HuggingFaceModelFetcher.is_huggingface_model_id(model_id):
173185
logger.info(f"'{model_id}' identified as a Hugging Face model ID.")
174186
ds_model = self._search_model_in_catalog(model_id, compartment_id)
175187
if ds_model and ds_model.artifact:
176-
logger.info("Loading configuration from existing model catalog artifact.")
188+
logger.info(
189+
"Loading configuration from existing model catalog artifact."
190+
)
177191
try:
178-
return load_config(ds_model.artifact, "config.json"), ds_model.display_name
192+
return (
193+
load_config(ds_model.artifact, "config.json"),
194+
ds_model.display_name,
195+
)
179196
except AquaFileNotFoundError:
180-
logger.warning("config.json not found in artifact, fetching from Hugging Face Hub.")
197+
logger.warning(
198+
"config.json not found in artifact, fetching from Hugging Face Hub."
199+
)
181200
return HuggingFaceModelFetcher.fetch_config_only(model_id), model_id
182201
else:
183202
logger.info(f"'{model_id}' identified as a model OCID.")
184203
ds_model = self._validate_model_ocid(model_id)
185204
return self._get_model_config(ds_model), ds_model.display_name
186205

187-
def _search_model_in_catalog(self, model_id: str, compartment_id: str) -> Optional[DataScienceModel]:
206+
def _search_model_in_catalog(
207+
self, model_id: str, compartment_id: str
208+
) -> Optional[DataScienceModel]:
188209
"""
189210
Searches for a Hugging Face model in the Data Science model catalog by display name.
190211
"""
191212
try:
192213
# This should work since the SDK's list method can filter by display_name.
193-
models = DataScienceModel.list(compartment_id=compartment_id, display_name=model_id)
214+
models = DataScienceModel.list(
215+
compartment_id=compartment_id, display_name=model_id
216+
)
194217
if models:
195218
logger.info(f"Found model '{model_id}' in the Data Science catalog.")
196219
return models[0]
197220
except Exception as e:
198221
logger.warning(f"Could not search for model '{model_id}' in catalog: {e}")
199222
return None
200223

201-
def valid_compute_shapes(self, compartment_id: Optional[str] = None) -> List["ComputeShapeSummary"]:
224+
def valid_compute_shapes(
225+
self, compartment_id: Optional[str] = None
226+
) -> List["ComputeShapeSummary"]:
202227
"""
203228
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
204229
@@ -219,7 +244,9 @@ def valid_compute_shapes(self, compartment_id: Optional[str] = None) -> List["Co
219244
environment variables.
220245
"""
221246
if not compartment_id:
222-
compartment_id = os.environ.get("NB_SESSION_COMPARTMENT_OCID") or os.environ.get("PROJECT_COMPARTMENT_OCID")
247+
compartment_id = os.environ.get(
248+
"NB_SESSION_COMPARTMENT_OCID"
249+
) or os.environ.get("PROJECT_COMPARTMENT_OCID")
223250
if compartment_id:
224251
logger.info(f"Using compartment_id from environment: {compartment_id}")
225252

ads/aqua/shaperecommend/shape_report.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class RequestRecommend(BaseModel):
1818
"""
1919

2020
model_id: str = Field(
21-
..., description="The OCID or Hugging Face ID of the model to recommend feasible compute shapes."
21+
...,
22+
description="The OCID or Hugging Face ID of the model to recommend feasible compute shapes.",
2223
)
2324
generate_table: Optional[bool] = (
2425
Field(

tests/unitary/with_extras/aqua/test_recommend.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
get_estimator,
2121
)
2222
from ads.aqua.shaperecommend.llm_config import LLMConfig
23-
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend, HuggingFaceModelFetcher
23+
from ads.aqua.shaperecommend.recommend import (
24+
AquaShapeRecommend,
25+
HuggingFaceModelFetcher,
26+
)
2427
from ads.aqua.shaperecommend.shape_report import (
2528
DeploymentParams,
2629
ModelConfig,
@@ -455,16 +458,20 @@ def test_shape_report_pareto_front(self):
455458
assert a and b not in pf
456459
assert len(pf) == 2
457460

461+
458462
class TestHuggingFaceModelFetcher:
459-
@pytest.mark.parametrize("model_id, expected", [
460-
("meta-llama/Llama-2-7b-hf", True),
461-
("mistralai/Mistral-7B-v0.1", True),
462-
("ocid1.datasciencemodel.oc1.iad.xxxxxxxx", False),
463-
])
463+
@pytest.mark.parametrize(
464+
"model_id, expected",
465+
[
466+
("meta-llama/Llama-2-7b-hf", True),
467+
("mistralai/Mistral-7B-v0.1", True),
468+
("ocid1.datasciencemodel.oc1.iad.xxxxxxxx", False),
469+
],
470+
)
464471
def test_is_huggingface_model_id(self, model_id, expected):
465472
assert HuggingFaceModelFetcher.is_huggingface_model_id(model_id) == expected
466473

467-
@patch('requests.get')
474+
@patch("requests.get")
468475
def test_fetch_config_only_success(self, mock_get):
469476
mock_response = MagicMock()
470477
mock_response.status_code = 200
@@ -475,7 +482,7 @@ def test_fetch_config_only_success(self, mock_get):
475482
assert config == {"model_type": "llama"}
476483
mock_get.assert_called_once()
477484

478-
@patch('requests.get')
485+
@patch("requests.get")
479486
def test_fetch_config_only_not_found(self, mock_get):
480487
mock_response = MagicMock()
481488
mock_response.status_code = 404
@@ -487,7 +494,6 @@ def test_fetch_config_only_not_found(self, mock_get):
487494
@patch.dict(os.environ, {"HF_TOKEN": "test_token_123"}, clear=True)
488495
def test_get_hf_token(self):
489496
assert HuggingFaceModelFetcher.get_hf_token() == "test_token_123"
490-
491497

492498
# @pytest.mark.network
493499
# def test_fetch_config_only_real_call_success(self):
@@ -496,27 +502,26 @@ def test_get_hf_token(self):
496502
# This test requires an internet connection.
497503
# """
498504
# model_id = "distilbert-base-uncased"
499-
505+
500506
# try:
501507
# config = HuggingFaceModelFetcher.fetch_config_only(model_id)
502508
# assert isinstance(config, dict)
503509
# assert "model_type" in config
504510
# assert "dim" in config
505511
# except AquaValueError as e:
506512
# pytest.fail(f"Real network call to Hugging Face failed: {e}")
507-
508-
509-
@patch('ads.aqua.shaperecommend.recommend.OCIDataScienceModelDeployment.shapes')
513+
514+
@patch("ads.aqua.shaperecommend.recommend.OCIDataScienceModelDeployment.shapes")
510515
@patch.dict(os.environ, {}, clear=True)
511516
def test_valid_compute_shapes_raises_error_no_compartment(self, mock_oci_shapes):
512517
"""
513518
Tests that valid_compute_shapes raises a ValueError when no compartment ID is
514519
provided and none can be found in the environment.
515520
"""
516521
app = AquaShapeRecommend()
517-
522+
518523
with pytest.raises(AquaValueError, match="A compartment OCID is required"):
519524
app.valid_compute_shapes(compartment_id=None)
520-
525+
521526
# Verify that the OCI SDK was not called because the check failed early
522-
mock_oci_shapes.assert_not_called()
527+
mock_oci_shapes.assert_not_called()

0 commit comments

Comments
 (0)