Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 239 additions & 129 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,134 +19,244 @@
from .version import __version__ as __version__
from .version import __git_version__ as __git_version__

import contextlib

# JIT compilation support
from . import jit as jit
from .activation import gelu_and_mul as gelu_and_mul
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
from .activation import silu_and_mul as silu_and_mul
from .activation import (
silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize,
)
from .attention import BatchAttention as BatchAttention
from .attention import (
BatchAttentionWithAttentionSinkWrapper as BatchAttentionWithAttentionSinkWrapper,
)
from .autotuner import autotune as autotune
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
)
from .cascade import merge_state as merge_state
from .cascade import merge_state_in_place as merge_state_in_place
from .cascade import merge_states as merge_states
from .decode import (
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
fast_decode_plan as fast_decode_plan,
)
from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache
from .fp4_quantization import (
SfLayout,
block_scale_interleave,
nvfp4_block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp4_dequantize_host,
mxfp4_dequantize,
mxfp4_quantize,
nvfp4_quantize,
nvfp4_batched_quantize,
shuffle_matrix_a,
shuffle_matrix_sf_a,
scaled_fp4_grouped_quantize,
)
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize
from .fused_moe import (
RoutingMethodType,
GatedActType,
cutlass_fused_moe,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_fp8 as bmm_fp8
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
from .page import get_seq_lens as get_seq_lens
from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
)
from .prefill import (
BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper,
)
from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache
from .prefill import (
single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse,
)
from .quantization import packbits as packbits
from .quantization import segment_packbits as segment_packbits
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope as apply_rope
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
)
from .sampling import chain_speculative_sampling as chain_speculative_sampling
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
from .sampling import sampling_from_logits as sampling_from_logits
from .sampling import sampling_from_probs as sampling_from_probs
from .sampling import softmax as softmax
from .sampling import top_k_mask_logits as top_k_mask_logits
from .sampling import top_k_renorm_probs as top_k_renorm_probs
from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs
from .sampling import (
top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits,
)
from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs
from .sampling import top_p_renorm_probs as top_p_renorm_probs
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
)
from .trtllm_low_latency_gemm import (
prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights,
)
from .utils import next_positive_power_of_2 as next_positive_power_of_2
from .xqa import xqa as xqa
from .xqa import xqa_mla as xqa_mla

# ============================================================================
# Activation functions
# ============================================================================
with contextlib.suppress(ImportError):
from .activation import gelu_and_mul as gelu_and_mul
from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul
from .activation import silu_and_mul as silu_and_mul
from .activation import (
silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize,
)
Comment on lines +30 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't silently swallow real ImportErrors

Blanket contextlib.suppress(ImportError) means any ImportError raised while importing .activation (including missing transitive deps or a bad refactor) now disappears, import flashinfer succeeds, and callers only discover the breakage later via an unhelpful AttributeError. Please differentiate β€œmodule truly not packaged” from β€œmodule broken”: e.g., probe importlib.util.find_spec and only skip when the spec is absent, or cache the caught exception and surface it via __getattr__ so users still see the real failure instead of a silent drop in exports.

πŸ€– Prompt for AI Agents
In flashinfer/__init__.py around lines 30 to 36, the blanket
contextlib.suppress(ImportError) hides real import errors from .activation;
change the logic to first check
importlib.util.find_spec("flashinfer.activation") and only skip imports if the
spec is None (module not packaged), otherwise attempt the import and on
ImportError cache the caught exception in a module-level variable (e.g.,
_activation_import_error) and implement __getattr__ to raise that cached
exception when consumers try to access the missing attributes; this preserves
silent skipping for genuinely absent optional modules but surfaces real
ImportErrors to callers.


# ============================================================================
# Attention modules
# ============================================================================
with contextlib.suppress(ImportError):
from .attention import BatchAttention as BatchAttention
from .attention import (
BatchAttentionWithAttentionSinkWrapper as BatchAttentionWithAttentionSinkWrapper,
)

# ============================================================================
# Autotuner
# ============================================================================
with contextlib.suppress(ImportError):
from .autotuner import autotune as autotune

# ============================================================================
# Cascade attention
# ============================================================================
with contextlib.suppress(ImportError):
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
)
from .cascade import merge_state as merge_state
from .cascade import merge_state_in_place as merge_state_in_place
from .cascade import merge_states as merge_states

# ============================================================================
# Decode operations
# ============================================================================
with contextlib.suppress(ImportError):
from .decode import (
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache,
)
from .decode import fast_decode_plan as fast_decode_plan
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache

# ============================================================================
# FP4 quantization
# ============================================================================
with contextlib.suppress(ImportError):
from .fp4_quantization import (
SfLayout,
block_scale_interleave,
nvfp4_block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp4_dequantize_host,
mxfp4_dequantize,
mxfp4_quantize,
nvfp4_quantize,
nvfp4_batched_quantize,
shuffle_matrix_a,
shuffle_matrix_sf_a,
scaled_fp4_grouped_quantize,
)

