@@ -1648,6 +1648,15 @@ struct LoadOpConversion
16481648 usePackedType = true ;
16491649 }
16501650
1651+ if (isTransposeRequired) {
1652+ if (!usePackedType) {
1653+ // use the d32 transpose 2d load.
1654+ loadResultElemType = i32_ty;
1655+ packedElemsPerLanePerDPASInst = 32 / elemSizeInBits;
1656+ usePackedType = true ;
1657+ }
1658+ }
1659+
16511660 Type packedDPASOperandType =
16521661 LLVM::getVectorType (loadResultElemType, packedElemsPerLanePerDPASInst);
16531662
@@ -2082,12 +2091,14 @@ struct LoadOpConversion
20822091 offsetX = b.udiv (offsetX, b.i32_val (32 / originalElemBits));
20832092 }
20842093
2094+ Value base_width = b.mul (baseWidth, elemSizeInBytes);
2095+ Value base_pitch = b.mul (pitch, elemSizeInBytes);
20852096 auto load2dOp = rewriter.create <TritonGEN::Matrix2DBlockLoadOp>(
20862097 loc, load2DGenXType,
20872098 /* ptr*/ base,
2088- /* base_width*/ b. mul (baseWidth, elemSizeInBytes) ,
2099+ /* base_width*/ base_width ,
20892100 /* base_height*/ baseHeight,
2090- /* base_pitch*/ b. mul (pitch, elemSizeInBytes) ,
2101+ /* base_pitch*/ base_pitch ,
20912102 /* x*/ b.trunc (i32_ty, offsetX),
20922103 /* y*/ b.trunc (i32_ty, offsetY),
20932104 /* elem_size_in_bits*/ elemSizeInBits,
@@ -2105,6 +2116,10 @@ struct LoadOpConversion
21052116 rewriter.eraseOp (load2dOp);
21062117 return failure ();
21072118 }
2119+ #if 0
2120+ targetInfo.printf(rewriter, "base: %p, baseWidth: %d, baseHeight:%d, pitch:%d, offset_x:%d, offset_y:%d, loadVal: %d",
2121+ {base, base_width, baseHeight, base_pitch, offsetX, offsetY, load2dOp.getResult()});
2122+ #endif
21082123 LLVM_DEBUG (llvm::dbgs () << " Generated load op: " << load2dOp << " \n " );
21092124
21102125 unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA
@@ -2166,11 +2181,14 @@ struct LoadOpConversion
21662181 vblk * packedColNumPerVBlock + col)
21672182 << " , " << std::to_string (k + row) << " \n " ;
21682183 });
2184+ auto ret = b.bitcast (loadVal, unpackedDPASOperandType);
2185+ #if 0
2186+ targetInfo.printf(rewriter, "loadVal: %d", {ret});
2187+ #endif
21692188 loadVals[{outer * packedColNum * numLoadPerOutRepCluster +
21702189 rep * packedColNum +
21712190 vblk * packedColNumPerVBlock + col,
2172- k + row}] =
2173- b.bitcast (loadVal, unpackedDPASOperandType);
2191+ k + row}] = ret;
21742192 } break ;
21752193 case DpasEncodingAttr::OpIdx::OperandC: {
21762194 llvm_unreachable (" unexpected OpIdx::OperandC" );
0 commit comments