diff --git a/src/llmcompressor/recipe/recipe.py b/src/llmcompressor/recipe/recipe.py index 2a4c1e4c5..8c6f025b8 100644 --- a/src/llmcompressor/recipe/recipe.py +++ b/src/llmcompressor/recipe/recipe.py @@ -35,6 +35,12 @@ class Recipe(RecipeBase): when serializing a recipe, yaml will be used by default. """ + version: Optional[str] = Field(default=None) + args: RecipeArgs = Field(default_factory=RecipeArgs) + stages: List[RecipeStage] = Field(default_factory=list) + metadata: Optional[RecipeMetaData] = Field(default=None) + args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs) + @classmethod def from_modifiers( cls, @@ -280,12 +286,6 @@ def simplify_combine_recipes( return combined - version: str = None - args: RecipeArgs = Field(default_factory=RecipeArgs) - stages: List[RecipeStage] = Field(default_factory=list) - metadata: RecipeMetaData = None - args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs) - def calculate_start(self) -> int: """ Calculate and return the start epoch of the recipe. @@ -399,11 +399,12 @@ def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]: formatted_values["stages"] = stages # fill out any default argument values - args = {} + args = {**values.pop("args", {})} for key, val in values.items(): - args[key] = val + # avoid nesting the args in the recipe + if key not in cls.__pydantic_fields__: + args[key] = val formatted_values["args"] = RecipeArgs(args) - return formatted_values @staticmethod @@ -504,52 +505,62 @@ def combine_metadata(self, metadata: Optional[RecipeMetaData]): else: self.metadata.update_missing_metadata(metadata) - def dict(self, *args, **kwargs) -> Dict[str, Any]: - """ - :return: A dictionary representation of the recipe + def model_dump(self, *args, **kwargs) -> Dict[str, Any]: """ - dict_ = super().model_dump(*args, **kwargs) - stages = {} - - for stage in dict_["stages"]: - name = f"{stage['group']}_stage" - del stage["group"] + Generate a serializable dictionary representation of this recipe. - if name not in stages: - stages[name] = [] + This method transforms the internal recipe structure into a format + suitable for YAML serialization while preserving all necessary + information for round-trip deserialization. - stages[name].append(stage) + :param args: Additional positional arguments for parent method + :param kwargs: Additional keyword arguments for parent method + :return: Dictionary ready for YAML serialization + """ + # Retrieve base representation from parent class + raw_dict = super().model_dump(*args, **kwargs) + + # Initialize clean output dictionary + serializable_dict = {} + + # Copy recipe metadata attributes + metadata_keys = ["version", "args", "metadata"] + for key in metadata_keys: + if value := raw_dict.get(key): + serializable_dict[key] = value + + # Process and organize stages by group + if "stages" in raw_dict: + # Group stages by their type (e.g., "train", "eval") + grouped_stages = {} + for stage in raw_dict["stages"]: + group_id = ( + f"{stage.pop('group')}_stage" # Remove group field and use as key + ) - dict_["stages"] = stages + if group_id not in grouped_stages: + grouped_stages[group_id] = [] - return dict_ + grouped_stages[group_id].append(stage) - def model_dump(self, *args, **kwargs) -> Dict[str, Any]: - """ - Override the model_dump method to provide a dictionary representation that - is compatible with model_validate. + # Format each stage for YAML output + for group_id, stages in grouped_stages.items(): + for idx, stage_data in enumerate(stages): + # Create unique identifiers for multiple stages of same type + final_id = f"{group_id}_{idx}" if len(stages) > 1 else group_id - Unlike the standard model_dump, this transforms the stages list to a format - expected by the validation logic, ensuring round-trip compatibility with - model_validate. + # Create clean stage representation + stage_yaml = get_yaml_serializable_stage_dict( + modifiers=stage_data["modifiers"] + ) - :return: A dictionary representation of the recipe compatible with - model_validate - """ - # Get the base dictionary from parent class - base_dict = super().model_dump(*args, **kwargs) + # Preserve run type if specified + if run_type := stage_data.get("run_type"): + stage_yaml["run_type"] = run_type - # Transform stages into the expected format - if "stages" in base_dict: - stages_dict = {} - for stage in base_dict["stages"]: - group = stage["group"] - if group not in stages_dict: - stages_dict[group] = [] - stages_dict[group].append(stage) - base_dict["stages"] = stages_dict + serializable_dict[final_id] = stage_yaml - return base_dict + return serializable_dict def yaml(self, file_path: Optional[str] = None) -> str: """ @@ -559,10 +570,9 @@ def yaml(self, file_path: Optional[str] = None) -> str: :return: The yaml string representation of the recipe """ file_stream = None if file_path is None else open(file_path, "w") - yaml_dict = self._get_yaml_dict() ret = yaml.dump( - yaml_dict, + self.model_dump(), stream=file_stream, allow_unicode=True, sort_keys=False, @@ -575,47 +585,6 @@ def yaml(self, file_path: Optional[str] = None) -> str: return ret - def _get_yaml_dict(self) -> Dict[str, Any]: - """ - Get a dictionary representation of the recipe for yaml serialization - The returned dict will only contain information necessary for yaml - serialization and must not be used in place of the dict method - - :return: A dictionary representation of the recipe for yaml serialization - """ - - original_recipe_dict = self.dict() - yaml_recipe_dict = {} - - # populate recipe level attributes - recipe_level_attributes = ["version", "args", "metadata"] - - for attribute in recipe_level_attributes: - if attribute_value := original_recipe_dict.get(attribute): - yaml_recipe_dict[attribute] = attribute_value - - # populate stages - stages = original_recipe_dict["stages"] - for stage_name, stage_list in stages.items(): - for idx, stage in enumerate(stage_list): - if len(stage_list) > 1: - # resolve name clashes caused by combining recipes with - # duplicate stage names - final_stage_name = f"{stage_name}_{idx}" - else: - final_stage_name = stage_name - stage_dict = get_yaml_serializable_stage_dict( - modifiers=stage["modifiers"] - ) - - # infer run_type from stage - if run_type := stage.get("run_type"): - stage_dict["run_type"] = run_type - - yaml_recipe_dict[final_stage_name] = stage_dict - - return yaml_recipe_dict - RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]] RecipeStageInput = Union[str, List[str], List[List[str]]] diff --git a/tests/llmcompressor/helpers.py b/tests/llmcompressor/helpers.py index 2204b708e..7c33fa3df 100644 --- a/tests/llmcompressor/helpers.py +++ b/tests/llmcompressor/helpers.py @@ -65,4 +65,25 @@ def valid_recipe_strings(): [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] ] """, + """ + version: 1.0 + args: + learning_rate: 0.001 + train_stage: + pruning_modifiers: + ConstantPruningModifier: + start: 0.0 + end: 2.0 + targets: ['re:.*weight'] + quantization_modifiers: + QuantizationModifier: + bits: 8 + targets: ['re:.*weight'] + eval_stage: + pruning_modifiers: + ConstantPruningModifier: + start: 2.0 + end: 4.0 + targets: ['re:.*weight'] + """, ] diff --git a/tests/llmcompressor/recipe/test_recipe.py b/tests/llmcompressor/recipe/test_recipe.py index 848b0bc6a..550a707a0 100644 --- a/tests/llmcompressor/recipe/test_recipe.py +++ b/tests/llmcompressor/recipe/test_recipe.py @@ -1,98 +1,185 @@ +import os import tempfile import pytest import yaml from llmcompressor.modifiers.obcq.base import SparseGPTModifier +from llmcompressor.modifiers.pruning.constant import ConstantPruningModifier from llmcompressor.recipe import Recipe from tests.llmcompressor.helpers import valid_recipe_strings @pytest.mark.parametrize("recipe_str", valid_recipe_strings()) -def test_recipe_create_instance_accepts_valid_recipe_string(recipe_str): - recipe = Recipe.create_instance(recipe_str) - assert recipe is not None, "Recipe could not be created from string" +class TestRecipeWithStrings: + """Tests that use various recipe strings for validation.""" + + def test_create_from_string(self, recipe_str): + """Test creating a Recipe from a YAML string.""" + recipe = Recipe.create_instance(recipe_str) + assert recipe is not None, "Recipe could not be created from string" + assert isinstance(recipe, Recipe), "Created object is not a Recipe instance" + + def test_create_from_file(self, recipe_str): + """Test creating a Recipe from a YAML file.""" + content = yaml.safe_load(recipe_str) + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + yaml.dump(content, f) + f.flush() # Ensure content is written + recipe = Recipe.create_instance(f.name) + assert recipe is not None, "Recipe could not be created from file" + assert isinstance(recipe, Recipe), "Created object is not a Recipe instance" + + def test_yaml_serialization_roundtrip(self, recipe_str): + """ + Test that a recipe can be serialized to YAML + and deserialized back with all properties preserved. + """ + # Create original recipe + original_recipe = Recipe.create_instance(recipe_str) + # Serialize to YAML + yaml_str = original_recipe.yaml() + assert yaml_str, "Serialized YAML string should not be empty" -@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) -def test_recipe_create_instance_accepts_valid_recipe_file(recipe_str): - content = yaml.safe_load(recipe_str) - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: - yaml.dump(content, f) - recipe = Recipe.create_instance(f.name) - assert recipe is not None, "Recipe could not be created from file" + # Deserialize from YAML + deserialized_recipe = Recipe.create_instance(yaml_str) + # Compare serialized forms + original_dict = original_recipe.model_dump() + deserialized_dict = deserialized_recipe.model_dump() -@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) -def test_serialization(recipe_str): - recipe_instance = Recipe.create_instance(recipe_str) - serialized_recipe = recipe_instance.yaml() - recipe_from_serialized = Recipe.create_instance(serialized_recipe) + assert original_dict == deserialized_dict, "Serialization roundtrip failed" - expected_dict = recipe_instance.dict() - actual_dict = recipe_from_serialized.dict() + def test_model_dump_and_validate(self, recipe_str): + """ + Test that model_dump produces a format compatible + with model_validate. + """ + recipe = Recipe.create_instance(recipe_str) + validated_recipe = Recipe.model_validate(recipe.model_dump()) + assert ( + recipe == validated_recipe + ), "Recipe instance and validated recipe do not match" - assert expected_dict == actual_dict +class TestRecipeSerialization: + """ + Tests for Recipe serialization and deserialization + edge cases.""" -def test_recipe_creates_correct_modifier(): - start = 1 - end = 10 - targets = "__ALL_PRUNABLE__" + def test_empty_recipe_serialization(self): + """Test serialization of a minimal recipe with no stages.""" + recipe = Recipe() + assert len(recipe.stages) == 0, "New recipe should have no stages" - yaml_str = f""" - test_stage: - pruning_modifiers: - ConstantPruningModifier: - start: {start} - end: {end} - targets: {targets} - """ + # Test roundtrip serialization + dumped = recipe.model_dump() + loaded = Recipe.model_validate(dumped) + assert recipe == loaded, "Empty recipe serialization failed" + + def test_file_serialization(self): + """Test serializing a recipe to a file and reading it back.""" + recipe = Recipe.create_instance(valid_recipe_strings()[0]) - recipe_instance = Recipe.create_instance(yaml_str) - - stage_modifiers = recipe_instance.create_modifier() - assert len(stage_modifiers) == 1 - assert len(modifiers := stage_modifiers[0].modifiers) == 1 - from llmcompressor.modifiers.pruning.constant import ConstantPruningModifier - - assert isinstance(modifier := modifiers[0], ConstantPruningModifier) - assert modifier.start == start - assert modifier.end == end - - -def test_recipe_can_be_created_from_modifier_instances(): - modifier = SparseGPTModifier( - sparsity=0.5, - ) - group_name = "dummy" - - # for pep8 compliance - recipe_str = ( - f"{group_name}_stage:\n" - " pruning_modifiers:\n" - " SparseGPTModifier:\n" - " sparsity: 0.5\n" - ) - - expected_recipe_instance = Recipe.create_instance(recipe_str) - expected_modifiers = expected_recipe_instance.create_modifier() - - actual_recipe_instance = Recipe.create_instance( - [modifier], modifier_group_name=group_name - ) - actual_modifiers = actual_recipe_instance.create_modifier() - - # assert num stages is the same - assert len(actual_modifiers) == len(expected_modifiers) - - # assert num modifiers in each stage is the same - assert len(actual_modifiers[0].modifiers) == len(expected_modifiers[0].modifiers) - - # assert modifiers in each stage are the same type - # and have the same parameters - for actual_modifier, expected_modifier in zip( - actual_modifiers[0].modifiers, expected_modifiers[0].modifiers - ): - assert isinstance(actual_modifier, type(expected_modifier)) - assert actual_modifier.model_dump() == expected_modifier.model_dump() + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "recipe.yaml") + + # Write to file + recipe.yaml(file_path=file_path) + assert os.path.exists(file_path), "YAML file was not created" + assert os.path.getsize(file_path) > 0, "YAML file is empty" + + # Read back from file + loaded_recipe = Recipe.create_instance(file_path) + assert ( + recipe == loaded_recipe + ), "Recipe loaded from file doesn't match original" + + +class TestRecipeModifiers: + """Tests for creating and working with modifiers in recipes.""" + + def test_creates_correct_modifier(self): + """ + Test that a recipe creates the expected modifier type + with correct parameters. + """ + # Recipe parameters + params = {"start": 1, "end": 10, "targets": "__ALL_PRUNABLE__"} + + # Create recipe from YAML + yaml_str = f""" + test_stage: + pruning_modifiers: + ConstantPruningModifier: + start: {params['start']} + end: {params['end']} + targets: {params['targets']} + """ + recipe = Recipe.create_instance(yaml_str) + + # Get modifiers from recipe + stage_modifiers = recipe.create_modifier() + assert len(stage_modifiers) == 1, "Expected exactly one stage modifier" + + modifiers = stage_modifiers[0].modifiers + assert len(modifiers) == 1, "Expected exactly one modifier in the stage" + + # Verify modifier type and parameters + modifier = modifiers[0] + assert isinstance( + modifier, ConstantPruningModifier + ), "Wrong modifier type created" + assert modifier.start == params["start"], "Modifier start value incorrect" + assert modifier.end == params["end"], "Modifier end value incorrect" + assert modifier.targets == params["targets"], "Modifier targets incorrect" + + def test_create_from_modifier_instances(self): + """Test creating a recipe from modifier instances.""" + # Create a modifier instance + sparsity_value = 0.5 + modifier = SparseGPTModifier(sparsity=sparsity_value) + group_name = "dummy" + + # Expected YAML representation + recipe_str = ( + f"{group_name}_stage:\n" + " pruning_modifiers:\n" + " SparseGPTModifier:\n" + f" sparsity: {sparsity_value}\n" + ) + + # Create recipes for comparison + expected_recipe = Recipe.create_instance(recipe_str) + actual_recipe = Recipe.create_instance( + [modifier], modifier_group_name=group_name + ) + + # Compare recipes by creating and checking their modifiers + self._compare_recipe_modifiers(actual_recipe, expected_recipe) + + def _compare_recipe_modifiers(self, actual_recipe, expected_recipe): + """Helper method to compare modifiers created from two recipes.""" + actual_modifiers = actual_recipe.create_modifier() + expected_modifiers = expected_recipe.create_modifier() + + # Compare stage counts + assert len(actual_modifiers) == len(expected_modifiers), "Stage counts differ" + + if not actual_modifiers: + return # No modifiers to compare + + # Compare modifier counts in each stage + assert len(actual_modifiers[0].modifiers) == len( + expected_modifiers[0].modifiers + ), "Modifier counts differ" + + # Compare modifier types and parameters + for actual_mod, expected_mod in zip( + actual_modifiers[0].modifiers, expected_modifiers[0].modifiers + ): + assert isinstance(actual_mod, type(expected_mod)), "Modifier types differ" + assert ( + actual_mod.model_dump() == expected_mod.model_dump() + ), "Modifier parameters differ" diff --git a/tests/recipe/test_recipe.py b/tests/recipe/test_recipe.py deleted file mode 100644 index 7675162f4..000000000 --- a/tests/recipe/test_recipe.py +++ /dev/null @@ -1,54 +0,0 @@ -from src.llmcompressor.recipe import Recipe - - -def test_recipe_model_dump(): - """Test that model_dump produces a format compatible with model_validate.""" - # Create a recipe with multiple stages and modifiers - recipe_str = """ - version: "1.0" - args: - learning_rate: 0.001 - train_stage: - pruning_modifiers: - ConstantPruningModifier: - start: 0.0 - end: 2.0 - targets: ['re:.*weight'] - quantization_modifiers: - QuantizationModifier: - bits: 8 - targets: ['re:.*weight'] - eval_stage: - pruning_modifiers: - ConstantPruningModifier: - start: 2.0 - end: 4.0 - targets: ['re:.*weight'] - """ - - # Create recipe instance - recipe = Recipe.create_instance(recipe_str) - - # Get dictionary representation - recipe_dict = recipe.model_dump() - - # Verify the structure is compatible with model_validate - # by creating a new recipe from the dictionary - new_recipe = Recipe.model_validate(recipe_dict) - - # Verify version and args are preserved - assert new_recipe.version == recipe.version - assert new_recipe.args == recipe.args - - # Verify stages are preserved - assert len(new_recipe.stages) == len(recipe.stages) - - # Verify stage names and modifiers are preserved - for new_stage, orig_stage in zip(new_recipe.stages, recipe.stages): - assert new_stage.group == orig_stage.group - assert len(new_stage.modifiers) == len(orig_stage.modifiers) - - # Verify modifier types and args are preserved - for new_mod, orig_mod in zip(new_stage.modifiers, orig_stage.modifiers): - assert new_mod.type == orig_mod.type - assert new_mod.args == orig_mod.args