@@ -105,13 +105,39 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
105
105
return rewriter.notifyMatchFailure (
106
106
op, " requires elementwise op on ranked tensors" );
107
107
108
- auto rank = cast<RankedTensorType>(op->getResult (0 ).getType ()).getRank ();
109
- SmallVector<AffineMap, 3 > indexingMaps (
110
- op->getNumResults () + op->getNumOperands (),
111
- rewriter.getMultiDimIdentityMap (rank));
112
- SmallVector<utils::IteratorType, 6 > iteratorTypes (
108
+ auto resTy = cast<RankedTensorType>(op->getResult (0 ).getType ());
109
+ auto rank = resTy.getRank ();
110
+
111
+ // Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0.
112
+ AffineMap scalarMap = AffineMap::get (/* dimCount=*/ rank, /* symbolCount=*/ 0 ,
113
+ /* results=*/ {}, rewriter.getContext ());
114
+ AffineMap idMap = rewriter.getMultiDimIdentityMap (rank);
115
+
116
+ // Create indexing maps: one per operand, one per result.
117
+ SmallVector<AffineMap, 6 > indexingMaps;
118
+ indexingMaps.reserve (op->getNumOperands () + op->getNumResults ());
119
+
120
+ for (Value v : op->getOperands ()) {
121
+ Type ty = v.getType ();
122
+ if (isScalarLike (ty))
123
+ indexingMaps.push_back (scalarMap);
124
+ else if (auto rt = dyn_cast<RankedTensorType>(ty)) {
125
+ indexingMaps.push_back (idMap);
126
+ } else
127
+ return rewriter.notifyMatchFailure (
128
+ op,
129
+ " unsupported operand type (expected scalar-like or ranked tensor)" );
130
+ }
131
+
132
+ for (Value r : op->getResults ()) {
133
+ (void )r;
134
+ indexingMaps.push_back (idMap); // results use identity map.
135
+ }
136
+
137
+ SmallVector<utils::IteratorType, 4 > iteratorTypes (
113
138
rank, utils::IteratorType::parallel);
114
- auto outputs = getOrCreateOperandsMatchingResultTypes (rewriter, op);
139
+ SmallVector<Value, 2 > outputs =
140
+ getOrCreateOperandsMatchingResultTypes (rewriter, op);
115
141
rewriter.replaceOpWithNewOp <linalg::GenericOp>(
116
142
op, /* resultTensorTypes=*/ op->getResultTypes (),
117
143
/* inputs=*/ op->getOperands (),
@@ -120,14 +146,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
120
146
/* iteratorTypes=*/ iteratorTypes,
121
147
/* bodyBuilder=*/
122
148
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
123
- auto resultTypes = llvm::to_vector<6 >(
149
+ SmallVector<Type> resultEltTys = llvm::to_vector<6 >(
124
150
llvm::map_range (op->getResultTypes (), [](Type type) {
125
151
return cast<TensorType>(type).getElementType ();
126
152
}));
127
- auto *scalarOp =
153
+ Operation *scalarOp =
128
154
builder.create (loc, op->getName ().getIdentifier (),
129
155
regionArgs.take_front (op->getNumOperands ()),
130
- resultTypes , op->getAttrs ());
156
+ resultEltTys , op->getAttrs ());
131
157
linalg::YieldOp::create (builder, loc, scalarOp->getResults ());
132
158
});
133
159
return success ();
0 commit comments