@@ -7,6 +7,7 @@ use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
7
7
8
8
use rig:: {
9
9
completion:: { self , CompletionError , CompletionRequest } ,
10
+ message:: { ImageMediaType , MimeType } ,
10
11
providers:: openai:: Message ,
11
12
OneOrMany ,
12
13
} ;
@@ -117,6 +118,71 @@ pub struct CompletionModel {
117
118
pub model : String ,
118
119
}
119
120
121
+ fn user_text_to_json ( content : rig:: message:: UserContent ) -> serde_json:: Value {
122
+ match content {
123
+ rig:: message:: UserContent :: Text ( text) => json ! ( {
124
+ "role" : "user" ,
125
+ "content" : text. text,
126
+ } ) ,
127
+ _ => unreachable ! ( ) ,
128
+ }
129
+ }
130
+
131
+ fn user_content_to_json (
132
+ content : rig:: message:: UserContent ,
133
+ ) -> Result < serde_json:: Value , CompletionError > {
134
+ match content {
135
+ rig:: message:: UserContent :: Text ( text) => Ok ( json ! ( {
136
+ "type" : "text" ,
137
+ "text" : text. text
138
+ } ) ) ,
139
+ rig:: message:: UserContent :: Image ( image) => Ok ( json ! ( {
140
+ "type" : "image_url" ,
141
+ "image_url" : {
142
+ "url" : format!( "data:{};base64,{}" , image. media_type. unwrap_or( ImageMediaType :: PNG ) . to_mime_type( ) , image. data) ,
143
+ }
144
+ } ) ) ,
145
+ rig:: message:: UserContent :: Audio ( _) => Err ( CompletionError :: RequestError (
146
+ "Audio is not supported" . into ( ) ,
147
+ ) ) ,
148
+ rig:: message:: UserContent :: Document ( _) => Err ( CompletionError :: RequestError (
149
+ "Document is not supported" . into ( ) ,
150
+ ) ) ,
151
+ rig:: message:: UserContent :: ToolResult ( _) => unreachable ! ( ) ,
152
+ }
153
+ }
154
+
155
+ fn tool_content_to_json (
156
+ content : Vec < rig:: message:: UserContent > ,
157
+ ) -> Result < serde_json:: Value , CompletionError > {
158
+ let mut str_content = String :: new ( ) ;
159
+ let mut tool_id = String :: new ( ) ;
160
+
161
+ for content in content. into_iter ( ) {
162
+ match content {
163
+ rig:: message:: UserContent :: ToolResult ( tool_result) => {
164
+ tool_id = tool_result. id ;
165
+ str_content = tool_result
166
+ . content
167
+ . iter ( )
168
+ . map ( |c| match c {
169
+ rig:: message:: ToolResultContent :: Text ( text) => text. text . clone ( ) ,
170
+ // ignore image content
171
+ _ => "" . to_string ( ) ,
172
+ } )
173
+ . collect :: < Vec < _ > > ( )
174
+ . join ( "" ) ;
175
+ }
176
+ _ => unreachable ! ( ) ,
177
+ }
178
+ }
179
+ Ok ( json ! ( {
180
+ "role" : "tool" ,
181
+ "content" : str_content,
182
+ "tool_call_id" : tool_id,
183
+ } ) )
184
+ }
185
+
120
186
impl CompletionModel {
121
187
pub fn new ( client : Client , model : & str ) -> Self {
122
188
Self {
@@ -130,64 +196,103 @@ impl CompletionModel {
130
196
completion_request : CompletionRequest ,
131
197
) -> Result < Value , CompletionError > {
132
198
// Add preamble to chat history (if available)
133
- let mut full_history: Vec < Message > = match & completion_request. preamble {
134
- Some ( preamble) => vec ! [ Message :: system( preamble) ] ,
199
+ let mut full_history: Vec < serde_json:: Value > = match & completion_request. preamble {
200
+ Some ( preamble) => vec ! [ json!( {
201
+ "role" : "system" ,
202
+ "content" : preamble,
203
+ } ) ] ,
135
204
None => vec ! [ ] ,
136
205
} ;
137
206
138
207
// Convert existing chat history
139
- let chat_history: Vec < Message > = completion_request
140
- . chat_history
141
- . into_iter ( )
142
- . map ( |message| message. try_into ( ) )
143
- . collect :: < Result < Vec < Vec < Message > > , _ > > ( ) ?
144
- . into_iter ( )
145
- . flatten ( )
146
- . collect ( ) ;
147
-
148
- // Combine all messages into a single history
149
- full_history. extend ( chat_history) ;
150
- let messages: Vec < Value > = full_history
151
- . into_iter ( )
152
- . map ( |ref m| match m {
153
- Message :: Assistant {
154
- content,
155
- refusal : _,
156
- audio : _,
157
- name : _,
158
- tool_calls,
159
- } => {
160
- if !tool_calls. is_empty ( ) {
161
- json ! ( {
162
- "role" : "assistant" ,
163
- "content" : null,
164
- "tool_calls" : tool_calls,
165
- } )
208
+ for message in completion_request. chat_history . into_iter ( ) {
209
+ match message {
210
+ rig:: message:: Message :: User { content } => {
211
+ if content. len ( ) == 1
212
+ && matches ! ( content. first( ) , rig:: message:: UserContent :: Text ( _) )
213
+ {
214
+ full_history. push ( user_text_to_json ( content. first ( ) ) ) ;
215
+ } else if content
216
+ . iter ( )
217
+ . any ( |c| matches ! ( c, rig:: message:: UserContent :: ToolResult ( _) ) )
218
+ {
219
+ let ( tool_content, user_content) =
220
+ content. into_iter ( ) . partition :: < Vec < _ > , _ > ( |c| {
221
+ matches ! ( c, rig:: message:: UserContent :: ToolResult ( _) )
222
+ } ) ;
223
+ full_history. push ( tool_content_to_json ( tool_content. clone ( ) ) ?) ;
224
+ for tool_content in tool_content. into_iter ( ) {
225
+ match tool_content {
226
+ rig:: message:: UserContent :: ToolResult ( result) => {
227
+ for tool_result_content in result. content . into_iter ( ) {
228
+ match tool_result_content {
229
+ rig:: message:: ToolResultContent :: Image ( image) => {
230
+ full_history. push ( json ! ( {
231
+ "role" : "user" ,
232
+ "content" : [ {
233
+ "type" : "image_url" ,
234
+ "image_url" : {
235
+ "url" : format!( "data:{};base64,{}" , image. media_type. unwrap_or( ImageMediaType :: PNG ) . to_mime_type( ) , image. data) ,
236
+ }
237
+ } ]
238
+ } ) ) ;
239
+ }
240
+ _ => { }
241
+ }
242
+ }
243
+ }
244
+ _ => unreachable ! ( ) ,
245
+ }
246
+ }
247
+ if !user_content. is_empty ( ) {
248
+ if user_content. len ( ) == 1 {
249
+ full_history
250
+ . push ( user_text_to_json ( user_content. first ( ) . unwrap ( ) . clone ( ) ) ) ;
251
+ } else {
252
+ let user_content = user_content
253
+ . into_iter ( )
254
+ . map ( user_content_to_json)
255
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
256
+ full_history
257
+ . push ( json ! ( { "role" : "user" , "content" : user_content} ) ) ;
258
+ }
259
+ }
166
260
} else {
167
- json ! ( {
168
- "role" : "assistant" ,
169
- "content" : match content. first( ) . unwrap( ) {
170
- AssistantContent :: Text { text } => text,
171
- _ => "" ,
172
- } ,
173
- } )
261
+ let content = content
262
+ . into_iter ( )
263
+ . map ( user_content_to_json)
264
+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
265
+ full_history. push ( json ! ( { "role" : "user" , "content" : content} ) ) ;
174
266
}
175
267
}
176
- Message :: ToolResult {
177
- tool_call_id,
178
- content,
179
- } => {
180
- let content = json ! ( content. first( ) ) ;
181
- let text = content. as_object ( ) . unwrap ( ) . get ( "text" ) . unwrap ( ) ;
182
- json ! ( {
183
- "role" : "tool" ,
184
- "content" : text,
185
- "tool_call_id" : tool_call_id,
186
- } )
268
+ rig:: message:: Message :: Assistant { content } => {
269
+ for content in content {
270
+ match content {
271
+ rig:: message:: AssistantContent :: Text ( text) => {
272
+ full_history. push ( json ! ( {
273
+ "role" : "assistant" ,
274
+ "content" : text. text
275
+ } ) ) ;
276
+ }
277
+ rig:: message:: AssistantContent :: ToolCall ( tool_call) => {
278
+ full_history. push ( json ! ( {
279
+ "role" : "assistant" ,
280
+ "content" : null,
281
+ "tool_calls" : [ {
282
+ "id" : tool_call. id,
283
+ "type" : "function" ,
284
+ "function" : {
285
+ "name" : tool_call. function. name,
286
+ "arguments" : tool_call. function. arguments. to_string( )
287
+ }
288
+ } ]
289
+ } ) ) ;
290
+ }
291
+ }
292
+ }
187
293
}
188
- _ => json ! ( m) ,
189
- } )
190
- . collect ( ) ;
294
+ } ;
295
+ }
191
296
192
297
let tools = completion_request
193
298
. tools
@@ -201,7 +306,7 @@ impl CompletionModel {
201
306
. collect :: < Vec < _ > > ( ) ;
202
307
let request = json ! ( {
203
308
"model" : self . model,
204
- "messages" : messages ,
309
+ "messages" : full_history ,
205
310
"tools" : tools,
206
311
"temperature" : completion_request. temperature,
207
312
} ) ;
0 commit comments