-
Notifications
You must be signed in to change notification settings - Fork 561
[wip] feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' #1979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughRemoves static mm_fp4 compute-capability → backend mapping in benchmarks; adds "auto" backend support and runtime backend-selection/autotuning for FP4 GEMM across flashinfer/gemm.py, updates runner names and cuDNN tactic plumbing, extends tests to exercise backend="auto" and refactors test entry points. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant mm_fp4
participant SupportCheckers
participant BackendSelector
participant RunnerFactory
participant Runner
Caller->>mm_fp4: call mm_fp4(backend="auto", ...)
mm_fp4->>SupportCheckers: query device capability & requirements
SupportCheckers-->>mm_fp4: feasible backends (list)
rect rgba(120,180,200,0.12)
Note over mm_fp4,BackendSelector: Auto-backend filtering & validation
mm_fp4->>BackendSelector: apply layout/size/constraint checks
BackendSelector-->>mm_fp4: filtered backends
end
mm_fp4->>RunnerFactory: construct runners for filtered backends
RunnerFactory-->>mm_fp4: runners
mm_fp4->>Runner: autotune/benchmark runners (parallel/iterative)
Runner-->>mm_fp4: metrics
mm_fp4->>Runner: pick best runner and execute
Runner-->>Caller: result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
bc94c4c to
254827a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/gemm.py (1)
2096-2134: Consider extracting auto-backend selection into a helper function.The auto-backend selection logic (lines 2096-2134) is complex and involves:
- CUDA/cuDNN version inspection
- Backend ordering heuristics
- Problem size validation
- Exception handling for unsupported configurations
This logic could benefit from extraction into a dedicated helper function (e.g.,
_select_mm_fp4_backends) to improve readability and testability.Additionally, the bare
except Exceptionat lines 2131-2132 might hide unexpected errors. Consider either:
- Catching more specific exceptions (e.g.,
ValueError,RuntimeError)- Adding logging to track which backends fail validation and why
Example refactoring:
def _select_mm_fp4_backends( cuda_major: int, cudnn_version: int, a: torch.Tensor, b: torch.Tensor, a_descale: torch.Tensor, b_descale: torch.Tensor, alpha: Optional[torch.Tensor], out_dtype: torch.dtype, out: torch.Tensor, block_size: int, use_8x4_sf_layout: bool, use_nvfp4: bool, ) -> List[str]: """Select supported backends for mm_fp4 based on device capabilities.""" # Backend ordering heuristics if cuda_major >= 13 and cudnn_version >= 91400: candidate_backends = ("cudnn", "cutlass") else: candidate_backends = ("cutlass", "cudnn") # Filter by problem size support backends = [] for candidate in candidate_backends: try: _check_mm_fp4_problem_size( a, b, a_descale, b_descale, alpha, out_dtype, out, block_size, use_8x4_sf_layout, cast(Literal["cudnn", "trtllm", "cutlass", "auto"], candidate), use_nvfp4, ) backends.append(candidate) except (ValueError, RuntimeError): pass # Backend not supported for this problem return backends
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(5 hunks)flashinfer/gemm.py(17 hunks)tests/gemm/test_mm_fp4.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm.py (4)
flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)flashinfer/autotuner.py (9)
TunableRunner(194-247)get_valid_tactics(196-214)OptimizationProfile(168-183)forward(220-244)AutoTuner(335-784)get(362-365)TuningConfig(101-141)choose_one(400-529)get_opt_shapes(177-183)flashinfer/trtllm_low_latency_gemm.py (2)
get_valid_tactics(52-77)forward(79-109)flashinfer/utils.py (4)
supported_compute_capability(772-852)get_compute_capability(251-254)is_compute_capability_supported(966-972)backend_requirement(855-1028)
🪛 Ruff (0.14.2)
flashinfer/gemm.py
96-96: Unused function argument: device
(ARG001)
432-432: Unused method argument: inputs
(ARG002)
433-433: Unused method argument: profile
(ARG002)
441-441: Unused method argument: do_preparation
(ARG002)
442-442: Unused method argument: kwargs
(ARG002)
1722-1722: Unused method argument: profile
(ARG002)
1733-1733: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1736-1736: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1772-1772: Unused method argument: do_preparation
(ARG002)
1773-1773: Unused method argument: kwargs
(ARG002)
1855-1855: Avoid specifying long messages outside the exception class
(TRY003)
1876-1876: Unused function argument: backend
(ARG001)
1934-1934: Unused function argument: backend
(ARG001)
1956-1956: Unused function argument: backend
(ARG001)
1957-1957: Unused function argument: use_nvfp4
(ARG001)
1965-1965: Unused function argument: b
(ARG001)
1966-1966: Unused function argument: a_descale
(ARG001)
1967-1967: Unused function argument: b_descale
(ARG001)
1968-1968: Unused function argument: alpha
(ARG001)
1969-1969: Unused function argument: out_dtype
(ARG001)
1970-1970: Unused function argument: out
(ARG001)
1971-1971: Unused function argument: block_size
(ARG001)
1972-1972: Unused function argument: use_8x4_sf_layout
(ARG001)
1973-1973: Unused function argument: backend
(ARG001)
1974-1974: Unused function argument: use_nvfp4
(ARG001)
2099-2099: Unpacked variable cc_major is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2099-2099: Unpacked variable cc_minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2131-2132: try-except-pass detected, consider logging the exception
(S110)
2131-2131: Do not catch blind exception: Exception
(BLE001)
2163-2163: Avoid specifying long messages outside the exception class
(TRY003)
2509-2509: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2510-2510: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2511-2511: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2513-2513: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2516-2516: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2530-2530: Unused method argument: do_preparation
(ARG002)
2531-2531: Unused method argument: kwargs
(ARG002)
🔇 Additional comments (12)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
241-243: LGTM! Auto backend addition is correct.The addition of "auto" to the supported backends list for
mm_fp4at compute capabilities 10.0, 10.3, and 12.0 is consistent with the PR objectives and aligns with the auto-backend implementation inflashinfer/gemm.py.benchmarks/routines/gemm.py (2)
134-134: LGTM! Backend choices updated correctly.The addition of "auto" to the --backends argument choices is consistent with the auto-backend support introduced in this PR.
793-793: LGTM! Auto backend support properly integrated.The changes correctly:
- Add "auto" to the list of autotune-supported backends for
mm_fp4- Implement backend filtering logic for "auto" that respects the
use_128x4_sf_layoutconstraint- Include "auto" in the
run_backendexecution pathThe filtering logic at lines 836-842 appropriately mirrors the filtering done for other backends (cudnn, cutlass) and ensures "auto" is removed when layout constraints aren't met.
Also applies to: 836-842, 899-899
flashinfer/gemm.py (7)
425-465: LGTM! Runner refactoring improves consistency.The refactoring of CUTLASS FP4 GEMM into
cutlass_fp4_gemm_runnerwith the helper function_create_cutlass_fp4_gemm_moduleimproves naming consistency and aligns with the pattern used for other runners (e.g.,trtllm_fp4_gemm_runner).
1270-1294: LGTM! cuDNN tactic support enables fine-grained autotuning.The addition of
tacticparameter tobuild_plans_cudnn_fp4_gemm_graphandexecute_cudnn_gemm_fp4_graphenables plan-specific execution for autotuning. The logic correctly:
- Builds a specific plan when
tactic != -1- Builds all plans when
tactic == -1(fallback)- Executes the selected plan or uses default execution
This aligns with the autotuning framework's expectations and follows the pattern established by other tunable runners.
Also applies to: 1306-1331
1665-1802: LGTM! cuDNN FP4 runner properly implements TunableRunner interface.The new
_cudnn_gemm_fp4and_cudnn_gemm_fp4_runnerfunctions correctly:
- Encapsulate cuDNN FP4 GEMM execution with tactic support
- Implement the
TunableRunnerinterface withget_valid_tacticsandforwardmethods- Query available execution plans from the cuDNN graph
- Support tactic-specific execution for autotuning
The implementation follows the established pattern for tunable runners and integrates well with the autotuning framework.
1962-1997: LGTM! Auto backend requirement validation is well-implemented.The
_auto_gemm_fp4_requirementfunction correctly validates that the "auto" backend can be used by:
- Checking compute capability support for candidate backends (cudnn, cutlass)
- Explicitly excluding trtllm due to its different interface (as documented in the PR description)
- Returning True if at least one backend is supported
The implementation ensures that "auto" will only be accepted on devices where at least one compatible backend is available.
2136-2163: LGTM! Runner construction logic handles all backend cases correctly.The runner construction for each backend (cudnn, trtllm, cutlass) correctly:
- Creates appropriate runner instances based on backend type
- Handles dtype conversions for cutlass backend (uint8 ↔ float8_e4m3fn)
- Dispatches to the correct module based on device architecture (SM120 vs SM100/SM103)
- Falls through to a clear error for unsupported backends
The logic is well-structured and handles all supported backend configurations.
2165-2217: LGTM! Autotuning integration is well-structured.The autotuning setup correctly:
- Defines dynamic tensor specs for batch size variation (power-of-2 bucketing)
- Sets constraint specs to maintain shape relationships
- Prepares input tensors in the expected format
- Uses
AutoTuner.choose_oneto select the best (runner, tactic) combination- Executes the chosen runner with the selected tactic
The integration follows the established autotuning framework patterns and enables cross-backend tuning when
backend="auto".
2487-2563: LGTM! TRTLLM FP4 runner refactoring enables autotuning.The refactoring of
trtllm_fp4_gemm_runnerto:
- Accept
use_8x4_sf_layoutas a parameter- Implement the
TunableRunnerinterface with tactic support- Return a properly configured runner instance
This change aligns the TRTLLM backend with the autotuning framework and maintains consistency with other FP4 runners. The implementation correctly handles the
use_8x4_sf_layoutparameter throughout the runner lifecycle.tests/gemm/test_mm_fp4.py (2)
15-95: LGTM! Test refactoring improves maintainability.Extracting the test logic into
_test_mm_fp4is a good refactoring that:
- Eliminates code duplication between test functions
- Makes the test logic reusable and easier to maintain
- Consolidates backend support checks and skip conditions
The updated skip condition at lines 34-35 correctly limits mxfp4 support to cudnn and auto backends, which aligns with the implementation in
flashinfer/gemm.py.
97-127: LGTM! Test split provides good coverage of auto backend.The split between
test_mm_fp4(non-auto backends) andtest_mm_fp4_backend_auto(auto backend) is well-designed:
test_mm_fp4maintains full parameter coverage for individual backendstest_mm_fp4_backend_autotests the auto backend with a reduced but representative parameter space- The reduced parameter space (fewer m/n/k combinations, only
use_128x4_sf_layout=True) is appropriate for auto backend testing and helps keep test execution time reasonableThis approach provides comprehensive coverage while avoiding combinatorial explosion of test cases.
| def get_cuda_version(device: torch.device): | ||
| return tuple(map(int, torch.version.cuda.split("."))) # (major, minor) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused device parameter in get_cuda_version.
The device parameter is not used in the function body—the function only accesses torch.version.cuda, which is a global property independent of the device.
Consider one of the following:
- If the CUDA version should be device-specific, use the
deviceparameter to query it appropriately - If the CUDA version is indeed global, remove the unused parameter
-def get_cuda_version(device: torch.device):
+def get_cuda_version():
return tuple(map(int, torch.version.cuda.split("."))) # (major, minor)And update the call site at line 2098:
- cuda_major, _ = get_cuda_version(a.device)
+ cuda_major, _ = get_cuda_version()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def get_cuda_version(device: torch.device): | |
| return tuple(map(int, torch.version.cuda.split("."))) # (major, minor) | |
| def get_cuda_version(): | |
| return tuple(map(int, torch.version.cuda.split("."))) # (major, minor) |
🧰 Tools
🪛 Ruff (0.14.2)
96-96: Unused function argument: device
(ARG001)
🤖 Prompt for AI Agents
In flashinfer/gemm.py around lines 96-98, the get_cuda_version function declares
an unused device parameter and instead reads the global torch.version.cuda;
remove the unused device parameter from the function signature (and adjust its
type hint to return Tuple[int, int]) and update the call site at line 2098 to
call get_cuda_version() with no arguments; also search and update any other
callers to remove the argument so the function and its uses are consistent.
benchmarks/routines/gemm.py
Outdated
| "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)" | ||
| ) | ||
| backends.remove("cutlass") | ||
| remove_cutlass = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another way to avoid these remove_backend_x bools is to call the related backend check (which should be annoated with the decorator), or have the decorator return a filtered list as I proposed. #2000 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regardless whether you stuff it into the decorator, this will be a pattern that will happen for all APIs, so we should think about encapsulating the "if backend and checks_dont_pass: filter_it_out".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good idea. I have removed these hard coded checks entirely and have started using the checkers in the latest
| ) | ||
| # Auto-select the best backend | ||
| if backend == "auto": | ||
| cuda_major, _ = get_cuda_version(a.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These checks should be part of the _auto_gemm_fp4_requirement check.
I think a cleaner way would be to move the generation of the list of candidate_backends in the @backend_requirement decorator, where "auto" backend is treated specially. It lists the required checks for each backend already. An alternative is that we create a separate decorator that composes and uses the backend checks of the backend_requirement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The danger here is that we may be repeating some checks, but not all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When writing the code path for this PR, I noted that the following questions had to be answered at different times by the auto backend logic:
- Is there at least one runnable backend for the given input params -- for early error raising
- What are the runnable backends for the given input params -- to consider which backends to choose from
- In the current GPU/CUDA/cuDNN environment, what is the preferred ordering of backends -- for heuristics
The current implementation in the PR answers 1 in @backend_requirement and 2 & 3 in the body of the mm_fp4 while you're suggesting putting 2 inside @backend_requirement. I agree that this helps us avoid repeating checks but this will involve--as you raised--a special treatment for the auto backend and a change to backend_requirement. We can discuss
| candidate_backends = ("cutlass", "cudnn") | ||
|
|
||
| # Filter to only supported backends for this compute capability | ||
| # Note: The requirement function already validated that at least one backend is supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is the dangerous part: at this point, we know 1 backend replied that its check is ok. But we are considering all backends. Maybe cudnn supports it but not trtllm or cutlass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct here. In the latest commit, I now check whether the backend is supported generally + for the inputs
flashinfer/gemm.py
Outdated
| Literal["cudnn", "trtllm", "cutlass", "auto"], candidate | ||
| ) | ||
| try: | ||
| _check_mm_fp4_problem_size( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't have to do this here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm finding that we actually do because the result of _check_mm_fp4_problem_size depends on the backend.
With mxfp4 cases, for example, we need to check each backend so that we can keep cudnn and reject all other backends.
flashinfer/gemm.py
Outdated
| for candidate in candidate_backends: | ||
| # mypy requires explicit type casting for the backend literal | ||
| backend_literal = cast( | ||
| Literal["cudnn", "trtllm", "cutlass", "auto"], candidate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is auto added back?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Auto is actually not being added here since the cast() is telling pre-commit tests that backend_literal will be one of ["cudnn", "trtllm", "cutlass", "auto"] while candidate_backends will never contain auto.
However, there is no need for auto to be there and I can see it being confusing so I have removed in the latest commit
flashinfer/gemm.py
Outdated
| # At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'. | ||
| runners = [] | ||
| for cur_backend in backends: | ||
| if cur_backend == "cudnn": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a cleaner way than the if-then-else structure? Maybe some dictionary that maps between backend and runner retrieval?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a good idea. Updated in latest commit to use a dictionary
flashinfer/gemm.py
Outdated
| ) | ||
| elif cur_backend == "cutlass": | ||
| if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn: | ||
| a_descale = a_descale.view(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an implementation detail, and maybe needs to be moved to the cutlass runner itself, just like we do with the cudnn_runner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree and this allows removal of the if-then-else structure above. Updated in latest commit
| a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer | ||
| ) | ||
| # Now we have a list of runners for desired & supported backends. | ||
| tuner = AutoTuner.get() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is great that we unify the autotuning logic, and separating the different runner implementations.
…uto, but no cudnn autotune
254827a to
c9f3d52
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/routines/flashinfer_benchmark_utils.py(1 hunks)benchmarks/routines/gemm.py(4 hunks)flashinfer/gemm.py(17 hunks)tests/gemm/test_mm_fp4.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/routines/gemm.py (2)
flashinfer/gemm.py (1)
_check_mm_fp4_problem_size(1812-1870)flashinfer/autotuner.py (1)
autotune(251-262)
flashinfer/gemm.py (2)
flashinfer/jit/cpp_ext.py (1)
get_cuda_version(64-83)flashinfer/utils.py (4)
supported_compute_capability(772-852)get_compute_capability(251-254)is_compute_capability_supported(966-972)backend_requirement(855-1028)
🪛 Ruff (0.14.2)
benchmarks/routines/gemm.py
900-900: Do not catch blind exception: Exception
(BLE001)
flashinfer/gemm.py
96-96: Unused function argument: device
(ARG001)
436-436: Unused method argument: inputs
(ARG002)
437-437: Unused method argument: profile
(ARG002)
445-445: Unused method argument: do_preparation
(ARG002)
446-446: Unused method argument: kwargs
(ARG002)
1730-1730: Unused method argument: profile
(ARG002)
1741-1741: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1744-1744: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1780-1780: Unused method argument: do_preparation
(ARG002)
1781-1781: Unused method argument: kwargs
(ARG002)
1863-1863: Avoid specifying long messages outside the exception class
(TRY003)
1884-1884: Unused function argument: backend
(ARG001)
1942-1942: Unused function argument: backend
(ARG001)
1964-1964: Unused function argument: backend
(ARG001)
1965-1965: Unused function argument: use_nvfp4
(ARG001)
1973-1973: Unused function argument: b
(ARG001)
1974-1974: Unused function argument: a_descale
(ARG001)
1975-1975: Unused function argument: b_descale
(ARG001)
1976-1976: Unused function argument: alpha
(ARG001)
1977-1977: Unused function argument: out_dtype
(ARG001)
1978-1978: Unused function argument: out
(ARG001)
1979-1979: Unused function argument: block_size
(ARG001)
1980-1980: Unused function argument: use_8x4_sf_layout
(ARG001)
1981-1981: Unused function argument: backend
(ARG001)
1982-1982: Unused function argument: use_nvfp4
(ARG001)
2109-2109: Unpacked variable cc_major is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2109-2109: Unpacked variable cc_minor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2153-2154: try-except-pass detected, consider logging the exception
(S110)
2153-2153: Do not catch blind exception: Exception
(BLE001)
2517-2517: Unpacked variable a_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2518-2518: Unpacked variable b_descale is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2519-2519: Unpacked variable alpha is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2521-2521: Unpacked variable out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2524-2524: Unpacked variable workspace_buffer is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
2538-2538: Unused method argument: do_preparation
(ARG002)
2539-2539: Unused method argument: kwargs
(ARG002)
| candidate_backends = ["cudnn", "cutlass", "trtllm"] | ||
| backend_checkers = { | ||
| "cudnn": _cudnn_gemm_fp4_requirement, | ||
| "cutlass": _cutlass_gemm_fp4_requirement, | ||
| }, | ||
| common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends | ||
| # Does not consider trtllm due to different interface. | ||
| } | ||
|
|
||
| for candidate in candidate_backends: | ||
| checker = backend_checkers[candidate] | ||
| if hasattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix auto backend support KeyError.
_auto_gemm_fp4_requirement() enumerates candidate_backends = ["cudnn", "cutlass", "trtllm"], but the backend_checkers dict only defines entries for "cudnn" and "cutlass". The first time this helper runs—e.g., via mm_fp4.is_backend_supported("auto", …) in the updated tests/benchmarks—it raises a KeyError when it hits "trtllm", so the new backend="auto" path never runs. We need to either include a checker for trtllm or drop it from candidate_backends. Here's a minimal fix removing the stray entry:
- candidate_backends = ["cudnn", "cutlass", "trtllm"]
- backend_checkers = {
- "cudnn": _cudnn_gemm_fp4_requirement,
- "cutlass": _cutlass_gemm_fp4_requirement,
- # Does not consider trtllm due to different interface.
- }
+ candidate_backends = ["cudnn", "cutlass"]
+ backend_checkers = {
+ "cudnn": _cudnn_gemm_fp4_requirement,
+ "cutlass": _cutlass_gemm_fp4_requirement,
+ }🤖 Prompt for AI Agents
In flashinfer/gemm.py around lines 1989 to 1998, candidate_backends includes
"trtllm" but backend_checkers only defines "cudnn" and "cutlass", causing a
KeyError when iterating; fix by keeping candidate_backends and backend_checkers
in sync—either remove "trtllm" from candidate_backends or add a corresponding
checker entry for "trtllm" (preferably remove the stray "trtllm" entry here to
match the comment and avoid referencing an undefined checker).
📌 Description
DRAFT. Please do not merge.
Current PR:
autobackend tomm_fp4that can be autotuned. It replacescudnnas the default.cudnnbackend to be autotuned.Behavior of
autobackend:cutlassorcudnnkernel backends.trtllmkernel is not considered due to a non-interchangeable interface between trtllm and (cutlass, cudnn) backend.autobackend therefore only supports inputs runnable bycutlassand/orcudnn**use_nvfp4=False,cutlasswill be removed from the backend list as it fails support check.backend='trtllm'orbackend='cutlass'orbackend='cudnn'--> Autotunes within the backend. Same as previous behavior, but now autotuning is supported for cudnn.backend='auto'--> Autotunes within and across backends (cudnn & cutlass) and chooses the best config of best backend.trtllmkernel is not considered due to a non-interchangeable interface between trtllm and (cutlass, cudnn) backend.mm_fp4were refactored to enable cross-backend autotuning. Refactoring was done to match cross-backend autotune-enabledbmm_fp8as a reference.🔍 Related Issues
#1722
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Behavior Changes
Tests