Skip to content

Commit e9aaef2

Browse files
authored
List models using open ai api (#1234)
2 parents 9c1095e + 8942499 commit e9aaef2

File tree

5 files changed

+92
-41
lines changed

5 files changed

+92
-41
lines changed

ads/aqua/client/client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,19 @@ def embeddings(
582582
payload = {**(payload or {}), "input": input}
583583
return self._request(payload=payload, headers=headers)
584584

585+
def fetch_data(self) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
586+
"""Fetch Data in json format by sending a request to the endpoint.
587+
588+
Args:
589+
590+
Returns:
591+
Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: The server's response, typically including the data in JSON format.
592+
"""
593+
# headers = {"Content-Type", "application/json"}
594+
response = self._client.get(self.endpoint)
595+
json_response = response.json()
596+
return json_response
597+
585598

586599
class AsyncClient(BaseClient):
587600
"""

ads/aqua/extension/deployment_handler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,37 @@ def post(self, *args, **kwargs): # noqa: ARG002
373373
)
374374

375375

376+
class AquaModelListHandler(AquaAPIhandler):
377+
"""Handler for Aqua model list params REST APIs.
378+
379+
Methods
380+
-------
381+
get(self, *args, **kwargs)
382+
Validates parameters for the given model id.
383+
"""
384+
385+
@handle_exceptions
386+
def get(self, model_deployment_id):
387+
"""
388+
Handles get model list for the Active Model Deployment
389+
Raises
390+
------
391+
HTTPError
392+
Raises HTTPError if inputs are missing or are invalid
393+
"""
394+
395+
self.set_header("Content-Type", "application/json")
396+
endpoint: str = ""
397+
model_deployment = AquaDeploymentApp().get(model_deployment_id)
398+
endpoint = model_deployment.endpoint.rstrip("/") + "/predict/v1/models"
399+
aqua_client = Client(endpoint=endpoint)
400+
try:
401+
list_model_result = aqua_client.fetch_data()
402+
return self.finish(list_model_result)
403+
except Exception as ex:
404+
raise HTTPError(500, str(ex))
405+
406+
376407
__handlers__ = [
377408
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
378409
("deployments/config/?([^/]*)", AquaDeploymentHandler),
@@ -381,4 +412,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
381412
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
382413
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
383414
("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler),
415+
("deployments/models/list/?([^/]*)", AquaModelListHandler),
384416
]

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def test__load_default_properties(self, mock_from_ocid):
372372
ModelDeploymentInfrastructure.CONST_SHAPE_CONFIG_DETAILS: {
373373
"ocpus": 10.0,
374374
"memory_in_gbs": 36.0,
375+
"cpu_baseline": None,
375376
},
376377
ModelDeploymentInfrastructure.CONST_REPLICA: 1,
377378
}
@@ -886,7 +887,7 @@ def test_model_deployment_from_dict(self):
886887
def test_update_model_deployment_details(self, mock_create):
887888
dsc_model = MagicMock()
888889
dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx"
889-
mock_create.return_value = dsc_model
890+
mock_create.return_value = dsc_model
890891
model_deployment = self.initialize_model_deployment()
891892
update_model_deployment_details = (
892893
model_deployment._update_model_deployment_details()
@@ -1127,9 +1128,7 @@ def test_from_ocid(self, mock_from_ocid):
11271128
"create_model_deployment",
11281129
)
11291130
@patch.object(DataScienceModel, "create")
1130-
def test_deploy(
1131-
self, mock_create, mock_create_model_deployment, mock_sync
1132-
):
1131+
def test_deploy(self, mock_create, mock_create_model_deployment, mock_sync):
11331132
dsc_model = MagicMock()
11341133
dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx"
11351134
mock_create.return_value = dsc_model
@@ -1346,44 +1345,35 @@ def test_update_spec(self):
13461345
model_deployment = self.initialize_model_deployment()
13471346
model_deployment._update_spec(
13481347
display_name="test_updated_name",
1349-
freeform_tags={"test_updated_key":"test_updated_value"},
1350-
access_log={
1351-
"log_id": "test_updated_access_log_id"
1352-
},
1353-
predict_log={
1354-
"log_group_id": "test_updated_predict_log_group_id"
1355-
},
1356-
shape_config_details={
1357-
"ocpus": 100,
1358-
"memoryInGBs": 200
1359-
},
1348+
freeform_tags={"test_updated_key": "test_updated_value"},
1349+
access_log={"log_id": "test_updated_access_log_id"},
1350+
predict_log={"log_group_id": "test_updated_predict_log_group_id"},
1351+
shape_config_details={"ocpus": 100, "memoryInGBs": 200},
13601352
replica=20,
13611353
image="test_updated_image",
1362-
env={
1363-
"test_updated_env_key":"test_updated_env_value"
1364-
}
1354+
env={"test_updated_env_key": "test_updated_env_value"},
13651355
)
13661356

13671357
assert model_deployment.display_name == "test_updated_name"
13681358
assert model_deployment.freeform_tags == {
1369-
"test_updated_key":"test_updated_value"
1359+
"test_updated_key": "test_updated_value"
13701360
}
13711361
assert model_deployment.infrastructure.access_log == {
13721362
"logId": "test_updated_access_log_id",
1373-
"logGroupId": "fakeid.loggroup.oc1.iad.xxx"
1363+
"logGroupId": "fakeid.loggroup.oc1.iad.xxx",
13741364
}
13751365
assert model_deployment.infrastructure.predict_log == {
13761366
"logId": "fakeid.log.oc1.iad.xxx",
1377-
"logGroupId": "test_updated_predict_log_group_id"
1367+
"logGroupId": "test_updated_predict_log_group_id",
13781368
}
13791369
assert model_deployment.infrastructure.shape_config_details == {
13801370
"ocpus": 100,
1381-
"memoryInGBs": 200
1371+
"memoryInGBs": 200,
13821372
}
13831373
assert model_deployment.infrastructure.replica == 20
13841374
assert model_deployment.runtime.image == "test_updated_image"
13851375
assert model_deployment.runtime.env == {
1386-
"test_updated_env_key":"test_updated_env_value"
1376+
"test_updated_env_key": "test_updated_env_value"
13871377
}
13881378

13891379
@patch.object(OCIDataScienceMixin, "sync")
@@ -1393,18 +1383,14 @@ def test_update_spec(self):
13931383
)
13941384
@patch.object(DataScienceModel, "create")
13951385
def test_model_deployment_with_large_size_artifact(
1396-
self,
1397-
mock_create,
1398-
mock_create_model_deployment,
1399-
mock_sync
1386+
self, mock_create, mock_create_model_deployment, mock_sync
14001387
):
14011388
dsc_model = MagicMock()
14021389
dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx"
14031390
mock_create.return_value = dsc_model
14041391
model_deployment = self.initialize_model_deployment()
14051392
(
1406-
model_deployment.runtime
1407-
.with_auth({"test_key":"test_value"})
1393+
model_deployment.runtime.with_auth({"test_key": "test_value"})
14081394
.with_region("test_region")
14091395
.with_overwrite_existing_artifact(True)
14101396
.with_remove_existing_artifact(True)
@@ -1425,18 +1411,18 @@ def test_model_deployment_with_large_size_artifact(
14251411
mock_create_model_deployment.return_value = response
14261412
model_deployment = self.initialize_model_deployment()
14271413
model_deployment.set_spec(model_deployment.CONST_ID, "test_model_deployment_id")
1428-
1414+
14291415
create_model_deployment_details = (
14301416
model_deployment._build_model_deployment_details()
14311417
)
14321418
model_deployment.deploy(wait_for_completion=False)
14331419
mock_create.assert_called_with(
14341420
bucket_uri="test_bucket_uri",
1435-
auth={"test_key":"test_value"},
1421+
auth={"test_key": "test_value"},
14361422
region="test_region",
14371423
overwrite_existing_artifact=True,
14381424
remove_existing_artifact=True,
1439-
timeout=100
1425+
timeout=100,
14401426
)
14411427
mock_create_model_deployment.assert_called_with(create_model_deployment_details)
14421428
mock_sync.assert_called()

tests/unitary/default_setup/pipeline/test_pipeline.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
nb_session_ocid="ocid1.datasciencenotebooksession.oc1.iad..<unique_ocid>",
3333
shape_name="VM.Standard.E3.Flex",
3434
block_storage_size_in_gbs=100,
35-
shape_config_details={"ocpus": 1, "memory_in_gbs": 16},
35+
shape_config_details={"ocpus": 1.0, "memory_in_gbs": 16.0, "cpu_baseline": None},
3636
)
3737
PIPELINE_OCID = "ocid.xxx.datasciencepipeline.<unique_ocid>"
3838

@@ -334,10 +334,8 @@ def test_pipeline_define(self):
334334
"jobId": "TestJobIdOne",
335335
"description": "Test description one",
336336
"commandLineArguments": "ARGUMENT --KEY VALUE",
337-
"environmentVariables": {
338-
"ENV": "VALUE"
339-
},
340-
"maximumRuntimeInMinutes": 20
337+
"environmentVariables": {"ENV": "VALUE"},
338+
"maximumRuntimeInMinutes": 20,
341339
},
342340
},
343341
{
@@ -1066,10 +1064,8 @@ def test_pipeline_to_dict(self):
10661064
"jobId": "TestJobIdOne",
10671065
"description": "Test description one",
10681066
"commandLineArguments": "ARGUMENT --KEY VALUE",
1069-
"environmentVariables": {
1070-
"ENV": "VALUE"
1071-
},
1072-
"maximumRuntimeInMinutes": 20
1067+
"environmentVariables": {"ENV": "VALUE"},
1068+
"maximumRuntimeInMinutes": 20,
10731069
},
10741070
},
10751071
{

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from parameterized import parameterized
1414

1515
import ads.aqua
16+
from ads.aqua.modeldeployment.entities import AquaDeploymentDetail
1617
import ads.config
1718
from ads.aqua.extension.deployment_handler import (
1819
AquaDeploymentHandler,
1920
AquaDeploymentParamsHandler,
2021
AquaDeploymentStreamingInferenceHandler,
22+
AquaModelListHandler,
2123
)
2224

2325

@@ -260,3 +262,25 @@ def test_post(self, mock_get_model_deployment_response):
260262
self.handler.write.assert_any_call("chunk1")
261263
self.handler.write.assert_any_call("chunk2")
262264
self.handler.finish.assert_called_once()
265+
266+
267+
class AquaModelListHandlerTestCase(unittest.TestCase):
268+
default_params = {
269+
"data": [{"id": "id", "object": "object", "owned_by": "openAI", "created": 124}]
270+
}
271+
272+
@patch.object(IPythonHandler, "__init__")
273+
def setUp(self, ipython_init_mock) -> None:
274+
ipython_init_mock.return_value = None
275+
self.aqua_model_list_handler = AquaModelListHandler(MagicMock(), MagicMock())
276+
self.aqua_model_list_handler._headers = MagicMock()
277+
278+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.get")
279+
@patch("notebook.base.handlers.APIHandler.finish")
280+
def test_get_model_list(self, mock_get, mock_finish):
281+
"""Test to check the handler get method to return model list."""
282+
283+
mock_get.return_value = MagicMock(id="test_model_id")
284+
mock_finish.side_effect = lambda x: x
285+
result = self.aqua_model_list_handler.get(model_id="test_model_id")
286+
mock_get.assert_called()

0 commit comments

Comments
 (0)