Skip to content

Commit 655a7e9

Browse files
committed
Model Dump Fixes
Signed-off-by: Rahul Tuli <[email protected]>
1 parent cff6171 commit 655a7e9

File tree

4 files changed

+65
-128
lines changed

4 files changed

+65
-128
lines changed

src/llmcompressor/recipe/recipe.py

Lines changed: 32 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,10 @@ def simplify_combine_recipes(
280280

281281
return combined
282282

283-
version: str = None
283+
version: Optional[str] = None
284284
args: RecipeArgs = Field(default_factory=RecipeArgs)
285285
stages: List[RecipeStage] = Field(default_factory=list)
286-
metadata: RecipeMetaData = None
286+
metadata: Optional[RecipeMetaData] = None
287287
args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs)
288288

289289
def calculate_start(self) -> int:
@@ -399,11 +399,12 @@ def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]:
399399
formatted_values["stages"] = stages
400400

401401
# fill out any default argument values
402-
args = {}
402+
args = {**values.pop("args", {})}
403403
for key, val in values.items():
404-
args[key] = val
404+
# avoid nesting the args in the recipe
405+
if key not in cls.__pydantic_fields__:
406+
args[key] = val
405407
formatted_values["args"] = RecipeArgs(args)
406-
407408
return formatted_values
408409

409410
@staticmethod
@@ -504,7 +505,7 @@ def combine_metadata(self, metadata: Optional[RecipeMetaData]):
504505
else:
505506
self.metadata.update_missing_metadata(metadata)
506507

507-
def dict(self, *args, **kwargs) -> Dict[str, Any]:
508+
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
508509
"""
509510
:return: A dictionary representation of the recipe
510511
"""
@@ -522,80 +523,17 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
522523

523524
dict_["stages"] = stages
524525

525-
return dict_
526-
527-
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
528-
"""
529-
Override the model_dump method to provide a dictionary representation that
530-
is compatible with model_validate.
531-
532-
Unlike the standard model_dump, this transforms the stages list to a format
533-
expected by the validation logic, ensuring round-trip compatibility with
534-
model_validate.
535-
536-
:return: A dictionary representation of the recipe compatible with
537-
model_validate
538-
"""
539-
# Get the base dictionary from parent class
540-
base_dict = super().model_dump(*args, **kwargs)
541-
542-
# Transform stages into the expected format
543-
if "stages" in base_dict:
544-
stages_dict = {}
545-
for stage in base_dict["stages"]:
546-
group = stage["group"]
547-
if group not in stages_dict:
548-
stages_dict[group] = []
549-
stages_dict[group].append(stage)
550-
base_dict["stages"] = stages_dict
551-
552-
return base_dict
553-
554-
def yaml(self, file_path: Optional[str] = None) -> str:
555-
"""
556-
Return a yaml string representation of the recipe.
557-
558-
:param file_path: optional file path to save yaml to
559-
:return: The yaml string representation of the recipe
560-
"""
561-
file_stream = None if file_path is None else open(file_path, "w")
562-
yaml_dict = self._get_yaml_dict()
563-
564-
ret = yaml.dump(
565-
yaml_dict,
566-
stream=file_stream,
567-
allow_unicode=True,
568-
sort_keys=False,
569-
default_flow_style=None,
570-
width=88,
571-
)
572-
573-
if file_stream is not None:
574-
file_stream.close()
575-
576-
return ret
577-
578-
def _get_yaml_dict(self) -> Dict[str, Any]:
579-
"""
580-
Get a dictionary representation of the recipe for yaml serialization
581-
The returned dict will only contain information necessary for yaml
582-
serialization and must not be used in place of the dict method
583-
584-
:return: A dictionary representation of the recipe for yaml serialization
585-
"""
586-
587-
original_recipe_dict = self.dict()
588526
yaml_recipe_dict = {}
589527

590528
# populate recipe level attributes
591529
recipe_level_attributes = ["version", "args", "metadata"]
592530

593531
for attribute in recipe_level_attributes:
594-
if attribute_value := original_recipe_dict.get(attribute):
532+
if attribute_value := dict_.get(attribute):
595533
yaml_recipe_dict[attribute] = attribute_value
596534

597535
# populate stages
598-
stages = original_recipe_dict["stages"]
536+
stages = dict_.pop("stages", {})
599537
for stage_name, stage_list in stages.items():
600538
for idx, stage in enumerate(stage_list):
601539
if len(stage_list) > 1:
@@ -616,6 +554,29 @@ def _get_yaml_dict(self) -> Dict[str, Any]:
616554

617555
return yaml_recipe_dict
618556

557+
def yaml(self, file_path: Optional[str] = None) -> str:
558+
"""
559+
Return a yaml string representation of the recipe.
560+
561+
:param file_path: optional file path to save yaml to
562+
:return: The yaml string representation of the recipe
563+
"""
564+
file_stream = None if file_path is None else open(file_path, "w")
565+
566+
ret = yaml.dump(
567+
self.model_dump(),
568+
stream=file_stream,
569+
allow_unicode=True,
570+
sort_keys=False,
571+
default_flow_style=None,
572+
width=88,
573+
)
574+
575+
if file_stream is not None:
576+
file_stream.close()
577+
578+
return ret
579+
619580

620581
RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]]
621582
RecipeStageInput = Union[str, List[str], List[List[str]]]

tests/llmcompressor/helpers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,25 @@ def valid_recipe_strings():
6565
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
6666
]
6767
""",
68+
"""
69+
version: 1.0
70+
args:
71+
learning_rate: 0.001
72+
train_stage:
73+
pruning_modifiers:
74+
ConstantPruningModifier:
75+
start: 0.0
76+
end: 2.0
77+
targets: ['re:.*weight']
78+
quantization_modifiers:
79+
QuantizationModifier:
80+
bits: 8
81+
targets: ['re:.*weight']
82+
eval_stage:
83+
pruning_modifiers:
84+
ConstantPruningModifier:
85+
start: 2.0
86+
end: 4.0
87+
targets: ['re:.*weight']
88+
""",
6889
]

tests/llmcompressor/recipe/test_recipe.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,26 @@ def test_recipe_create_instance_accepts_valid_recipe_file(recipe_str):
2424

2525

2626
@pytest.mark.parametrize("recipe_str", valid_recipe_strings())
27-
def test_serialization(recipe_str):
27+
def test_yaml_serialization(recipe_str):
2828
recipe_instance = Recipe.create_instance(recipe_str)
2929
serialized_recipe = recipe_instance.yaml()
3030
recipe_from_serialized = Recipe.create_instance(serialized_recipe)
3131

32-
expected_dict = recipe_instance.dict()
33-
actual_dict = recipe_from_serialized.dict()
32+
expected_dict = recipe_instance.model_dump()
33+
actual_dict = recipe_from_serialized.model_dump()
3434

3535
assert expected_dict == actual_dict
3636

3737

38+
@pytest.mark.parametrize("recipe_str", valid_recipe_strings())
39+
def test_model_dump_and_validate(recipe_str):
40+
recipe_instance = Recipe.create_instance(recipe_str)
41+
validated_recipe = Recipe.model_validate(recipe_instance.model_dump())
42+
assert (
43+
recipe_instance == validated_recipe
44+
), "Recipe instance and validated recipe do not match"
45+
46+
3847
def test_recipe_creates_correct_modifier():
3948
start = 1
4049
end = 10

tests/recipe/test_recipe.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)