|
11 | 11 | from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
|
12 | 12 | from ads.aqua.common.errors import AquaRuntimeError
|
13 | 13 | from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
|
| 14 | +from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY |
14 | 15 | from ads.aqua.extension.base_handler import AquaAPIhandler
|
15 | 16 | from ads.aqua.extension.errors import Errors
|
16 | 17 | from ads.aqua.model import AquaModelApp
|
17 | 18 | from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
|
18 | 19 | from ads.config import SERVICE
|
| 20 | +from ads.model import DataScienceModel |
19 | 21 | from ads.model.common.utils import MetadataArtifactPathType
|
| 22 | +from ads.model.service.oci_datascience_model import OCIDataScienceModel |
20 | 23 |
|
21 | 24 |
|
22 | 25 | class AquaModelHandler(AquaAPIhandler):
|
@@ -320,26 +323,65 @@ def post(self, *args, **kwargs): # noqa: ARG002
|
320 | 323 | )
|
321 | 324 |
|
322 | 325 |
|
323 |
| -class AquaModelTokenizerConfigHandler(AquaAPIhandler): |
| 326 | +class AquaModelChatTemplateHandler(AquaAPIhandler): |
324 | 327 | def get(self, model_id):
|
325 | 328 | """
|
326 |
| - Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model. |
327 |
| - Expected request format: GET /aqua/models/<model-ocid>/tokenizer |
| 329 | + Handles requests for retrieving the chat template from custom metadata of a specified model. |
| 330 | + Expected request format: GET /aqua/models/<model-ocid>/chat-template |
328 | 331 |
|
329 | 332 | """
|
330 | 333 |
|
331 | 334 | path_list = urlparse(self.request.path).path.strip("/").split("/")
|
332 |
| - # Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer |
333 |
| - # path_list=['aqua','models','<model-ocid>','tokenizer'] |
| 335 | + # Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template |
| 336 | + # path_list=['aqua','models','<model-ocid>','chat-template'] |
334 | 337 | if (
|
335 | 338 | len(path_list) == 4
|
336 | 339 | and is_valid_ocid(path_list[2])
|
337 |
| - and path_list[3] == "tokenizer" |
| 340 | + and path_list[3] == "chat-template" |
338 | 341 | ):
|
339 |
| - return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id)) |
| 342 | + try: |
| 343 | + oci_data_science_model = OCIDataScienceModel.from_id(model_id) |
| 344 | + except Exception as e: |
| 345 | + raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}") |
| 346 | + return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template")) |
340 | 347 |
|
341 | 348 | raise HTTPError(400, f"The request {self.request.path} is invalid.")
|
342 | 349 |
|
| 350 | + @handle_exceptions |
| 351 | + def post(self, model_id: str): |
| 352 | + """ |
| 353 | + Handles POST requests to add a custom chat_template metadata artifact to a model. |
| 354 | +
|
| 355 | + Expected request format: |
| 356 | + POST /aqua/models/<model-ocid>/chat-template |
| 357 | + Body: { "chat_template": "<your_template_string>" } |
| 358 | +
|
| 359 | + """ |
| 360 | + try: |
| 361 | + input_body = self.get_json_body() |
| 362 | + except Exception as e: |
| 363 | + raise HTTPError(400, f"Invalid JSON body: {str(e)}") |
| 364 | + |
| 365 | + chat_template = input_body.get("chat_template") |
| 366 | + if not chat_template: |
| 367 | + raise HTTPError(400, "Missing required field: 'chat_template'") |
| 368 | + |
| 369 | + try: |
| 370 | + data_science_model = DataScienceModel.from_id(model_id) |
| 371 | + except Exception as e: |
| 372 | + raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}") |
| 373 | + |
| 374 | + try: |
| 375 | + result = data_science_model.create_custom_metadata_artifact( |
| 376 | + metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY, |
| 377 | + path_type=MetadataArtifactPathType.CONTENT, |
| 378 | + artifact_path_or_content=chat_template.encode() |
| 379 | + ) |
| 380 | + except Exception as e: |
| 381 | + raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}") |
| 382 | + |
| 383 | + return self.finish(result) |
| 384 | + |
343 | 385 |
|
344 | 386 | class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler):
|
345 | 387 | """
|
@@ -381,7 +423,7 @@ def post(self, model_id: str, metadata_key: str):
|
381 | 423 | ("model/?([^/]*)", AquaModelHandler),
|
382 | 424 | ("model/?([^/]*)/license", AquaModelLicenseHandler),
|
383 | 425 | ("model/?([^/]*)/readme", AquaModelReadmeHandler),
|
384 |
| - ("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler), |
| 426 | + ("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler), |
385 | 427 | ("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
|
386 | 428 | (
|
387 | 429 | "model/?([^/]*)/definedMetadata/?([^/]*)",
|
|
0 commit comments