Skip to content

[BUG] Incompatible function arguments in _ScaledBasis when parsing arg #2967

@Edenzzzz

Description

@Edenzzzz

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
I have a function that looks like this

Image

If I use the first layout (quant_fp4) for mV(set as in examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py)

  if quant_v:
      ab_dtype = cutlass.Float4E2M1FN
      v_tensor, v_torch_underlying = cutlass_torch.cute_tensor_like(
          v_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16
      )
      # Get the correct stride_order from the reference tensor
      v_stride_order = tuple(v_ref.dim_order())
      v_tensor.mark_compact_shape_dynamic(
          mode=1,  # headdim_v dimension needs divisibility for FP4
          stride_order=v_stride_order,
          divisibility=32,
      )
      v_tensor = cutlass_torch.convert_cute_tensor(
          v_ref, v_tensor, ab_dtype, is_dynamic_layout=True
      )
  else:
      # V stays as regular dtype (not FP4 quantized) - create CUTE tensor
      # Convert torch dtype to CUTE dtype
      assert dtype_gen in [torch.bfloat16, torch.float16]
      if dtype_gen == torch.bfloat16:
          v_cute_dtype = cutlass.BFloat16
      elif dtype_gen == torch.float16:
          v_cute_dtype = cutlass.Float16

      v_tensor, v_torch_underlying = cutlass_torch.cute_tensor_like(
          v_ref, v_cute_dtype, is_dynamic_layout=True, assumed_align=16
      )
      # Get the correct stride_order from the reference tensor
      v_stride_order = tuple(v_ref.dim_order())
      v_tensor.mark_compact_shape_dynamic(
          mode=1,  # headdim_v dimension
          stride_order=v_stride_order,
          divisibility=16, 
      )
      v_tensor = cutlass_torch.convert_cute_tensor(
          v_ref, v_tensor, v_cute_dtype, is_dynamic_layout=True
      )

it immediately throws an error when parsing the string of mV when calling cute.compile, inside class ScaledBasis, where the scale is ? unlike in previous args. This is really confusing and would love any help on it

Image
 File "/sgl-workspace/cutlass/examples/python/CuTeDSL/blackwell/flash-attention/flash_attn/cute/interface.py", line 702, in _flash_attn_fwd
    _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi",
                                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/cutlass/examples/python/CuTeDSL/blackwell/flash-attention/flash_attn/cute/cute_dsl_utils.py", line 118, in cute_compile_patched
    output = cute_compile_og(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 578, in __call__
    return self._compile(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 653, in _compile
    return func._dsl_object._func(func, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1572, in _func
    result = self.generate_mlir(
             ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1373, in generate_mlir
    module, module_hash, result = self.generate_original_ir(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1193, in generate_original_ir
    module, result = profiler(build_ir_module)()
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/utils/timer.py", line 29, in func_wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1175, in build_ir_module
    result = funcBody(*ir_args, **ir_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/cutlass/examples/python/CuTeDSL/blackwell/flash-attention/flash_attn/cute/flash_fwd_sm100_fp4.py", line 990, in __call__
    ).launch(
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cutlass_dsl/cutlass.py", line 1012, in launch
    ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 1780, in kernel_wrapper
    kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}"
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cutlass_dsl/cutlass.py", line 773, in mangle_name
    return super().mangle_name(function_name, args, args_spec)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 577, in mangle_name
    class_name = str(arg).replace("class", "")
                 ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/tensor.py", line 158, in __str__
    return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>"
                                                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1346, in pretty_str
    return arg.__str__()
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 803, in __str__
    return f"{pretty_str(self.shape)}:{pretty_str(self.stride)}"
                                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1340, in pretty_str
    return _tuple_str(arg)
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1330, in _tuple_str
    res = "(" + construct_inner_str(t) + ")"
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1325, in construct_inner_str
    res += pretty_str(t[i])
           ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 1346, in pretty_str
    return arg.__str__()
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 590, in __str__
    return f"{self.to(_ScaledBasis).__str__()}"
              ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py", line 60, in wrapper
    res_or_list = opFunc(*args, **kwargs, loc=loc)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/cute/core.py", line 584, in to
    ret = _ScaledBasis(scale, self._mode)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: __init__(): incompatible function arguments. The following argument types are supported:
    1. __init__(self, value: int, mode: collections.abc.Sequence[int]) -> None
    2. __init__(self, value: cutlass._mlir._mlir_libs._cutlass_ir._cute.Ratio, mode: collections.abc.Sequence[int]) -> None
    3. __init__(self, value: object, mode: collections.abc.Sequence[int], divisibility: int) -> None

Invoked with types: cutlass._mlir._mlir_libs._cutlass_ir._cute.ScaledBasis, cutlass.base_dsl._mlir_helpers.arith.ArithValue, list

Steps/Code to reproduce bug
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.

Expected behavior
A clear and concise description of what you expected to happen.

Environment details (please complete the following information):

  • CuTe-dsl 4.3.5

Additional context
Add any other context about the problem here.
cc @kongroo

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions