Skip to content

Commit 80a0449

Browse files
authored
Added more tests for Quantization24SparseW4A16 (#1434)
SUMMARY: Added following validations for each stage: 1. Check if recipe.yaml file exists 2. Check if config.json file exists 3. Check if `quantization_config` field is present in config.json 4. Check if format is set correctly TEST: - The current setup pass the test. Manually removed `format` and it failed expectedly. QUESTION: - Any more fields I should check for? - Naming? --------- Signed-off-by: shanjiaz <[email protected]>
1 parent 2db8fab commit 80a0449

File tree

1 file changed

+56
-3
lines changed

1 file changed

+56
-3
lines changed

tests/examples/test_quantization_2of4_sparse_w4a16.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55

66
import pytest
7+
from transformers import AutoConfig
78

89
from tests.examples.utils import (
910
ReadMe,
@@ -28,19 +29,71 @@ class TestQuantization24SparseW4A16:
2829

2930
def test_doc_example_command(self, example_dir: str, tmp_path: Path):
3031
"""
31-
Test for the example command in the README.
32+
Validates the quantization_2of4_sparse_w4a16 example by executing the README
33+
command and verifying output artifacts for each processing stage.
3234
"""
3335
readme_path = Path.cwd() / example_dir / "README.md"
3436
readme = ReadMe(readme_path)
3537

3638
command = readme.get_code_block_content(position=2, lang="shell")
37-
assert command.startswith("python")
39+
assert command.startswith("python"), (
40+
"Expected shell command to start with 'python'"
41+
)
3842

3943
command = shlex.split(command)
4044
result = copy_and_run_command(tmp_path, example_dir, command)
41-
4245
assert result.returncode == 0, gen_cmd_fail_message(command, result)
4346

47+
output_dir = Path("output_llama7b_2of4_w4a16_channel")
48+
49+
stages = {
50+
"quantization": {
51+
"path": Path("quantization_stage"),
52+
"format": "marlin-24",
53+
},
54+
"sparsity": {
55+
"path": Path("sparsity_stage"),
56+
"format": "sparse-24-bitmask",
57+
},
58+
"finetuning": {
59+
"path": Path("finetuning_stage"),
60+
"format": "sparse-24-bitmask",
61+
},
62+
}
63+
64+
for stage, stage_info in stages.items():
65+
stage_path = (
66+
tmp_path / example_dir / output_dir / stage_info["path"]
67+
)
68+
recipe_path = stage_path / "recipe.yaml"
69+
config_path = stage_path / "config.json"
70+
71+
assert recipe_path.exists(), (
72+
f"Missing recipe file in {stage}: {recipe_path}"
73+
)
74+
assert config_path.exists(), (
75+
f"Missing config file in {stage}: {config_path}"
76+
)
77+
78+
config = AutoConfig.from_pretrained(stage_path)
79+
assert config is not None, f"Failed to load config in {stage}"
80+
81+
quant_config = getattr(config, "quantization_config", {})
82+
if stage == "quantization":
83+
actual_format = quant_config.get("format")
84+
else:
85+
actual_format = quant_config.get(
86+
"sparsity_config", {}
87+
).get("format")
88+
89+
assert actual_format, (
90+
f"Missing expected format field in {stage} config"
91+
)
92+
assert actual_format == stage_info["format"], (
93+
f"Unexpected format in {stage}: got '{actual_format}', "
94+
f"expected '{stage_info['format']}'"
95+
)
96+
4497
def test_alternative_recipe(self, example_dir: str, tmp_path: Path):
4598
"""
4699
Test for the example command in the README with the alternative recipe file.

0 commit comments

Comments
 (0)