Skip to content

TypeTree support in autodiff #144197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
19 changes: 16 additions & 3 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};
Expand Down Expand Up @@ -85,6 +86,9 @@ pub struct AutoDiffItem {
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
// Type Tree support
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
Expand Down Expand Up @@ -276,14 +280,23 @@ impl AutoDiffAttrs {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
AutoDiffItem { source, target, attrs: self }
pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use gccjit::{
use rustc_abi as abi;
use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout, WrappingRange};
use rustc_apfloat::{Float, Round, Status, ieee};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::common::{
AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope, TypeKind,
Expand Down Expand Up @@ -1368,6 +1369,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
_src_align: Align,
size: RValue<'gcc>,
flags: MemFlags,
_tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_size_t(), false);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_gcc/src/intrinsic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);

bx.lifetime_end(scratch, scratch_size);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
bx.lifetime_end(llscratch, scratch_size);
}
Expand Down
10 changes: 4 additions & 6 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,13 +682,11 @@ pub(crate) fn run_pass_manager(
for function in cx.get_functions() {
let enzyme_marker = "enzyme_marker";
if attributes::has_string_attr(function, enzyme_marker) {
// Sanity check: Ensure 'noinline' is present before replacing it.
assert!(
attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
);
// Remove 'noinline' if present (it should be there in most cases)
if attributes::has_attr(function, Function, llvm::AttributeKind::NoInline) {
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
}

attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
attributes::remove_string_attr_from_llfn(function, enzyme_marker);

assert!(
Expand Down
16 changes: 14 additions & 2 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};

use rustc_ast::expand::typetree::FncTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;

Expand Down Expand Up @@ -31,6 +32,7 @@ use tracing::{debug, instrument};

use crate::abi::FnAbiLlvmExt;
use crate::attributes;
use crate::builder::autodiff::add_tt;
use crate::common::Funclet;
use crate::context::{CodegenCx, FullCx, GenericCx, SCx};
use crate::llvm::{
Expand Down Expand Up @@ -1105,11 +1107,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
Expand All @@ -1118,7 +1121,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};

// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
}
}

Expand Down
142 changes: 141 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::os::raw::{c_char, c_uint};
use std::ptr;

use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
Expand All @@ -14,7 +16,7 @@ use crate::context::SimpleCx;
use crate::declare::declare_simple_fn;
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{Metadata, True};
use crate::llvm::{Metadata, True, TypeTree};
use crate::value::Value;
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};

Expand Down Expand Up @@ -512,3 +514,141 @@ pub(crate) fn differentiate<'ll>(

Ok(())
}

/// Converts a Rust TypeTree to Enzyme's internal TypeTree format
///
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree)
/// and converts it to Enzyme's internal C++ TypeTree representation that
/// Enzyme can understand during differentiation analysis.
#[cfg(llvm_enzyme)]
fn to_enzyme_typetree(
rust_typetree: RustTypeTree,
data_layout: &str,
llcx: &llvm::Context,
) -> TypeTree {
// Start with an empty TypeTree
let mut enzyme_tt = TypeTree::new();

// Convert each Type in the Rust TypeTree to Enzyme format
for rust_type in rust_typetree.0 {
let concrete_type = match rust_type.kind {
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
};

// Create a TypeTree for this specific type
let type_tt = TypeTree::from_type(concrete_type, llcx);

// Apply offset if specified
let type_tt = if rust_type.offset == -1 {
type_tt // -1 means everywhere/no specific offset
} else {
// Apply specific offset positioning
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0)
};

// Merge this type into the main TypeTree
enzyme_tt = enzyme_tt.merge(type_tt);
}

enzyme_tt
}

#[cfg(not(llvm_enzyme))]
#[allow(dead_code)]
fn to_enzyme_typetree(
_rust_typetree: RustTypeTree,
_data_layout: &str,
_llcx: &llvm::Context,
) -> ! {
unimplemented!("TypeTree conversion not available without llvm_enzyme support")
}

// Attaches TypeTree information to LLVM function as enzyme_type attributes.
#[cfg(llvm_enzyme)]
pub(crate) fn add_tt<'ll>(
llmod: &'ll llvm::Module,
llcx: &'ll llvm::Context,
fn_def: &'ll Value,
tt: FncTree,
) {
let inputs = tt.args;
let ret_tt: RustTypeTree = tt.ret;

// Get LLVM data layout string for TypeTree conversion
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
let llvm_data_layout =
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
.expect("got a non-UTF8 data-layout from LLVM");

// Attribute name that Enzyme recognizes for TypeTree information
let attr_name = "enzyme_type";
let c_attr_name = std::ffi::CString::new(attr_name).unwrap();

// Attach TypeTree attributes to each input parameter
// Enzyme uses these to understand parameter memory layouts during differentiation
for (i, input) in inputs.iter().enumerate() {
unsafe {
// Convert Rust TypeTree to Enzyme's internal format
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);

// Serialize TypeTree to string format that Enzyme can parse
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);

// Create LLVM string attribute with TypeTree information
let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);

// Attach attribute to the specific function parameter
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);

// Free the C string to prevent memory leaks
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
}
}

// Attach TypeTree attribute to the return type
// Enzyme needs this to understand how to handle return value derivatives
unsafe {
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);

let ret_attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);

// Attach to function return type
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);

// Free the C string
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
}
}

// Fallback implementation when Enzyme is not available
#[cfg(not(llvm_enzyme))]
pub(crate) fn add_tt<'ll>(
_llmod: &'ll llvm::Module,
_llcx: &'ll llvm::Context,
_fn_def: &'ll Value,
_tt: FncTree,
) {
unimplemented!()
}
Loading
Loading