Skip to content

Commit 577cb5d

Browse files
committed
Remove inlining for autodiff handling
1 parent c7fc24b commit 577cb5d

File tree

10 files changed

+58
-151
lines changed

10 files changed

+58
-151
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -348,28 +348,10 @@ mod llvm_enzyme {
348348
let mut rustc_ad_attr =
349349
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
350350

351-
let ts2: Vec<TokenTree> = vec![TokenTree::Token(
352-
Token::new(TokenKind::Ident(sym::never, false.into()), span),
353-
Spacing::Joint,
354-
)];
355-
let never_arg = ast::DelimArgs {
356-
dspan: DelimSpan::from_single(span),
357-
delim: ast::token::Delimiter::Parenthesis,
358-
tokens: TokenStream::from_iter(ts2),
359-
};
360-
let inline_item = ast::AttrItem {
361-
unsafety: ast::Safety::Default,
362-
path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
363-
args: ast::AttrArgs::Delimited(never_arg),
364-
tokens: None,
365-
};
366-
let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
367351
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
368352
let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
369-
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
370-
let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
371353

372-
// We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
354+
// We're avoid duplicating the attribute `#[rustc_autodiff]`.
373355
fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
374356
match (attr, item) {
375357
(ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
@@ -388,18 +370,12 @@ mod llvm_enzyme {
388370
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
389371
iitem.attrs.push(attr);
390372
}
391-
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
392-
iitem.attrs.push(inline_never.clone());
393-
}
394373
Annotatable::Item(iitem.clone())
395374
}
396375
Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
397376
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
398377
assoc_item.attrs.push(attr);
399378
}
400-
if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
401-
assoc_item.attrs.push(inline_never.clone());
402-
}
403379
Annotatable::AssocItem(assoc_item.clone(), i)
404380
}
405381
Annotatable::Stmt(ref mut stmt) => {
@@ -408,10 +384,6 @@ mod llvm_enzyme {
408384
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
409385
iitem.attrs.push(attr);
410386
}
411-
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
412-
{
413-
iitem.attrs.push(inline_never.clone());
414-
}
415387
}
416388
_ => unreachable!("stmt kind checked previously"),
417389
};

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ use tracing::debug;
1010
use crate::builder::{Builder, PlaceRef, UNNAMED};
1111
use crate::context::SimpleCx;
1212
use crate::declare::declare_simple_fn;
13-
use crate::llvm::AttributePlace::Function;
13+
use crate::llvm;
1414
use crate::llvm::{Metadata, True, Type};
1515
use crate::value::Value;
16-
use crate::{attributes, llvm};
1716

1817
pub(crate) fn adjust_activity_to_abi<'tcx>(
1918
tcx: TyCtxt<'tcx>,
@@ -308,11 +307,6 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
308307
enzyme_ty,
309308
);
310309

311-
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
312-
// do it's work.
313-
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
314-
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
315-
316310
let num_args = llvm::LLVMCountParams(&fn_to_diff);
317311
let mut args = Vec::with_capacity(num_args as usize + 1);
318312
args.push(fn_to_diff);

tests/codegen-llvm/autodiff/batched.rs

Lines changed: 34 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,74 +21,39 @@ fn square(x: &f32) -> f32 {
2121
x * x
2222
}
2323

