@@ -1631,42 +1631,59 @@ class VectorLayoutInferer {
16311631 return success ();
16321632 }
16331633
1634- // Find the small tiling such that there is not padding and each vreg holds
1635- // a continuous slice of the flatten data .
1634+ // Find the small tiling such that each vreg holds a continuous slice of the
1635+ // flatten data and each row is either fully occupied or is all padding .
16361636 auto small_second_minor_tiling_layout =
1637- [&](ArrayRef<int64_t > shape) -> std::optional<VectorLayout> {
1637+ [&](ArrayRef<int64_t > shape,
1638+ ImplicitDim implicit_dim) -> std::optional<VectorLayout> {
1639+ if (!llvm::isPowerOf2_32 (bitwidth)) {
1640+ return std::nullopt ;
1641+ }
1642+
16381643 const int64_t elements_per_vreg = native_tiling[0 ] * native_tiling[1 ];
1644+ bool aligned_1d_tiling = shape.back () % elements_per_vreg == 0 ;
1645+ // Force 1d tiling with implicit second minor.
1646+ if (implicit_dim == ImplicitDim::kSecondMinor || aligned_1d_tiling) {
1647+ return VectorLayout (bitwidth, {0 , 0 }, {1 , target_shape_[1 ] * packing},
1648+ implicit_dim);
1649+ }
1650+
1651+ CHECK_EQ (implicit_dim, ImplicitDim::kNone );
16391652 if (shape.size () < 2 ) {
16401653 return std::nullopt ;
16411654 }
1642- if (!llvm::isPowerOf2_32 (bitwidth)) {
1655+ int64_t second_minor_tiling = elements_per_vreg / shape.back ();
1656+ if (elements_per_vreg % shape.back () != 0 ||
1657+ second_minor_tiling % packing != 0 ||
1658+ second_minor_tiling > native_tiling[0 ]) {
16431659 return std::nullopt ;
16441660 }
1645- int64_t second_minor_tiling = elements_per_vreg / shape.back ();
1646- bool can_use_1d_tiling = shape.back () % elements_per_vreg == 0 ;
1647- std::array<int64_t , 2 > tiling;
1648- if (can_use_1d_tiling) {
1649- tiling = {1 , target_shape_[1 ] * packing};
1650- } else if (elements_per_vreg % shape.back () == 0 &&
1651- second_minor_tiling % packing == 0 &&
1652- second_minor_tiling <= native_tiling[0 ]) {
1653- tiling = {second_minor_tiling, target_shape_[1 ]};
1654- } else {
1661+ auto layout =
1662+ VectorLayout (bitwidth, {0 , 0 },
1663+ {second_minor_tiling, target_shape_[1 ]}, implicit_dim);
1664+ // Must be lane-aligned. This makes sure vreg is one-to-one mapping.
1665+ if (shape.back () != layout.vregSlice (target_shape_)[1 ]) {
16551666 return std::nullopt ;
16561667 }
16571668 // TODO(b/440370770): Preserve replicated offsets.
1658- auto layout = VectorLayout (bitwidth, {0 , 0 }, tiling, ImplicitDim::kNone );
1669+ return layout;
1670+ };
1671+
1672+ auto has_padding = [&](std::array<int64_t , 2 > tiled_ishape,
1673+ VectorLayout layout) -> bool {
16591674 auto vreg_slice = layout.vregSlice (target_shape_);
1660- if ((shape.back () != vreg_slice[1 ] && !can_use_1d_tiling) ||
1661- shape[shape.size () - 2 ] % vreg_slice[0 ] != 0 ) {
1662- return std::nullopt ;
1675+ bool is_1d_tiling = layout.tiling () ==
1676+ std::array<int64_t , 2 >{1 , target_shape_[1 ] * packing};
1677+ if (is_1d_tiling) {
1678+ return tiled_ishape[1 ] % vreg_slice[1 ] != 0 ;
16631679 }
1664- return layout;
1680+ return (tiled_ishape[0 ] % vreg_slice[0 ] != 0 ) ||
1681+ (tiled_ishape[1 ] != vreg_slice[1 ]);
16651682 };
16661683
1667- // Use the small tiling if there's no padding and each vreg holds a
1668- // contiguous slice of the flattened data. It makes reshape a row shuffle
1669- // within a vreg.
1684+ // Use the small tiling if each vreg holds a contiguous slice of the
1685+ // flattened data and each row is either fully occupied or is all
1686+ // padding. It makes reshape a row shuffle within a vreg.
16701687 //
16711688 // For example,
16721689 // - (4, 256) with (4, 128) tiling to (1, 1024) with (1, 128) tiling is
@@ -1675,16 +1692,43 @@ class VectorLayoutInferer {
16751692 // - (4, 256) with (4, 128) tiling to (2, 512) with (2, 128) tiling is
16761693 // to shuffle sublane from [0, 1, 2, 3, 4, 5, 6, 7] to
16771694 // [0, 2, 4, 6, 1, 3, 5, 7]
1678- auto src_small_second_minor_tiling_layout =
1679- small_second_minor_tiling_layout (src_shape);
1680- auto res_small_second_minor_tiling_layout =
1681- small_second_minor_tiling_layout (res_shape);
1682-
1683- if (src_small_second_minor_tiling_layout.has_value () &&
1684- res_small_second_minor_tiling_layout.has_value ()) {
1685- setLayout (op, *src_small_second_minor_tiling_layout,
1686- *res_small_second_minor_tiling_layout);
1687- return success ();
1695+ //
1696+ // Use implicit second minor to simplify the logic a bit.
1697+ for (ImplicitDim src_implicit_dim :
1698+ {ImplicitDim::kNone , ImplicitDim::kSecondMinor }) {
1699+ for (ImplicitDim res_implicit_dim :
1700+ {ImplicitDim::kNone , ImplicitDim::kSecondMinor }) {
1701+ auto src_small_second_minor_tiling_layout =
1702+ small_second_minor_tiling_layout (src_shape, src_implicit_dim);
1703+ auto res_small_second_minor_tiling_layout =
1704+ small_second_minor_tiling_layout (res_shape, res_implicit_dim);
1705+ if (!src_small_second_minor_tiling_layout.has_value () ||
1706+ !res_small_second_minor_tiling_layout.has_value ()) {
1707+ continue ;
1708+ }
1709+ auto src_layout = *src_small_second_minor_tiling_layout;
1710+ auto res_layout = *res_small_second_minor_tiling_layout;
1711+ auto src_tiled_ishape = src_layout.getImplicitTiledDims (src_shape, 1 );
1712+ auto res_tiled_ishape = res_layout.getImplicitTiledDims (res_shape, 1 );
1713+ bool src_vreg_has_padding = has_padding (src_tiled_ishape, src_layout);
1714+ bool res_vreg_has_padding = has_padding (res_tiled_ishape, res_layout);
1715+ if (!src_vreg_has_padding && !res_vreg_has_padding) {
1716+ // No padding on either side, e.g., reshape i32 (8, 128) to (4, 256)
1717+ // with input tiling (8, 128) and output tiling (4, 128).
1718+ setLayout (op, src_layout, res_layout);
1719+ return success ();
1720+ }
1721+ if (src_vreg_has_padding && res_vreg_has_padding) {
1722+ // Padding on both sides, e.g., reshape i32 (10, 128) to (5, 256)
1723+ // with input tiling (8, 128) and output tiling (4, 128). We need to
1724+ // make sure only the last vreg in tiled dims is padded.
1725+ if (llvm::product_of (src_tiled_ishape) ==
1726+ llvm::product_of (res_tiled_ishape)) {
1727+ setLayout (op, src_layout, res_layout);
1728+ return success ();
1729+ }
1730+ }
1731+ }
16881732 }
16891733
16901734 // Shape casts for {32/16/8}-bit vector types with rank >= 2.
0 commit comments