@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
116
116
// opIdx: 0 => a, 1 => b
117
117
auto type = cast<triton::MemDescType>(v.getType ());
118
118
SmallVector<int64_t > shape{type.getShape ().begin (), type.getShape ().end ()};
119
- SmallVector<int64_t > offset{ 0 , 0 } ;
119
+ SmallVector<int64_t > offset (shape. size () , 0 ) ;
120
120
Type elementType = type.getElementType ();
121
121
122
122
// k => (prefetchWidth, k - prefetchWidth)
@@ -140,8 +140,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
140
140
type.getMemorySpace ()),
141
141
v, offsetsVal);
142
142
143
+ // We need to assign kwidth to zero in the case where the parent layout is
144
+ // Blocked, otherwise the verifier emits a failure. The parent layout is
145
+ // Blocked only when Tensor Cores are disabled.
146
+ int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
147
+ ? 0
148
+ : prefetchWidth / 8 ;
143
149
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get (
144
- builder.getContext (), opIdx, dotEncoding, prefetchWidth / 8 );
150
+ builder.getContext (), opIdx, dotEncoding, kwidth );
145
151
Value prefetchSlice = builder.create <triton::gpu::LocalLoadOp>(
146
152
v.getLoc (), RankedTensorType::get (shape, elementType, dotOperandEnc),
147
153
newSmem);
@@ -190,6 +196,22 @@ LogicalResult Prefetcher::initialize() {
190
196
break ;
191
197
if (!op->getResult (0 ).hasOneUse ())
192
198
break ;
199
+ // Similar to issues faced in HoistLayoutConversion pattern in
200
+ // OptimizeDotOperands.cpp, we can't propagate through type casts from
201
+ // predicates as they aren't supported in Triton when encoded with dot_op
202
+ // layout.
203
+ if (isa<arith::UIToFPOp>(op)) {
204
+ Type srcType = getElementTypeOrSelf (op->getOperand (0 ));
205
+ if (srcType.isInteger (1 ))
206
+ break ;
207
+ }
208
+ // Propagation through ExpandDims is currently not supported. This blindly
209
+ // replaces the encoding with dot encoding & but ExpandDims requires a
210
+ // SliceEncoding. This could be rewritten to support it somehow, but I
211
+ // don't think it's trivial & it's currently crashing.
212
+ if (isa<ExpandDimsOp>(op)) {
213
+ break ;
214
+ }
193
215
rets.push_back (op->getOperand (0 ));
194
216
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
195
217
foundConvertFromShared = true ;
0 commit comments