Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions R/jit.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@
#' allowing in-place operations and reducing memory usage.
#' @param device (`NULL` | `character(1)` | [`PJRTDevice`][pjrt::pjrt_device])\cr
#' The device to use if no input tensors are provided to infer the platform.
#' @param async (`logical(1)`)\cr
#' Whether to use asynchronous execution. When `TRUE` (the default),
#' computations are dispatched without blocking until the result is needed
#' (e.g., when calling `as_array()`). This improves performance by allowing
#' multiple operations to be pipelined.
#' @return (`function`)
#' @export
jit <- function(f, static = character(), cache_size = 100L, donate = character(), device = NULL) {
jit <- function(f, static = character(), cache_size = 100L, donate = character(),
device = NULL, async = TRUE) {
cache <- xlamisc::LRUCache$new(cache_size)
assert_subset(static, formalArgs2(f))
assert_string(device, null.ok = TRUE)

call_xla <- function(exec, out_node, consts_flat, args_flat, is_static_flat, avals_out = NULL) {
args_nonstatic <- args_flat[!is_static_flat]
args_unwrapped <- lapply(args_nonstatic, \(a) a$tensor)
args_unwrapped <- lapply(args_nonstatic, get_buffer)
out_vals <- rlang::exec(
pjrt::pjrt_execute,
exec,
Expand All @@ -38,6 +44,24 @@ jit <- function(f, static = character(), cache_size = 100L, donate = character()
unflatten(out_node, out_vals)
}

call_xla_async <- function(exec, out_node, consts_flat, args_flat, is_static_flat, avals_out = NULL) {
args_nonstatic <- args_flat[!is_static_flat]
args_unwrapped <- lapply(args_nonstatic, get_buffer)
out_vals <- rlang::exec(
pjrt::pjrt_execute_async,
exec,
!!!consts_flat,
!!!args_unwrapped,
simplify = FALSE
)
if (!is.null(avals_out)) {
out_vals <- Map(function(val, aval) nv_tensor_from_promise(val, ambiguous = aval$ambiguous), out_vals, avals_out)
} else {
out_vals <- lapply(out_vals, nv_tensor_from_promise)
}
unflatten(out_node, out_vals)
}

f_jit <- function() {
args <- as.list(match.call())[-1L]
args <- lapply(args, eval, envir = parent.frame())
Expand Down Expand Up @@ -82,7 +106,8 @@ jit <- function(f, static = character(), cache_size = 100L, donate = character()
class(in_tree) <- c("ListNode", "Node")
cache_hit <- cache$get(list(in_tree, avals_in, platform))
if (!is.null(cache_hit)) {
return(call_xla(cache_hit[[1]], cache_hit[[2]], cache_hit[[3]], args_flat, is_static_flat, cache_hit[[4]]))
call_fn <- if (async) call_xla_async else call_xla
return(call_fn(cache_hit[[1]], cache_hit[[2]], cache_hit[[3]], args_flat, is_static_flat, cache_hit[[4]]))
}
desc <- local_descriptor()
graph <- trace_fn(f, desc = desc, toplevel = TRUE, args_flat = avals_in, in_tree = in_tree)
Expand Down Expand Up @@ -113,7 +138,8 @@ jit <- function(f, static = character(), cache_size = 100L, donate = character()
program <- pjrt_program(src = src, format = "mlir")
exec <- pjrt_compile(program, client = pjrt::pjrt_client(platform))
cache$set(list(in_tree, avals_in, platform), list(exec, out_tree, const_tensors, avals_out))
call_xla(exec, out_tree, const_tensors, args_flat, is_static_flat, avals_out)
call_fn <- if (async) call_xla_async else call_xla
call_fn(exec, out_tree, const_tensors, args_flat, is_static_flat, avals_out)
}
formals(f_jit) <- formals2(f)
f_jit
Expand Down
4 changes: 2 additions & 2 deletions R/stablehlo.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

#' @export
hlo_scalar.AnvilTensor <- function(value, ..., func = NULL) {
stablehlo::hlo_scalar(value$tensor, ..., func = func)
stablehlo::hlo_scalar(await_tensor(value), ..., func = func)
}

#' @export
hlo_tensor.AnvilTensor <- function(value, ..., func = NULL) {
stablehlo::hlo_tensor(value$tensor, ..., func = func)
stablehlo::hlo_tensor(await_tensor(value), ..., func = func)
}

#' @title HloEnv
Expand Down
47 changes: 38 additions & 9 deletions R/tensor.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,33 @@ is_anvil_tensor <- function(x) {
inherits(x, "AnvilTensor")
}

# Extract buffer without blocking (buffer is valid immediately even in promise)
get_buffer <- function(x) {
tensor <- x$tensor
if (inherits(tensor, "PJRTBufferPromise")) {
tensor$buffer
} else {
tensor
}
}

# Block until tensor is ready, return buffer
await_tensor <- function(x) {
tensor <- x$tensor
if (inherits(tensor, "PJRTBufferPromise")) {
pjrt::value(tensor)
} else {
tensor
}
}

#' Get the underlying PJRT buffer from an AnvilTensor or pass through other values
#' @param x An AnvilTensor or any other value
#' @return The underlying PJRT buffer if x is an AnvilTensor, otherwise x unchanged
#' @keywords internal
unwrap_if_tensor <- function(x) {
if (is_anvil_tensor(x)) {
x$tensor
await_tensor(x)
} else {
x
}
Expand All @@ -81,13 +101,22 @@ ensure_nv_tensor <- function(x, ambiguous = FALSE) {
}
return(x)
}
assert_class(x, "PJRTBuffer")
if (!inherits(x, "PJRTBuffer") && !inherits(x, "PJRTBufferPromise")) {
cli_abort("x must be a PJRTBuffer or PJRTBufferPromise, not {.cls {class(x)}}")
}
structure(
list(tensor = x, ambiguous = ambiguous),
class = "AnvilTensor"
)
}

nv_tensor_from_promise <- function(promise, ambiguous = FALSE) {
structure(
list(tensor = promise, ambiguous = ambiguous),
class = "AnvilTensor"
)
}

#' @rdname AnvilTensor
#' @export
nv_scalar <- function(data, dtype = NULL, device = NULL, ambiguous = NULL) {
Expand Down Expand Up @@ -120,7 +149,7 @@ nv_aten <- function(dtype, shape, ambiguous = FALSE) {

#' @export
dtype.AnvilTensor <- function(x, ...) {
as_dtype(as.character(pjrt::elt_type(x$tensor)))
as_dtype(as.character(pjrt::elt_type(get_buffer(x))))
}

#' @title Get Ambiguity of a Tensor
Expand All @@ -146,28 +175,28 @@ ambiguous.AbstractTensor <- function(x, ...) {

#' @export
shape.AnvilTensor <- function(x, ...) {
tengen::shape(x$tensor)
tengen::shape(get_buffer(x))
}

#' @export
as_array.AnvilTensor <- function(x, ...) {
tengen::as_array(x$tensor)
tengen::as_array(await_tensor(x))
}

#' @export
as_raw.AnvilTensor <- function(x, ...) {
tengen::as_raw(x$tensor)
tengen::as_raw(await_tensor(x))
}

#' @method ndims AnvilTensor
#' @export
ndims.AnvilTensor <- function(x, ...) {
tengen::ndims(x$tensor)
tengen::ndims(get_buffer(x))
}

#' @export
platform.AnvilTensor <- function(x, ...) {
pjrt::platform(x$tensor)
pjrt::platform(get_buffer(x))
}

#' @title Abstract Tensor Class
Expand Down Expand Up @@ -448,7 +477,7 @@ print.AnvilTensor <- function(x, header = TRUE, ...) {
dtype_str <- paste0(as.character(dtype(x)), if (x$ambiguous) "?")
footer <- sprintf("[ %s%s{%s} ]", toupper(platform(x)), dtype_str, paste0(shape(x), collapse = ","))

print(x$tensor, header = FALSE, footer = footer)
print(await_tensor(x), header = FALSE, footer = footer)
invisible(x)
}

Expand Down
168 changes: 168 additions & 0 deletions bench.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Benchmark: Async execution vs nv_while for logistic regression
#
# This benchmark compares three approaches:
# 1. nv_while: Single compiled function with XLA while loop
# 2. R loop + async=TRUE: R loop calling async jit function
# 3. R loop + async=FALSE: R loop calling sync jit function

devtools::load_all(".")

# --- Data Setup ---
set.seed(42)

# Create synthetic data for logistic regression
n <- 1000L
p <- 10L

X <- matrix(rnorm(n * p), nrow = n, ncol = p)
beta_true <- rnorm(p)
alpha_true <- 0.5
probs <- 1 / (1 + exp(-(X %*% beta_true + alpha_true)))
y <- rbinom(n, 1, probs)

X_tensor <- nv_tensor(X, dtype = "f32")
y_tensor <- nv_tensor(y, dtype = "f32", shape = c(n, 1L))

# --- Model Functions ---
predict_proba <- function(X, beta, alpha) {
logits <- X %*% beta + alpha
nv_logistic(logits)
}

binary_cross_entropy <- function(y_true, y_pred) {
eps <- 1e-7
y_pred_clipped <- nv_clamp(eps, y_pred, 1 - eps)
loss <- -(y_true * log(y_pred_clipped) + (1 - y_true) * log(1 - y_pred_clipped))
mean(loss)
}

model_loss <- function(X, y, beta, alpha) {
y_pred <- predict_proba(X, beta, alpha)
binary_cross_entropy(y, y_pred)
}

model_loss_grad <- gradient(model_loss, wrt = c("beta", "alpha"))

# --- Approach 1: nv_while (single compiled function) ---
fit_logreg_while <- jit(function(X, y, beta, alpha, n_epochs, lr) {
output <- nv_while(
list(beta = beta, alpha = alpha, epoch = nv_scalar(0L)),
\(beta, alpha, epoch) epoch < n_epochs,
\(beta, alpha, epoch) {
grads <- model_loss_grad(X, y, beta, alpha)
list(
beta = beta - lr * grads$beta,
alpha = alpha - lr * grads$alpha,
epoch = epoch + 1L
)
}
)
list(beta = output$beta, alpha = output$alpha)
})

# --- Approach 2 & 3: R loop with jit step function ---
make_fit_logreg_rloop <- function(async) {
step_fn <- jit(function(X, y, beta, alpha, lr) {
grads <- model_loss_grad(X, y, beta, alpha)
list(
beta = beta - lr * grads$beta,
alpha = alpha - lr * grads$alpha
)
}, async = async, donate = c("beta", "alpha"))

function(X, y, beta, alpha, n_epochs, lr) {
for (i in seq_len(n_epochs)) {
result <- step_fn(X, y, beta, alpha, lr)
beta <- result$beta
alpha <- result$alpha
}
list(beta = beta, alpha = alpha)
}
}

fit_logreg_rloop_async <- make_fit_logreg_rloop(async = TRUE)
fit_logreg_rloop_sync <- make_fit_logreg_rloop(async = FALSE)

# --- Benchmark ---
run_benchmark <- function(n_epochs, n_reps = 5) {
# Helper to create fresh initial parameters (needed because donate invalidates buffers)
make_init <- function() {
list(
beta = nv_tensor(rnorm(p), dtype = "f32", shape = c(p, 1L)),
alpha = nv_scalar(0, dtype = "f32"),
lr = nv_scalar(0.1),
n_epochs_tensor = nv_scalar(as.integer(n_epochs))
)
}

# Warmup runs
cat("Warming up...\n")
init <- make_init()
invisible(fit_logreg_while(X_tensor, y_tensor, init$beta, init$alpha, init$n_epochs_tensor, init$lr))
init <- make_init()
invisible(fit_logreg_rloop_async(X_tensor, y_tensor, init$beta, init$alpha, n_epochs, init$lr))
init <- make_init()
invisible(fit_logreg_rloop_sync(X_tensor, y_tensor, init$beta, init$alpha, n_epochs, init$lr))

cat(sprintf("\nBenchmarking with %d epochs, %d repetitions...\n\n", n_epochs, n_reps))

# Benchmark nv_while
times_while <- numeric(n_reps)
for (i in seq_len(n_reps)) {
init <- make_init()
t0 <- Sys.time()
result <- fit_logreg_while(X_tensor, y_tensor, init$beta, init$alpha, init$n_epochs_tensor, init$lr)
# Force evaluation
invisible(as_array(result$beta))
times_while[i] <- as.numeric(Sys.time() - t0, units = "secs")
}

# Benchmark R loop + async
times_async <- numeric(n_reps)
for (i in seq_len(n_reps)) {
init <- make_init()
t0 <- Sys.time()
result <- fit_logreg_rloop_async(X_tensor, y_tensor, init$beta, init$alpha, n_epochs, init$lr)
# Force evaluation
invisible(as_array(result$beta))
times_async[i] <- as.numeric(Sys.time() - t0, units = "secs")
}

# Benchmark R loop + sync
times_sync <- numeric(n_reps)
for (i in seq_len(n_reps)) {
init <- make_init()
t0 <- Sys.time()
result <- fit_logreg_rloop_sync(X_tensor, y_tensor, init$beta, init$alpha, n_epochs, init$lr)
# Force evaluation
invisible(as_array(result$beta))
times_sync[i] <- as.numeric(Sys.time() - t0, units = "secs")
}

# Results
cat("Results (seconds):\n")
cat(sprintf(" nv_while: %.4f (sd: %.4f)\n", mean(times_while), sd(times_while)))
cat(sprintf(" R loop + async: %.4f (sd: %.4f)\n", mean(times_async), sd(times_async)))
cat(sprintf(" R loop + sync: %.4f (sd: %.4f)\n", mean(times_sync), sd(times_sync)))

cat("\nSpeedup ratios:\n")
cat(sprintf(" nv_while vs R loop + sync: %.2fx\n", mean(times_sync) / mean(times_while)))
cat(sprintf(" nv_while vs R loop + async: %.2fx\n", mean(times_async) / mean(times_while)))
cat(sprintf(" async vs sync: %.2fx\n", mean(times_sync) / mean(times_async)))

invisible(list(
nv_while = times_while,
async = times_async,
sync = times_sync
))
}

# Run benchmarks with different epoch counts
cat("=== Small iteration count (100 epochs) ===\n")
run_benchmark(n_epochs = 100)

cat("\n=== Medium iteration count (500 epochs) ===\n")
run_benchmark(n_epochs = 500)

cat("\n=== Large iteration count (1000 epochs) ===\n")
run_benchmark(n_epochs = 1000)
Loading
Loading