Skip to content

Conversation

@bangshengtang
Copy link

@bangshengtang bangshengtang commented Nov 3, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Enhanced module initialization to gracefully handle missing optional dependencies without causing import failures.

this is to make the code more modular, applications can selectively pack and import modules based on needs using the same __init__.py
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 3, 2025

Walkthrough

The module initialization converts eager imports to lazy, guarded imports organized into thematic sections, using contextlib.suppress(ImportError) to gracefully handle missing optional dependencies without raising errors at import time.

Changes

Cohort / File(s) Summary
Import reorganization and lazy loading
flashinfer/__init__.py
Replaces eager imports with conditional, lazy imports wrapped in contextlib.suppress(ImportError). Reorganizes imports into thematic sections (Activation, Attention, Autotuner, Cascade, Decode, Quantization, MoE, GEMM, MLA, Normalization, Paged KV, POD, Prefill, RoPE, Sampling, Sparse attention, TRT-LLM, Utilities, XQA) to improve maintainability and handle optional dependencies gracefully.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

  • Review areas of attention:
    • Verify all previously exported symbols remain importable or are intentionally made optional
    • Confirm no imports were accidentally removed or reorganized incorrectly
    • Validate that the thematic section organization is consistent and logical
    • Ensure contextlib.suppress(ImportError) is appropriately applied to all optional dependencies

Poem

🐰 Imports once eager, now they're lazy and wise,
Waiting for dependencies under autumn skies,
Gracefully skipping what's not yet there,
A flexible module, with thoughtful care,
✨ Dependencies danceβ€”some present, some spare!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description uses the template structure but leaves the critical 'Description' section empty, providing no explanation of what changes were made or why they're needed. Fill in the Description section with details about the changes made, the rationale for using lazy imports and contextlib.suppress, and how this enables selective submodule packing.
βœ… Passed checks (1 passed)
Check name Status Explanation
Title check βœ… Passed The title clearly and specifically describes the main change: refactoring flashinfer/init.py to enable selective submodule packing without modifying the file.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bangshengtang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the flashinfer/__init__.py file to enhance the modularity and flexibility of the FlashInfer library. By wrapping submodule imports within contextlib.suppress(ImportError) blocks, applications can now selectively import specific functionalities. This change prevents ImportError exceptions when certain submodules are not available or desired, making the library more adaptable to diverse usage scenarios and potentially reducing dependency overhead.

Highlights

  • Conditional Submodule Imports: All submodule imports in flashinfer/__init__.py are now wrapped with contextlib.suppress(ImportError). This allows applications to selectively import FlashInfer submodules without encountering errors if certain components are not present or required.
  • Improved Modularity: The refactoring enhances the modularity of the library, enabling more flexible usage patterns where users can choose to load only the necessary parts of FlashInfer.
  • Code Organization: Descriptive comments have been added to categorize groups of submodule imports, improving the readability and maintainability of the __init__.py file.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors flashinfer/__init__.py to make submodule imports optional by wrapping them in contextlib.suppress(ImportError). This is a good change that improves the modularity of the package, allowing users to selectively pack submodules. My review includes a suggestion to improve code style by removing redundant aliases in the import statements, which will enhance readability.

Comment on lines 19 to +262
from .version import __version__ as __version__
from .version import __git_version__ as __git_version__

import contextlib

# JIT compilation support
from . import jit as jit
from .activation import gelu_and_mul as gelu_and_mul
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
from .activation import silu_and_mul as silu_and_mul
from .activation import (
silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize,
)
from .attention import BatchAttention as BatchAttention
from .attention import (
BatchAttentionWithAttentionSinkWrapper as BatchAttentionWithAttentionSinkWrapper,
)
from .autotuner import autotune as autotune
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
)
from .cascade import merge_state as merge_state
from .cascade import merge_state_in_place as merge_state_in_place
from .cascade import merge_states as merge_states
from .decode import (
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
fast_decode_plan as fast_decode_plan,
)
from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache
from .fp4_quantization import (
SfLayout,
block_scale_interleave,
nvfp4_block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp4_dequantize_host,
mxfp4_dequantize,
mxfp4_quantize,
nvfp4_quantize,
nvfp4_batched_quantize,
shuffle_matrix_a,
shuffle_matrix_sf_a,
scaled_fp4_grouped_quantize,
)
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize
from .fused_moe import (
RoutingMethodType,
GatedActType,
cutlass_fused_moe,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_fp8 as bmm_fp8
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
from .page import get_seq_lens as get_seq_lens
from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
)
from .prefill import (
BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper,
)
from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache
from .prefill import (
single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse,
)
from .quantization import packbits as packbits
from .quantization import segment_packbits as segment_packbits
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope as apply_rope
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
)
from .sampling import chain_speculative_sampling as chain_speculative_sampling
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
from .sampling import sampling_from_logits as sampling_from_logits
from .sampling import sampling_from_probs as sampling_from_probs
from .sampling import softmax as softmax
from .sampling import top_k_mask_logits as top_k_mask_logits
from .sampling import top_k_renorm_probs as top_k_renorm_probs
from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs
from .sampling import (
top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits,
)
from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs
from .sampling import top_p_renorm_probs as top_p_renorm_probs
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
)
from .trtllm_low_latency_gemm import (
prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights,
)
from .utils import next_positive_power_of_2 as next_positive_power_of_2
from .xqa import xqa as xqa
from .xqa import xqa_mla as xqa_mla

