Skip to content

Commit f5f32ad

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] Allow padding in small tiling row shuffle reshape.
We just need to make sure vreg-slice is lane aligned, e.g., each row is either fully occupied or fully padded, and only last vreg contains padding on tiled dims. To cover more cases, try to infer 1d tiling with implicit second minor. For example, reshape vector<16x512x56x128xbf16> to vector<16x512x7168xbf16> can use in tiling = (16, 128) and out tiling = (1, 256) and make it no-op. PiperOrigin-RevId: 831053083
1 parent 24e80c4 commit f5f32ad

File tree

2 files changed

+135
-65
lines changed

2 files changed

+135
-65
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5610,6 +5610,12 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
56105610
layout_in.vregSlice(ctx.target_shape);
56115611
const std::array<int64_t, 2> dst_vreg_slice =
56125612
layout_out.vregSlice(ctx.target_shape);
5613+
auto dst_vregs_shape = layout_out.tileArrayShape(
5614+
/*src_is_implicit=*/false, /*res_is_implicit=*/true, dst_shape,
5615+
ctx.target_shape);
5616+
auto src_vregs_shape = layout_in.tileArrayShape(
5617+
/*src_is_implicit=*/false, /*res_is_implicit=*/true, src_shape,
5618+
ctx.target_shape);
56135619
if (layout_in.tiling() == layout_out.tiling() &&
56145620
layout_in.offsets() == layout_out.offsets() &&
56155621
src_tiled_dims == dst_tiled_dims) {
@@ -5650,40 +5656,53 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
56505656
no_op = true;
56515657
}
56525658

