diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index c645a0c82..4545a3863 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ - +import json from typing import Optional from urllib.parse import urlparse @@ -341,9 +341,13 @@ def get(self, model_id): ): try: oci_data_science_model = OCIDataScienceModel.from_id(model_id) + chat_template = oci_data_science_model.get_custom_metadata_artifact("chat_template") + chat_template = chat_template.decode("utf-8") + + return self.finish(json.dumps({"chat_template": chat_template})) + except Exception as e: - raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}") - return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template")) + raise HTTPError(404, f"Failed to fetch chat template for model_id={model_id}. Details: {str(e)}") raise HTTPError(400, f"The request {self.request.path} is invalid.") diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index e0938e1c1..ffa6a8444 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*-- # Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import json from unicodedata import category from unittest import TestCase from unittest.mock import MagicMock, patch, ANY @@ -272,11 +273,13 @@ def test_get_valid_path(self, mock_urlparse, mock_from_id): mock_urlparse.return_value = request_path model_mock = MagicMock() - model_mock.get_custom_metadata_artifact.return_value = "chat_template_string" + model_mock.get_custom_metadata_artifact.return_value = b"chat_template_string" mock_from_id.return_value = model_mock self.model_chat_template_handler.get(model_id="test_model_id") - self.model_chat_template_handler.finish.assert_called_with("chat_template_string") + self.model_chat_template_handler.finish.assert_called_with( + json.dumps({"chat_template": "chat_template_string"}) + ) model_mock.get_custom_metadata_artifact.assert_called_with("chat_template") @patch("ads.aqua.extension.model_handler.urlparse") @@ -361,7 +364,7 @@ def test_post_model_not_found(self, mock_write_error, mock_from_id): _, exc_instance, _ = exc_info assert isinstance(exc_instance, HTTPError) assert exc_instance.status_code == 404 - assert "Model not found" in str(exc_instance) + assert "Model not found for id" in str(exc_instance) class TestAquaHuggingFaceHandler: