Skip to content

Commit 43d0a37

Browse files
authored
Merge branch 'main' into feature/model_group
2 parents a84b58b + 2098b24 commit 43d0a37

File tree

8 files changed

+275
-68
lines changed

8 files changed

+275
-68
lines changed

ads/aqua/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
SUPPORTED_FILE_FORMATS = ["jsonl"]
5656
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
5757

58+
AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
59+
5860
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
5961
"datasciencemodel": "models",
6062
"datasciencemodeldeployment": "model-deployments",

ads/aqua/extension/model_handler.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
1212
from ads.aqua.common.errors import AquaRuntimeError
1313
from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
14+
from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY
1415
from ads.aqua.extension.base_handler import AquaAPIhandler
1516
from ads.aqua.extension.errors import Errors
1617
from ads.aqua.model import AquaModelApp
1718
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
1819
from ads.config import SERVICE
20+
from ads.model import DataScienceModel
1921
from ads.model.common.utils import MetadataArtifactPathType
22+
from ads.model.service.oci_datascience_model import OCIDataScienceModel
2023

2124

2225
class AquaModelHandler(AquaAPIhandler):
@@ -320,26 +323,65 @@ def post(self, *args, **kwargs): # noqa: ARG002
320323
)
321324

322325

323-
class AquaModelTokenizerConfigHandler(AquaAPIhandler):
326+
class AquaModelChatTemplateHandler(AquaAPIhandler):
324327
def get(self, model_id):
325328
"""
326-
Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
327-
Expected request format: GET /aqua/models/<model-ocid>/tokenizer
329+
Handles requests for retrieving the chat template from custom metadata of a specified model.
330+
Expected request format: GET /aqua/models/<model-ocid>/chat-template
328331
329332
"""
330333

331334
path_list = urlparse(self.request.path).path.strip("/").split("/")
332-
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
333-
# path_list=['aqua','models','<model-ocid>','tokenizer']
335+
# Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template
336+
# path_list=['aqua','models','<model-ocid>','chat-template']
334337
if (
335338
len(path_list) == 4
336339
and is_valid_ocid(path_list[2])
337-
and path_list[3] == "tokenizer"
340+
and path_list[3] == "chat-template"
338341
):
339-
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
342+
try:
343+
oci_data_science_model = OCIDataScienceModel.from_id(model_id)
344+
except Exception as e:
345+
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
346+
return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template"))
340347

341348
raise HTTPError(400, f"The request {self.request.path} is invalid.")
342349

350+
@handle_exceptions
351+
def post(self, model_id: str):
352+
"""
353+
Handles POST requests to add a custom chat_template metadata artifact to a model.
354+
355+
Expected request format:
356+
POST /aqua/models/<model-ocid>/chat-template
357+
Body: { "chat_template": "<your_template_string>" }
358+
359+
"""
360+
try:
361+
input_body = self.get_json_body()
362+
except Exception as e:
363+
raise HTTPError(400, f"Invalid JSON body: {str(e)}")
364+
365+
chat_template = input_body.get("chat_template")
366+
if not chat_template:
367+
raise HTTPError(400, "Missing required field: 'chat_template'")
368+
369+
try:
370+
data_science_model = DataScienceModel.from_id(model_id)
371+
except Exception as e:
372+
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
373+
374+
try:
375+
result = data_science_model.create_custom_metadata_artifact(
376+
metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
377+
path_type=MetadataArtifactPathType.CONTENT,
378+
artifact_path_or_content=chat_template.encode()
379+
)
380+
except Exception as e:
381+
raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}")
382+
383+
return self.finish(result)
384+
343385