# ============================================================================
# Activation functions
# ============================================================================
with contextlib.suppress(ImportError):
from .activation import gelu_and_mul as gelu_and_mul
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
from .activation import silu_and_mul as silu_and_mul
from .activation import (
silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize,
)

# ============================================================================
# Attention modules
# ============================================================================
with contextlib.suppress(ImportError):
from .attention import BatchAttention as BatchAttention
from .attention import (
BatchAttentionWithAttentionSinkWrapper as BatchAttentionWithAttentionSinkWrapper,
)

# ============================================================================
# Autotuner
# ============================================================================
with contextlib.suppress(ImportError):
from .autotuner import autotune as autotune

# ============================================================================
# Cascade attention
# ============================================================================
with contextlib.suppress(ImportError):
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
)
from .cascade import merge_state as merge_state
from .cascade import merge_state_in_place as merge_state_in_place
from .cascade import merge_states as merge_states

# ============================================================================
# Decode operations
# ============================================================================
with contextlib.suppress(ImportError):
from .decode import (
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache,
)
from .decode import fast_decode_plan as fast_decode_plan
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache

# ============================================================================
# FP4 quantization
# ============================================================================
with contextlib.suppress(ImportError):
from .fp4_quantization import (
SfLayout,
block_scale_interleave,
nvfp4_block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp4_dequantize_host,
mxfp4_dequantize,
mxfp4_quantize,
nvfp4_quantize,
nvfp4_batched_quantize,
shuffle_matrix_a,
shuffle_matrix_sf_a,
scaled_fp4_grouped_quantize,
)

# ============================================================================
# FP8 quantization
# ============================================================================
with contextlib.suppress(ImportError):
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize

# ============================================================================
# Fused mixture-of-experts (MoE)
# ============================================================================
with contextlib.suppress(ImportError):
from .fused_moe import (
RoutingMethodType,
GatedActType,
cutlass_fused_moe,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)

# ============================================================================
# GEMM operations
# ============================================================================
with contextlib.suppress(ImportError):
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_fp8 as bmm_fp8
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100

# ============================================================================
# Multi-latent attention (MLA)
# ============================================================================
with contextlib.suppress(ImportError):
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper

# ============================================================================
# Normalization operations
# ============================================================================
with contextlib.suppress(ImportError):
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm

# ============================================================================
# Paged KV cache operations
# ============================================================================
with contextlib.suppress(ImportError):
from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
from .page import get_seq_lens as get_seq_lens

# ============================================================================
# POD (Persistent Output Decoding)
# ============================================================================
with contextlib.suppress(ImportError):
from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper


# ============================================================================
# Prefill operations
# ============================================================================
with contextlib.suppress(ImportError):
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
)
from .prefill import (
BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper,
)
from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache
from .prefill import (
single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse,
)

# ============================================================================
# Quantization utilities
# ============================================================================
with contextlib.suppress(ImportError):
from .quantization import packbits as packbits
from .quantization import segment_packbits as segment_packbits

# ============================================================================
# RoPE (Rotary Position Embedding)
# ============================================================================
with contextlib.suppress(ImportError):
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope as apply_rope
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
)

# ============================================================================
# Sampling operations
# ============================================================================
with contextlib.suppress(ImportError):
from .sampling import chain_speculative_sampling as chain_speculative_sampling
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
from .sampling import sampling_from_logits as sampling_from_logits
from .sampling import sampling_from_probs as sampling_from_probs
from .sampling import softmax as softmax
from .sampling import top_k_mask_logits as top_k_mask_logits
from .sampling import top_k_renorm_probs as top_k_renorm_probs
from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs
from .sampling import (
top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits,
)
from .sampling import (
top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs,
)
from .sampling import top_p_renorm_probs as top_p_renorm_probs
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs

