Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions rig-bedrock/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,9 @@ impl CompletionModel {
}

if !text.is_empty() {
yield Ok(RawStreamingChoice::Reasoning {
yield Ok(RawStreamingChoice::ReasoningDelta {
reasoning: text.clone(),
id: None,
signature: None,
})
}
},
Expand Down
4 changes: 4 additions & 0 deletions rig-core/src/agent/prompt_request/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ where
yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id, signature })));
did_call_tool = false;
},
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, id }) => {
yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id }));
did_call_tool = false;
},
Ok(StreamedAssistantContent::Final(final_resp)) => {
if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; };
if is_text_response {
Expand Down
11 changes: 5 additions & 6 deletions rig-core/src/providers/anthropic/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,9 @@ fn handle_event(
state.thinking.push_str(thinking);
}

Some(Ok(RawStreamingChoice::Reasoning {
Some(Ok(RawStreamingChoice::ReasoningDelta {
id: None,
reasoning: thinking.clone(),
signature: None,
}))
}
ContentDelta::SignatureDelta { signature } => {
Expand Down Expand Up @@ -504,11 +503,11 @@ mod tests {
let choice = result.unwrap().unwrap();

match choice {
RawStreamingChoice::Reasoning { id, reasoning, .. } => {
RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
assert_eq!(id, None);
assert_eq!(reasoning, "Analyzing the request...");
}
_ => panic!("Expected Reasoning choice"),
_ => panic!("Expected ReasoningDelta choice"),
}

// Verify thinking state was updated
Expand Down Expand Up @@ -584,10 +583,10 @@ mod tests {
let choice = result.unwrap().unwrap();

match choice {
RawStreamingChoice::Reasoning { reasoning, .. } => {
RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
assert_eq!(reasoning, "Thinking while tool is active...");
}
_ => panic!("Expected Reasoning choice"),
_ => panic!("Expected ReasoningDelta choice"),
}

// Tool call state should remain unchanged
Expand Down
5 changes: 2 additions & 3 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -879,10 +879,9 @@ where

// DeepSeek-specific reasoning stream
if let Some(content) = &delta.reasoning_content {
yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
reasoning: content.to_string(),
yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
id: None,
signature: None,
reasoning: content.to_string()
});
}

Expand Down
5 changes: 4 additions & 1 deletion rig-core/src/providers/gemini/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ where
thought: Some(true),
..
} => {
yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning: text.clone(), id: None, signature: None });
yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
id: None,
reasoning: text.clone(),
});
},
Part {
part: PartKind::Text(text),
Expand Down
3 changes: 1 addition & 2 deletions rig-core/src/providers/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,9 @@ where
if let Some(choice) = data.choices.first() {
match &choice.delta {
StreamingDelta::Reasoning { reasoning } => {
yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
id: None,
reasoning: reasoning.to_string(),
signature: None,
});
}

Expand Down
5 changes: 2 additions & 3 deletions rig-core/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,9 @@ where
if let Some(thinking_content) = thinking
&& !thinking_content.is_empty() {
thinking_response += &thinking_content;
yield RawStreamingChoice::Reasoning {
reasoning: thinking_content,
yield RawStreamingChoice::ReasoningDelta {
id: None,
signature: None,
reasoning: thinking_content,
};
}

Expand Down
13 changes: 10 additions & 3 deletions rig-core/src/providers/openai/responses_api/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ pub enum ItemChunkKind {
ReasoningSummaryPartAdded(SummaryPartChunk),
#[serde(rename = "response.reasoning_summary_part.done")]
ReasoningSummaryPartDone(SummaryPartChunk),
#[serde(rename = "response.reasoning_summary_text.added")]
ReasoningSummaryTextAdded(SummaryTextChunk),
#[serde(rename = "response.reasoning_summary_text.delta")]
ReasoningSummaryTextDelta(SummaryTextChunk),
#[serde(rename = "response.reasoning_summary_text.done")]
ReasoningSummaryTextDone(SummaryTextChunk),
}
Expand Down Expand Up @@ -295,7 +295,11 @@ where
})
.collect::<Vec<String>>()
.join("\n");
yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()), signature: None })
yield Ok(streaming::RawStreamingChoice::Reasoning {
id: Some(id.to_string()),
reasoning,
signature: None,
})
}
_ => continue
}
Expand All @@ -304,6 +308,9 @@ where
combined_text.push_str(&delta.delta);
yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
}
ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
}
ItemChunkKind::RefusalDelta(delta) => {
combined_text.push_str(&delta.delta);
yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
Expand Down
3 changes: 1 addition & 2 deletions rig-core/src/providers/openrouter/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,9 @@ where

// Streamed reasoning content
if let Some(reasoning) = &delta.reasoning && !reasoning.is_empty() {
yield Ok(streaming::RawStreamingChoice::Reasoning {
yield Ok(streaming::RawStreamingChoice::ReasoningDelta {
reasoning: reasoning.clone(),
id: None,
signature: None,
});
}

Expand Down
40 changes: 33 additions & 7 deletions rig-core/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,17 @@ where
},
/// A tool call partial/delta
ToolCallDelta { id: String, delta: String },
/// A reasoning chunk
/// A reasoning (in its entirety)
Reasoning {
id: Option<String>,
reasoning: String,
signature: Option<String>,
},
/// A reasoning partial/delta
ReasoningDelta {
id: Option<String>,
reasoning: String,
},

/// The final response object, must be yielded if you want the
/// `response` field to be populated on the `StreamingCompletionResponse`
Expand Down Expand Up @@ -231,15 +236,19 @@ where
id,
reasoning,
signature,
} => {
} => Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
id,
reasoning: vec![reasoning],
signature,
})))),
RawStreamingChoice::ReasoningDelta { id, reasoning } => {
// Forward the streaming tokens to the outer stream
// and concat the text together
stream.reasoning = format!("{}{}", stream.reasoning, reasoning);
Poll::Ready(Some(Ok(StreamedAssistantContent::Reasoning(Reasoning {
Poll::Ready(Some(Ok(StreamedAssistantContent::ReasoningDelta {
id,
reasoning: vec![reasoning],
signature,
}))))
reasoning,
})))
}
RawStreamingChoice::ToolCall {
id,
Expand Down Expand Up @@ -362,6 +371,12 @@ impl<R: Clone + Unpin + GetTokenUsage> Stream for StreamingResultDyn<R> {
reasoning,
signature,
}))),
RawStreamingChoice::ReasoningDelta { id, reasoning } => {
Poll::Ready(Some(Ok(RawStreamingChoice::ReasoningDelta {
id,
reasoning,
})))
}
RawStreamingChoice::ToolCall {
id,
name,
Expand Down Expand Up @@ -514,6 +529,10 @@ mod tests {
print!("{reasoning}");
std::io::Write::flush(&mut std::io::stdout()).unwrap();
}
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
println!("Reasoning delta: {reasoning}");
chunk_count += 1;
}
Err(e) => {
eprintln!("Error: {e:?}");
break;
Expand Down Expand Up @@ -555,8 +574,15 @@ mod tests {
pub enum StreamedAssistantContent<R> {
Text(Text),
ToolCall(ToolCall),
ToolCallDelta { id: String, delta: String },
ToolCallDelta {
id: String,
delta: String,
},
Reasoning(Reasoning),
ReasoningDelta {
id: Option<String>,
reasoning: String,
},
Final(R),
}

Expand Down