From 083808d195ce71a83e6150a5f264b4dd2180c318 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 17 Jul 2025 15:14:31 -0700 Subject: [PATCH] [flang][cuda] Support device component in a pointer or allocatable derived-type --- flang/lib/Lower/ConvertVariable.cpp | 26 ++++++++++++++-- flang/test/Lower/CUDA/cuda-set-allocator.cuf | 32 ++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index 6c4516686f9d0..23d87d7b83c06 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -814,6 +814,10 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, if (auto boxTy = mlir::dyn_cast(baseTy)) baseTy = boxTy.getEleTy(); baseTy = fir::unwrapRefType(baseTy); + + if (mlir::isa(baseTy)) + TODO(loc, "array of derived-type with device component"); + auto recTy = mlir::dyn_cast(fir::unwrapSequenceType(baseTy)); assert(recTy && "expected fir::RecordType"); @@ -824,7 +828,7 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, if (Fortran::semantics::IsDeviceAllocatable(sym)) { unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); mlir::Type fieldTy; - std::vector coordinates; + llvm::SmallVector coordinates; if (fieldIdx != std::numeric_limits::max()) { // Field found in the base record type. @@ -867,8 +871,24 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, TODO(loc, "device resident component in complex derived-type " "hierarchy"); - mlir::Value comp = builder.create( - loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates); + mlir::Value base = fir::getBase(exv); + mlir::Value comp; + if (mlir::isa(fir::unwrapRefType(base.getType()))) { + mlir::Value box = builder.create(loc, base); + mlir::Value addr = builder.create(loc, box); + llvm::SmallVector lenParams; + assert(coordinates.size() == 1 && "expect one coordinate"); + auto field = mlir::dyn_cast( + coordinates[0].getDefiningOp()); + comp = builder.create( + loc, builder.getRefType(fieldTy), addr, + /*component=*/field.getFieldName(), + /*componentShape=*/mlir::Value{}, + hlfir::DesignateOp::Subscripts{}); + } else { + comp = builder.create( + loc, builder.getRefType(fieldTy), base, coordinates); + } cuf::DataAttributeAttr dataAttr = Fortran::lower::translateSymbolCUFDataAttribute( builder.getContext(), sym); diff --git a/flang/test/Lower/CUDA/cuda-set-allocator.cuf b/flang/test/Lower/CUDA/cuda-set-allocator.cuf index ee89ea38a3fc7..e3bb181f65398 100644 --- a/flang/test/Lower/CUDA/cuda-set-allocator.cuf +++ b/flang/test/Lower/CUDA/cuda-set-allocator.cuf @@ -21,4 +21,36 @@ contains ! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> ! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref>>> {data_attr = #cuf.cuda} + subroutine sub2() + type(ty_device), pointer :: d1 + end subroutine + +! CHECK-LABEL: func.func @_QMm1Psub2() +! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.box>>,y:i32,z:!fir.box>>}>>> {bindc_name = "d1", data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub2Ed1"} -> !fir.ref>>,y:i32,z:!fir.box>>}>>>> +! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QMm1Fsub2Ed1"} : (!fir.ref>>,y:i32,z:!fir.box>>}>>>>) -> (!fir.ref>>,y:i32,z:!fir.box>>}>>>>, !fir.ref>>,y:i32,z:!fir.box>>}>>>>) +! CHECK: %[[LOAD1:.*]] = fir.load %[[DECL]]#0 : !fir.ref>>,y:i32,z:!fir.box>>}>>>> +! CHECK: %[[ADDR1:.*]] = fir.box_addr %[[LOAD1]] : (!fir.box>>,y:i32,z:!fir.box>>}>>>) -> !fir.ptr>>,y:i32,z:!fir.box>>}>> +! CHECK: %[[DESIGNATE1:.*]] = hlfir.designate %[[ADDR1]]{"x"} : (!fir.ptr>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> +! CHECK: cuf.set_allocator_idx %[[DESIGNATE1]] : !fir.ref>>> {data_attr = #cuf.cuda} +! CHECK: %[[LOAD2:.*]] = fir.load %[[DECL]]#0 : !fir.ref>>,y:i32,z:!fir.box>>}>>>> +! CHECK: %[[ADDR2:.*]] = fir.box_addr %[[LOAD2]] : (!fir.box>>,y:i32,z:!fir.box>>}>>>) -> !fir.ptr>>,y:i32,z:!fir.box>>}>> +! CHECK: %[[DESIGNATE2:.*]] = hlfir.designate %[[ADDR2]]{"z"} : (!fir.ptr>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> +! CHECK: cuf.set_allocator_idx %[[DESIGNATE2]] : !fir.ref>>> {data_attr = #cuf.cuda} + + subroutine sub3() + type(ty_device), allocatable :: d1 + end subroutine + +! CHECK-LABEL: func.func @_QMm1Psub3() +! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.box>>,y:i32,z:!fir.box>>}>>> {bindc_name = "d1", data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub3Ed1"} -> !fir.ref>>,y:i32,z:!fir.box>>}>>>> +! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QMm1Fsub3Ed1"} : (!fir.ref>>,y:i32,z:!fir.box>>}>>>>) -> (!fir.ref>>,y:i32,z:!fir.box>>}>>>>, !fir.ref>>,y:i32,z:!fir.box>>}>>>>) +! CHECK: %[[LOAD1:.*]] = fir.load %[[DECL]]#0 : !fir.ref>>,y:i32,z:!fir.box>>}>>>> +! CHECK: %[[ADDR1:.*]] = fir.box_addr %[[LOAD1]] : (!fir.box>>,y:i32,z:!fir.box>>}>>>) -> !fir.heap>>,y:i32,z:!fir.box>>}>> +! CHECK: %[[DESIGNATE1:.*]] = hlfir.designate %[[ADDR1]]{"x"} : (!fir.heap>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> +! CHECK: cuf.set_allocator_idx %[[DESIGNATE1]] : !fir.ref>>> {data_attr = #cuf.cuda} +! CHECK: %[[LOAD2:.*]] = fir.load %[[DECL]]#0 : !fir.ref>>,y:i32,z:!fir.box>>}>>>> +! CHECK: %[[ADDR2:.*]] = fir.box_addr %[[LOAD2]] : (!fir.box>>,y:i32,z:!fir.box>>}>>>) -> !fir.heap>>,y:i32,z:!fir.box>>}>> +! CHECK: %[[DESIGNATE2:.*]] = hlfir.designate %[[ADDR2]]{"z"} : (!fir.heap>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> +! CHECK: cuf.set_allocator_idx %[[DESIGNATE2]] : !fir.ref>>> {data_attr = #cuf.cuda} + end module