Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 107 additions & 84 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,7 +1822,8 @@ def astype(
# Otherwise, mypy is unhappy with using ``idx`` for both range and
# np.ndenumerate.
idx: Any
reg_type = self.registers.flat[0].type
any_reg = self.registers.flat[0]
reg_type = any_reg.type
is_vector_reg = ir.VectorType.isinstance(reg_type)
reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else (1,)
[vector_len] = reg_shape # This is meant to be a 1D assertion.
Expand All @@ -1831,6 +1832,41 @@ def astype(
"Register bitwidth in target type must be divisible by 8, got"
f" {new_reg_bitwidth}"
)
# If the vector originates from a slice (common after relayouts), we
# can fuse the slicing into the conversion and reuse many
# preprocessing ops (shifts, prmts) accross different vectors.
regs_from_32bit_slice = (
isinstance(
_slice_op := any_reg.owner.opview, vector.ExtractStridedSliceOp
)
and utils.bitwidth(_slice_op.source.type) == 32
and _slice_op.strides[0].value == 1
)
def packed_registers(
dst_vector_len: int, *, if_not_sliced: bool
) -> Iterable[tuple[Sequence[tuple[int, ...]], ir.Value]]:
"""Tries to pack registers up to destination vector length."""
if regs_from_32bit_slice and if_not_sliced:
for idx, reg in np.ndenumerate(self.registers):
yield [idx], reg
return
generator = np.ndenumerate(self.registers)
indices = []
regs = []
while True:
try:
for _ in range(max(dst_vector_len // vector_len, 1)):
idx, reg = next(generator)
indices.append(idx)
regs.append(reg)
yield indices, utils.vector_concat(regs)
regs.clear()
indices.clear()
except StopIteration:
break
if regs:
yield indices, utils.vector_concat(regs)

if cur_dtype == i4 and new_dtype == f8e4m3fn:
# The algorithm here is taken from CUTLASS's `NumericArrayConverter`
# specialization for int4 -> f8e4m3, available at
Expand Down Expand Up @@ -1871,27 +1907,11 @@ def upcast_to_f8e4m3fn(reg: ir.Value, part: int):
)
new_registers = np.empty_like(self.registers)

def packed_registers() -> Iterable[tuple[Sequence[int], ir.Value]]:
"""Tries to pack registers into groups of 16 bits if vector_len < 4."""
generator = np.ndenumerate(self.registers)
indices = []
regs = []
while True:
try:
for _ in range(max(4 // vector_len, 1)):
idx, reg = next(generator)
indices.append(cast(int, idx))
regs.append(reg)
yield indices, utils.vector_concat(regs)
regs.clear()
indices.clear()
except StopIteration:
break
if regs:
yield indices, utils.vector_concat(regs)

for indices, reg in packed_registers():
group_size = ir.VectorType(reg.type).shape[0]
# TODO(apaszke,bchetioui): Using 8 helps some (but not all) cases.
# TODO(apaszke,bchetioui): Add the slice optimization here.
packing_width = 8 if vector_len == 2 else 4
for indices, reg in packed_registers(packing_width, if_not_sliced=False):
[group_size] = ir.VectorType(reg.type).shape
assert group_size % vector_len == 0
int_ty = ir.IntegerType.get_signless(group_size * 4)
reg_as_i32 = utils.bitcast(reg, int_ty)
Expand Down Expand Up @@ -1926,7 +1946,11 @@ def packed_registers() -> Iterable[tuple[Sequence[int], ir.Value]]:
if cur_dtype == i4 and self.is_signed and new_dtype == bf16 and vector_len % 2 == 0:
new_registers = np.empty_like(self.registers)
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
for idx, reg in np.ndenumerate(self.registers):
# We use packed_registers for consistency, even though the packing is not
# really profitable here: the PTX below begins by an op dependent on the
# extracted part and so there are no ops that can be shared across packed
# parts.
for indices, reg in packed_registers(2, if_not_sliced=True):
# The algorithm here is largely the same as CUTLASS's
# NumericArrayConverter specialization for int4 -> bf16 casts.
# We modify it slightly, because we only extract 2 values.
Expand All @@ -1942,7 +1966,7 @@ def packed_registers() -> Iterable[tuple[Sequence[int], ir.Value]]:
# bias coming from flipping the sign bit which is 136 (0x4308 as bits).
def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
assert 0 <= part < 4
return llvm.inline_asm(
int_reg = llvm.inline_asm(
i32,
[reg, reg_shr],
f"""
Expand All @@ -1956,49 +1980,50 @@ def upcast_i4_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
""",
"=r,r,r",
)
return utils.bitcast(int_reg, ir.VectorType.get((2,), bf16))
[group_size] = ir.VectorType(reg.type).shape
assert group_size % vector_len == 0
assert group_size * 4 <= 32
int_ty = ir.IntegerType.get_signless(group_size * 4)
# If the vector originates from a slice (common after relayouts), we
# can fuse the slicing into the conversion and prevent LLVM from
# generating a bunch of shifts to align the vector data to the LSB.
# This also lets us share the right shift among more vectors.
out_int_regs = []
if regs_from_32bit_slice:
slice_op = reg.owner.opview
slice_offset = slice_op.offsets[0].value
reg_int = utils.bitcast(slice_op.source, i32)
reg_int_shr = arith.shrui(reg_int, c(4, i32))
assert slice_offset % 2 == 0
out_int_regs.extend(
upcast_i4_to_bf16(reg_int, reg_int_shr, part=slice_offset // 2 + part)
for part in range(group_size // 2)
)
else:
reg_slice_int = utils.bitcast(reg, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_i4_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
for part in range(group_size // 2)
)
out_reg = utils.vector_concat(out_int_regs)
offset = 0
out_int_regs: list[ir.Value] = []
for group_size in (8, 4, 2):
int_ty = ir.IntegerType.get_signless(group_size * 4)
while vector_len - offset >= group_size:
# If the vector originates from a slice (common after relayouts), we
# can fuse the slicing into the conversion and prevent LLVM from
# generating a bunch of shifts to align the vector data to the LSB.
# This also lets us share the right shift among more vectors.
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
and utils.bitwidth(slice_op.source.type) == 32
and slice_op.strides[0].value == 1):
slice_offset = slice_op.offsets[0].value + offset
reg_int = utils.bitcast(slice_op.source, i32)
reg_int_shr = arith.shrui(reg_int, c(4, i32))
out_int_regs.extend(
upcast_i4_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
for part in range(group_size // 2)
)
else:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = utils.bitcast(reg_slice, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_i4_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
for part in range(group_size // 2)
)
offset += group_size
assert offset == vector_len
out_vec_int = utils.vector_concat([
vector.broadcast(ir.VectorType.get((1,), i32), reg)
for reg in out_int_regs
])
new_registers[idx] = utils.bitcast(out_vec_int, out_vec_ty)
for idx in indices:
new_registers[idx] = new_reg = utils.vector_slice(
out_reg, slice(offset, offset + vector_len)
)
offset += vector_len
assert new_reg.type == out_vec_ty
return FragmentedArray(
_registers=new_registers, _layout=self.layout, _is_signed=None
)
if cur_dtype == i4 and self.is_signed and new_dtype == i8 and is_signed:
new_registers = np.empty_like(self.registers)
out_vec_ty = ir.VectorType.get((vector_len,), new_dtype)
for idx, reg in np.ndenumerate(self.registers):
for indices, reg in packed_registers(8, if_not_sliced=True):
def upcast_i4_to_i8(reg: ir.Value, first_valid_nibble: int = 0):
# When first_valid_nibble is >0, then only the nibbles in the range
# [first_valid_nibble, 8) will be upcast and placed in the low
Expand Down Expand Up @@ -2035,31 +2060,29 @@ def upcast_i4_to_i8(reg: ir.Value, first_valid_nibble: int = 0):
utils.bitcast(llvm.extractvalue(i32, out_struct, (i,)), i8_vec)
for i in range(2)
])
[group_size] = ir.VectorType(reg.type).shape
assert group_size % vector_len == 0
assert group_size * 4 <= 32
int_ty = ir.IntegerType.get_signless(group_size * 4)
if regs_from_32bit_slice:
slice_op = reg.owner.opview
slice_offset = slice_op.offsets[0].value
reg_int = utils.bitcast(slice_op.source, i32)
reg_i8 = upcast_i4_to_i8(reg_int, first_valid_nibble=slice_offset)
else:
reg_slice_int = utils.bitcast(reg, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_i8 = upcast_i4_to_i8(reg_slice_int)

# distribute packed registers to original indices
offset = 0
out_regs: list[ir.Value] = []
for group_size in (8, 4, 2):
int_ty = ir.IntegerType.get_signless(group_size * 4)
while vector_len - offset >= group_size:
# If the vector originates from a slice (common after relayouts), we
# can fuse the slicing into the conversion and reuse many
# preprocessing ops (shifts, prmts) accross different vectors.
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
and utils.bitwidth(slice_op.source.type) == 32
and slice_op.strides[0].value == 1):
slice_offset = slice_op.offsets[0].value + offset
reg_int = utils.bitcast(slice_op.source, i32)
reg_i8 = upcast_i4_to_i8(reg_int, first_valid_nibble=slice_offset)
else:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = utils.bitcast(reg_slice, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_i8 = upcast_i4_to_i8(reg_slice_int)
out_regs.append(utils.vector_slice(reg_i8, slice(group_size)))
offset += group_size
assert offset == vector_len
new_registers[idx] = new_reg = utils.vector_concat(out_regs)
assert new_reg.type == out_vec_ty
for idx in indices:
new_registers[idx] = new_reg = utils.vector_slice(
reg_i8, slice(offset, offset + vector_len)
)
offset += vector_len
assert new_reg.type == out_vec_ty
return FragmentedArray(
_registers=new_registers, _layout=self.layout, _is_signed=is_signed
)
Expand Down
31 changes: 13 additions & 18 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,34 +1794,29 @@ def vector_slice(v: ir.Value, s: slice):


def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
index = ir.IndexType.get()
if not vectors:
raise ValueError("Cannot concatenate an empty list of vectors")
vty = vectors[0].type
if not ir.VectorType.isinstance(vty):
raise ValueError("Cannot concatenate non-vector values")
vty = ir.VectorType(vty)
if vty.rank != 1:
raise NotImplementedError("Only 1D vectors are supported")
for v in vectors:
if v.type != vty:
raise ValueError("Cannot concatenate vectors of different types")
result = llvm.mlir_undef(
ir.VectorType.get((vty.shape[0] * len(vectors),), vty.element_type)
)
offset = 0
for v in vectors:
for i in range(vty.shape[0]):
elem = vector.extract(
v, dynamic_position=[], static_position=ir.DenseI64ArrayAttr.get([i])
)
result = vector.insert(
elem,
result,
dynamic_position=[],
static_position=ir.DenseI64ArrayAttr.get([offset + i]),
)
offset += vty.shape[0]
return result
return _vector_concat_rec(vectors)


def _vector_concat_rec(vectors: Sequence[ir.Value]) -> ir.Value:
if len(vectors) == 1:
return vectors[0]
elif len(vectors) == 2:
[vec_len] = ir.VectorType(vectors[0].type).shape
return vector.shuffle(*vectors, mask=ir.DenseI64ArrayAttr.get(list(range(2 * vec_len))))
l = _vector_concat_rec(vectors[: len(vectors) // 2])
r = _vector_concat_rec(vectors[len(vectors) // 2 :])
return _vector_concat_rec([l, r])


def is_known_divisible(value, divisor, max_depth=10) -> bool:
Expand Down
54 changes: 38 additions & 16 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,17 +663,25 @@ def kernel(ctx, inp, out, smem):
jax_dtype_from_to=(
(jnp.int8, jnp.bfloat16),
(jnp.int4, jnp.bfloat16),
(jnp.int4, jnp.float8_e4m3fn),
(jnp.int4, jnp.int8),
# TODO(apaszke,bchetioui): bf16/f32 -> f8e4m3fn
),
layout=(
fa.WGMMA_LAYOUT,
fa.WGMMA_LAYOUT_UPCAST_2X,
fa.WGMMA_LAYOUT_UPCAST_4X,
layout_descs=(
("WGMMA_LAYOUT", "WGMMA_LAYOUT"),
("WGMMA_LAYOUT_8BIT", "WGMMA_LAYOUT_8BIT"),
("WGMMA_LAYOUT_UPCAST_2X", "WGMMA_LAYOUT_UPCAST_2X"),
("WGMMA_LAYOUT_UPCAST_2X", "WGMMA_LAYOUT"),
("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT_UPCAST_4X"),
("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT_UPCAST_2X"),
("WGMMA_LAYOUT_UPCAST_4X", "WGMMA_LAYOUT"),
),
change_layout=(False, True),
)
@jtu.skip_if_mosaic_gpu_exceeds_shared_memory(device_patterns="RTX PRO 6000 Blackwell")
def test_optimized_conversion(self, jax_dtype_from_to, layout, change_layout):
def test_optimized_conversion(self, jax_dtype_from_to, layout_descs):
layout_desc_from, layout_desc_to = layout_descs
layout_from: fa.TiledLayout = getattr(fa, layout_desc_from)
layout_to: fa.TiledLayout = getattr(fa, layout_desc_to)
jax_dtype_from, jax_dtype_to = jax_dtype_from_to
mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from)
mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to)
Expand All @@ -684,21 +692,19 @@ def kernel(ctx, inp, out, smem):
t = mgpu.FragmentedArray.load_untiled(
inp,
is_signed=utils.is_signed(jax_dtype_from),
layout=layout,
layout=layout_from,
optimized=False,
)
if change_layout:
if layout_from != layout_to:
if (
layout == fa.WGMMA_LAYOUT_UPCAST_4X
and utils.bitwidth(mlir_dtype_from) > 4
layout_from == fa.WGMMA_LAYOUT_UPCAST_4X
and utils.bitwidth(mlir_dtype_from) != 4
):
self.skipTest("Unimplemented relayout")
t = t.to_layout(fa.WGMMA_LAYOUT)
t = t.to_layout(layout_to)
t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to))
t.store_untiled(out, optimized=False)

# We only test lossless conversions for now.
# TODO(apaszke): Test and fix failures that appear with lossy conversions.
int_sample_dtype = getattr(
jnp,
"int" + str(min(bitwidth(mlir_dtype_from), bitwidth(mlir_dtype_to))),
Expand All @@ -709,9 +715,25 @@ def kernel(ctx, inp, out, smem):
).astype(jax_dtype_from)

expected = values.astype(np.int32).astype(jax_dtype_to)
res = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), values, expected, ()
)(values)
@contextlib.contextmanager
def _maybe_profile():
yield; return # Comment to gather statistics.
with jtu.set_env(MOSAIC_GPU_DUMP_SASS="1"), self.capture_stdout() as sass:
yield
log_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp")
file_path = os.path.join(log_dir, "conversion_stats.csv")
with open(file_path, "a") as f:
data = (
jnp.dtype(jax_dtype_from).name, jnp.dtype(jax_dtype_to).name,
layout_desc_from, layout_desc_to, sass().count("\n")
)
f.write(",".join(map(str, data)) + "\n")
f.flush()
self.fail("Disable profiling before submission")
with _maybe_profile():
res = mgpu.as_gpu_kernel(
kernel, (1, 1, 1), (128, 1, 1), values, expected, ()
)(values)
np.testing.assert_array_equal(res, expected)

@parameterized.named_parameters(
Expand Down
Loading