-
Notifications
You must be signed in to change notification settings - Fork 564
Refactor flashinfer/__init__.py so that applications could selectively pack submodules without modifying __init__.py #2027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
this is to make the code more modular, applications can selectively pack and import modules based on needs using the same __init__.py
WalkthroughThe module initialization converts eager imports to lazy, guarded imports organized into thematic sections, using Changes
Estimated code review effortπ― 2 (Simple) | β±οΈ ~10 minutes
Poem
Pre-merge checks and finishing touchesβ Failed checks (1 warning)
β Passed checks (1 passed)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bangshengtang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors flashinfer/__init__.py to make submodule imports optional by wrapping them in contextlib.suppress(ImportError). This is a good change that improves the modularity of the package, allowing users to selectively pack submodules. My review includes a suggestion to improve code style by removing redundant aliases in the import statements, which will enhance readability.
| from .version import __version__ as __version__ | ||
| from .version import __git_version__ as __git_version__ | ||
|
|
||
| import contextlib | ||
|
|
||
| # JIT compilation support | ||
| from . import jit as jit | ||
| from .activation import gelu_and_mul as gelu_and_mul | ||
| from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul | ||
| from .activation import silu_and_mul as silu_and_mul | ||
| from .activation import ( | ||
| silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize, | ||
| ) | ||
| from .attention import BatchAttention as BatchAttention | ||
| from .attention import ( | ||
| BatchAttentionWithAttentionSinkWrapper as BatchAttentionWithAttentionSinkWrapper, | ||
| ) | ||
| from .autotuner import autotune as autotune | ||
| from .cascade import ( | ||
| BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper, | ||
| ) | ||
| from .cascade import ( | ||
| BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper, | ||
| ) | ||
| from .cascade import ( | ||
| MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper, | ||
| ) | ||
| from .cascade import merge_state as merge_state | ||
| from .cascade import merge_state_in_place as merge_state_in_place | ||
| from .cascade import merge_states as merge_states | ||
| from .decode import ( | ||
| BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper, | ||
| ) | ||
| from .decode import ( | ||
| BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper, | ||
| ) | ||
| from .decode import ( | ||
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper, | ||
| ) | ||
| from .decode import ( | ||
| fast_decode_plan as fast_decode_plan, | ||
| ) | ||
| from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache | ||
| from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache | ||
| from .fp4_quantization import ( | ||
| SfLayout, | ||
| block_scale_interleave, | ||
| nvfp4_block_scale_interleave, | ||
| e2m1_and_ufp8sf_scale_to_float, | ||
| fp4_quantize, | ||
| mxfp4_dequantize_host, | ||
| mxfp4_dequantize, | ||
| mxfp4_quantize, | ||
| nvfp4_quantize, | ||
| nvfp4_batched_quantize, | ||
| shuffle_matrix_a, | ||
| shuffle_matrix_sf_a, | ||
| scaled_fp4_grouped_quantize, | ||
| ) | ||
| from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize | ||
| from .fused_moe import ( | ||
| RoutingMethodType, | ||
| GatedActType, | ||
| cutlass_fused_moe, | ||
| reorder_rows_for_gated_act_gemm, | ||
| trtllm_fp4_block_scale_moe, | ||
| trtllm_fp4_block_scale_routed_moe, | ||
| trtllm_fp8_block_scale_moe, | ||
| trtllm_fp8_per_tensor_scale_moe, | ||
| ) | ||
| from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper | ||
| from .gemm import bmm_fp8 as bmm_fp8 | ||
| from .gemm import mm_fp4 as mm_fp4 | ||
| from .gemm import mm_fp8 as mm_fp8 | ||
| from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 | ||
| from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper | ||
| from .norm import fused_add_rmsnorm as fused_add_rmsnorm | ||
| from .norm import layernorm as layernorm | ||
| from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm | ||
| from .norm import gemma_rmsnorm as gemma_rmsnorm | ||
| from .norm import rmsnorm as rmsnorm | ||
| from .page import append_paged_kv_cache as append_paged_kv_cache | ||
| from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache | ||
| from .page import get_batch_indices_positions as get_batch_indices_positions | ||
| from .page import get_seq_lens as get_seq_lens | ||
| from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper | ||
| from .prefill import ( | ||
| BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper, | ||
| ) | ||
| from .prefill import ( | ||
| BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper, | ||
| ) | ||
| from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache | ||
| from .prefill import ( | ||
| single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse, | ||
| ) | ||
| from .quantization import packbits as packbits | ||
| from .quantization import segment_packbits as segment_packbits | ||
| from .rope import apply_llama31_rope as apply_llama31_rope | ||
| from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace | ||
| from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids | ||
| from .rope import ( | ||
| apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace, | ||
| ) | ||
| from .rope import apply_rope as apply_rope | ||
| from .rope import apply_rope_inplace as apply_rope_inplace | ||
| from .rope import apply_rope_pos_ids as apply_rope_pos_ids | ||
| from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace | ||
| from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache | ||
| from .rope import ( | ||
| apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace, | ||
| ) | ||
| from .sampling import chain_speculative_sampling as chain_speculative_sampling | ||
| from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs | ||
| from .sampling import sampling_from_logits as sampling_from_logits | ||
| from .sampling import sampling_from_probs as sampling_from_probs | ||
| from .sampling import softmax as softmax | ||
| from .sampling import top_k_mask_logits as top_k_mask_logits | ||
| from .sampling import top_k_renorm_probs as top_k_renorm_probs | ||
| from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs | ||
| from .sampling import ( | ||
| top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits, | ||
| ) | ||
| from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs | ||
| from .sampling import top_p_renorm_probs as top_p_renorm_probs | ||
| from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs | ||
| from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper | ||
| from .sparse import ( | ||
| VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper, | ||
| ) | ||
| from .trtllm_low_latency_gemm import ( | ||
| prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights, | ||
| ) | ||
| from .utils import next_positive_power_of_2 as next_positive_power_of_2 | ||
| from .xqa import xqa as xqa | ||
| from .xqa import xqa_mla as xqa_mla | ||
|
|
||
| # ============================================================================ | ||
| # Activation functions | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .activation import gelu_and_mul as gelu_and_mul | ||
| from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul | ||
| from .activation import silu_and_mul as silu_and_mul | ||
| from .activation import ( | ||
| silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # Attention modules | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .attention import BatchAttention as BatchAttention | ||
| from .attention import ( | ||
| BatchAttentionWithAttentionSinkWrapper as BatchAttentionWithAttentionSinkWrapper, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # Autotuner | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .autotuner import autotune as autotune | ||
|
|
||
| # ============================================================================ | ||
| # Cascade attention | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .cascade import ( | ||
| BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper, | ||
| ) | ||
| from .cascade import ( | ||
| BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper, | ||
| ) | ||
| from .cascade import ( | ||
| MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper, | ||
| ) | ||
| from .cascade import merge_state as merge_state | ||
| from .cascade import merge_state_in_place as merge_state_in_place | ||
| from .cascade import merge_states as merge_states | ||
|
|
||
| # ============================================================================ | ||
| # Decode operations | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .decode import ( | ||
| BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper, | ||
| ) | ||
| from .decode import ( | ||
| BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper, | ||
| ) | ||
| from .decode import ( | ||
| CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper, | ||
| ) | ||
| from .decode import ( | ||
| cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache, | ||
| ) | ||
| from .decode import fast_decode_plan as fast_decode_plan | ||
| from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache | ||
|
|
||
| # ============================================================================ | ||
| # FP4 quantization | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .fp4_quantization import ( | ||
| SfLayout, | ||
| block_scale_interleave, | ||
| nvfp4_block_scale_interleave, | ||
| e2m1_and_ufp8sf_scale_to_float, | ||
| fp4_quantize, | ||
| mxfp4_dequantize_host, | ||
| mxfp4_dequantize, | ||
| mxfp4_quantize, | ||
| nvfp4_quantize, | ||
| nvfp4_batched_quantize, | ||
| shuffle_matrix_a, | ||
| shuffle_matrix_sf_a, | ||
| scaled_fp4_grouped_quantize, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # FP8 quantization | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize | ||
|
|
||
| # ============================================================================ | ||
| # Fused mixture-of-experts (MoE) | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .fused_moe import ( | ||
| RoutingMethodType, | ||
| GatedActType, | ||
| cutlass_fused_moe, | ||
| reorder_rows_for_gated_act_gemm, | ||
| trtllm_fp4_block_scale_moe, | ||
| trtllm_fp4_block_scale_routed_moe, | ||
| trtllm_fp8_block_scale_moe, | ||
| trtllm_fp8_per_tensor_scale_moe, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # GEMM operations | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper | ||
| from .gemm import bmm_fp8 as bmm_fp8 | ||
| from .gemm import mm_fp4 as mm_fp4 | ||
| from .gemm import mm_fp8 as mm_fp8 | ||
| from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 | ||
|
|
||
| # ============================================================================ | ||
| # Multi-latent attention (MLA) | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper | ||
|
|
||
| # ============================================================================ | ||
| # Normalization operations | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .norm import fused_add_rmsnorm as fused_add_rmsnorm | ||
| from .norm import layernorm as layernorm | ||
| from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm | ||
| from .norm import gemma_rmsnorm as gemma_rmsnorm | ||
| from .norm import rmsnorm as rmsnorm | ||
|
|
||
| # ============================================================================ | ||
| # Paged KV cache operations | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .page import append_paged_kv_cache as append_paged_kv_cache | ||
| from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache | ||
| from .page import get_batch_indices_positions as get_batch_indices_positions | ||
| from .page import get_seq_lens as get_seq_lens | ||
|
|
||
| # ============================================================================ | ||
| # POD (Persistent Output Decoding) | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .pod import PODWithPagedKVCacheWrapper as PODWithPagedKVCacheWrapper | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Prefill operations | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .prefill import ( | ||
| BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper, | ||
| ) | ||
| from .prefill import ( | ||
| BatchPrefillWithRaggedKVCacheWrapper as BatchPrefillWithRaggedKVCacheWrapper, | ||
| ) | ||
| from .prefill import single_prefill_with_kv_cache as single_prefill_with_kv_cache | ||
| from .prefill import ( | ||
| single_prefill_with_kv_cache_return_lse as single_prefill_with_kv_cache_return_lse, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # Quantization utilities | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .quantization import packbits as packbits | ||
| from .quantization import segment_packbits as segment_packbits | ||
|
|
||
| # ============================================================================ | ||
| # RoPE (Rotary Position Embedding) | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .rope import apply_llama31_rope as apply_llama31_rope | ||
| from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace | ||
| from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids | ||
| from .rope import ( | ||
| apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace, | ||
| ) | ||
| from .rope import apply_rope as apply_rope | ||
| from .rope import apply_rope_inplace as apply_rope_inplace | ||
| from .rope import apply_rope_pos_ids as apply_rope_pos_ids | ||
| from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace | ||
| from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache | ||
| from .rope import ( | ||
| apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # Sampling operations | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .sampling import chain_speculative_sampling as chain_speculative_sampling | ||
| from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs | ||
| from .sampling import sampling_from_logits as sampling_from_logits | ||
| from .sampling import sampling_from_probs as sampling_from_probs | ||
| from .sampling import softmax as softmax | ||
| from .sampling import top_k_mask_logits as top_k_mask_logits | ||
| from .sampling import top_k_renorm_probs as top_k_renorm_probs | ||
| from .sampling import top_k_sampling_from_probs as top_k_sampling_from_probs | ||
| from .sampling import ( | ||
| top_k_top_p_sampling_from_logits as top_k_top_p_sampling_from_logits, | ||
| ) | ||
| from .sampling import ( | ||
| top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs, | ||
| ) | ||
| from .sampling import top_p_renorm_probs as top_p_renorm_probs | ||
| from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs | ||
|
|
||
| # ============================================================================ | ||
| # Sparse attention | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper | ||
| from .sparse import ( | ||
| VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # TRT-LLM low-latency GEMM | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .trtllm_low_latency_gemm import ( | ||
| prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights, | ||
| ) | ||
|
|
||
| # ============================================================================ | ||
| # Utilities | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .utils import next_positive_power_of_2 as next_positive_power_of_2 | ||
|
|
||
| # ============================================================================ | ||
| # XQA (Cross-Query Attention) | ||
| # ============================================================================ | ||
| with contextlib.suppress(ImportError): | ||
| from .xqa import xqa as xqa | ||
| from .xqa import xqa_mla as xqa_mla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The as <name> part of the imports is redundant when the imported name is the same as the alias. Removing these redundant aliases improves code readability and conciseness.
from .version import __version__
from .version import __git_version__
import contextlib
# JIT compilation support
from . import jit
# ============================================================================
# Activation functions
# ============================================================================
with contextlib.suppress(ImportError):
from .activation import gelu_and_mul
from .activation import gelu_tanh_and_mul
from .activation import silu_and_mul
from .activation import (
silu_and_mul_scaled_nvfp4_experts_quantize,
)
# ============================================================================
# Attention modules
# ============================================================================
with contextlib.suppress(ImportError):
from .attention import BatchAttention
from .attention import (
BatchAttentionWithAttentionSinkWrapper,
)
# ============================================================================
# Autotuner
# ============================================================================
with contextlib.suppress(ImportError):
from .autotuner import autotune
# ============================================================================
# Cascade attention
# ============================================================================
with contextlib.suppress(ImportError):
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .cascade import (
MultiLevelCascadeAttentionWrapper,
)
from .cascade import merge_state
from .cascade import merge_state_in_place
from .cascade import merge_states
# ============================================================================
# Decode operations
# ============================================================================
with contextlib.suppress(ImportError):
from .decode import (
BatchDecodeMlaWithPagedKVCacheWrapper,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .decode import (
cudnn_batch_decode_with_kv_cache,
)
from .decode import fast_decode_plan
from .decode import single_decode_with_kv_cache
# ============================================================================
# FP4 quantization
# ============================================================================
with contextlib.suppress(ImportError):
from .fp4_quantization import (
SfLayout,
block_scale_interleave,
nvfp4_block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float,
fp4_quantize,
mxfp4_dequantize_host,
mxfp4_dequantize,
mxfp4_quantize,
nvfp4_quantize,
nvfp4_batched_quantize,
shuffle_matrix_a,
shuffle_matrix_sf_a,
scaled_fp4_grouped_quantize,
)
# ============================================================================
# FP8 quantization
# ============================================================================
with contextlib.suppress(ImportError):
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize
# ============================================================================
# Fused mixture-of-experts (MoE)
# ============================================================================
with contextlib.suppress(ImportError):
from .fused_moe import (
RoutingMethodType,
GatedActType,
cutlass_fused_moe,
reorder_rows_for_gated_act_gemm,
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
)
# ============================================================================
# GEMM operations
# ============================================================================
with contextlib.suppress(ImportError):
from .gemm import SegmentGEMMWrapper
from .gemm import bmm_fp8
from .gemm import mm_fp4
from .gemm import mm_fp8
from .gemm import tgv_gemm_sm100
# ============================================================================
# Multi-latent attention (MLA)
# ============================================================================
with contextlib.suppress(ImportError):
from .mla import BatchMLAPagedAttentionWrapper
# ============================================================================
# Normalization operations
# ============================================================================
with contextlib.suppress(ImportError):
from .norm import fused_add_rmsnorm
from .norm import layernorm
from .norm import gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm
from .norm import rmsnorm
# ============================================================================
# Paged KV cache operations
# ============================================================================
with contextlib.suppress(ImportError):
from .page import append_paged_kv_cache
from .page import append_paged_mla_kv_cache
from .page import get_batch_indices_positions
from .page import get_seq_lens
# ============================================================================
# POD (Persistent Output Decoding)
# ============================================================================
with contextlib.suppress(ImportError):
from .pod import PODWithPagedKVCacheWrapper
# ============================================================================
# Prefill operations
# ============================================================================
with contextlib.suppress(ImportError):
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper,
)
from .prefill import (
BatchPrefillWithRaggedKVCacheWrapper,
)
from .prefill import single_prefill_with_kv_cache
from .prefill import (
single_prefill_with_kv_cache_return_lse,
)
# ============================================================================
# Quantization utilities
# ============================================================================
with contextlib.suppress(ImportError):
from .quantization import packbits
from .quantization import segment_packbits
# ============================================================================
# RoPE (Rotary Position Embedding)
# ============================================================================
with contextlib.suppress(ImportError):
from .rope import apply_llama31_rope
from .rope import apply_llama31_rope_inplace
from .rope import apply_llama31_rope_pos_ids
from .rope import (
apply_llama31_rope_pos_ids_inplace,
)
from .rope import apply_rope
from .rope import apply_rope_inplace
from .rope import apply_rope_pos_ids
from .rope import apply_rope_pos_ids_inplace
from .rope import apply_rope_with_cos_sin_cache
from .rope import (
apply_rope_with_cos_sin_cache_inplace,
)
# ============================================================================
# Sampling operations
# ============================================================================
with contextlib.suppress(ImportError):
from .sampling import chain_speculative_sampling
from .sampling import min_p_sampling_from_probs
from .sampling import sampling_from_logits
from .sampling import sampling_from_probs
from .sampling import softmax
from .sampling import top_k_mask_logits
from .sampling import top_k_renorm_probs
from .sampling import top_k_sampling_from_probs
from .sampling import (
top_k_top_p_sampling_from_logits,
)
from .sampling import (
top_k_top_p_sampling_from_probs,
)
from .sampling import top_p_renorm_probs
from .sampling import top_p_sampling_from_probs
# ============================================================================
# Sparse attention
# ============================================================================
with contextlib.suppress(ImportError):
from .sparse import BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper,
)
# ============================================================================
# TRT-LLM low-latency GEMM
# ============================================================================
with contextlib.suppress(ImportError):
from .trtllm_low_latency_gemm import (
prepare_low_latency_gemm_weights,
)
# ============================================================================
# Utilities
# ============================================================================
with contextlib.suppress(ImportError):
from .utils import next_positive_power_of_2
# ============================================================================
# XQA (Cross-Query Attention)
# ============================================================================
with contextlib.suppress(ImportError):
from .xqa import xqa
from .xqa import xqa_mlaThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| with contextlib.suppress(ImportError): | ||
| from .activation import gelu_and_mul as gelu_and_mul | ||
| from .activation import gelu_tanh_and_mul as gelu_tanh_and_mul | ||
| from .activation import silu_and_mul as silu_and_mul | ||
| from .activation import ( | ||
| silu_and_mul_scaled_nvfp4_experts_quantize as silu_and_mul_scaled_nvfp4_experts_quantize, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't silently swallow real ImportErrors
Blanket contextlib.suppress(ImportError) means any ImportError raised while importing .activation (including missing transitive deps or a bad refactor) now disappears, import flashinfer succeeds, and callers only discover the breakage later via an unhelpful AttributeError. Please differentiate βmodule truly not packagedβ from βmodule brokenβ: e.g., probe importlib.util.find_spec and only skip when the spec is absent, or cache the caught exception and surface it via __getattr__ so users still see the real failure instead of a silent drop in exports.
π€ Prompt for AI Agents
In flashinfer/__init__.py around lines 30 to 36, the blanket
contextlib.suppress(ImportError) hides real import errors from .activation;
change the logic to first check
importlib.util.find_spec("flashinfer.activation") and only skip imports if the
spec is None (module not packaged), otherwise attempt the import and on
ImportError cache the caught exception in a module-level variable (e.g.,
_activation_import_error) and implement __getattr__ to raise that cached
exception when consumers try to access the missing attributes; this preserves
silent skipping for genuinely absent optional modules but surfaces real
ImportErrors to callers.
|
LGTM. What was the main use case for the change? |
|
@nvmbreughe we have use cases that only need a subset of flashinfer, e.g. sampling, we can make an additional target that only includes all the relevant files from flashinfer/, and the callsite can |
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit