From 9ab05195561f651884f752942f43fc7317aabbc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 4 Jun 2025 08:10:51 +0000 Subject: [PATCH 01/33] Lower autodiff functions using instrinsics --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 3 +++ compiler/rustc_hir_analysis/src/check/intrinsic.rs | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 7b27e496986ae..ba0f4b32ae7eb 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -189,6 +189,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } + _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + return Err(ty::Instance::new_raw(def_id, instance.args)); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 6e5fe3823ab51..2e9ee43896225 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -171,6 +171,8 @@ pub(crate) fn check_intrinsic_type( } }; + let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); + let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -198,6 +200,17 @@ pub(crate) fn check_intrinsic_type( let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { + _ if has_autodiff => { + let sig = tcx.fn_sig(intrinsic_id.to_def_id()); + let sig = sig.skip_binder(); + let n_tps = generics.own_counts().types; + let n_cts = generics.own_counts().consts; + + let inputs = sig.skip_binder().inputs().to_vec(); + let output = sig.skip_binder().output(); + + (n_tps, n_cts, inputs, output) + } sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), From 4bc9e4f8770012fa3b755346e0e7a13484f5b9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 5 Jun 2025 17:46:21 +0000 Subject: [PATCH 02/33] Macro expansion with `rustc_intrinsic` WARNING: ad function defined in traits are broken --- compiler/rustc_builtin_macros/src/autodiff.rs | 20 ++- tests/pretty/autodiff/autodiff_forward.pp | 136 +++++------------- tests/pretty/autodiff/autodiff_forward.rs | 1 + tests/pretty/autodiff/autodiff_reverse.pp | 43 ++---- tests/pretty/autodiff/autodiff_reverse.rs | 5 +- tests/pretty/autodiff/inherent_impl.pp | 3 +- tests/pretty/autodiff/inherent_impl.rs | 1 + 7 files changed, 67 insertions(+), 142 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c784477833279..374a2f9d47ee7 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -330,7 +330,9 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // UNUSED + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); @@ -342,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: Some(d_body), + body: None, define_opaque: None, }); let mut rustc_ad_attr = @@ -429,12 +431,18 @@ mod llvm_enzyme { tokens: ts, }); + let rustc_intrinsic_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr, intrinsic_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -444,13 +452,15 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index a2525abc83207..787c2e517492c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -36,78 +37,44 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] -#[inline(never)] -pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f64, f64)>::default()) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64); #[rustc_autodiff] #[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f2(x, y)) -} +#[rustc_intrinsic] +pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -#[inline(never)] -pub fn df4() -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4() -> (); #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] -#[inline(never)] -pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((by_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64; #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; struct DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] @@ -115,84 +82,47 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const)] -#[inline(never)] -pub fn df6() -> DoesNotImplDefault { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f6()); - ::core::hint::black_box(()); - ::core::hint::black_box(f6()) -} +#[rustc_intrinsic] +pub fn df6() -> DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] -#[inline(never)] -pub fn df7(x: f32) -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f7(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df7(x: f32) -> (); #[no_mangle] #[rustc_autodiff] #[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] -#[inline(never)] +#[rustc_intrinsic] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 5usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 5usize]>::default()) -} +-> [f32; 5usize]; #[rustc_autodiff(Forward, 4, Dual, DualOnly)] -#[inline(never)] +#[rustc_intrinsic] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 4usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 4usize]>::default()) -} +-> [f32; 4usize]; #[rustc_autodiff(Forward, 1, Dual, DualOnly)] -#[inline(never)] -fn f8_1(x: &f32, bx_0: &f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) -} +#[rustc_intrinsic] +fn f8_1(x: &f32, bx_0: &f32) -> f32; pub fn f9() { #[rustc_autodiff] #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] - #[inline(never)] - fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f32, f32)>::default()) - } + #[rustc_intrinsic] + fn d_inner_2(x: f32, bx_0: f32) + -> (f32, f32); #[rustc_autodiff(Forward, 1, Dual, DualOnly)] - #[inline(never)] - fn d_inner_1(x: f32, bx_0: f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) - } + #[rustc_intrinsic] + fn d_inner_1(x: f32, bx_0: f32) + -> f32; } #[rustc_autodiff] #[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] -#[inline(never)] +#[rustc_intrinsic] pub fn d_square + - Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f10::(x)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f10::(x)) -} +Copy>(x: &T, dx_0: &mut T, dret: T) -> T; fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index e23a1b3e241e9..b003d87dccfa7 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index e67c3443ddef1..6f368c74f1a26 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -29,58 +30,36 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f1(x, y)) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -#[inline(never)] -pub fn df2() { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df2(); #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -#[inline(never)] -pub fn df4(x: f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4(x: f32); #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] -#[inline(never)] -pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dy_0)); -} +#[rustc_intrinsic] +pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32); fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index d37e5e3eb4cec..fc95ba2e5a63e 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp @@ -23,7 +24,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 { unimplemented!() } -enum Foo { Reverse } +enum Foo { + Reverse, +} use Foo::Reverse; // What happens if we already have Reverse in type (enum variant decl) and value (enum variant // constructor) namespace? > It's expected to work normally. diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index d18061b2dbdef..4bc8dac0dc758 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -31,7 +32,7 @@ self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] - #[inline(never)] + #[rustc_intrinsic] fn df(&self, x: f64, dret: f64) -> (f64, f64) { unsafe { asm!("NOP", options(pure, nomem)); }; ::core::hint::black_box(self.f(x)); diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 11ff209f9d89e..9f00ff5eb02c1 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp From 6eb931c88687f80468621066dd340f71e5ef8af3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 17 Jun 2025 19:52:00 +0000 Subject: [PATCH 03/33] Lowering draft --- compiler/rustc_builtin_macros/src/autodiff.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 374a2f9d47ee7..9df0497f8e060 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -344,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: None, + body: None, // This leads to an error when the ad function is inside a traits define_opaque: None, }); let mut rustc_ad_attr = diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index ba0f4b32ae7eb..2ead51d7b9b40 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -190,6 +190,21 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { ) } _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + // NOTE(Sa4dUs): This is a hacky way to get the autodiff items + // so we can focus on the lowering of the intrinsic call + + // `diff_items` is empty even when autodiff is enabled, and if we're here, + // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr + let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + + // this shouldn't happen? + if diff_items.is_empty() { + bug!("no autodiff items found for {def_id:?}"); + } + + // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + + // Just gen the fallback body for now return Err(ty::Instance::new_raw(def_id, instance.args)); } sym::is_val_statically_known => { From 4a3203e084734fce25e8dfefcb540c44f4f25521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 23 Jun 2025 12:17:53 +0000 Subject: [PATCH 04/33] Naive impl of intrinsic codegen Note(Sa4dUs): Most tests are still broken due to `sret` and how funcs are searched in the current logic --- .../src/builder/autodiff.rs | 62 ++++++------------- compiler/rustc_codegen_llvm/src/intrinsic.rs | 62 +++++++++++++++---- tests/codegen-llvm/autodiff/scalar.rs | 1 + 3 files changed, 70 insertions(+), 55 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 829b3c513c258..25bbc350d31c6 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,13 +3,13 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::common::TypeKind; -use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{SBuilder, UNNAMED}; +use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; @@ -18,7 +18,7 @@ use crate::llvm::{Metadata, True}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; -fn get_params(fnc: &Value) -> Vec<&Value> { +fn _get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; let mut fnc_args: Vec<&Value> = vec![]; fnc_args.reserve(param_num); @@ -48,9 +48,9 @@ fn has_sret(fnc: &Value) -> bool { // need to match those. // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it // using iterators and peek()? -fn match_args_from_caller_to_enzyme<'ll>( +fn match_args_from_caller_to_enzyme<'ll, 'tcx>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll, 'll>, + builder: &mut Builder<'_, 'll, 'tcx>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], @@ -288,11 +288,14 @@ fn compute_enzyme_fn_ty<'ll>( /// [^1]: // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll>( +pub(crate) fn generate_enzyme_call<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_fn: &'ll Value, + fn_args: &[OperandRef<'tcx, &'ll Value>], attrs: AutoDiffAttrs, + dest: PlaceRef<'tcx, &'ll Value>, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -365,14 +368,6 @@ fn generate_enzyme_call<'ll>( let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - // first, remove all calls from fnc - let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); - let br = llvm::LLVMRustGetTerminator(entry); - llvm::LLVMRustEraseInstFromParent(br); - - let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = SBuilder::build(cx, entry); - let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); @@ -388,10 +383,10 @@ fn generate_enzyme_call<'ll>( } let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = get_params(outer_fn); + let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); match_args_from_caller_to_enzyme( &cx, - &builder, + builder, attrs.width, &mut args, &attrs.input_activity, @@ -399,29 +394,9 @@ fn generate_enzyme_call<'ll>( has_sret, ); - let call = builder.call(enzyme_ty, ad_fn, &args, None); - - // This part is a bit iffy. LLVM requires that a call to an inlineable function has some - // metadata attached to it, but we just created this code oota. Given that the - // differentiated function already has partly confusing metadata, and given that this - // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the - // dummy code which we inserted at a higher level. - // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, - // and how to best improve it for enzyme core and rust-enzyme. - let md_ty = cx.get_md_kind_id("dbg"); - if llvm::LLVMRustHasMetadata(last_inst, md_ty) { - let md = llvm::LLVMRustDIGetInstMetadata(last_inst) - .expect("failed to get instruction metadata"); - let md_todiff = cx.get_metadata_value(md); - llvm::LLVMSetMetadata(call, md_ty, md_todiff); - } else { - // We don't panic, since depending on whether we are in debug or release mode, we might - // have no debug info to copy, which would then be ok. - trace!("no dbg info"); - } + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); + builder.store_to_place(call, dest.val); if cx.val_ty(call) == cx.type_void() || has_sret { if has_sret { @@ -444,10 +419,10 @@ fn generate_enzyme_call<'ll>( llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); } builder.ret_void(); - } else { - builder.ret(call); } + builder.store_to_place(call, dest.val); + // Let's crash in case that we messed something up above and generated invalid IR. llvm::LLVMRustVerifyFunction( outer_fn, @@ -461,6 +436,7 @@ pub(crate) fn differentiate<'ll>( cgcx: &CodegenContext, diff_items: Vec, ) -> Result<(), FatalError> { + // TODO(Sa4dUs): delete all this logic for item in &diff_items { trace!("{}", item); } @@ -480,7 +456,7 @@ pub(crate) fn differentiate<'ll>( for item in diff_items.iter() { let name = item.source.clone(); let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(fn_def) = fn_def else { + let Some(_fn_def) = fn_def else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -492,7 +468,7 @@ pub(crate) fn differentiate<'ll>( }; debug!(?item.target); let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(fn_target) = fn_target else { + let Some(_fn_target) = fn_target else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -503,7 +479,7 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 2ead51d7b9b40..affd0074ea256 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -9,17 +9,19 @@ use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; use rustc_hir as hir; +use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; -use rustc_symbol_mangling::mangle_internal_symbol; +use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::spec::PanicStrategy; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; +use crate::builder::autodiff::generate_enzyme_call; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -189,23 +191,59 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } - _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + _ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => { // NOTE(Sa4dUs): This is a hacky way to get the autodiff items // so we can focus on the lowering of the intrinsic call + let mut source_id = None; + let mut diff_attrs = None; + let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect(); + + // Hacky way of getting primal-diff pair, only works for code with 1 autodiff call + for target_id in &items { + let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else { + continue; + }; - // `diff_items` is empty even when autodiff is enabled, and if we're here, - // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr - let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + if target_attrs.is_source() { + source_id = Some(*target_id); + } else { + diff_attrs = Some(target_attrs); + } + } - // this shouldn't happen? - if diff_items.is_empty() { - bug!("no autodiff items found for {def_id:?}"); + if source_id.is_none() || diff_attrs.is_none() { + bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}"); } - // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + let diff_attrs = diff_attrs.unwrap().clone(); + + // Get source fn + let source_id = source_id.unwrap(); + let fn_source = Instance::mono(tcx, source_id); + let source_symbol = + symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + // Declare target fn + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty()); + let outer_fn: &'ll Value = + self.cx.declare_fn(&target_symbol, fn_abi, Some(instance)); + + // Build body + generate_enzyme_call( + self, + self.cx, + fn_to_diff, + outer_fn, + args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore + diff_attrs.clone(), + result, + ); - // Just gen the fallback body for now - return Err(ty::Instance::new_raw(def_id, instance.args)); + return Ok(()); } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs index 096b4209e84ad..c2bca7e9c81ef 100644 --- a/tests/codegen-llvm/autodiff/scalar.rs +++ b/tests/codegen-llvm/autodiff/scalar.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; From 63a0cfdb317a70003c449fd7a1d61b6e7bd12bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 24 Jun 2025 14:00:26 +0000 Subject: [PATCH 05/33] Feature intrinsics in cg tests --- tests/codegen-llvm/autodiff/batched.rs | 1 + tests/codegen-llvm/autodiff/generic.rs | 1 + tests/codegen-llvm/autodiff/identical_fnc.rs | 1 + tests/codegen-llvm/autodiff/inline.rs | 1 + tests/codegen-llvm/autodiff/sret.rs | 28 ++++++++++---------- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs index d27aed50e6cc4..88a1de9994c8a 100644 --- a/tests/codegen-llvm/autodiff/batched.rs +++ b/tests/codegen-llvm/autodiff/batched.rs @@ -10,6 +10,7 @@ // reduce this test to only match the first lines and the ret instructions. #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_forward; diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs index 2f674079be021..af9706c621208 100644 --- a/tests/codegen-llvm/autodiff/generic.rs +++ b/tests/codegen-llvm/autodiff/generic.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs index 1c25b3d09ab0d..ff8e6c74a6b34 100644 --- a/tests/codegen-llvm/autodiff/identical_fnc.rs +++ b/tests/codegen-llvm/autodiff/identical_fnc.rs @@ -10,6 +10,7 @@ // We also explicetly test that we keep running merge_function after AD, by checking for two // identical function calls in the LLVM-IR, while having two different calls in the Rust code. #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/inline.rs b/tests/codegen-llvm/autodiff/inline.rs index 65bed170207cc..5db69b960343c 100644 --- a/tests/codegen-llvm/autodiff/inline.rs +++ b/tests/codegen-llvm/autodiff/inline.rs @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs index d2fa85e3e3787..67f68fc053cc4 100644 --- a/tests/codegen-llvm/autodiff/sret.rs +++ b/tests/codegen-llvm/autodiff/sret.rs @@ -8,6 +8,7 @@ // We therefore use this test to verify some of our sret handling. #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; @@ -17,26 +18,25 @@ fn primal(x: f32, y: f32) -> f64 { (x * x * y) as f64 } -// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) -// CHECK-NEXT:start: -// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) -// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 -// CHECK-NEXT: store double %.elt, ptr %_0, align 8 -// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 -// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 -// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 -// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 -// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 -// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 -// CHECK-NEXT: ret void -// CHECK-NEXT:} +// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y) +// CHECK-NEXT: invertstart: +// CHECK-NEXT: %_4 = fmul float %x, %x +// CHECK-NEXT: %_3 = fmul float %_4, %y +// CHECK-NEXT: %_0 = fpext float %_3 to double +// CHECK-NEXT: %0 = fadd fast float %y, %y +// CHECK-NEXT: %1 = fmul fast float %0, %x +// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0 +// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1 +// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2 +// CHECK-NEXT: ret { double, float, float } %4 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); let y = std::hint::black_box(2.5); let scalar = std::hint::black_box(1.0); let (r1, r2, r3) = df(x, y, scalar); - // 3*3*1.5 = 22.5 + // 3*3*2.5 = 22.5 assert_eq!(r1, 22.5); // 2*x*y = 2*3*2.5 = 15.0 assert_eq!(r2, 15.0); From 02395d7c9016bca75534367d6789f7930d03ed96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 29 Jun 2025 18:13:11 +0000 Subject: [PATCH 06/33] Remove `sret` logic --- .../src/builder/autodiff.rs | 252 ++++-------------- compiler/rustc_codegen_llvm/src/context.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 21 +- .../rustc_hir_analysis/src/check/intrinsic.rs | 7 +- 4 files changed, 74 insertions(+), 208 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 25bbc350d31c6..c72a567a30136 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -14,7 +14,7 @@ use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; -use crate::llvm::{Metadata, True}; +use crate::llvm::{Metadata, True, Type}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; @@ -29,7 +29,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> { fnc_args } -fn has_sret(fnc: &Value) -> bool { +fn _has_sret(fnc: &Value) -> bool { let num_args = llvm::LLVMCountParams(fnc) as usize; if num_args == 0 { false @@ -55,7 +55,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], outer_args: &[&'ll llvm::Value], - has_sret: bool, ) { debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level @@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - if has_sret { - // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the - // inner function will still return something. We increase our outer_pos by one, - // and once we're done with all other args we will take the return of the inner call and - // update the sret pointer with it - outer_pos = 1; - } - - let enzyme_const = cx.create_metadata(b"enzyme_const"); - let enzyme_out = cx.create_metadata(b"enzyme_out"); - let enzyme_dup = cx.create_metadata(b"enzyme_dup"); - let enzyme_dupv = cx.create_metadata(b"enzyme_dupv"); - let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed"); - let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv"); + let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); + let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); + let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); + let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); + let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); + let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( } } -// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input -// arguments. We do however need to declare them with their correct return type. -// We already figured the correct return type out in our frontend, when generating the outer_fn, -// so we can now just go ahead and use that. This is not always trivial, e.g. because sret. -// Beyond sret, this article describes our challenges nicely: -// -// I.e. (i32, f32) will get merged into i64, but we don't handle that yet. -fn compute_enzyme_fn_ty<'ll>( - cx: &SimpleCx<'ll>, - attrs: &AutoDiffAttrs, - fn_to_diff: &'ll Value, - outer_fn: &'ll Value, -) -> &'ll llvm::Type { - let fn_ty = cx.get_type_of_global(outer_fn); - let mut ret_ty = cx.get_return_type(fn_ty); - - let has_sret = has_sret(outer_fn); - - if has_sret { - // Now we don't just forward the return type, so we have to figure it out based on the - // primal return type, in combination with the autodiff settings. - let fn_ty = cx.get_type_of_global(fn_to_diff); - let inner_ret_ty = cx.get_return_type(fn_ty); - - let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) }; - if inner_ret_ty == void_ty { - // This indicates that even the inner function has an sret. - // Right now I only look for an sret in the outer function. - // This *probably* needs some extra handling, but I never ran - // into such a case. So I'll wait for user reports to have a test case. - bug!("sret in inner function"); - } - - if attrs.width == 1 { - // Enzyme returns a struct of style: - // `{ original_ret(if requested), float, float, ... }` - let mut struct_elements = vec![]; - if attrs.has_primal_ret() { - struct_elements.push(inner_ret_ty); - } - // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, - // and therefore part of the return struct. - let param_tys = cx.func_params_types(fn_ty); - for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { - if matches!(act, DiffActivity::Active) { - // Now find the float type at position i based on the fn_ty, - // to know what (f16/f32/f64/...) to add to the struct. - struct_elements.push(param_ty); - } - } - ret_ty = cx.type_struct(&struct_elements, false); - } else { - // First we check if we also have to deal with the primal return. - match attrs.mode { - DiffMode::Forward => match attrs.ret_activity { - DiffActivity::Dual => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) }; - ret_ty = arr_ty; - } - DiffActivity::DualOnly => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) }; - ret_ty = arr_ty; - } - DiffActivity::Const => { - todo!("Not sure, do we need to do something here?"); - } - _ => { - bug!("unreachable"); - } - }, - DiffMode::Reverse => { - todo!("Handle sret for reverse mode"); - } - _ => { - bug!("unreachable"); - } - } - } - } - - // LLVM can figure out the input types on it's own, so we take a shortcut here. - unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) } -} - /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another /// function with expected naming and calling conventions[^1] which will be /// discovered by the enzyme LLVM pass and its body populated with the differentiated @@ -292,7 +197,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, - outer_fn: &'ll Value, + outer_name: &str, + ret_ty: &'ll Type, fn_args: &[OperandRef<'tcx, &'ll Value>], attrs: AutoDiffAttrs, dest: PlaceRef<'tcx, &'ll Value>, @@ -305,11 +211,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( } .to_string(); - // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple + // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. - let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::str::from_utf8(&name).unwrap(); - ad_name.push_str(outer_fn_name); + ad_name.push_str(outer_name); // Let us assume the user wrote the following function square: // @@ -320,13 +224,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // ret double %0 // } // ``` - // - // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`. - // Our macro generates the following placeholder code (slightly simplified): - // - // ```llvm // define double @dsquare(double %x) { - // ; placeholder code // return 0.0; // } // ``` @@ -343,92 +241,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // ret double %0 // } // ``` - unsafe { - let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn); - - // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and - // think a bit more about what should go here. - let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_simple_fn( - cx, - &ad_name, - llvm::CallConv::try_from(cc).expect("invalid callconv"), - llvm::UnnamedAddr::No, - llvm::Visibility::Default, - enzyme_ty, - ); - - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - - // We add a made-up attribute just such that we can recognize it after AD to update - // (no)-inline attributes. We'll then also remove this attribute. - let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); - attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - - let num_args = llvm::LLVMCountParams(&fn_to_diff); - let mut args = Vec::with_capacity(num_args as usize + 1); - args.push(fn_to_diff); - - let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return"); - if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { - args.push(cx.get_metadata_value(enzyme_primal_ret)); - } - if attrs.width > 1 { - let enzyme_width = cx.create_metadata(b"enzyme_width"); - args.push(cx.get_metadata_value(enzyme_width)); - args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); - } - - let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); - match_args_from_caller_to_enzyme( - &cx, - builder, - attrs.width, - &mut args, - &attrs.input_activity, - &outer_args, - has_sret, - ); - - let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - - builder.store_to_place(call, dest.val); + let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }; + + // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and + // think a bit more about what should go here. + // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now + let cc = 8; + let ad_fn = declare_simple_fn( + cx, + &ad_name, + llvm::CallConv::try_from(cc).expect("invalid callconv"), + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + enzyme_ty, + ); + + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to + // do it's work. + let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); + attributes::apply_to_llfn(ad_fn, Function, &[attr]); + + let num_args = llvm::LLVMCountParams(&fn_to_diff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(fn_to_diff); + + let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { + args.push(cx.get_metadata_value(enzyme_primal_ret)); + } + if attrs.width > 1 { + let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + args.push(cx.get_metadata_value(enzyme_width)); + args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); + } - if cx.val_ty(call) == cx.type_void() || has_sret { - if has_sret { - // This is what we already have in our outer_fn (shortened): - // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) { - // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>) - // - // store [4 x double] %7, ptr %0, align 8 - // ret void - // } + let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); - // now store the result of the enzyme call into the sret pointer. - let sret_ptr = outer_args[0]; - let call_ty = cx.val_ty(call); - if attrs.width == 1 { - assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); - } else { - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); - } - llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); - } - builder.ret_void(); - } + match_args_from_caller_to_enzyme( + &cx, + builder, + attrs.width, + &mut args, + &attrs.input_activity, + &outer_args, + ); - builder.store_to_place(call, dest.val); + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Let's crash in case that we messed something up above and generated invalid IR. - llvm::LLVMRustVerifyFunction( - outer_fn, - llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, - ); - } + builder.store_to_place(call, dest.val); } pub(crate) fn differentiate<'ll>( diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index ee77774c68832..25e72ecc1cbe6 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -654,7 +654,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type { + pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { assert_eq!(self.type_kind(ty), TypeKind::Function); unsafe { llvm::LLVMGetReturnType(ty) } } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index affd0074ea256..e436df62d3c80 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -176,10 +176,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; + let callee_ty = instance.ty(tcx, self.typing_env()); - let name = tcx.item_name(instance.def_id()); let fn_args = instance.args; + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); + let ret_ty = sig.output(); + let name = tcx.item_name(instance.def_id()); + + let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let simple = call_simple_intrinsic(self, name, args); let llval = match name { _ if simple.is_some() => simple.unwrap(), @@ -225,20 +232,14 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - // Declare target fn - let target_symbol = - symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); - let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty()); - let outer_fn: &'ll Value = - self.cx.declare_fn(&target_symbol, fn_abi, Some(instance)); - // Build body generate_enzyme_call( self, self.cx, fn_to_diff, - outer_fn, - args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore + name.as_str(), + llret_ty, + args, diff_attrs.clone(), result, ); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 2e9ee43896225..5c03b2ce33e64 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -197,7 +197,12 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; - let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); + // FIXME(Sa4dUs): Get the actual safety level of the diff function + let safety = if has_autodiff { + hir::Safety::Safe + } else { + intrinsic_operation_unsafety(tcx, intrinsic_id) + }; let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { _ if has_autodiff => { From 2fc199d049fcdde8f225f03842b9ac8cfe8217f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 7 Jul 2025 17:24:17 +0000 Subject: [PATCH 07/33] Move logic to a dedicated `enzyme_autodiff` intrinsic --- compiler/rustc_builtin_macros/src/autodiff.rs | 143 ++++++++++++++++-- .../src/builder/autodiff.rs | 8 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 67 ++++---- .../rustc_hir_analysis/src/check/intrinsic.rs | 2 + compiler/rustc_span/src/symbol.rs | 1 + library/core/src/intrinsics/mod.rs | 4 + 6 files changed, 180 insertions(+), 45 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 9df0497f8e060..c00e659bb311a 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -331,20 +331,23 @@ mod llvm_enzyme { .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - // UNUSED + // TODO(Sa4dUs): Remove this and all the related logic let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); + let d_body = + call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics, + generics: generics.clone(), contract: None, - body: None, // This leads to an error when the ad function is inside a traits + body: Some(d_body), define_opaque: None, }); let mut rustc_ad_attr = @@ -431,10 +434,7 @@ mod llvm_enzyme { tokens: ts, }); - let rustc_intrinsic_attr = - P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic))); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span); + let vis_clone = vis.clone(); let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); @@ -442,7 +442,7 @@ mod llvm_enzyme { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, intrinsic_attr], + attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -452,15 +452,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = - ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = - ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { @@ -474,7 +472,9 @@ mod llvm_enzyme { } }; - return vec![orig_annotatable, d_annotatable]; + let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); + + return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -495,6 +495,123 @@ mod llvm_enzyme { ty } + // Generate `enzyme_autodiff` intrinsic call + // ``` + // std::intrinsics::enzyme_autodiff(source, diff, (args)) + // ``` + fn call_enzyme_autodiff( + ecx: &ExtCtxt<'_>, + primal: Ident, + diff: Ident, + span: Span, + d_sig: &FnSig, + ) -> P { + let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + + let tuple_expr = ecx.expr_tuple( + span, + d_sig + .decl + .inputs + .iter() + .map(|arg| match arg.pat.kind { + PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)), + _ => todo!(), + }) + .collect::>() + .into(), + ); + + let enzyme_path = ecx.path( + span, + vec![ + Ident::from_str("std"), + Ident::from_str("intrinsics"), + Ident::from_str("enzyme_autodiff"), + ], + ); + let call_expr = ecx.expr_call( + span, + ecx.expr_path(enzyme_path), + vec![primal_path_expr, diff_path_expr, tuple_expr].into(), + ); + + let block = ecx.block_expr(call_expr); + + block + } + + // Generate dummy const to prevent primal function + // from being optimized away before applying enzyme + // ``` + // const _: () = + // { + // #[used] + // pub static DUMMY_PTR: fn_type = primal_fn; + // }; + // ``` + fn gen_dummy_const( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + sig: FnSig, + generics: Generics, + vis: Visibility, + ) -> Annotatable { + // #[used] + let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let used_attr = outer_normal_attr(&used_attr, new_id, span); + + // static DUMMY_PTR: = + let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); + let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { + safety: sig.header.safety, + ext: sig.header.ext, + generic_params: generics.params, + decl: sig.decl, + decl_span: sig.span, + })); + let static_ty = ecx.ty(span, fn_ptr_ty); + + let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); + let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { + ident: static_ident, + ty: static_ty, + safety: ast::Safety::Default, + mutability: ast::Mutability::Not, + expr: Some(static_expr), + define_opaque: None, + })); + + let static_item = ast::Item { + attrs: thin_vec![used_attr], + id: ast::DUMMY_NODE_ID, + span, + vis, + kind: static_item_kind, + tokens: None, + }; + + let block_expr = ecx.expr_block(Box::new(ast::Block { + stmts: thin_vec![ecx.stmt_item(span, P(static_item))], + id: ast::DUMMY_NODE_ID, + rules: ast::BlockCheckMode::Default, + span, + tokens: None, + })); + + let const_item = ecx.item_const( + span, + Ident::from_str_and_span("_", span), + ecx.ty(span, ast::TyKind::Tup(thin_vec![])), + block_expr, + ); + + Annotatable::Item(const_item) + } + // Will generate a body of the type: // ``` // { diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index c72a567a30136..66c34fbcfb181 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -9,7 +9,7 @@ use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED}; +use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; @@ -199,7 +199,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( fn_to_diff: &'ll Value, outer_name: &str, ret_ty: &'ll Type, - fn_args: &[OperandRef<'tcx, &'ll Value>], + fn_args: &[&'ll Value], attrs: AutoDiffAttrs, dest: PlaceRef<'tcx, &'ll Value>, ) { @@ -275,15 +275,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); } - let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); - match_args_from_caller_to_enzyme( &cx, builder, attrs.width, &mut args, &attrs.input_activity, - &outer_args, + fn_args, ); let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index e436df62d3c80..c16b9b33cf0c6 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,6 +3,7 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; +use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -198,48 +199,60 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } - _ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => { - // NOTE(Sa4dUs): This is a hacky way to get the autodiff items - // so we can focus on the lowering of the intrinsic call - let mut source_id = None; - let mut diff_attrs = None; - let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect(); - - // Hacky way of getting primal-diff pair, only works for code with 1 autodiff call - for target_id in &items { - let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else { - continue; - }; + sym::enzyme_autodiff => { + let val_arr: Vec<&'ll Value> = match args[2].val { + crate::intrinsic::OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; - if target_attrs.is_source() { - source_id = Some(*target_id); - } else { - diff_attrs = Some(target_attrs); - } - } + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(self, i); + let field_layout = tuple_place.layout.field(self, i); + let llvm_ty = field_layout.llvm_type(self.cx); - if source_id.is_none() || diff_attrs.is_none() { - bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}"); - } + let field_val = + self.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } - let diff_attrs = diff_attrs.unwrap().clone(); + ret_arr + } + crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + }; - // Get source fn - let source_id = source_id.unwrap(); - let fn_source = Instance::mono(tcx, source_id); + // Get source, diff, and attrs + let source_id = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_source = Instance::mono(tcx, *source_id); let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + let diff_id = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_diff = Instance::mono(tcx, *diff_id); + let diff_symbol = + symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, *diff_id); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + // Build body generate_enzyme_call( self, self.cx, fn_to_diff, - name.as_str(), + &diff_symbol, llret_ty, - args, + &val_arr, diff_attrs.clone(), result, ); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 5c03b2ce33e64..19a481ab3419a 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -135,6 +135,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::round_ties_even_f32 | sym::round_ties_even_f64 | sym::round_ties_even_f128 + | sym::enzyme_autodiff | sym::const_eval_select => hir::Safety::Safe, _ => hir::Safety::Unsafe, }; @@ -216,6 +217,7 @@ pub(crate) fn check_intrinsic_type( (n_tps, n_cts, inputs, output) } + sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index d54175548e30e..05c72f5e81e3b 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -915,6 +915,7 @@ symbols! { enumerate_method, env, env_CFG_RELEASE: env!("CFG_RELEASE"), + enzyme_autodiff, eprint_macro, eprintln_macro, eq, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 106cc725fee2c..e209930bbcb35 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3163,6 +3163,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn enzyme_autodiff(f: F, df: G, args: T) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] From 929684304c948ea487f4e18362274e6384593d83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 7 Jul 2025 17:29:42 +0000 Subject: [PATCH 08/33] Remove attr checking from hir_analysis --- .../rustc_hir_analysis/src/check/intrinsic.rs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 19a481ab3419a..9692c82fbfcda 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -198,25 +198,8 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; - // FIXME(Sa4dUs): Get the actual safety level of the diff function - let safety = if has_autodiff { - hir::Safety::Safe - } else { - intrinsic_operation_unsafety(tcx, intrinsic_id) - }; let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { - _ if has_autodiff => { - let sig = tcx.fn_sig(intrinsic_id.to_def_id()); - let sig = sig.skip_binder(); - let n_tps = generics.own_counts().types; - let n_cts = generics.own_counts().consts; - - let inputs = sig.skip_binder().inputs().to_vec(); - let output = sig.skip_binder().output(); - - (n_tps, n_cts, inputs, output) - } sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), From d5611e4968d7c6ad32226cd7145785339b4dec3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 8 Jul 2025 14:20:24 +0000 Subject: [PATCH 09/33] FIx generics error when passing fn as param to intrinsic --- compiler/rustc_builtin_macros/src/autodiff.rs | 51 ++++++++++++++++--- .../rustc_hir_analysis/src/check/intrinsic.rs | 3 +- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c00e659bb311a..26aa69cf0f806 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -16,8 +16,9 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, - MetaItemInner, PatKind, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind, + FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, + PathSegment, QSelf, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -337,8 +338,14 @@ mod llvm_enzyme { &generics, ); - let d_body = - call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + let d_body = call_enzyme_autodiff( + ecx, + primal, + first_ident(&meta_item_vec[0]), + span, + &d_sig, + &generics, + ); // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { @@ -505,9 +512,10 @@ mod llvm_enzyme { diff: Ident, span: Span, d_sig: &FnSig, + generics: &Generics, ) -> P { - let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); - let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span); let tuple_expr = ecx.expr_tuple( span, @@ -542,6 +550,37 @@ mod llvm_enzyme { block } + // Generate turbofish expression from fn name and generics + // Given `foo` and ``, gen `foo::` + fn gen_turbofish_expr( + ecx: &ExtCtxt<'_>, + ident: Ident, + generics: &Generics, + span: Span, + ) -> P { + let generic_args = generics + .params + .iter() + .map(|p| { + let path = ast::Path::from_ident(p.ident); + let ty = ecx.ty_path(path); + AngleBracketedArg::Arg(GenericArg::Type(ty)) + }) + .collect::>(); + + let args = AngleBracketedArgs { span, args: generic_args }; + + let segment = PathSegment { + ident, + id: ast::DUMMY_NODE_ID, + args: Some(P(GenericArgs::AngleBracketed(args))), + }; + + let path = Path { span, segments: thin_vec![segment], tokens: None }; + + ecx.expr_path(path) + } + // Generate dummy const to prevent primal function // from being optimized away before applying enzyme // ``` diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 9692c82fbfcda..bca73b2135ccc 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -172,8 +172,6 @@ pub(crate) fn check_intrinsic_type( } }; - let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); - let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -198,6 +196,7 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; + let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), From 74152345e219d31943f7ab6fce9a27df0984723e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 8 Jul 2025 17:38:54 +0000 Subject: [PATCH 10/33] Use Instance::new_raw instead of Instance::mono Note(Sa4dUs): `cg/generic.rs` test is passing with some tweaks --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index c16b9b33cf0c6..95a01369ae59f 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -224,21 +224,21 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { }; // Get source, diff, and attrs - let source_id = match fn_args.into_type_list(tcx)[0].kind() { - ty::FnDef(def_id, _) => def_id, + let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, source_params) => (def_id, source_params), _ => bug!("invalid args"), }; - let fn_source = Instance::mono(tcx, *source_id); + let fn_source = Instance::new_raw(*source_id, source_args); let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - let diff_id = match fn_args.into_type_list(tcx)[1].kind() { - ty::FnDef(def_id, _) => def_id, + let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, diff_args) => (def_id, diff_args), _ => bug!("invalid args"), }; - let fn_diff = Instance::mono(tcx, *diff_id); + let fn_diff = Instance::new_raw(*diff_id, diff_args); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); From 3f64b82362095453e87342983905ab3c962732f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 10 Jul 2025 18:26:38 +0000 Subject: [PATCH 11/33] Hacky fix for issues at trait calls --- compiler/rustc_builtin_macros/src/autodiff.rs | 300 ++++-------------- .../src/builder/autodiff.rs | 3 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 149 +++++---- 3 files changed, 146 insertions(+), 306 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 26aa69cf0f806..bc52a62f73399 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -16,9 +16,9 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind, - FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, - PathSegment, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, FnRetTy, + FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, + PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -74,10 +74,12 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { + fn extract_item_info( + iitem: &P, + ) -> Option<(Visibility, FnSig, Ident, Generics, bool)> { match &iitem.kind { ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false)) } _ => None, } @@ -229,16 +231,20 @@ mod llvm_enzyme { // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function - let Some((vis, sig, primal, generics)) = (match &item { + let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) - } + Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( + assoc_item.vis.clone(), + sig.clone(), + ident.clone(), + generics.clone(), + *of_trait, + )), _ => None, }, _ => None, @@ -333,18 +339,21 @@ mod llvm_enzyme { let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); // TODO(Sa4dUs): Remove this and all the related logic - let _d_body = gen_enzyme_body( - ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, - &generics, - ); - - let d_body = call_enzyme_autodiff( + let d_body = gen_enzyme_body( ecx, + &x, + n_active, + &sig, + &d_sig, primal, - first_ident(&meta_item_vec[0]), + &new_args, span, - &d_sig, + sig_span, + idents, + errored, + first_ident(&meta_item_vec[0]), &generics, + impl_of_trait, ); // The first element of it is the name of the function to be generated @@ -441,8 +450,6 @@ mod llvm_enzyme { tokens: ts, }); - let vis_clone = vis.clone(); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { @@ -479,9 +486,7 @@ mod llvm_enzyme { } }; - let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); - - return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; + return vec![orig_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -513,9 +518,10 @@ mod llvm_enzyme { span: Span, d_sig: &FnSig, generics: &Generics, - ) -> P { - let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span); - let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span); + is_impl: bool, + ) -> rustc_ast::Stmt { + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl); let tuple_expr = ecx.expr_tuple( span, @@ -545,9 +551,7 @@ mod llvm_enzyme { vec![primal_path_expr, diff_path_expr, tuple_expr].into(), ); - let block = ecx.block_expr(call_expr); - - block + ecx.stmt_expr(call_expr) } // Generate turbofish expression from fn name and generics @@ -557,6 +561,7 @@ mod llvm_enzyme { ident: Ident, generics: &Generics, span: Span, + is_impl: bool, ) -> P { let generic_args = generics .params @@ -568,7 +573,7 @@ mod llvm_enzyme { }) .collect::>(); - let args = AngleBracketedArgs { span, args: generic_args }; + let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args }; let segment = PathSegment { ident, @@ -576,79 +581,18 @@ mod llvm_enzyme { args: Some(P(GenericArgs::AngleBracketed(args))), }; - let path = Path { span, segments: thin_vec![segment], tokens: None }; - - ecx.expr_path(path) - } - - // Generate dummy const to prevent primal function - // from being optimized away before applying enzyme - // ``` - // const _: () = - // { - // #[used] - // pub static DUMMY_PTR: fn_type = primal_fn; - // }; - // ``` - fn gen_dummy_const( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - sig: FnSig, - generics: Generics, - vis: Visibility, - ) -> Annotatable { - // #[used] - let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let used_attr = outer_normal_attr(&used_attr, new_id, span); - - // static DUMMY_PTR: = - let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); - let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { - safety: sig.header.safety, - ext: sig.header.ext, - generic_params: generics.params, - decl: sig.decl, - decl_span: sig.span, - })); - let static_ty = ecx.ty(span, fn_ptr_ty); - - let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); - let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { - ident: static_ident, - ty: static_ty, - safety: ast::Safety::Default, - mutability: ast::Mutability::Not, - expr: Some(static_expr), - define_opaque: None, - })); - - let static_item = ast::Item { - attrs: thin_vec![used_attr], - id: ast::DUMMY_NODE_ID, - span, - vis, - kind: static_item_kind, - tokens: None, + let segments = if is_impl { + thin_vec![ + PathSegment { ident: Ident::from_str("Foo"), id: ast::DUMMY_NODE_ID, args: None }, + segment, + ] + } else { + thin_vec![segment] }; - let block_expr = ecx.expr_block(Box::new(ast::Block { - stmts: thin_vec![ecx.stmt_item(span, P(static_item))], - id: ast::DUMMY_NODE_ID, - rules: ast::BlockCheckMode::Default, - span, - tokens: None, - })); - - let const_item = ecx.item_const( - span, - Ident::from_str_and_span("_", span), - ecx.ty(span, ast::TyKind::Tup(thin_vec![])), - block_expr, - ); + let path = Path { span, segments, tokens: None }; - Annotatable::Item(const_item) + ecx.expr_path(path) } // Will generate a body of the type: @@ -666,33 +610,14 @@ mod llvm_enzyme { ecx: &ExtCtxt<'_>, span: Span, primal: Ident, - new_names: &[String], - sig_span: Span, + _new_names: &[String], + _sig_span: Span, new_decl_span: Span, idents: &[Ident], errored: bool, generics: &Generics, ) -> (P, P, P, P) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let noop = ast::InlineAsm { - asm_macro: ast::AsmMacro::Asm, - template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], - template_strs: Box::new([]), - operands: vec![], - clobber_abis: vec![], - options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, - line_spans: vec![], - }; - let noop_expr = ecx.expr_asm(span, P(noop)); - let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); - let unsf_block = ast::Block { - stmts: thin_vec![ecx.stmt_semi(noop_expr)], - id: ast::DUMMY_NODE_ID, - tokens: None, - rules: unsf, - span, - }; - let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( @@ -700,25 +625,13 @@ mod llvm_enzyme { blackbox_call_expr.clone(), thin_vec![primal_call.clone()], ); - let tup_args = new_names - .iter() - .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) - .collect(); - - let black_box_remaining_args = ecx.expr_call( - sig_span, - blackbox_call_expr.clone(), - thin_vec![ecx.expr_tuple(sig_span, tup_args)], - ); let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_semi(unsf_expr)); // This uses primal args which won't be available if we errored before if !errored { body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); } - body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); (body, primal_call, black_box_primal_call, blackbox_call_expr) } @@ -733,9 +646,9 @@ mod llvm_enzyme { /// from optimizing any arguments away. fn gen_enzyme_body( ecx: &ExtCtxt<'_>, - x: &AutoDiffAttrs, - n_active: u32, - sig: &ast::FnSig, + _x: &AutoDiffAttrs, + _n_active: u32, + _sig: &ast::FnSig, d_sig: &ast::FnSig, primal: Ident, new_names: &[String], @@ -743,19 +656,15 @@ mod llvm_enzyme { sig_span: Span, idents: Vec, errored: bool, + diff_ident: Ident, generics: &Generics, + is_impl: bool, ) -> P { let new_decl_span = d_sig.span; - // Just adding some default inline-asm and black_box usages to prevent early inlining - // and optimizations which alter the function signature. - // - // The bb_primal_call is the black_box call of the primal function. We keep it around, - // since it has the convenient property of returning the type of the primal function, - // Remember, we only care to match types here. - // No matter which return we pick, we always wrap it into a std::hint::black_box call, - // to prevent rustc from propagating it into the caller. - let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( + // Add a call to the primal function to prevent it from being inlined + // and call `enzyme_autodiff` intrinsic (this also covers the return type) + let (mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper( ecx, span, primal, @@ -767,98 +676,15 @@ mod llvm_enzyme { generics, ); - if !has_ret(&d_sig.decl.output) { - // there is no return type that we have to match, () works fine. - return body; - } - - // Everything from here onwards just tries to fulfil the return type. Fun! - - // having an active-only return means we'll drop the original return type. - // So that can be treated identical to not having one in the first place. - let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); - - if primal_ret && n_active == 0 && x.mode.is_rev() { - // We only have the primal ret. - body.stmts.push(ecx.stmt_expr(bb_primal_call)); - return body; - } - - if !primal_ret && n_active == 1 { - // Again no tuple return, so return default float val. - let ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - let arg = ty.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - body.stmts.push(ecx.stmt_expr(default_call_expr)); - return body; - } - - let mut exprs: P = primal_call; - let d_ret_ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - if x.mode.is_fwd() { - // Fwd mode is easy. If the return activity is Const, we support arbitrary types. - // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. - // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. - // In all three cases, we can return `std::hint::black_box(::default())`. - if x.ret_activity == DiffActivity::Const { - // Here we call the primal function, since our dummy function has the same return - // type due to the Const return activity. - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 }; - let y = ExprKind::Path( - Some(P(q)), - ecx.path_ident(span, Ident::with_dummy_span(kw::Default)), - ); - let default_call_expr = ecx.expr(span, y); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); - } - } else if x.mode.is_rev() { - if x.width == 1 { - // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. - match d_ret_ty.kind { - TyKind::Tup(ref args) => { - // We have a tuple return type. We need to create a tuple of the same size - // and fill it with default values. - let mut exprs2 = thin_vec![exprs]; - for arg in args.iter().skip(1) { - let arg = arg.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs2.push(default_call_expr); - } - exprs = ecx.expr_tuple(new_decl_span, exprs2); - } - _ => { - // Interestingly, even the `-> ArbitraryType` case - // ends up getting matched and handled correctly above, - // so we don't have to handle any other case for now. - panic!("Unsupported return type: {:?}", d_ret_ty); - } - } - } - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - unreachable!("Unsupported mode: {:?}", x.mode); - } - - body.stmts.push(ecx.stmt_expr(exprs)); + body.stmts.push(call_enzyme_autodiff( + ecx, + primal, + diff_ident, + new_decl_span, + d_sig, + generics, + is_impl, + )); body } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 66c34fbcfb181..a5fae701b45ef 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -245,8 +245,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and // think a bit more about what should go here. - // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now - let cc = 8; + let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) }; let ad_fn = declare_simple_fn( cx, &ad_name, diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 95a01369ae59f..873614a3d7f92 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -9,11 +9,11 @@ use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphizati use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; -use rustc_hir as hir; use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; @@ -177,16 +177,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; - let callee_ty = instance.ty(tcx, self.typing_env()); - let fn_args = instance.args; - - let sig = callee_ty.fn_sig(tcx); - let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); - let ret_ty = sig.output(); let name = tcx.item_name(instance.def_id()); - - let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let fn_args = instance.args; let simple = call_simple_intrinsic(self, name, args); let llval = match name { @@ -200,63 +193,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { ) } sym::enzyme_autodiff => { - let val_arr: Vec<&'ll Value> = match args[2].val { - crate::intrinsic::OperandValue::Ref(ref place_value) => { - let mut ret_arr = vec![]; - let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; - - for i in 0..tuple_place.layout.layout.0.fields.count() { - let field_place = tuple_place.project_field(self, i); - let field_layout = tuple_place.layout.field(self, i); - let llvm_ty = field_layout.llvm_type(self.cx); - - let field_val = - self.load(llvm_ty, field_place.val.llval, field_place.val.align); - - ret_arr.push(field_val) - } - - ret_arr - } - crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], - OperandValue::Immediate(v) => vec![v], - OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), - }; - - // Get source, diff, and attrs - let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { - ty::FnDef(def_id, source_params) => (def_id, source_params), - _ => bug!("invalid args"), - }; - let fn_source = Instance::new_raw(*source_id, source_args); - let source_symbol = - symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); - let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); - let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - - let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { - ty::FnDef(def_id, diff_args) => (def_id, diff_args), - _ => bug!("invalid args"), - }; - let fn_diff = Instance::new_raw(*diff_id, diff_args); - let diff_symbol = - symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - - let diff_attrs = autodiff_attrs(tcx, *diff_id); - let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; - - // Build body - generate_enzyme_call( - self, - self.cx, - fn_to_diff, - &diff_symbol, - llret_ty, - &val_arr, - diff_attrs.clone(), - result, - ); - + codegen_enzyme_autodiff(self, tcx, instance, args, result); return Ok(()); } sym::is_val_statically_known => { @@ -1183,6 +1120,84 @@ fn get_rust_try_fn<'a, 'll, 'tcx>( rust_try } +fn codegen_enzyme_autodiff<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + args: &[OperandRef<'tcx, &'ll Value>], + result: PlaceRef<'tcx, &'ll Value>, +) { + let fn_args = instance.args; + let callee_ty = instance.ty(tcx, bx.typing_env()); + + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), sig); + + let ret_ty = sig.output(); + let llret_ty = bx.layout_of(ret_ty).llvm_type(bx); + + let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]); + + // Get source, diff, and attrs + let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, source_params) => (def_id, source_params), + _ => bug!("invalid args"), + }; + let fn_source = Instance::new_raw(*source_id, source_args); + let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, diff_args) => (def_id, diff_args), + _ => bug!("invalid args"), + }; + let fn_diff = Instance::new_raw(*diff_id, diff_args); + let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, *diff_id); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + + // Build body + generate_enzyme_call( + bx, + bx.cx, + fn_to_diff, + &diff_symbol, + llret_ty, + &val_arr, + diff_attrs.clone(), + result, + ); +} + +fn get_args_from_tuple<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + op: OperandRef<'tcx, &'ll Value>, +) -> Vec<&'ll Value> { + match op.val { + OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: op.layout }; + + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(bx, i); + let field_layout = tuple_place.layout.field(bx, i); + let llvm_ty = field_layout.llvm_type(bx.cx); + + let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } + + ret_arr + } + OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + } +} + fn generic_simd_intrinsic<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, name: Symbol, From 1b7a5401217edcc59c288246051b8e7515b07017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 11 Jul 2025 14:54:49 +0000 Subject: [PATCH 12/33] Fix how fns where being retrieved at intrinsic cg --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 873614a3d7f92..d14fae073a4b9 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1143,7 +1143,8 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( ty::FnDef(def_id, source_params) => (def_id, source_params), _ => bug!("invalid args"), }; - let fn_source = Instance::new_raw(*source_id, source_args); + let fn_source = + Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap(); let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; @@ -1152,10 +1153,11 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( ty::FnDef(def_id, diff_args) => (def_id, diff_args), _ => bug!("invalid args"), }; - let fn_diff = Instance::new_raw(*diff_id, diff_args); + let fn_diff = + Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - let diff_attrs = autodiff_attrs(tcx, *diff_id); + let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; // Build body From cde3c2f3f80b8bf0748a946e67d1ef90551c357a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 11 Jul 2025 18:10:45 +0000 Subject: [PATCH 13/33] Use Self instead of Foo placeholder --- compiler/rustc_builtin_macros/src/autodiff.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index bc52a62f73399..0a3c630ae98f8 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -583,7 +583,7 @@ mod llvm_enzyme { let segments = if is_impl { thin_vec![ - PathSegment { ident: Ident::from_str("Foo"), id: ast::DUMMY_NODE_ID, args: None }, + PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None }, segment, ] } else { @@ -630,7 +630,7 @@ mod llvm_enzyme { // This uses primal args which won't be available if we errored before if !errored { - body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); + body.stmts.push(ecx.stmt_semi(primal_call.clone())); } (body, primal_call, black_box_primal_call, blackbox_call_expr) From 5b3909e057a56cacc5f534ca20e3887d47895e30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sat, 12 Jul 2025 10:28:06 +0000 Subject: [PATCH 14/33] Remove unused code --- compiler/rustc_builtin_macros/src/autodiff.rs | 69 +++--------- .../src/builder/autodiff.rs | 104 ++---------------- compiler/rustc_codegen_llvm/src/errors.rs | 3 + 3 files changed, 27 insertions(+), 149 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 0a3c630ae98f8..e199f40153ede 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -262,7 +262,6 @@ mod llvm_enzyme { }; let has_ret = has_ret(&sig.decl.output); - let sig_span = ecx.with_call_site_ctxt(sig.span); // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field @@ -331,24 +330,13 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); - let n_active: u32 = x - .input_activity - .iter() - .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) - .count() as u32; - let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); + let (d_sig, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - // TODO(Sa4dUs): Remove this and all the related logic let d_body = gen_enzyme_body( ecx, - &x, - n_active, - &sig, &d_sig, primal, - &new_args, span, - sig_span, idents, errored, first_ident(&meta_item_vec[0]), @@ -361,7 +349,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: generics.clone(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -542,7 +530,7 @@ mod llvm_enzyme { vec![ Ident::from_str("std"), Ident::from_str("intrinsics"), - Ident::from_str("enzyme_autodiff"), + Ident::with_dummy_span(sym::enzyme_autodiff), ], ); let call_expr = ecx.expr_call( @@ -555,7 +543,7 @@ mod llvm_enzyme { } // Generate turbofish expression from fn name and generics - // Given `foo` and ``, gen `foo::` + // Given `foo` and `` params, gen `foo::` fn gen_turbofish_expr( ecx: &ExtCtxt<'_>, ident: Ident, @@ -597,35 +585,19 @@ mod llvm_enzyme { // Will generate a body of the type: // ``` - // { - // unsafe { - // asm!("NOP"); - // } - // ::core::hint::black_box(primal(args)); - // ::core::hint::black_box((args, ret)); - // + // primal(args); + // std::intrinsics::enzyme_autodiff(primal, diff, (args)) // } // ``` fn init_body_helper( ecx: &ExtCtxt<'_>, span: Span, primal: Ident, - _new_names: &[String], - _sig_span: Span, - new_decl_span: Span, idents: &[Ident], errored: bool, generics: &Generics, - ) -> (P, P, P, P) { - let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); + ) -> P { let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let black_box_primal_call = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], - ); - let mut body = ecx.block(span, ThinVec::new()); // This uses primal args which won't be available if we errored before @@ -633,7 +605,7 @@ mod llvm_enzyme { body.stmts.push(ecx.stmt_semi(primal_call.clone())); } - (body, primal_call, black_box_primal_call, blackbox_call_expr) + body } /// We only want this function to type-check, since we will replace the body @@ -646,14 +618,9 @@ mod llvm_enzyme { /// from optimizing any arguments away. fn gen_enzyme_body( ecx: &ExtCtxt<'_>, - _x: &AutoDiffAttrs, - _n_active: u32, - _sig: &ast::FnSig, d_sig: &ast::FnSig, primal: Ident, - new_names: &[String], span: Span, - sig_span: Span, idents: Vec, errored: bool, diff_ident: Ident, @@ -664,17 +631,7 @@ mod llvm_enzyme { // Add a call to the primal function to prevent it from being inlined // and call `enzyme_autodiff` intrinsic (this also covers the return type) - let (mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper( - ecx, - span, - primal, - new_names, - sig_span, - new_decl_span, - &idents, - errored, - generics, - ); + let mut body = init_body_helper(ecx, span, primal, &idents, errored, generics); body.stmts.push(call_enzyme_autodiff( ecx, @@ -771,7 +728,7 @@ mod llvm_enzyme { sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - ) -> (ast::FnSig, Vec, Vec, bool) { + ) -> (ast::FnSig, Vec, bool) { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; @@ -783,7 +740,7 @@ mod llvm_enzyme { found: num_activities, }); // This is not the right signature, but we can continue parsing. - return (sig.clone(), vec![], vec![], true); + return (sig.clone(), vec![], true); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); @@ -826,7 +783,7 @@ mod llvm_enzyme { if errors { // This is not the right signature, but we can continue parsing. - return (sig.clone(), new_inputs, idents, true); + return (sig.clone(), idents, true); } let unsafe_activities = x @@ -1034,7 +991,7 @@ mod llvm_enzyme { } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); - (d_sig, new_inputs, idents, false) + (d_sig, idents, false) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index a5fae701b45ef..602d779ad61ba 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,42 +1,18 @@ use std::ptr; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; -use rustc_codegen_ssa::ModuleCodegen; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_errors::FatalError; use rustc_middle::bug; -use tracing::{debug, trace}; +use tracing::debug; -use crate::back::write::llvm_err; use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True, Type}; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; - -fn _get_params(fnc: &Value) -> Vec<&Value> { - let param_num = llvm::LLVMCountParams(fnc) as usize; - let mut fnc_args: Vec<&Value> = vec![]; - fnc_args.reserve(param_num); - unsafe { - llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); - fnc_args.set_len(param_num); - } - fnc_args -} - -fn _has_sret(fnc: &Value) -> bool { - let num_args = llvm::LLVMCountParams(fnc) as usize; - if num_args == 0 { - false - } else { - unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) } - } -} +use crate::{attributes, llvm}; // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the // original inputs, as well as metadata and the additional shadow arguments. @@ -66,12 +42,12 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); - let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); - let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); - let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); - let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); - let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); + let enzyme_const = cx.create_metadata(b"enzyme_const"); + let enzyme_out = cx.create_metadata(b"enzyme_out"); + let enzyme_dup = cx.create_metadata(b"enzyme_dup"); + let enzyme_dupv = cx.create_metadata(b"enzyme_dupv"); + let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed"); + let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv"); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -264,12 +240,12 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); - let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return"); if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { args.push(cx.get_metadata_value(enzyme_primal_ret)); } if attrs.width > 1 { - let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + let enzyme_width = cx.create_metadata(b"enzyme_width"); args.push(cx.get_metadata_value(enzyme_width)); args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); } @@ -287,61 +263,3 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder.store_to_place(call, dest.val); } - -pub(crate) fn differentiate<'ll>( - module: &'ll ModuleCodegen, - cgcx: &CodegenContext, - diff_items: Vec, -) -> Result<(), FatalError> { - // TODO(Sa4dUs): delete all this logic - for item in &diff_items { - trace!("{}", item); - } - - let diag_handler = cgcx.create_dcx(); - - let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); - - // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag? - if !diff_items.is_empty() - && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) - { - return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable)); - } - - // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme. - for item in diff_items.iter() { - let name = item.source.clone(); - let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(_fn_def) = fn_def else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find source function".to_owned(), - }, - )); - }; - debug!(?item.target); - let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(_fn_target) = fn_target else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find target function".to_owned(), - }, - )); - }; - - // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); - } - - // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts - - trace!("done with differentiate()"); - - Ok(()) -} diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 627b0c9ff3b33..2ba84e4622416 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -32,9 +32,12 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } +// TODO(Sa4dUs): we will need to reintroduce these errors somewhere +/* #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_enable)] pub(crate) struct AutoDiffWithoutEnable; +*/ #[derive(Diagnostic)] #[diag(codegen_llvm_lto_bitcode_from_rlib)] From 9ce6d1607f02555bf0a51b0881b74cb893ff937b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 14 Jul 2025 09:01:56 +0000 Subject: [PATCH 15/33] Remove primal call and collect it in mono instead --- compiler/rustc_builtin_macros/src/autodiff.rs | 12 ++---- compiler/rustc_monomorphize/src/collector.rs | 7 ++++ .../src/collector/autodiff.rs | 38 +++++++++++++++++++ 3 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 compiler/rustc_monomorphize/src/collector/autodiff.rs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index e199f40153ede..23ebacaaf0d07 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -594,17 +594,11 @@ mod llvm_enzyme { span: Span, primal: Ident, idents: &[Ident], - errored: bool, + _errored: bool, generics: &Generics, ) -> P { - let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let mut body = ecx.block(span, ThinVec::new()); - - // This uses primal args which won't be available if we errored before - if !errored { - body.stmts.push(ecx.stmt_semi(primal_call.clone())); - } - + let _primal_call = gen_primal_call(ecx, span, primal, idents, generics); + let body = ecx.block(span, ThinVec::new()); body } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 35b80a9b96f4d..3808ae6185099 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -205,6 +205,8 @@ //! this is not implemented however: a mono item will be produced //! regardless of whether it is actually needed or not. +mod autodiff; + use std::cell::OnceCell; use std::path::PathBuf; @@ -237,6 +239,8 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan}; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; +#[cfg(llvm_enzyme)] +use crate::collector::autodiff::collect_enzyme_autodiff_source_fn; use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit}; #[derive(PartialEq)] @@ -913,6 +917,9 @@ fn visit_instance_use<'tcx>( return; } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { + #[cfg(llvm_enzyme)] + collect_enzyme_autodiff_source_fn(tcx, instance, intrinsic, output); + if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will // be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs new file mode 100644 index 0000000000000..d062302ae53a6 --- /dev/null +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -0,0 +1,38 @@ +use rustc_middle::bug; +use rustc_middle::ty::{self, IntrinsicDef, TyCtxt}; +use tracing::debug; + +use crate::collector::{MonoItems, create_fn_mono_item}; + +pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + intrinsic: IntrinsicDef, + output: &mut MonoItems<'tcx>, +) { + if intrinsic.name != rustc_span::sym::enzyme_autodiff { + return; + }; + + debug!("enzyme_autodiff found"); + let (primal, span) = match instance.args[0].kind() { + rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() { + ty::FnDef(def_id, substs) => { + let span = tcx.def_span(def_id); + let instance = ty::Instance::expect_resolve( + tcx, + ty::TypingEnv::non_body_analysis(tcx, def_id), + *def_id, + substs, + span, + ); + + (instance, span) + } + _ => bug!("expected function"), + }, + _ => bug!("expected type"), + }; + + output.push(create_fn_mono_item(tcx, primal, span)); +} From 4e2a0057bb5aa96216064d0631fcfbc1158adeea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 15 Jul 2025 08:37:28 +0000 Subject: [PATCH 16/33] Update codegen tests --- tests/codegen-llvm/autodiff/batched.rs | 72 +++++++++----------- tests/codegen-llvm/autodiff/generic.rs | 2 +- tests/codegen-llvm/autodiff/identical_fnc.rs | 8 +-- tests/codegen-llvm/autodiff/inline.rs | 2 +- tests/codegen-llvm/autodiff/scalar.rs | 2 +- tests/codegen-llvm/autodiff/sret.rs | 2 +- tests/codegen-llvm/autodiff/trait.rs | 31 +++++++++ 7 files changed, 70 insertions(+), 49 deletions(-) create mode 100644 tests/codegen-llvm/autodiff/trait.rs diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs index 88a1de9994c8a..5e94c7bb9b8e6 100644 --- a/tests/codegen-llvm/autodiff/batched.rs +++ b/tests/codegen-llvm/autodiff/batched.rs @@ -10,7 +10,7 @@ // reduce this test to only match the first lines and the ret instructions. #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_forward; @@ -22,7 +22,7 @@ fn square(x: &f32) -> f32 { x * x } -// d_sqaure2 +// d_square2 // CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") // CHECK-NEXT: start: // CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 @@ -33,24 +33,20 @@ fn square(x: &f32) -> f32 { // CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: ret [4 x float] %19 -// CHECK-NEXT: } +// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl" +// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val +// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 +// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1" +// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val +// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 +// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2" +// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val +// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 +// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3" +// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val +// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 +// CHECK-NEXT: ret [4 x float] %15 +// CHECK-NEXT: } // d_square3, the extra float is the original return value (x * x) // CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") @@ -64,26 +60,22 @@ fn square(x: &f32) -> f32 { // CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 // CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 // CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val -// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0 -// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1 -// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2 -// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3 -// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7 -// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0 -// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer -// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10 -// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0 -// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0 -// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1 -// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1 -// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2 -// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2 -// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3 -// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3 -// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0 -// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1 -// CHECK-NEXT: ret { float, [4 x float] } %21 -// CHECK-NEXT: } +// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl" +// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val +// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 +// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1" +// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val +// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 +// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2" +// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val +// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 +// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3" +// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val +// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 +// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0 +// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1 +// CHECK-NEXT: ret { float, [4 x float] } %17 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs index af9706c621208..9553ef3760e74 100644 --- a/tests/codegen-llvm/autodiff/generic.rs +++ b/tests/codegen-llvm/autodiff/generic.rs @@ -2,7 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs index ff8e6c74a6b34..1b4edf1d954b4 100644 --- a/tests/codegen-llvm/autodiff/identical_fnc.rs +++ b/tests/codegen-llvm/autodiff/identical_fnc.rs @@ -10,7 +10,7 @@ // We also explicetly test that we keep running merge_function after AD, by checking for two // identical function calls in the LLVM-IR, while having two different calls in the Rust code. #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; @@ -30,10 +30,8 @@ fn square2(x: &f64) -> f64 { // CHECK-NEXT:start: // CHECK-NOT:br // CHECK-NOT:ret -// CHECK:; call identical_fnc::d_square -// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1) -// CHECK-NEXT:; call identical_fnc::d_square -// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2) +// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx1) +// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx2) fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/inline.rs b/tests/codegen-llvm/autodiff/inline.rs index 5db69b960343c..1aa1b8a912be1 100644 --- a/tests/codegen-llvm/autodiff/inline.rs +++ b/tests/codegen-llvm/autodiff/inline.rs @@ -3,7 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs index c2bca7e9c81ef..745b03ee0ed8f 100644 --- a/tests/codegen-llvm/autodiff/scalar.rs +++ b/tests/codegen-llvm/autodiff/scalar.rs @@ -2,7 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs index 67f68fc053cc4..e2272fd4df7d3 100644 --- a/tests/codegen-llvm/autodiff/sret.rs +++ b/tests/codegen-llvm/autodiff/sret.rs @@ -8,7 +8,7 @@ // We therefore use this test to verify some of our sret handling. #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/trait.rs b/tests/codegen-llvm/autodiff/trait.rs new file mode 100644 index 0000000000000..988e9145087b2 --- /dev/null +++ b/tests/codegen-llvm/autodiff/trait.rs @@ -0,0 +1,31 @@ +//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// Just check it does not crash for now +// CHECK: ; +#![feature(autodiff)] +#![feature(core_intrinsics)] + +use std::autodiff::autodiff_reverse; + +struct Foo { + a: f64, +} + +trait MyTrait { + fn f(&self, x: f64) -> f64; + fn df(&self, x: f64, seed: f64) -> (f64, f64); +} + +impl MyTrait for Foo { + #[autodiff_reverse(df, Const, Active, Active)] + fn f(&self, x: f64) -> f64 { + self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) + } +} + +fn main() { + let foo = Foo { a: 3.0f64 }; + dbg!(foo.df(1.0, 1.0)); +} From c55e7ca7bdbe09be171809954ebcac7de0dd1e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 22 Jul 2025 10:38:36 +0000 Subject: [PATCH 17/33] Handle slices when extracting args from tuple --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 24 +++++++++++++++---- compiler/rustc_monomorphize/src/collector.rs | 4 ++-- .../src/collector/autodiff.rs | 23 +++++++++++++----- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index d14fae073a4b9..fb838615990c3 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,7 +3,6 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; -use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -1157,7 +1156,13 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); + // TODO(Sa4dUs): Store autodiff items in a single pass and just get them here + // in a O(1) step + let diff_attrs = tcx + .collect_and_partition_mono_items(()) + .autodiff_items + .iter() + .find(|item| item.target == diff_symbol); let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; // Build body @@ -1168,7 +1173,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( &diff_symbol, llret_ty, &val_arr, - diff_attrs.clone(), + diff_attrs.attrs.clone(), result, ); } @@ -1185,11 +1190,22 @@ fn get_args_from_tuple<'ll, 'tcx>( for i in 0..tuple_place.layout.layout.0.fields.count() { let field_place = tuple_place.project_field(bx, i); let field_layout = tuple_place.layout.field(bx, i); + let field_ty = field_layout.ty; let llvm_ty = field_layout.llvm_type(bx.cx); let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align); - ret_arr.push(field_val) + match field_ty.kind() { + ty::Ref(_, inner_ty, _) if matches!(inner_ty.kind(), ty::Slice(_)) => { + let ptr = bx.extract_value(field_val, 0); + let len = bx.extract_value(field_val, 1); + ret_arr.push(ptr); + ret_arr.push(len); + } + _ => { + ret_arr.push(field_val); + } + } } ret_arr diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 3808ae6185099..74e67165afa79 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -240,7 +240,7 @@ use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; #[cfg(llvm_enzyme)] -use crate::collector::autodiff::collect_enzyme_autodiff_source_fn; +use crate::collector::autodiff::collect_enzyme_autodiff_fn; use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit}; #[derive(PartialEq)] @@ -918,7 +918,7 @@ fn visit_instance_use<'tcx>( } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { #[cfg(llvm_enzyme)] - collect_enzyme_autodiff_source_fn(tcx, instance, intrinsic, output); + collect_enzyme_autodiff_fn(tcx, instance, intrinsic, output); if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs index d062302ae53a6..3c5d768b79d4b 100644 --- a/compiler/rustc_monomorphize/src/collector/autodiff.rs +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -1,10 +1,13 @@ use rustc_middle::bug; -use rustc_middle::ty::{self, IntrinsicDef, TyCtxt}; -use tracing::debug; +use rustc_middle::ty::{self, GenericArg, IntrinsicDef, TyCtxt}; use crate::collector::{MonoItems, create_fn_mono_item}; -pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>( +// Here, we force both primal and diff function to be collected in +// mono so this does not interfere in `enzyme_autodiff` intrinsics +// codegen process. If they are unused, they will be removed later and +// won't be present at LLVM-IR. +pub(crate) fn collect_enzyme_autodiff_fn<'tcx>( tcx: TyCtxt<'tcx>, instance: ty::Instance<'tcx>, intrinsic: IntrinsicDef, @@ -14,8 +17,16 @@ pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>( return; }; - debug!("enzyme_autodiff found"); - let (primal, span) = match instance.args[0].kind() { + collect_autodiff_fn_from_arg(instance.args[0], tcx, output); + collect_autodiff_fn_from_arg(instance.args[1], tcx, output); +} + +fn collect_autodiff_fn_from_arg<'tcx>( + arg: GenericArg<'tcx>, + tcx: TyCtxt<'tcx>, + output: &mut MonoItems<'tcx>, +) { + let (instance, span) = match arg.kind() { rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() { ty::FnDef(def_id, substs) => { let span = tcx.def_span(def_id); @@ -34,5 +45,5 @@ pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>( _ => bug!("expected type"), }; - output.push(create_fn_mono_item(tcx, primal, span)); + output.push(create_fn_mono_item(tcx, instance, span)); } From 70503d7194918031b3a74539580bbbd64ef7ef43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 23 Jul 2025 11:04:31 +0000 Subject: [PATCH 18/33] Do not depend on mono anymore --- .../src/builder/autodiff.rs | 79 ++++++++++++++++++- compiler/rustc_codegen_llvm/src/intrinsic.rs | 23 +++--- 2 files changed, 90 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 602d779ad61ba..c88ccdeda9e65 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,7 +3,8 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_middle::bug; +use rustc_middle::{bug, ty}; +use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use tracing::debug; use crate::builder::{Builder, PlaceRef, UNNAMED}; @@ -14,6 +15,82 @@ use crate::llvm::{Metadata, True, Type}; use crate::value::Value; use crate::{attributes, llvm}; +pub(crate) fn adjust_activity_to_abi<'tcx>( + tcx: TyCtxt<'tcx>, + fn_ty: Ty<'tcx>, + da: &mut Vec, +) { + if !matches!(fn_ty.kind(), ty::FnDef(..)) { + bug!("expected fn def for autodiff, got {:?}", fn_ty); + } + + // We don't actually pass the types back into the type system. + // All we do is decide how to handle the arguments. + let sig = fn_ty.fn_sig(tcx).skip_binder(); + + let mut new_activities = vec![]; + let mut new_positions = vec![]; + for (i, ty) in sig.inputs().iter().enumerate() { + if let Some(inner_ty) = ty.builtin_deref(true) { + if inner_ty.is_slice() { + // Now we need to figure out the size of each slice element in memory to allow + // safety checks and usability improvements in the backend. + let sty = match inner_ty.builtin_index() { + Some(sty) => sty, + None => { + panic!("slice element type unknown"); + } + }; + let pci = PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: sty, + }; + + let layout = tcx.layout_of(pci); + let elem_size = match layout { + Ok(layout) => layout.size, + Err(_) => { + bug!("autodiff failed to compute slice element size"); + } + }; + let elem_size: u32 = elem_size.bytes() as u32; + + // We know that the length will be passed as extra arg. + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + // There is one FakeActivitySize per slice, so for convenience we store the + // slice element size in bytes in it. We will use the size in the backend. + let activity = match da[i] { + DiffActivity::DualOnly + | DiffActivity::Dual + | DiffActivity::Dualv + | DiffActivity::DuplicatedOnly + | DiffActivity::Duplicated => { + DiffActivity::FakeActivitySize(Some(elem_size)) + } + DiffActivity::Const => DiffActivity::Const, + _ => bug!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + + continue; + } + } + } + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } +} + // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the // original inputs, as well as metadata and the additional shadow arguments. // This function matches the arguments from the outer function to the inner enzyme call. diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index fb838615990c3..a3f6d74f32b92 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,6 +3,7 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; +use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -12,7 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE; use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; @@ -21,7 +22,7 @@ use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; -use crate::builder::autodiff::generate_enzyme_call; +use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -1156,14 +1157,14 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - // TODO(Sa4dUs): Store autodiff items in a single pass and just get them here - // in a O(1) step - let diff_attrs = tcx - .collect_and_partition_mono_items(()) - .autodiff_items - .iter() - .find(|item| item.target == diff_symbol); - let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); + let Some(mut diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + + adjust_activity_to_abi( + tcx, + fn_source.ty(tcx, TypingEnv::fully_monomorphized()), + &mut diff_attrs.input_activity, + ); // Build body generate_enzyme_call( @@ -1173,7 +1174,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( &diff_symbol, llret_ty, &val_arr, - diff_attrs.attrs.clone(), + diff_attrs.clone(), result, ); } From 09bb20e028f1521d82ffd2d85812bf6b644a4ed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 23 Jul 2025 14:35:12 +0000 Subject: [PATCH 19/33] Get args from tuple using fnabi and minor fixes --- compiler/rustc_builtin_macros/src/autodiff.rs | 22 ++++-- .../src/builder/autodiff.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 68 +++++++++++-------- 3 files changed, 56 insertions(+), 36 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 23ebacaaf0d07..1ab6b9c8572d3 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -16,9 +16,9 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, FnRetTy, - FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, - PathSegment, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode, + FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind, + MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -554,10 +554,18 @@ mod llvm_enzyme { let generic_args = generics .params .iter() - .map(|p| { - let path = ast::Path::from_ident(p.ident); - let ty = ecx.ty_path(path); - AngleBracketedArg::Arg(GenericArg::Type(ty)) + .filter_map(|p| match &p.kind { + GenericParamKind::Type { .. } => { + let path = ast::Path::from_ident(p.ident); + let ty = ecx.ty_path(path); + Some(AngleBracketedArg::Arg(GenericArg::Type(ty))) + } + GenericParamKind::Const { .. } => { + let expr = ecx.expr_path(ast::Path::from_ident(p.ident)); + let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr }; + Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const))) + } + GenericParamKind::Lifetime { .. } => None, }) .collect::>(); diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index c88ccdeda9e65..347908c50aae1 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,8 +3,8 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_middle::{bug, ty}; use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; +use rustc_middle::{bug, ty}; use tracing::debug; use crate::builder::{Builder, PlaceRef, UNNAMED}; diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index a3f6d74f32b92..42f0a4c8966d4 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -17,6 +17,7 @@ use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; +use rustc_target::callconv::PassMode; use rustc_target::spec::PanicStrategy; use tracing::debug; @@ -1136,8 +1137,6 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( let ret_ty = sig.output(); let llret_ty = bx.layout_of(ret_ty).llvm_type(bx); - let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]); - // Get source, diff, and attrs let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { ty::FnDef(def_id, source_params) => (def_id, source_params), @@ -1155,6 +1154,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( }; let fn_diff = Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); + let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2], fn_diff); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); @@ -1181,39 +1181,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( fn get_args_from_tuple<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, - op: OperandRef<'tcx, &'ll Value>, + tuple_op: OperandRef<'tcx, &'ll Value>, + fn_instance: Instance<'tcx>, ) -> Vec<&'ll Value> { - match op.val { - OperandValue::Ref(ref place_value) => { - let mut ret_arr = vec![]; - let tuple_place = PlaceRef { val: *place_value, layout: op.layout }; - - for i in 0..tuple_place.layout.layout.0.fields.count() { - let field_place = tuple_place.project_field(bx, i); - let field_layout = tuple_place.layout.field(bx, i); - let field_ty = field_layout.ty; - let llvm_ty = field_layout.llvm_type(bx.cx); - - let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align); - - match field_ty.kind() { - ty::Ref(_, inner_ty, _) if matches!(inner_ty.kind(), ty::Slice(_)) => { - let ptr = bx.extract_value(field_val, 0); - let len = bx.extract_value(field_val, 1); - ret_arr.push(ptr); - ret_arr.push(len); + let cx = bx.cx; + let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty()); + + match tuple_op.val { + OperandValue::Immediate(val) => vec![val], + OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Ref(ptr) => { + let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout }; + + let mut result = Vec::with_capacity(fn_abi.args.len()); + let mut tuple_index = 0; + + for arg in &fn_abi.args { + match arg.mode { + PassMode::Ignore => {} + PassMode::Direct(_) | PassMode::Cast { .. } => { + let field = tuple_place.project_field(bx, tuple_index); + let llvm_ty = field.layout.llvm_type(bx.cx); + let val = bx.load(llvm_ty, field.val.llval, field.val.align); + result.push(val); + tuple_index += 1; } - _ => { - ret_arr.push(field_val); + PassMode::Pair(_, _) => { + let field = tuple_place.project_field(bx, tuple_index); + let llvm_ty = field.layout.llvm_type(bx.cx); + let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align); + result.push(bx.extract_value(pair_val, 0)); + result.push(bx.extract_value(pair_val, 1)); + tuple_index += 1; + } + PassMode::Indirect { .. } => { + let field = tuple_place.project_field(bx, tuple_index); + result.push(field.val.llval); + tuple_index += 1; } } } - ret_arr + result } - OperandValue::Pair(v1, v2) => vec![v1, v2], - OperandValue::Immediate(v) => vec![v], - OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + + OperandValue::ZeroSized => bug!("unexpected ZeroSized argument in get_args_from_tuple"), } } From 410867755835439a30f7a20f6387f1b4cf459cc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 25 Jul 2025 13:35:50 +0000 Subject: [PATCH 20/33] Remove dead code --- Cargo.lock | 2 - compiler/rustc_codegen_llvm/src/lib.rs | 6 - compiler/rustc_codegen_ssa/src/back/write.rs | 18 +-- compiler/rustc_codegen_ssa/src/base.rs | 7 +- .../rustc_codegen_ssa/src/codegen_attrs.rs | 2 +- .../rustc_codegen_ssa/src/traits/write.rs | 2 - compiler/rustc_middle/src/mir/mono.rs | 2 - compiler/rustc_monomorphize/Cargo.toml | 2 - .../src/collector/autodiff.rs | 3 +- .../rustc_monomorphize/src/partitioning.rs | 34 +---- .../src/partitioning/autodiff.rs | 143 ------------------ 11 files changed, 9 insertions(+), 212 deletions(-) delete mode 100644 compiler/rustc_monomorphize/src/partitioning/autodiff.rs diff --git a/Cargo.lock b/Cargo.lock index d9cfda17ad929..aa8154420ae13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4307,7 +4307,6 @@ name = "rustc_monomorphize" version = "0.0.0" dependencies = [ "rustc_abi", - "rustc_ast", "rustc_data_structures", "rustc_errors", "rustc_fluent_macro", @@ -4316,7 +4315,6 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", - "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index ca84b6de8b11a..79e80db6f5554 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -30,7 +30,6 @@ use context::SimpleCx; use errors::ParseTargetMachineConfig; use llvm_util::target_config; use rustc_ast::expand::allocator::AllocatorKind; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -173,15 +172,10 @@ impl WriteBackendMethods for LlvmCodegenBackend { exported_symbols_for_lto: &[String], each_linked_rlib_for_lto: &[PathBuf], modules: Vec>, - diff_fncs: Vec, ) -> Result, FatalError> { let mut module = back::lto::run_fat(cgcx, exported_symbols_for_lto, each_linked_rlib_for_lto, modules)?; - if !diff_fncs.is_empty() { - builder::autodiff::differentiate(&module, cgcx, diff_fncs)?; - } - let dcx = cgcx.create_dcx(); let dcx = dcx.handle(); back::lto::run_pass_manager(cgcx, dcx, &mut module, false)?; diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index aa29afb7f5b11..2e8122798d169 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -7,7 +7,6 @@ use std::{fs, io, mem, str, thread}; use rustc_abi::Size; use rustc_ast::attr; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::FxIndexMap; use rustc_data_structures::jobserver::{self, Acquired}; use rustc_data_structures::memmap::Mmap; @@ -38,7 +37,7 @@ use tracing::debug; use super::link::{self, ensure_removed}; use super::lto::{self, SerializedModule}; use crate::back::lto::check_lto_allowed; -use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir}; +use crate::errors::ErrorCreatingRemarkDir; use crate::traits::*; use crate::{ CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind, @@ -454,7 +453,6 @@ pub(crate) fn start_async_codegen( backend: B, tcx: TyCtxt<'_>, target_cpu: String, - autodiff_items: &[AutoDiffItem], ) -> OngoingCodegen { let (coordinator_send, coordinator_receive) = channel(); @@ -473,7 +471,6 @@ pub(crate) fn start_async_codegen( backend.clone(), tcx, &crate_info, - autodiff_items, shared_emitter, codegen_worker_send, coordinator_receive, @@ -728,7 +725,6 @@ pub(crate) enum WorkItem { each_linked_rlib_for_lto: Vec, needs_fat_lto: Vec>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, - autodiff: Vec, }, /// Performs thin-LTO on the given module. ThinLto(lto::ThinModule), @@ -1001,7 +997,6 @@ fn execute_fat_lto_work_item( each_linked_rlib_for_lto: &[PathBuf], mut needs_fat_lto: Vec>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, - autodiff: Vec, module_config: &ModuleConfig, ) -> Result, FatalError> { for (module, wp) in import_only_modules { @@ -1013,7 +1008,6 @@ fn execute_fat_lto_work_item( exported_symbols_for_lto, each_linked_rlib_for_lto, needs_fat_lto, - autodiff, )?; let module = B::codegen(cgcx, module, module_config)?; Ok(WorkItemResult::Finished(module)) @@ -1105,7 +1099,6 @@ fn start_executing_work( backend: B, tcx: TyCtxt<'_>, crate_info: &CrateInfo, - autodiff_items: &[AutoDiffItem], shared_emitter: SharedEmitter, codegen_worker_send: Sender, coordinator_receive: Receiver>, @@ -1115,7 +1108,6 @@ fn start_executing_work( ) -> thread::JoinHandle> { let coordinator_send = tx_to_llvm_workers; let sess = tcx.sess; - let autodiff_items = autodiff_items.to_vec(); let mut each_linked_rlib_for_lto = Vec::new(); let mut each_linked_rlib_file_for_lto = Vec::new(); @@ -1448,7 +1440,6 @@ fn start_executing_work( each_linked_rlib_for_lto: each_linked_rlib_file_for_lto, needs_fat_lto, import_only_modules, - autodiff: autodiff_items.clone(), }, 0, )); @@ -1456,11 +1447,6 @@ fn start_executing_work( helper.request_token(); } } else { - if !autodiff_items.is_empty() { - let dcx = cgcx.create_dcx(); - dcx.handle().emit_fatal(AutodiffWithoutLto {}); - } - for (work, cost) in generate_thin_lto_work( &cgcx, &exported_symbols_for_lto, @@ -1795,7 +1781,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>( each_linked_rlib_for_lto, needs_fat_lto, import_only_modules, - autodiff, } => { let _timer = cgcx .prof @@ -1806,7 +1791,6 @@ fn spawn_work<'a, B: ExtraBackendMethods>( &each_linked_rlib_for_lto, needs_fat_lto, import_only_modules, - autodiff, module_config, ) } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index b4556ced0b3fb..b483c01da59e5 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -647,7 +647,7 @@ pub fn codegen_crate( ) -> OngoingCodegen { // Skip crate items and just output metadata in -Z no-codegen mode. if tcx.sess.opts.unstable_opts.no_codegen || !tcx.sess.opts.output_types.should_codegen() { - let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu, &[]); + let ongoing_codegen = start_async_codegen(backend, tcx, target_cpu); ongoing_codegen.codegen_finished(tcx); @@ -665,8 +665,7 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let MonoItemPartitions { codegen_units, autodiff_items, .. } = - tcx.collect_and_partition_mono_items(()); + let MonoItemPartitions { codegen_units, .. } = tcx.collect_and_partition_mono_items(()); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -679,7 +678,7 @@ pub fn codegen_crate( } } - let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu, autodiff_items); + let ongoing_codegen = start_async_codegen(backend.clone(), tcx, target_cpu); // Codegen an allocator shim, if necessary. if let Some(kind) = allocator_kind_for_codegen(tcx) { diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 7f54a47327af8..3ea48fa219924 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -624,7 +624,7 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option { /// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never /// panic, unless we introduced a bug when parsing the autodiff macro. //FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed. -fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { +pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { let attrs = tcx.get_attrs(id, sym::rustc_autodiff); let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::>(); diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index f391c198e1a10..c29ad90735b7b 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,6 +1,5 @@ use std::path::PathBuf; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -23,7 +22,6 @@ pub trait WriteBackendMethods: Clone + 'static { exported_symbols_for_lto: &[String], each_linked_rlib_for_lto: &[PathBuf], modules: Vec>, - diff_fncs: Vec, ) -> Result, FatalError>; /// Performs thin LTO by performing necessary global analysis and returning two /// lists, one of the modules that need optimization and another for modules that diff --git a/compiler/rustc_middle/src/mir/mono.rs b/compiler/rustc_middle/src/mir/mono.rs index e5864660575c5..613f0f7a90cd1 100644 --- a/compiler/rustc_middle/src/mir/mono.rs +++ b/compiler/rustc_middle/src/mir/mono.rs @@ -2,7 +2,6 @@ use std::borrow::Cow; use std::fmt; use std::hash::Hash; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::base_n::{BaseNString, CASE_INSENSITIVE, ToBaseN}; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::FxIndexMap; @@ -337,7 +336,6 @@ impl ToStableHashKey> for MonoItem<'_> { pub struct MonoItemPartitions<'tcx> { pub codegen_units: &'tcx [CodegenUnit<'tcx>], pub all_mono_items: &'tcx DefIdSet, - pub autodiff_items: &'tcx [AutoDiffItem], } #[derive(Debug, HashStable)] diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index 0ed5b4fc0d09c..09a55f0b5f8da 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -6,7 +6,6 @@ edition = "2024" [dependencies] # tidy-alphabetical-start rustc_abi = { path = "../rustc_abi" } -rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_fluent_macro = { path = "../rustc_fluent_macro" } @@ -15,7 +14,6 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } -rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs index 3c5d768b79d4b..f388f3779a289 100644 --- a/compiler/rustc_monomorphize/src/collector/autodiff.rs +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -18,7 +18,6 @@ pub(crate) fn collect_enzyme_autodiff_fn<'tcx>( }; collect_autodiff_fn_from_arg(instance.args[0], tcx, output); - collect_autodiff_fn_from_arg(instance.args[1], tcx, output); } fn collect_autodiff_fn_from_arg<'tcx>( @@ -27,7 +26,7 @@ fn collect_autodiff_fn_from_arg<'tcx>( output: &mut MonoItems<'tcx>, ) { let (instance, span) = match arg.kind() { - rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() { + ty::GenericArgKind::Type(ty) => match ty.kind() { ty::FnDef(def_id, substs) => { let span = tcx.def_span(def_id); let instance = ty::Instance::expect_resolve( diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index d76b27d9970b6..8aa7c3eea619c 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -92,8 +92,6 @@ //! source-level module, functions from the same module will be available for //! inlining, even when they are not marked `#[inline]`. -mod autodiff; - use std::cmp; use std::collections::hash_map::Entry; use std::fs::{self, File}; @@ -251,17 +249,7 @@ where always_export_generics, ); - // We can't differentiate a function that got inlined. - let autodiff_active = cfg!(llvm_enzyme) - && matches!(mono_item, MonoItem::Fn(_)) - && cx - .tcx - .codegen_fn_attrs(mono_item.def_id()) - .autodiff_item - .as_ref() - .is_some_and(|ad| ad.is_active()); - - if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { + if visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1156,27 +1144,15 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio } } - #[cfg(not(llvm_enzyme))] - let autodiff_mono_items: Vec<_> = vec![]; - #[cfg(llvm_enzyme)] - let mut autodiff_mono_items: Vec<_> = vec![]; let mono_items: DefIdSet = items .iter() .filter_map(|mono_item| match *mono_item { - MonoItem::Fn(ref instance) => { - #[cfg(llvm_enzyme)] - autodiff_mono_items.push((mono_item, instance)); - Some(instance.def_id()) - } + MonoItem::Fn(ref instance) => Some(instance.def_id()), MonoItem::Static(def_id) => Some(def_id), _ => None, }) .collect(); - let autodiff_items = - autodiff::find_autodiff_source_functions(tcx, &usage_map, autodiff_mono_items); - let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); - // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats && let Err(err) = @@ -1234,11 +1210,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> MonoItemPartitio } } - MonoItemPartitions { - all_mono_items: tcx.arena.alloc(mono_items), - codegen_units, - autodiff_items, - } + MonoItemPartitions { all_mono_items: tcx.arena.alloc(mono_items), codegen_units } } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs deleted file mode 100644 index 22d593b80b895..0000000000000 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ /dev/null @@ -1,143 +0,0 @@ -use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; -use rustc_hir::def_id::LOCAL_CRATE; -use rustc_middle::bug; -use rustc_middle::mir::mono::MonoItem; -use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; -use rustc_symbol_mangling::symbol_name_for_instance_in_crate; -use tracing::{debug, trace}; - -use crate::partitioning::UsageMap; - -fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec) { - if !matches!(fn_ty.kind(), ty::FnDef(..)) { - bug!("expected fn def for autodiff, got {:?}", fn_ty); - } - - // We don't actually pass the types back into the type system. - // All we do is decide how to handle the arguments. - let sig = fn_ty.fn_sig(tcx).skip_binder(); - - let mut new_activities = vec![]; - let mut new_positions = vec![]; - for (i, ty) in sig.inputs().iter().enumerate() { - if let Some(inner_ty) = ty.builtin_deref(true) { - if inner_ty.is_slice() { - // Now we need to figure out the size of each slice element in memory to allow - // safety checks and usability improvements in the backend. - let sty = match inner_ty.builtin_index() { - Some(sty) => sty, - None => { - panic!("slice element type unknown"); - } - }; - let pci = PseudoCanonicalInput { - typing_env: TypingEnv::fully_monomorphized(), - value: sty, - }; - - let layout = tcx.layout_of(pci); - let elem_size = match layout { - Ok(layout) => layout.size, - Err(_) => { - bug!("autodiff failed to compute slice element size"); - } - }; - let elem_size: u32 = elem_size.bytes() as u32; - - // We know that the length will be passed as extra arg. - if !da.is_empty() { - // We are looking at a slice. The length of that slice will become an - // extra integer on llvm level. Integers are always const. - // However, if the slice get's duplicated, we want to know to later check the - // size. So we mark the new size argument as FakeActivitySize. - // There is one FakeActivitySize per slice, so for convenience we store the - // slice element size in bytes in it. We will use the size in the backend. - let activity = match da[i] { - DiffActivity::DualOnly - | DiffActivity::Dual - | DiffActivity::Dualv - | DiffActivity::DuplicatedOnly - | DiffActivity::Duplicated => { - DiffActivity::FakeActivitySize(Some(elem_size)) - } - DiffActivity::Const => DiffActivity::Const, - _ => bug!("unexpected activity for ptr/ref"), - }; - new_activities.push(activity); - new_positions.push(i + 1); - } - - continue; - } - } - } - // now add the extra activities coming from slices - // Reverse order to not invalidate the indices - for _ in 0..new_activities.len() { - let pos = new_positions.pop().unwrap(); - let activity = new_activities.pop().unwrap(); - da.insert(pos, activity); - } -} - -pub(crate) fn find_autodiff_source_functions<'tcx>( - tcx: TyCtxt<'tcx>, - usage_map: &UsageMap<'tcx>, - autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>, -) -> Vec { - let mut autodiff_items: Vec = vec![]; - for (item, instance) in autodiff_mono_items { - let target_id = instance.def_id(); - let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item; - let Some(target_attrs) = cg_fn_attr else { - continue; - }; - let mut input_activities: Vec = target_attrs.input_activity.clone(); - if target_attrs.is_source() { - trace!("source found: {:?}", target_id); - } - if !target_attrs.apply_autodiff() { - continue; - } - - let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); - - let source = - usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { - MonoItem::Fn(ref instance_s) => { - let source_id = instance_s.def_id(); - if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item - && ad.is_active() - { - return Some(instance_s); - } - None - } - _ => None, - }); - let inst = match source { - Some(source) => source, - None => continue, - }; - - debug!("source_id: {:?}", inst.def_id()); - let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized()); - assert!(fn_ty.is_fn()); - adjust_activity_to_abi(tcx, fn_ty, &mut input_activities); - let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); - - let mut new_target_attrs = target_attrs.clone(); - new_target_attrs.input_activity = input_activities; - let itm = new_target_attrs.into_item(symb, target_symbol); - autodiff_items.push(itm); - } - - if !autodiff_items.is_empty() { - trace!("AUTODIFF ITEMS EXIST"); - for item in &mut *autodiff_items { - trace!("{}", &item); - } - } - - autodiff_items -} From e347e0092a27ec40b7760f5dcf97df074120ce41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 25 Jul 2025 13:48:34 +0000 Subject: [PATCH 21/33] Reintroduce autodiff enable and lto errors --- compiler/rustc_codegen_llvm/src/errors.rs | 3 --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 10 ++++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 2ba84e4622416..627b0c9ff3b33 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -32,12 +32,9 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } -// TODO(Sa4dUs): we will need to reintroduce these errors somewhere -/* #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_enable)] pub(crate) struct AutoDiffWithoutEnable; -*/ #[derive(Diagnostic)] #[diag(codegen_llvm_lto_bitcode_from_rlib)] diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 42f0a4c8966d4..38b348ea73633 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -15,6 +15,7 @@ use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; +use rustc_session::config::Lto; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::callconv::PassMode; @@ -25,6 +26,7 @@ use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; use crate::context::CodegenCx; +use crate::errors::{AutoDiffWithoutEnable, AutoDiffWithoutLTO}; use crate::llvm::{self, Metadata}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; @@ -1128,6 +1130,14 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( args: &[OperandRef<'tcx, &'ll Value>], result: PlaceRef<'tcx, &'ll Value>, ) { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable); + } + + if tcx.sess.lto() != Lto::Fat { + let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutLTO); + } + let fn_args = instance.args; let callee_ty = instance.ty(tcx, bx.typing_env()); From ba32958f6cb37c4878e54634d34f3cacb2ca9929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 25 Jul 2025 16:29:11 +0000 Subject: [PATCH 22/33] Minor fixes after rebase --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 7 +------ triagebot.toml | 3 --- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 38b348ea73633..10a2872baa35b 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -15,7 +15,6 @@ use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; -use rustc_session::config::Lto; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::callconv::PassMode; @@ -26,7 +25,7 @@ use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; use crate::context::CodegenCx; -use crate::errors::{AutoDiffWithoutEnable, AutoDiffWithoutLTO}; +use crate::errors::AutoDiffWithoutEnable; use crate::llvm::{self, Metadata}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; @@ -1134,10 +1133,6 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutEnable); } - if tcx.sess.lto() != Lto::Fat { - let _ = tcx.dcx().emit_almost_fatal(AutoDiffWithoutLTO); - } - let fn_args = instance.args; let callee_ty = instance.ty(tcx, bx.typing_env()); diff --git a/triagebot.toml b/triagebot.toml index 168815465b614..8ad448981bab7 100644 --- a/triagebot.toml +++ b/triagebot.toml @@ -282,7 +282,6 @@ trigger_files = [ "src/tools/enzyme", "src/doc/unstable-book/src/compiler-flags/autodiff.md", "compiler/rustc_ast/src/expand/autodiff_attrs.rs", - "compiler/rustc_monomorphize/src/partitioning/autodiff.rs", "compiler/rustc_codegen_llvm/src/builder/autodiff.rs", "compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs", ] @@ -1278,8 +1277,6 @@ cc = ["@ZuseZ4"] cc = ["@ZuseZ4"] [mentions."compiler/rustc_builtin_macros/src/autodiff.rs"] cc = ["@ZuseZ4"] -[mentions."compiler/rustc_monomorphize/src/partitioning/autodiff.rs"] -cc = ["@ZuseZ4"] [mentions."compiler/rustc_codegen_llvm/src/builder/autodiff.rs"] cc = ["@ZuseZ4"] [mentions."compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs"] From ce2da2bc28b4b7b1391b52f6d2a1bc59adb9ed57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 25 Jul 2025 17:38:26 +0000 Subject: [PATCH 23/33] Add rest of test fixes --- .../src/builder/autodiff.rs | 2 +- tests/codegen-llvm/autodiff/generic.rs | 16 ++-- tests/codegen-llvm/autodiff/identical_fnc.rs | 6 +- tests/codegen-llvm/autodiff/inline.rs | 24 ------ tests/codegen-llvm/autodiffv2.rs | 8 +- tests/pretty/autodiff/autodiff_forward.pp | 79 +++++++++++-------- tests/pretty/autodiff/autodiff_forward.rs | 2 +- tests/pretty/autodiff/autodiff_reverse.pp | 23 +++--- tests/pretty/autodiff/autodiff_reverse.rs | 2 +- tests/pretty/autodiff/inherent_impl.pp | 9 +-- tests/pretty/autodiff/inherent_impl.rs | 2 +- tests/ui/autodiff/autodiff_illegal.rs | 1 + 12 files changed, 82 insertions(+), 92 deletions(-) delete mode 100644 tests/codegen-llvm/autodiff/inline.rs diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 347908c50aae1..56116959a6223 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -276,7 +276,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // %0 = fmul double %x, %x // ret double %0 // } - // ``` + // // define double @dsquare(double %x) { // return 0.0; // } diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs index 9553ef3760e74..da17c55b1427f 100644 --- a/tests/codegen-llvm/autodiff/generic.rs +++ b/tests/codegen-llvm/autodiff/generic.rs @@ -11,23 +11,23 @@ fn square + Copy>(x: &T) -> T { *x * *x } -// Ensure that `d_square::` code is generated even if `square::` was never called +// Ensure that `d_square::` code is generated // // CHECK: ; generic::square -// CHECK-NEXT: ; Function Attrs: -// CHECK-NEXT: define internal {{.*}} double +// CHECK-NEXT: ; Function Attrs: {{.*}} +// CHECK-NEXT: define internal {{.*}} float // CHECK-NEXT: start: // CHECK-NOT: ret -// CHECK: fmul double +// CHECK: fmul float -// Ensure that `d_square::` code is generated +// Ensure that `d_square::` code is generated even if `square::` was never called // // CHECK: ; generic::square -// CHECK-NEXT: ; Function Attrs: {{.*}} -// CHECK-NEXT: define internal {{.*}} float +// CHECK-NEXT: ; Function Attrs: +// CHECK-NEXT: define internal {{.*}} double // CHECK-NEXT: start: // CHECK-NOT: ret -// CHECK: fmul float +// CHECK: fmul double fn main() { let xf32: f32 = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs index 1b4edf1d954b4..d847a637e7ad9 100644 --- a/tests/codegen-llvm/autodiff/identical_fnc.rs +++ b/tests/codegen-llvm/autodiff/identical_fnc.rs @@ -26,12 +26,12 @@ fn square2(x: &f64) -> f64 { // CHECK:; identical_fnc::main // CHECK-NEXT:; Function Attrs: -// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E() +// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E() // CHECK-NEXT:start: // CHECK-NOT:br // CHECK-NOT:ret -// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx1) -// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17hdfa1c645848284b7E(double %x.val, ptr %dx2) +// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx1) +// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx2) fn main() { let x = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/inline.rs b/tests/codegen-llvm/autodiff/inline.rs deleted file mode 100644 index 1aa1b8a912be1..0000000000000 --- a/tests/codegen-llvm/autodiff/inline.rs +++ /dev/null @@ -1,24 +0,0 @@ -//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt -//@ no-prefer-dynamic -//@ needs-enzyme - -#![feature(autodiff)] -#![feature(core_intrinsics)] - -use std::autodiff::autodiff_reverse; - -#[autodiff_reverse(d_square, Duplicated, Active)] -fn square(x: &f64) -> f64 { - x * x -} - -// CHECK: ; inline::d_square -// CHECK-NEXT: ; Function Attrs: alwaysinline -// CHECK-NOT: noinline -// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE -fn main() { - let x = std::hint::black_box(3.0); - let mut dx1 = std::hint::black_box(1.0); - let _ = d_square(&x, &mut dx1, 1.0); - assert_eq!(dx1, 6.0); -} diff --git a/tests/codegen-llvm/autodiffv2.rs b/tests/codegen-llvm/autodiffv2.rs index a40d19d3be3a8..ca34d67365084 100644 --- a/tests/codegen-llvm/autodiffv2.rs +++ b/tests/codegen-llvm/autodiffv2.rs @@ -25,13 +25,15 @@ // in our frontend and in the llvm backend to avoid these issues. #![feature(autodiff)] +#![feature(core_intrinsics)] -use std::autodiff::autodiff; +use std::autodiff::autodiff_forward; +// CHECK: ; #[no_mangle] //#[autodiff(d_square1, Forward, Dual, Dual)] -#[autodiff(d_square2, Forward, 4, Dualv, Dualv)] -#[autodiff(d_square3, Forward, 4, Dual, Dual)] +#[autodiff_forward(d_square2, 4, Dualv, Dualv)] +#[autodiff_forward(d_square3, 4, Dual, Dual)] fn square(x: &[f32], y: &mut [f32]) { assert!(x.len() >= 4); assert!(y.len() >= 5); diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index 787c2e517492c..89305a1d9351a 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,7 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -37,44 +37,49 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] -#[rustc_intrinsic] -pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64); +pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { + std::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y)) +} #[rustc_autodiff] #[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[rustc_intrinsic] -pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64; +pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y)) +} #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[rustc_intrinsic] -pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64; +pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y)) +} #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -#[rustc_intrinsic] -pub fn df4() -> (); +pub fn df4() -> () { std::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ()) } #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] -#[rustc_intrinsic] -pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64; +pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f5::<>, df5_y::<>, (x, y, by_0)) +} #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[rustc_intrinsic] -pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64; +pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f5::<>, df5_x::<>, (x, bx_0, y)) +} #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[rustc_intrinsic] -pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; +pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret)) +} struct DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] @@ -82,47 +87,55 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const)] -#[rustc_intrinsic] -pub fn df6() -> DoesNotImplDefault; +pub fn df6() -> DoesNotImplDefault { + std::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ()) +} #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] -#[rustc_intrinsic] -pub fn df7(x: f32) -> (); +pub fn df7(x: f32) -> () { + std::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,)) +} #[no_mangle] #[rustc_autodiff] #[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] -#[rustc_intrinsic] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) --> [f32; 5usize]; + -> [f32; 5usize] { + std::intrinsics::enzyme_autodiff(f8::<>, f8_3::<>, + (x, bx_0, bx_1, bx_2, bx_3)) +} #[rustc_autodiff(Forward, 4, Dual, DualOnly)] -#[rustc_intrinsic] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) --> [f32; 4usize]; + -> [f32; 4usize] { + std::intrinsics::enzyme_autodiff(f8::<>, f8_2::<>, + (x, bx_0, bx_1, bx_2, bx_3)) +} #[rustc_autodiff(Forward, 1, Dual, DualOnly)] -#[rustc_intrinsic] -fn f8_1(x: &f32, bx_0: &f32) -> f32; +fn f8_1(x: &f32, bx_0: &f32) -> f32 { + std::intrinsics::enzyme_autodiff(f8::<>, f8_1::<>, (x, bx_0)) +} pub fn f9() { #[rustc_autodiff] #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] - #[rustc_intrinsic] - fn d_inner_2(x: f32, bx_0: f32) - -> (f32, f32); + fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { + std::intrinsics::enzyme_autodiff(inner::<>, d_inner_2::<>, (x, bx_0)) + } #[rustc_autodiff(Forward, 1, Dual, DualOnly)] - #[rustc_intrinsic] - fn d_inner_1(x: f32, bx_0: f32) - -> f32; + fn d_inner_1(x: f32, bx_0: f32) -> f32 { + std::intrinsics::enzyme_autodiff(inner::<>, d_inner_1::<>, (x, bx_0)) + } } #[rustc_autodiff] #[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] -#[rustc_intrinsic] pub fn d_square + -Copy>(x: &T, dx_0: &mut T, dret: T) -> T; + Copy>(x: &T, dx_0: &mut T, dret: T) -> T { + std::intrinsics::enzyme_autodiff(f10::, d_square::, (x, dx_0, dret)) +} fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index b003d87dccfa7..e763b6382d5ab 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,7 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index 6f368c74f1a26..a79fd3ada978a 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,7 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -30,36 +30,37 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[rustc_intrinsic] -pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; +pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, dx_0, y, dret)) +} #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -#[rustc_intrinsic] -pub fn df2(); +pub fn df2() { std::intrinsics::enzyme_autodiff(f2::<>, df2::<>, ()) } #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[rustc_intrinsic] -pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; +pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { + std::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, dx_0, y, dret)) +} enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -#[rustc_intrinsic] -pub fn df4(x: f32); +pub fn df4(x: f32) { std::intrinsics::enzyme_autodiff(f4::<>, df4::<>, (x,)) } #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] -#[rustc_intrinsic] -pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32); +pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { + std::intrinsics::enzyme_autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0)) +} fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index fc95ba2e5a63e..be5a38401f13e 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,7 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index 4bc8dac0dc758..f9e91955d7283 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,7 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -32,11 +32,8 @@ self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] - #[rustc_intrinsic] fn df(&self, x: f64, dret: f64) -> (f64, f64) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(self.f(x)); - ::core::hint::black_box((dret,)); - ::core::hint::black_box((self.f(x), f64::default())) + std::intrinsics::enzyme_autodiff(Self::f::<>, Self::df::<>, + (self, x, dret)) } } diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 9f00ff5eb02c1..4941bb75d8259 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,7 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(intrinsics)] +#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp diff --git a/tests/ui/autodiff/autodiff_illegal.rs b/tests/ui/autodiff/autodiff_illegal.rs index a53b6d5e58981..1b0651f0d05b6 100644 --- a/tests/ui/autodiff/autodiff_illegal.rs +++ b/tests/ui/autodiff/autodiff_illegal.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_illegal.pp From fde1300fd33512548be808134e8241d5e66e4bba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 25 Jul 2025 17:55:34 +0000 Subject: [PATCH 24/33] Remove cfg enzyme for ci --- compiler/rustc_monomorphize/src/collector.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 74e67165afa79..92a976258a35d 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -239,7 +239,6 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan}; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; -#[cfg(llvm_enzyme)] use crate::collector::autodiff::collect_enzyme_autodiff_fn; use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit}; @@ -917,7 +916,6 @@ fn visit_instance_use<'tcx>( return; } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { - #[cfg(llvm_enzyme)] collect_enzyme_autodiff_fn(tcx, instance, intrinsic, output); if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { From 886559b17e373dd72449f5c7b4702a2c0a8a695d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 25 Jul 2025 18:19:51 +0000 Subject: [PATCH 25/33] FIx cg_gcc --- compiler/rustc_codegen_gcc/src/lib.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/compiler/rustc_codegen_gcc/src/lib.rs b/compiler/rustc_codegen_gcc/src/lib.rs index 613315f77a6b3..9e08d8f297223 100644 --- a/compiler/rustc_codegen_gcc/src/lib.rs +++ b/compiler/rustc_codegen_gcc/src/lib.rs @@ -92,7 +92,6 @@ use gccjit::{CType, Context, OptimizationLevel}; #[cfg(feature = "master")] use gccjit::{TargetInfo, Version}; use rustc_ast::expand::allocator::AllocatorKind; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn, @@ -362,12 +361,7 @@ impl WriteBackendMethods for GccCodegenBackend { _exported_symbols_for_lto: &[String], each_linked_rlib_for_lto: &[PathBuf], modules: Vec>, - diff_fncs: Vec, ) -> Result, FatalError> { - if !diff_fncs.is_empty() { - unimplemented!(); - } - back::lto::run_fat(cgcx, each_linked_rlib_for_lto, modules) } From e8e00813e2eea5d32dcdc288c26418a1210f0647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 27 Jul 2025 11:26:05 +0000 Subject: [PATCH 26/33] Macro expansion cleanup --- compiler/rustc_builtin_macros/src/autodiff.rs | 157 +++--------------- tests/pretty/autodiff/autodiff_forward.pp | 36 ++-- tests/pretty/autodiff/autodiff_reverse.pp | 12 +- tests/pretty/autodiff/inherent_impl.pp | 2 +- 4 files changed, 56 insertions(+), 151 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 1ab6b9c8572d3..e78b64b63a199 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -21,7 +21,7 @@ mod llvm_enzyme { MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; - use rustc_span::{Ident, Span, Symbol, kw, sym}; + use rustc_span::{Ident, Span, Symbol, sym}; use thin_vec::{ThinVec, thin_vec}; use tracing::{debug, trace}; @@ -183,11 +183,8 @@ mod llvm_enzyme { } /// We expand the autodiff macro to generate a new placeholder function which passes - /// type-checking and can be called by users. The function body of the placeholder function will - /// later be replaced on LLVM-IR level, so the design of the body is less important and for now - /// should just prevent early inlining and optimizations which alter the function signature. - /// The exact signature of the generated function depends on the configuration provided by the - /// user, but here is an example: + /// type-checking and can be called by users. The exact signature of the generated function + /// depends on the configuration provided by the user, but here is an example: /// /// ``` /// #[autodiff(cos_box, Reverse, Duplicated, Active)] @@ -203,14 +200,8 @@ mod llvm_enzyme { /// f32::sin(**x) /// } /// #[rustc_autodiff(Reverse, Duplicated, Active)] - /// #[inline(never)] /// fn cos_box(x: &Box, dx: &mut Box, dret: f32) -> f32 { - /// unsafe { - /// asm!("NOP"); - /// }; - /// ::core::hint::black_box(sin(x)); - /// ::core::hint::black_box((dx, dret)); - /// ::core::hint::black_box(sin(x)) + /// std::intrinsics::enzyme_autodiff(sin::<>, cos_box::<>, (x, dx, dret)) /// } /// ``` /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked @@ -330,22 +321,20 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); - let (d_sig, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); + let d_sig = gen_enzyme_decl(ecx, &sig, &x, span); let d_body = gen_enzyme_body( ecx, &d_sig, primal, span, - idents, - errored, first_ident(&meta_item_vec[0]), &generics, impl_of_trait, ); // The first element of it is the name of the function to be generated - let asdf = Box::new(ast::Fn { + let d_fn = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), @@ -442,7 +431,7 @@ mod llvm_enzyme { let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { - let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); + let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn); let d_fn = P(ast::AssocItem { attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, @@ -454,13 +443,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(d_fn)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { @@ -525,14 +514,8 @@ mod llvm_enzyme { .into(), ); - let enzyme_path = ecx.path( - span, - vec![ - Ident::from_str("std"), - Ident::from_str("intrinsics"), - Ident::with_dummy_span(sym::enzyme_autodiff), - ], - ); + let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::enzyme_autodiff]); + let enzyme_path = ecx.path(span, enzyme_path_idents); let call_expr = ecx.expr_call( span, ecx.expr_path(enzyme_path), @@ -591,25 +574,6 @@ mod llvm_enzyme { ecx.expr_path(path) } - // Will generate a body of the type: - // ``` - // primal(args); - // std::intrinsics::enzyme_autodiff(primal, diff, (args)) - // } - // ``` - fn init_body_helper( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - idents: &[Ident], - _errored: bool, - generics: &Generics, - ) -> P { - let _primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let body = ecx.block(span, ThinVec::new()); - body - } - /// We only want this function to type-check, since we will replace the body /// later on llvm level. Using `loop {}` does not cover all return types anymore, /// so instead we manually build something that should pass the type checker. @@ -623,8 +587,6 @@ mod llvm_enzyme { d_sig: &ast::FnSig, primal: Ident, span: Span, - idents: Vec, - errored: bool, diff_ident: Ident, generics: &Generics, is_impl: bool, @@ -633,87 +595,22 @@ mod llvm_enzyme { // Add a call to the primal function to prevent it from being inlined // and call `enzyme_autodiff` intrinsic (this also covers the return type) - let mut body = init_body_helper(ecx, span, primal, &idents, errored, generics); - - body.stmts.push(call_enzyme_autodiff( - ecx, - primal, - diff_ident, - new_decl_span, - d_sig, - generics, - is_impl, - )); + let body = ecx.block( + span, + thin_vec![call_enzyme_autodiff( + ecx, + primal, + diff_ident, + new_decl_span, + d_sig, + generics, + is_impl, + )], + ); body } - fn gen_primal_call( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - idents: &[Ident], - generics: &Generics, - ) -> P { - let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower; - - if has_self { - let args: ThinVec<_> = - idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let self_expr = ecx.expr_self(span); - ecx.expr_method_call(span, self_expr, primal, args) - } else { - let args: ThinVec<_> = - idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect(); - let mut primal_path = ecx.path_ident(span, primal); - - let is_generic = !generics.params.is_empty(); - - match (is_generic, primal_path.segments.last_mut()) { - (true, Some(function_path)) => { - let primal_generic_types = generics - .params - .iter() - .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. })); - - let generated_generic_types = primal_generic_types - .map(|type_param| { - let generic_param = TyKind::Path( - None, - ast::Path { - span, - segments: thin_vec![ast::PathSegment { - ident: type_param.ident, - args: None, - id: ast::DUMMY_NODE_ID, - }], - tokens: None, - }, - ); - - ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty { - id: type_param.id, - span, - kind: generic_param, - tokens: None, - }))) - }) - .collect(); - - function_path.args = - Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs { - span, - args: generated_generic_types, - }))); - } - _ => {} - } - - let primal_call_expr = ecx.expr_path(primal_path); - ecx.expr_call(span, primal_call_expr, args) - } - } - // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer. // Active arguments must be scalars. Their shadow argument is added to the return type (and will be @@ -730,7 +627,7 @@ mod llvm_enzyme { sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - ) -> (ast::FnSig, Vec, bool) { + ) -> ast::FnSig { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; @@ -742,7 +639,7 @@ mod llvm_enzyme { found: num_activities, }); // This is not the right signature, but we can continue parsing. - return (sig.clone(), vec![], true); + return sig.clone(); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); @@ -785,7 +682,7 @@ mod llvm_enzyme { if errors { // This is not the right signature, but we can continue parsing. - return (sig.clone(), idents, true); + return sig.clone(); } let unsafe_activities = x @@ -993,7 +890,7 @@ mod llvm_enzyme { } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); - (d_sig, idents, false) + d_sig } } diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index 89305a1d9351a..baf774f1fb911 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -38,7 +38,7 @@ } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - std::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y)) + ::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y)) } #[rustc_autodiff] #[inline(never)] @@ -47,7 +47,7 @@ } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y)) + ::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y)) } #[rustc_autodiff] #[inline(never)] @@ -56,13 +56,15 @@ } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y)) + ::core::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y)) } #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -pub fn df4() -> () { std::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ()) } +pub fn df4() -> () { + ::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ()) +} #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { @@ -70,15 +72,16 @@ } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f5::<>, df5_y::<>, (x, y, by_0)) + ::core::intrinsics::enzyme_autodiff(f5::<>, df5_y::<>, (x, y, by_0)) } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f5::<>, df5_x::<>, (x, bx_0, y)) + ::core::intrinsics::enzyme_autodiff(f5::<>, df5_x::<>, (x, bx_0, y)) } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f5::<>, df5_rev::<>, (x, dx_0, y, dret)) + ::core::intrinsics::enzyme_autodiff(f5::<>, df5_rev::<>, + (x, dx_0, y, dret)) } struct DoesNotImplDefault; #[rustc_autodiff] @@ -88,14 +91,14 @@ } #[rustc_autodiff(Forward, 1, Const)] pub fn df6() -> DoesNotImplDefault { - std::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ()) + ::core::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ()) } #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] pub fn df7(x: f32) -> () { - std::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,)) + ::core::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,)) } #[no_mangle] #[rustc_autodiff] @@ -104,18 +107,18 @@ #[rustc_autodiff(Forward, 4, Dual, Dual)] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) -> [f32; 5usize] { - std::intrinsics::enzyme_autodiff(f8::<>, f8_3::<>, + ::core::intrinsics::enzyme_autodiff(f8::<>, f8_3::<>, (x, bx_0, bx_1, bx_2, bx_3)) } #[rustc_autodiff(Forward, 4, Dual, DualOnly)] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) -> [f32; 4usize] { - std::intrinsics::enzyme_autodiff(f8::<>, f8_2::<>, + ::core::intrinsics::enzyme_autodiff(f8::<>, f8_2::<>, (x, bx_0, bx_1, bx_2, bx_3)) } #[rustc_autodiff(Forward, 1, Dual, DualOnly)] fn f8_1(x: &f32, bx_0: &f32) -> f32 { - std::intrinsics::enzyme_autodiff(f8::<>, f8_1::<>, (x, bx_0)) + ::core::intrinsics::enzyme_autodiff(f8::<>, f8_1::<>, (x, bx_0)) } pub fn f9() { #[rustc_autodiff] @@ -123,11 +126,13 @@ fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - std::intrinsics::enzyme_autodiff(inner::<>, d_inner_2::<>, (x, bx_0)) + ::core::intrinsics::enzyme_autodiff(inner::<>, d_inner_2::<>, + (x, bx_0)) } #[rustc_autodiff(Forward, 1, Dual, DualOnly)] fn d_inner_1(x: f32, bx_0: f32) -> f32 { - std::intrinsics::enzyme_autodiff(inner::<>, d_inner_1::<>, (x, bx_0)) + ::core::intrinsics::enzyme_autodiff(inner::<>, d_inner_1::<>, + (x, bx_0)) } } #[rustc_autodiff] @@ -136,6 +141,7 @@ #[rustc_autodiff(Reverse, 1, Duplicated, Active)] pub fn d_square + Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - std::intrinsics::enzyme_autodiff(f10::, d_square::, (x, dx_0, dret)) + ::core::intrinsics::enzyme_autodiff(f10::, d_square::, + (x, dx_0, dret)) } fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index a79fd3ada978a..476752be4bde6 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -31,13 +31,13 @@ } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, dx_0, y, dret)) + ::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, dx_0, y, dret)) } #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -pub fn df2() { std::intrinsics::enzyme_autodiff(f2::<>, df2::<>, ()) } +pub fn df2() { ::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, ()) } #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { @@ -45,7 +45,7 @@ } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - std::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, dx_0, y, dret)) + ::core::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, dx_0, y, dret)) } enum Foo { Reverse, } use Foo::Reverse; @@ -53,7 +53,9 @@ #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -pub fn df4(x: f32) { std::intrinsics::enzyme_autodiff(f4::<>, df4::<>, (x,)) } +pub fn df4(x: f32) { + ::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, (x,)) +} #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { @@ -61,6 +63,6 @@ } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - std::intrinsics::enzyme_autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0)) + ::core::intrinsics::enzyme_autodiff(f5::<>, df5::<>, (x, dx_0, y, dy_0)) } fn main() {} diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index f9e91955d7283..31e2f9f799fae 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -33,7 +33,7 @@ } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] fn df(&self, x: f64, dret: f64) -> (f64, f64) { - std::intrinsics::enzyme_autodiff(Self::f::<>, Self::df::<>, + ::core::intrinsics::enzyme_autodiff(Self::f::<>, Self::df::<>, (self, x, dret)) } } From 333069278a4016bb8e5808a363ed23d9a6aa06ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 27 Jul 2025 15:08:49 +0000 Subject: [PATCH 27/33] Remove dead code --- compiler/rustc_codegen_llvm/src/context.rs | 5 ----- compiler/rustc_codegen_ssa/src/codegen_attrs.rs | 8 -------- compiler/rustc_middle/src/middle/codegen_fn_attrs.rs | 4 ---- 3 files changed, 17 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 25e72ecc1cbe6..54f45d1649ddd 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -8,7 +8,6 @@ use std::str; use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx}; use rustc_codegen_ssa::back::versioned_llvm_target; use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh}; -use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::errors as ssa_errors; use rustc_codegen_ssa::traits::*; use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN}; @@ -654,10 +653,6 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { - assert_eq!(self.type_kind(ty), TypeKind::Function); - unsafe { llvm::LLVMGetReturnType(ty) } - } pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type { unsafe { llvm::LLVMGlobalGetValueType(val) } } diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 3ea48fa219924..cfd314330e881 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -210,14 +210,6 @@ fn process_builtin_attrs( let mut interesting_spans = InterestingAttributeDiagnosticSpans::default(); let rust_target_features = tcx.rust_target_features(LOCAL_CRATE); - // If our rustc version supports autodiff/enzyme, then we call our handler - // to check for any `#[rustc_autodiff(...)]` attributes. - // FIXME(jdonszelmann): merge with loop below - if cfg!(llvm_enzyme) { - let ad = autodiff_attrs(tcx, did.into()); - codegen_fn_attrs.autodiff_item = ad; - } - for attr in attrs.iter() { if let hir::Attribute::Parsed(p) = attr { match p { diff --git a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs index 94384e64afd15..3b290d90da7a8 100644 --- a/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs +++ b/compiler/rustc_middle/src/middle/codegen_fn_attrs.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use rustc_abi::Align; -use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_hir::attrs::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_macros::{HashStable, TyDecodable, TyEncodable}; use rustc_span::Symbol; @@ -75,8 +74,6 @@ pub struct CodegenFnAttrs { /// The `#[patchable_function_entry(...)]` attribute. Indicates how many nops should be around /// the function entry. pub patchable_function_entry: Option, - /// For the `#[autodiff]` macros. - pub autodiff_item: Option, } #[derive(Copy, Clone, Debug, TyEncodable, TyDecodable, HashStable)] @@ -182,7 +179,6 @@ impl CodegenFnAttrs { instruction_set: None, alignment: None, patchable_function_entry: None, - autodiff_item: None, } } From c56d57f8233da9aa773bf0c90dd4340b4390d650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 28 Jul 2025 07:22:52 +0000 Subject: [PATCH 28/33] Add `enzyme_autodiff` doc comment --- library/core/src/intrinsics/mod.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index e209930bbcb35..95ee692beff1d 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3163,6 +3163,17 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +/// Generates the LLVM body for the automatic differentiation of `f` using Enzyme, +/// with `df` as the derivative function and `args` as its arguments. +/// +/// Used internally as the body of `df` when expanding the `#[autodiff_forward]` +/// and `#[autodiff_reverse]` attribute macros. +/// +/// Type Parameters: +/// - `F`: The original function to differentiate. Must be a function item. +/// - `G`: The derivative function. Must be a function item. +/// - `T`: A tuple of arguments passed to `df`. +/// - `R`: The return type of the derivative function. #[rustc_nounwind] #[rustc_intrinsic] pub const fn enzyme_autodiff(f: F, df: G, args: T) -> R; From 7f0aca5e8a1c4e1610ae66528920a459fade0890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 28 Jul 2025 16:30:21 +0000 Subject: [PATCH 29/33] Add expansion example to intrinsic docs --- library/core/src/intrinsics/mod.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 95ee692beff1d..eb4dd74c697c5 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3174,6 +3174,29 @@ pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; /// - `G`: The derivative function. Must be a function item. /// - `T`: A tuple of arguments passed to `df`. /// - `R`: The return type of the derivative function. +/// +/// This shows where the `enzyme_autodiff` intrinsic is used during macro expansion: +/// +/// ```rust,ignore (macro example) +/// #[autodiff_forward(df1, Dual, Const, Dual)] +/// pub fn f1(x: &[f64], y: f64) -> f64 { +/// unimplemented!() +/// } +/// ``` +/// +/// expands to: +/// +/// ```rust,ignore (macro example) +/// #[rustc_autodiff] +/// #[inline(never)] +/// pub fn f1(x: &[f64], y: f64) -> f64 { +/// ::core::panicking::panic("not implemented") +/// } +/// #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] +/// pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { +/// ::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y)) +/// } +/// ``` #[rustc_nounwind] #[rustc_intrinsic] pub const fn enzyme_autodiff(f: F, df: G, args: T) -> R; From 5673cb838cc3e66e18541843028d0d5706d84dd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 29 Jul 2025 12:46:22 +0000 Subject: [PATCH 30/33] Better error handling --- compiler/rustc_builtin_macros/src/autodiff.rs | 14 +++--- compiler/rustc_codegen_llvm/src/intrinsic.rs | 46 +++++++++++++++---- .../src/collector/autodiff.rs | 4 +- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index e78b64b63a199..7a9d27afe3a22 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -74,12 +74,10 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info( - iitem: &P, - ) -> Option<(Visibility, FnSig, Ident, Generics, bool)> { + fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { match &iitem.kind { ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false)) + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) } _ => None, } @@ -223,9 +221,13 @@ mod llvm_enzyme { // parameters. // these will be used to generate the differentiated version of the function let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { - Annotatable::Item(iitem) => extract_item_info(iitem), + Annotatable::Item(iitem) => { + extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) + } Annotatable::Stmt(stmt) => match &stmt.kind { - ast::StmtKind::Item(iitem) => extract_item_info(iitem), + ast::StmtKind::Item(iitem) => { + extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false)) + } _ => None, }, Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 10a2872baa35b..b1b55236cee85 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1145,25 +1145,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( // Get source, diff, and attrs let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { ty::FnDef(def_id, source_params) => (def_id, source_params), - _ => bug!("invalid args"), + _ => bug!("invalid autodiff intrinsic args"), + }; + + let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) { + Ok(Some(instance)) => instance, + Ok(None) => bug!( + "could not resolve ({:?}, {:?}) to a specific autodiff instance", + source_id, + source_args + ), + Err(_) => { + // An error has already been emitted + return; + } }; - let fn_source = - Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap(); + let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); - let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol); - let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else { + bug!("could not find source function") + }; let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { ty::FnDef(def_id, diff_args) => (def_id, diff_args), _ => bug!("invalid args"), }; - let fn_diff = - Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); - let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2], fn_diff); + + let fn_diff = match Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args) { + Ok(Some(instance)) => instance, + Ok(None) => bug!( + "could not resolve ({:?}, {:?}) to a specific autodiff instance", + diff_id, + diff_args + ), + Err(_) => { + // An error has already been emitted + return; + } + }; + + let val_arr = get_args_from_tuple(bx, args[2], fn_diff); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); - let Some(mut diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else { + bug!("could not find autodiff attrs") + }; adjust_activity_to_abi( tcx, diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs index f388f3779a289..37b1375ed446d 100644 --- a/compiler/rustc_monomorphize/src/collector/autodiff.rs +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -39,9 +39,9 @@ fn collect_autodiff_fn_from_arg<'tcx>( (instance, span) } - _ => bug!("expected function"), + _ => bug!("expected autodiff function"), }, - _ => bug!("expected type"), + _ => bug!("expected type when matching autodiff arg"), }; output.push(create_fn_mono_item(tcx, instance, span)); From 68df0653056afea2ebccb91367817e50b5be4df2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 29 Jul 2025 17:58:27 +0000 Subject: [PATCH 31/33] Allow `core_intrinsics` when `autodiff` is enabled --- library/core/src/macros/mod.rs | 2 ++ tests/codegen-llvm/{ => autodiff}/autodiffv2.rs | 1 - tests/codegen-llvm/autodiff/batched.rs | 1 - tests/codegen-llvm/autodiff/generic.rs | 1 - tests/codegen-llvm/autodiff/identical_fnc.rs | 1 - tests/codegen-llvm/autodiff/scalar.rs | 1 - tests/codegen-llvm/autodiff/sret.rs | 1 - tests/codegen-llvm/autodiff/trait.rs | 1 - tests/pretty/autodiff/autodiff_forward.pp | 1 - tests/pretty/autodiff/autodiff_forward.rs | 1 - tests/pretty/autodiff/autodiff_reverse.pp | 1 - tests/pretty/autodiff/autodiff_reverse.rs | 1 - tests/pretty/autodiff/inherent_impl.pp | 1 - tests/pretty/autodiff/inherent_impl.rs | 1 - tests/ui/autodiff/autodiff_illegal.rs | 1 - 15 files changed, 2 insertions(+), 14 deletions(-) rename tests/codegen-llvm/{ => autodiff}/autodiffv2.rs (99%) diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index 3d57da63683b2..061427285631d 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1491,6 +1491,7 @@ pub(crate) mod builtin { /// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities. #[unstable(feature = "autodiff", issue = "124509")] #[allow_internal_unstable(rustc_attrs)] + #[allow_internal_unstable(core_intrinsics)] #[rustc_builtin_macro] pub macro autodiff_forward($item:item) { /* compiler built-in */ @@ -1509,6 +1510,7 @@ pub(crate) mod builtin { /// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities. #[unstable(feature = "autodiff", issue = "124509")] #[allow_internal_unstable(rustc_attrs)] + #[allow_internal_unstable(core_intrinsics)] #[rustc_builtin_macro] pub macro autodiff_reverse($item:item) { /* compiler built-in */ diff --git a/tests/codegen-llvm/autodiffv2.rs b/tests/codegen-llvm/autodiff/autodiffv2.rs similarity index 99% rename from tests/codegen-llvm/autodiffv2.rs rename to tests/codegen-llvm/autodiff/autodiffv2.rs index ca34d67365084..85aed6a183b63 100644 --- a/tests/codegen-llvm/autodiffv2.rs +++ b/tests/codegen-llvm/autodiff/autodiffv2.rs @@ -25,7 +25,6 @@ // in our frontend and in the llvm backend to avoid these issues. #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_forward; diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs index 5e94c7bb9b8e6..665bebf17d53f 100644 --- a/tests/codegen-llvm/autodiff/batched.rs +++ b/tests/codegen-llvm/autodiff/batched.rs @@ -10,7 +10,6 @@ // reduce this test to only match the first lines and the ret instructions. #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_forward; diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs index da17c55b1427f..995fdf5d90ed5 100644 --- a/tests/codegen-llvm/autodiff/generic.rs +++ b/tests/codegen-llvm/autodiff/generic.rs @@ -2,7 +2,6 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs index d847a637e7ad9..894906f067ba7 100644 --- a/tests/codegen-llvm/autodiff/identical_fnc.rs +++ b/tests/codegen-llvm/autodiff/identical_fnc.rs @@ -10,7 +10,6 @@ // We also explicetly test that we keep running merge_function after AD, by checking for two // identical function calls in the LLVM-IR, while having two different calls in the Rust code. #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs index 745b03ee0ed8f..096b4209e84ad 100644 --- a/tests/codegen-llvm/autodiff/scalar.rs +++ b/tests/codegen-llvm/autodiff/scalar.rs @@ -2,7 +2,6 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs index e2272fd4df7d3..d8451b0eb2d5a 100644 --- a/tests/codegen-llvm/autodiff/sret.rs +++ b/tests/codegen-llvm/autodiff/sret.rs @@ -8,7 +8,6 @@ // We therefore use this test to verify some of our sret handling. #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen-llvm/autodiff/trait.rs b/tests/codegen-llvm/autodiff/trait.rs index 988e9145087b2..701f3a9e843bd 100644 --- a/tests/codegen-llvm/autodiff/trait.rs +++ b/tests/codegen-llvm/autodiff/trait.rs @@ -5,7 +5,6 @@ // Just check it does not crash for now // CHECK: ; #![feature(autodiff)] -#![feature(core_intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index baf774f1fb911..4b24feea362d7 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,7 +3,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index e763b6382d5ab..e23a1b3e241e9 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,7 +1,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index 476752be4bde6..c4d87e6d47767 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,7 +3,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index be5a38401f13e..c50b81d7780d0 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,7 +1,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index 31e2f9f799fae..dd551df2a2caa 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,7 +3,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 4941bb75d8259..11ff209f9d89e 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,7 +1,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp diff --git a/tests/ui/autodiff/autodiff_illegal.rs b/tests/ui/autodiff/autodiff_illegal.rs index 1b0651f0d05b6..a53b6d5e58981 100644 --- a/tests/ui/autodiff/autodiff_illegal.rs +++ b/tests/ui/autodiff/autodiff_illegal.rs @@ -1,7 +1,6 @@ //@ needs-enzyme #![feature(autodiff)] -#![feature(core_intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_illegal.pp From b79b337d76e69bab6ba6bf839b64de8ecf1aa1c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 30 Jul 2025 07:18:16 +0000 Subject: [PATCH 32/33] Remove autodiff limitations subsection --- src/doc/rustc-dev-guide/src/SUMMARY.md | 1 - .../src/autodiff/limitations.md | 27 ------------------- 2 files changed, 28 deletions(-) delete mode 100644 src/doc/rustc-dev-guide/src/autodiff/limitations.md diff --git a/src/doc/rustc-dev-guide/src/SUMMARY.md b/src/doc/rustc-dev-guide/src/SUMMARY.md index e3c0d50fcc737..8518f0033d068 100644 --- a/src/doc/rustc-dev-guide/src/SUMMARY.md +++ b/src/doc/rustc-dev-guide/src/SUMMARY.md @@ -107,7 +107,6 @@ - [Installation](./autodiff/installation.md) - [How to debug](./autodiff/debugging.md) - [Autodiff flags](./autodiff/flags.md) - - [Current limitations](./autodiff/limitations.md) # Source Code Representation diff --git a/src/doc/rustc-dev-guide/src/autodiff/limitations.md b/src/doc/rustc-dev-guide/src/autodiff/limitations.md deleted file mode 100644 index 90afbd51f3fd9..0000000000000 --- a/src/doc/rustc-dev-guide/src/autodiff/limitations.md +++ /dev/null @@ -1,27 +0,0 @@ -# Current limitations - -## Safety and Soundness - -Enzyme currently assumes that the user passes shadow arguments (`dx`, `dy`, ...) of appropriate size. Under Reverse Mode, we additionally assume that shadow arguments are mutable. In Reverse Mode we adjust the outermost pointer or reference to be mutable. Therefore `&f32` will receive the shadow type `&mut f32`. However, we do not check length for other types than slices (e.g. enums, Vec). We also do not enforce mutability of inner references, but will warn if we recognize them. We do intend to add additional checks over time. - -## ABI adjustments - -In some cases, a function parameter might get lowered in a way that we currently don't handle correctly, leading to a compile time type mismatch in the `rustc_codegen_llvm` backend. Here are some [examples](https://github.com/EnzymeAD/rust/issues/105). - -## Compile Times - -Enzyme will often achieve excellent runtime performance, but might increase your compile time by a large factor. For Rust, we already have made significant improvements and have a list of further improvements planed - please reach out if you have time to help here. - -### Type Analysis - -Most of the times, Type Analysis (TA) is the reason of large (>5x) compile time increases when using Enzyme. This poster explains why we need to run Type Analysis in the bottom left part: [Poster Link](https://c.wsmoses.com/posters/Enzyme-llvmdev.pdf). - -We intend to increase the number of locations where we pass down Type information based on Rust types, which in turn will reduce the number of locations where Enzyme has to run Type Analysis, which will help compile times. - -### Duplicated Optimizations - -The key reason for Enzyme offering often excellent performance is that Enzyme differentiates already optimized LLVM-IR. However, we also (have to) run LLVM's optimization pipeline after differentiating, to make sure that the code which Enzyme generates is optimized properly. As a result you should have excellent runtime performance (please fill an issue if not), but at a compile time cost for running optimizations twice. - -### Fat-LTO - -The usage of `#[autodiff(...)]` currently requires compiling your project with Fat-LTO. We technically only need LTO if the function being differentiated calls functions in other compilation units. Therefore, other solutions are possible, but this is the most simple one to get started. From b7d7b733573daaa9f038e35afd8a47493e1f3a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 30 Jul 2025 13:21:30 +0000 Subject: [PATCH 33/33] Remove inlining for autodiff handling --- compiler/rustc_builtin_macros/src/autodiff.rs | 30 +------ .../src/builder/autodiff.rs | 8 +- tests/codegen-llvm/autodiff/batched.rs | 89 +++++++------------ tests/codegen-llvm/autodiff/generic.rs | 16 +--- tests/codegen-llvm/autodiff/identical_fnc.rs | 16 ++-- tests/codegen-llvm/autodiff/scalar.rs | 13 +-- tests/codegen-llvm/autodiff/sret.rs | 21 ++--- tests/pretty/autodiff/autodiff_forward.pp | 10 --- tests/pretty/autodiff/autodiff_reverse.pp | 5 -- tests/pretty/autodiff/inherent_impl.pp | 1 - 10 files changed, 58 insertions(+), 151 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 7a9d27afe3a22..8622df657eda9 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -348,28 +348,10 @@ mod llvm_enzyme { let mut rustc_ad_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff))); - let ts2: Vec = vec![TokenTree::Token( - Token::new(TokenKind::Ident(sym::never, false.into()), span), - Spacing::Joint, - )]; - let never_arg = ast::DelimArgs { - dspan: DelimSpan::from_single(span), - delim: ast::token::Delimiter::Parenthesis, - tokens: TokenStream::from_iter(ts2), - }; - let inline_item = ast::AttrItem { - unsafety: ast::Safety::Default, - path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)), - args: ast::AttrArgs::Delimited(never_arg), - tokens: None, - }; - let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None }); let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let attr = outer_normal_attr(&rustc_ad_attr, new_id, span); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let inline_never = outer_normal_attr(&inline_never_attr, new_id, span); - // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`. + // We're avoid duplicating the attribute `#[rustc_autodiff]`. fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool { match (attr, item) { (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => { @@ -388,18 +370,12 @@ mod llvm_enzyme { if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { iitem.attrs.push(attr); } - if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { - iitem.attrs.push(inline_never.clone()); - } Annotatable::Item(iitem.clone()) } Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => { if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { assoc_item.attrs.push(attr); } - if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) { - assoc_item.attrs.push(inline_never.clone()); - } Annotatable::AssocItem(assoc_item.clone(), i) } Annotatable::Stmt(ref mut stmt) => { @@ -408,10 +384,6 @@ mod llvm_enzyme { if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) { iitem.attrs.push(attr); } - if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) - { - iitem.attrs.push(inline_never.clone()); - } } _ => unreachable!("stmt kind checked previously"), }; diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 56116959a6223..e2df3265f6f7d 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -10,10 +10,9 @@ use tracing::debug; use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::llvm::AttributePlace::Function; +use crate::llvm; use crate::llvm::{Metadata, True, Type}; use crate::value::Value; -use crate::{attributes, llvm}; pub(crate) fn adjust_activity_to_abi<'tcx>( tcx: TyCtxt<'tcx>, @@ -308,11 +307,6 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( enzyme_ty, ); - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); diff --git a/tests/codegen-llvm/autodiff/batched.rs b/tests/codegen-llvm/autodiff/batched.rs index 665bebf17d53f..2104d2028493a 100644 --- a/tests/codegen-llvm/autodiff/batched.rs +++ b/tests/codegen-llvm/autodiff/batched.rs @@ -21,74 +21,39 @@ fn square(x: &f32) -> f32 { x * x } -// d_square2 -// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'") -// CHECK-NEXT: start: -// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 -// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 -// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 -// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 -// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 -// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 -// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 -// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl" -// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val -// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 -// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1" -// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val -// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 -// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2" -// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val -// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 -// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3" -// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val -// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 -// CHECK-NEXT: ret [4 x float] %15 -// CHECK-NEXT: } - -// d_square3, the extra float is the original return value (x * x) -// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'") -// CHECK-NEXT: start: -// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0 -// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4 -// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1 -// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4 -// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2 -// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4 -// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3 -// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4 -// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val -// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl" -// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val -// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0 -// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1" -// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val -// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1 -// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2" -// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val -// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2 -// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3" -// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val -// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3 -// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0 -// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1 -// CHECK-NEXT: ret { float, [4 x float] } %17 -// CHECK-NEXT: } - fn main() { let x = std::hint::black_box(3.0); + + // square(&x) + // CHECK: %_0.i = fmul float %_2.i, %_2.i + // CHECK-NEXT: store float %_0.i, ptr %output, align 4 let output = square(&x); dbg!(&output); assert_eq!(9.0, output); + + // square(&x) + // CHECK: %_2.i26 = load float, ptr %x, align 4 + // CHECK-NEXT: %_0.i27 = fmul float %_2.i26, %_2.i26 dbg!(square(&x)); let mut df_dx1 = 1.0; let mut df_dx2 = 2.0; let mut df_dx3 = 3.0; let mut df_dx4 = 0.0; + + // [o1, o2, o3, o4] (o4 is being optimized away as its smth * 0.0) + // CHECK: %x.val = load float, ptr %x, align 4 + // CHECK-NEXT: %13 = fmul fast float %x.val, 2.000000e+00 + // CHECK-NEXT: %14 = fmul fast float %x.val, 4.000000e+00 + // CHECK-NEXT: %15 = fmul fast float %x.val, 6.000000e+00 let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); dbg!(o1, o2, o3, o4); + + // [output2, o1, o2, o3, o4] (o4 is being optimized away as its smth * 0.0) + // CHECK: %_0.i45 = fmul float %x.val35, %x.val35 + // CHECK-NEXT: %40 = fmul fast float %x.val35, 2.000000e+00 + // CHECK-NEXT: %41 = fmul fast float %x.val35, 4.000000e+00 + // CHECK-NEXT: %42 = fmul fast float %x.val35, 6.000000e+00 let [output2, o1, o2, o3, o4] = d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4); dbg!(o1, o2, o3, o4); @@ -101,8 +66,22 @@ fn main() { assert_eq!(2.0, df_dx2); assert_eq!(3.0, df_dx3); assert_eq!(0.0, df_dx4); + + // d_square3(&x, &mut df_dx1) + // CHECK: %x.val39 = load float, ptr %x, align 4 + // CHECK-NEXT: %72 = fmul fast float %x.val39, 2.000000e+00 assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1); + + // d_square3(&x, &mut df_dx2) + // CHECK: %74 = fmul fast float %x.val39, 4.000000e+00 + // CHECK-NEXT: store float %74, ptr %_191, align 4 assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2); + + // d_square3(&x, &mut df_dx3) + // CHECK: %76 = fmul fast float %x.val39, 6.000000e+00 + // CHECK-NEXT: store float %76, ptr %_200, align 4 assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3); + + // d_square3(&x, &mut df_dx3) is being optimized away as it's smth * 0.0 assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4); } diff --git a/tests/codegen-llvm/autodiff/generic.rs b/tests/codegen-llvm/autodiff/generic.rs index 995fdf5d90ed5..603ca849e689e 100644 --- a/tests/codegen-llvm/autodiff/generic.rs +++ b/tests/codegen-llvm/autodiff/generic.rs @@ -10,23 +10,13 @@ fn square + Copy>(x: &T) -> T { *x * *x } -// Ensure that `d_square::` code is generated +// Ensure that `square::` code is generated // -// CHECK: ; generic::square -// CHECK-NEXT: ; Function Attrs: {{.*}} -// CHECK-NEXT: define internal {{.*}} float -// CHECK-NEXT: start: -// CHECK-NOT: ret -// CHECK: fmul float +// CHECK: %1 = fmul float %xf32, %xf32 // Ensure that `d_square::` code is generated even if `square::` was never called // -// CHECK: ; generic::square -// CHECK-NEXT: ; Function Attrs: -// CHECK-NEXT: define internal {{.*}} double -// CHECK-NEXT: start: -// CHECK-NOT: ret -// CHECK: fmul double +// CHECK: define internal { double } @diffe_ZN7generic6square17he5c855620985cd59E fn main() { let xf32: f32 = std::hint::black_box(3.0); diff --git a/tests/codegen-llvm/autodiff/identical_fnc.rs b/tests/codegen-llvm/autodiff/identical_fnc.rs index 894906f067ba7..e3b5db291cfbb 100644 --- a/tests/codegen-llvm/autodiff/identical_fnc.rs +++ b/tests/codegen-llvm/autodiff/identical_fnc.rs @@ -23,15 +23,13 @@ fn square2(x: &f64) -> f64 { x * x } -// CHECK:; identical_fnc::main -// CHECK-NEXT:; Function Attrs: -// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E() -// CHECK-NEXT:start: -// CHECK-NOT:br -// CHECK-NOT:ret -// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx1) -// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx2) - +// CHECK: %0 = fadd fast double %x.val, %x.val +// CHECK-NEXT: %1 = load double, ptr %dx1, align 8 +// CHECK-NEXT: %2 = fadd fast double %1, %0 +// CHECK-NEXT: store double %2, ptr %dx1, align 8 +// CHECK-NEXT: %3 = load double, ptr %dx2, align 8 +// CHECK-NEXT: %4 = fadd fast double %3, %0 +// CHECK-NEXT: store double %4, ptr %dx2, align 8 fn main() { let x = std::hint::black_box(3.0); let mut dx1 = std::hint::black_box(1.0); diff --git a/tests/codegen-llvm/autodiff/scalar.rs b/tests/codegen-llvm/autodiff/scalar.rs index 096b4209e84ad..ca6bf76aeb12b 100644 --- a/tests/codegen-llvm/autodiff/scalar.rs +++ b/tests/codegen-llvm/autodiff/scalar.rs @@ -11,16 +11,11 @@ fn square(x: &f64) -> f64 { x * x } -// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'" -// CHECK-NEXT:invertstart: -// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val -// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val -// CHECK-NEXT: %1 = load double, ptr %"x'", align 8 -// CHECK-NEXT: %2 = fadd fast double %1, %0 -// CHECK-NEXT: store double %2, ptr %"x'", align 8 -// CHECK-NEXT: ret double %_0 -// CHECK-NEXT:} +// square +// CHECK: %_0.i = fmul double %_2.i, %_2.i +// d_square +// CHECK: %0 = fadd fast double %_2.i, %_2.i fn main() { let x = std::hint::black_box(3.0); let output = square(&x); diff --git a/tests/codegen-llvm/autodiff/sret.rs b/tests/codegen-llvm/autodiff/sret.rs index d8451b0eb2d5a..658055a561a1a 100644 --- a/tests/codegen-llvm/autodiff/sret.rs +++ b/tests/codegen-llvm/autodiff/sret.rs @@ -17,19 +17,14 @@ fn primal(x: f32, y: f32) -> f64 { (x * x * y) as f64 } -// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y) -// CHECK-NEXT: invertstart: -// CHECK-NEXT: %_4 = fmul float %x, %x -// CHECK-NEXT: %_3 = fmul float %_4, %y -// CHECK-NEXT: %_0 = fpext float %_3 to double -// CHECK-NEXT: %0 = fadd fast float %y, %y -// CHECK-NEXT: %1 = fmul fast float %0, %x -// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0 -// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1 -// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2 -// CHECK-NEXT: ret { double, float, float } %4 -// CHECK-NEXT: } - +// CHECK: %_4.i = fmul float %x, %x +// CHECK-NEXT: %_3.i = fmul float %_4.i, %y +// CHECK-NEXT: %_0.i = fpext float %_3.i to double +// CHECK-NEXT: %3 = fadd fast float %y, %y +// CHECK-NEXT: %4 = fmul fast float %3, %x +// CHECK-NEXT: store double %_0.i, ptr %r1, align 8 +// CHECK-NEXT: store float %4, ptr %r2, align 4 +// CHECK-NEXT: store float %_4.i, ptr %r3, align 4 fn main() { let x = std::hint::black_box(3.0); let y = std::hint::black_box(2.5); diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index 4b24feea362d7..e337f4ac21202 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -16,7 +16,6 @@ use std::autodiff::{autodiff_forward, autodiff_reverse}; #[rustc_autodiff] -#[inline(never)] pub fn f1(x: &[f64], y: f64) -> f64 { @@ -40,7 +39,6 @@ ::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y)) } #[rustc_autodiff] -#[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } @@ -49,7 +47,6 @@ ::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y)) } #[rustc_autodiff] -#[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } @@ -58,14 +55,12 @@ ::core::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y)) } #[rustc_autodiff] -#[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] pub fn df4() -> () { ::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ()) } #[rustc_autodiff] -#[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } @@ -84,7 +79,6 @@ } struct DoesNotImplDefault; #[rustc_autodiff] -#[inline(never)] pub fn f6() -> DoesNotImplDefault { ::core::panicking::panic("not implemented") } @@ -93,7 +87,6 @@ ::core::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ()) } #[rustc_autodiff] -#[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] pub fn df7(x: f32) -> () { @@ -101,7 +94,6 @@ } #[no_mangle] #[rustc_autodiff] -#[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) @@ -121,7 +113,6 @@ } pub fn f9() { #[rustc_autodiff] - #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { @@ -135,7 +126,6 @@ } } #[rustc_autodiff] -#[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] pub fn d_square + diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index c4d87e6d47767..9fd9be7b1a87a 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -16,7 +16,6 @@ use std::autodiff::autodiff_reverse; #[rustc_autodiff] -#[inline(never)] pub fn f1(x: &[f64], y: f64) -> f64 { // Not the most interesting derivative, but who are we to judge @@ -33,12 +32,10 @@ ::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, dx_0, y, dret)) } #[rustc_autodiff] -#[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] pub fn df2() { ::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, ()) } #[rustc_autodiff] -#[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } @@ -49,14 +46,12 @@ enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] -#[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] pub fn df4(x: f32) { ::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, (x,)) } #[rustc_autodiff] -#[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index dd551df2a2caa..36ff402f8bcca 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -26,7 +26,6 @@ impl MyTrait for Foo { #[rustc_autodiff] - #[inline(never)] fn f(&self, x: f64) -> f64 { self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) }