@@ -12427,6 +12427,198 @@ class DecomposeAtenRoundDecimalsOp
12427
12427
};
12428
12428
} // namespace
12429
12429
12430
+ namespace {
12431
+ class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
12432
+ public:
12433
+ using OpRewritePattern<AtenAsStridedOp>::OpRewritePattern;
12434
+ LogicalResult matchAndRewrite(AtenAsStridedOp op,
12435
+ PatternRewriter &rewriter) const override {
12436
+
12437
+ // The `aten.as_strided` operation is decomposed into a series of
12438
+ // operations that compute the indices based on the provided sizes and
12439
+ // strides, and then index into the flattened input tensor as follows:
12440
+
12441
+ // input_flat = input.view(-1)
12442
+ //
12443
+ // for dim, s in enumerate(self.size):
12444
+ // arange = torch.arange(s)
12445
+ // view_shape = []
12446
+ // for i in range(len(self.size)):
12447
+ // if i == dim:
12448
+ // view_shape.append(-1)
12449
+ // else:
12450
+ // view_shape.append(1)
12451
+ // arange = arange.view(view_shape)
12452
+ // if dim != 0:
12453
+ // idx = idx + arange * self.stride[dim]
12454
+ //
12455
+ // # Flatten indices and add offset
12456
+ // final_indices = idx.reshape(-1) + self.storage_offset
12457
+ //
12458
+ // # Index the flattened input tensor
12459
+ // output = input_flat[final_indices]
12460
+ //
12461
+ // # Reshape to desired output size
12462
+ // return output.view(self.size)
12463
+
12464
+ Location loc = op.getLoc();
12465
+ MLIRContext *context = op->getContext();
12466
+ Value input = op.getSelf();
12467
+ auto inputType = dyn_cast<BaseTensorType>(input.getType());
12468
+
12469
+ if (!inputType || !inputType.hasSizes() || !inputType.areAllSizesKnown())
12470
+ return rewriter.notifyMatchFailure(op, "input must have known sizes");
12471
+
12472
+ SmallVector<int64_t> sizesInts;
12473
+ if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizesInts)))
12474
+ return rewriter.notifyMatchFailure(
12475
+ op, "sizes must be a list of constant ints");
12476
+
12477
+ SmallVector<int64_t> stridesInts;
12478
+ if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stridesInts)))
12479
+ return rewriter.notifyMatchFailure(
12480
+ op, "strides must be a list of constant ints");
12481
+
12482
+ int64_t storageOffset = 0;
12483
+ if (!isa<Torch::NoneType>(op.getStorageOffset().getType())) {
12484
+ if (!matchPattern(op.getStorageOffset(),
12485
+ m_TorchConstantInt(&storageOffset)))
12486
+ return rewriter.notifyMatchFailure(
12487
+ op, "storage_offset must be a constant integer");
12488
+ }
12489
+
12490
+ ArrayRef<int64_t> inputSizes = inputType.getSizes();
12491
+ int64_t inputRank = inputSizes.size();
12492
+ int64_t resultRank = sizesInts.size();
12493
+
12494
+ Value cstZero =
12495
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
12496
+ if (inputRank > 1) {
12497
+ // If the input is not a 1-d tensor, we need to flatten it
12498
+ // to a 1D tensor before applying the strided indexing.
12499
+ int64_t flattenedInputSize = 1;
12500
+ for (int64_t size : inputSizes)
12501
+ flattenedInputSize *= size;
12502
+
12503
+ auto flattenedInputTy =
12504
+ cast<BaseTensorType>(inputType.getWithSizesAndDtype(
12505
+ {flattenedInputSize}, inputType.getOptionalDtype()));
12506
+
12507
+ Value end = rewriter.create<ConstantIntOp>(
12508
+ loc, rewriter.getI64IntegerAttr(inputRank - 1));
12509
+ input = rewriter.create<AtenFlattenUsingIntsOp>(loc, flattenedInputTy,
12510
+ input, cstZero, end);
12511
+ }
12512
+
12513
+ Value cstOne =
12514
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
12515
+ Value cstMinusOne =
12516
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
12517
+
12518
+ SmallVector<int64_t> viewShapeInts(resultRank, 1);
12519
+ SmallVector<Value> viewShapeListElems(resultRank, cstOne);
12520
+
12521
+ auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
12522
+ Value finalIndices;
12523
+ for (unsigned dim = 0; dim < sizesInts.size(); dim++) {
12524
+ int64_t size = sizesInts[dim];
12525
+ Value cstNone = rewriter.create<ConstantNoneOp>(loc);
12526
+ Value end =
12527
+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(size));
12528
+
12529
+ auto arangeType =
12530
+ ValueTensorType::get(context, llvm::ArrayRef(size), si64Type);
12531
+ Value index = rewriter.create<Torch::AtenArangeOp>(
12532
+ loc, arangeType, end, cstNone, cstNone, cstNone, cstNone);
12533
+
12534
+ // Set the current dimension to -1 for broadcasting
12535
+ viewShapeInts[dim] = -1;
12536
+ viewShapeListElems[dim] = cstMinusOne;
12537
+
12538
+ Value viewShapeList = rewriter.create<Torch::PrimListConstructOp>(
12539
+ loc, Torch::ListType::get(Torch::IntType::get(context)),
12540
+ viewShapeListElems);
12541
+
12542
+ auto viewType = ValueTensorType::get(
12543
+ context, llvm::ArrayRef(viewShapeInts), si64Type);
12544
+ index = rewriter.create<AtenViewOp>(loc, viewType, index, viewShapeList);
12545
+
12546
+ // Multiply the index with the stride for the current dimension
12547
+ Value cstStride = rewriter.create<ConstantIntOp>(
12548
+ loc, rewriter.getI64IntegerAttr(stridesInts[dim]));
12549
+ index = rewriter.create<AtenMulScalarOp>(loc, viewType, index, cstStride);
12550
+
12551
+ // Reset the current dimension to 1 for the next iteration
12552
+ viewShapeInts[dim] = 1;
12553
+ viewShapeListElems[dim] = cstOne;
12554
+
12555
+ if (dim == 0) {
12556
+ finalIndices = index;
12557
+ continue;
12558
+ }
12559
+
12560
+ // calculate common shape for broadcast
12561
+ SmallVector<int64_t> broadcastShape;
12562
+ SmallVector<Value> broadcastShapeValue;
12563
+ computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape,
12564
+ broadcastShapeValue);
12565
+ Type broadcastType = ValueTensorType::get(
12566
+ context, llvm::ArrayRef(broadcastShape), si64Type);
12567
+
12568
+ finalIndices = rewriter.create<AtenAddTensorOp>(
12569
+ loc, broadcastType, finalIndices, index, cstOne);
12570
+ }
12571
+
12572
+ int64_t flattenedResultSize = 1;
12573
+ for (int64_t size : sizesInts)
12574
+ flattenedResultSize *= size;
12575
+
12576
+ // Flattening the indices and adding the storage offset
12577
+ finalIndices = rewriter.create<AtenFlattenUsingIntsOp>(
12578
+ loc,
12579
+ ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12580
+ si64Type),
12581
+ finalIndices, cstZero, cstMinusOne); // -1 means flatten all
12582
+
12583
+ if (storageOffset != 0) {
12584
+ Value cstStorageOffset = rewriter.create<ConstantIntOp>(
12585
+ loc, rewriter.getI64IntegerAttr(storageOffset));
12586
+ finalIndices = rewriter.create<AtenAddScalarOp>(
12587
+ loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne);
12588
+ }
12589
+
12590
+ // Index the flattened input tensor
12591
+ Type listElemType =
12592
+ inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
12593
+ /*optionalDtype=*/nullptr);
12594
+ Value indicesList = rewriter.create<Torch::PrimListConstructOp>(
12595
+ loc, Torch::ListType::get(listElemType),
12596
+ SmallVector<Value>{finalIndices});
12597
+
12598
+ auto flattenedResultTy =
12599
+ ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize),
12600
+ inputType.getOptionalDtype());
12601
+ Value result = rewriter.create<AtenIndexTensorOp>(loc, flattenedResultTy,
12602
+ input, indicesList);
12603
+
12604
+ // Reshape the result to the desired output size
12605
+ SmallVector<Value> sizesIntsValues;
12606
+ for (int64_t size : sizesInts) {
12607
+ sizesIntsValues.push_back(rewriter.create<ConstantIntOp>(
12608
+ loc, rewriter.getI64IntegerAttr(size)));
12609
+ }
12610
+ Value resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
12611
+ loc, Torch::ListType::get(Torch::IntType::get(context)),
12612
+ sizesIntsValues);
12613
+ result =
12614
+ rewriter.create<AtenViewOp>(loc, op.getType(), result, resultSizeList);
12615
+
12616
+ rewriter.replaceOp(op, result);
12617
+ return success();
12618
+ }
12619
+ };
12620
+ } // namespace
12621
+
12430
12622
namespace {
12431
12623
class DecomposeComplexOpsPass
12432
12624
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12750,6 +12942,7 @@ class DecomposeComplexOpsPass
12750
12942
patterns);
12751
12943
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
12752
12944
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
12945
+ addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);
12753
12946
12754
12947
GreedyRewriteConfig config;
12755
12948
config.setUseTopDownTraversal(true);
0 commit comments