Skip to content

Commit 2c67718

Browse files
authored
[flang][cuda] Introduce cuf.set_allocator_idx operation (#148717)
1 parent 5eecec8 commit 2c67718

File tree

9 files changed

+120
-2
lines changed

9 files changed

+120
-2
lines changed

flang-rt/lib/cuda/descriptor.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ void RTDEF(CUFDescriptorCheckSection)(
6262
}
6363
}
6464

65+
void RTDEF(CUFSetAllocatorIndex)(
66+
Descriptor *, int index, const char *sourceFile, int sourceLine) {
67+
if (!desc) {
68+
Terminator terminator{sourceFile, sourceLine};
69+
terminator.Crash("descriptor is null");
70+
}
71+
desc->SetAllocIdx(index);
72+
}
73+
6574
RT_EXT_API_GROUP_END
6675
}
6776
} // namespace Fortran::runtime::cuda

flang-rt/unittests/Runtime/CUDA/AllocatorCUF.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,13 @@ TEST(AllocatableCUFTest, DescriptorAllocationTest) {
7272
EXPECT_TRUE(desc != nullptr);
7373
RTNAME(CUFFreeDescriptor)(desc);
7474
}
75+
76+
TEST(AllocatableCUFTest, CUFSetAllocatorIndex) {
77+
using Fortran::common::TypeCategory;
78+
RTNAME(CUFRegisterAllocator)();
79+
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
80+
auto a{createAllocatable(TypeCategory::Real, 4)};
81+
EXPECT_EQ((int)kDefaultAllocator, a->GetAllocIdx());
82+
RTNAME(CUFSetAllocatorIndex)(*a, kDeviceAllocatorPos, __FILE__, __LINE__);
83+
EXPECT_EQ((int)kDeviceAllocatorPos, a->GetAllocIdx());
84+
}

flang/include/flang/Optimizer/Builder/Runtime/CUDA/Descriptor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void genSyncGlobalDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
3131
void genDescriptorCheckSection(fir::FirOpBuilder &builder, mlir::Location loc,
3232
mlir::Value desc);
3333

34+
/// Generate runtime call to set the allocator index in the descriptor.
35+
void genSetAllocatorIndex(fir::FirOpBuilder &builder, mlir::Location loc,
36+
mlir::Value desc, mlir::Value index);
37+
3438
} // namespace fir::runtime::cuda
3539

3640
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_CUDA_DESCRIPTOR_H_

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,4 +388,25 @@ def cuf_StreamCastOp : cuf_Op<"stream_cast", [NoMemoryEffect]> {
388388
let hasVerifier = 1;
389389
}
390390

391+
def cuf_SetAllocatorIndexOp : cuf_Op<"set_allocator_idx", []> {
392+
let summary = "Set the allocator index in a descriptor";
393+
394+
let description = [{
395+
Allocator index in the Fortran descriptor is used to retrived the correct
396+
CUDA allocator to allocate the memory on the device.
397+
In many cases the allocator index is set when the descriptor is created. For
398+
device components, the descriptor is part of the derived-type itself and
399+
needs to be set after the derived-type is allocated in managed memory.
400+
}];
401+
402+
let arguments = (ins Arg<fir_ReferenceType, "", [MemRead, MemWrite]>:$box,
403+
cuf_DataAttributeAttr:$data_attr);
404+
405+
let assemblyFormat = [{
406+
$box `:` qualified(type($box)) attr-dict
407+
}];
408+
409+
let hasVerifier = 1;
410+
}
411+
391412
#endif // FORTRAN_DIALECT_CUF_CUF_OPS

flang/include/flang/Runtime/CUDA/descriptor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void RTDECL(CUFSyncGlobalDescriptor)(
4141
void RTDECL(CUFDescriptorCheckSection)(
4242
const Descriptor *, const char *sourceFile = nullptr, int sourceLine = 0);
4343

44+
/// Set the allocator index with the provided value.
45+
void RTDECL(CUFSetAllocatorIndex)(Descriptor *, int index,
46+
const char *sourceFile = nullptr, int sourceLine = 0);
47+
4448
} // extern "C"
4549

4650
} // namespace Fortran::runtime::cuda

flang/lib/Optimizer/Builder/Runtime/CUDA/Descriptor.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,18 @@ void fir::runtime::cuda::genDescriptorCheckSection(fir::FirOpBuilder &builder,
4747
builder, loc, fTy, desc, sourceFile, sourceLine)};
4848
builder.create<fir::CallOp>(loc, func, args);
4949
}
50+
51+
void fir::runtime::cuda::genSetAllocatorIndex(fir::FirOpBuilder &builder,
52+
mlir::Location loc,
53+
mlir::Value desc,
54+
mlir::Value index) {
55+
mlir::func::FuncOp func =
56+
fir::runtime::getRuntimeFunc<mkRTKey(CUFSetAllocatorIndex)>(loc, builder);
57+
auto fTy = func.getFunctionType();
58+
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
59+
mlir::Value sourceLine =
60+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
61+
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
62+
builder, loc, fTy, desc, index, sourceFile, sourceLine)};
63+
builder.create<fir::CallOp>(loc, func, args);
64+
}

flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,17 @@ llvm::LogicalResult cuf::StreamCastOp::verify() {
345345
return checkStreamType(*this);
346346
}
347347

348+
//===----------------------------------------------------------------------===//
349+
// SetAllocatorOp
350+
//===----------------------------------------------------------------------===//
351+
352+
llvm::LogicalResult cuf::SetAllocatorIndexOp::verify() {
353+
if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
354+
return emitOpError(
355+
"expect box to be a reference to class or box type value");
356+
return mlir::success();
357+
}
358+
348359
// Tablegen operators
349360

350361
#define GET_OP_CLASSES

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "flang/Runtime/CUDA/memory.h"
2323
#include "flang/Runtime/CUDA/pointer.h"
2424
#include "flang/Runtime/allocatable.h"
25+
#include "flang/Runtime/allocator-registry-consts.h"
2526
#include "flang/Support/Fortran.h"
2627
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2728
#include "mlir/Dialect/DLTI/DLTI.h"
@@ -923,6 +924,34 @@ struct CUFSyncDescriptorOpConversion
923924
}
924925
};
925926

927+
struct CUFSetAllocatorIndexOpConversion
928+
: public mlir::OpRewritePattern<cuf::SetAllocatorIndexOp> {
929+
using OpRewritePattern::OpRewritePattern;
930+
931+
mlir::LogicalResult
932+
matchAndRewrite(cuf::SetAllocatorIndexOp op,
933+
mlir::PatternRewriter &rewriter) const override {
934+
auto mod = op->getParentOfType<mlir::ModuleOp>();
935+
fir::FirOpBuilder builder(rewriter, mod);
936+
mlir::Location loc = op.getLoc();
937+
int idx = kDefaultAllocator;
938+
if (op.getDataAttr() == cuf::DataAttribute::Device) {
939+
idx = kDeviceAllocatorPos;
940+
} else if (op.getDataAttr() == cuf::DataAttribute::Managed) {
941+
idx = kManagedAllocatorPos;
942+
} else if (op.getDataAttr() == cuf::DataAttribute::Unified) {
943+
idx = kUnifiedAllocatorPos;
944+
} else if (op.getDataAttr() == cuf::DataAttribute::Pinned) {
945+
idx = kPinnedAllocatorPos;
946+
}
947+
mlir::Value index =
948+
builder.createIntegerConstant(loc, builder.getI32Type(), idx);
949+
fir::runtime::cuda::genSetAllocatorIndex(builder, loc, op.getBox(), index);
950+
op.erase();
951+
return mlir::success();
952+
}
953+
};
954+
926955
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
927956
public:
928957
void runOnOperation() override {
@@ -984,8 +1013,8 @@ void cuf::populateCUFToFIRConversionPatterns(
9841013
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
9851014
patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
9861015
patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
987-
CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
988-
patterns.getContext());
1016+
CUFFreeOpConversion, CUFSyncDescriptorOpConversion,
1017+
CUFSetAllocatorIndexOpConversion>(patterns.getContext());
9891018
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
9901019
&dl, &converter);
9911020
patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(

flang/test/Fir/CUDA/cuda-alloc-free.fir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,19 @@ func.func @_QQalloc_char() attributes {fir.bindc_name = "alloc_char"} {
9494
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
9595
// CHECK: fir.call @_FortranACUFMemAlloc(%[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) {cuf.data_attr = #cuf.cuda<device>} : (i64, i32, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
9696

97+
98+
func.func @_QQsetalloc() {
99+
%0 = cuf.alloc !fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}> {bindc_name = "d1", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEd1"} -> !fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>
100+
%1 = fir.coordinate_of %0, a2 : (!fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
101+
cuf.set_allocator_idx %1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
102+
return
103+
}
104+
105+
// CHECK-LABEL: func.func @_QQsetalloc() {
106+
// CHECK: %[[DT:.*]] = fir.call @_FortranACUFMemAlloc
107+
// CHECK: %[[CONV:.*]] = fir.convert %[[DT]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>
108+
// CHECK: %[[COMP:.*]] = fir.coordinate_of %[[CONV]], a2 : (!fir.ref<!fir.type<_QMm1Tdt1{a2:!fir.box<!fir.heap<!fir.array<?xf32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
109+
// CHECK: %[[DESC:.*]] = fir.convert %[[COMP]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
110+
// CHECK: fir.call @_FortranACUFSetAllocatorIndex(%[[DESC]], %c2{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
111+
97112
} // end module

0 commit comments

Comments
 (0)