diff --git a/rig-core/src/extractor.rs b/rig-core/src/extractor.rs index 5d77428ad..103423b62 100644 --- a/rig-core/src/extractor.rs +++ b/rig-core/src/extractor.rs @@ -36,7 +36,7 @@ use serde_json::json; use crate::{ agent::{Agent, AgentBuilder, AgentBuilderSimple}, - completion::{Completion, CompletionError, CompletionModel, ToolDefinition}, + completion::{Completion, CompletionError, CompletionModel, ToolDefinition, Usage}, message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction}, tool::Tool, wasm_compat::{WasmCompatSend, WasmCompatSync}, @@ -81,7 +81,7 @@ where pub async fn extract( &self, text: impl Into + WasmCompatSend, - ) -> Result { + ) -> Result, ExtractionError> { let mut last_error = None; let text_message = text.into(); @@ -114,7 +114,7 @@ where &self, text: impl Into + WasmCompatSend, chat_history: Vec, - ) -> Result { + ) -> Result, ExtractionError> { let mut last_error = None; let text_message = text.into(); @@ -125,7 +125,7 @@ where ); let attempt_text = text_message.clone(); match self.extract_json(attempt_text, chat_history.clone()).await { - Ok(data) => return Ok(data), + Ok(res) => return Ok(res), Err(e) => { tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying..."); last_error = Some(e); @@ -141,7 +141,7 @@ where &self, text: impl Into + WasmCompatSend, messages: Vec, - ) -> Result { + ) -> Result, ExtractionError> { let response = self.agent.completion(text, messages).await?.send().await?; if !response.choice.iter().any(|x| { @@ -193,7 +193,12 @@ where return Err(ExtractionError::NoData); }; - Ok(serde_json::from_value(raw_data)?) + let data = serde_json::from_value(raw_data)?; + + Ok(ExtractResponse { + data, + usage: response.usage, + }) } pub async fn get_inner(&self) -> &Agent { @@ -205,6 +210,25 @@ where } } +/// An extraction result. +pub struct ExtractResponse { + data: T, + usage: Usage, +} + +impl ExtractResponse +where + T: Clone, +{ + pub fn data(&self) -> T { + self.data.clone() + } + + pub fn usage(&self) -> Usage { + self.usage.clone() + } +} + /// Builder for the Extractor pub struct ExtractorBuilder where