5653-
auto can_use_row_shuffle = [&ctx](ArrayRef<int64_t> shape,
5654-
VectorLayout layout,
5655-
std::array<int64_t, 2> vreg_slice) {
5656-
if (shape.size() < 2) {
5659+
bool can_use_row_shuffle = [&]() {
5660+
if (!llvm::isPowerOf2_32(bitwidth)) {
56575661
return false;
56585662
}
5659-
// vreg must not be padded.
5660-
if (shape.back() % vreg_slice[1] != 0 ||
5661-
shape[shape.size() - 2] % vreg_slice[0] != 0) {
5663+
if (layout_in.offsets() != LayoutOffsets{0, 0} ||
5664+
layout_out.offsets() != LayoutOffsets{0, 0}) {
56625665
return false;
56635666
}
5664-
if (!llvm::isPowerOf2_32(layout.bitwidth())) {
5665-
return false;
5666-
}
5667-
if (layout.offsets() != LayoutOffsets{0, 0}) {
5668-
return false;
5669-
}
5670-
if (layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
5667+
5668+
auto is_lane_aligned = [&](std::array<int64_t, 2> tiled_ishape,
5669+
VectorLayout layout) -> bool {
5670+
bool is_1d_tiling =
5671+
layout.tiling() ==
5672+
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout.packing()};
5673+
auto vreg_slice = layout.vregSlice(ctx.target_shape);
5674+
return is_1d_tiling || tiled_ishape[1] % vreg_slice[1] == 0;
5675+
};
5676+
5677+
if (!is_lane_aligned(src_tiled_dims, layout_in) ||
5678+
!is_lane_aligned(dst_tiled_dims, layout_out)) {
56715679
return false;
56725680
}
5673-
// 2d tiling.
5674-
if (layout.tiling()[0] <= ctx.target_shape[0] * layout.packing() &&
5675-
layout.tiling()[1] == ctx.target_shape[1] &&
5676-
shape.back() == vreg_slice[1]) {
5681+
5682+
auto has_padding = [&](std::array<int64_t, 2> tiled_ishape,
5683+
VectorLayout layout) -> bool {
5684+
auto vreg_slice = layout.vregSlice(ctx.target_shape);
5685+
bool is_1d_tiling =
5686+
layout.tiling() ==
5687+
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout.packing()};
5688+
if (is_1d_tiling) {
5689+
return tiled_ishape[1] % vreg_slice[1] != 0;
5690+
}
5691+
return (tiled_ishape[0] % vreg_slice[0] != 0) ||
5692+
(tiled_ishape[1] != vreg_slice[1]);
5693+
};
5694+
5695+
bool src_vreg_has_padding = has_padding(src_tiled_dims, layout_in);
5696+
bool dst_vreg_has_padding = has_padding(dst_tiled_dims, layout_out);
5697+
if (!src_vreg_has_padding && !dst_vreg_has_padding) {
56775698
return true;
56785699
}
5679-
// 1d tiling.
5680-
if (layout.tiling() ==
5681-
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout.packing()} &&
5682-
shape.back() % vreg_slice[1] == 0) {
5683-
return true;
5700+
if (src_vreg_has_padding && dst_vreg_has_padding) {
5701+
return llvm::product_of(src_tiled_dims) ==
5702+
llvm::product_of(dst_tiled_dims);
56845703
}
56855704
return false;
5686-
};
5705+
}();
56875706

56885707
FAILUREOR_ASSIGN_OR_RETURN(
56895708
xla::Array<Value> src_vregs,
@@ -5715,15 +5734,11 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
57155734
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
57165735
return dst_vregs_local;
57175736
} else if (
5718-
// Row shuffle within a vreg if there is no padding and each vreg holds
5719-
// a contiguous slice of the flattened data.
5720-
can_use_row_shuffle(src_shape, layout_in, src_vreg_slice) &&
5721-
can_use_row_shuffle(dst_shape, layout_out, dst_vreg_slice)) {
5737+
// Row shuffle within a vreg if each vreg holds a contiguous slice of
5738+
// the flattened data.
5739+
can_use_row_shuffle) {
57225740
auto [sublane_count, lane_count] = ctx.target_shape;
5723-
auto dst_vregs_shape =
5724-
layout_out.tileArrayShape(false, false, dst_shape, ctx.target_shape);
5725-
auto src_vregs_shape =
5726-
layout_in.tileArrayShape(false, false, src_shape, ctx.target_shape);
5741+
src_vregs.Reshape(dst_vregs_shape);
57275742
if (bitwidth == 32) {
57285743
// For 32 bit data, a sublane is effectively a physical row.
57295744
std::array<int64_t, 2> src_sublane_slice = {
@@ -5845,8 +5860,20 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
58455860
// with tiling (16, 128) and then to (8, 512) with tiling (8, 128).
58465861
const int64_t src_sublane_tiling = layout_in.tiling()[0];
58475862
const int64_t dst_sublane_tiling = layout_out.tiling()[0];
5863+
const int64_t native_sublane_tiling =
5864+
ctx.target_shape[0] * layout_in.packing();
58485865
CHECK(llvm::isPowerOf2_64(static_cast<uint64_t>(src_sublane_tiling)));
58495866
CHECK(llvm::isPowerOf2_64(static_cast<uint64_t>(dst_sublane_tiling)));
5867+
CHECK(
5868+
llvm::isPowerOf2_64(static_cast<uint64_t>(native_sublane_tiling)));
5869+
// (target_shape[0] * packing, target_shape[1]) <->
5870+
// (1, target_shape[1] * packing) is a no-op.
5871+
if ((src_sublane_tiling == 1 &&
5872+
dst_sublane_tiling == native_sublane_tiling) ||
5873+
(src_sublane_tiling == native_sublane_tiling &&
5874+
dst_sublane_tiling == 1)) {
5875+
return src_vregs;
5876+
}
58505877
tpu::PackFormat unpack_format, pack_format;
58515878
if (src_sublane_tiling > dst_sublane_tiling) {
58525879
unpack_format = tpu::PackFormat::kInterleaved;
@@ -5887,7 +5914,6 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
58875914
src_vreg->getLoc(), src_vreg->getType(), dst_vreg);
58885915
});
58895916
}
5890-
src_vregs.Reshape(dst_vregs_shape);
58915917
return src_vregs;
58925918
} else if (
58935919
// Lower shape_casts for {32/16/8}-bit types where the minor dimension

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,42 +1631,59 @@ class VectorLayoutInferer {
16311631
return success();
16321632
}
16331633

1634-
// Find the small tiling such that there is not padding and each vreg holds
1635-
// a continuous slice of the flatten data.
1634+
// Find the small tiling such that each vreg holds a continuous slice of the
1635+
// flatten data and each row is either fully occupied or is all padding.
16361636
auto small_second_minor_tiling_layout =
1637-
[&](ArrayRef<int64_t> shape) -> std::optional<VectorLayout> {
1637+
[&](ArrayRef<int64_t> shape,
1638+
ImplicitDim implicit_dim) -> std::optional<VectorLayout> {
1639+
if (!llvm::isPowerOf2_32(bitwidth)) {
1640+
return std::nullopt;
1641+
}
1642+
16381643
const int64_t elements_per_vreg = native_tiling[0] * native_tiling[1];
1644+
bool aligned_1d_tiling = shape.back() % elements_per_vreg == 0;
1645+
// Force 1d tiling with implicit second minor.
1646+
if (implicit_dim == ImplicitDim::kSecondMinor || aligned_1d_tiling) {
1647+
return VectorLayout(bitwidth, {0, 0}, {1, target_shape_[1] * packing},
1648+
implicit_dim);
1649+
}
1650+
1651+
CHECK_EQ(implicit_dim, ImplicitDim::kNone);
16391652
if (shape.size() < 2) {
16401653
return std::nullopt;
16411654
}
1642-
if (!llvm::isPowerOf2_32(bitwidth)) {
1655+
int64_t second_minor_tiling = elements_per_vreg / shape.back();
1656+
if (elements_per_vreg % shape.back() != 0 ||
1657+
second_minor_tiling % packing != 0 ||
1658+
second_minor_tiling > native_tiling[0]) {
16431659
return std::nullopt;
16441660
}
1645-
int64_t second_minor_tiling = elements_per_vreg / shape.back();
1646-
bool can_use_1d_tiling = shape.back() % elements_per_vreg == 0;
1647-
std::array<int64_t, 2> tiling;
1648-
if (can_use_1d_tiling) {
1649-
tiling = {1, target_shape_[1] * packing};
1650-
} else if (elements_per_vreg % shape.back() == 0 &&
1651-
second_minor_tiling % packing == 0 &&
1652-
second_minor_tiling <= native_tiling[0]) {
1653-
tiling = {second_minor_tiling, target_shape_[1]};
1654-
} else {
1661+
auto layout =
1662+
VectorLayout(bitwidth, {0, 0},
1663+
{second_minor_tiling, target_shape_[1]}, implicit_dim);
1664+
// Must be lane-aligned. This makes sure vreg is one-to-one mapping.
1665+
if (shape.back() != layout.vregSlice(target_shape_)[1]) {
16551666
return std::nullopt;
16561667
}
16571668
// TODO(b/440370770): Preserve replicated offsets.
1658-
auto layout = VectorLayout(bitwidth, {0, 0}, tiling, ImplicitDim::kNone);
1669+
return layout;
1670+
};
1671+
1672+
auto has_padding = [&](std::array<int64_t, 2> tiled_ishape,
1673+
VectorLayout layout) -> bool {
16591674
auto vreg_slice = layout.vregSlice(target_shape_);
1660-
if ((shape.back() != vreg_slice[1] && !can_use_1d_tiling) ||
1661-
shape[shape.size() - 2] % vreg_slice[0] != 0) {
1662-
return std::nullopt;
1675+
bool is_1d_tiling = layout.tiling() ==
1676+
std::array<int64_t, 2>{1, target_shape_[1] * packing};
1677+
if (is_1d_tiling) {
1678+
return tiled_ishape[1] % vreg_slice[1] != 0;
16631679
}
1664-
return layout;
1680+
return (tiled_ishape[0] % vreg_slice[0] != 0) ||
1681+
(tiled_ishape[1] != vreg_slice[1]);
16651682
};
16661683

1667-
// Use the small tiling if there's no padding and each vreg holds a
1668-
// contiguous slice of the flattened data. It makes reshape a row shuffle
1669-
// within a vreg.
1684+
// Use the small tiling if each vreg holds a contiguous slice of the
1685+
// flattened data and each row is either fully occupied or is all
1686+
// padding. It makes reshape a row shuffle within a vreg.
16701687
//
16711688
// For example,
16721689
// - (4, 256) with (4, 128) tiling to (1, 1024) with (1, 128) tiling is
@@ -1675,16 +1692,43 @@ class VectorLayoutInferer {
16751692
// - (4, 256) with (4, 128) tiling to (2, 512) with (2, 128) tiling is
16761693
// to shuffle sublane from [0, 1, 2, 3, 4, 5, 6, 7] to
16771694
// [0, 2, 4, 6, 1, 3, 5, 7]
1678-
auto src_small_second_minor_tiling_layout =
1679-
small_second_minor_tiling_layout(src_shape);
1680-
auto res_small_second_minor_tiling_layout =
1681-
small_second_minor_tiling_layout(res_shape);
1682-
1683-
if (src_small_second_minor_tiling_layout.has_value() &&
1684-
res_small_second_minor_tiling_layout.has_value()) {
1685-
setLayout(op, *src_small_second_minor_tiling_layout,
1686-
*res_small_second_minor_tiling_layout);
1687-
return success();
1695+
//
1696+
// Use implicit second minor to simplify the logic a bit.
1697+
for (ImplicitDim src_implicit_dim :
1698+
{ImplicitDim::kNone, ImplicitDim::kSecondMinor}) {
1699+
for (ImplicitDim res_implicit_dim :
1700+
{ImplicitDim::kNone, ImplicitDim::kSecondMinor}) {
1701+
auto src_small_second_minor_tiling_layout =
1702+
small_second_minor_tiling_layout(src_shape, src_implicit_dim);
1703+
auto res_small_second_minor_tiling_layout =
1704+
small_second_minor_tiling_layout(res_shape, res_implicit_dim);
1705+
if (!src_small_second_minor_tiling_layout.has_value() ||
1706+
!res_small_second_minor_tiling_layout.has_value()) {
1707+
continue;
1708+
}
1709+
auto src_layout = *src_small_second_minor_tiling_layout;
1710+
auto res_layout = *res_small_second_minor_tiling_layout;
1711+
auto src_tiled_ishape = src_layout.getImplicitTiledDims(src_shape, 1);
1712+
auto res_tiled_ishape = res_layout.getImplicitTiledDims(res_shape, 1);
1713+
bool src_vreg_has_padding = has_padding(src_tiled_ishape, src_layout);
1714+
bool res_vreg_has_padding = has_padding(res_tiled_ishape, res_layout);
1715+
if (!src_vreg_has_padding && !res_vreg_has_padding) {
1716+
// No padding on either side, e.g., reshape i32 (8, 128) to (4, 256)
1717+
// with input tiling (8, 128) and output tiling (4, 128).
1718+
setLayout(op, src_layout, res_layout);
1719+
return success();
1720+
}
1721+
if (src_vreg_has_padding && res_vreg_has_padding) {
1722+
// Padding on both sides, e.g., reshape i32 (10, 128) to (5, 256)
1723+
// with input tiling (8, 128) and output tiling (4, 128). We need to
1724+
// make sure only the last vreg in tiled dims is padded.
1725+
if (llvm::product_of(src_tiled_ishape) ==
1726+
llvm::product_of(res_tiled_ishape)) {
1727+
setLayout(op, src_layout, res_layout);
1728+
return success();
1729+
}
1730+
}
1731+
}
16881732
}
16891733

16901734
// Shape casts for {32/16/8}-bit vector types with rank >= 2.

0 commit comments

Comments
 (0)