# ============================================================================
# FP8 quantization
# ============================================================================
with contextlib.suppress(ImportError):
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize

# ============================================================================
# Fused mixture-of-experts (MoE)
# ============================================================================
with contextlib.suppress(ImportError):
from .fused_moe import (
RoutingMethodType,
GatedActType,
cutlass_fused_moe,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)

# ============================================================================
# GEMM operations
# ============================================================================
with contextlib.suppress(ImportError):
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_fp8 as bmm_fp8
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100

# ============================================================================
# Multi-latent attention (MLA)
# ============================================================================
with contextlib.suppress(ImportError):
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper

# ============================================================================
# Normalization operations
# ============================================================================
with contextlib.suppress(ImportError):
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm

# ============================================================================
# Paged KV cache operations
# ============================================================================
with contextlib.suppress(ImportError):
from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
from .page import get_seq_lens as get_seq_lens

# ============================================================================
# POD (Persistent Output Decoding)
# ============================================================================
with contextlib.suppress(ImportError):
from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper


# ============================================================================
# Prefill operations
# ============================================================================
with contextlib.suppress(ImportError):
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
)
from .prefill import (
BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper,
)
from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache
from .prefill import (
single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse,
)

# ============================================================================
# Quantization utilities
# ============================================================================
with contextlib.suppress(ImportError):
from .quantization import packbits as packbits
from .quantization import segment_packbits as segment_packbits

# ============================================================================
# RoPE (Rotary Position Embedding)
# ============================================================================
with contextlib.suppress(ImportError):
from .rope import apply_llama31_rope as apply_llama31_rope
from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope as apply_rope
from .rope import apply_rope_inplace as apply_rope_inplace
from .rope import apply_rope_pos_ids as apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace,
)

# ============================================================================
# Sampling operations
# ============================================================================
with contextlib.suppress(ImportError):
from .sampling import chain_speculative_sampling as chain_speculative_sampling
from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs
from .sampling import sampling_from_logits as sampling_from_logits
from .sampling import sampling_from_probs as sampling_from_probs
from .sampling import softmax as softmax
from .sampling import top_k_mask_logits as top_k_mask_logits
from .sampling import top_k_renorm_probs as top_k_renorm_probs
from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs
from .sampling import (
top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits,
)
from .sampling import (
top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs,
)
from .sampling import top_p_renorm_probs as top_p_renorm_probs
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs

# ============================================================================
# Sparse attention
# ============================================================================
with contextlib.suppress(ImportError):
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
)

# ============================================================================
# TRT-LLM low-latency GEMM
# ============================================================================
with contextlib.suppress(ImportError):
from .trtllm_low_latency_gemm import (
prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights,
)

# ============================================================================
# Utilities
# ============================================================================
with contextlib.suppress(ImportError):
from .utils import next_positive_power_of_2 as next_positive_power_of_2

# ============================================================================
# XQA (Cross-Query Attention)
# ============================================================================
with contextlib.suppress(ImportError):
from .xqa import xqa as xqa
from .xqa import xqa_mla as xqa_mla
Comment on lines 19 to +262
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The as <name> part of the imports is redundant when the imported name is the same as the alias. Removing these redundant aliases improves code readability and conciseness.

from .version import __version__
from .version import __git_version__

import contextlib

# JIT compilation support
from . import jit

# ============================================================================
# Activation functions
# ============================================================================
with contextlib.suppress(ImportError):
    from .activation import gelu_and_mul
    from .activation import gelu_tanh_and_mul
    from .activation import silu_and_mul
    from .activation import (
        silu_and_mul_scaled_nvfp4_experts_quantize,
    )

# ============================================================================
# Attention modules
# ============================================================================
with contextlib.suppress(ImportError):
    from .attention import BatchAttention
    from .attention import (
        BatchAttentionWithAttentionSinkWrapper,
    )

# ============================================================================
# Autotuner
# ============================================================================
with contextlib.suppress(ImportError):
    from .autotuner import autotune

# ============================================================================
# Cascade attention
# ============================================================================
with contextlib.suppress(ImportError):
    from .cascade import (
        BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
    )
    from .cascade import (
        BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
    )
    from .cascade import (
        MultiLevelCascadeAttentionWrapper,
    )
    from .cascade import merge_state
    from .cascade import merge_state_in_place
    from .cascade import merge_states

# ============================================================================
# Decode operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .decode import (
        BatchDecodeMlaWithPagedKVCacheWrapper,
    )
    from .decode import (
        BatchDecodeWithPagedKVCacheWrapper,
    )
    from .decode import (
        CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
    )
    from .decode import (
        cudnn_batch_decode_with_kv_cache,
    )
    from .decode import fast_decode_plan
    from .decode import single_decode_with_kv_cache

