Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Oct 25, 2025

📌 Description

DRAFT. Please do not merge.

Current PR:

  • Introduces an auto backend to mm_fp4 that can be autotuned. It replaces cudnn as the default.
  • Allows cudnn backend to be autotuned.

Behavior of auto backend:

  • Examines CUDA version & cuDNN version and calls either cutlass or cudnn kernel backends.
    • trtllm kernel is not considered due to a non-interchangeable interface between trtllm and (cutlass, cudnn) backend.
    • ** auto backend therefore only supports inputs runnable by cutlass and/or cudnn**
  • Non-autotuned behavior:
    • Constructs an ordered list of backends (cudnn, cutlass) or (cutlass, cudnn) where ordering is based on previous microbenchmark study results.
      • If CUDA 12 --> cutlass comes to front.
      • If CUDA 13 and cuDNN version < 9.14 --> cutlass comes front
      • If CUDA 13 and cuDNN version >= 9.14 --> cudnn comes front
    • If kernel is not available from a support check, it is removed from the list.
      • For example, if use_nvfp4=False, cutlass will be removed from the backend list as it fails support check.
  • Autotune behavior:
    • If backend='trtllm' or backend='cutlass' or backend='cudnn' --> Autotunes within the backend. Same as previous behavior, but now autotuning is supported for cudnn.
    • If backend='auto' --> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend.
      • trtllm kernel is not considered due to a non-interchangeable interface between trtllm and (cutlass, cudnn) backend.
  • Note: A lot of helper functions to mm_fp4 were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabled bmm_fp8 as a reference.

🔍 Related Issues

#1722

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Automatic backend selection for FP4 matrix multiplication via backend="auto".
    • CUDA version detection utility for device compatibility.
    • Tactic selection support for cuDNN FP4 operations.
  • Refactor

    • Restructured FP4 GEMM runner components and runner exports for modularity.
  • Behavior Changes

    • FP4 backend availability is now determined at runtime via support checks (no static backend list).
  • Tests

    • Tests reorganized and new auto-backend test added for FP4.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 25, 2025

Walkthrough

Removes static mm_fp4 compute-capability → backend mapping in benchmarks; adds "auto" backend support and runtime backend-selection/autotuning for FP4 GEMM across flashinfer/gemm.py, updates runner names and cuDNN tactic plumbing, extends tests to exercise backend="auto" and refactors test entry points.

Changes

Cohort / File(s) Summary
Benchmark Backend Mapping
benchmarks/routines/flashinfer_benchmark_utils.py
Removes the static mm_fp4 entry from routine_cc_to_supported_backends (previous cc→backend map). Adds a comment that mm_fp4 uses runtime support checkers instead of a static list.
Benchmark Test Framework
benchmarks/routines/gemm.py
Extends parse_gemm_args to accept "auto" for backends; updates mm_fp4-related test logic to accept and validate "auto", replaces static cc-based filtering with runtime backend validation and dynamic backend removal.
Core GEMM Library
flashinfer/gemm.py
Adds get_cuda_version(device); renames/exports runner accessors to *_runner variants (e.g., cutlass_fp4_gemm_runner, trtllm_fp4_gemm_runner); introduces tactic-aware cuDNN FP4 plan/execution APIs; implements backend="auto" support for mm_fp4 with _auto_gemm_fp4_requirement, multi-runner construction, autotuning, and runtime backend selection.
Test Suite
tests/gemm/test_mm_fp4.py
Extracts core test logic to _test_mm_fp4, makes test_mm_fp4 a wrapper, and adds test_mm_fp4_backend_auto to exercise backend="auto"; consolidates skip rules and adjusts parameterization/thresholds.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Caller
    participant mm_fp4
    participant SupportCheckers
    participant BackendSelector
    participant RunnerFactory
    participant Runner

    Caller->>mm_fp4: call mm_fp4(backend="auto", ...)
    mm_fp4->>SupportCheckers: query device capability & requirements
    SupportCheckers-->>mm_fp4: feasible backends (list)
    
    rect rgba(120,180,200,0.12)
    Note over mm_fp4,BackendSelector: Auto-backend filtering & validation
    mm_fp4->>BackendSelector: apply layout/size/constraint checks
    BackendSelector-->>mm_fp4: filtered backends
    end

    mm_fp4->>RunnerFactory: construct runners for filtered backends
    RunnerFactory-->>mm_fp4: runners
    mm_fp4->>Runner: autotune/benchmark runners (parallel/iterative)
    Runner-->>mm_fp4: metrics
    mm_fp4->>Runner: pick best runner and execute
    Runner-->>Caller: result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay attention to device capability detection & support-checker logic in flashinfer/gemm.py.
  • Verify correct renaming and export of runner accessors and that callers are updated accordingly.
  • Review autotuning/autoselection flow for race/ordering issues and correct metrics aggregation.
  • Confirm benchmarks/tests correctly handle removed static mapping and that no code still expects the removed mm_fp4 entry.

