diff --git a/flang/include/flang/Semantics/semantics.h b/flang/include/flang/Semantics/semantics.h index 0dbca51ee0dcf..12220cc7f0edd 100644 --- a/flang/include/flang/Semantics/semantics.h +++ b/flang/include/flang/Semantics/semantics.h @@ -162,7 +162,6 @@ class SemanticsContext { warningsAreErrors_ = x; return *this; } - SemanticsContext &set_debugModuleWriter(bool x) { debugModuleWriter_ = x; return *this; diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp index b01147606a99c..9b48432e049b9 100644 --- a/flang/lib/Semantics/check-cuda.cpp +++ b/flang/lib/Semantics/check-cuda.cpp @@ -761,14 +761,13 @@ void CUDAChecker::Enter(const parser::AssignmentStmt &x) { // legal. if (nbLhs == 0 && nbRhs > 1) { context_.Say(lhsLoc, - "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US); + "More than one reference to a CUDA object on the right hand side of the assignment"_err_en_US); } - if (Fortran::evaluate::HasCUDADeviceAttrs(assign->lhs) && - Fortran::evaluate::HasCUDAImplicitTransfer(assign->rhs)) { + if (evaluate::HasCUDADeviceAttrs(assign->lhs) && + evaluate::HasCUDAImplicitTransfer(assign->rhs)) { if (GetNbOfCUDAManagedOrUnifiedSymbols(assign->lhs) == 1 && - GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && - GetNbOfCUDADeviceSymbols(assign->rhs) == 1) { + GetNbOfCUDAManagedOrUnifiedSymbols(assign->rhs) == 1 && nbRhs == 1) { return; // This is a special case handled on the host. } context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US); diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index a2f2906af10b8..d769f221b1983 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -2081,7 +2081,7 @@ static bool ConflictsWithIntrinsicAssignment(const Procedure &proc) { } static bool ConflictsWithIntrinsicOperator( - const GenericKind &kind, const Procedure &proc) { + const GenericKind &kind, const Procedure &proc, SemanticsContext &context) { if (!kind.IsIntrinsicOperator()) { return false; } @@ -2167,7 +2167,7 @@ bool CheckHelper::CheckDefinedOperator(SourceName opName, GenericKind kind, } } else if (!checkDefinedOperatorArgs(opName, specific, proc)) { return false; // error was reported - } else if (ConflictsWithIntrinsicOperator(kind, proc)) { + } else if (ConflictsWithIntrinsicOperator(kind, proc, context_)) { msg = "%s function '%s' conflicts with intrinsic operator"_err_en_US; } if (msg) { diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp index 14473724f0f40..92dbe0e5da11c 100644 --- a/flang/lib/Semantics/expression.cpp +++ b/flang/lib/Semantics/expression.cpp @@ -165,10 +165,17 @@ class ArgumentAnalyzer { bool CheckForNullPointer(const char *where = "as an operand here"); bool CheckForAssumedRank(const char *where = "as an operand here"); + bool AnyCUDADeviceData() const; + // Returns true if an interface has been defined for an intrinsic operator + // with one or more device operands. + bool HasDeviceDefinedIntrinsicOpOverride(const char *) const; + template bool HasDeviceDefinedIntrinsicOpOverride(E opr) const { + return HasDeviceDefinedIntrinsicOpOverride( + context_.context().languageFeatures().GetNames(opr)); + } + // Find and return a user-defined operator or report an error. // The provided message is used if there is no such operator. - // If a definedOpSymbolPtr is provided, the caller must check - // for its accessibility. MaybeExpr TryDefinedOp( const char *, parser::MessageFixedText, bool isUserOp = false); template @@ -183,6 +190,8 @@ class ArgumentAnalyzer { void Dump(llvm::raw_ostream &); private: + bool HasDeviceDefinedIntrinsicOpOverride( + const std::vector &) const; MaybeExpr TryDefinedOp( const std::vector &, parser::MessageFixedText); MaybeExpr TryBoundOp(const Symbol &, int passIndex); @@ -202,7 +211,7 @@ class ArgumentAnalyzer { void SayNoMatch( const std::string &, bool isAssignment = false, bool isAmbiguous = false); std::string TypeAsFortran(std::size_t); - bool AnyUntypedOrMissingOperand(); + bool AnyUntypedOrMissingOperand() const; ExpressionAnalyzer &context_; ActualArguments actuals_; @@ -4497,13 +4506,20 @@ void ArgumentAnalyzer::Analyze( bool ArgumentAnalyzer::IsIntrinsicRelational(RelationalOperator opr, const DynamicType &leftType, const DynamicType &rightType) const { CHECK(actuals_.size() == 2); - return semantics::IsIntrinsicRelational( - opr, leftType, GetRank(0), rightType, GetRank(1)); + return !(context_.context().languageFeatures().IsEnabled( + common::LanguageFeature::CUDA) && + HasDeviceDefinedIntrinsicOpOverride(opr)) && + semantics::IsIntrinsicRelational( + opr, leftType, GetRank(0), rightType, GetRank(1)); } bool ArgumentAnalyzer::IsIntrinsicNumeric(NumericOperator opr) const { std::optional leftType{GetType(0)}; - if (actuals_.size() == 1) { + if (context_.context().languageFeatures().IsEnabled( + common::LanguageFeature::CUDA) && + HasDeviceDefinedIntrinsicOpOverride(AsFortran(opr))) { + return false; + } else if (actuals_.size() == 1) { if (IsBOZLiteral(0)) { return opr == NumericOperator::Add; // unary '+' } else { @@ -4617,6 +4633,53 @@ bool ArgumentAnalyzer::CheckForAssumedRank(const char *where) { return true; } +bool ArgumentAnalyzer::AnyCUDADeviceData() const { + for (const std::optional &arg : actuals_) { + if (arg) { + if (const Expr *expr{arg->UnwrapExpr()}) { + if (HasCUDADeviceAttrs(*expr)) { + return true; + } + } + } + } + return false; +} + +// Some operations can be defined with explicit non-type-bound interfaces +// that would erroneously conflict with intrinsic operations in their +// types and ranks but have one or more dummy arguments with the DEVICE +// attribute. +bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( + const char *opr) const { + if (AnyCUDADeviceData() && !AnyUntypedOrMissingOperand()) { + std::string oprNameString{"operator("s + opr + ')'}; + parser::CharBlock oprName{oprNameString}; + parser::Messages buffer; + auto restorer{context_.GetContextualMessages().SetMessages(buffer)}; + const auto &scope{context_.context().FindScope(source_)}; + if (Symbol * generic{scope.FindSymbol(oprName)}) { + parser::Name name{generic->name(), generic}; + const Symbol *resultSymbol{nullptr}; + if (context_.AnalyzeDefinedOp( + name, ActualArguments{actuals_}, resultSymbol)) { + return true; + } + } + } + return false; +} + +bool ArgumentAnalyzer::HasDeviceDefinedIntrinsicOpOverride( + const std::vector &oprNames) const { + for (const char *opr : oprNames) { + if (HasDeviceDefinedIntrinsicOpOverride(opr)) { + return true; + } + } + return false; +} + MaybeExpr ArgumentAnalyzer::TryDefinedOp( const char *opr, parser::MessageFixedText error, bool isUserOp) { if (AnyUntypedOrMissingOperand()) { @@ -5135,7 +5198,7 @@ std::string ArgumentAnalyzer::TypeAsFortran(std::size_t i) { } } -bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() { +bool ArgumentAnalyzer::AnyUntypedOrMissingOperand() const { for (const auto &actual : actuals_) { if (!actual || (!actual->GetType() && !IsBareNullPointer(actual->UnwrapExpr()))) { diff --git a/flang/test/Semantics/bug1214.cuf b/flang/test/Semantics/bug1214.cuf new file mode 100644 index 0000000000000..114fad15ea500 --- /dev/null +++ b/flang/test/Semantics/bug1214.cuf @@ -0,0 +1,49 @@ +! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s +module overrides + type realResult + real a + end type + interface operator(*) + procedure :: multHostDevice, multDeviceHost + end interface + interface assignment(=) + procedure :: assignHostResult, assignDeviceResult + end interface + contains + elemental function multHostDevice(x, y) result(result) + real, intent(in) :: x + real, intent(in), device :: y + type(realResult) result + result%a = x * y + end + elemental function multDeviceHost(x, y) result(result) + real, intent(in), device :: x + real, intent(in) :: y + type(realResult) result + result%a = x * y + end + elemental subroutine assignHostResult(lhs, rhs) + real, intent(out) :: lhs + type(realResult), intent(in) :: rhs + lhs = rhs%a + end + elemental subroutine assignDeviceResult(lhs, rhs) + real, intent(out), device :: lhs + type(realResult), intent(in) :: rhs + lhs = rhs%a + end +end + +program p + use overrides + real, device :: da, db + real :: ha, hb +!CHECK: CALL assigndeviceresult(db,multhostdevice(2._4,da)) + db = 2. * da +!CHECK: CALL assigndeviceresult(db,multdevicehost(da,2._4)) + db = da * 2. +!CHECK: CALL assignhostresult(ha,multhostdevice(2._4,da)) + ha = 2. * da +!CHECK: CALL assignhostresult(ha,multdevicehost(da,2._4)) + ha = da * 2. +end diff --git a/flang/test/Semantics/cuf11.cuf b/flang/test/Semantics/cuf11.cuf index 554ac258e5510..1f5beb02aee45 100644 --- a/flang/test/Semantics/cuf11.cuf +++ b/flang/test/Semantics/cuf11.cuf @@ -16,7 +16,7 @@ subroutine sub1() real, device :: adev(10), bdev(10) real :: ahost(10) -!ERROR: More than one reference to a CUDA object on the right hand side of the assigment +!ERROR: More than one reference to a CUDA object on the right hand side of the assignment ahost = adev + bdev ahost = adev + adev