Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1836,12 +1836,14 @@ LogicalResult UnpackSubelementsOp::canonicalize(UnpackSubelementsOp op,
if (auto pack = dyn_cast<PackSubelementsOp>(op.getSource().getDefiningOp());
pack && pack.getPackFormat() == op.getPackFormat() &&
pack.getSources().front().getType() == op.getType()) {
rewriter.replaceAllOpUsesWith(
op, pack.getPaddedSources(
pack.getSources(), pack.getPositions(),
op.getType().getElementTypeBitWidth() /
pack.getType().getElementTypeBitWidth())[op.getIndex()]);
return success();
Value source = pack.getPaddedSources(
pack.getSources(), pack.getPositions(),
op.getType().getElementTypeBitWidth() /
pack.getType().getElementTypeBitWidth())[op.getIndex()];
if (source) {
rewriter.replaceAllOpUsesWith(op, source);
return success();
}
}
return failure();
}
Expand Down
85 changes: 52 additions & 33 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5650,40 +5650,51 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
no_op = true;
}

auto can_use_row_shuffle = [&ctx](ArrayRef<int64_t> shape,
VectorLayout layout,
std::array<int64_t, 2> vreg_slice) {
if (shape.size() < 2) {
bool can_use_row_shuffle = [&]() {
if (!llvm::isPowerOf2_32(layout_in.bitwidth())) {
return false;
}
// vreg must not be padded.
if (shape.back() % vreg_slice[1] != 0 ||
shape[shape.size() - 2] % vreg_slice[0] != 0) {
return false;
}
if (!llvm::isPowerOf2_32(layout.bitwidth())) {
return false;
}
if (layout.offsets() != LayoutOffsets{0, 0}) {
return false;
}
if (layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return false;
bool src_is_1d_tiling =
layout_in.tiling() ==
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout_in.packing()};
bool dst_is_1d_tiling =
layout_out.tiling() ==
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout_out.packing()};
bool src_is_vreg_slice_lane_aligned =
(!src_is_1d_tiling && src_tiled_dims[1] == src_vreg_slice[1]) ||
(src_is_1d_tiling && src_tiled_dims[1] % src_vreg_slice[1] == 0);
bool dst_is_vreg_slice_lane_aligned =
(!dst_is_1d_tiling && dst_tiled_dims[1] == dst_vreg_slice[1]) ||
(dst_is_1d_tiling && dst_tiled_dims[1] % dst_vreg_slice[1] == 0);
bool src_is_vreg_slice_sublane_aligned =
src_tiled_dims[0] % src_vreg_slice[0] == 0;
bool dst_is_vreg_slice_sublane_aligned =
dst_tiled_dims[0] % dst_vreg_slice[0] == 0;
if (src_is_vreg_slice_lane_aligned && dst_is_vreg_slice_lane_aligned) {
if (src_is_vreg_slice_sublane_aligned &&
dst_is_vreg_slice_sublane_aligned) {
// Both src and dst are aligned to vreg slice sublanes.
return true;
}
if (!src_is_vreg_slice_sublane_aligned &&
!dst_is_vreg_slice_sublane_aligned &&
llvm::product_of(src_tiled_dims) ==
llvm::product_of(dst_tiled_dims)) {
// Neither src nor dst are aligned to vreg slice sublanes.
// Padding happens only on the last vreg in tiled dims.
return true;
}
}
// 2d tiling.
if (layout.tiling()[0] <= ctx.target_shape[0] * layout.packing() &&
layout.tiling()[1] == ctx.target_shape[1] &&
shape.back() == vreg_slice[1]) {
if (src_is_vreg_slice_lane_aligned && dst_is_1d_tiling &&
llvm::product_of(src_tiled_dims) == dst_tiled_dims[1]) {
return true;
}
// 1d tiling.
if (layout.tiling() ==
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout.packing()} &&
shape.back() % vreg_slice[1] == 0) {
if (dst_is_vreg_slice_lane_aligned && src_is_1d_tiling &&
llvm::product_of(dst_tiled_dims) == src_tiled_dims[1]) {
return true;
}
return false;
};
}();

FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_vregs,
Expand Down Expand Up @@ -5717,13 +5728,10 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
} else if (
// Row shuffle within a vreg if there is no padding and each vreg holds
// a contiguous slice of the flattened data.
can_use_row_shuffle(src_shape, layout_in, src_vreg_slice) &&
can_use_row_shuffle(dst_shape, layout_out, dst_vreg_slice)) {
can_use_row_shuffle) {
auto [sublane_count, lane_count] = ctx.target_shape;
auto dst_vregs_shape =
layout_out.tileArrayShape(false, false, dst_shape, ctx.target_shape);
auto src_vregs_shape =
layout_in.tileArrayShape(false, false, src_shape, ctx.target_shape);
src_vregs.Reshape(
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
if (bitwidth == 32) {
// For 32 bit data, a sublane is effectively a physical row.
std::array<int64_t, 2> src_sublane_slice = {
Expand Down Expand Up @@ -5845,8 +5853,20 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
// with tiling (16, 128) and then to (8, 512) with tiling (8, 128).
const int64_t src_sublane_tiling = layout_in.tiling()[0];
const int64_t dst_sublane_tiling = layout_out.tiling()[0];
const int64_t native_sublane_tiling =
ctx.target_shape[0] * layout_in.packing();
CHECK(llvm::isPowerOf2_64(static_cast<uint64_t>(src_sublane_tiling)));
CHECK(llvm::isPowerOf2_64(static_cast<uint64_t>(dst_sublane_tiling)));
CHECK(
llvm::isPowerOf2_64(static_cast<uint64_t>(native_sublane_tiling)));
// (target_shape[0] * packing, target_shape[1]) <->
// (1, target_shape[1] * packing) is a no-op.
if ((src_sublane_tiling == 1 &&
dst_sublane_tiling == native_sublane_tiling) ||
(src_sublane_tiling == native_sublane_tiling &&
dst_sublane_tiling == 1)) {
return src_vregs;
}
tpu::PackFormat unpack_format, pack_format;
if (src_sublane_tiling > dst_sublane_tiling) {
unpack_format = tpu::PackFormat::kInterleaved;
Expand Down Expand Up @@ -5887,7 +5907,6 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
src_vreg->getLoc(), src_vreg->getType(), dst_vreg);
});
}
src_vregs.Reshape(dst_vregs_shape);
return src_vregs;
} else if (
// Lower shape_casts for {32/16/8}-bit types where the minor dimension
Expand Down
54 changes: 45 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1633,8 +1633,8 @@ class VectorLayoutInferer {
return success();
}

// Find the small tiling such that there is not padding and each vreg holds
// a continuous slice of the flatten data.
// Find the small tiling such that each vreg holds a continuous slice of the
// flatten data and each row is either fully occupied or is all padding.
auto small_second_minor_tiling_layout =
[&](ArrayRef<int64_t> shape) -> std::optional<VectorLayout> {
const int64_t elements_per_vreg = native_tiling[0] * native_tiling[1];
Expand All @@ -1659,16 +1659,15 @@ class VectorLayoutInferer {
// TODO(b/440370770): Preserve replicated offsets.
auto layout = VectorLayout(bitwidth, {0, 0}, tiling, ImplicitDim::kNone);
auto vreg_slice = layout.vregSlice(target_shape_);
if ((shape.back() != vreg_slice[1] && !can_use_1d_tiling) ||
shape[shape.size() - 2] % vreg_slice[0] != 0) {
if (shape.back() != vreg_slice[1] && !can_use_1d_tiling) {
return std::nullopt;
}
return layout;
};

// Use the small tiling if there's no padding and each vreg holds a
// contiguous slice of the flattened data. It makes reshape a row shuffle
// within a vreg.
// Use the small tiling if each vreg holds a contiguous slice of the
// flattened data and each row is either fully occupied or is all
// padding. It makes reshape a row shuffle within a vreg.
//
// For example,
// - (4, 256) with (4, 128) tiling to (1, 1024) with (1, 128) tiling is
Expand All @@ -1684,8 +1683,45 @@ class VectorLayoutInferer {

if (src_small_second_minor_tiling_layout.has_value() &&
res_small_second_minor_tiling_layout.has_value()) {
setLayout(op, *src_small_second_minor_tiling_layout,
*res_small_second_minor_tiling_layout);
auto src_vreg_slice =
src_small_second_minor_tiling_layout->vregSlice(target_shape_);
auto res_vreg_slice =
res_small_second_minor_tiling_layout->vregSlice(target_shape_);
bool src_vreg_slice_aligned =
src_shape[src_shape.size() - 2] % src_vreg_slice[0] == 0;
bool res_vreg_slice_aligned =
res_shape[res_shape.size() - 2] % res_vreg_slice[0] == 0;
if (
// Both input and output are aligned to its vreg slice.
(src_vreg_slice_aligned && res_vreg_slice_aligned) ||
// If not aligned, make sure the padding only happens on last vreg in
// tiled dims. For example, reshape i32 (12, 128) to (6, 256) with
// input tiling (8, 128) and output tiling (4, 128).
(!src_vreg_slice_aligned && !res_vreg_slice_aligned &&
llvm::product_of(src_shape.take_back(2)) ==
llvm::product_of(res_shape.take_back(2)))) {
setLayout(op, *src_small_second_minor_tiling_layout,
*res_small_second_minor_tiling_layout);
return success();
}
}
if (src_small_second_minor_tiling_layout.has_value() &&
llvm::product_of(src_shape.take_back(2)) == res_shape.back()) {
// For example, reshape i32 (8, 10, 128) to (8, 1280) with input tiling
// (8, 128) and output tiling (1, 128).
setLayout(
op, *src_small_second_minor_tiling_layout,
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing}, ImplicitDim::kNone));
return success();
}
if (res_small_second_minor_tiling_layout.has_value() &&
llvm::product_of(res_shape.take_back(2)) == src_shape.back()) {
setLayout(
op,
VectorLayout(layout.bitwidth(), {0, 0},
{1, target_shape_[1] * packing}, ImplicitDim::kNone),
*res_small_second_minor_tiling_layout);
return success();
}

Expand Down
Loading