-
Notifications
You must be signed in to change notification settings - Fork 60
[AQUA] Add Hugging Face model support to Shape Recommender #1262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
a1c0e12
2126fd0
b24938e
0ca295e
ff37df9
05183f9
911a6af
4c2c6a2
d2bf709
e78ca08
8554da2
76288dd
367ce70
29bd22c
efe0953
19ddce1
3d935bb
e324143
2251e3c
c3b5afa
d68e501
79d6f5f
c840a39
a362351
49962ff
7d02bf1
a1bde9f
86b257b
287b406
5088b85
2a9d843
8edc0e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,20 @@ | |
| # Copyright (c) 2025 Oracle and/or its affiliates. | ||
| # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
|
|
||
| #!/usr/bin/env python | ||
|
||
| # Copyright (c) 2025 Oracle and/or its affiliates. | ||
| # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
|
|
||
| import shutil | ||
| from typing import List, Union | ||
| import os | ||
|
||
| import json | ||
| from typing import List, Union, Optional, Dict, Any, Tuple | ||
|
||
|
|
||
| from pydantic import ValidationError | ||
| from rich.table import Table | ||
|
|
||
| from huggingface_hub import hf_hub_download | ||
| from huggingface_hub.utils import HfHubHTTPError | ||
| from ads.aqua.app import logger | ||
| from ads.aqua.common.entities import ComputeShapeSummary | ||
| from ads.aqua.common.errors import ( | ||
|
|
@@ -17,9 +25,10 @@ | |
| ) | ||
| from ads.aqua.common.utils import ( | ||
| build_pydantic_error_message, | ||
| get_resource_type, | ||
| load_config, | ||
| load_gpu_shapes_index, | ||
| is_valid_ocid, | ||
| get_resource_type, | ||
| ) | ||
| from ads.aqua.shaperecommend.constants import ( | ||
| BITS_AND_BYTES_4BIT, | ||
|
|
@@ -91,14 +100,13 @@ def which_shapes( | |
| """ | ||
| try: | ||
| shapes = self.valid_compute_shapes(compartment_id=request.compartment_id) | ||
|
|
||
| ds_model = self._validate_model_ocid(request.model_id) | ||
| data = self._get_model_config(ds_model) | ||
|
|
||
| # data, model_name = self._get_model_config_and_name( | ||
|
||
| # model_id=request.model_id, compartment_id=request.compartment_id | ||
| # ) | ||
| data, model_name = self._get_model_config_and_name( | ||
| model_id=request.model_id, | ||
| ) | ||
| llm_config = LLMConfig.from_raw_config(data) | ||
|
|
||
| model_name = ds_model.display_name if ds_model.display_name else "" | ||
|
|
||
| shape_recommendation_report = self._summarize_shapes_for_seq_lens( | ||
| llm_config, shapes, model_name | ||
| ) | ||
|
|
@@ -127,7 +135,111 @@ def which_shapes( | |
|
|
||
| return shape_recommendation_report | ||
|
|
||
| def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]: | ||
| def _get_model_config_and_name( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: It would be more clean if we had return at the end: |
||
| self, | ||
| model_id: str, | ||
| ) -> Tuple[Dict, str]: | ||
| """ | ||
| Loads model configuration by trying OCID logic first, then falling back | ||
| to treating the model_id as a Hugging Face Hub ID. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model_id : str | ||
| The model OCID or Hugging Face model ID. | ||
| # compartment_id : Optional[str] | ||
| # The compartment OCID, used for searching the model catalog. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tuple[Dict, str] | ||
| A tuple containing: | ||
| - The model configuration dictionary. | ||
| - The display name for the model. | ||
| """ | ||
| if is_valid_ocid(model_id): | ||
| logger.info(f"'{model_id}' identified as a model OCID.") | ||
|
||
| ds_model = self._validate_model_ocid(model_id) | ||
| return self._get_model_config(ds_model), ds_model.display_name | ||
|
|
||
| logger.info( | ||
| f"'{model_id}' is not an OCID, treating as a Hugging Face model ID." | ||
| ) | ||
| # if not compartment_id: | ||
|
||
| # compartment_id = os.environ.get( | ||
| # "NB_SESSION_COMPARTMENT_OCID" | ||
| # ) or os.environ.get("PROJECT_COMPARTMENT_OCID") | ||
| # if compartment_id: | ||
| # logger.info(f"Using compartment_id from environment: {compartment_id}") | ||
| # if not compartment_id: | ||
| # raise AquaValueError( | ||
| # "A compartment OCID is required to list available shapes. " | ||
| # "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' " | ||
| # "or 'PROJECT_COMPARTMENT_OCID' environment variable." | ||
| # "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>" | ||
| # ) | ||
|
|
||
| # ds_model = self._search_model_in_catalog(model_id, compartment_id) | ||
| # if ds_model: | ||
| # logger.info("Loading configuration from existing model catalog artifact.") | ||
| # try: | ||
| # return ( | ||
| # self._get_model_config(ds_model), | ||
| # ds_model.display_name, | ||
| # ) | ||
| # except AquaFileNotFoundError: | ||
| # logger.warning( | ||
| # "config.json not found in artifact, fetching from Hugging Face Hub." | ||
| # ) | ||
|
|
||
| return self._fetch_hf_config(model_id), model_id | ||
|
|
||
| def _fetch_hf_config(self, model_id: str) -> Dict: | ||
| """ | ||
| Downloads a model's config.json from Hugging Face Hub using the | ||
| huggingface_hub library. | ||
| """ | ||
Aryanag2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try: | ||
| config_path = hf_hub_download(repo_id=model_id, filename="config.json") | ||
| with open(config_path, "r", encoding="utf-8") as f: | ||
| return json.load(f) | ||
| except HfHubHTTPError as e: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check the existing - |
||
| if "401" in str(e): | ||
| raise AquaValueError( | ||
|
||
| f"Model '{model_id}' requires authentication. Please set your HuggingFace access token as an environment variable (HF_TOKEN). cli command: export HF_TOKEN=<HF_TOKEN>" | ||
| ) | ||
| elif "404" in str(e) or "not found" in str(e).lower(): | ||
| raise AquaValueError( | ||
| f"Model '{model_id}' not found on HuggingFace. Please check the name for typos." | ||
| ) | ||
| raise AquaValueError( | ||
| f"Failed to download config for '{model_id}': {e}" | ||
| ) from e | ||
|
|
||
| # def _search_model_in_catalog( | ||
| # self, model_id: str, compartment_id: str | ||
| # ) -> Optional[DataScienceModel]: | ||
| # """ | ||
| # Searches for a model in the Data Science catalog by its display name. | ||
| # """ | ||
| # try: | ||
| # models = DataScienceModel.list( | ||
| # compartment_id=compartment_id, display_name=model_id | ||
| # ) | ||
| # if len(models) > 1: | ||
| # logger.warning( | ||
| # f"Found multiple models with the name '{model_id}'. Using the first one found." | ||
| # ) | ||
| # if models: | ||
| # logger.info(f"Found model '{model_id}' in the Data Science catalog.") | ||
| # return models[0] | ||
| # except Exception as e: | ||
| # logger.warning(f"Could not search for model '{model_id}' in catalog: {e}") | ||
| # return None | ||
|
|
||
| def valid_compute_shapes( | ||
| self, compartment_id: Optional[str] = None | ||
| ) -> List["ComputeShapeSummary"]: | ||
| """ | ||
| Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file. | ||
|
|
||
|
|
@@ -143,9 +255,25 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary | |
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If the file cannot be opened, parsed, or the 'shapes' key is missing. | ||
| AquaValueError | ||
| If a compartment_id is not provided and cannot be found in the | ||
| environment variables. | ||
| """ | ||
| if not compartment_id: | ||
| compartment_id = os.environ.get( | ||
|
||
| "NB_SESSION_COMPARTMENT_OCID" | ||
| ) or os.environ.get("PROJECT_COMPARTMENT_OCID") | ||
| if compartment_id: | ||
| logger.info(f"Using compartment_id from environment: {compartment_id}") | ||
|
|
||
| if not compartment_id: | ||
| raise AquaValueError( | ||
| "A compartment OCID is required to list available shapes. " | ||
| "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' " | ||
|
||
| "or 'PROJECT_COMPARTMENT_OCID' environment variable." | ||
| "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>" | ||
|
||
| ) | ||
|
|
||
| oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id) | ||
| set_user_shapes = {shape.name: shape for shape in oci_shapes} | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it is not used anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done