24-
// d_square2
25-
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
26-
// CHECK-NEXT: start:
27-
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
28-
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
29-
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
30-
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
31-
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
32-
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
33-
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
34-
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
35-
// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl"
36-
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
37-
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
38-
// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1"
39-
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
40-
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
41-
// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2"
42-
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
43-
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
44-
// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3"
45-
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
46-
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
47-
// CHECK-NEXT: ret [4 x float] %15
48-
// CHECK-NEXT: }
49-
50-
// d_square3, the extra float is the original return value (x * x)
51-
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
52-
// CHECK-NEXT: start:
53-
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
54-
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
55-
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
56-
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
57-
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
58-
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
59-
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
60-
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
61-
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
62-
// CHECK-NEXT: %4 = fadd fast float %"_2'ipl", %"_2'ipl"
63-
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
64-
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
65-
// CHECK-NEXT: %7 = fadd fast float %"_2'ipl1", %"_2'ipl1"
66-
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
67-
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
68-
// CHECK-NEXT: %10 = fadd fast float %"_2'ipl2", %"_2'ipl2"
69-
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
70-
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
71-
// CHECK-NEXT: %13 = fadd fast float %"_2'ipl3", %"_2'ipl3"
72-
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
73-
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
74-
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
75-
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
76-
// CHECK-NEXT: ret { float, [4 x float] } %17
77-
// CHECK-NEXT: }
78-
7924
fn main() {
8025
let x = std::hint::black_box(3.0);
26+
27+
// square(&x)
28+
// CHECK: %_0.i = fmul float %_2.i, %_2.i
29+
// CHECK-NEXT: store float %_0.i, ptr %output, align 4
8130
let output = square(&x);
8231
dbg!(&output);
8332
assert_eq!(9.0, output);
33+
34+
// square(&x)
35+
// CHECK: %_2.i26 = load float, ptr %x, align 4
36+
// CHECK-NEXT: %_0.i27 = fmul float %_2.i26, %_2.i26
8437
dbg!(square(&x));
8538

8639
let mut df_dx1 = 1.0;
8740
let mut df_dx2 = 2.0;
8841
let mut df_dx3 = 3.0;
8942
let mut df_dx4 = 0.0;
43+
44+
// [o1, o2, o3, o4] (o4 is being optimized away as its smth * 0.0)
45+
// CHECK: %x.val = load float, ptr %x, align 4
46+
// CHECK-NEXT: %13 = fmul fast float %x.val, 2.000000e+00
47+
// CHECK-NEXT: %14 = fmul fast float %x.val, 4.000000e+00
48+
// CHECK-NEXT: %15 = fmul fast float %x.val, 6.000000e+00
9049
let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
9150
dbg!(o1, o2, o3, o4);
51+
52+
// [output2, o1, o2, o3, o4] (o4 is being optimized away as its smth * 0.0)
53+
// CHECK: %_0.i45 = fmul float %x.val35, %x.val35
54+
// CHECK-NEXT: %40 = fmul fast float %x.val35, 2.000000e+00
55+
// CHECK-NEXT: %41 = fmul fast float %x.val35, 4.000000e+00
56+
// CHECK-NEXT: %42 = fmul fast float %x.val35, 6.000000e+00
9257
let [output2, o1, o2, o3, o4] =
9358
d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
9459
dbg!(o1, o2, o3, o4);
@@ -101,8 +66,22 @@ fn main() {
10166
assert_eq!(2.0, df_dx2);
10267
assert_eq!(3.0, df_dx3);
10368
assert_eq!(0.0, df_dx4);
69+
70+
// d_square3(&x, &mut df_dx1)
71+
// CHECK: %x.val39 = load float, ptr %x, align 4
72+
// CHECK-NEXT: %72 = fmul fast float %x.val39, 2.000000e+00
10473
assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
74+
75+
// d_square3(&x, &mut df_dx2)
76+
// CHECK: %74 = fmul fast float %x.val39, 4.000000e+00
77+
// CHECK-NEXT: store float %74, ptr %_191, align 4
10578
assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
79+
80+
// d_square3(&x, &mut df_dx3)
81+
// CHECK: %76 = fmul fast float %x.val39, 6.000000e+00
82+
// CHECK-NEXT: store float %76, ptr %_200, align 4
10683
assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
84+
85+
// d_square3(&x, &mut df_dx3) is being optimized away as it's smth * 0.0
10786
assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
10887
}

tests/codegen-llvm/autodiff/generic.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,13 @@ fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
1010
*x * *x
1111
}
1212

