@@ -17,6 +17,7 @@ use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
17
17
use rustc_middle:: { bug, span_bug} ;
18
18
use rustc_span:: { Span , Symbol , sym} ;
19
19
use rustc_symbol_mangling:: { mangle_internal_symbol, symbol_name_for_instance_in_crate} ;
20
+ use rustc_target:: callconv:: PassMode ;
20
21
use rustc_target:: spec:: PanicStrategy ;
21
22
use tracing:: debug;
22
23
@@ -1144,8 +1145,6 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
1144
1145
let ret_ty = sig. output ( ) ;
1145
1146
let llret_ty = bx. layout_of ( ret_ty) . llvm_type ( bx) ;
1146
1147
1147
- let val_arr: Vec < & ' ll Value > = get_args_from_tuple ( bx, args[ 2 ] ) ;
1148
-
1149
1148
// Get source, diff, and attrs
1150
1149
let ( source_id, source_args) = match fn_args. into_type_list ( tcx) [ 0 ] . kind ( ) {
1151
1150
ty:: FnDef ( def_id, source_params) => ( def_id, source_params) ,
@@ -1163,6 +1162,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
1163
1162
} ;
1164
1163
let fn_diff =
1165
1164
Instance :: try_resolve ( tcx, bx. cx . typing_env ( ) , * diff_id, diff_args) . unwrap ( ) . unwrap ( ) ;
1165
+ let val_arr: Vec < & ' ll Value > = get_args_from_tuple ( bx, args[ 2 ] , fn_diff) ;
1166
1166
let diff_symbol = symbol_name_for_instance_in_crate ( tcx, fn_diff. clone ( ) , LOCAL_CRATE ) ;
1167
1167
1168
1168
let diff_attrs = autodiff_attrs ( tcx, fn_diff. def_id ( ) ) ;
@@ -1189,39 +1189,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
1189
1189
1190
1190
fn get_args_from_tuple < ' ll , ' tcx > (
1191
1191
bx : & mut Builder < ' _ , ' ll , ' tcx > ,
1192
- op : OperandRef < ' tcx , & ' ll Value > ,
1192
+ tuple_op : OperandRef < ' tcx , & ' ll Value > ,
1193
+ fn_instance : Instance < ' tcx > ,
1193
1194
) -> Vec < & ' ll Value > {
1194
- match op. val {
1195
- OperandValue :: Ref ( ref place_value) => {
1196
- let mut ret_arr = vec ! [ ] ;
1197
- let tuple_place = PlaceRef { val : * place_value, layout : op. layout } ;
1198
-
1199
- for i in 0 ..tuple_place. layout . layout . 0 . fields . count ( ) {
1200
- let field_place = tuple_place. project_field ( bx, i) ;
1201
- let field_layout = tuple_place. layout . field ( bx, i) ;
1202
- let field_ty = field_layout. ty ;
1203
- let llvm_ty = field_layout. llvm_type ( bx. cx ) ;
1204
-
1205
- let field_val = bx. load ( llvm_ty, field_place. val . llval , field_place. val . align ) ;
1206
-
1207
- match field_ty. kind ( ) {
1208
- ty:: Ref ( _, inner_ty, _) if matches ! ( inner_ty. kind( ) , ty:: Slice ( _) ) => {
1209
- let ptr = bx. extract_value ( field_val, 0 ) ;
1210
- let len = bx. extract_value ( field_val, 1 ) ;
1211
- ret_arr. push ( ptr) ;
1212
- ret_arr. push ( len) ;
1195
+ let cx = bx. cx ;
1196
+ let fn_abi = cx. fn_abi_of_instance ( fn_instance, ty:: List :: empty ( ) ) ;
1197
+
1198
+ match tuple_op. val {
1199
+ OperandValue :: Immediate ( val) => vec ! [ val] ,
1200
+ OperandValue :: Pair ( v1, v2) => vec ! [ v1, v2] ,
1201
+ OperandValue :: Ref ( ptr) => {
1202
+ let tuple_place = PlaceRef { val : ptr, layout : tuple_op. layout } ;
1203
+
1204
+ let mut result = Vec :: with_capacity ( fn_abi. args . len ( ) ) ;
1205
+ let mut tuple_index = 0 ;
1206
+
1207
+ for arg in & fn_abi. args {
1208
+ match arg. mode {
1209
+ PassMode :: Ignore => { }
1210
+ PassMode :: Direct ( _) | PassMode :: Cast { .. } => {
1211
+ let field = tuple_place. project_field ( bx, tuple_index) ;
1212
+ let llvm_ty = field. layout . llvm_type ( bx. cx ) ;
1213
+ let val = bx. load ( llvm_ty, field. val . llval , field. val . align ) ;
1214
+ result. push ( val) ;
1215
+ tuple_index += 1 ;
1213
1216
}
1214
- _ => {
1215
- ret_arr. push ( field_val) ;
1217
+ PassMode :: Pair ( _, _) => {
1218
+ let field = tuple_place. project_field ( bx, tuple_index) ;
1219
+ let llvm_ty = field. layout . llvm_type ( bx. cx ) ;
1220
+ let pair_val = bx. load ( llvm_ty, field. val . llval , field. val . align ) ;
1221
+ result. push ( bx. extract_value ( pair_val, 0 ) ) ;
1222
+ result. push ( bx. extract_value ( pair_val, 1 ) ) ;
1223
+ tuple_index += 1 ;
1224
+ }
1225
+ PassMode :: Indirect { .. } => {
1226
+ let field = tuple_place. project_field ( bx, tuple_index) ;
1227
+ result. push ( field. val . llval ) ;
1228
+ tuple_index += 1 ;
1216
1229
}
1217
1230
}
1218
1231
}
1219
1232
1220
- ret_arr
1233
+ result
1221
1234
}
1222
- OperandValue :: Pair ( v1, v2) => vec ! [ v1, v2] ,
1223
- OperandValue :: Immediate ( v) => vec ! [ v] ,
1224
- OperandValue :: ZeroSized => bug ! ( "unexpected `ZeroSized` arg" ) ,
1235
+
1236
+ OperandValue :: ZeroSized => bug ! ( "unexpected ZeroSized argument in get_args_from_tuple" ) ,
1225
1237
}
1226
1238
}
1227
1239
0 commit comments