@@ -426,6 +426,107 @@ class ConvertAtenReplicationPad2dOp
426426};
427427} // namespace
428428
429+ namespace {
430+
431+ // Lower aten.replication_pad3d operator into a sequence of
432+ // tensor.extract_slice and tensor.concat operations.
433+ class ConvertAtenReplicationPad3dOp
434+ : public OpConversionPattern<AtenReplicationPad3dOp> {
435+
436+ private:
437+ enum sliceLoc { START = 0 , END = 1 };
438+
439+ Value extractSlice (ConversionPatternRewriter &rewriter, Location loc,
440+ Value input, int64_t dimension, sliceLoc sliceLoc) const {
441+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
442+ int64_t inputRank = inputType.getRank ();
443+ SmallVector<Value> inputShape = getTensorSizes (rewriter, loc, input);
444+
445+ SmallVector<OpFoldResult> offsets (inputRank, rewriter.getIndexAttr (0 ));
446+ if (sliceLoc == END) {
447+ Value dimSize = inputShape[dimension];
448+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
449+ Value endIdx = rewriter.create <arith::SubIOp>(loc, dimSize, one);
450+ offsets[dimension] = getAsOpFoldResult (endIdx);
451+ }
452+
453+ SmallVector<OpFoldResult> allOneStrides (inputRank,
454+ rewriter.getIndexAttr (1 ));
455+ SmallVector<OpFoldResult> sizes (inputRank, rewriter.getIndexAttr (0 ));
456+ for (int i = 0 ; i < inputRank; ++i)
457+ sizes[i] = (i == dimension) ? rewriter.getIndexAttr (1 )
458+ : getAsOpFoldResult (inputShape[i]);
459+
460+ Value extractedSlice = rewriter.create <tensor::ExtractSliceOp>(
461+ loc, input, offsets, sizes, allOneStrides);
462+ return extractedSlice;
463+ }
464+
465+ Value createTile (ConversionPatternRewriter &rewriter, Location loc,
466+ Value slice, int64_t tileWidth, int64_t dimension) const {
467+ SmallVector<Value> slices (tileWidth, slice);
468+ if (tileWidth == 1 )
469+ return slice;
470+ return rewriter.create <tensor::ConcatOp>(loc, dimension, slices);
471+ }
472+
473+ public:
474+ using OpConversionPattern::OpConversionPattern;
475+
476+ LogicalResult
477+ matchAndRewrite (AtenReplicationPad3dOp op, OpAdaptor adaptor,
478+ ConversionPatternRewriter &rewriter) const override {
479+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
480+ return failure ();
481+
482+ Location loc = op->getLoc ();
483+ Value input = adaptor.getSelf ();
484+ auto inputType = llvm::cast<RankedTensorType>(input.getType ());
485+ int64_t inputRank = inputType.getRank ();
486+ unsigned numDims = inputType.getRank ();
487+ assert (numDims >= 2 && " Not enough input dimensions" );
488+
489+ SmallVector<int64_t > padInts;
490+ if (!matchPattern (op.getPadding (), m_TorchListOfConstantInts (padInts)))
491+ return rewriter.notifyMatchFailure (
492+ op, " only support constant int pad ranges" );
493+
494+ if (padInts.size () != 6 )
495+ return rewriter.notifyMatchFailure (
496+ op, " pad range must have exactly six values" );
497+
498+ Value res = input;
499+ int64_t padIdx = 0 ;
500+ for (int64_t dim = inputRank - 1 ; dim >= inputRank - 3 ; dim--) {
501+ int64_t startTileWidth = padInts[padIdx++];
502+ int64_t endTileWidth = padInts[padIdx++];
503+
504+ SmallVector<Value> resultParts;
505+ if (startTileWidth > 0 ) {
506+ Value slice = extractSlice (rewriter, loc, res, dim, sliceLoc::START);
507+ Value tile = createTile (rewriter, loc, slice, startTileWidth, dim);
508+ resultParts.push_back (tile);
509+ }
510+
511+ resultParts.push_back (res);
512+
513+ if (endTileWidth > 0 ) {
514+ Value slice = extractSlice (rewriter, loc, res, dim, sliceLoc::END);
515+ Value tile = createTile (rewriter, loc, slice, endTileWidth, dim);
516+ resultParts.push_back (tile);
517+ }
518+
519+ if (resultParts.size () > 1 )
520+ res = rewriter.create <tensor::ConcatOp>(loc, dim, resultParts);
521+ }
522+
523+ Type resultType = getTypeConverter ()->convertType (op.getType ());
524+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, resultType, res);
525+ return success ();
526+ }
527+ };
528+
529+ } // namespace
429530namespace {
430531// Converts constant tensor allocation like ops.
431532template <typename OpTy, int fillVal>
@@ -696,6 +797,8 @@ void mlir::torch::torch_to_linalg::
696797 RewritePatternSet &patterns,
697798 ConversionTarget &target) {
698799 MLIRContext *context = patterns.getContext ();
800+ target.addIllegalOp <AtenReplicationPad3dOp>();
801+ patterns.add <ConvertAtenReplicationPad3dOp>(typeConverter, context);
699802 target.addIllegalOp <AtenReplicationPad2dOp>();
700803 patterns.add <ConvertAtenReplicationPad2dOp>(typeConverter, context);
701804 target.addIllegalOp <AtenReplicationPad1dOp>();
0 commit comments