344386
class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler):
345387
"""
@@ -381,7 +423,7 @@ def post(self, model_id: str, metadata_key: str):
381423
("model/?([^/]*)", AquaModelHandler),
382424
("model/?([^/]*)/license", AquaModelLicenseHandler),
383425
("model/?([^/]*)/readme", AquaModelReadmeHandler),
384-
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
426+
("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler),
385427
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
386428
(
387429
"model/?([^/]*)/definedMetadata/?([^/]*)",

ads/aqua/model/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class ModelTask(ExtendedEnum):
2626
TEXT_GENERATION = "text-generation"
2727
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
2828
IMAGE_TO_TEXT = "image-to-text"
29+
TIME_SERIES_FORECASTING = "time-series-forecasting"
2930

3031

3132
class FineTuningMetricCategories(ExtendedEnum):

ads/aqua/modeldeployment/deployment.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ def create(
218218
freeform_tags=freeform_tags,
219219
defined_tags=defined_tags,
220220
)
221+
task_tag = aqua_model.freeform_tags.get(Tags.TASK, UNKNOWN)
222+
if (
223+
task_tag == ModelTask.TIME_SERIES_FORECASTING
224+
or task_tag == ModelTask.TIME_SERIES_FORECASTING.replace("-", "_")
225+
):
226+
create_deployment_details.env_var.update(
227+
{Tags.TASK.upper(): ModelTask.TIME_SERIES_FORECASTING}
228+
)
221229
return self._create(
222230
aqua_model=aqua_model,
223231
create_deployment_details=create_deployment_details,
@@ -1077,14 +1085,16 @@ def _create_deployment(
10771085
).deploy(wait_for_completion=False)
10781086

10791087
deployment_id = deployment.id
1088+
10801089
logger.info(
10811090
f"Aqua model deployment {deployment_id} created for model {aqua_model_id}. Work request Id is {deployment.dsc_model_deployment.workflow_req_id}"
10821091
)
1092+
status_list = []
10831093

10841094
progress_thread = threading.Thread(
10851095
target=self.get_deployment_status,
10861096
args=(
1087-
deployment_id,
1097+
deployment,
10881098
deployment.dsc_model_deployment.workflow_req_id,
10891099
model_type,
10901100
model_name,
@@ -1604,7 +1614,7 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
16041614

16051615
def get_deployment_status(
16061616
self,
1607-
model_deployment_id: str,
1617+
deployment: ModelDeployment,
16081618
work_request_id: str,
16091619
model_type: str,
16101620
model_name: str,
@@ -1626,37 +1636,60 @@ def get_deployment_status(
16261636
AquaDeployment
16271637
An Aqua deployment instance.
16281638
"""
1629-
ocid = get_ocid_substring(model_deployment_id, key_len=8)
1630-
telemetry_kwargs = {"ocid": ocid}
1631-
1639+
ocid = get_ocid_substring(deployment.id, key_len=8)
16321640
data_science_work_request: DataScienceWorkRequest = DataScienceWorkRequest(
16331641
work_request_id
16341642
)
1635-
16361643
try:
16371644
data_science_work_request.wait_work_request(
16381645
progress_bar_description="Creating model deployment",
16391646
max_wait_time=DEFAULT_WAIT_TIME,
16401647
poll_interval=DEFAULT_POLL_INTERVAL,
16411648
)
16421649
except Exception:
1650+
status = ""
1651+
logs = deployment.show_logs().sort_values(by="time", ascending=False)
1652+
1653+
if logs and len(logs) > 0:
1654+
status = logs.iloc[0]["message"]
1655+
1656+
status = re.sub(r"[^a-zA-Z0-9]", " ", status)
1657+
16431658
if data_science_work_request._error_message:
16441659
error_str = ""
16451660
for error in data_science_work_request._error_message:
16461661
error_str = error_str + " " + error.message
16471662

1648-
self.telemetry.record_event(
1649-
category=f"aqua/{model_type}/deployment/status",
1650-
action="FAILED",
1651-
detail=error_str,
1652-
value=model_name,
1653-
**telemetry_kwargs,
1654-
)
1663+
error_str = re.sub(r"[^a-zA-Z0-9]", " ", error_str)
1664+
telemetry_kwargs = {
1665+
"ocid": ocid,
1666+
"model_name": model_name,
1667+
"work_request_error": error_str,
1668+
"status": status,
1669+
}
1670+
1671+
self.telemetry.record_event(
1672+
category=f"aqua/{model_type}/deployment/status",
1673+
action="FAILED",
1674+
**telemetry_kwargs,
1675+
)
1676+
else:
1677+
telemetry_kwargs = {
1678+
"ocid": ocid,
1679+
"model_name": model_name,
1680+
"status": status,
1681+
}
1682+
1683+
self.telemetry.record_event(
1684+
category=f"aqua/{model_type}/deployment/status",
1685+
action="FAILED",
1686+
**telemetry_kwargs,
1687+
)
16551688