Possibly related PRs

Suggested reviewers

  • Anerudhan
  • nvmbreughe
  • cyx-6
  • yzh119
  • wenscarl

Poem

🐰 Hopped from cc maps to runtime checks today,

Auto picks runners so workloads may play,
Tactics and CUDA versions checked with care,
Tests now wrap and probe the auto-backend lair,
Hop, tune, run — performance carrots everywhere 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.15% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "[wip] feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn'" clearly and concisely summarizes the primary changes in the changeset. The title is directly related to the main objectives of the PR, which involve introducing an auto backend option for the mm_fp4 function and enabling autotuning capabilities for the cudnn backend. The title is specific, avoids vague language, and accurately reflects the key improvements across all modified files (flashinfer/gemm.py, benchmarks/routines/gemm.py, benchmarks/routines/flashinfer_benchmark_utils.py, and tests/gemm/test_mm_fp4.py). The [wip] prefix appropriately signals draft status without detracting from the substantive message.
Description Check ✅ Passed The pull request description comprehensively follows the provided template structure. It includes all required sections: a detailed 📌 Description explaining what the PR does and why (introducing auto backend and cross-backend autotuning), a 🔍 Related Issues section linking to #1722, and a complete 🚀 Pull Request Checklist with pre-commit checks and tests sections marked as complete. The description provides substantial detail about the auto backend behavior under different conditions (CUDA versions, cuDNN versions, support checks), autotune logic, and references to the refactoring pattern based on bmm_fp8. The explicit "DRAFT. Please do not merge." notice is clearly communicated, and all required information is present and well-organized.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu changed the title feat: Add backend='auto' to mm_fp4 feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Oct 25, 2025
@bkryu bkryu changed the title feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' feat: [DRAFT] Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Oct 25, 2025
@bkryu bkryu changed the title feat: [DRAFT] Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' [wip] feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' Oct 25, 2025
@bkryu bkryu self-assigned this Oct 27, 2025
@bkryu bkryu force-pushed the mm_fp4_auto_backend branch 4 times, most recently from bc94c4c to 254827a Compare October 30, 2025 17:19
@bkryu bkryu marked this pull request as ready for review October 30, 2025 17:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/gemm.py (1)

2096-2134: Consider extracting auto-backend selection into a helper function.

The auto-backend selection logic (lines 2096-2134) is complex and involves:

  1. CUDA/cuDNN version inspection
  2. Backend ordering heuristics
  3. Problem size validation
  4. Exception handling for unsupported configurations

This logic could benefit from extraction into a dedicated helper function (e.g., _select_mm_fp4_backends) to improve readability and testability.

Additionally, the bare except Exception at lines 2131-2132 might hide unexpected errors. Consider either:

  1. Catching more specific exceptions (e.g., ValueError, RuntimeError)
  2. Adding logging to track which backends fail validation and why

Example refactoring:

def _select_mm_fp4_backends(
    cuda_major: int,
    cudnn_version: int,
    a: torch.Tensor,
    b: torch.Tensor,
    a_descale: torch.Tensor,
    b_descale: torch.Tensor,
    alpha: Optional[torch.Tensor],
    out_dtype: torch.dtype,
    out: torch.Tensor,
    block_size: int,
    use_8x4_sf_layout: bool,
    use_nvfp4: bool,
) -> List[str]:
    """Select supported backends for mm_fp4 based on device capabilities."""
    # Backend ordering heuristics
    if cuda_major >= 13 and cudnn_version >= 91400:
        candidate_backends = ("cudnn", "cutlass")
    else:
        candidate_backends = ("cutlass", "cudnn")
    
    # Filter by problem size support
    backends = []
    for candidate in candidate_backends:
        try:
            _check_mm_fp4_problem_size(
                a, b, a_descale, b_descale, alpha, out_dtype,
                out, block_size, use_8x4_sf_layout,
                cast(Literal["cudnn", "trtllm", "cutlass", "auto"], candidate),
                use_nvfp4,
            )
            backends.append(candidate)
        except (ValueError, RuntimeError):
            pass  # Backend not supported for this problem
    
    return backends
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9287c9 and 254827a.

📒 Files selected for processing (4)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
  • benchmarks/routines/gemm.py (5 hunks)
  • flashinfer/gemm.py (17 hunks)
  • tests/gemm/test_mm_fp4.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (4)
flashinfer/jit/cpp_ext.py (1)
  • get_cuda_version (64-83)
flashinfer/autotuner.py (9)
  • TunableRunner (194-247)
  • get_valid_tactics (196-214)
  • OptimizationProfile (168-183)
  • forward (220-244)
  • AutoTuner (335-784)
  • get (362-365)
  • TuningConfig (101-141)
  • choose_one (400-529)
  • get_opt_shapes (177-183)
flashinfer/trtllm_low_latency_gemm.py (2)
  • get_valid_tactics (52-77)
  • forward (79-109)
flashinfer/utils.py (4)
  • supported_compute_capability (772-852)
  • get_compute_capability (251-254)
  • is_compute_capability_supported (966-972)
  • backend_requirement (855-1028)
🪛 Ruff (0.14.2)
flashinfer/gemm.py

96-96: Unused function argument: device

(ARG001)


432-432: Unused method argument: inputs

(ARG002)


433-433: Unused method argument: profile

(ARG002)


441-441: Unused method argument: do_preparation

(ARG002)


442-442: Unused method argument: kwargs

(ARG002)


1722-1722: Unused method argument: profile

(ARG002)


1733-1733: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1736-1736: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1772-1772: Unused method argument: do_preparation

(ARG002)


1773-1773: Unused method argument: kwargs

(ARG002)


1855-1855: Avoid specifying long messages outside the exception class

(TRY003)


1876-1876: Unused function argument: backend

(ARG001)


1934-1934: Unused function argument: backend

(ARG001)


1956-1956: Unused function argument: backend

(ARG001)


1957-1957: Unused function argument: use_nvfp4

(ARG001)


1965-1965: Unused function argument: b

(ARG001)


1966-1966: Unused function argument: a_descale

(ARG001)


1967-1967: Unused function argument: b_descale

(ARG001)


1968-1968: Unused function argument: alpha

(ARG001)


1969-1969: Unused function argument: out_dtype

(ARG001)


1970-1970: Unused function argument: out

(ARG001)


1971-1971: Unused function argument: block_size

(ARG001)


1972-1972: Unused function argument: use_8x4_sf_layout

(ARG001)


1973-1973: Unused function argument: backend

(ARG001)


1974-1974: Unused function argument: use_nvfp4

(ARG001)


2099-2099: Unpacked variable cc_major is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2099-2099: Unpacked variable cc_minor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2131-2132: try-except-pass detected, consider logging the exception

(S110)


2131-2131: Do not catch blind exception: Exception

(BLE001)


2163-2163: Avoid specifying long messages outside the exception class

(TRY003)


2509-2509: Unpacked variable a_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2510-2510: Unpacked variable b_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2511-2511: Unpacked variable alpha is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2513-2513: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2516-2516: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2530-2530: Unused method argument: do_preparation

(ARG002)


2531-2531: Unused method argument: kwargs

(ARG002)

🔇 Additional comments (12)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

241-243: LGTM! Auto backend addition is correct.

The addition of "auto" to the supported backends list for mm_fp4 at compute capabilities 10.0, 10.3, and 12.0 is consistent with the PR objectives and aligns with the auto-backend implementation in flashinfer/gemm.py.

benchmarks/routines/gemm.py (2)

134-134: LGTM! Backend choices updated correctly.

The addition of "auto" to the --backends argument choices is consistent with the auto-backend support introduced in this PR.


793-793: LGTM! Auto backend support properly integrated.

The changes correctly:

  1. Add "auto" to the list of autotune-supported backends for mm_fp4
  2. Implement backend filtering logic for "auto" that respects the use_128x4_sf_layout constraint
  3. Include "auto" in the run_backend execution path

The filtering logic at lines 836-842 appropriately mirrors the filtering done for other backends (cudnn, cutlass) and ensures "auto" is removed when layout constraints aren't met.

Also applies to: 836-842, 899-899

flashinfer/gemm.py (7)

425-465: LGTM! Runner refactoring improves consistency.

The refactoring of CUTLASS FP4 GEMM into cutlass_fp4_gemm_runner with the helper function _create_cutlass_fp4_gemm_module improves naming consistency and aligns with the pattern used for other runners (e.g., trtllm_fp4_gemm_runner).


1270-1294: LGTM! cuDNN tactic support enables fine-grained autotuning.

The addition of tactic parameter to build_plans_cudnn_fp4_gemm_graph and execute_cudnn_gemm_fp4_graph enables plan-specific execution for autotuning. The logic correctly:

  • Builds a specific plan when tactic != -1
  • Builds all plans when tactic == -1 (fallback)
  • Executes the selected plan or uses default execution

This aligns with the autotuning framework's expectations and follows the pattern established by other tunable runners.

Also applies to: 1306-1331


1665-1802: LGTM! cuDNN FP4 runner properly implements TunableRunner interface.

The new _cudnn_gemm_fp4 and _cudnn_gemm_fp4_runner functions correctly:

  1. Encapsulate cuDNN FP4 GEMM execution with tactic support
  2. Implement the TunableRunner interface with get_valid_tactics and forward methods
  3. Query available execution plans from the cuDNN graph
  4. Support tactic-specific execution for autotuning

The implementation follows the established pattern for tunable runners and integrates well with the autotuning framework.


1962-1997: LGTM! Auto backend requirement validation is well-implemented.

The _auto_gemm_fp4_requirement function correctly validates that the "auto" backend can be used by:

  1. Checking compute capability support for candidate backends (cudnn, cutlass)
  2. Explicitly excluding trtllm due to its different interface (as documented in the PR description)
  3. Returning True if at least one backend is supported

The implementation ensures that "auto" will only be accepted on devices where at least one compatible backend is available.


2136-2163: LGTM! Runner construction logic handles all backend cases correctly.

The runner construction for each backend (cudnn, trtllm, cutlass) correctly:

  1. Creates appropriate runner instances based on backend type
  2. Handles dtype conversions for cutlass backend (uint8 ↔ float8_e4m3fn)
  3. Dispatches to the correct module based on device architecture (SM120 vs SM100/SM103)
  4. Falls through to a clear error for unsupported backends

The logic is well-structured and handles all supported backend configurations.


2165-2217: LGTM! Autotuning integration is well-structured.

The autotuning setup correctly:

  1. Defines dynamic tensor specs for batch size variation (power-of-2 bucketing)
  2. Sets constraint specs to maintain shape relationships
  3. Prepares input tensors in the expected format
  4. Uses AutoTuner.choose_one to select the best (runner, tactic) combination
  5. Executes the chosen runner with the selected tactic

The integration follows the established autotuning framework patterns and enables cross-backend tuning when backend="auto".


2487-2563: LGTM! TRTLLM FP4 runner refactoring enables autotuning.

The refactoring of trtllm_fp4_gemm_runner to:

  1. Accept use_8x4_sf_layout as a parameter
  2. Implement the TunableRunner interface with tactic support
  3. Return a properly configured runner instance

This change aligns the TRTLLM backend with the autotuning framework and maintains consistency with other FP4 runners. The implementation correctly handles the use_8x4_sf_layout parameter throughout the runner lifecycle.

tests/gemm/test_mm_fp4.py (2)

15-95: LGTM! Test refactoring improves maintainability.

Extracting the test logic into _test_mm_fp4 is a good refactoring that:

  1. Eliminates code duplication between test functions
  2. Makes the test logic reusable and easier to maintain
  3. Consolidates backend support checks and skip conditions

The updated skip condition at lines 34-35 correctly limits mxfp4 support to cudnn and auto backends, which aligns with the implementation in flashinfer/gemm.py.


97-127: LGTM! Test split provides good coverage of auto backend.

The split between test_mm_fp4 (non-auto backends) and test_mm_fp4_backend_auto (auto backend) is well-designed:

  1. test_mm_fp4 maintains full parameter coverage for individual backends
  2. test_mm_fp4_backend_auto tests the auto backend with a reduced but representative parameter space
  3. The reduced parameter space (fewer m/n/k combinations, only use_128x4_sf_layout=True) is appropriate for auto backend testing and helps keep test execution time reasonable

This approach provides comprehensive coverage while avoiding combinatorial explosion of test cases.

Comment on lines +96 to +98
def get_cuda_version(device: torch.device):
return tuple(map(int, torch.version.cuda.split("."))) # (major, minor)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused device parameter in get_cuda_version.

The device parameter is not used in the function body—the function only accesses torch.version.cuda, which is a global property independent of the device.

Consider one of the following:

  1. If the CUDA version should be device-specific, use the device parameter to query it appropriately
  2. If the CUDA version is indeed global, remove the unused parameter
-def get_cuda_version(device: torch.device):
+def get_cuda_version():
     return tuple(map(int, torch.version.cuda.split(".")))  # (major, minor)

And update the call site at line 2098:

-        cuda_major, _ = get_cuda_version(a.device)
+        cuda_major, _ = get_cuda_version()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_cuda_version(device: torch.device):
return tuple(map(int, torch.version.cuda.split("."))) # (major, minor)
def get_cuda_version():
return tuple(map(int, torch.version.cuda.split("."))) # (major, minor)
🧰 Tools
🪛 Ruff (0.14.2)

96-96: Unused function argument: device

(ARG001)

🤖 Prompt for AI Agents
In flashinfer/gemm.py around lines 96-98, the get_cuda_version function declares
an unused device parameter and instead reads the global torch.version.cuda;
remove the unused device parameter from the function signature (and adjust its
type hint to return Tuple[int, int]) and update the call site at line 2098 to
call get_cuda_version() with no arguments; also search and update any other
callers to remove the argument so the function and its uses are consistent.

