-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
I have a function that looks like this
If I use the first layout (quant_fp4) for mV(set as in examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)
if quant_v:
ab_dtype = cutlass.Float4E2M1FN
v_tensor, v_torch_underlying = cutlass_torch.cute_tensor_like(
v_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16
)
# Get the correct stride_order from the reference tensor
v_stride_order = tuple(v_ref.dim_order())
v_tensor.mark_compact_shape_dynamic(
mode=1, # headdim_v dimension needs divisibility for FP4
stride_order=v_stride_order,
divisibility=32,
)
v_tensor = cutlass_torch.convert_cute_tensor(
v_ref, v_tensor, ab_dtype, is_dynamic_layout=True
)
else:
# V stays as regular dtype (not FP4 quantized) - create CUTE tensor
# Convert torch dtype to CUTE dtype
assert dtype_gen in [torch.bfloat16, torch.float16]
if dtype_gen == torch.bfloat16:
v_cute_dtype = cutlass.BFloat16
elif dtype_gen == torch.float16:
v_cute_dtype = cutlass.Float16
v_tensor, v_torch_underlying = cutlass_torch.cute_tensor_like(
v_ref, v_cute_dtype, is_dynamic_layout=True, assumed_align=16
)
# Get the correct stride_order from the reference tensor
v_stride_order = tuple(v_ref.dim_order())
v_tensor.mark_compact_shape_dynamic(
mode=1, # headdim_v dimension
stride_order=v_stride_order,
divisibility=16,
)
v_tensor = cutlass_torch.convert_cute_tensor(
v_ref, v_tensor, v_cute_dtype, is_dynamic_layout=True
)
it immediately throws an error when parsing the string of mV when calling cute.compile, inside class ScaledBasis, where the scale is ? unlike in previous args. This is really confusing and would love any help on it
File "/sgl-workspace/cutlass/examples/python/CuTeDSL/blackwell/flash-attention/flash_attn/cute/interface.py", line 702, in _flash_attn_fwd
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/cutlass/examples/python/CuTeDSL/blackwell/flash-attention/flash_attn/cute/cute_dsl_utils.py", line 118, in cute_compile_patched
output = cute_compile_og(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 578, in __call__
return self._compile(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 653, in _compile
return func._dsl_object._func(func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1572, in _func
result = self.generate_mlir(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1373, in generate_mlir
module, module_hash, result = self.generate_original_ir(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1193, in generate_original_ir
module, result = profiler(build_ir_module)()
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/utils/timer.py", line 29, in func_wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1175, in build_ir_module
result = funcBody(*ir_args, **ir_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sgl-workspace/cutlass/examples/python/CuTeDSL/blackwell/flash-attention/flash_attn/cute/flash_fwd_sm100_fp4.py", line 990, in __call__
).launch(
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cutlass_dsl/cutlass.py", line 1012, in launch
ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1780, in kernel_wrapper
kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cutlass_dsl/cutlass.py", line 773, in mangle_name
return super().mangle_name(function_name, args, args_spec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 577, in mangle_name
class_name = str(arg).replace("class", "")
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/tensor.py", line 158, in __str__
return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>"
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1346, in pretty_str
return arg.__str__()
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 803, in __str__
return f"{pretty_str(self.shape)}:{pretty_str(self.stride)}"
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1340, in pretty_str
return _tuple_str(arg)
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1330, in _tuple_str
res = "(" + construct_inner_str(t) + ")"
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1325, in construct_inner_str
res += pretty_str(t[i])
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1346, in pretty_str
return arg.__str__()
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 590, in __str__
return f"{self.to(_ScaledBasis).__str__()}"
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py", line 60, in wrapper
res_or_list = opFunc(*args, **kwargs, loc=loc)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 584, in to
ret = _ScaledBasis(scale, self._mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: __init__(): incompatible function arguments. The following argument types are supported:
1. __init__(self, value: int, mode: collections.abc.Sequence[int]) -> None
2. __init__(self, value: cutlass._mlir._mlir_libs._cutlass_ir._cute.Ratio, mode: collections.abc.Sequence[int]) -> None
3. __init__(self, value: object, mode: collections.abc.Sequence[int], divisibility: int) -> None
Invoked with types: cutlass._mlir._mlir_libs._cutlass_ir._cute.ScaledBasis, cutlass.base_dsl._mlir_helpers.arith.ArithValue, list
Steps/Code to reproduce bug
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.
Expected behavior
A clear and concise description of what you expected to happen.
Environment details (please complete the following information):
- CuTe-dsl 4.3.5
Additional context
Add any other context about the problem here.
cc @kongroo