diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 2f918faaf752b..4a9915caed6bf 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,6 +6,7 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; +use crate::expand::typetree::TypeTree; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -85,6 +86,9 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, + // Type Tree support + pub inputs: Vec, + pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -276,14 +280,23 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { - AutoDiffItem { source, target, attrs: self } + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } } } impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs) + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) } } diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index a4ec4bf8deac4..99e5a4bccca3c 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -10,6 +10,7 @@ use gccjit::{ use rustc_abi as abi; use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout, WrappingRange}; use rustc_apfloat::{Float, Round, Status, ieee}; +use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::MemFlags; use rustc_codegen_ssa::common::{ AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope, TypeKind, @@ -1368,6 +1369,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { _src_align: Align, size: RValue<'gcc>, flags: MemFlags, + _tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_size_t(), false); diff --git a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs index 0753ac1aeb84e..ed576c9fd1405 100644 --- a/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs +++ b/compiler/rustc_codegen_gcc/src/intrinsic/mod.rs @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(self.layout.size.bytes()), MemFlags::empty(), + None, ); bx.lifetime_end(scratch, scratch_size); diff --git a/compiler/rustc_codegen_llvm/src/abi.rs b/compiler/rustc_codegen_llvm/src/abi.rs index 009e7e2487b66..3ee66ec1cec77 100644 --- a/compiler/rustc_codegen_llvm/src/abi.rs +++ b/compiler/rustc_codegen_llvm/src/abi.rs @@ -238,6 +238,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> { scratch_align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); bx.lifetime_end(llscratch, scratch_size); } diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 84302009da999..590ce00202c70 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -682,13 +682,11 @@ pub(crate) fn run_pass_manager( for function in cx.get_functions() { let enzyme_marker = "enzyme_marker"; if attributes::has_string_attr(function, enzyme_marker) { - // Sanity check: Ensure 'noinline' is present before replacing it. - assert!( - attributes::has_attr(function, Function, llvm::AttributeKind::NoInline), - "Expected __enzyme function to have 'noinline' before adding 'alwaysinline'" - ); + // Remove 'noinline' if present (it should be there in most cases) + if attributes::has_attr(function, Function, llvm::AttributeKind::NoInline) { + attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); + } - attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); attributes::remove_string_attr_from_llfn(function, enzyme_marker); assert!( diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0ade9edb0d2ea..45a0dca491d99 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow}; use std::ops::Deref; use std::{iter, ptr}; +use rustc_ast::expand::typetree::FncTree; pub(crate) mod autodiff; pub(crate) mod gpu_offload; @@ -31,6 +32,7 @@ use tracing::{debug, instrument}; use crate::abi::FnAbiLlvmExt; use crate::attributes; +use crate::builder::autodiff::add_tt; use crate::common::Funclet; use crate::context::{CodegenCx, FullCx, GenericCx, SCx}; use crate::llvm::{ @@ -1105,11 +1107,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align: Align, size: &'ll Value, flags: MemFlags, + tt: Option, ) { assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); let size = self.intcast(size, self.type_isize(), false); let is_volatile = flags.contains(MemFlags::VOLATILE); - unsafe { + let memcpy = unsafe { llvm::LLVMRustBuildMemCpy( self.llbuilder, dst, @@ -1118,7 +1121,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { src_align.bytes() as c_uint, size, is_volatile, - ); + ) + }; + + // TypeTree metadata for memcpy is especially important: when Enzyme encounters + // a memcpy during autodiff, it needs to know the structure of the data being + // copied to properly track derivatives. For example, copying an array of floats + // vs. copying a struct with mixed types requires different derivative handling. + // The TypeTree tells Enzyme exactly what memory layout to expect. + if let Some(tt) = tt { + add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 829b3c513c258..80a91b9c7797b 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,6 +1,8 @@ +use std::os::raw::{c_char, c_uint}; use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; @@ -14,7 +16,7 @@ use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; -use crate::llvm::{Metadata, True}; +use crate::llvm::{Metadata, True, TypeTree}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; @@ -512,3 +514,141 @@ pub(crate) fn differentiate<'ll>( Ok(()) } + +/// Converts a Rust TypeTree to Enzyme's internal TypeTree format +/// +/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) +/// and converts it to Enzyme's internal C++ TypeTree representation that +/// Enzyme can understand during differentiation analysis. +#[cfg(llvm_enzyme)] +fn to_enzyme_typetree( + rust_typetree: RustTypeTree, + data_layout: &str, + llcx: &llvm::Context, +) -> TypeTree { + // Start with an empty TypeTree + let mut enzyme_tt = TypeTree::new(); + + // Convert each Type in the Rust TypeTree to Enzyme format + for rust_type in rust_typetree.0 { + let concrete_type = match rust_type.kind { + rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, + rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, + rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, + rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, + rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, + rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, + rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, + }; + + // Create a TypeTree for this specific type + let type_tt = TypeTree::from_type(concrete_type, llcx); + + // Apply offset if specified + let type_tt = if rust_type.offset == -1 { + type_tt // -1 means everywhere/no specific offset + } else { + // Apply specific offset positioning + type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) + }; + + // Merge this type into the main TypeTree + enzyme_tt = enzyme_tt.merge(type_tt); + } + + enzyme_tt +} + +#[cfg(not(llvm_enzyme))] +#[allow(dead_code)] +fn to_enzyme_typetree( + _rust_typetree: RustTypeTree, + _data_layout: &str, + _llcx: &llvm::Context, +) -> ! { + unimplemented!("TypeTree conversion not available without llvm_enzyme support") +} + +// Attaches TypeTree information to LLVM function as enzyme_type attributes. +#[cfg(llvm_enzyme)] +pub(crate) fn add_tt<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + fn_def: &'ll Value, + tt: FncTree, +) { + let inputs = tt.args; + let ret_tt: RustTypeTree = tt.ret; + + // Get LLVM data layout string for TypeTree conversion + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + // Attribute name that Enzyme recognizes for TypeTree information + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + + // Attach TypeTree attributes to each input parameter + // Enzyme uses these to understand parameter memory layouts during differentiation + for (i, input) in inputs.iter().enumerate() { + unsafe { + // Convert Rust TypeTree to Enzyme's internal format + let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + + // Serialize TypeTree to string format that Enzyme can parse + let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + + // Create LLVM string attribute with TypeTree information + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + + // Attach attribute to the specific function parameter + // Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + + // Free the C string to prevent memory leaks + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + } + } + + // Attach TypeTree attribute to the return type + // Enzyme needs this to understand how to handle return value derivatives + unsafe { + let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + + let ret_attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + + // Attach to function return type + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + + // Free the C string + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + } +} + +// Fallback implementation when Enzyme is not available +#[cfg(not(llvm_enzyme))] +pub(crate) fn add_tt<'ll>( + _llmod: &'ll llvm::Module, + _llcx: &'ll llvm::Context, + _fn_def: &'ll Value, + _tt: FncTree, +) { + unimplemented!() +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 56d756e52cce1..a827d6234c18a 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -3,9 +3,35 @@ use libc::{c_char, c_uint}; use super::MetadataKindId; -use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value}; +use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value}; use crate::llvm::{Bool, Builder}; +// TypeTree types +pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree; + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub(crate) struct EnzymeTypeTree { + _unused: [u8; 0], +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +#[allow(non_camel_case_types)] +pub(crate) enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} + +pub(crate) struct TypeTree { + pub(crate) inner: CTypeTreeRef, +} + #[link(name = "llvm-wrapper", kind = "static")] unsafe extern "C" { // Enzyme @@ -68,10 +94,33 @@ pub(crate) mod Enzyme_AD { use libc::c_void; + use super::{CConcreteType, CTypeTreeRef, Context}; + unsafe extern "C" { pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); } + + // TypeTree functions + unsafe extern "C" { + pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef; + pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); + pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + pub(crate) fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; + pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + } + unsafe extern "C" { static mut EnzymePrintPerf: c_void; static mut EnzymePrintActivity: c_void; @@ -141,6 +190,57 @@ pub(crate) use self::Fallback_AD::*; pub(crate) mod Fallback_AD { #![allow(unused_variables)] + use libc::c_char; + + use super::{CConcreteType, CTypeTreeRef, Context}; + + // TypeTree function fallbacks + pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef { + unimplemented!() + } + + pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { + unimplemented!() + } + + pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } + + pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { + unimplemented!() + } + + pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { + unimplemented!() + } + + pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { + unimplemented!() + } + + pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { + unimplemented!() + } + + pub(crate) fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ) { + unimplemented!() + } + + pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { + unimplemented!() + } + + pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { + unimplemented!() + } + pub(crate) fn set_inline(val: bool) { unimplemented!() } @@ -169,3 +269,83 @@ pub(crate) mod Fallback_AD { unimplemented!() } } + +impl TypeTree { + pub(crate) fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + TypeTree { inner } + } + + pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + TypeTree { inner } + } + + pub(crate) fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + self + } + + #[must_use] + pub(crate) fn shift( + self, + layout: &str, + offset: isize, + max_size: isize, + add_offset: usize, + ) -> Self { + let layout = std::ffi::CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ); + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl std::fmt::Display for TypeTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { + EnzymeTypeTreeToStringFree(ptr); + } + + Ok(()) + } +} + +impl std::fmt::Debug for TypeTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index edfb29dd1be72..ea617cd0f0c87 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2673,4 +2673,5 @@ unsafe extern "C" { pub(crate) fn LLVMRustSetNoSanitizeAddress(Global: &Value); pub(crate) fn LLVMRustSetNoSanitizeHWAddress(Global: &Value); + } diff --git a/compiler/rustc_codegen_llvm/src/va_arg.rs b/compiler/rustc_codegen_llvm/src/va_arg.rs index ce079f3cb0af1..99d72b63a10ce 100644 --- a/compiler/rustc_codegen_llvm/src/va_arg.rs +++ b/compiler/rustc_codegen_llvm/src/va_arg.rs @@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>( src_align, bx.const_u32(layout.layout.size().bytes() as u32), MemFlags::empty(), + None, ); tmp } else { diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index bde63fd501aa2..2b0aa3db74844 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -1550,6 +1550,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { align, bx.const_usize(copy_bytes), MemFlags::empty(), + None, ); // ...and then load it with the ABI type. llval = load_cast(bx, cast, llscratch, scratch_align); diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index fc95f62b4a43d..e96f2a01aed62 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>( if allow_overlap { bx.memmove(dst, align, src, align, size, flags); } else { - bx.memcpy(dst, align, src, align, size, flags); + bx.memcpy(dst, align, src, align, size, flags, None); } } diff --git a/compiler/rustc_codegen_ssa/src/mir/statement.rs b/compiler/rustc_codegen_ssa/src/mir/statement.rs index f164e0f912373..0a50d7f18dbef 100644 --- a/compiler/rustc_codegen_ssa/src/mir/statement.rs +++ b/compiler/rustc_codegen_ssa/src/mir/statement.rs @@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let align = pointee_layout.align; let dst = dst_val.immediate(); let src = src_val.immediate(); - bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty()); + bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None); } mir::StatementKind::FakeRead(..) | mir::StatementKind::Retag { .. } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 979456a6ba70f..fc040070935bb 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,6 +2,7 @@ use std::assert_matches::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; +use rustc_ast::expand::typetree::FncTree; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; @@ -424,6 +425,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + tt: Option, ); fn memmove( &mut self, @@ -480,7 +482,7 @@ pub trait BuilderMethods<'a, 'tcx>: temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { let bytes = self.const_usize(layout.size.bytes()); - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); } } diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index f36ae83165319..9fbea69959c6b 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -44,7 +44,6 @@ pub struct UnsupportedUnion { pub ty_name: String, } -// FIXME(autodiff): I should get used somewhere #[derive(Diagnostic)] #[diag(middle_autodiff_unsafe_inner_const_ref)] pub struct AutodiffUnsafeInnerConstRef<'tcx> { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index a7cde2ad48547..8040ccf92f284 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -10,6 +10,7 @@ //! ["The `ty` module: representing types"]: https://rustc-dev-guide.rust-lang.org/ty.html #![allow(rustc::usage_of_ty_tykind)] +#![allow(unused_imports)] use std::assert_matches::assert_matches; use std::fmt::Debug; @@ -24,7 +25,10 @@ pub use assoc::*; pub use generic_args::{GenericArgKind, TermKind, *}; pub use generics::*; pub use intrinsic::IntrinsicDef; -use rustc_abi::{Align, FieldIdx, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx}; +use rustc_abi::{ + Align, FieldIdx, FieldsShape, Integer, IntegerType, ReprFlags, ReprOptions, VariantIdx, +}; +use rustc_ast::expand::typetree::{FncTree, Kind, Type, TypeTree}; use rustc_ast::node_id::NodeMap; pub use rustc_ast_ir::{Movability, Mutability, try_visit}; use rustc_attr_data_structures::{AttributeKind, StrippedCfgItem, find_attr}; @@ -63,7 +67,7 @@ pub use rustc_type_ir::solve::SizedTraitKind; pub use rustc_type_ir::*; #[allow(hidden_glob_reexports, unused_imports)] use rustc_type_ir::{InferCtxtLike, Interner}; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, trace}; pub use vtable::*; use {rustc_ast as ast, rustc_attr_data_structures as attr, rustc_hir as hir}; @@ -112,7 +116,7 @@ pub use self::typeck_results::{ Rust2024IncompatiblePatInfo, TypeckResults, UserType, UserTypeAnnotationIndex, UserTypeKind, }; pub use self::visit::*; -use crate::error::{OpaqueHiddenTypeMismatch, TypeMismatchReason}; +use crate::error::{AutodiffUnsafeInnerConstRef, OpaqueHiddenTypeMismatch, TypeMismatchReason}; use crate::metadata::ModChild; use crate::middle::privacy::EffectiveVisibilities; use crate::mir::{Body, CoroutineLayout, CoroutineSavedLocal, SourceInfo}; @@ -222,6 +226,9 @@ pub struct ResolverAstLowering { pub disambiguator: DisambiguatorState, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, + pub trait_map: NodeMap>, /// List functions and methods for which lifetime elision was successful. pub lifetime_elision_allowed: FxHashSet, @@ -2340,3 +2347,323 @@ mod size_asserts { static_assert_size!(WithCachedTypeInfo>, 48); // tidy-alphabetical-end } + +pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + let mut visited = vec![]; + let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; + return TypeTree(vec![tt]); +} + +use rustc_ast::expand::autodiff_attrs::DiffActivity; + +// This function combines three tasks. To avoid traversing each type 3x, we combine them. +// 1. Create a TypeTree from a Ty. This is the main task. +// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM +// lowering. E.g. fat ptr are going to introduce an extra int. +// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an +// autodiff macro on top). Here we want to make sure that shadows are mutable internally. +// We know the outermost ref/ptr indirection is mutability - we generate it like that. +// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. +// Not an error, because it only causes issues if they are actually read, which we don't check +// yet. We should add such analysis to relibably either issue an error or accept without warning. +// If there only were some research to do that... +pub fn fnc_typetrees<'tcx>( + tcx: TyCtxt<'tcx>, + fn_ty: Ty<'tcx>, + da: &mut Vec, + span: Option, +) -> FncTree { + if !fn_ty.is_fn() { + return FncTree { args: vec![], ret: TypeTree::new() }; + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + // Recommended by compiler-errors: + // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + + let mut new_activities = vec![]; + let mut new_positions = vec![]; + let mut visited = vec![]; + let mut args = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + // We care about safety checks, if an argument get's duplicated and we write into the + // shadow. That's equivalent to Duplicated or DuplicatedOnly. + let safety = if !da.is_empty() { + assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len()); + // If we have Activities, we also have spans + assert!(span.is_some()); + match da[i] { + DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true, + _ => false, + } + } else { + false + }; + + visited.clear(); + if ty.is_raw_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do with fn ptr?"); + } + let (inner_ty, _is_mut) = match ty.kind() { + ty::RawPtr(inner_ty, mutbl) => (*inner_ty, *mutbl == hir::Mutability::Mut), + ty::Ref(_, inner_ty, mutbl) => (*inner_ty, *mutbl == hir::Mutability::Mut), + _ => { + let inner_ty = ty.builtin_deref(true).unwrap(); + (inner_ty, false) // Box - assume not mutable for now + } + }; + if inner_ty.is_slice() { + // We know that the length will be passed as extra arg. + let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + args.push(TypeTree(vec![tt])); + let i64_tt = + Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; + args.push(TypeTree(vec![i64_tt])); + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly + | DiffActivity::Dual + | DiffActivity::DuplicatedOnly + | DiffActivity::Duplicated => DiffActivity::FakeActivitySize(None), + DiffActivity::Const => DiffActivity::Const, + _ => panic!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + trace!("ABI MATCHING!"); + continue; + } + } + let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span); + args.push(arg_tt); + } + + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } + + visited.clear(); + let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span); + + FncTree { args, ret } +} + +fn typetree_from_ty<'a>( + ty: Ty<'a>, + tcx: TyCtxt<'a>, + depth: usize, + safety: bool, + visited: &mut Vec>, + span: Option, +) -> TypeTree { + if depth > 20 { + trace!("depth > 20 for ty: {}", &ty); + } + if visited.contains(&ty) { + // recursive type + trace!("recursive type: {}", &ty); + return TypeTree::new(); + } + visited.push(ty); + + if ty.is_raw_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do with fn ptr?"); + } + + let (inner_ty, is_mut) = match ty.kind() { + ty::RawPtr(inner_ty, mutbl) => (*inner_ty, *mutbl == hir::Mutability::Mut), + ty::Ref(_, inner_ty, mutbl) => (*inner_ty, *mutbl == hir::Mutability::Mut), + _ => { + let inner_ty = ty.builtin_deref(true).unwrap(); + (inner_ty, false) // Box - assume not mutable for now + } + }; + + // Now account for inner mutability. + if !is_mut && depth > 0 && safety { + // If we have mutability, we also have a span + assert!(span.is_some()); + let span = span.unwrap(); + + tcx.sess.dcx().emit_warn(AutodiffUnsafeInnerConstRef { span, ty }); + } + + //visited.push(inner_ty); + let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + visited.pop(); + return TypeTree(vec![tt]); + } + + if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { + visited.pop(); + return TypeTree::new(); + } + + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + } else { + panic!("scalar that is neither integral nor floating point"); + }; + visited.pop(); + return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); + } + + let typing_env = TypingEnv::fully_monomorphized(); + let layout = tcx.layout_of(typing_env.as_query_input(ty)); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() && !ty.is_simd() { + let adt_def = ty.ty_adt_def().unwrap(); + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + FieldsShape::Array { .. } => { + return TypeTree::new(); + } //e.g. core::arch::x86_64::__m128i + FieldsShape::Union(_) => { + return TypeTree::new(); + } + FieldsShape::Primitive => { + return TypeTree::new(); + } + }; + + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let typing_env = TypingEnv::fully_monomorphized(); + let field_ty: Ty<'_> = tcx.normalize_erasing_regions(typing_env, field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + //visited.push(field_ty); + let mut child = + typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + Some(child) + }) + .flatten() + .collect::>(); + + visited.pop(); + let ret_tt = TypeTree(fields); + return ret_tt; + } else if adt_def.is_enum() { + // Enzyme can't represent enums, so let it figure it out itself, without seeeding + // typetree + //unimplemented!("adt that is an enum"); + } else { + //let ty_name = tcx.def_path_debug_str(adt_def.did()); + //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); + } + } + + if ty.is_simd() { + trace!("simd"); + let (_size, inner_ty) = ty.simd_size_and_type(tcx); + //visited.push(inner_ty); + let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); + //let tt = TypeTree( + // std::iter::repeat(subtt) + // .take(*count as usize) + // .enumerate() + // .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + // .flatten() + // .collect(), + //); + visited.pop(); + return TypeTree::new(); + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + if (*count as usize) == 0 { + return TypeTree::new(); + } + let sub_ty = ty.builtin_index().unwrap(); + //visited.push(sub_ty); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + // calculate size of subtree + let typing_env = TypingEnv::fully_monomorphized(); + let size = tcx.layout_of(typing_env.as_query_input(sub_ty)).unwrap().size.bytes() as usize; + let tt = TypeTree( + std::iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + visited.pop(); + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + //visited.push(sub_ty); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); + + visited.pop(); + return subtt; + } + + visited.pop(); + TypeTree::new() +} diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index 22d593b80b895..868c29d5df6eb 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -2,7 +2,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::bug; use rustc_middle::mir::mono::MonoItem; -use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; +use rustc_middle::ty::{ + self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv, fnc_typetrees, +}; use rustc_symbol_mangling::symbol_name_for_instance_in_crate; use tracing::{debug, trace}; @@ -127,8 +129,14 @@ pub(crate) fn find_autodiff_source_functions<'tcx>( let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); let mut new_target_attrs = target_attrs.clone(); - new_target_attrs.input_activity = input_activities; - let itm = new_target_attrs.into_item(symb, target_symbol); + new_target_attrs.input_activity = input_activities.clone(); + + // generate typetrees for the function + let span = tcx.def_span(inst.def_id()); + let fnc_tree = fnc_typetrees(tcx, fn_ty, &mut input_activities, Some(span)); + let (inputs, output) = (fnc_tree.args, fnc_tree.ret); + + let itm = new_target_attrs.into_item(symb, target_symbol, inputs, output); autodiff_items.push(itm); } diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index 0d41a822e8a8d..3b7f004193487 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1219,6 +1219,9 @@ pub struct Resolver<'ra, 'tcx> { // that were encountered during resolution. These names are used to generate item names // for APITs, so we don't want to leak details of resolution into these names. impl_trait_names: FxHashMap, + + /// Mapping of autodiff function IDs + autodiff_map: FxHashMap, } /// This provides memory for the rest of the crate. The `'ra` lifetime that is @@ -1597,6 +1600,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> { current_crate_outer_attr_insert_span, mods_with_parse_errors: Default::default(), impl_trait_names: Default::default(), + autodiff_map: Default::default(), }; let root_parent_scope = ParentScope::module(graph_root, &resolver); @@ -1719,6 +1723,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> { .map(|(k, f)| (k, f.key())) .collect(), disambiguator: self.disambiguator, + autodiff_map: self.autodiff_map, trait_map: self.trait_map, lifetime_elision_allowed: self.lifetime_elision_allowed, lint_buffer: Steal::new(self.lint_buffer), diff --git a/src/llvm-project b/src/llvm-project index e8a2ffcf322f4..d3c793b025645 160000 --- a/src/llvm-project +++ b/src/llvm-project @@ -1 +1 @@ -Subproject commit e8a2ffcf322f45b8dce82c65ab27a3e2430a6b51 +Subproject commit d3c793b025645a4565ac59aceb30d2d116ff1a41 diff --git a/src/tools/enzyme b/src/tools/enzyme index 2cccfba93c165..b5098d515d5e1 160000 --- a/src/tools/enzyme +++ b/src/tools/enzyme @@ -1 +1 @@ -Subproject commit 2cccfba93c1650f26f1cf8be8aa875a7c1d23fb3 +Subproject commit b5098d515d5e1bd0f5470553bc0d18da9794ca8b diff --git a/tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.check b/tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.check new file mode 100644 index 0000000000000..e2f8daafd62a4 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.check @@ -0,0 +1,9 @@ +CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double, [-1,32]:Float@double, [-1,40]:Float@double, [-1,48]:Float@double, [-1,56]:Float@double}:{} + +CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double, [-1,32]:Float@double, [-1,40]:Float@double, [-1,48]:Float@double, [-1,56]:Float@double} + +CHECK-DAG: load double{{.*}}: {[-1]:Float@double} + +CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double} + +CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double} \ No newline at end of file diff --git a/tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.rs b/tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.rs new file mode 100644 index 0000000000000..bba1ecf1cb550 --- /dev/null +++ b/tests/run-make/autodiff/type-trees/type-analysis/memcpy/memcpy.rs @@ -0,0 +1,31 @@ +#![feature(autodiff)] + +use std::autodiff::autodiff_reverse; +use std::ptr; + +#[autodiff_reverse(d_test_memcpy, Duplicated, Active)] +#[no_mangle] +fn test_memcpy(input: &[f64; 8]) -> f64 { + let mut local_data = [0.0f64; 8]; + + unsafe { + ptr::copy_nonoverlapping(input.as_ptr(), local_data.as_mut_ptr(), 8); + } + + let mut result = 0.0; + for i in 0..8 { + result += local_data[i] * local_data[i]; + } + + result +} + +fn main() { + let input = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mut d_input = [0.0; 8]; + let result = test_memcpy(&input); + let result_d = d_test_memcpy(&input, &mut d_input, 1.0); + + assert_eq!(result, result_d); + println!("Memcpy test passed: result = {}", result); +} diff --git a/tests/run-make/autodiff/type-trees/type-analysis/memcpy/rmake.rs b/tests/run-make/autodiff/type-trees/type-analysis/memcpy/rmake.rs new file mode 100644 index 0000000000000..b4c65974cb5bf --- /dev/null +++ b/tests/run-make/autodiff/type-trees/type-analysis/memcpy/rmake.rs @@ -0,0 +1,28 @@ +//@ needs-enzyme +//@ ignore-cross-compile + +use std::fs; + +use run_make_support::{llvm_filecheck, rfs, rustc}; + +fn main() { + // Compile the Rust file with the required flags, capturing both stdout and stderr + let output = rustc() + .input("memcpy.rs") + .arg("-Zautodiff=Enable,PrintTAFn=test_memcpy") + .arg("-Zautodiff=NoPostopt") + .opt_level("3") + .arg("-Clto=fat") + .arg("-g") + .run(); + + let stdout = output.stdout_utf8(); + let stderr = output.stderr_utf8(); + + // Write the outputs to files + rfs::write("memcpy.stdout", stdout); + rfs::write("memcpy.stderr", stderr); + + // Run FileCheck on the stdout using the check file + llvm_filecheck().patterns("memcpy.check").stdin_buf(rfs::read("memcpy.stdout")).run(); +}