"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
)
backends.remove("cutlass")
remove_cutlass = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way to avoid these remove_backend_x bools is to call the related backend check (which should be annoated with the decorator), or have the decorator return a filtered list as I proposed. #2000 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless whether you stuff it into the decorator, this will be a pattern that will happen for all APIs, so we should think about encapsulating the "if backend and checks_dont_pass: filter_it_out".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good idea. I have removed these hard coded checks entirely and have started using the checkers in the latest

)
# Auto-select the best backend
if backend == "auto":
cuda_major, _ = get_cuda_version(a.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks should be part of the _auto_gemm_fp4_requirement check.

I think a cleaner way would be to move the generation of the list of candidate_backends in the @backend_requirement decorator, where "auto" backend is treated specially. It lists the required checks for each backend already. An alternative is that we create a separate decorator that composes and uses the backend checks of the backend_requirement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The danger here is that we may be repeating some checks, but not all of them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When writing the code path for this PR, I noted that the following questions had to be answered at different times by the auto backend logic:

  1. Is there at least one runnable backend for the given input params -- for early error raising
  2. What are the runnable backends for the given input params -- to consider which backends to choose from
  3. In the current GPU/CUDA/cuDNN environment, what is the preferred ordering of backends -- for heuristics

The current implementation in the PR answers 1 in @backend_requirement and 2 & 3 in the body of the mm_fp4 while you're suggesting putting 2 inside @backend_requirement. I agree that this helps us avoid repeating checks but this will involve--as you raised--a special treatment for the auto backend and a change to backend_requirement. We can discuss

candidate_backends = ("cutlass", "cudnn")

# Filter to only supported backends for this compute capability
# Note: The requirement function already validated that at least one backend is supported
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the dangerous part: at this point, we know 1 backend replied that its check is ok. But we are considering all backends. Maybe cudnn supports it but not trtllm or cutlass.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct here. In the latest commit, I now check whether the backend is supported generally + for the inputs

Literal["cudnn", "trtllm", "cutlass", "auto"], candidate
)
try:
_check_mm_fp4_problem_size(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't have to do this here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm finding that we actually do because the result of _check_mm_fp4_problem_size depends on the backend.

With mxfp4 cases, for example, we need to check each backend so that we can keep cudnn and reject all other backends.

for candidate in candidate_backends:
# mypy requires explicit type casting for the backend literal
backend_literal = cast(
Literal["cudnn", "trtllm", "cutlass", "auto"], candidate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is auto added back?

Copy link
Collaborator Author

@bkryu bkryu Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Auto is actually not being added here since the cast() is telling pre-commit tests that backend_literal will be one of ["cudnn", "trtllm", "cutlass", "auto"] while candidate_backends will never contain auto.

However, there is no need for auto to be there and I can see it being confusing so I have removed in the latest commit

# At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'.
runners = []
for cur_backend in backends:
if cur_backend == "cudnn":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a cleaner way than the if-then-else structure? Maybe some dictionary that maps between backend and runner retrieval?

Copy link
Collaborator Author

@bkryu bkryu Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good idea. Updated in latest commit to use a dictionary

)
elif cur_backend == "cutlass":
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
a_descale = a_descale.view(torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an implementation detail, and maybe needs to be moved to the cutlass runner itself, just like we do with the cudnn_runner.

Copy link
Collaborator Author

@bkryu bkryu Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree and this allows removal of the if-then-else structure above. Updated in latest commit

a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer
)
# Now we have a list of runners for desired & supported backends.
tuner = AutoTuner.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is great that we unify the autotuning logic, and separating the different runner implementations.

@bkryu bkryu force-pushed the mm_fp4_auto_backend branch from 254827a to c9f3d52 Compare October 31, 2025 18:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 254827a and c9f3d52.

📒 Files selected for processing (4)
  • benchmarks/routines/flashinfer_benchmark_utils.py (1 hunks)
  • benchmarks/routines/gemm.py (4 hunks)
  • flashinfer/gemm.py (17 hunks)
  • tests/gemm/test_mm_fp4.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/gemm.py (2)
flashinfer/gemm.py (1)
  • _check_mm_fp4_problem_size (1812-1870)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm.py (2)
flashinfer/jit/cpp_ext.py (1)
  • get_cuda_version (64-83)
flashinfer/utils.py (4)
  • supported_compute_capability (772-852)
  • get_compute_capability (251-254)
  • is_compute_capability_supported (966-972)
  • backend_requirement (855-1028)
🪛 Ruff (0.14.2)
benchmarks/routines/gemm.py

900-900: Do not catch blind exception: Exception

(BLE001)

flashinfer/gemm.py

96-96: Unused function argument: device

(ARG001)


436-436: Unused method argument: inputs

(ARG002)


437-437: Unused method argument: profile

(ARG002)


445-445: Unused method argument: do_preparation

(ARG002)


446-446: Unused method argument: kwargs

(ARG002)


1730-1730: Unused method argument: profile

(ARG002)


1741-1741: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1744-1744: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1780-1780: Unused method argument: do_preparation

(ARG002)


1781-1781: Unused method argument: kwargs

(ARG002)


1863-1863: Avoid specifying long messages outside the exception class

(TRY003)


1884-1884: Unused function argument: backend

(ARG001)


1942-1942: Unused function argument: backend

(ARG001)


1964-1964: Unused function argument: backend

(ARG001)


1965-1965: Unused function argument: use_nvfp4

(ARG001)


1973-1973: Unused function argument: b

(ARG001)


1974-1974: Unused function argument: a_descale

(ARG001)


1975-1975: Unused function argument: b_descale

(ARG001)


1976-1976: Unused function argument: alpha

(ARG001)


1977-1977: Unused function argument: out_dtype

(ARG001)


1978-1978: Unused function argument: out

(ARG001)


1979-1979: Unused function argument: block_size

(ARG001)


1980-1980: Unused function argument: use_8x4_sf_layout

(ARG001)


1981-1981: Unused function argument: backend

(ARG001)


1982-1982: Unused function argument: use_nvfp4

(ARG001)


2109-2109: Unpacked variable cc_major is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2109-2109: Unpacked variable cc_minor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2153-2154: try-except-pass detected, consider logging the exception

(S110)


2153-2153: Do not catch blind exception: Exception

(BLE001)


2517-2517: Unpacked variable a_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2518-2518: Unpacked variable b_descale is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2519-2519: Unpacked variable alpha is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2521-2521: Unpacked variable out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2524-2524: Unpacked variable workspace_buffer is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


2538-2538: Unused method argument: do_preparation

(ARG002)


2539-2539: Unused method argument: kwargs

(ARG002)

Comment on lines +1989 to +1998
candidate_backends = ["cudnn", "cutlass", "trtllm"]
backend_checkers = {
"cudnn": _cudnn_gemm_fp4_requirement,
"cutlass": _cutlass_gemm_fp4_requirement,
},
common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends
# Does not consider trtllm due to different interface.
}

for candidate in candidate_backends:
checker = backend_checkers[candidate]
if hasattr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix auto backend support KeyError.

_auto_gemm_fp4_requirement() enumerates candidate_backends = ["cudnn", "cutlass", "trtllm"], but the backend_checkers dict only defines entries for "cudnn" and "cutlass". The first time this helper runs—e.g., via mm_fp4.is_backend_supported("auto", …) in the updated tests/benchmarks—it raises a KeyError when it hits "trtllm", so the new backend="auto" path never runs. We need to either include a checker for trtllm or drop it from candidate_backends. Here's a minimal fix removing the stray entry:

-    candidate_backends = ["cudnn", "cutlass", "trtllm"]
-    backend_checkers = {
-        "cudnn": _cudnn_gemm_fp4_requirement,
-        "cutlass": _cutlass_gemm_fp4_requirement,
-        # Does not consider trtllm due to different interface.
-    }
+    candidate_backends = ["cudnn", "cutlass"]
+    backend_checkers = {
+        "cudnn": _cudnn_gemm_fp4_requirement,
+        "cutlass": _cutlass_gemm_fp4_requirement,
+    }
🤖 Prompt for AI Agents
In flashinfer/gemm.py around lines 1989 to 1998, candidate_backends includes
"trtllm" but backend_checkers only defines "cudnn" and "cutlass", causing a
KeyError when iterating; fix by keeping candidate_backends and backend_checkers
in sync—either remove "trtllm" from candidate_backends or add a corresponding
checker entry for "trtllm" (preferably remove the stray "trtllm" entry here to
match the comment and avoid referencing an undefined checker).

@bkryu bkryu marked this pull request as draft October 31, 2025 22:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants