Skip to content

TypeTree support in autodiff #144197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open

TypeTree support in autodiff #144197

wants to merge 16 commits into from

Conversation

KMJ-007
Copy link
Contributor

@KMJ-007 KMJ-007 commented Jul 19, 2025

TypeTrees for Autodiff

What are TypeTrees?

Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently.

Structure

TypeTree(Vec<Type>)

Type {
    offset: isize,  // byte offset (-1 = everywhere)
    size: usize,    // size in bytes
    kind: Kind,     // Float, Integer, Pointer, etc.
    child: TypeTree // nested structure
}

Example: fn compute(x: &f32, data: &[f32]) -> f32

Input 0: x: &f32

TypeTree(vec![Type {
    offset: 0, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: 0, size: 4, kind: Float,
        child: TypeTree::new()
    }])
}])

Input 1: data: &[f32]

TypeTree(vec![Type {
    offset: 0, size: 8, kind: Pointer,
    child: TypeTree(vec![Type {
        offset: -1, size: 4, kind: Float,  // -1 = all elements
        child: TypeTree::new()
    }])
}])

Output: f32

TypeTree(vec![Type {
    offset: 0, size: 4, kind: Float,
    child: TypeTree::new()
}])

Why Needed?

  • Enzyme can't deduce complex type layouts from LLVM IR
  • Prevents slow memory pattern analysis
  • Enables correct derivative computation for nested structures
  • Tells Enzyme which bytes are differentiable vs metadata

What Enzyme Does With This Information:

Without TypeTrees (current state):

; Enzyme sees generic LLVM IR:

define float @distance(i8* %p1, i8* %p2) {

; Has to guess what these pointers point to

; Slow analysis of all memory operations

; May miss optimization opportunities

}

With TypeTrees (our goal):

// Enzyme knows:

// - %p1 points to struct with f32 at +0, f32 at +4, i32 at +8

// - Only the f32 fields need derivatives

// - Can generate efficient derivative code directly

TypeTrees - Offset and -1 Explained

Type Structure

Type {

offset: isize, // WHERE this type starts

size: usize, // HOW BIG this type is

kind: Kind, // WHAT KIND of data (Float, Int, Pointer)

child: TypeTree // WHAT'S INSIDE (for pointers/containers)

}

Offset Values

Regular Offset (0, 4, 8, etc.)

Specific byte position within a structure

struct Point {

x: f32, // offset 0, size 4

y: f32, // offset 4, size 4

id: i32, // offset 8, size 4

}

TypeTree for &Point:

TypeTree(vec![

	Type { offset: 0, size: 4, kind: Float }, // x at byte 0

	Type { offset: 4, size: 4, kind: Float }, // y at byte 4

	Type { offset: 8, size: 4, kind: Integer } // id at byte 8

])

Offset -1 (Special: "Everywhere")

Means "this pattern repeats for ALL elements"

Example 1: Array [f32; 100]

TypeTree(vec![Type {

offset: -1, // ALL positions

size: 4, // each f32 is 4 bytes

kind: Float, // every element is float

}])

Instead of listing 100 separate Types with offsets 0,4,8,12...396

Example 2: Slice &[i32]

// Pointer to slice data

TypeTree(vec![Type {

	offset: 0, size: 8, kind: Pointer,

	child: TypeTree(vec![Type {

	offset: -1, // ALL slice elements

	size: 4, // each i32 is 4 bytes

	kind: Integer

	}])

}])

Example 3: Mixed Structure

struct Container {

	header: i64, // offset 0

	data: [f32; 1000], // offset 8, but elements use -1

}
TypeTree(vec![

	Type { offset: 0, size: 8, kind: Integer }, // header

	Type { offset: 8, size: 4000, kind: Pointer,

	child: TypeTree(vec![Type {

	offset: -1, size: 4, kind: Float // ALL array elements

}])

}

])

@rustbot rustbot added F-autodiff `#![feature(autodiff)]` S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. T-compiler Relevant to the compiler team, which will review and decide on the PR/issue. labels Jul 19, 2025
@rust-log-analyzer

This comment has been minimized.

@rustbot rustbot added the A-LLVM Area: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues. label Jul 19, 2025
@rust-log-analyzer

This comment has been minimized.

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 19, 2025

Currently, I have implemented only for memcpy

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 19, 2025

r? @ZuseZ4

@KMJ-007 KMJ-007 marked this pull request as ready for review July 19, 2025 23:50
@rustbot rustbot added S-waiting-on-review Status: Awaiting review from the assignee but also interested parties. and removed S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. labels Jul 19, 2025
@rustbot
Copy link
Collaborator

rustbot commented Jul 19, 2025

Some changes occurred in compiler/rustc_ast/src/expand/autodiff_attrs.rs

cc @ZuseZ4

Some changes occurred in compiler/rustc_codegen_llvm/src/builder/autodiff.rs

cc @ZuseZ4

Some changes occurred in compiler/rustc_codegen_ssa

cc @WaffleLapkin

Some changes occurred in compiler/rustc_monomorphize/src/partitioning/autodiff.rs

cc @ZuseZ4

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

@rustbot
Copy link
Collaborator

rustbot commented Jul 20, 2025

Some changes occurred in compiler/rustc_codegen_gcc

cc @antoyo, @GuillaumeGomez

@rust-log-analyzer

This comment has been minimized.

@KMJ-007
Copy link
Contributor Author

KMJ-007 commented Jul 21, 2025

CI is failing, fixing them!

