From 19bb984254b8d4e7f40bc7ee1ba1976327706dfd Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 11 Nov 2025 13:45:05 -0800 Subject: [PATCH] [Mosaic] Allow padding in small tiling row shuffle reshape. We just need to make sure shape aligns to vreg-slice lane dim and only last vreg contains padding on tiled dims. Examples: 1: Reshape vector<10x128xi32> to vector<5x256xi32> can use the row shuffle reshape routine by inferring in tiling = (8, 128) and out tiling = (4, 128) because 1) vregs are still one-to-one mapping, ensured by vreg-slice lane aligned, and 2) only last vreg in tiled dims are padded, ensured by #elements are the same in tiled dims. 2: Reshape vector<16x512x56x128xbf16> to vector<16x512x7168xbf16> can use in tiling = (16, 128) and out tiling = (1, 256) and make it no-op. PiperOrigin-RevId: 831053083 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 14 +-- .../tpu/transforms/apply_vector_layout.cc | 85 ++++++++++++------- .../tpu/transforms/infer_vector_layout.cc | 54 ++++++++++-- 3 files changed, 105 insertions(+), 48 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 78faa1565506..12f42a72aca6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1836,12 +1836,14 @@ LogicalResult UnpackSubelementsOp::canonicalize(UnpackSubelementsOp op, if (auto pack = dyn_cast(op.getSource().getDefiningOp()); pack && pack.getPackFormat() == op.getPackFormat() && pack.getSources().front().getType() == op.getType()) { - rewriter.replaceAllOpUsesWith( - op, pack.getPaddedSources( - pack.getSources(), pack.getPositions(), - op.getType().getElementTypeBitWidth() / - pack.getType().getElementTypeBitWidth())[op.getIndex()]); - return success(); + Value source = pack.getPaddedSources( + pack.getSources(), pack.getPositions(), + op.getType().getElementTypeBitWidth() / + pack.getType().getElementTypeBitWidth())[op.getIndex()]; + if (source) { + rewriter.replaceAllOpUsesWith(op, source); + return success(); + } } return failure(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 0dee1129a35a..92b7f2939364 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -5650,40 +5650,51 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op, no_op = true; } - auto can_use_row_shuffle = [&ctx](ArrayRef shape, - VectorLayout layout, - std::array vreg_slice) { - if (shape.size() < 2) { + bool can_use_row_shuffle = [&]() { + if (!llvm::isPowerOf2_32(layout_in.bitwidth())) { return false; } - // vreg must not be padded. - if (shape.back() % vreg_slice[1] != 0 || - shape[shape.size() - 2] % vreg_slice[0] != 0) { - return false; - } - if (!llvm::isPowerOf2_32(layout.bitwidth())) { - return false; - } - if (layout.offsets() != LayoutOffsets{0, 0}) { - return false; - } - if (layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) { - return false; + bool src_is_1d_tiling = + layout_in.tiling() == + std::array{1, ctx.target_shape[1] * layout_in.packing()}; + bool dst_is_1d_tiling = + layout_out.tiling() == + std::array{1, ctx.target_shape[1] * layout_out.packing()}; + bool src_is_vreg_slice_lane_aligned = + (!src_is_1d_tiling && src_tiled_dims[1] == src_vreg_slice[1]) || + (src_is_1d_tiling && src_tiled_dims[1] % src_vreg_slice[1] == 0); + bool dst_is_vreg_slice_lane_aligned = + (!dst_is_1d_tiling && dst_tiled_dims[1] == dst_vreg_slice[1]) || + (dst_is_1d_tiling && dst_tiled_dims[1] % dst_vreg_slice[1] == 0); + bool src_is_vreg_slice_sublane_aligned = + src_tiled_dims[0] % src_vreg_slice[0] == 0; + bool dst_is_vreg_slice_sublane_aligned = + dst_tiled_dims[0] % dst_vreg_slice[0] == 0; + if (src_is_vreg_slice_lane_aligned && dst_is_vreg_slice_lane_aligned) { + if (src_is_vreg_slice_sublane_aligned && + dst_is_vreg_slice_sublane_aligned) { + // Both src and dst are aligned to vreg slice sublanes. + return true; + } + if (!src_is_vreg_slice_sublane_aligned && + !dst_is_vreg_slice_sublane_aligned && + llvm::product_of(src_tiled_dims) == + llvm::product_of(dst_tiled_dims)) { + // Neither src nor dst are aligned to vreg slice sublanes. + // Padding happens only on the last vreg in tiled dims. + return true; + } } - // 2d tiling. - if (layout.tiling()[0] <= ctx.target_shape[0] * layout.packing() && - layout.tiling()[1] == ctx.target_shape[1] && - shape.back() == vreg_slice[1]) { + if (src_is_vreg_slice_lane_aligned && dst_is_1d_tiling && + llvm::product_of(src_tiled_dims) == dst_tiled_dims[1]) { return true; } - // 1d tiling. - if (layout.tiling() == - std::array{1, ctx.target_shape[1] * layout.packing()} && - shape.back() % vreg_slice[1] == 0) { + if (dst_is_vreg_slice_lane_aligned && src_is_1d_tiling && + llvm::product_of(dst_tiled_dims) == src_tiled_dims[1]) { return true; } return false; - }; + }(); FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_vregs, @@ -5717,13 +5728,10 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op, } else if ( // Row shuffle within a vreg if there is no padding and each vreg holds // a contiguous slice of the flattened data. - can_use_row_shuffle(src_shape, layout_in, src_vreg_slice) && - can_use_row_shuffle(dst_shape, layout_out, dst_vreg_slice)) { + can_use_row_shuffle) { auto [sublane_count, lane_count] = ctx.target_shape; - auto dst_vregs_shape = - layout_out.tileArrayShape(false, false, dst_shape, ctx.target_shape); - auto src_vregs_shape = - layout_in.tileArrayShape(false, false, src_shape, ctx.target_shape); + src_vregs.Reshape( + layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape)); if (bitwidth == 32) { // For 32 bit data, a sublane is effectively a physical row. std::array src_sublane_slice = { @@ -5845,8 +5853,20 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op, // with tiling (16, 128) and then to (8, 512) with tiling (8, 128). const int64_t src_sublane_tiling = layout_in.tiling()[0]; const int64_t dst_sublane_tiling = layout_out.tiling()[0]; + const int64_t native_sublane_tiling = + ctx.target_shape[0] * layout_in.packing(); CHECK(llvm::isPowerOf2_64(static_cast(src_sublane_tiling))); CHECK(llvm::isPowerOf2_64(static_cast(dst_sublane_tiling))); + CHECK( + llvm::isPowerOf2_64(static_cast(native_sublane_tiling))); + // (target_shape[0] * packing, target_shape[1]) <-> + // (1, target_shape[1] * packing) is a no-op. + if ((src_sublane_tiling == 1 && + dst_sublane_tiling == native_sublane_tiling) || + (src_sublane_tiling == native_sublane_tiling && + dst_sublane_tiling == 1)) { + return src_vregs; + } tpu::PackFormat unpack_format, pack_format; if (src_sublane_tiling > dst_sublane_tiling) { unpack_format = tpu::PackFormat::kInterleaved; @@ -5887,7 +5907,6 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op, src_vreg->getLoc(), src_vreg->getType(), dst_vreg); }); } - src_vregs.Reshape(dst_vregs_shape); return src_vregs; } else if ( // Lower shape_casts for {32/16/8}-bit types where the minor dimension diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index aa9b1a3ca7f5..6bfc765ae87e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1633,8 +1633,8 @@ class VectorLayoutInferer { return success(); } - // Find the small tiling such that there is not padding and each vreg holds - // a continuous slice of the flatten data. + // Find the small tiling such that each vreg holds a continuous slice of the + // flatten data and each row is either fully occupied or is all padding. auto small_second_minor_tiling_layout = [&](ArrayRef shape) -> std::optional { const int64_t elements_per_vreg = native_tiling[0] * native_tiling[1]; @@ -1659,16 +1659,15 @@ class VectorLayoutInferer { // TODO(b/440370770): Preserve replicated offsets. auto layout = VectorLayout(bitwidth, {0, 0}, tiling, ImplicitDim::kNone); auto vreg_slice = layout.vregSlice(target_shape_); - if ((shape.back() != vreg_slice[1] && !can_use_1d_tiling) || - shape[shape.size() - 2] % vreg_slice[0] != 0) { + if (shape.back() != vreg_slice[1] && !can_use_1d_tiling) { return std::nullopt; } return layout; }; - // Use the small tiling if there's no padding and each vreg holds a - // contiguous slice of the flattened data. It makes reshape a row shuffle - // within a vreg. + // Use the small tiling if each vreg holds a contiguous slice of the + // flattened data and each row is either fully occupied or is all + // padding. It makes reshape a row shuffle within a vreg. // // For example, // - (4, 256) with (4, 128) tiling to (1, 1024) with (1, 128) tiling is @@ -1684,8 +1683,45 @@ class VectorLayoutInferer { if (src_small_second_minor_tiling_layout.has_value() && res_small_second_minor_tiling_layout.has_value()) { - setLayout(op, *src_small_second_minor_tiling_layout, - *res_small_second_minor_tiling_layout); + auto src_vreg_slice = + src_small_second_minor_tiling_layout->vregSlice(target_shape_); + auto res_vreg_slice = + res_small_second_minor_tiling_layout->vregSlice(target_shape_); + bool src_vreg_slice_aligned = + src_shape[src_shape.size() - 2] % src_vreg_slice[0] == 0; + bool res_vreg_slice_aligned = + res_shape[res_shape.size() - 2] % res_vreg_slice[0] == 0; + if ( + // Both input and output are aligned to its vreg slice. + (src_vreg_slice_aligned && res_vreg_slice_aligned) || + // If not aligned, make sure the padding only happens on last vreg in + // tiled dims. For example, reshape i32 (12, 128) to (6, 256) with + // input tiling (8, 128) and output tiling (4, 128). + (!src_vreg_slice_aligned && !res_vreg_slice_aligned && + llvm::product_of(src_shape.take_back(2)) == + llvm::product_of(res_shape.take_back(2)))) { + setLayout(op, *src_small_second_minor_tiling_layout, + *res_small_second_minor_tiling_layout); + return success(); + } + } + if (src_small_second_minor_tiling_layout.has_value() && + llvm::product_of(src_shape.take_back(2)) == res_shape.back()) { + // For example, reshape i32 (8, 10, 128) to (8, 1280) with input tiling + // (8, 128) and output tiling (1, 128). + setLayout( + op, *src_small_second_minor_tiling_layout, + VectorLayout(layout.bitwidth(), {0, 0}, + {1, target_shape_[1] * packing}, ImplicitDim::kNone)); + return success(); + } + if (res_small_second_minor_tiling_layout.has_value() && + llvm::product_of(res_shape.take_back(2)) == src_shape.back()) { + setLayout( + op, + VectorLayout(layout.bitwidth(), {0, 0}, + {1, target_shape_[1] * packing}, ImplicitDim::kNone), + *res_small_second_minor_tiling_layout); return success(); }