Skip to content

Commit 6420b0b

Browse files
committed
Add test for autodiff abi handling
1 parent 2801f9a commit 6420b0b

File tree

1 file changed

+331
-0
lines changed

1 file changed

+331
-0
lines changed
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
//@ revisions: debug release
2+
3+
//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
4+
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
5+
//@ no-prefer-dynamic
6+
//@ needs-enzyme
7+
8+
// This test checks that Rust types are lowered to LLVM-IR types in a way
9+
// we expect and Enzyme can handle. We explicitly check release mode to
10+
// ensure that LLVM's O3 pipeline doesn't rewrite function signatures
11+
// into forms that Enzyme can't process correctly.
12+
13+
#![feature(autodiff)]
14+
15+
use std::autodiff::{autodiff_forward, autodiff_reverse};
16+
17+
#[derive(Copy, Clone)]
18+
struct Input {
19+
x: f32,
20+
y: f32,
21+
}
22+
23+
#[derive(Copy, Clone)]
24+
struct Wrapper {
25+
z: f32,
26+
}
27+
28+
#[derive(Copy, Clone)]
29+
struct NestedInput {
30+
x: f32,
31+
y: Wrapper,
32+
}
33+
34+
fn square(x: f32) -> f32 {
35+
x * x
36+
}
37+
38+
// CHECK: ; abi_handling::f1
39+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
40+
// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E
41+
// debug-SAME: (ptr align 4 %x)
42+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E
43+
// release-SAME: (float %x.0.val, float %x.4.val)
44+
#[autodiff_forward(df1, Dual, Dual)]
45+
fn f1(x: &[f32; 2]) -> f32 {
46+
x[0] + x[1]
47+
}
48+
49+
// CHECK: ; abi_handling::f2
50+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
51+
// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E
52+
// debug-SAME: (ptr %f, float %x)
53+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E
54+
// release-SAME: (float noundef %x)
55+
#[autodiff_reverse(df2, Const, Active, Active)]
56+
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
57+
f(x)
58+
}
59+
60+
// CHECK: ; abi_handling::f3
61+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
62+
// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E
63+
// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
64+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E
65+
// release-SAME: (float %x.0.val)
66+
#[autodiff_forward(df3, Dual, Dual, Dual)]
67+
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
68+
*x * *y
69+
}
70+
71+
// CHECK: ; abi_handling::f4
72+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
73+
// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE
74+
// debug-SAME: (float %x.0, float %x.1)
75+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE
76+
// release-SAME: (float noundef %x.0, float noundef %x.1)
77+
#[autodiff_forward(df4, Dual, Dual)]
78+
fn f4(x: (f32, f32)) -> f32 {
79+
x.0 * x.1
80+
}
81+
82+
// CHECK: ; abi_handling::f5
83+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
84+
// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
85+
// debug-SAME: (float %i.0, float %i.1)
86+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
87+
// release-SAME: (float noundef %i.0, float noundef %i.1)
88+
#[autodiff_forward(df5, Dual, Dual)]
89+
fn f5(i: Input) -> f32 {
90+
i.x + i.y
91+
}
92+
93+
// CHECK: ; abi_handling::f6
94+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
95+
// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE
96+
// debug-SAME: (float %i.0, float %i.1)
97+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE
98+
// release-SAME: (float noundef %i.0, float noundef %i.1)
99+
#[autodiff_forward(df6, Dual, Dual)]
100+
fn f6(i: NestedInput) -> f32 {
101+
i.x + i.y.z * i.y.z
102+
}
103+
104+
// CHECK: ; abi_handling::f7
105+
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
106+
// debug-NEXT: define internal float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
107+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
108+
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
109+
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
110+
#[autodiff_forward(df7, Dual, Dual)]
111+
fn f7(x: (&f32, &f32)) -> f32 {
112+
x.0 * x.1
113+
}
114+
115+
// df1
116+
// release: define internal fastcc { float, float }
117+
// release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
118+
// release-SAME: (float %x.0.val, float %x.4.val)
119+
// release-NEXT: start:
120+
// release-NEXT: %_0 = fadd float %x.0.val, %x.4.val
121+
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
122+
// release-NEXT: %1 = insertvalue { float, float } %0, float 1.000000e+00, 1
123+
// release-NEXT: ret { float, float } %1
124+
// release-NEXT: }
125+
126+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
127+
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'")
128+
// debug-NEXT: start:
129+
// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0
130+
// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0
131+
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4
132+
// debug-NEXT: %_2 = load float, ptr %0, align 4
133+
// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1
134+
// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1
135+
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4
136+
// debug-NEXT: %_5 = load float, ptr %1, align 4
137+
// debug-NEXT: %_0 = fadd float %_2, %_5
138+
// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl"
139+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
140+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
141+
// debug-NEXT: ret { float, float } %4
142+
// debug-NEXT: }
143+
144+
// df2
145+
// release: define internal fastcc { float, float }
146+
// release-SAME: @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E
147+
// release-SAME: (float noundef %x)
148+
// release-NEXT: invertstart:
149+
// release-NEXT: %_0.i = fmul float %x, %x
150+
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0
151+
// release-NEXT: %1 = insertvalue { float, float } %0, float 0.000000e+00, 1
152+
// release-NEXT: ret { float, float } %1
153+
// release-NEXT: }
154+
155+
// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E
156+
// debug-SAME: (ptr %f, float %x, float %differeturn)
157+
// debug-NEXT: start:
158+
// debug-NEXT: %"x'de" = alloca float, align 4
159+
// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4
160+
// debug-NEXT: %toreturn = alloca float, align 4
161+
// debug-NEXT: %_0 = call float %f(float %x)
162+
// debug-NEXT: store float %_0, ptr %toreturn, align 4
163+
// debug-NEXT: br label %invertstart
164+
// debug-EMPTY:
165+
// debug-NEXT: invertstart: ; preds = %start
166+
// debug-NEXT: %retreload = load float, ptr %toreturn, align 4
167+
// debug-NEXT: %0 = load float, ptr %"x'de", align 4
168+
// debug-NEXT: %1 = insertvalue { float, float } undef, float %retreload, 0
169+
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
170+
// debug-NEXT: ret { float, float } %2
171+
// debug-NEXT: }
172+
173+
// df3
174+
// release: define internal fastcc { float, float }
175+
// release-SAME: @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
176+
// release-SAME: (float %x.0.val)
177+
// release-NEXT: start:
178+
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0
179+
// release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1
180+
// release-NEXT: ret { float, float } %1
181+
// release-NEXT: }
182+
183+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
184+
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'")
185+
// debug-NEXT: start:
186+
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4
187+
// debug-NEXT: %_3 = load float, ptr %x, align 4
188+
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4
189+
// debug-NEXT: %_4 = load float, ptr %y, align 4
190+
// debug-NEXT: %_0 = fmul float %_3, %_4
191+
// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4
192+
// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3
193+
// debug-NEXT: %2 = fadd fast float %0, %1
194+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
195+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
196+
// debug-NEXT: ret { float, float } %4
197+
// debug-NEXT: }
198+
199+
// df4
200+
// release: define internal fastcc { float, float }
201+
// release-SAME: @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE
202+
// release-SAME: (float noundef %x.0, float %"x.0'")
203+
// release-NEXT: start:
204+
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0
205+
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1
206+
// release-NEXT: ret { float, float } %1
207+
// release-NEXT: }
208+
209+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE
210+
// debug-SAME: (float %x.0, float %"x.0'", float %x.1, float %"x.1'")
211+
// debug-NEXT: start:
212+
// debug-NEXT: %_0 = fmul float %x.0, %x.1
213+
// debug-NEXT: %0 = fmul fast float %"x.0'", %x.1
214+
// debug-NEXT: %1 = fmul fast float %"x.1'", %x.0
215+
// debug-NEXT: %2 = fadd fast float %0, %1
216+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
217+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
218+
// debug-NEXT: ret { float, float } %4
219+
// debug-NEXT: }
220+
221+
// df5
222+
// release: define internal fastcc { float, float }
223+
// release-SAME: @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
224+
// release-SAME: (float noundef %i.0, float %"i.0'")
225+
// release-NEXT: start:
226+
// release-NEXT: %_0 = fadd float %i.0, 1.000000e+00
227+
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
228+
// release-NEXT: %1 = insertvalue { float, float } %0, float %"i.0'", 1
229+
// release-NEXT: ret { float, float } %1
230+
// release-NEXT: }
231+
232+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
233+
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'")
234+
// debug-NEXT: start:
235+
// debug-NEXT: %_0 = fadd float %i.0, %i.1
236+
// debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'"
237+
// debug-NEXT: %1 = insertvalue { float, float } undef, float %_0, 0
238+
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
239+
// debug-NEXT: ret { float, float } %2
240+
// debug-NEXT: }
241+
242+
// df6
243+
// release: define internal fastcc { float, float }
244+
// release-SAME: @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE
245+
// release-SAME: (float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'")
246+
// release-NEXT: start:
247+
// release-NEXT: %_3 = fmul float %i.1, %i.1
248+
// release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'"
249+
// release-NEXT: %1 = fmul fast float %0, %i.1
250+
// release-NEXT: %_0 = fadd float %i.0, %_3
251+
// release-NEXT: %2 = fadd fast float %"i.0'", %1
252+
// release-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
253+
// release-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
254+
// release-NEXT: ret { float, float } %4
255+
// release-NEXT: }
256+
257+
// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE
258+
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'")
259+
// debug-NEXT: start:
260+
// debug-NEXT: %_3 = fmul float %i.1, %i.1
261+
// debug-NEXT: %0 = fmul fast float %"i.1'", %i.1
262+
// debug-NEXT: %1 = fmul fast float %"i.1'", %i.1
263+
// debug-NEXT: %2 = fadd fast float %0, %1
264+
// debug-NEXT: %_0 = fadd float %i.0, %_3
265+
// debug-NEXT: %3 = fadd fast float %"i.0'", %2
266+
// debug-NEXT: %4 = insertvalue { float, float } undef, float %_0, 0
267+
// debug-NEXT: %5 = insertvalue { float, float } %4, float %3, 1
268+
// debug-NEXT: ret { float, float } %5
269+
// debug-NEXT: }
270+
271+
// df7
272+
// release: define internal fastcc { float, float }
273+
// release-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
274+
// release-SAME: (float %x.0.0.val, float %"x.0'.0.val")
275+
// release-NEXT: start:
276+
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.0.val, 0
277+
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'.0.val", 1
278+
// release-NEXT: ret { float, float } %1
279+
// release-NEXT: }
280+
281+
// debug: define internal { float, float }
282+
// debug-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
283+
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %"x.0'", ptr align 4 %x.1, ptr align 4 %"x.1'")
284+
// debug-NEXT: start:
285+
// debug-NEXT: %0 = call fast { float, float } @"fwddiffe_ZN49_{{.*}}"
286+
// debug-NEXT: %1 = extractvalue { float, float } %0, 0
287+
// debug-NEXT: %2 = extractvalue { float, float } %0, 1
288+
// debug-NEXT: %3 = insertvalue { float, float } undef, float %1, 0
289+
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
290+
// debug-NEXT: ret { float, float } %4
291+
// debug-NEXT: }
292+
293+
fn main() {
294+
let x = std::hint::black_box(2.0);
295+
let y = std::hint::black_box(3.0);
296+
let z = std::hint::black_box(4.0);
297+
static Y: f32 = std::hint::black_box(3.2);
298+
299+
let in_f1 = [x, y];
300+
dbg!(f1(&in_f1));
301+
let res_f1 = df1(&in_f1, &[1.0, 0.0]);
302+
dbg!(res_f1);
303+
304+
dbg!(f2(square, x));
305+
let res_f2 = df2(square, x, 1.0);
306+
dbg!(res_f2);
307+
308+
dbg!(f3(&x, &Y));
309+
let res_f3 = df3(&x, &Y, &1.0, &0.0);
310+
dbg!(res_f3);
311+
312+
let in_f4 = (x, y);
313+
dbg!(f4(in_f4));
314+
let res_f4 = df4(in_f4, (1.0, 0.0));
315+
dbg!(res_f4);
316+
317+
let in_f5 = Input { x, y };
318+
dbg!(f5(in_f5));
319+
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
320+
dbg!(res_f5);
321+
322+
let in_f6 = NestedInput { x, y: Wrapper { z: y } };
323+
dbg!(f6(in_f6));
324+
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
325+
dbg!(res_f6);
326+
327+
let in_f7 = (&x, &y);
328+
dbg!(f7(in_f7));
329+
let res_f7 = df7(in_f7, (&1.0, &0.0));
330+
dbg!(res_f7);
331+
}

0 commit comments

Comments
 (0)