diff --git a/src/prompto/apis/ollama/ollama_utils.py b/src/prompto/apis/ollama/ollama_utils.py index 807cbdf7..735e65ef 100644 --- a/src/prompto/apis/ollama/ollama_utils.py +++ b/src/prompto/apis/ollama/ollama_utils.py @@ -1,5 +1,7 @@ ollama_chat_roles = set(["system", "user", "assistant"]) +from ollama import ChatResponse, GenerateResponse + def process_response(response: dict) -> str: """ @@ -25,5 +27,9 @@ def process_response(response: dict) -> str: "Unsupported response format. " f"No 'response' or 'message' key found in response: {response}" ) + elif isinstance(response, ChatResponse): + return response.message.content if response.message else "" + elif isinstance(response, GenerateResponse): + return response.response if response.response else "" else: raise ValueError(f"Unsupported response type: {type(response)}")