Skip to content

Commit 77cb67d

Browse files
sarckkjinzhen-lin
authored andcommitted
Add interleaved RoPE test for Llama4 (Maverick) (vllm-project#21478)
Signed-off-by: Yong Hoon Shin <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 3da4fd1 commit 77cb67d

File tree

1 file changed

+74
-20
lines changed

1 file changed

+74
-20
lines changed

tests/models/multimodal/generation/test_maverick.py

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
GenerationConfig)
2323

2424
from vllm import LLM, SamplingParams
25+
from vllm.v1.executor.abstract import Executor
26+
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
27+
FullAttentionSpec)
2528

2629
from ....utils import multi_gpu_test
2730

@@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
6972
raise
7073

7174

75+
def get_rope_layers_config(model_path: str) -> list[int]:
76+
"""
77+
Get the interleaved RoPE configuration from HuggingFace config
78+
79+
Args:
80+
model_path: Path to the local directory containing the reduced
81+
Maverick model checkpoint
82+
83+
Returns:
84+
List of 0 or 1 indicating whether each layer uses RoPE and local attn
85+
0 indicates that RoPE is not used while 1 indicates that RoPE is used.
86+
"""
87+
config_path = Path(model_path) / "config.json"
88+
model_config = json.loads(config_path.read_text())
89+
text_config = model_config["text_config"]
90+
no_rope_layers = text_config["no_rope_layers"]
91+
print(f"Found no_rope_layers: {no_rope_layers}")
92+
return no_rope_layers
93+
94+
7295
def create_reduced_maverick_model(
7396
original_model_name:
7497
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
@@ -113,7 +136,6 @@ def create_reduced_maverick_model(
113136
print("Loading original model configuration...")
114137
original_config = AutoConfig.from_pretrained(original_model_name,
115138
trust_remote_code=True)
116-
117139
print("Creating reduced configuration...")
118140
reduced_config = create_reduced_config(original_config, text_layers,
119141
num_experts, vision_layers)
@@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
510532
f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB")
511533

512534

513-
def run_reduced_model(model_path: str,
514-
should_profile: bool = False,
515-
**kwargs) -> None:
516-
"""Test the created reduced model with vLLM."""
517-
518-
print(f"\nTesting reduced model at {model_path}...")
519-
520-
llm = LLM(
521-
model=model_path,
522-
trust_remote_code=True,
523-
max_model_len=512, # Small context for testing
524-
gpu_memory_utilization=0.3, # Conservative memory usage
525-
**kwargs,
535+
def check_attention_spec_interleaved_rope(
536+
llm: LLM,
537+
num_attention_layers: int,
538+
num_ranks: int,
539+
rope_layers: list[int],
540+
):
541+
"""Check that the attention spec is correct."""
542+
assert isinstance(llm.llm_engine.model_executor, Executor)
543+
kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs(
526544
)
527-
545+
for rank in range(num_ranks):
546+
kv_cache_specs = kv_cache_specs_per_rank[rank]
547+
assert len(kv_cache_specs.keys()) == num_attention_layers
548+
for i in range(num_attention_layers):
549+
if rope_layers[i] == 0:
550+
expected_spec = FullAttentionSpec
551+
else:
552+
expected_spec = ChunkedLocalAttentionSpec
553+
assert isinstance(
554+
kv_cache_specs[
555+
f"language_model.model.layers.{i}.self_attn.attn"],
556+
expected_spec)
557+
558+
559+
def run_reduced_model(llm: LLM, should_profile: bool = False) -> None:
560+
"""Test the created reduced model with vLLM."""
528561
sampling_params = SamplingParams(temperature=0.8,
529562
top_p=0.95,
530563
max_tokens=50)
@@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
551584
@pytest.mark.parametrize("tp,ep", [(2, True)])
552585
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
553586
def test_dummy_maverick(
587+
monkeypatch,
554588
original_model_name: str,
555589
text_layers: int,
556590
num_experts: int,
@@ -562,6 +596,10 @@ def test_dummy_maverick(
562596
force_recreate: bool = True,
563597
profile: bool = False,
564598
) -> None:
599+
# Disable multiprocessing allows us to access model executor from LLM engine
600+
monkeypatch.setenv("VLLM_USE_V1", "1")
601+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
602+
565603
model_path = create_reduced_maverick_model(
566604
original_model_name=original_model_name,
567605
output_dir=output_dir,
@@ -573,11 +611,27 @@ def test_dummy_maverick(
573611

574612
print(f"\nReduced model created successfully at: {model_path}")
575613

576-
run_reduced_model(model_path=model_path,
577-
should_profile=profile,
578-
enforce_eager=enforce_eager,
579-
tensor_parallel_size=tp,
580-
enable_expert_parallel=ep)
614+
rope_layers = get_rope_layers_config(model_path)
615+
616+
llm = LLM(
617+
model=model_path,
618+
trust_remote_code=True,
619+
max_model_len=512, # Small context for testing
620+
gpu_memory_utilization=0.3, # Conservative memory usage
621+
enforce_eager=enforce_eager,
622+
tensor_parallel_size=tp,
623+
enable_expert_parallel=ep,
624+
)
625+
626+
check_attention_spec_interleaved_rope(
627+
llm,
628+
text_layers,
629+
tp,
630+
rope_layers,
631+
)
632+
633+
print(f"\nTesting reduced model at {model_path}...")
634+
run_reduced_model(llm=llm, should_profile=profile)
581635

582636

583637
def main():

0 commit comments

Comments
 (0)