16561689
else:
1657-
self.telemetry.record_event_async(
1690+
telemetry_kwargs = {"ocid": ocid, "model_name": model_name}
1691+
self.telemetry.record_event(
16581692
category=f"aqua/{model_type}/deployment/status",
16591693
action="SUCCEEDED",
1660-
value=model_name,
16611694
**telemetry_kwargs,
16621695
)

ads/common/oci_logging.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

43
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import datetime
87
import logging
98
import time
10-
from typing import Dict, Union, List
9+
from typing import Dict, List, Union
1110

11+
import oci.exceptions
1212
import oci.logging
1313
import oci.loggingsearch
14-
import oci.exceptions
14+
1515
from ads.common.decorator.utils import class_or_instance_method
1616
from ads.common.oci_mixin import OCIModelMixin, OCIWorkRequestMixin
1717
from ads.common.oci_resource import OCIResource, ResourceNotFoundError
1818

19-
2019
logger = logging.getLogger(__name__)
2120

2221
# Maximum number of log records to be returned by default.
@@ -862,9 +861,7 @@ def tail(
862861
time_start=time_start,
863862
log_filter=log_filter,
864863
)
865-
self._print(
866-
sorted(tail_logs, key=lambda log: log["time"])
867-
)
864+
self._print(sorted(tail_logs, key=lambda log: log["time"]))
868865

869866
def head(
870867
self,

ads/model/deployment/model_deployment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,8 @@ def watch(
757757
log_filter : str, optional
758758
Expression for filtering the logs. This will be the WHERE clause of the query.
759759
Defaults to None.
760+
status_list : List[str], optional
761+
List of status of model deployment. This is used to store list of status from logs.
760762
761763
Returns
762764
-------

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,29 +2376,84 @@ def test_validate_multimodel_deployment_feasibility_positive_single(
23762376
"test_data/deployment/aqua_summary_multi_model_single.json",
23772377
)
23782378

2379-
def test_get_deployment_status(self):
2379+
def test_get_deployment_status_success(self):
2380+
model_deployment = copy.deepcopy(TestDataset.model_deployment_object[0])
23802381
deployment_id = "fakeid.datasciencemodeldeployment.oc1.iad.xxx"
23812382
work_request_id = "fakeid.workrequest.oc1.iad.xxx"
23822383
model_type = "custom"
23832384
model_name = "model_name"
23842385

23852386
with patch(
2386-
"ads.model.service.oci_datascience_model_deployment.DataScienceWorkRequest.__init__"
2387-
) as mock_ds_work_request:
2388-
mock_ds_work_request.return_value = None
2389-
with patch(
2390-
"ads.model.service.oci_datascience_model_deployment.DataScienceWorkRequest.wait_work_request"
2391-
) as mock_wait:
2392-
self.app.get_deployment_status(
2393-
deployment_id, work_request_id, model_type, model_name
2394-
)
2387+
"ads.model.service.oci_datascience_model_deployment.DataScienceWorkRequest.__init__",
2388+
return_value=None,
2389+
) as mock_ds_work_request, patch(
2390+
"ads.model.service.oci_datascience_model_deployment.DataScienceWorkRequest.wait_work_request"
2391+
) as mock_wait:
2392+
self.app.get_deployment_status(
2393+
oci.data_science.models.ModelDeploymentSummary(**model_deployment),
2394+
work_request_id,
2395+
model_type,
2396+
model_name,
2397+
)
23952398

2396-
mock_ds_work_request.assert_called_with(work_request_id)
2397-
mock_wait.assert_called_with(
2398-
progress_bar_description="Creating model deployment",
2399-
max_wait_time=DEFAULT_WAIT_TIME,
2400-
poll_interval=DEFAULT_POLL_INTERVAL,
2401-
)
2399+
mock_ds_work_request.assert_called_once_with(work_request_id)
2400+
mock_wait.assert_called_once_with(
2401+
progress_bar_description="Creating model deployment",
2402+
max_wait_time=DEFAULT_WAIT_TIME,
2403+
poll_interval=DEFAULT_POLL_INTERVAL,
2404+
)
2405+
2406+
def raise_exception(*args, **kwargs):
2407+
raise Exception("Work request failed")
2408+
2409+
def test_get_deployment_status_failed(self):
2410+
model_deployment = copy.deepcopy(TestDataset.model_deployment_object[0])
2411+
deployment_id = "fakeid.datasciencemodeldeployment.oc1.iad.xxx"
2412+
work_request_id = "fakeid.workrequest.oc1.iad.xxx"
2413+
model_type = "custom"
2414+
model_name = "model_name"
2415+
with patch(
2416+
"ads.telemetry.client.TelemetryClient.record_event"
2417+
) as mock_record_event, patch(
2418+
"ads.aqua.modeldeployment.deployment.DataScienceWorkRequest"
2419+
) as mock_ds_work_request_class, patch(
2420+
"ads.model.deployment.model_deployment.ModelDeployment.show_logs"
2421+
) as mock_show_log:
2422+
mock_ds_work_request_instance = MagicMock()
2423+
mock_ds_work_request_class.return_value = mock_ds_work_request_instance
2424+
2425+
mock_ds_work_request_instance._error_message = [
2426+
MagicMock(message="Some error occurred")
2427+
]
2428+
2429+
mock_ds_work_request_instance.wait_work_request.side_effect = (
2430+
self.raise_exception
2431+
)
2432+
2433+
logs_df = MagicMock()
2434+
logs_df.sort_values.return_value = logs_df
2435+
logs_df.empty = False
2436+
logs_df.iloc.__getitem__.return_value = {
2437+
"message": "Error: deployment failed!"
2438+
}
2439+
mock_show_log.return_value = logs_df
2440+
2441+
self.app.get_deployment_status(
2442+
ModelDeployment(),
2443+
work_request_id,
2444+
model_type,
2445+
model_name,
2446+
)
2447+
mock_record_event.assert_called_once()
2448+
args, kwargs = mock_record_event.call_args
2449+
self.assertEqual(kwargs["category"], f"aqua/{model_type}/deployment/status")
2450+
self.assertEqual(kwargs["action"], "FAILED")
2451+
self.assertIn("work_request_error", kwargs)
2452+
self.assertIn("status", kwargs)
2453+
self.assertIn("ocid", kwargs)
2454+
self.assertIn("model_name", kwargs)
2455+
2456+
mock_ds_work_request_class.assert_called_once_with(work_request_id)
24022457

24032458

24042459
class TestBaseModelSpec:

0 commit comments

Comments
 (0)