Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions tests/models/multimodal/generation/test_maverick.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
GenerationConfig)

from vllm import LLM, SamplingParams
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec)

from ....utils import multi_gpu_test

Expand Down Expand Up @@ -69,6 +72,26 @@ def run_maverick_serving(model: str):
raise


def get_rope_layers_config(model_path: str) -> list[int]:
"""
Get the interleaved RoPE configuration from HuggingFace config

Args:
model_path: Path to the local directory containing the reduced
Maverick model checkpoint

Returns:
List of 0 or 1 indicating whether each layer uses RoPE and local attn
0 indicates that RoPE is not used while 1 indicates that RoPE is used.
"""
config_path = Path(model_path) / "config.json"
model_config = json.loads(config_path.read_text())
text_config = model_config["text_config"]
no_rope_layers = text_config["no_rope_layers"]
print(f"Found no_rope_layers: {no_rope_layers}")
return no_rope_layers


def create_reduced_maverick_model(
original_model_name:
str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
Expand Down Expand Up @@ -113,7 +136,6 @@ def create_reduced_maverick_model(
print("Loading original model configuration...")
original_config = AutoConfig.from_pretrained(original_model_name,
trust_remote_code=True)

print("Creating reduced configuration...")
reduced_config = create_reduced_config(original_config, text_layers,
num_experts, vision_layers)
Expand Down Expand Up @@ -510,21 +532,32 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor],
f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB")


def run_reduced_model(model_path: str,
should_profile: bool = False,
**kwargs) -> None:
"""Test the created reduced model with vLLM."""

print(f"\nTesting reduced model at {model_path}...")

llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=512, # Small context for testing
gpu_memory_utilization=0.3, # Conservative memory usage
**kwargs,
def check_attention_spec_interleaved_rope(
llm: LLM,
num_attention_layers: int,
num_ranks: int,
rope_layers: list[int],
):
"""Check that the attention spec is correct."""
assert isinstance(llm.llm_engine.model_executor, Executor)
kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs(
)

for rank in range(num_ranks):
kv_cache_specs = kv_cache_specs_per_rank[rank]
assert len(kv_cache_specs.keys()) == num_attention_layers
for i in range(num_attention_layers):
if rope_layers[i] == 0:
expected_spec = FullAttentionSpec
else:
expected_spec = ChunkedLocalAttentionSpec
assert isinstance(
kv_cache_specs[
f"language_model.model.layers.{i}.self_attn.attn"],
expected_spec)


def run_reduced_model(llm: LLM, should_profile: bool = False) -> None:
"""Test the created reduced model with vLLM."""
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=50)
Expand All @@ -551,6 +584,7 @@ def run_reduced_model(model_path: str,
@pytest.mark.parametrize("tp,ep", [(2, True)])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_dummy_maverick(
monkeypatch,
original_model_name: str,
text_layers: int,
num_experts: int,
Expand All @@ -562,6 +596,10 @@ def test_dummy_maverick(
force_recreate: bool = True,
profile: bool = False,
) -> None:
# Disable multiprocessing allows us to access model executor from LLM engine
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

model_path = create_reduced_maverick_model(
original_model_name=original_model_name,
output_dir=output_dir,
Expand All @@ -573,11 +611,27 @@ def test_dummy_maverick(

print(f"\nReduced model created successfully at: {model_path}")

run_reduced_model(model_path=model_path,
should_profile=profile,
enforce_eager=enforce_eager,
tensor_parallel_size=tp,
enable_expert_parallel=ep)
rope_layers = get_rope_layers_config(model_path)

llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=512, # Small context for testing
gpu_memory_utilization=0.3, # Conservative memory usage
enforce_eager=enforce_eager,
tensor_parallel_size=tp,
enable_expert_parallel=ep,
)

check_attention_spec_interleaved_rope(
llm,
text_layers,
tp,
rope_layers,
)

print(f"\nTesting reduced model at {model_path}...")
run_reduced_model(llm=llm, should_profile=profile)


def main():
Expand Down