// For now, we just ignore the TypeTree since gcc backend doesn't support autodiff yet
// When autodiff support is added to gcc backend, this should attach TypeTree information
// as function attributes similar to how LLVM backend does it.
}
// TODO(antoyo): handle aligns and is_volatile.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please move this comment right above the call to new_call()?

Copy link
Member

@ZuseZ4 ZuseZ4 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KMJ-007 The TODO might be a bit too ambitious, I don't think anyone is currently working on autodiff for GCC,so I would just remove it and the code change. Also, even if they want to add it one day, they should rather do it from scratch to be able to learn some lessons from Enzyme, so I'm not sure if TypeTrees would even be a thing for them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was talking about the existing TODO:

// TODO(antoyo): handle aligns and is_volatile.

I would like it to stay where it was before this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KMJ-007 The TODO might be a bit too ambitious, I don't think anyone is currently working on autodiff for GCC,so I would just remove it and the code change. Also, even if they want to add it one day, they should rather do it from scratch to be able to learn some lessons from Enzyme, so I'm not sure if TypeTrees would even be a thing for them.

I added it for temp, because CI was failing for gcc, i think i shouldn't have changed memcpy for gcc, but somewhere they are sharing the same trait or struct, I don't remember exactly, but will look into it again, and handle the changes so we don't break gcc or change it

@rustbot
Copy link
Collaborator

rustbot commented Jul 23, 2025

Some changes occurred in src/tools/enzyme

cc @ZuseZ4

@rustbot

This comment has been minimized.

@rustbot rustbot added has-merge-commits PR has merge commits, merge with caution. S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. labels Jul 23, 2025
KMJ-007 added 3 commits July 23, 2025 17:37
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
KMJ-007 added 6 commits July 23, 2025 17:37
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
Signed-off-by: Karan Janthe <[email protected]>
@rustbot
Copy link
Collaborator

rustbot commented Jul 23, 2025

⚠️ Warning ⚠️

  • Some commits in this PR modify submodules.

@rustbot rustbot removed has-merge-commits PR has merge commits, merge with caution. S-waiting-on-author Status: This is awaiting some action (such as code changes or more information) from the author. labels Jul 23, 2025
@rust-log-analyzer

This comment has been minimized.

@rustbot rustbot added the A-run-make Area: port run-make Makefiles to rmake.rs label Jul 23, 2025
@rustbot
Copy link
Collaborator

rustbot commented Jul 23, 2025

Some changes occurred in compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

cc @ZuseZ4

This PR modifies run-make tests.

cc @jieyouxu

@rust-log-analyzer

This comment has been minimized.

Signed-off-by: Karan Janthe <[email protected]>
@rust-log-analyzer

This comment has been minimized.

@rust-log-analyzer

This comment has been minimized.

Signed-off-by: Karan Janthe <[email protected]>
@rust-log-analyzer

This comment has been minimized.

@KMJ-007 KMJ-007 requested a review from ZuseZ4 July 24, 2025 04:57
Signed-off-by: Karan Janthe <[email protected]>
@rust-log-analyzer
Copy link
Collaborator

The job pr-check-2 failed! Check out the build log: (web) (plain enhanced) (plain)

Click to see the possible cause of the failure (guessed by this bot)
[RUSTC-TIMING] rustc_mir_build test:false 3.396
    Checking rustc_mir_transform v0.0.0 (/checkout/compiler/rustc_mir_transform)
[RUSTC-TIMING] rustc_passes test:false 1.854
    Checking rustc_borrowck v0.0.0 (/checkout/compiler/rustc_borrowck)
error: unused imports: `c_char` and `c_uint`
 --> compiler/rustc_codegen_llvm/src/builder/autodiff.rs:1:20
  |
1 | use std::os::raw::{c_char, c_uint};
  |                    ^^^^^^  ^^^^^^
  |
  = note: `-D unused-imports` implied by `-D warnings`
  = help: to override `-D warnings` add `#[allow(unused_imports)]`

---

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:275:21
    |
275 |         let inner = unsafe { EnzymeNewTypeTree() };
    |                     ^^^^^^ unnecessary `unsafe` block
    |
    = note: `-D unused-unsafe` implied by `-D warnings`
    = help: to override `-D warnings` add `#[allow(unused_unsafe)]`

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:280:21
    |
280 |         let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) };
    |                     ^^^^^^ unnecessary `unsafe` block

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:285:9
    |
---

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:318:21
    |
318 |         let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) };
    |                     ^^^^^^ unnecessary `unsafe` block

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:325:19
    |
325 |         let ptr = unsafe { EnzymeTypeTreeToString(self.inner) };
    |                   ^^^^^^ unnecessary `unsafe` block

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:333:9
    |
333 |         unsafe {
    |         ^^^^^^ unnecessary `unsafe` block

error: unnecessary `unsafe` block
   --> compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs:349:9
    |
349 |         unsafe { EnzymeFreeTypeTree(self.inner) }
    |         ^^^^^^ unnecessary `unsafe` block

[RUSTC-TIMING] rustc_codegen_llvm test:false 3.424
error: could not compile `rustc_codegen_llvm` (lib) due to 10 previous errors
warning: build failed, waiting for other jobs to finish...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
A-LLVM Area: Code generation parts specific to LLVM. Both correctness bugs and optimization-related issues. A-run-make Area: port run-make Makefiles to rmake.rs F-autodiff `#![feature(autodiff)]` S-waiting-on-review Status: Awaiting review from the assignee but also interested parties. T-compiler Relevant to the compiler team, which will review and decide on the PR/issue.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants