Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions runtimes/huggingface/mlserver_huggingface/codecs/conversation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import List, Any
from mlserver.codecs.base import InputCodec, register_input_codec
from mlserver.types import RequestInput, ResponseOutput, Parameters
from transformers.pipelines import Conversation
from mlserver.codecs.lists import is_list_of
from .utils import json_decode, json_encode
from .utils import json_decode, json_encode, get_conversation_class


Conversation = get_conversation_class()


@register_input_codec
Expand All @@ -16,12 +18,14 @@ class HuggingfaceConversationCodec(InputCodec):

@classmethod
def can_encode(cls, payload: Any) -> bool:
return is_list_of(payload, Conversation)
return Conversation is not None and is_list_of(payload, Conversation)

@classmethod
def encode_output(
cls, name: str, payload: List[Conversation], use_bytes: bool = True, **kwargs
cls, name: str, payload: List[Any], use_bytes: bool = True, **kwargs
) -> ResponseOutput:
if Conversation is None:
raise ImportError("transformers.pipelines.Conversation is not available.")
encoded = [json_encode(item, use_bytes=use_bytes) for item in payload]
shape = [len(encoded), 1]
return ResponseOutput(
Expand All @@ -41,8 +45,10 @@ def decode_output(cls, response_output: ResponseOutput) -> List[Any]:

@classmethod
def encode_input(
cls, name: str, payload: List[Conversation], use_bytes: bool = True, **kwargs
cls, name: str, payload: List[Any], use_bytes: bool = True, **kwargs
) -> RequestInput:
if Conversation is None:
raise ImportError("transformers.pipelines.Conversation is not available.")
output = cls.encode_output(name, payload, use_bytes)
return RequestInput(
name=output.name,
Expand All @@ -55,6 +61,6 @@ def encode_input(
)

@classmethod
def decode_input(cls, request_input: RequestInput) -> List[Conversation]:
def decode_input(cls, request_input: RequestInput) -> List[Any]:
packed = request_input.data
return [json_decode(item) for item in packed]
25 changes: 19 additions & 6 deletions runtimes/huggingface/mlserver_huggingface/codecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
import base64
import numpy as np
from PIL import Image, ImageChops
from transformers.pipelines import Conversation
from mlserver.codecs.json import JSONEncoderWithArray

IMAGE_PREFIX = "data:image/"
DEFAULT_IMAGE_FORMAT = "PNG"


class HuggingfaceJSONEncoder(JSONEncoderWithArray):
def get_conversation_class():
try:
from transformers.pipelines import Conversation
return Conversation
except ImportError:
return None

Conversation = get_conversation_class()


class HuggingfaceJSONEncoder(JSONEncoderWithArray):
def default(self, obj):
if isinstance(obj, Image.Image):
buf = io.BytesIO()
Expand All @@ -24,7 +33,7 @@ def default(self, obj):
+ ";base64,"
+ base64.b64encode(buf.getvalue()).decode()
)
elif isinstance(obj, Conversation):
elif Conversation and isinstance(obj, Conversation):
return {
"uuid": str(obj.uuid),
"past_user_inputs": obj.past_user_inputs,
Expand Down Expand Up @@ -66,7 +75,7 @@ def do(cls, raw):

@classmethod
def convert_conversation(cls, d: Dict[str, Any]):
if set(d.keys()) == conversation_keys:
if Conversation and set(d.keys()) == conversation_keys:
return Conversation(
text=d["new_user_input"],
conversation_id=d["uuid"],
Expand All @@ -83,8 +92,12 @@ def convert_dict(cls, d: Dict[str, Any]):
tmp = {}
for k, v in d.items():
if isinstance(v, dict):
if set(v.keys()) == conversation_keys:
tmp[k] = Conversation(text=v["new_user_input"])
if set(d.keys()) == conversation_keys:
tmp[k] = (
Conversation(text=v["new_user_input"])
if Conversation
else v
)
else:
tmp[k] = cls.convert_dict(v)
elif isinstance(v, list):
Expand Down