Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ class Entity : public mlir::Value {
bool isProcedurePointer() const {
return hlfir::isFortranProcedurePointerType(getType());
}
bool isVolatile() const {
if (auto iface = getIfVariableInterface()) {
if (auto attrs = iface.getFortranAttrs()) {
return bitEnumContainsAny(
attrs.value(), fir::FortranVariableFlagsEnum::fortran_volatile);
}
}
return false;
}
bool isBoxAddressOrValue() const {
return hlfir::isBoxAddressOrValueType(getType());
}
Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/CodeGen/CGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def fircg_XArrayCoorOp : fircg_Op<"ext_array_coor", [AttrSizedOperandSegments]>
Variadic<AnyCoordinateType>:$indices,
Variadic<AnyIntegerType>:$lenParams
);
let results = (outs fir_ReferenceType);
let results = (outs AnyReferenceType);

let assemblyFormat = [{
$memref (`(`$shape^`)`)? (`origin` $shift^)? (`[`$slice^`]`)?
Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Dialect/FIRAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include "mlir/IR/EnumAttr.td"

class fir_Attr<string name> : AttrDef<FIROpsDialect, name>;

def FIRnoAttributes : I32BitEnumAttrCaseNone<"None">;
def FIRnoAttributes : I32BitEnumAttrCaseNone<"None">;
def FIRallocatable : I32BitEnumAttrCaseBit<"allocatable", 0>;
def FIRasynchronous : I32BitEnumAttrCaseBit<"asynchronous", 1>;
def FIRbind_c : I32BitEnumAttrCaseBit<"bind_c", 2>;
Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,7 @@ def fir_ArrayCoorOp : fir_Op<"array_coor",
Variadic<AnyIntegerType>:$typeparams
);

let results = (outs fir_ReferenceType);
let results = (outs AnyReferenceType);

let assemblyFormat = [{
$memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams`
Expand Down
60 changes: 50 additions & 10 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,36 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {

let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
static mlir::Type get(mlir::Type t, bool isVolatile);
}];

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

def fir_VolatileReferenceType : FIR_Type<"VolatileReference", "volatile_ref"> {
let summary = "Volatile reference to an entity type";

let description = [{
The type of a volatile reference to an entity in memory.
}];

let parameters = (ins "mlir::Type":$eleTy);

let builders = [TypeBuilderWithInferredContext<
(ins "mlir::Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
}]>,
];

let extraClassDeclaration = [{
mlir::Type getElementType() const { return getEleTy(); }
}];

let genVerifyDecl = 1;
let assemblyFormat = "`<` $eleTy `>`";
}

def fir_ShapeType : FIR_Type<"Shape", "shape"> {
let summary = "shape of a multidimensional array object";

Expand Down Expand Up @@ -598,18 +622,28 @@ def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
fir_VectorType.predicate, IsTupleTypePred, fir_CharacterType.predicate]>,
"any composite">;

def AnyReferenceType : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_VolatileReferenceType.predicate]>,
"any reference type">;

// Reference types
def AnyReferenceLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_HeapType.predicate, fir_PointerType.predicate,
fir_LLVMPointerType.predicate]>, "any reference">;
def AnyReferenceLike
: TypeConstraint<
Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate,
fir_HeapType.predicate, fir_PointerType.predicate,
fir_LLVMPointerType.predicate]>,
"any reference">;

def FuncType : TypeConstraint<FunctionType.predicate, "function type">;

def AnyCodeOrDataRefLike : TypeConstraint<Or<[AnyReferenceLike.predicate,
FunctionType.predicate]>, "any code or data reference">;

def RefOrLLVMPtr : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_LLVMPointerType.predicate]>, "fir.ref or fir.llvm_ptr">;
def RefOrLLVMPtr
: TypeConstraint<
Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate,
fir_LLVMPointerType.predicate]>,
"fir.ref or fir.llvm_ptr">;

def AnyBoxLike : TypeConstraint<Or<[fir_BoxType.predicate,
fir_BoxCharType.predicate, fir_BoxProcType.predicate,
Expand All @@ -621,9 +655,12 @@ def BoxOrClassType : TypeConstraint<Or<[fir_BoxType.predicate,
def AnyRefOrBoxLike : TypeConstraint<Or<[AnyReferenceLike.predicate,
AnyBoxLike.predicate, FunctionType.predicate]>,
"any reference or box like">;
def AnyRefOrBox : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_HeapType.predicate, fir_PointerType.predicate,
IsBaseBoxTypePred]>, "any reference or box">;
def AnyRefOrBox
: TypeConstraint<
Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate,
fir_HeapType.predicate, fir_PointerType.predicate,
IsBaseBoxTypePred]>,
"any reference or box">;
def AnyRefOrBoxType : Type<AnyRefOrBox.predicate, "any legal ref or box type">;

def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
Expand Down Expand Up @@ -651,8 +688,11 @@ def AnyCoordinateLike : TypeConstraint<Or<[AnySignlessInteger.predicate,
def AnyCoordinateType : Type<AnyCoordinateLike.predicate, "coordinate type">;

// The legal types of global symbols
def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
FunctionType.predicate]>, "any addressable">;
def AnyAddressableLike
: TypeConstraint<
Or<[fir_ReferenceType.predicate, fir_VolatileReferenceType.predicate,
FunctionType.predicate]>,
"any addressable">;

def ArrayOrBoxOrRecord : TypeConstraint<Or<[fir_SequenceType.predicate,
IsBaseBoxTypePred, fir_RecordType.predicate]>,
Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def hlfir_DeclareOp : hlfir_Op<"declare", [AttrSizedOperandSegments,

/// Given a FIR memory type, and information about non default lower
/// bounds, get the related HLFIR variable type.
static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds);
static mlir::Type getHLFIRVariableType(mlir::Type type, bool hasLowerBounds, bool isVolatile=false);
}];

let hasVerifier = 1;
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/CallInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ class Fortran::lower::CallInterfaceImpl {
if (obj.attrs.test(Attrs::Value))
isValueAttr = true; // TODO: do we want an mlir::Attribute as well?
if (obj.attrs.test(Attrs::Volatile)) {
TODO(loc, "VOLATILE in procedure interface");
// TODO(loc, "VOLATILE in procedure interface");
addMLIRAttr(fir::getVolatileAttrName());
}
// obj.attrs.test(Attrs::Asynchronous) does not impact the way the argument
Expand Down
17 changes: 17 additions & 0 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@

namespace {

/// Determine if a given symbol has the VOLATILE attribute.
static bool isVolatileSymbol(const Fortran::semantics::Symbol &symbol) {
return symbol.GetUltimate().attrs().test(Fortran::semantics::Attr::VOLATILE);
}

/// Lower Designators to HLFIR.
class HlfirDesignatorBuilder {
private:
Expand Down Expand Up @@ -223,6 +228,18 @@ class HlfirDesignatorBuilder {
designatorNode, getConverter().getFoldingContext(),
/*namedConstantSectionsAreAlwaysContiguous=*/false))
return fir::BoxType::get(resultValueType);

// Check if this should be a volatile reference
if constexpr (std::is_same_v<std::decay_t<T>,
Fortran::evaluate::SymbolRef>) {
if (isVolatileSymbol(designatorNode.get()))
return fir::VolatileReferenceType::get(resultValueType);
} else if constexpr (std::is_same_v<std::decay_t<T>,
Fortran::evaluate::Component>) {
if (isVolatileSymbol(designatorNode.GetLastSymbol()))
return fir::VolatileReferenceType::get(resultValueType);
}

// Other designators can be handled as raw addresses.
return fir::ReferenceType::get(resultValueType);
}
Expand Down
9 changes: 7 additions & 2 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,13 @@ std::pair<mlir::Value, mlir::Value> hlfir::genVariableFirBaseShapeAndParams(
auto params = fir::getTypeParams(exv);
typeParams.append(params.begin(), params.end());
}
if (entity.isScalar())
if (entity.isScalar()) {
return {fir::getBase(exv), mlir::Value{}};
if (auto variableInterface = entity.getIfVariableInterface())
}
if (auto variableInterface = entity.getIfVariableInterface()) {
return {fir::getBase(exv),
asEmboxShape(loc, builder, exv, variableInterface.getShape())};
}
return {fir::getBase(exv), builder.createShape(loc, exv)};
}

Expand Down Expand Up @@ -809,6 +811,9 @@ mlir::Type hlfir::getVariableElementType(hlfir::Entity variable) {
} else if (fir::isRecordWithTypeParameters(eleTy)) {
return fir::BoxType::get(eleTy);
}
if (variable.isVolatile()) {
return fir::VolatileReferenceType::get(eleTy);
}
return fir::ReferenceType::get(eleTy);
}

Expand Down
16 changes: 11 additions & 5 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,6 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
auto fromTy = convertType(fromFirTy);
auto toTy = convertType(toFirTy);
mlir::Value op0 = adaptor.getOperands()[0];

if (fromFirTy == toFirTy) {
rewriter.replaceOp(convert, op0);
return mlir::success();
Expand Down Expand Up @@ -3217,6 +3216,9 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

mlir::Type originalLoadTy = load.getMemref().getType();
const bool isVolatile =
mlir::isa<fir::VolatileReferenceType>(originalLoadTy);
mlir::Type llvmLoadTy = convertObjectType(load.getType());
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(load.getType())) {
// fir.box is a special case because it is considered an ssa value in
Expand Down Expand Up @@ -3256,7 +3258,7 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
rewriter.replaceOp(load, newBoxStorage);
} else {
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
load.getLoc(), llvmLoadTy, adaptor.getOperands(), load->getAttrs());
load.getLoc(), llvmLoadTy, adaptor.getOperands()[0], 0, isVolatile);
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
loadOp.setTBAATags(*optionalTag);
else
Expand Down Expand Up @@ -3531,6 +3533,9 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Location loc = store.getLoc();
mlir::Type storeTy = store.getValue().getType();
mlir::Type originalStoreTy = store.getMemref().getType();
const bool isVolatile =
mlir::isa<fir::VolatileReferenceType>(originalStoreTy);
mlir::Value llvmValue = adaptor.getValue();
mlir::Value llvmMemref = adaptor.getMemref();
mlir::LLVM::AliasAnalysisOpInterface newOp;
Expand All @@ -3541,10 +3546,11 @@ struct StoreOpConversion : public fir::FIROpConversion<fir::StoreOp> {
TypePair boxTypePair{boxTy, llvmBoxTy};
mlir::Value boxSize =
computeBoxSize(loc, boxTypePair, llvmValue, rewriter);
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(
loc, llvmMemref, llvmValue, boxSize, /*isVolatile=*/false);
newOp = rewriter.create<mlir::LLVM::MemcpyOp>(loc, llvmMemref, llvmValue,
boxSize, isVolatile);
} else {
newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref);
newOp = rewriter.create<mlir::LLVM::StoreOp>(loc, llvmValue, llvmMemref,
0, isVolatile, false);
}
if (std::optional<mlir::ArrayAttr> optionalTag = store.getTbaa())
newOp.setTBAATags(*optionalTag);
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/CodeGen/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
});
addConversion(
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
addConversion(
[&](fir::VolatileReferenceType ref) { return convertPointerLike(ref); });
addConversion([&](fir::SequenceType sequence) {
return convertSequenceType(sequence);
});
Expand Down
20 changes: 11 additions & 9 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ llvm::LogicalResult fir::AllocaOp::verify() {
if (verifyTypeParamCount(getInType(), numLenParams()))
return emitOpError("LEN params do not correspond to type");
mlir::Type outType = getType();
if (!mlir::isa<fir::ReferenceType>(outType))
if (!mlir::isa<fir::ReferenceType, fir::VolatileReferenceType>(outType))
return emitOpError("must be a !fir.ref type");
return mlir::success();
}
Expand Down Expand Up @@ -305,8 +305,8 @@ static mlir::Type wrapAllocMemResultType(mlir::Type intype) {
// Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER
// 8.5.3 note 1 prohibits ALLOCATABLE procedures as well
// FIR semantics: one may not allocate a memory reference value
if (mlir::isa<fir::ReferenceType, fir::HeapType, fir::PointerType,
mlir::FunctionType>(intype))
if (mlir::isa<fir::ReferenceType, fir::VolatileReferenceType, fir::HeapType,
fir::PointerType, mlir::FunctionType>(intype))
return {};
return fir::HeapType::get(intype);
}
Expand Down Expand Up @@ -441,8 +441,9 @@ llvm::LogicalResult fir::ArrayCoorOp::verify() {
if (sliceTy.getRank() != arrDim)
return emitOpError("rank of dimension in slice mismatched");
}
if (!validTypeParams(getMemref().getType(), getTypeparams()))
if (!validTypeParams(getMemref().getType(), getTypeparams())) {
return emitOpError("invalid type parameters");
}

return mlir::success();
}
Expand Down Expand Up @@ -823,8 +824,8 @@ void fir::ArrayCoorOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//

static mlir::Type adjustedElementType(mlir::Type t) {
if (auto ty = mlir::dyn_cast<fir::ReferenceType>(t)) {
auto eleTy = ty.getEleTy();
if (fir::isa_ref_type(t)) {
mlir::Type eleTy = fir::dyn_cast_ptrEleTy(t);
if (fir::isa_char(eleTy))
return eleTy;
if (fir::isa_derived(eleTy))
Expand Down Expand Up @@ -1364,9 +1365,10 @@ bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) {
}

bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) {
return mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType,
fir::LLVMPointerType, mlir::MemRefType, mlir::FunctionType,
fir::TypeDescType, mlir::LLVM::LLVMPointerType>(ty);
return mlir::isa<fir::ReferenceType, fir::VolatileReferenceType,
fir::PointerType, fir::HeapType, fir::LLVMPointerType,
mlir::MemRefType, mlir::FunctionType, fir::TypeDescType,
mlir::LLVM::LLVMPointerType>(ty);
}

static std::optional<mlir::Type> getVectorElementType(mlir::Type ty) {
Expand Down
Loading