88from intelli .wrappers .openai_wrapper import OpenAIWrapper
99from intelli .wrappers .anthropic_wrapper import AnthropicWrapper
1010from intelli .wrappers .keras_wrapper import KerasWrapper
11+ from intelli .wrappers .nvidia_wrapper import NvidiaWrapper
1112from enum import Enum
1213
1314class ChatProvider (Enum ):
@@ -16,6 +17,7 @@ class ChatProvider(Enum):
1617 MISTRAL = "mistral"
1718 ANTHROPIC = "anthropic"
1819 KERAS = "keras"
20+ NVIDIA = "nvidia"
1921
2022class Chatbot :
2123
@@ -58,6 +60,8 @@ def _initialize_provider(self):
5860 return AnthropicWrapper (self .api_key )
5961 elif self .provider == ChatProvider .KERAS .value :
6062 return KerasWrapper (self .options ['model_name' ], self .options .get ('model_params' , {}))
63+ elif self .provider == ChatProvider .NVIDIA .value :
64+ return NvidiaWrapper (self .api_key )
6165 else :
6266 raise ValueError (f"Unsupported provider: { self .provider } " )
6367
@@ -104,6 +108,13 @@ def _chat_anthropic(self, params):
104108 response = self .wrapper .generate_text (params )
105109
106110 return [message ['text' ] for message in response ['content' ]]
111+
112+ def _chat_nvidia (self , params ):
113+ result = self .wrapper .generate_text (params )
114+ choices = result .get ("choices" , [])
115+ if not choices :
116+ raise Exception ("No choices returned from NVIDIA API" )
117+ return [choices [0 ]["message" ]["content" ]]
107118
108119 def stream (self , chat_input ):
109120 """Streams responses from the selected provider for the given chat input."""
@@ -156,6 +167,20 @@ def _stream_anthropic(self, params):
156167 except json .JSONDecodeError as e :
157168 print ("Error decoding JSON from stream:" , e )
158169
170+ def _stream_nvidia (self , params ):
171+ params ["stream" ] = True
172+ stream = self .wrapper .generate_text_stream (params )
173+ for line in stream :
174+ if line .strip () and line .startswith ("data: " ) and line != "data: [DONE]" :
175+ json_content = line [len ("data: " ):].strip ()
176+ try :
177+ data_chunk = json .loads (json_content )
178+ content = data_chunk .get ("choices" , [{}])[0 ].get ("delta" , {}).get ("content" , "" )
179+ if content :
180+ yield content
181+ except json .JSONDecodeError as e :
182+ print ("Error decoding JSON:" , e )
183+
159184 # helpers
160185 def _parse_openai_responses (self , results ):
161186 responses = []
0 commit comments