# ============================================================================
# FP4 quantization
# ============================================================================
with contextlib.suppress(ImportError):
    from .fp4_quantization import (
        SfLayout,
        block_scale_interleave,
        nvfp4_block_scale_interleave,
        e2m1_and_ufp8sf_scale_to_float,
        fp4_quantize,
        mxfp4_dequantize_host,
        mxfp4_dequantize,
        mxfp4_quantize,
        nvfp4_quantize,
        nvfp4_batched_quantize,
        shuffle_matrix_a,
        shuffle_matrix_sf_a,
        scaled_fp4_grouped_quantize,
    )

# ============================================================================
# FP8 quantization
# ============================================================================
with contextlib.suppress(ImportError):
    from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize

# ============================================================================
# Fused mixture-of-experts (MoE)
# ============================================================================
with contextlib.suppress(ImportError):
    from .fused_moe import (
        RoutingMethodType,
        GatedActType,
        cutlass_fused_moe,
        reorder_rows_for_gated_act_gemm,
        trtllm_fp4_block_scale_moe,
        trtllm_fp4_block_scale_routed_moe,
        trtllm_fp8_block_scale_moe,
        trtllm_fp8_per_tensor_scale_moe,
    )

# ============================================================================
# GEMM operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .gemm import SegmentGEMMWrapper
    from .gemm import bmm_fp8
    from .gemm import mm_fp4
    from .gemm import mm_fp8
    from .gemm import tgv_gemm_sm100

# ============================================================================
# Multi-latent attention (MLA)
# ============================================================================
with contextlib.suppress(ImportError):
    from .mla import BatchMLAPagedAttentionWrapper

# ============================================================================
# Normalization operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .norm import fused_add_rmsnorm
    from .norm import layernorm
    from .norm import gemma_fused_add_rmsnorm
    from .norm import gemma_rmsnorm
    from .norm import rmsnorm

# ============================================================================
# Paged KV cache operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .page import append_paged_kv_cache
    from .page import append_paged_mla_kv_cache
    from .page import get_batch_indices_positions
    from .page import get_seq_lens

# ============================================================================
# POD (Persistent Output Decoding)
# ============================================================================
with contextlib.suppress(ImportError):
    from .pod import PODWithPagedKVCacheWrapper


# ============================================================================
# Prefill operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .prefill import (
        BatchPrefillWithPagedKVCacheWrapper,
    )
    from .prefill import (
        BatchPrefillWithRaggedKVCacheWrapper,
    )
    from .prefill import single_prefill_with_kv_cache
    from .prefill import (
        single_prefill_with_kv_cache_return_lse,
    )

# ============================================================================
# Quantization utilities
# ============================================================================
with contextlib.suppress(ImportError):
    from .quantization import packbits
    from .quantization import segment_packbits

# ============================================================================
# RoPE (Rotary Position Embedding)
# ============================================================================
with contextlib.suppress(ImportError):
    from .rope import apply_llama31_rope
    from .rope import apply_llama31_rope_inplace
    from .rope import apply_llama31_rope_pos_ids
    from .rope import (
        apply_llama31_rope_pos_ids_inplace,
    )
    from .rope import apply_rope
    from .rope import apply_rope_inplace
    from .rope import apply_rope_pos_ids
    from .rope import apply_rope_pos_ids_inplace
    from .rope import apply_rope_with_cos_sin_cache
    from .rope import (
        apply_rope_with_cos_sin_cache_inplace,
    )

# ============================================================================
# Sampling operations
# ============================================================================
with contextlib.suppress(ImportError):
    from .sampling import chain_speculative_sampling
    from .sampling import min_p_sampling_from_probs
    from .sampling import sampling_from_logits
    from .sampling import sampling_from_probs
    from .sampling import softmax
    from .sampling import top_k_mask_logits
    from .sampling import top_k_renorm_probs
    from .sampling import top_k_sampling_from_probs
    from .sampling import (
        top_k_top_p_sampling_from_logits,
    )
    from .sampling import (
        top_k_top_p_sampling_from_probs,
    )
    from .sampling import top_p_renorm_probs
    from .sampling import top_p_sampling_from_probs

# ============================================================================
# Sparse attention
# ============================================================================
with contextlib.suppress(ImportError):
    from .sparse import BlockSparseAttentionWrapper
    from .sparse import (
        VariableBlockSparseAttentionWrapper,
    )

# ============================================================================
# TRT-LLM low-latency GEMM
# ============================================================================
with contextlib.suppress(ImportError):
    from .trtllm_low_latency_gemm import (
        prepare_low_latency_gemm_weights,
    )

# ============================================================================
# Utilities
# ============================================================================
with contextlib.suppress(ImportError):
    from .utils import next_positive_power_of_2

# ============================================================================
# XQA (Cross-Query Attention)
# ============================================================================
with contextlib.suppress(ImportError):
    from .xqa import xqa
    from .xqa import xqa_mla