# ============================================================================
# Sparse attention
# ============================================================================
with contextlib.suppress(ImportError):
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
)

# ============================================================================
# TRT-LLM low-latency GEMM
# ============================================================================
with contextlib.suppress(ImportError):
from .trtllm_low_latency_gemm import (
prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights,
)

# ============================================================================
# Utilities
# ============================================================================
with contextlib.suppress(ImportError):
from .utils import next_positive_power_of_2 as next_positive_power_of_2

# ============================================================================
# XQA (Cross-Query Attention)
# ============================================================================
with contextlib.suppress(ImportError):
from .xqa import xqa as xqa
from .xqa import xqa_mla as xqa_mla
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The as <name> part of the imports is redundant when the imported name is the same as the alias. Removing these redundant aliases improves code readability and conciseness.

from .version import __version__
from .version import __git_version__

import contextlib

# JIT compilation support
from . import jit

# ============================================================================
# Activation functions
# ============================================================================
with contextlib.suppress(ImportError):
    from .activation import gelu_and_mul
    from .activation import gelu_tanh_and_mul
    from .activation import silu_and_mul
    from .activation import (
        silu_and_mul_scaled_nvfp4_experts_quantize,
    )

# ============================================================================
# Attention modules
# ============================================================================
with contextlib.suppress(ImportError):
    from .attention import BatchAttention
    from .attention import (
        BatchAttentionWithAttentionSinkWrapper,
    )

# ============================================================================
# Autotuner
# ============================================================================
with contextlib.suppress(ImportError):
    from .autotuner import autotune

# ============================================================================
# Cascade attention
# ============================================================================
with contextlib.suppress(ImportError):
    from .cascade import (
        BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
    )
    from .cascade import (
        BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
    )
    from .cascade import (
        MultiLevelCascadeAttentionWrapper,
    )
    from .cascade import merge_state
    from .cascade import merge_state_in_place
    from .cascade import merge_states

# ============================================================================
# Decode operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .decode import (
        BatchDecodeMlaWithPagedKVCacheWrapper,
    )
    from .decode import (
        BatchDecodeWithPagedKVCacheWrapper,
    )
    from .decode import (
        CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
    )
    from .decode import (
        cudnn_batch_decode_with_kv_cache,
    )
    from .decode import fast_decode_plan
    from .decode import single_decode_with_kv_cache

# ============================================================================
# FP4 quantization
# ============================================================================
with contextlib.suppress(ImportError):
    from .fp4_quantization import (
        SfLayout,
        block_scale_interleave,
        nvfp4_block_scale_interleave,
        e2m1_and_ufp8sf_scale_to_float,
        fp4_quantize,
        mxfp4_dequantize_host,
        mxfp4_dequantize,
        mxfp4_quantize,
        nvfp4_quantize,
        nvfp4_batched_quantize,
        shuffle_matrix_a,
        shuffle_matrix_sf_a,
        scaled_fp4_grouped_quantize,
    )

# ============================================================================
# FP8 quantization
# ============================================================================
with contextlib.suppress(ImportError):
    from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize

# ============================================================================
# Fused mixture-of-experts (MoE)
# ============================================================================
with contextlib.suppress(ImportError):
    from .fused_moe import (
        RoutingMethodType,
        GatedActType,
        cutlass_fused_moe,
        reorder_rows_for_gated_act_gemm,
        trtllm_fp4_block_scale_moe,
        trtllm_fp4_block_scale_routed_moe,
        trtllm_fp8_block_scale_moe,
        trtllm_fp8_per_tensor_scale_moe,
    )

# ============================================================================
# GEMM operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .gemm import SegmentGEMMWrapper
    from .gemm import bmm_fp8
    from .gemm import mm_fp4
    from .gemm import mm_fp8
    from .gemm import tgv_gemm_sm100

# ============================================================================
# Multi-latent attention (MLA)
# ============================================================================
with contextlib.suppress(ImportError):
    from .mla import BatchMLAPagedAttentionWrapper

# ============================================================================
# Normalization operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .norm import fused_add_rmsnorm
    from .norm import layernorm
    from .norm import gemma_fused_add_rmsnorm
    from .norm import gemma_rmsnorm
    from .norm import rmsnorm

# ============================================================================
# Paged KV cache operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .page import append_paged_kv_cache
    from .page import append_paged_mla_kv_cache
    from .page import get_batch_indices_positions
    from .page import get_seq_lens

# ============================================================================
# POD (Persistent Output Decoding)
# ============================================================================
with contextlib.suppress(ImportError):
    from .pod import PODWithPagedKVCacheWrapper


# ============================================================================
# Prefill operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .prefill import (
        BatchPrefillWithPagedKVCacheWrapper,
    )
    from .prefill import (
        BatchPrefillWithRaggedKVCacheWrapper,
    )
    from .prefill import single_prefill_with_kv_cache
    from .prefill import (
        single_prefill_with_kv_cache_return_lse,
    )

# ============================================================================
# Quantization utilities
# ============================================================================
with contextlib.suppress(ImportError):
    from .quantization import packbits
    from .quantization import segment_packbits

# ============================================================================
# RoPE (Rotary Position Embedding)
# ============================================================================
with contextlib.suppress(ImportError):
    from .rope import apply_llama31_rope
    from .rope import apply_llama31_rope_inplace
    from .rope import apply_llama31_rope_pos_ids
    from .rope import (
        apply_llama31_rope_pos_ids_inplace,
    )
    from .rope import apply_rope
    from .rope import apply_rope_inplace
    from .rope import apply_rope_pos_ids
    from .rope import apply_rope_pos_ids_inplace
    from .rope import apply_rope_with_cos_sin_cache
    from .rope import (
        apply_rope_with_cos_sin_cache_inplace,
    )

# ============================================================================
# Sampling operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .sampling import chain_speculative_sampling
    from .sampling import min_p_sampling_from_probs
    from .sampling import sampling_from_logits
    from .sampling import sampling_from_probs
    from .sampling import softmax
    from .sampling import top_k_mask_logits
    from .sampling import top_k_renorm_probs
    from .sampling import top_k_sampling_from_probs
    from .sampling import (
        top_k_top_p_sampling_from_logits,
    )
    from .sampling import (
        top_k_top_p_sampling_from_probs,
    )
    from .sampling import top_p_renorm_probs
    from .sampling import top_p_sampling_from_probs

# ============================================================================
# Sparse attention
# ============================================================================
with contextlib.suppress(ImportError):
    from .sparse import BlockSparseAttentionWrapper
    from .sparse import (
        VariableBlockSparseAttentionWrapper,
    )

# ============================================================================
# TRT-LLM low-latency GEMM
# ============================================================================
with contextlib.suppress(ImportError):
    from .trtllm_low_latency_gemm import (
        prepare_low_latency_gemm_weights,
    )

# ============================================================================
# Utilities
# ============================================================================
with contextlib.suppress(ImportError):
    from .utils import next_positive_power_of_2

# ============================================================================
# XQA (Cross-Query Attention)
# ============================================================================
with contextlib.suppress(ImportError):
    from .xqa import xqa
    from .xqa import xqa_mla

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between da01b1b and c41d23e.

πŸ“’ Files selected for processing (1)
  • flashinfer/__init__.py (1 hunks)

Comment on lines +30 to +36
with contextlib.suppress(ImportError):
from .activation import gelu_and_mul as gelu_and_mul
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
from .activation import silu_and_mul as silu_and_mul
from .activation import (
silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't silently swallow real ImportErrors

Blanket contextlib.suppress(ImportError) means any ImportError raised while importing .activation (including missing transitive deps or a bad refactor) now disappears, import flashinfer succeeds, and callers only discover the breakage later via an unhelpful AttributeError. Please differentiate β€œmodule truly not packaged” from β€œmodule broken”: e.g., probe importlib.util.find_spec and only skip when the spec is absent, or cache the caught exception and surface it via __getattr__ so users still see the real failure instead of a silent drop in exports.

πŸ€– Prompt for AI Agents
In flashinfer/__init__.py around lines 30 to 36, the blanket
contextlib.suppress(ImportError) hides real import errors from .activation;
change the logic to first check
importlib.util.find_spec("flashinfer.activation") and only skip imports if the
spec is None (module not packaged), otherwise attempt the import and on
ImportError cache the caught exception in a module-level variable (e.g.,
_activation_import_error) and implement __getattr__ to raise that cached
exception when consumers try to access the missing attributes; this preserves
silent skipping for genuinely absent optional modules but surfaces real
ImportErrors to callers.

@nvmbreughe
Copy link
Contributor

LGTM. What was the main use case for the change?

@bangshengtang
Copy link
Author

@nvmbreughe we have use cases that only need a subset of flashinfer, e.g. sampling, we can make an additional target that only includes all the relevant files from flashinfer/, and the callsite can import flashinfer.sampling just fine. thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants