Skip to content

Commit 8554da2

Browse files
committed
added support for no compartment id provided and unit test for the same
1 parent e78ca08 commit 8554da2

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _search_model_in_catalog(self, model_id: str, compartment_id: str) -> Option
198198
logger.warning(f"Could not search for model '{model_id}' in catalog: {e}")
199199
return None
200200

201-
def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]:
201+
def valid_compute_shapes(self, compartment_id: Optional[str] = None) -> List["ComputeShapeSummary"]:
202202
"""
203203
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
204204
@@ -214,9 +214,22 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary
214214
215215
Raises
216216
------
217-
ValueError
218-
If the file cannot be opened, parsed, or the 'shapes' key is missing.
217+
AquaValueError
218+
If a compartment_id is not provided and cannot be found in the
219+
environment variables.
219220
"""
221+
if not compartment_id:
222+
compartment_id = os.environ.get("NB_SESSION_COMPARTMENT_OCID") or os.environ.get("PROJECT_COMPARTMENT_OCID")
223+
if compartment_id:
224+
logger.info(f"Using compartment_id from environment: {compartment_id}")
225+
226+
if not compartment_id:
227+
raise AquaValueError(
228+
"A compartment OCID is required to list available shapes. "
229+
"Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
230+
"or 'PROJECT_COMPARTMENT_OCID' environment variable."
231+
)
232+
220233
oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id)
221234
set_user_shapes = {shape.name: shape for shape in oci_shapes}
222235

tests/unitary/with_extras/aqua/test_recommend.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,4 +503,20 @@ def test_fetch_config_only_real_call_success(self):
503503
assert "model_type" in config
504504
assert "dim" in config
505505
except AquaValueError as e:
506-
pytest.fail(f"Real network call to Hugging Face failed: {e}")
506+
pytest.fail(f"Real network call to Hugging Face failed: {e}")
507+
508+
509+
@patch('ads.aqua.shaperecommend.recommend.OCIDataScienceModelDeployment.shapes')
510+
@patch.dict(os.environ, {}, clear=True)
511+
def test_valid_compute_shapes_raises_error_no_compartment(self, mock_oci_shapes):
512+
"""
513+
Tests that valid_compute_shapes raises a ValueError when no compartment ID is
514+
provided and none can be found in the environment.
515+
"""
516+
app = AquaShapeRecommend()
517+
518+
with pytest.raises(AquaValueError, match="A compartment OCID is required"):
519+
app.valid_compute_shapes(compartment_id=None)
520+
521+
# Verify that the OCI SDK was not called because the check failed early
522+
mock_oci_shapes.assert_not_called()

0 commit comments

Comments
 (0)