2828 cast ,
2929)
3030
31- from pydantic import ValidationError
32-
3331from neo4j_graphrag .message_history import MessageHistory
3432from neo4j_graphrag .types import LLMMessage
3533
3634from ..exceptions import LLMGenerationError
3735from .base import LLMInterface
38- from .rate_limit import RateLimitHandler , rate_limit_handler , async_rate_limit_handler
3936from .types import (
40- BaseMessage ,
4137 LLMResponse ,
42- MessageList ,
4338 ToolCall ,
4439 ToolCallResponse ,
45- SystemMessage ,
46- UserMessage ,
4740)
4841
4942from neo4j_graphrag .tool import Tool
5043
5144if TYPE_CHECKING :
5245 from openai .types .chat import (
5346 ChatCompletionMessageParam ,
54- ChatCompletionToolParam ,
55- )
47+ ChatCompletionToolParam , ChatCompletionUserMessageParam ,
48+ ChatCompletionSystemMessageParam , ChatCompletionAssistantMessageParam ,
49+ )
5650 from openai import OpenAI , AsyncOpenAI
51+ from .rate_limit import RateLimitHandler
5752else :
5853 ChatCompletionMessageParam = Any
5954 ChatCompletionToolParam = Any
6055 OpenAI = Any
6156 AsyncOpenAI = Any
57+ RateLimitHandler = Any
6258
6359
6460class BaseOpenAILLM (LLMInterface , abc .ABC ):
@@ -93,23 +89,26 @@ def __init__(
9389
9490 def get_messages (
9591 self ,
96- input : str ,
97- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
98- system_instruction : Optional [str ] = None ,
92+ messages : list [LLMMessage ],
9993 ) -> Iterable [ChatCompletionMessageParam ]:
100- messages = []
101- if system_instruction :
102- messages .append (SystemMessage (content = system_instruction ).model_dump ())
103- if message_history :
104- if isinstance (message_history , MessageHistory ):
105- message_history = message_history .messages
106- try :
107- MessageList (messages = cast (list [BaseMessage ], message_history ))
108- except ValidationError as e :
109- raise LLMGenerationError (e .errors ()) from e
110- messages .extend (cast (Iterable [dict [str , Any ]], message_history ))
111- messages .append (UserMessage (content = input ).model_dump ())
112- return messages # type: ignore
94+ chat_messages = []
95+ for m in messages :
96+ message_type : ChatCompletionMessageParam
97+ if m ["role" ] == "system" :
98+ message_type = ChatCompletionSystemMessageParam
99+ elif m ["role" ] == "user" :
100+ message_type = ChatCompletionUserMessageParam
101+ elif m ["role" ] == "assistant" :
102+ message_type = ChatCompletionAssistantMessageParam
103+ else :
104+ raise ValueError (f"Unknown message type: { m ['role' ]} " )
105+ chat_messages .append (
106+ message_type (
107+ role = m ["role" ],
108+ content = m ["content" ],
109+ )
110+ )
111+ return chat_messages
113112
114113 def _convert_tool_to_openai_format (self , tool : Tool ) -> Dict [str , Any ]:
115114 """Convert a Tool object to OpenAI's expected format.
@@ -132,21 +131,15 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]:
132131 except AttributeError :
133132 raise LLMGenerationError (f"Tool { tool } is not a valid Tool object" )
134133
135- @rate_limit_handler
136- def invoke (
134+ def _invoke (
137135 self ,
138- input : str ,
139- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
140- system_instruction : Optional [str ] = None ,
136+ input : list [LLMMessage ],
141137 ) -> LLMResponse :
142138 """Sends a text input to the OpenAI chat completion model
143139 and returns the response's content.
144140
145141 Args:
146142 input (str): Text sent to the LLM.
147- message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
148- with each message having a specific role assigned.
149- system_instruction (Optional[str]): An option to override the llm system message for this invocation.
150143
151144 Returns:
152145 LLMResponse: The response from OpenAI.
@@ -155,10 +148,8 @@ def invoke(
155148 LLMGenerationError: If anything goes wrong.
156149 """
157150 try :
158- if isinstance (message_history , MessageHistory ):
159- message_history = message_history .messages
160151 response = self .client .chat .completions .create (
161- messages = self .get_messages (input , message_history , system_instruction ),
152+ messages = self .get_messages (input ),
162153 model = self .model_name ,
163154 ** self .model_params ,
164155 )
@@ -167,7 +158,6 @@ def invoke(
167158 except self .openai .OpenAIError as e :
168159 raise LLMGenerationError (e )
169160
170- @rate_limit_handler
171161 def invoke_with_tools (
172162 self ,
173163 input : str ,
@@ -242,21 +232,15 @@ def invoke_with_tools(
242232 except self .openai .OpenAIError as e :
243233 raise LLMGenerationError (e )
244234
245- @async_rate_limit_handler
246- async def ainvoke (
235+ async def _ainvoke (
247236 self ,
248- input : str ,
249- message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
250- system_instruction : Optional [str ] = None ,
237+ input : list [LLMMessage ],
251238 ) -> LLMResponse :
252239 """Asynchronously sends a text input to the OpenAI chat
253240 completion model and returns the response's content.
254241
255242 Args:
256243 input (str): Text sent to the LLM.
257- message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
258- with each message having a specific role assigned.
259- system_instruction (Optional[str]): An option to override the llm system message for this invocation.
260244
261245 Returns:
262246 LLMResponse: The response from OpenAI.
@@ -265,10 +249,8 @@ async def ainvoke(
265249 LLMGenerationError: If anything goes wrong.
266250 """
267251 try :
268- if isinstance (message_history , MessageHistory ):
269- message_history = message_history .messages
270252 response = await self .async_client .chat .completions .create (
271- messages = self .get_messages (input , message_history , system_instruction ),
253+ messages = self .get_messages (input ),
272254 model = self .model_name ,
273255 ** self .model_params ,
274256 )
@@ -277,7 +259,6 @@ async def ainvoke(
277259 except self .openai .OpenAIError as e :
278260 raise LLMGenerationError (e )
279261
280- @async_rate_limit_handler
281262 async def ainvoke_with_tools (
282263 self ,
283264 input : str ,
0 commit comments