Skip to content
Draft
9 changes: 7 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Imports:
jsonlite,
later (>= 1.4.0),
lifecycle,
promises (>= 1.3.1),
otel (>= 0.2.0),
promises (>= 1.3.3.9004),
R6,
rlang (>= 1.1.0),
S7 (>= 0.2.0)
Expand All @@ -52,13 +53,16 @@ Suggests:
withr
VignetteBuilder:
knitr
Remotes:
rstudio/promises
Config/Needs/check: r-lib/otelsdk
Config/Needs/website: tidyverse/tidytemplate, rmarkdown
Config/testthat/edition: 3
Config/testthat/parallel: true
Config/testthat/start-first: chat, provider*
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
Collate:
'utils-S7.R'
'types.R'
Expand All @@ -83,6 +87,7 @@ Collate:
'import-standalone-types-check.R'
'interpolate.R'
'live.R'
'otel.R'
'parallel-chat.R'
'params.R'
'provider-anthropic.R'
Expand Down
29 changes: 23 additions & 6 deletions R/chat-tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ on_load({
echo = "none",
on_tool_request = function(request) invisible(),
on_tool_result = function(result) invisible(),
yield_request = FALSE
yield_request = FALSE,
parent_ospan = NULL
) {
tool_requests <- extract_tool_requests(turn)

Expand All @@ -51,7 +52,7 @@ on_load({
next
}

result <- invoke_tool(request)
result <- invoke_tool(request, parent_ospan = parent_ospan)

if (promises::is.promise(result@value)) {
cli::cli_abort(
Expand All @@ -78,7 +79,8 @@ on_load({
echo = "none",
on_tool_request = function(request) invisible(),
on_tool_result = function(result) invisible(),
yield_request = FALSE
yield_request = FALSE,
parent_ospan = NULL
) {
tool_requests <- extract_tool_requests(turn)

Expand All @@ -94,7 +96,7 @@ on_load({
return(rejected)
}

result <- coro::await(invoke_tool_async(request))
result <- coro::await(invoke_tool_async(request, parent_ospan))

maybe_echo_tool(result, echo = echo)
on_tool_result(result)
Expand Down Expand Up @@ -144,7 +146,7 @@ new_tool_result <- function(request, result = NULL, error = NULL) {
}

# Also need to handle edge cases: https://platform.openai.com/docs/guides/function-calling/edge-cases
invoke_tool <- function(request) {
invoke_tool <- function(request, parent_ospan = NULL) {
if (is.null(request@tool)) {
return(new_tool_result(request, error = "Unknown tool"))
}
Expand All @@ -155,19 +157,25 @@ invoke_tool <- function(request) {
return(args)
}

tool_ospan <- start_local_active_tool_ospan(
request,
parent_ospan = parent_ospan
)

tryCatch(
{
result <- do.call(request@tool, args)
new_tool_result(request, result)
},
error = function(e) {
record_tool_ospan_error(tool_ospan, e)
new_tool_result(request, error = e)
}
)
}

on_load(
invoke_tool_async <- coro::async(function(request) {
invoke_tool_async <- coro::async(function(request, parent_ospan = NULL) {
if (is.null(request@tool)) {
return(new_tool_result(request, error = "Unknown tool"))
}
Expand All @@ -178,12 +186,21 @@ on_load(
return(args)
}

tool_ospan <- start_local_active_tool_ospan(
request,
parent_ospan = parent_ospan
)
# Must activate the span in a promise domain so that it propagates to
# async calls made by the tool function.
activate_and_cleanup_ospan(tool_ospan, ospan_promise_domain = TRUE)

tryCatch(
{
result <- await(do.call(request@tool, args))
new_tool_result(request, result)
},
error = function(e) {
record_tool_ospan_error(tool_ospan, e)
new_tool_result(request, error = e)
}
)
Expand Down
95 changes: 72 additions & 23 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,15 @@ Chat <- R6::R6Class(
tool_errors <- list()
withr::defer(warn_tool_errors(tool_errors))

agent_ospan <- local_agent_ospan(private$provider)

while (!is.null(user_turn)) {
assistant_chunks <- private$submit_turns(
user_turn,
stream = stream,
echo = echo,
yield_as_content = yield_as_content
yield_as_content = yield_as_content,
parent_ospan = agent_ospan
)
for (chunk in assistant_chunks) {
yield(chunk)
Expand All @@ -511,7 +514,8 @@ Chat <- R6::R6Class(
echo = echo,
on_tool_request = private$callback_on_tool_request$invoke,
on_tool_result = private$callback_on_tool_result$invoke,
yield_request = yield_as_content
yield_request = yield_as_content,
parent_ospan = agent_ospan
)

tool_results <- list()
Expand Down Expand Up @@ -550,12 +554,15 @@ Chat <- R6::R6Class(
tool_errors <- list()
withr::defer(warn_tool_errors(tool_errors))

agent_ospan <- local_agent_ospan(private$provider)

while (!is.null(user_turn)) {
assistant_chunks <- private$submit_turns_async(
user_turn,
stream = stream,
echo = echo,
yield_as_content = yield_as_content
yield_as_content = yield_as_content,
parent_ospan = agent_ospan
)
for (chunk in await_each(assistant_chunks)) {
yield(chunk)
Expand All @@ -570,11 +577,12 @@ Chat <- R6::R6Class(
echo = echo,
on_tool_request = private$callback_on_tool_request$invoke_async,
on_tool_result = private$callback_on_tool_result$invoke_async,
yield_request = yield_as_content
yield_request = yield_as_content,
parent_ospan = agent_ospan
)
if (tool_mode == "sequential") {
tool_results <- list()
for (tool_step in coro::await_each(tool_calls)) {
for (tool_step in await_each(tool_calls)) {
if (yield_as_content) {
yield(tool_step)
}
Expand All @@ -583,7 +591,9 @@ Chat <- R6::R6Class(
}
}
} else {
# otel::with_active_span(agent_ospan, {
tool_results <- coro::collect(tool_calls)
# })
if (yield_as_content) {
# Filter out and yield tool requests before awaiting tool results
is_request <- map_lgl(tool_results, is_tool_request)
Expand Down Expand Up @@ -620,19 +630,31 @@ Chat <- R6::R6Class(
stream,
echo,
type = NULL,
yield_as_content = FALSE
yield_as_content = FALSE,
parent_ospan = NULL
) {
if (echo == "all") {
cat_line(format(user_turn), prefix = "> ")
}

response <- chat_perform(
provider = private$provider,
mode = if (stream) "stream" else "value",
turns = c(private$.turns, list(user_turn)),
tools = if (is.null(type)) private$tools,
type = type
chat_ospan <- local_chat_ospan(
private$provider,
parent_ospan = parent_ospan
)

promises::with_ospan_promise_domain({
otel::with_active_span(chat_ospan, {
response <- chat_perform(
provider = private$provider,
mode = if (stream) "stream" else "value",
turns = c(private$.turns, list(user_turn)),
tools = if (is.null(type)) private$tools,
type = type,
parent_ospan = chat_ospan
)
})
})

emit <- emitter(echo)
any_text <- FALSE

Expand All @@ -652,9 +674,15 @@ Chat <- R6::R6Class(

result <- stream_merge_chunks(private$provider, result, chunk)
}
turn <- value_turn(private$provider, result, has_type = !is.null(type))
record_chat_ospan_status(chat_ospan, result)
turn <- value_turn(
private$provider,
result,
has_type = !is.null(type)
)
turn <- match_tools(turn, private$tools)
} else {
record_chat_ospan_status(chat_ospan, response)
turn <- value_turn(
private$provider,
response,
Expand Down Expand Up @@ -705,21 +733,32 @@ Chat <- R6::R6Class(
stream,
echo,
type = NULL,
yield_as_content = FALSE
yield_as_content = FALSE,
parent_ospan = NULL
) {
response <- chat_perform(
provider = private$provider,
mode = if (stream) "async-stream" else "async-value",
turns = c(private$.turns, list(user_turn)),
tools = if (is.null(type)) private$tools,
type = type
chat_ospan <- local_chat_ospan(
private$provider,
parent_ospan = parent_ospan
)

promises::with_ospan_promise_domain({
otel::with_active_span(chat_ospan, {
response <- chat_perform(
provider = private$provider,
mode = if (stream) "async-stream" else "async-value",
turns = c(private$.turns, list(user_turn)),
tools = if (is.null(type)) private$tools,
type = type,
parent_ospan = chat_ospan
)
})
})
emit <- emitter(echo)
any_text <- FALSE

if (stream) {
result <- NULL
for (chunk in await_each(response)) {
for (chunk in coro::await_each(response)) {
text <- stream_text(private$provider, chunk)
if (!is.null(text)) {
emit(text)
Expand All @@ -733,11 +772,21 @@ Chat <- R6::R6Class(

result <- stream_merge_chunks(private$provider, result, chunk)
}
turn <- value_turn(private$provider, result, has_type = !is.null(type))
record_chat_ospan_status(chat_ospan, result)
turn <- value_turn(
private$provider,
result,
has_type = !is.null(type)
)
} else {
result <- await(response)

turn <- value_turn(private$provider, result, has_type = !is.null(type))
record_chat_ospan_status(chat_ospan, result)
turn <- value_turn(
private$provider,
result,
has_type = !is.null(type)
)
text <- turn@text
if (!is.null(text)) {
emit(text)
Expand Down
30 changes: 25 additions & 5 deletions R/httr2.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ chat_perform <- function(
mode = c("value", "stream", "async-stream", "async-value"),
turns,
tools = NULL,
type = NULL
type = NULL,
parent_ospan = NULL
) {
mode <- arg_match(mode)
stream <- mode %in% c("stream", "async-stream")
Expand All @@ -23,9 +24,13 @@ chat_perform <- function(
switch(
mode,
"value" = chat_perform_value(provider, req),
"stream" = chat_perform_stream(provider, req),
"stream" = chat_perform_stream(provider, req, parent_ospan = parent_ospan),
"async-value" = chat_perform_async_value(provider, req),
"async-stream" = chat_perform_async_stream(provider, req)
"async-stream" = chat_perform_async_stream(
provider,
req,
parent_ospan = parent_ospan
)
)
}

Expand All @@ -34,7 +39,14 @@ chat_perform_value <- function(provider, req) {
}

on_load(
chat_perform_stream <- coro::generator(function(provider, req) {
chat_perform_stream <- coro::generator(function(
provider,
req,
parent_ospan = NULL
) {
if (!is.null(parent_ospan)) {
otel::local_active_span(parent_ospan)
}
resp <- req_perform_connection(req)
on.exit(close(resp))

Expand All @@ -55,7 +67,15 @@ chat_perform_async_value <- function(provider, req) {
}

on_load(
chat_perform_async_stream <- coro::async_generator(function(provider, req) {
chat_perform_async_stream <- coro::async_generator(function(
provider,
req,
parent_ospan = NULL
) {
if (!is.null(parent_ospan)) {
promises::local_ospan_promise_domain()
otel::local_active_span(parent_ospan)
}
resp <- req_perform_connection(req, blocking = FALSE)
on.exit(close(resp))

Expand Down
Loading
Loading