Skip to content

Commit 75b6637

Browse files
committed
model offload C++ structs through Rust structs
1 parent 6e948da commit 75b6637

File tree

1 file changed

+96
-75
lines changed

1 file changed

+96
-75
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 96 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub(crate) fn handle_gpu_code<'ll>(
1919
let mut o_types = vec![];
2020
let mut kernels = vec![];
2121
let mut region_ids = vec![];
22-
let offload_entry_ty = add_tgt_offload_entry(&cx);
22+
let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
2323
for num in 0..9 {
2424
let kernel = cx.get_function(&format!("kernel_{num}"));
2525
if let Some(kernel) = kernel {
@@ -54,7 +54,6 @@ fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm
5454
// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
5555
// offloaded was defined.
5656
fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
57-
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
5857
let unknown_txt = ";unknown;unknown;0;0;;";
5958
let c_entry_name = CString::new(unknown_txt).unwrap();
6059
let c_val = c_entry_name.as_bytes_with_nul();
@@ -79,15 +78,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
7978
at_one
8079
}
8180

82-
pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
83-
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
84-
let tptr = cx.type_ptr();
85-
let ti64 = cx.type_i64();
86-
let ti32 = cx.type_i32();
87-
let ti16 = cx.type_i16();
88-
// For each kernel to run on the gpu, we will later generate one entry of this type.
89-
// copied from LLVM
90-
// typedef struct {
81+
struct TgtOffloadEntry {
9182
// uint64_t Reserved;
9283
// uint16_t Version;
9384
// uint16_t Kind;
@@ -97,21 +88,40 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
9788
// uint64_t Size; Size of the entry info (0 if it is a function)
9889
// uint64_t Data;
9990
// void *AuxAddr;
100-
// } __tgt_offload_entry;
101-
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
102-
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
103-
offload_entry_ty
10491
}
10592

106-
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
107-
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
108-
let tptr = cx.type_ptr();
109-
let ti64 = cx.type_i64();
110-
let ti32 = cx.type_i32();
111-
let tarr = cx.type_array(ti32, 3);
93+
impl TgtOffloadEntry {
94+
pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
95+
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
96+
let tptr = cx.type_ptr();
97+
let ti64 = cx.type_i64();
98+
let ti32 = cx.type_i32();
99+
let ti16 = cx.type_i16();
100+
// For each kernel to run on the gpu, we will later generate one entry of this type.
101+
// copied from LLVM
102+
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
103+
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
104+
offload_entry_ty
105+
}
106+
107+
fn new<'ll>(
108+
cx: &'ll SimpleCx<'_>,
109+
region_id: &'ll Value,
110+
llglobal: &'ll Value,
111+
) -> Vec<&'ll Value> {
112+
let reserved = cx.get_const_i64(0);
113+
let version = cx.get_const_i16(1);
114+
let kind = cx.get_const_i16(1);
115+
let flags = cx.get_const_i32(0);
116+
let size = cx.get_const_i64(0);
117+
let data = cx.get_const_i64(0);
118+
let aux_addr = cx.const_null(cx.type_ptr());
119+
vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
120+
}
121+
}
112122

113-
// Taken from the LLVM APITypes.h declaration:
114-
//struct KernelArgsTy {
123+
// Taken from the LLVM APITypes.h declaration:
124+
struct KernelArgsTy {
115125
// uint32_t Version = 0; // Version of this struct for ABI compatibility.
116126
// uint32_t NumArgs = 0; // Number of arguments in each input pointer.
117127
// void **ArgBasePtrs =
@@ -122,8 +132,8 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
122132
// void **ArgNames = nullptr; // Name of the data for debugging, possibly null.
123133
// void **ArgMappers = nullptr; // User-defined mappers, possibly null.
124134
// uint64_t Tripcount =
125-
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
126-
// struct {
135+
// 0; // Tripcount for the teams / distribute loop, 0 otherwise.
136+
// struct {
127137
// uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
128138
// uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
129139
// uint64_t Unused : 62;
@@ -133,12 +143,53 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
133143
// // The number of threads (for x,y,z dimension).
134144
// uint32_t ThreadLimit[3] = {0, 0, 0};
135145
// uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
136-
//};
137-
let kernel_elements =
138-
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
146+
}
147+
148+
impl KernelArgsTy {
149+
const OFFLOAD_VERSION: u64 = 3;
150+
const FLAGS: u64 = 0;
151+
const TRIPCOUNT: u64 = 0;
152+
fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
153+
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
154+
let tptr = cx.type_ptr();
155+
let ti64 = cx.type_i64();
156+
let ti32 = cx.type_i32();
157+
let tarr = cx.type_array(ti32, 3);
158+
159+
let kernel_elements =
160+
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
161+
162+
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
163+
kernel_arguments_ty
164+
}
139165

