Skip to content

Commit 49962ff

Browse files
committed
resolving comments
1 parent a362351 commit 49962ff

File tree

2 files changed

+24
-86
lines changed

2 files changed

+24
-86
lines changed

ads/aqua/shaperecommend/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,3 @@
114114
"ARM": "CPU",
115115
"UNKNOWN_ENUM_VALUE": "N/A",
116116
}
117-
118-
HUGGINGFACE_CONFIG_URL = "https://huggingface.co/{model_id}/resolve/main/config.json"

ads/aqua/shaperecommend/recommend.py

Lines changed: 24 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,17 @@
22
# Copyright (c) 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
#!/usr/bin/env python
6-
# Copyright (c) 2025 Oracle and/or its affiliates.
7-
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
8-
9-
import shutil
10-
import os
115
import json
12-
from typing import List, Union, Optional, Dict, Any, Tuple
6+
import os
7+
import shutil
8+
from typing import Dict, List, Optional, Tuple, Union
139

1410
from pydantic import ValidationError
1511
from rich.table import Table
1612

1713
from huggingface_hub import hf_hub_download
1814
from huggingface_hub.utils import HfHubHTTPError
15+
1916
from ads.aqua.app import logger
2017
from ads.aqua.common.entities import ComputeShapeSummary
2118
from ads.aqua.common.errors import (
@@ -25,14 +22,15 @@
2522
)
2623
from ads.aqua.common.utils import (
2724
build_pydantic_error_message,
25+
get_resource_type,
26+
is_valid_ocid,
2827
load_config,
2928
load_gpu_shapes_index,
30-
is_valid_ocid,
31-
get_resource_type,
29+
format_hf_custom_error_message,
3230
)
3331
from ads.aqua.shaperecommend.constants import (
34-
BITS_AND_BYTES_4BIT,
3532
BITSANDBYTES,
33+
BITS_AND_BYTES_4BIT,
3634
SAFETENSORS,
3735
SHAPE_MAP,
3836
TEXT_GENERATION,
@@ -46,6 +44,7 @@
4644
ShapeRecommendationReport,
4745
ShapeReport,
4846
)
47+
from ads.config import COMPARTMENT_OCID
4948
from ads.model.datascience_model import DataScienceModel
5049
from ads.model.service.oci_datascience_model_deployment import (
5150
OCIDataScienceModelDeployment,
@@ -100,9 +99,6 @@ def which_shapes(
10099
"""
101100
try:
102101
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
103-
# data, model_name = self._get_model_config_and_name(
104-
# model_id=request.model_id, compartment_id=request.compartment_id
105-
# )
106102
data, model_name = self._get_model_config_and_name(
107103
model_id=request.model_id,
108104
)
@@ -158,41 +154,18 @@ def _get_model_config_and_name(
158154
- The display name for the model.
159155
"""
160156
if is_valid_ocid(model_id):
161-
logger.info(f"'{model_id}' identified as a model OCID.")
157+
logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.")
162158
ds_model = self._validate_model_ocid(model_id)
163-
return self._get_model_config(ds_model), ds_model.display_name
159+
config = self._fetch_hf_config(model_id)
160+
model_name = ds_model.display_name
161+
else:
162+
logger.info(
163+
f"Assuming Hugging Face model ID: Fetching config for '{model_id}'."
164+
)
165+
config = self._fetch_hf_config(model_id)
166+
model_name = model_id
164167

165-
logger.info(
166-
f"'{model_id}' is not an OCID, treating as a Hugging Face model ID."
167-
)
168-
# if not compartment_id:
169-
# compartment_id = os.environ.get(
170-
# "NB_SESSION_COMPARTMENT_OCID"
171-
# ) or os.environ.get("PROJECT_COMPARTMENT_OCID")
172-
# if compartment_id:
173-
# logger.info(f"Using compartment_id from environment: {compartment_id}")
174-
# if not compartment_id:
175-
# raise AquaValueError(
176-
# "A compartment OCID is required to list available shapes. "
177-
# "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
178-
# "or 'PROJECT_COMPARTMENT_OCID' environment variable."
179-
# "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
180-
# )
181-
182-
# ds_model = self._search_model_in_catalog(model_id, compartment_id)
183-
# if ds_model:
184-
# logger.info("Loading configuration from existing model catalog artifact.")
185-
# try:
186-
# return (
187-
# self._get_model_config(ds_model),
188-
# ds_model.display_name,
189-
# )
190-
# except AquaFileNotFoundError:
191-
# logger.warning(
192-
# "config.json not found in artifact, fetching from Hugging Face Hub."
193-
# )
194-
195-
return self._fetch_hf_config(model_id), model_id
168+
return config, model_name
196169

197170
def _fetch_hf_config(self, model_id: str) -> Dict:
198171
"""
@@ -204,38 +177,7 @@ def _fetch_hf_config(self, model_id: str) -> Dict:
204177
with open(config_path, "r", encoding="utf-8") as f:
205178
return json.load(f)
206179
except HfHubHTTPError as e:
207-
if "401" in str(e):
208-
raise AquaValueError(
209-
f"Model '{model_id}' requires authentication. Please set your HuggingFace access token as an environment variable (HF_TOKEN). cli command: export HF_TOKEN=<HF_TOKEN>"
210-
)
211-
elif "404" in str(e) or "not found" in str(e).lower():
212-
raise AquaValueError(
213-
f"Model '{model_id}' not found on HuggingFace. Please check the name for typos."
214-
)
215-
raise AquaValueError(
216-
f"Failed to download config for '{model_id}': {e}"
217-
) from e
218-
219-
# def _search_model_in_catalog(
220-
# self, model_id: str, compartment_id: str
221-
# ) -> Optional[DataScienceModel]:
222-
# """
223-
# Searches for a model in the Data Science catalog by its display name.
224-
# """
225-
# try:
226-
# models = DataScienceModel.list(
227-
# compartment_id=compartment_id, display_name=model_id
228-
# )
229-
# if len(models) > 1:
230-
# logger.warning(
231-
# f"Found multiple models with the name '{model_id}'. Using the first one found."
232-
# )
233-
# if models:
234-
# logger.info(f"Found model '{model_id}' in the Data Science catalog.")
235-
# return models[0]
236-
# except Exception as e:
237-
# logger.warning(f"Could not search for model '{model_id}' in catalog: {e}")
238-
# return None
180+
format_hf_custom_error_message(e)
239181

240182
def valid_compute_shapes(
241183
self, compartment_id: Optional[str] = None
@@ -260,18 +202,16 @@ def valid_compute_shapes(
260202
environment variables.
261203
"""
262204
if not compartment_id:
263-
compartment_id = os.environ.get(
264-
"NB_SESSION_COMPARTMENT_OCID"
265-
) or os.environ.get("PROJECT_COMPARTMENT_OCID")
205+
compartment_id = COMPARTMENT_OCID
266206
if compartment_id:
267207
logger.info(f"Using compartment_id from environment: {compartment_id}")
268208

269209
if not compartment_id:
270210
raise AquaValueError(
271211
"A compartment OCID is required to list available shapes. "
272-
"Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
273-
"or 'PROJECT_COMPARTMENT_OCID' environment variable."
274-
"cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
212+
"Please specify it using the --compartment_id parameter.\n\n"
213+
"Example:\n"
214+
'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
275215
)
276216

277217
oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id)

0 commit comments

Comments
 (0)