13-
// Ensure that `d_square::<f32>` code is generated
13+
// Ensure that `square::<f32>` code is generated
1414
//
15-
// CHECK: ; generic::square
16-
// CHECK-NEXT: ; Function Attrs: {{.*}}
17-
// CHECK-NEXT: define internal {{.*}} float
18-
// CHECK-NEXT: start:
19-
// CHECK-NOT: ret
20-
// CHECK: fmul float
15+
// CHECK: %1 = fmul float %xf32, %xf32
2116

2217
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
2318
//
24-
// CHECK: ; generic::square
25-
// CHECK-NEXT: ; Function Attrs:
26-
// CHECK-NEXT: define internal {{.*}} double
27-
// CHECK-NEXT: start:
28-
// CHECK-NOT: ret
29-
// CHECK: fmul double
19+
// CHECK: define internal { double } @diffe_ZN7generic6square17he5c855620985cd59E
3020

3121
fn main() {
3222
let xf32: f32 = std::hint::black_box(3.0);

tests/codegen-llvm/autodiff/identical_fnc.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,13 @@ fn square2(x: &f64) -> f64 {
2323
x * x
2424
}
2525

26-
// CHECK:; identical_fnc::main
27-
// CHECK-NEXT:; Function Attrs:
28-
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
29-
// CHECK-NEXT:start:
30-
// CHECK-NOT:br
31-
// CHECK-NOT:ret
32-
// CHECK:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx1)
33-
// CHECK-NEXT:call fastcc void @diffe_ZN13identical_fnc6square17h67c6eccd3051fb4cE(double %x.val, ptr %dx2)
34-
26+
// CHECK: %0 = fadd fast double %x.val, %x.val
27+
// CHECK-NEXT: %1 = load double, ptr %dx1, align 8
28+
// CHECK-NEXT: %2 = fadd fast double %1, %0
29+
// CHECK-NEXT: store double %2, ptr %dx1, align 8
30+
// CHECK-NEXT: %3 = load double, ptr %dx2, align 8
31+
// CHECK-NEXT: %4 = fadd fast double %3, %0
32+
// CHECK-NEXT: store double %4, ptr %dx2, align 8
3533
fn main() {
3634
let x = std::hint::black_box(3.0);
3735
let mut dx1 = std::hint::black_box(1.0);

tests/codegen-llvm/autodiff/scalar.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,11 @@ fn square(x: &f64) -> f64 {
1111
x * x
1212
}
1313

14-
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
15-
// CHECK-NEXT:invertstart:
16-
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
17-
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
18-
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
19-
// CHECK-NEXT: %2 = fadd fast double %1, %0
20-
// CHECK-NEXT: store double %2, ptr %"x'", align 8
21-
// CHECK-NEXT: ret double %_0
22-
// CHECK-NEXT:}
14+
// square
15+
// CHECK: %_0.i = fmul double %_2.i, %_2.i
2316

17+
// d_square
18+
// CHECK: %0 = fadd fast double %_2.i, %_2.i
2419
fn main() {
2520
let x = std::hint::black_box(3.0);
2621
let output = square(&x);

tests/codegen-llvm/autodiff/sret.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,14 @@ fn primal(x: f32, y: f32) -> f64 {
1717
(x * x * y) as f64
1818
}
1919

20-
// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y)
21-
// CHECK-NEXT: invertstart:
22-
// CHECK-NEXT: %_4 = fmul float %x, %x
23-
// CHECK-NEXT: %_3 = fmul float %_4, %y
24-
// CHECK-NEXT: %_0 = fpext float %_3 to double
25-
// CHECK-NEXT: %0 = fadd fast float %y, %y
26-
// CHECK-NEXT: %1 = fmul fast float %0, %x
27-
// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0
28-
// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1
29-
// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2
30-
// CHECK-NEXT: ret { double, float, float } %4
31-
// CHECK-NEXT: }
32-
20+
// CHECK: %_4.i = fmul float %x, %x
21+
// CHECK-NEXT: %_3.i = fmul float %_4.i, %y
22+
// CHECK-NEXT: %_0.i = fpext float %_3.i to double
23+
// CHECK-NEXT: %3 = fadd fast float %y, %y
24+
// CHECK-NEXT: %4 = fmul fast float %3, %x
25+
// CHECK-NEXT: store double %_0.i, ptr %r1, align 8
26+
// CHECK-NEXT: store float %4, ptr %r2, align 4
27+
// CHECK-NEXT: store float %_4.i, ptr %r3, align 4
3328
fn main() {
3429
let x = std::hint::black_box(3.0);
3530
let y = std::hint::black_box(2.5);

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
use std::autodiff::{autodiff_forward, autodiff_reverse};
1717

1818
#[rustc_autodiff]
19-
#[inline(never)]
2019
pub fn f1(x: &[f64], y: f64) -> f64 {
2120

2221

@@ -40,7 +39,6 @@
4039
::core::intrinsics::enzyme_autodiff(f1::<>, df1::<>, (x, bx_0, y))
4140
}
4241
#[rustc_autodiff]
43-
#[inline(never)]
4442
pub fn f2(x: &[f64], y: f64) -> f64 {
4543
::core::panicking::panic("not implemented")
4644
}
@@ -49,7 +47,6 @@
4947
::core::intrinsics::enzyme_autodiff(f2::<>, df2::<>, (x, bx_0, y))
5048
}
5149
#[rustc_autodiff]
52-
#[inline(never)]
5350
pub fn f3(x: &[f64], y: f64) -> f64 {
5451
::core::panicking::panic("not implemented")
5552
}
@@ -58,14 +55,12 @@
5855
::core::intrinsics::enzyme_autodiff(f3::<>, df3::<>, (x, bx_0, y))
5956
}
6057
#[rustc_autodiff]
61-
#[inline(never)]
6258
pub fn f4() {}
6359
#[rustc_autodiff(Forward, 1, None)]
6460
pub fn df4() -> () {
6561
::core::intrinsics::enzyme_autodiff(f4::<>, df4::<>, ())
6662
}
6763
#[rustc_autodiff]
68-
#[inline(never)]
6964
pub fn f5(x: &[f64], y: f64) -> f64 {
7065
::core::panicking::panic("not implemented")
7166
}
@@ -84,7 +79,6 @@
8479
}
8580
struct DoesNotImplDefault;
8681
#[rustc_autodiff]
87-
#[inline(never)]
8882
pub fn f6() -> DoesNotImplDefault {
8983
::core::panicking::panic("not implemented")
9084
}
@@ -93,15 +87,13 @@
9387
::core::intrinsics::enzyme_autodiff(f6::<>, df6::<>, ())
9488
}
9589
#[rustc_autodiff]
96-
#[inline(never)]
9790
pub fn f7(x: f32) -> () {}
9891
#[rustc_autodiff(Forward, 1, Const, None)]
9992
pub fn df7(x: f32) -> () {
10093
::core::intrinsics::enzyme_autodiff(f7::<>, df7::<>, (x,))
10194
}
10295
#[no_mangle]
10396
#[rustc_autodiff]
104-
#[inline(never)]
10597
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
10698
#[rustc_autodiff(Forward, 4, Dual, Dual)]
10799
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
@@ -121,7 +113,6 @@
121113
}
122114
pub fn f9() {
123115
#[rustc_autodiff]
124-
#[inline(never)]
125116
fn inner(x: f32) -> f32 { x * x }
126117
#[rustc_autodiff(Forward, 1, Dual, Dual)]
127118
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
@@ -135,7 +126,6 @@
135126
}
136127
}
137128
#[rustc_autodiff]
138-
#[inline(never)]
139129
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
140130
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
141131
pub fn d_square<T: std::ops::Mul<Output = T> +

0 commit comments

Comments
 (0)