Skip to content

Commit e9139fe

Browse files
committed
Get args from tuple using fnabi and minor fixes
1 parent 7502ca0 commit e9139fe

File tree

3 files changed

+56
-36
lines changed

3 files changed

+56
-36
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ mod llvm_enzyme {
1616
use rustc_ast::tokenstream::*;
1717
use rustc_ast::visit::AssocCtxt::*;
1818
use rustc_ast::{
19-
self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, FnRetTy,
20-
FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path,
21-
PathSegment, TyKind, Visibility,
19+
self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
20+
FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
21+
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
2222
};
2323
use rustc_expand::base::{Annotatable, ExtCtxt};
2424
use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -554,10 +554,18 @@ mod llvm_enzyme {
554554
let generic_args = generics
555555
.params
556556
.iter()
557-
.map(|p| {
558-
let path = ast::Path::from_ident(p.ident);
559-
let ty = ecx.ty_path(path);
560-
AngleBracketedArg::Arg(GenericArg::Type(ty))
557+
.filter_map(|p| match &p.kind {
558+
GenericParamKind::Type { .. } => {
559+
let path = ast::Path::from_ident(p.ident);
560+
let ty = ecx.ty_path(path);
561+
Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
562+
}
563+
GenericParamKind::Const { .. } => {
564+
let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
565+
let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
566+
Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
567+
}
568+
GenericParamKind::Lifetime { .. } => None,
561569
})
562570
.collect::<ThinVec<_>>();
563571

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::common::TypeKind;
55
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
6-
use rustc_middle::{bug, ty};
76
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
7+
use rustc_middle::{bug, ty};
88
use tracing::debug;
99

1010
use crate::builder::{Builder, PlaceRef, UNNAMED};

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
1717
use rustc_middle::{bug, span_bug};
1818
use rustc_span::{Span, Symbol, sym};
1919
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
20+
use rustc_target::callconv::PassMode;
2021
use rustc_target::spec::PanicStrategy;
2122
use tracing::debug;
2223

@@ -1144,8 +1145,6 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11441145
let ret_ty = sig.output();
11451146
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
11461147

1147-
let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]);
1148-
11491148
// Get source, diff, and attrs
11501149
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
11511150
ty::FnDef(def_id, source_params) => (def_id, source_params),
@@ -1163,6 +1162,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11631162
};
11641163
let fn_diff =
11651164
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
1165+
let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2], fn_diff);
11661166
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
11671167

11681168
let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
@@ -1189,39 +1189,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11891189

11901190
fn get_args_from_tuple<'ll, 'tcx>(
11911191
bx: &mut Builder<'_, 'll, 'tcx>,
1192-
op: OperandRef<'tcx, &'ll Value>,
1192+
tuple_op: OperandRef<'tcx, &'ll Value>,
1193+
fn_instance: Instance<'tcx>,
11931194
) -> Vec<&'ll Value> {
1194-
match op.val {
1195-
OperandValue::Ref(ref place_value) => {
1196-
let mut ret_arr = vec![];
1197-
let tuple_place = PlaceRef { val: *place_value, layout: op.layout };
1198-
1199-
for i in 0..tuple_place.layout.layout.0.fields.count() {
1200-
let field_place = tuple_place.project_field(bx, i);
1201-
let field_layout = tuple_place.layout.field(bx, i);
1202-
let field_ty = field_layout.ty;
1203-
let llvm_ty = field_layout.llvm_type(bx.cx);
1204-
1205-
let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align);
1206-
1207-
match field_ty.kind() {
1208-
ty::Ref(_, inner_ty, _) if matches!(inner_ty.kind(), ty::Slice(_)) => {
1209-
let ptr = bx.extract_value(field_val, 0);
1210-
let len = bx.extract_value(field_val, 1);
1211-
ret_arr.push(ptr);
1212-
ret_arr.push(len);
1195+
let cx = bx.cx;
1196+
let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty());
1197+
1198+
match tuple_op.val {
1199+
OperandValue::Immediate(val) => vec![val],
1200+
OperandValue::Pair(v1, v2) => vec![v1, v2],
1201+
OperandValue::Ref(ptr) => {
1202+
let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout };
1203+
1204+
let mut result = Vec::with_capacity(fn_abi.args.len());
1205+
let mut tuple_index = 0;
1206+
1207+
for arg in &fn_abi.args {
1208+
match arg.mode {
1209+
PassMode::Ignore => {}
1210+
PassMode::Direct(_) | PassMode::Cast { .. } => {
1211+
let field = tuple_place.project_field(bx, tuple_index);
1212+
let llvm_ty = field.layout.llvm_type(bx.cx);
1213+
let val = bx.load(llvm_ty, field.val.llval, field.val.align);
1214+
result.push(val);
1215+
tuple_index += 1;
12131216
}
1214-
_ => {
1215-
ret_arr.push(field_val);
1217+
PassMode::Pair(_, _) => {
1218+
let field = tuple_place.project_field(bx, tuple_index);
1219+
let llvm_ty = field.layout.llvm_type(bx.cx);
1220+
let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align);
1221+
result.push(bx.extract_value(pair_val, 0));
1222+
result.push(bx.extract_value(pair_val, 1));
1223+
tuple_index += 1;
1224+
}
1225+
PassMode::Indirect { .. } => {
1226+
let field = tuple_place.project_field(bx, tuple_index);
1227+
result.push(field.val.llval);
1228+
tuple_index += 1;
12161229
}
12171230
}
12181231
}
12191232

1220-
ret_arr
1233+
result
12211234
}
1222-
OperandValue::Pair(v1, v2) => vec![v1, v2],
1223-
OperandValue::Immediate(v) => vec![v],
1224-
OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
1235+
1236+
OperandValue::ZeroSized => bug!("unexpected ZeroSized argument in get_args_from_tuple"),
12251237
}
12261238
}
12271239

0 commit comments

Comments
 (0)