@@ -557,13 +557,14 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
557557// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
558558// and computes a new offset (mlir::Value) by applying padding based on
559559// shared memory layout.
560- SmallVector<Value> lowerLdStShared (
561- Location loc, MLIRContext *ctx, LinearLayout cvt,
562- ArrayRef<Value> valsArray, // Input for store, output for load
563- Type llvmElemTy, Value smemBase,
564- std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
565- uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
566- const TargetInfoBase &targetInfo, Operation *localLoadOp = nullptr);
560+ SmallVector<Value>
561+ lowerLdStShared (Location loc, MLIRContext *ctx, LinearLayout cvt,
562+ ArrayRef<Value> valsArray, // Input for store, output for load
563+ Type llvmElemTy, Value smemBase,
564+ std::function<Value(Value)> calcPaddedOffset,
565+ Value affineOffset, uint64_t maskSpanAffineOffset,
566+ RewriterBase &rewriter, const TargetInfoBase &targetInfo,
567+ Operation *localLoadOp = nullptr );
567568
568569// Lower an ld/st-like operation given a layout and a callback that creates the
569570// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
@@ -576,10 +577,10 @@ SmallVector<Value> lowerLdSt(
576577 ArrayRef<Value> valsArray, // Input for store, output for load
577578 Type llvmElemTy, Value smemBase,
578579 std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
579- uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
580+ uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
580581 const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
581- std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
582- ArrayRef<Value>, Value, int , VectorType)>
582+ std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value> ,
583+ Value, int , VectorType)>
583584 lowerInst);
584585
585586// Lower local_load/local_store via ld.shared/st.shared
@@ -588,7 +589,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
588589 LinearLayout cvt, // Map from registers to offset
589590 ArrayRef<Value> valsArray, // Input for store, empty for load
590591 Type llvmElemTy, triton::gpu::MemDescType srcTy,
591- SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
592+ SharedMemoryObject smemObj, RewriterBase &rewriter,
592593 const TargetInfoBase &targetInfo,
593594 Operation *localLoadOp = nullptr );
594595
@@ -643,6 +644,12 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
643644 const LLVMTypeConverter *typeConverter,
644645 RewriterBase &rewriter);
645646
647+ LogicalResult
648+ transferWithinBlockSwizzling (triton::gpu::ConvertLayoutOp op, Value src,
649+ const TargetInfoBase &targetInfo,
650+ const LLVMTypeConverter *typeConverter,
651+ RewriterBase &rewriter);
652+
646653SmallVector<Value> inlineRegionImpl (RewriterBase &rewriter, Region ®ion,
647654 ArrayRef<Value> args,
648655 mlir::TypeID terminatorTypeId,
@@ -655,6 +662,13 @@ SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region ®ion,
655662 mlir::TypeID::get<TerminatorOp>(), loc);
656663}
657664
665+ void finalizeTensorAtomicResults (Operation *op, RankedTensorType tensorTy,
666+ ConversionPatternRewriter &rewriter,
667+ SmallVector<Value> &resultVals,
668+ Type valueElemTy, TritonLLVMOpBuilder &b,
669+ Value threadPred,
670+ const TargetInfoBase &targetInfo,
671+ const LLVMTypeConverter *typeConverter);
658672} // namespace mlir
659673
660674#endif
0 commit comments