Skip to content

Commit f007d51

Browse files
committed
fixed test case for deployment handler
1 parent e8e36db commit f007d51

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from parameterized import parameterized
1414

1515
import ads.aqua
16+
from ads.aqua.modeldeployment.entities import AquaDeploymentDetail
1617
import ads.config
1718
from ads.aqua.extension.deployment_handler import (
1819
AquaDeploymentHandler,
@@ -264,28 +265,22 @@ def test_post(self, mock_get_model_deployment_response):
264265

265266

266267
class AquaModelListHandlerTestCase(unittest.TestCase):
267-
default_params = ["--seed 42", "--trust-remote-code"]
268+
default_params = {
269+
"data": [{"id": "id", "object": "object", "owned_by": "openAI", "created": 124}]
270+
}
268271

269272
@patch.object(IPythonHandler, "__init__")
270273
def setUp(self, ipython_init_mock) -> None:
271274
ipython_init_mock.return_value = None
272-
self.test_instance = AquaModelListHandler(MagicMock(), MagicMock())
275+
self.aqua_model_list_handler = AquaModelListHandler(MagicMock(), MagicMock())
276+
self.aqua_model_list_handler._headers = MagicMock()
273277

278+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.get")
274279
@patch("notebook.base.handlers.APIHandler.finish")
275-
# @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_deployment_default_params")
276-
def test_get_model_list(self, mock_get_model_list_default_params, mock_finish):
280+
def test_get_model_list(self, mock_get, mock_finish):
277281
"""Test to check the handler get method to return model list."""
278282

279-
mock_get_model_list_default_params.return_value = self.default_params
283+
mock_get.return_value = MagicMock(id="test_model_id")
280284
mock_finish.side_effect = lambda x: x
281-
282-
# args = {"instance_shape": TestDataset.INSTANCE_SHAPE}
283-
# self.test_instance.get_argument = MagicMock(
284-
# side_effect=lambda arg, default=None : args.get(arg, default)
285-
# )
286-
result = self.test_instance.get(model_id="test_model_id")
287-
self.assertCountEqual(result["data"], self.default_params)
288-
289-
mock_get_model_list_default_params.assert_called_with(
290-
model_id="test_model_id",
291-
)
285+
result = self.aqua_model_list_handler.get(model_id="test_model_id")
286+
mock_get.assert_called()

0 commit comments

Comments
 (0)