Skip to content

Commit 721f598

Browse files
chsiggcopybara-github
authored andcommitted
NFC: Use the free function variants for dyn_cast/cast/isa/....
The member functions in Type/Attribute/Value/Location/AffineExpr got [removed](llvm/llvm-project@0078cf7). PiperOrigin-RevId: 748166276
1 parent 9958390 commit 721f598

File tree

6 files changed

+22
-16
lines changed

6 files changed

+22
-16
lines changed

loop_nest.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/ADT/SetVector.h"
1818
#include "llvm/ADT/SmallString.h"
1919
#include "mlir/IR/Builders.h"
20+
#include "mlir/Support/LLVM.h"
2021
#include "sequence.h"
2122
#include "util.h"
2223

@@ -106,7 +107,7 @@ const IterationSpace &IterationSpaceAnalysis::ComputeIterationSpace(
106107
loop_names.reserve(num_loops);
107108

108109
for (mlir::Attribute attr : compute_op.Loops()) {
109-
LoopAttr loop = attr.cast<LoopAttr>();
110+
LoopAttr loop = mlir::cast<LoopAttr>(attr);
110111
loop_names.push_back(loop.name());
111112
exprs.push_back(loop.iter());
112113
}
@@ -265,10 +266,10 @@ mlir::LogicalResult VerifyLoopNestWellFormed(
265266
int domain_size = shape.Dimensions().size();
266267

267268
for (int i = 0, e = loop_nest.size(); i < e; ++i) {
268-
LoopAttr loop = loop_nest[i].dyn_cast<LoopAttr>();
269+
LoopAttr loop = mlir::dyn_cast<LoopAttr>(loop_nest[i]);
269270
// Ensure that symbols are unique in the loop nest.
270271
for (int j = 0; j < i; ++j) {
271-
if (loop.name() == loop_nest[j].cast<LoopAttr>().name()) {
272+
if (loop.name() == mlir::cast<LoopAttr>(loop_nest[j]).name()) {
272273
return mlir::emitError(loc)
273274
<< "name " << loop.name() << " used twice in the same loop nest";
274275
}
@@ -317,7 +318,7 @@ class LoopNestState {
317318
int common_prefix_size = 0;
318319
for (int e = std::min(loop_nest.size(), open_loops_.size());
319320
common_prefix_size < e; ++common_prefix_size) {
320-
LoopAttr loop = loop_nest[common_prefix_size].cast<LoopAttr>();
321+
LoopAttr loop = mlir::cast<LoopAttr>(loop_nest[common_prefix_size]);
321322
if (loop.name() != open_loops_[common_prefix_size]) break;
322323
}
323324

@@ -326,7 +327,7 @@ class LoopNestState {
326327

327328
// Add remaining loops to the current fusion prefix.
328329
for (mlir::Attribute attribute : loop_nest.drop_front(common_prefix_size)) {
329-
LoopAttr loop = attribute.cast<LoopAttr>();
330+
LoopAttr loop = mlir::cast<LoopAttr>(attribute);
330331

331332
if (closed_loops_.count(loop.name()) > 0) {
332333
return op.EmitError() << "occurrences of loop " << loop.name()
@@ -500,7 +501,7 @@ static mlir::LogicalResult VerifyLoopRanges(
500501
const LoopFusionAnalysis &fusion_analysis,
501502
const SequenceAnalysis &sequence_analysis) {
502503
for (mlir::Attribute attr : loop_nest) {
503-
LoopAttr loop = attr.cast<LoopAttr>();
504+
LoopAttr loop = mlir::cast<LoopAttr>(attr);
504505
const LoopFusionClass &fusion_class = fusion_analysis.GetClass(loop.name());
505506
for (const auto &dimension : fusion_class.getDomain()) {
506507
if (sequence_analysis.IsBefore(op, dimension.value.defining_op())) {
@@ -670,7 +671,7 @@ mlir::LogicalResult LoopFusionAnalysis::Init(
670671
// Returns the unroll factor of the `pos`-th loop in the given compute op.
671672
// Expects the op to have a well-formed loop nest attribute.
672673
static unsigned ExtractUnrollFactor(const ComputeOpInstance &op, unsigned pos) {
673-
auto loop = op.Loops()[pos].cast<LoopAttr>();
674+
auto loop = mlir::cast<LoopAttr>(op.Loops()[pos]);
674675
if (mlir::IntegerAttr unroll_factor = loop.unroll()) {
675676
return unroll_factor.getInt();
676677
}
@@ -686,7 +687,7 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
686687
loop_names.reserve(loop_pos);
687688
iter_exprs.reserve(loop_pos);
688689
for (int i = 0; i < loop_pos; ++i) {
689-
LoopAttr loop = op.Loops()[i].cast<LoopAttr>();
690+
LoopAttr loop = mlir::cast<LoopAttr>(op.Loops()[i]);
690691
loop_names.push_back(loop.name());
691692
iter_exprs.push_back(loop.iter());
692693
}
@@ -698,7 +699,7 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
698699
auto loop_nest_mapping =
699700
MappingAttr::get(op.context(), op.domain_size(), iter_exprs);
700701

701-
LoopAttr loop = op.Loops()[loop_pos].cast<LoopAttr>();
702+
LoopAttr loop = mlir::cast<LoopAttr>(op.Loops()[loop_pos]);
702703
auto [it, was_inserted] =
703704
fusion_classes_.try_emplace(loop.name(), loop.name(), op, loop_nest);
704705
LoopFusionClass &fusion_class = it->second;

sair_op_interfaces.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/IR/Operation.h"
2727
#include "mlir/IR/Value.h"
2828
#include "mlir/Interfaces/SideEffectInterfaces.h"
29+
#include "mlir/Support/LLVM.h"
2930
#include "mlir/Support/LogicalResult.h"
3031
#include "sair_attributes.h"
3132
#include "sair_types.h"
@@ -67,7 +68,7 @@ class ValueOperand {
6768

6869
// Returns the type of the value referenced by the operand.
6970
ValueType GetType() const {
70-
return operand_->get().getType().cast<ValueType>();
71+
return mlir::cast<ValueType>(operand_->get().getType());
7172
}
7273

7374
// Returns the operation owning the operand.
@@ -510,14 +511,14 @@ class OperandInstance {
510511
inline auto OpInstance::getDomain() const {
511512
return llvm::map_range(GetDomainValues(), [](mlir::Value v) {
512513
OpInstance dim_op = OpInstance(llvm::cast<SairOp>(v.getDefiningOp()));
513-
return ResultInstance(dim_op, v.cast<OpResult>().getResultNumber());
514+
return ResultInstance(dim_op, mlir::cast<OpResult>(v).getResultNumber());
514515
});
515516
}
516517

517518
inline auto OpInstance::DomainWithDependencies() const {
518519
DomainShapeAttr shape = GetShape();
519520
return llvm::map_range(llvm::enumerate(GetDomainValues()), [=](auto p) {
520-
auto value = p.value().template cast<mlir::OpResult>();
521+
auto value = mlir::cast<mlir::OpResult>(p.value());
521522
OpInstance dim_op = OpInstance(llvm::cast<SairOp>(value.getOwner()));
522523
ValueAccessInstance access = {
523524
.value = ResultInstance(dim_op, value.getResultNumber()),

sequence.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "llvm/ADT/iterator_range.h"
2222
#include "llvm/Support/Debug.h"
23+
#include "mlir/Support/LLVM.h"
2324
#include "loop_nest.h"
2425
#include "sair_op_interfaces.h"
2526
#include "sair_ops.h"
@@ -454,7 +455,7 @@ ProgramPoint SequenceAnalysis::FindInsertionPoint(
454455
llvm::ArrayRef<mlir::Attribute> new_loops = new_op.Loops();
455456
num_common_loops = std::min<int>(new_loops.size(), num_common_loops);
456457
for (; num_common_loops > 0; --num_common_loops) {
457-
auto loop = new_loops[num_common_loops - 1].cast<LoopAttr>();
458+
auto loop = mlir::cast<LoopAttr>(new_loops[num_common_loops - 1]);
458459
if (loop.name() == start_loop_nest[num_common_loops - 1]) break;
459460
}
460461
if (num_common_loops <= num_loops) break;

test/passes.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "test/passes.h"
1616

1717
#include "mlir/IR/Builders.h"
18+
#include "mlir/Support/LLVM.h"
1819
#include "sair_attributes.h"
1920
#include "sair_dialect.h"
2021

@@ -34,7 +35,7 @@ static llvm::SmallVector<T, 4> GetAttrVector(llvm::StringRef name,
3435
llvm::SmallVector<T, 4> vector;
3536
vector.reserve(array.size());
3637
for (mlir::Attribute element : array.getValue()) {
37-
vector.push_back(element.cast<T>());
38+
vector.push_back(mlir::cast<T>(element));
3839
}
3940
return vector;
4041
}

transforms/lower_proj_any.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Func/IR/FuncOps.h"
1616
#include "mlir/Dialect/SCF/IR/SCF.h"
1717
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Support/LLVM.h"
1819
#include "loop_nest.h"
1920
#include "sair_dialect.h"
2021
#include "sair_ops.h"
@@ -86,7 +87,7 @@ class LowerProjAny : public impl::LowerProjAnyPassBase<LowerProjAny> {
8687
// guaranteed to be identity by verifiers.
8788
for (MappingExpr expr :
8889
mapping.Dimensions().drop_front(num_common_loops)) {
89-
if (!expr.isa<MappingNoneExpr>()) {
90+
if (!mlir::isa<MappingNoneExpr>(expr)) {
9091
return op.emitError()
9192
<< "cannot lower operation to proj_last on scalars";
9293
}

util.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define THIRD_PARTY_SAIR_TRANSFORMS_UTIL_H_
1717

1818
#include "mlir/IR/Builders.h"
19+
#include "mlir/Support/LLVM.h"
1920
#include "sair_attributes.h"
2021
#include "sair_op_interfaces.h"
2122

@@ -95,7 +96,7 @@ std::function<mlir::ArrayAttr(mlir::ArrayAttr)> MkArrayAttrMapper(
9596
llvm::SmallVector<mlir::Attribute> output;
9697
output.reserve(array.size());
9798
for (mlir::Attribute attr : array.getValue()) {
98-
output.push_back(scalar_fn(attr.cast<T>()));
99+
output.push_back(scalar_fn(mlir::cast<T>(attr)));
99100
}
100101
return mlir::ArrayAttr::get(array.getContext(), output);
101102
};

0 commit comments

Comments
 (0)