22
22
GenerationConfig )
23
23
24
24
from vllm import LLM , SamplingParams
25
+ from vllm .v1 .executor .abstract import Executor
26
+ from vllm .v1 .kv_cache_interface import (ChunkedLocalAttentionSpec ,
27
+ FullAttentionSpec )
25
28
26
29
from ....utils import multi_gpu_test
27
30
@@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
69
72
raise
70
73
71
74
75
+ def get_rope_layers_config (model_path : str ) -> list [int ]:
76
+ """
77
+ Get the interleaved RoPE configuration from HuggingFace config
78
+
79
+ Args:
80
+ model_path: Path to the local directory containing the reduced
81
+ Maverick model checkpoint
82
+
83
+ Returns:
84
+ List of 0 or 1 indicating whether each layer uses RoPE and local attn
85
+ 0 indicates that RoPE is not used while 1 indicates that RoPE is used.
86
+ """
87
+ config_path = Path (model_path ) / "config.json"
88
+ model_config = json .loads (config_path .read_text ())
89
+ text_config = model_config ["text_config" ]
90
+ no_rope_layers = text_config ["no_rope_layers" ]
91
+ print (f"Found no_rope_layers: { no_rope_layers } " )
92
+ return no_rope_layers
93
+
94
+
72
95
def create_reduced_maverick_model (
73
96
original_model_name :
74
97
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" ,
@@ -113,7 +136,6 @@ def create_reduced_maverick_model(
113
136
print ("Loading original model configuration..." )
114
137
original_config = AutoConfig .from_pretrained (original_model_name ,
115
138
trust_remote_code = True )
116
-
117
139
print ("Creating reduced configuration..." )
118
140
reduced_config = create_reduced_config (original_config , text_layers ,
119
141
num_experts , vision_layers )
@@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
510
532
f"{ index_data ['metadata' ]['total_size' ] / (1024 ** 3 ):.2f} GB" )
511
533
512
534
513
- def run_reduced_model (model_path : str ,
514
- should_profile : bool = False ,
515
- ** kwargs ) -> None :
516
- """Test the created reduced model with vLLM."""
517
-
518
- print (f"\n Testing reduced model at { model_path } ..." )
519
-
520
- llm = LLM (
521
- model = model_path ,
522
- trust_remote_code = True ,
523
- max_model_len = 512 , # Small context for testing
524
- gpu_memory_utilization = 0.3 , # Conservative memory usage
525
- ** kwargs ,
535
+ def check_attention_spec_interleaved_rope (
536
+ llm : LLM ,
537
+ num_attention_layers : int ,
538
+ num_ranks : int ,
539
+ rope_layers : list [int ],
540
+ ):
541
+ """Check that the attention spec is correct."""
542
+ assert isinstance (llm .llm_engine .model_executor , Executor )
543
+ kv_cache_specs_per_rank = llm .llm_engine .model_executor .get_kv_cache_specs (
526
544
)
527
-
545
+ for rank in range (num_ranks ):
546
+ kv_cache_specs = kv_cache_specs_per_rank [rank ]
547
+ assert len (kv_cache_specs .keys ()) == num_attention_layers
548
+ for i in range (num_attention_layers ):
549
+ if rope_layers [i ] == 0 :
550
+ expected_spec = FullAttentionSpec
551
+ else :
552
+ expected_spec = ChunkedLocalAttentionSpec
553
+ assert isinstance (
554
+ kv_cache_specs [
555
+ f"language_model.model.layers.{ i } .self_attn.attn" ],
556
+ expected_spec )
557
+
558
+
559
+ def run_reduced_model (llm : LLM , should_profile : bool = False ) -> None :
560
+ """Test the created reduced model with vLLM."""
528
561
sampling_params = SamplingParams (temperature = 0.8 ,
529
562
top_p = 0.95 ,
530
563
max_tokens = 50 )
@@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
551
584
@pytest .mark .parametrize ("tp,ep" , [(2 , True )])
552
585
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
553
586
def test_dummy_maverick (
587
+ monkeypatch ,
554
588
original_model_name : str ,
555
589
text_layers : int ,
556
590
num_experts : int ,
@@ -562,6 +596,10 @@ def test_dummy_maverick(
562
596
force_recreate : bool = True ,
563
597
profile : bool = False ,
564
598
) -> None :
599
+ # Disable multiprocessing allows us to access model executor from LLM engine
600
+ monkeypatch .setenv ("VLLM_USE_V1" , "1" )
601
+ monkeypatch .setenv ("VLLM_ENABLE_V1_MULTIPROCESSING" , "0" )
602
+
565
603
model_path = create_reduced_maverick_model (
566
604
original_model_name = original_model_name ,
567
605
output_dir = output_dir ,
@@ -573,11 +611,27 @@ def test_dummy_maverick(
573
611
574
612
print (f"\n Reduced model created successfully at: { model_path } " )
575
613
576
- run_reduced_model (model_path = model_path ,
577
- should_profile = profile ,
578
- enforce_eager = enforce_eager ,
579
- tensor_parallel_size = tp ,
580
- enable_expert_parallel = ep )
614
+ rope_layers = get_rope_layers_config (model_path )
615
+
616
+ llm = LLM (
617
+ model = model_path ,
618
+ trust_remote_code = True ,
619
+ max_model_len = 512 , # Small context for testing
620
+ gpu_memory_utilization = 0.3 , # Conservative memory usage
621
+ enforce_eager = enforce_eager ,
622
+ tensor_parallel_size = tp ,
623
+ enable_expert_parallel = ep ,
624
+ )
625
+
626
+ check_attention_spec_interleaved_rope (
627
+ llm ,
628
+ text_layers ,
629
+ tp ,
630
+ rope_layers ,
631
+ )
632
+
633
+ print (f"\n Testing reduced model at { model_path } ..." )
634
+ run_reduced_model (llm = llm , should_profile = profile )
581
635
582
636
583
637
def main ():
0 commit comments