Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions loop_nest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallString.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/LLVM.h"
#include "sequence.h"
#include "util.h"

Expand Down Expand Up @@ -106,7 +107,7 @@ const IterationSpace &IterationSpaceAnalysis::ComputeIterationSpace(
loop_names.reserve(num_loops);

for (mlir::Attribute attr : compute_op.Loops()) {
LoopAttr loop = attr.cast<LoopAttr>();
LoopAttr loop = mlir::cast<LoopAttr>(attr);
loop_names.push_back(loop.name());
exprs.push_back(loop.iter());
}
Expand Down Expand Up @@ -265,10 +266,10 @@ mlir::LogicalResult VerifyLoopNestWellFormed(
int domain_size = shape.Dimensions().size();

for (int i = 0, e = loop_nest.size(); i < e; ++i) {
LoopAttr loop = loop_nest[i].dyn_cast<LoopAttr>();
LoopAttr loop = mlir::dyn_cast<LoopAttr>(loop_nest[i]);
// Ensure that symbols are unique in the loop nest.
for (int j = 0; j < i; ++j) {
if (loop.name() == loop_nest[j].cast<LoopAttr>().name()) {
if (loop.name() == mlir::cast<LoopAttr>(loop_nest[j]).name()) {
return mlir::emitError(loc)
<< "name " << loop.name() << " used twice in the same loop nest";
}
Expand Down Expand Up @@ -317,7 +318,7 @@ class LoopNestState {
int common_prefix_size = 0;
for (int e = std::min(loop_nest.size(), open_loops_.size());
common_prefix_size < e; ++common_prefix_size) {
LoopAttr loop = loop_nest[common_prefix_size].cast<LoopAttr>();
LoopAttr loop = mlir::cast<LoopAttr>(loop_nest[common_prefix_size]);
if (loop.name() != open_loops_[common_prefix_size]) break;
}

Expand All @@ -326,7 +327,7 @@ class LoopNestState {

// Add remaining loops to the current fusion prefix.
for (mlir::Attribute attribute : loop_nest.drop_front(common_prefix_size)) {
LoopAttr loop = attribute.cast<LoopAttr>();
LoopAttr loop = mlir::cast<LoopAttr>(attribute);

if (closed_loops_.count(loop.name()) > 0) {
return op.EmitError() << "occurrences of loop " << loop.name()
Expand Down Expand Up @@ -500,7 +501,7 @@ static mlir::LogicalResult VerifyLoopRanges(
const LoopFusionAnalysis &fusion_analysis,
const SequenceAnalysis &sequence_analysis) {
for (mlir::Attribute attr : loop_nest) {
LoopAttr loop = attr.cast<LoopAttr>();
LoopAttr loop = mlir::cast<LoopAttr>(attr);
const LoopFusionClass &fusion_class = fusion_analysis.GetClass(loop.name());
for (const auto &dimension : fusion_class.getDomain()) {
if (sequence_analysis.IsBefore(op, dimension.value.defining_op())) {
Expand Down Expand Up @@ -670,7 +671,7 @@ mlir::LogicalResult LoopFusionAnalysis::Init(
// Returns the unroll factor of the `pos`-th loop in the given compute op.
// Expects the op to have a well-formed loop nest attribute.
static unsigned ExtractUnrollFactor(const ComputeOpInstance &op, unsigned pos) {
auto loop = op.Loops()[pos].cast<LoopAttr>();
auto loop = mlir::cast<LoopAttr>(op.Loops()[pos]);
if (mlir::IntegerAttr unroll_factor = loop.unroll()) {
return unroll_factor.getInt();
}
Expand All @@ -686,7 +687,7 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
loop_names.reserve(loop_pos);
iter_exprs.reserve(loop_pos);
for (int i = 0; i < loop_pos; ++i) {
LoopAttr loop = op.Loops()[i].cast<LoopAttr>();
LoopAttr loop = mlir::cast<LoopAttr>(op.Loops()[i]);
loop_names.push_back(loop.name());
iter_exprs.push_back(loop.iter());
}
Expand All @@ -698,7 +699,7 @@ mlir::LogicalResult LoopFusionAnalysis::RegisterLoop(
auto loop_nest_mapping =
MappingAttr::get(op.context(), op.domain_size(), iter_exprs);

LoopAttr loop = op.Loops()[loop_pos].cast<LoopAttr>();
LoopAttr loop = mlir::cast<LoopAttr>(op.Loops()[loop_pos]);
auto [it, was_inserted] =
fusion_classes_.try_emplace(loop.name(), loop.name(), op, loop_nest);
LoopFusionClass &fusion_class = it->second;
Expand Down
7 changes: 4 additions & 3 deletions sair_op_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "sair_attributes.h"
#include "sair_types.h"
Expand Down Expand Up @@ -67,7 +68,7 @@ class ValueOperand {

// Returns the type of the value referenced by the operand.
ValueType GetType() const {
return operand_->get().getType().cast<ValueType>();
return mlir::cast<ValueType>(operand_->get().getType());
}

// Returns the operation owning the operand.
Expand Down Expand Up @@ -510,14 +511,14 @@ class OperandInstance {
inline auto OpInstance::getDomain() const {
return llvm::map_range(GetDomainValues(), [](mlir::Value v) {
OpInstance dim_op = OpInstance(llvm::cast<SairOp>(v.getDefiningOp()));
return ResultInstance(dim_op, v.cast<OpResult>().getResultNumber());
return ResultInstance(dim_op, mlir::cast<OpResult>(v).getResultNumber());
});
}

inline auto OpInstance::DomainWithDependencies() const {
DomainShapeAttr shape = GetShape();
return llvm::map_range(llvm::enumerate(GetDomainValues()), [=](auto p) {
auto value = p.value().template cast<mlir::OpResult>();
auto value = mlir::cast<mlir::OpResult>(p.value());
OpInstance dim_op = OpInstance(llvm::cast<SairOp>(value.getOwner()));
ValueAccessInstance access = {
.value = ResultInstance(dim_op, value.getResultNumber()),
Expand Down
3 changes: 2 additions & 1 deletion sequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
#include "mlir/Support/LLVM.h"
#include "loop_nest.h"
#include "sair_op_interfaces.h"
#include "sair_ops.h"
Expand Down Expand Up @@ -454,7 +455,7 @@ ProgramPoint SequenceAnalysis::FindInsertionPoint(
llvm::ArrayRef<mlir::Attribute> new_loops = new_op.Loops();
num_common_loops = std::min<int>(new_loops.size(), num_common_loops);
for (; num_common_loops > 0; --num_common_loops) {
auto loop = new_loops[num_common_loops - 1].cast<LoopAttr>();
auto loop = mlir::cast<LoopAttr>(new_loops[num_common_loops - 1]);
if (loop.name() == start_loop_nest[num_common_loops - 1]) break;
}
if (num_common_loops <= num_loops) break;
Expand Down
3 changes: 2 additions & 1 deletion test/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "test/passes.h"

#include "mlir/IR/Builders.h"
#include "mlir/Support/LLVM.h"
#include "sair_attributes.h"
#include "sair_dialect.h"

Expand All @@ -34,7 +35,7 @@ static llvm::SmallVector<T, 4> GetAttrVector(llvm::StringRef name,
llvm::SmallVector<T, 4> vector;
vector.reserve(array.size());
for (mlir::Attribute element : array.getValue()) {
vector.push_back(element.cast<T>());
vector.push_back(mlir::cast<T>(element));
}
return vector;
}
Expand Down
3 changes: 2 additions & 1 deletion transforms/lower_proj_any.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "loop_nest.h"
#include "sair_dialect.h"
#include "sair_ops.h"
Expand Down Expand Up @@ -86,7 +87,7 @@ class LowerProjAny : public impl::LowerProjAnyPassBase<LowerProjAny> {
// guaranteed to be identity by verifiers.
for (MappingExpr expr :
mapping.Dimensions().drop_front(num_common_loops)) {
if (!expr.isa<MappingNoneExpr>()) {
if (!mlir::isa<MappingNoneExpr>(expr)) {
return op.emitError()
<< "cannot lower operation to proj_last on scalars";
}
Expand Down
3 changes: 2 additions & 1 deletion util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define THIRD_PARTY_SAIR_TRANSFORMS_UTIL_H_

#include "mlir/IR/Builders.h"
#include "mlir/Support/LLVM.h"
#include "sair_attributes.h"
#include "sair_op_interfaces.h"

Expand Down Expand Up @@ -95,7 +96,7 @@ std::function<mlir::ArrayAttr(mlir::ArrayAttr)> MkArrayAttrMapper(
llvm::SmallVector<mlir::Attribute> output;
output.reserve(array.size());
for (mlir::Attribute attr : array.getValue()) {
output.push_back(scalar_fn(attr.cast<T>()));
output.push_back(scalar_fn(mlir::cast<T>(attr)));
}
return mlir::ArrayAttr::get(array.getContext(), output);
};
Expand Down