140-
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
141-
kernel_arguments_ty
166+
fn new<'ll>(
167+
cx: &'ll SimpleCx<'_>,
168+
num_args: u64,
169+
o_types: &[&'ll Value],
170+
geps: [&'ll Value; 3],
171+
) -> [(Align, &'ll Value); 13] {
172+
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
173+
let eight = Align::EIGHT;
174+
let mut values = vec![];
175+
values.push((four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)));
176+
values.push((four, cx.get_const_i32(num_args)));
177+
values.push((eight, geps[0]));
178+
values.push((eight, geps[1]));
179+
values.push((eight, geps[2]));
180+
values.push((eight, o_types[0]));
181+
// The next two are debug infos. FIXME(offload): set them
182+
values.push((eight, cx.const_null(cx.type_ptr())));
183+
values.push((eight, cx.const_null(cx.type_ptr())));
184+
values.push((eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)));
185+
values.push((eight, cx.get_const_i64(KernelArgsTy::FLAGS)));
186+
let ti32 = cx.type_i32();
187+
let ci32_0 = cx.get_const_i32(0);
188+
values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
189+
values.push((four, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
190+
values.push((four, cx.get_const_i32(0)));
191+
values.try_into().expect("tgt_kernel_arguments construction failed")
192+
}
142193
}
143194

144195
fn gen_tgt_data_mappers<'ll>(
@@ -244,19 +295,10 @@ fn gen_define_handling<'ll>(
244295
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
245296
llvm::set_alignment(llglobal, Align::ONE);
246297
llvm::set_section(llglobal, c".llvm.rodata.offloading");
247-
248-
// Not actively used yet, for calling real kernels
249298
let name = format!(".offloading.entry.kernel_{num}");
250299

251300
// See the __tgt_offload_entry documentation above.
252-
let reserved = cx.get_const_i64(0);
253-
let version = cx.get_const_i16(1);
254-
let kind = cx.get_const_i16(1);
255-
let flags = cx.get_const_i32(0);
256-
let size = cx.get_const_i64(0);
257-
let data = cx.get_const_i64(0);
258-
let aux_addr = cx.const_null(cx.type_ptr());
259-
let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr];
301+
let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
260302

261303
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
262304
let c_name = CString::new(name).unwrap();
@@ -319,7 +361,7 @@ fn gen_call_handling<'ll>(
319361
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
320362
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
321363

322-
let tgt_kernel_decl = gen_tgt_kernel_global(&cx);
364+
let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
323365
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
324366

325367
let main_fn = cx.get_function("main");
@@ -407,19 +449,19 @@ fn gen_call_handling<'ll>(
407449
a1: &'ll Value,
408450
a2: &'ll Value,
409451
a4: &'ll Value,
410-
) -> (&'ll Value, &'ll Value, &'ll Value) {
452+
) -> [&'ll Value; 3] {
411453
let i32_0 = cx.get_const_i32(0);
412454

413455
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
414456
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
415457
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
416-
(gep1, gep2, gep3)
458+
[gep1, gep2, gep3]
417459
}
418460

419461
fn generate_mapper_call<'a, 'll>(
420462
builder: &mut SBuilder<'a, 'll>,
421463
cx: &'ll SimpleCx<'ll>,
422-
geps: (&'ll Value, &'ll Value, &'ll Value),
464+
geps: [&'ll Value; 3],
423465
o_type: &'ll Value,
424466
fn_to_call: &'ll Value,
425467
fn_ty: &'ll Type,
@@ -430,7 +472,7 @@ fn gen_call_handling<'ll>(
430472
let i64_max = cx.get_const_i64(u64::MAX);
431473
let num_args = cx.get_const_i32(num_args);
432474
let args =
433-
vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr];
475+
vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
434476
builder.call(fn_ty, fn_to_call, &args, None);
435477
}
436478

@@ -439,36 +481,20 @@ fn gen_call_handling<'ll>(
439481
let o = o_types[0];
440482
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
441483
generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
484+
let values = KernelArgsTy::new(&cx, num_args, o_types, geps);
442485

443486
// Step 3)
444-
let mut values = vec![];
445-
let offload_version = cx.get_const_i32(3);
446-
values.push((4, offload_version));
447-
values.push((4, cx.get_const_i32(num_args)));
448-
values.push((8, geps.0));
449-
values.push((8, geps.1));
450-
values.push((8, geps.2));
451-
values.push((8, o_types[0]));
452-
// The next two are debug infos. FIXME(offload) set them
453-
values.push((8, cx.const_null(cx.type_ptr())));
454-
values.push((8, cx.const_null(cx.type_ptr())));
455-
values.push((8, cx.get_const_i64(0)));
456-
values.push((8, cx.get_const_i64(0)));
457-
let ti32 = cx.type_i32();
458-
let ci32_0 = cx.get_const_i32(0);
459-
values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
460-
values.push((4, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
461-
values.push((4, cx.get_const_i32(0)));
462-
487+
// Here we fill the KernelArgsTy, see the documentation above
463488
for (i, value) in values.iter().enumerate() {
464489
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
465-
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
490+
builder.store(value.1, ptr, value.0);
466491
}
467492

468493
let args = vec![
469494
s_ident_t,
470-
// MAX == -1
471-
cx.get_const_i64(u64::MAX),
495+
// FIXME(offload) give users a way to select which GPU to use.
496+
cx.get_const_i64(u64::MAX), // MAX == -1.
497+
// FIXME(offload): Don't hardcode the numbers of threads in the future.
472498
cx.get_const_i32(2097152),
473499
cx.get_const_i32(256),
474500
region_ids[0],
@@ -483,19 +509,14 @@ fn gen_call_handling<'ll>(
483509
}
484510

485511
// Step 4)
486-
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
487-
488512
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
489513
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
490514

491515
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
492516

493517
drop(builder);
518+
// FIXME(offload) The issue is that we right now add a call to the gpu version of the function,
519+
// and then delete the call to the CPU version. In the future, we should use an intrinsic which
520+
// directly resolves to a call to the GPU version.
494521
unsafe { llvm::LLVMDeleteFunction(called) };
495-
496-
// With this we generated the following begin and end mappers. We could easily generate the
497-
// update mapper in an update.
498-
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
499-
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
500-
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
501522
}

0 commit comments

Comments
 (0)