|
20 | 20 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
21 | 21 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
22 | 22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
| 23 | +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
23 | 24 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
24 | 25 | #include "mlir/IR/AffineMap.h"
|
25 | 26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
@@ -732,6 +733,32 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
|
732 | 733 | return success();
|
733 | 734 | }
|
734 | 735 |
|
| 736 | +struct FoldSubviewIntoGatherToLDSOp |
| 737 | + : public OpRewritePattern<amdgpu::GatherToLDSOp> { |
| 738 | + using OpRewritePattern<amdgpu::GatherToLDSOp>::OpRewritePattern; |
| 739 | + LogicalResult |
| 740 | + matchAndRewrite(amdgpu::GatherToLDSOp op, PatternRewriter &rewriter) const override { |
| 741 | + Location loc = op.getLoc(); |
| 742 | + |
| 743 | + // Check if the source is a subview operation: |
| 744 | + auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp()); |
| 745 | + if (!subviewOp) |
| 746 | + return rewriter.notifyMatchFailure( |
| 747 | + loc, "GatherToLDSOp can only be folded if the source is a SubviewOp"); |
| 748 | + |
| 749 | + SmallVector<Value> sourceIndices; |
| 750 | + mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
| 751 | + rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(), |
| 752 | + subviewOp.getDroppedDims(), op.getSrcIndices(), sourceIndices); |
| 753 | + |
| 754 | + rewriter.replaceOpWithNewOp<admgpu::GatherToLDSOp>( |
| 755 | + op, subviewOp.getSource(), sourceIndices, op.getDst(), op.getDstIndices(), |
| 756 | + op.getTransferType()); |
| 757 | + |
| 758 | + return success(); |
| 759 | + } |
| 760 | +}; |
| 761 | + |
735 | 762 | void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
|
736 | 763 | patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
|
737 | 764 | LoadOpOfSubViewOpFolder<memref::LoadOp>,
|
@@ -762,8 +789,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
|
762 | 789 | StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
|
763 | 790 | StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
|
764 | 791 | StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
|
765 |
| - SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( |
766 |
| - patterns.getContext()); |
| 792 | + SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder, |
| 793 | + FoldSubviewIntoGatherToLDSOp>(patterns.getContext()); |
767 | 794 | }
|
768 | 795 |
|
769 | 796 | //===----------------------------------------------------------------------===//
|
|
0 commit comments