Skip to content

Commit 1795a8d

Browse files
committed
fix: pytest for evaluating model performance dont read stderr anymore
1 parent bfeca31 commit 1795a8d

File tree

6 files changed

+46
-35
lines changed

6 files changed

+46
-35
lines changed

.github/workflows/pytest.yml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@ jobs:
1111
- "3.8"
1212
- "3.9"
1313
- "3.10"
14+
- "3.11"
15+
- "3.12"
1416

1517
steps:
1618
- uses: actions/checkout@v4
17-
- name: Set up Python ${{ matrix.python-version }}
18-
uses: actions/setup-python@v4
19+
20+
- name: Install uv
21+
uses: astral-sh/setup-uv@v5
1922
with:
2023
python-version: ${{ matrix.python-version }}
21-
- name: Install Python dependencies
22-
run: |
23-
python -m pip install -U pip
24-
pip install -e .[test]
24+
25+
- name: Install the project
26+
run: uv sync --locked --all-extras --dev
27+
2528
- name: Run unit tests
26-
run: python -m pytest -m "not slow" --cov=compressai -s tests/
29+
run: uv run pytest -m "not slow" --cov=compressai -s tests/

.github/workflows/static-analysis.yml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,10 @@ jobs:
1010
python-version:
1111
- "3.8"
1212
- "3.10"
13+
- "3.12"
14+
1315
include:
1416
- os: "ubuntu-latest"
1517
steps:
16-
- uses: actions/checkout@v4
17-
- name: Set up Python ${{ matrix.python-version }}
18-
uses: actions/setup-python@v4
19-
with:
20-
python-version: ${{ matrix.python-version }}
21-
- name: Install Python dependencies
22-
run: |
23-
python3 -m pip install -U pip
24-
python3 -m pip install .[dev]
25-
- name: Run static analysis checks
18+
- uses: astral-sh/ruff-action@v3
2619
run: make static-analysis

compressai/entropy_models/entropy_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,9 @@ def _build_cdf(self, scales, means, weights, abs_max):
847847
),
848848
dim=1,
849849
).transpose(0, 1)
850-
pmf_quantized[pmf_real_steal_indices[0], pmf_real_steal_indices[1]] -= (
851-
pmf_zero_count
852-
)
850+
pmf_quantized[
851+
pmf_real_steal_indices[0], pmf_real_steal_indices[1]
852+
] -= pmf_zero_count
853853

854854
cdf = F.pad(torch.cumsum(pmf_quantized, 1).int(), (1, 0), "constant", 0)
855855
cdf = F.pad(cdf, (0, 1), "constant", cdf_limit + 1)

compressai/utils/video/eval_model/__main__.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,13 @@ def create_parser() -> argparse.ArgumentParser:
467467
default="mse",
468468
help="metric trained against (default: %(default)s)",
469469
)
470+
parent_parser.add_argument(
471+
"-d",
472+
"--output_directory",
473+
type=str,
474+
default="",
475+
help="path of output directory. Optional, required for output json file, results per video.",
476+
)
470477
parent_parser.add_argument(
471478
"-o",
472479
"--output-file",
@@ -506,7 +513,7 @@ def create_parser() -> argparse.ArgumentParser:
506513
return parser
507514

508515

509-
def main(args: Any = None) -> None:
516+
def main(args: Any = None) -> None: # noqa: C901
510517
if args is None:
511518
args = sys.argv[1:]
512519
parser = create_parser()
@@ -525,8 +532,8 @@ def main(args: Any = None) -> None:
525532
raise SystemExit(1)
526533

527534
# create output directory
528-
outputdir = args.output
529-
Path(outputdir).mkdir(parents=True, exist_ok=True)
535+
if args.output_directory:
536+
Path(args.output_directory).mkdir(parents=True, exist_ok=True)
530537

531538
if args.source == "pretrained":
532539
args.qualities = [int(q) for q in args.qualities.split(",") if q]
@@ -561,7 +568,7 @@ def main(args: Any = None) -> None:
561568
filepaths,
562569
args.dataset,
563570
model,
564-
outputdir,
571+
args.output_directory,
565572
trained_net=trained_net,
566573
description=description,
567574
**args_dict,
@@ -581,7 +588,9 @@ def main(args: Any = None) -> None:
581588
else:
582589
output_file = args.output_file
583590

584-
with (Path(f"{outputdir}/{output_file}").with_suffix(".json")).open("wb") as f:
591+
with (Path(f"{args.output_directory}/{output_file}").with_suffix(".json")).open(
592+
"wb"
593+
) as f:
585594
f.write(json.dumps(output, indent=2).encode())
586595
print(json.dumps(output, indent=2))
587596

tests/test_eval_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def test_eval_model():
7777
@pytest.mark.parametrize("quality", ("1", "4", "8"))
7878
@pytest.mark.parametrize("metric", ("mse", "ms-ssim"))
7979
@pytest.mark.parametrize("entropy_estimation", (False, True))
80-
def test_eval_model_pretrained(
81-
capsys, model, quality, metric, entropy_estimation, tmpdir
82-
):
80+
def test_eval_model_pretrained(model, quality, metric, entropy_estimation, tmpdir):
8381
here = os.path.dirname(__file__)
8482
dirpath = os.path.join(here, "assets/dataset/image")
8583

@@ -92,13 +90,18 @@ def test_eval_model_pretrained(
9290
metric,
9391
"-q",
9492
quality,
93+
"-o",
94+
f"{model}-{metric}-{quality}",
95+
"-d",
96+
str(tmpdir),
9597
]
9698
if entropy_estimation:
9799
cmd += ["--entropy-estimation"]
98100
eval_model.main(cmd)
99101

100-
output = capsys.readouterr().out
101-
output = json.loads(output)
102+
with open(f"{tmpdir}/{model}-{metric}-{quality}.json") as f:
103+
output = json.load(f)
104+
102105
expected = os.path.join(
103106
here,
104107
"expected",

tests/test_eval_model_video.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def test_eval_model_video():
7979
@pytest.mark.parametrize("quality", ("1", "4", "8"))
8080
@pytest.mark.parametrize("metric", ("mse",))
8181
@pytest.mark.parametrize("entropy_estimation", (True, False))
82-
def test_eval_model_pretrained(
83-
capsys, model, quality, metric, entropy_estimation, tmpdir
84-
):
82+
def test_eval_model_pretrained(model, quality, metric, entropy_estimation, tmpdir):
8583
here = os.path.dirname(__file__)
8684
dirpath = os.path.join(here, "assets/dataset/video")
8785

@@ -95,13 +93,18 @@ def test_eval_model_pretrained(
9593
metric,
9694
"-q",
9795
quality,
96+
"-d",
97+
str(tmpdir),
98+
"-o",
99+
f"{model}-{metric}-{quality}",
98100
]
99101
if entropy_estimation:
100102
cmd += ["--entropy-estimation"]
101103
eval_model.main(cmd)
102104

103-
output = capsys.readouterr().out
104-
output = json.loads(output)
105+
with open(f"{tmpdir}/{model}-{metric}-{quality}.json") as f:
106+
output = json.load(f)
107+
105108
expected = os.path.join(
106109
here,
107110
"expected",

0 commit comments

Comments
 (0)