diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 39c6e37..10679ab 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -54,6 +54,23 @@ jobs: directory: . snakefile: workflow/Snakefile args: "--lint --configfile config/example_config.yaml --config skip_version_check=True" + + Pytest: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12', '3.13', '3.14'] + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install test dependencies + run: python -m pip install --upgrade pip -r requirements-test.txt + - name: Run pytest suite + run: python -m pytest -q tests # Testing: # runs-on: ubuntu-latest # needs: diff --git a/pyproject.toml b/pyproject.toml index 52c96bc..1050e93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,7 @@ [tool.snakefmt] line_length = 127 + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +addopts = "-ra" diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..42b555e --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +pytest +click +numpy +pandas diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..161fdb3 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,20 @@ +# Tests + +This directory contains the pytest suite for workflow scripts and supporting fixtures. + +## Layout + +- `tests/conftest.py` keeps shared pytest fixtures and makes the repo root importable. +- `tests//test_*.py` holds the actual tests, grouped by script area such as `count`. +- `tests/fixtures//` stores reusable input data for those tests. + +## Adding a new script test + +1. Put the new test in the matching area folder, for example `tests/count/test_new_script.py`. +2. Add any reusable inputs under `tests/fixtures//`. +3. Prefer `click.testing.CliRunner` for Click commands and pytest fixtures like `tmp_path` for temporary outputs. +4. Run a focused check with `conda run -n mpralib python -m pytest -q tests//test_new_script.py`. + +## Current example + +The MPRAnalyze compiler test lives in `tests/count/test_mpranalyze_compiler.py` and uses the fixture input at `tests/fixtures/count/minimal_test_input.tsv`. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0496f1d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,42 @@ +import sys +from pathlib import Path + +import pytest +from click.testing import CliRunner + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +PROJECT_ROOT_STR = str(PROJECT_ROOT) +TESTS_ROOT = Path(__file__).resolve().parent + +if PROJECT_ROOT_STR not in sys.path: + sys.path.insert(0, PROJECT_ROOT_STR) + + +@pytest.fixture(scope="session") +def tests_root() -> Path: + return TESTS_ROOT + + +@pytest.fixture(scope="session") +def fixtures_root(tests_root: Path) -> Path: + return tests_root / "fixtures" + + +@pytest.fixture(scope="session") +def count_fixtures_root(fixtures_root: Path) -> Path: + return fixtures_root / "count" + + +@pytest.fixture +def minimal_count_input(count_fixtures_root: Path) -> Path: + return count_fixtures_root / "minimal_test_input.tsv" + + +@pytest.fixture +def ragged_missing_input(count_fixtures_root: Path) -> Path: + return count_fixtures_root / "ragged_missing_input.tsv" + + +@pytest.fixture +def cli_runner() -> CliRunner: + return CliRunner() diff --git a/tests/count/test_mpranalyze_compiler.py b/tests/count/test_mpranalyze_compiler.py new file mode 100644 index 0000000..c8f0f8d --- /dev/null +++ b/tests/count/test_mpranalyze_compiler.py @@ -0,0 +1,140 @@ +import pandas as pd + +from workflow.scripts.count import mpranalyze_compiler as compiler + + +class TestMpranalyzeCompiler: + def test_get_annot_parses_dna_and_rna_headers(self): + assert compiler.get_annot("DNA(condition X, replicate 1)") == ("DNA", "X", "1") + assert compiler.get_annot("RNA(condition Y, replicate 2)") == ("RNA", "Y", "2") + + def test_get_annot_returns_none_for_unmatched_headers(self): + assert compiler.get_annot("Sequence") == (None, None, None) + assert compiler.get_annot("label") == (None, None, None) + + def test_generate_annotation_output_repeats_and_numbers_barcodes(self): + input_frame = pd.DataFrame( + [ + {"type": "DNA", "condition": "X", "replicate": "1"}, + {"type": "RNA", "condition": "X", "replicate": "1"}, + ] + ) + + output = compiler.generate_annotation_output(input_frame, number_barcodes=2) + + assert list(output["sample"]) == ["DNA_X_1_1", "DNA_X_1_2", "RNA_X_1_1", "RNA_X_1_2"] + assert list(output["barcode"]) == ["1", "2", "1", "2"] + + def test_generate_count_output_pads_and_flattens_by_barcode(self): + input_frame = pd.DataFrame( + [ + {"label": "oligoA", "bc1": 10, "bc2": 20}, + {"label": "oligoA", "bc1": 11, "bc2": 21}, + {"label": "oligoB", "bc1": 12, "bc2": 22}, + ] + ).set_index("label") + + output = compiler.generate_count_output(input_frame, ["sample_1", "sample_2", "sample_3", "sample_4"], number_barcodes=2) + + assert list(output["seq_id"]) == ["oligoA", "oligoB"] + assert list(output.loc[output["seq_id"] == "oligoA", ["sample_1", "sample_2", "sample_3", "sample_4"]].iloc[0]) == [10, 11, 20, 21] + assert list(output.loc[output["seq_id"] == "oligoB", ["sample_1", "sample_2", "sample_3", "sample_4"]].iloc[0]) == [12, 0, 22, 0] + + def test_cli_generates_expected_count_tables(self, cli_runner, minimal_count_input, tmp_path): + rna_counts_output = tmp_path / "rna_counts.tsv.gz" + dna_counts_output = tmp_path / "dna_counts.tsv.gz" + rna_annotation_output = tmp_path / "rna_annot.tsv.gz" + dna_annotation_output = tmp_path / "dna_annot.tsv.gz" + + result = cli_runner.invoke( + compiler.cli, + [ + "--input", + str(minimal_count_input), + "--rna-counts-output", + str(rna_counts_output), + "--dna-counts-output", + str(dna_counts_output), + "--rna-annotation-output", + str(rna_annotation_output), + "--dna-annotation-output", + str(dna_annotation_output), + ], + ) + + assert result.exit_code == 0, result.output + + rna: pd.DataFrame = pd.read_csv(rna_counts_output, sep="\t").set_index("seq_id") + dna: pd.DataFrame = pd.read_csv(dna_counts_output, sep="\t").set_index("seq_id") + cols= ["RNA_X_1_1", "RNA_X_1_2", "RNA_X_2_1", "RNA_X_2_2", "RNA_X_3_1", "RNA_X_3_2"] + row = rna.loc["oligoA"] + assert list(row.loc[cols]) == [ + 100, + 101, + 200, + 201, + 300, + 301, + ] + cols= ["DNA_X_1_1", "DNA_X_1_2", "DNA_X_2_1", "DNA_X_2_2", "DNA_X_3_1", "DNA_X_3_2"] + row = dna.loc["oligoA"] + assert list(row.loc[cols]) == [ + 10, + 11, + 20, + 21, + 30, + 31, + ] + + def test_cli_missing_count_does_not_shift_replicates(self, cli_runner, ragged_missing_input, tmp_path): + # Regression guard for the MPRAflow-style "ragged within an oligo" bug + # (shendurelab/MPRAflow#87): a missing trailing count (empty cell / trailing + # tab) must stay a zero in its own replicate/barcode slot and must NOT shift + # later counts into the wrong replicate. The fixture's oligoA is missing its + # RNA replicate-3 / barcode-2 value, and oligoB has a single barcode (so it + # also exercises the across-oligo padding case). + rna_counts_output = tmp_path / "rna_counts.tsv.gz" + dna_counts_output = tmp_path / "dna_counts.tsv.gz" + rna_annotation_output = tmp_path / "rna_annot.tsv.gz" + dna_annotation_output = tmp_path / "dna_annot.tsv.gz" + + result = cli_runner.invoke( + compiler.cli, + [ + "--input", + str(ragged_missing_input), + "--rna-counts-output", + str(rna_counts_output), + "--dna-counts-output", + str(dna_counts_output), + "--rna-annotation-output", + str(rna_annotation_output), + "--dna-annotation-output", + str(dna_annotation_output), + ], + ) + + assert result.exit_code == 0, result.output + + rna: pd.DataFrame = pd.read_csv(rna_counts_output, sep="\t").set_index("seq_id") + dna: pd.DataFrame = pd.read_csv(dna_counts_output, sep="\t").set_index("seq_id") + + rna_cols = [ + "RNA_X_1_1", "RNA_X_1_2", "RNA_X_1_3", + "RNA_X_2_1", "RNA_X_2_2", "RNA_X_2_3", + "RNA_X_3_1", "RNA_X_3_2", "RNA_X_3_3", + ] + # oligoA: the hole is at replicate 3 / barcode 2 -> must be 0, and the real + # barcode-3 value (302) must stay in barcode 3, not shift up to barcode 2. + assert list(rna.loc["oligoA", rna_cols]) == [100, 101, 102, 200, 201, 202, 300, 0, 302] + # oligoB has one barcode; remaining barcode slots pad with zeros per replicate. + assert list(rna.loc["oligoB", rna_cols]) == [103, 0, 0, 203, 0, 0, 303, 0, 0] + + # DNA has no missing values; confirm nothing shifted there either. + dna_cols = [ + "DNA_X_1_1", "DNA_X_1_2", "DNA_X_1_3", + "DNA_X_2_1", "DNA_X_2_2", "DNA_X_2_3", + "DNA_X_3_1", "DNA_X_3_2", "DNA_X_3_3", + ] + assert list(dna.loc["oligoA", dna_cols]) == [10, 11, 12, 20, 21, 22, 30, 31, 32] diff --git a/tests/fixtures/count/minimal_test_input.tsv b/tests/fixtures/count/minimal_test_input.tsv new file mode 100644 index 0000000..1a330ce --- /dev/null +++ b/tests/fixtures/count/minimal_test_input.tsv @@ -0,0 +1,5 @@ +label Sequence Barcode DNA(condition X, replicate 1) DNA(condition X, replicate 2) DNA(condition X, replicate 3) RNA(condition X, replicate 1) RNA(condition X, replicate 2) RNA(condition X, replicate 3) +oligoA A BC0 10 20 30 100 200 300 +oligoA A BC1 11 21 31 101 201 301 +oligoB B BC2 12 22 32 102 202 302 +oligoB B BC3 13 23 33 103 203 303 diff --git a/tests/fixtures/count/ragged_missing_input.tsv b/tests/fixtures/count/ragged_missing_input.tsv new file mode 100644 index 0000000..cd53b62 --- /dev/null +++ b/tests/fixtures/count/ragged_missing_input.tsv @@ -0,0 +1,5 @@ +label Sequence Barcode DNA(condition X, replicate 1) DNA(condition X, replicate 2) DNA(condition X, replicate 3) RNA(condition X, replicate 1) RNA(condition X, replicate 2) RNA(condition X, replicate 3) +oligoA A BC0 10 20 30 100 200 300 +oligoA A BC1 11 21 31 101 201 +oligoA A BC2 12 22 32 102 202 302 +oligoB B BC3 13 23 33 103 203 303 diff --git a/workflow/scripts/count/mpranalyze_compiler.py b/workflow/scripts/count/mpranalyze_compiler.py index b559c1f..e758108 100644 --- a/workflow/scripts/count/mpranalyze_compiler.py +++ b/workflow/scripts/count/mpranalyze_compiler.py @@ -6,6 +6,44 @@ import numpy as np import pandas as pd +ANNOT_PATTERN = re.compile(r"^([DR]NA).*\(condition (.*), replicate (.*)\)$") + + +def get_annot(head: str) -> tuple[str|None, str|None, str|None]: + match = ANNOT_PATTERN.match(head) + if match is not None: + group1 = match.group(1) + group2 = match.group(2) + group3 = match.group(3) + return (group1, group2, group3) + return (None, None, None) + + +def generate_annotation_output(data, number_barcodes): + data = data.loc[data.index.repeat(number_barcodes)].copy() + data['barcode'] = data.groupby(['type', 'condition', 'replicate']).cumcount() + 1 + data['barcode'] = data['barcode'].astype(str) + data['sample'] = data[['type', 'condition', 'replicate', 'barcode']].agg('_'.join, axis=1) + return data[["sample", "type", "condition", "replicate", "barcode"]] + + +def generate_count_output(data, columns, number_barcodes): + rows = [] + seq_ids = [] + for label, group in data.groupby('label', sort=False): + padded = np.zeros((number_barcodes, data.shape[1]), dtype=np.int64) + vals = group.values[:number_barcodes].astype(np.int64) + padded[:len(vals)] = vals + rows.append(padded.flatten(order='F')) + seq_ids.append(label) + counts = pd.DataFrame(rows, columns=columns) + counts.insert(0, 'seq_id', seq_ids) + return counts + + +def write_table(data, file): + data.to_csv(file, index=False, sep='\t', compression='gzip') + # options @click.command() @@ -39,12 +77,6 @@ def cli(input_file, rna_counts_output_file, dna_counts_output_file, rna_annotation_output_file, dna_annotation_output_file): - annot_pattern = re.compile(r"^([DR]NA).*\(condition (.*), replicate (.*)\)$") - def get_annot(head): - m = annot_pattern.match(head) - if m is not None: - return m.group(1,2,3) - # read input df = pd.read_csv(input_file,sep="\t", header='infer') @@ -61,44 +93,26 @@ def get_annot(head): # counts for observation - dna_df = df.iloc[:,2:(2+n_dna_obs)].applymap(np.int64) - rna_df = df.iloc[:,(2+n_dna_obs):].applymap(np.int64) + dna_df = df.iloc[:,2:(2+n_dna_obs)].astype(np.int64) + rna_df = df.iloc[:,(2+n_dna_obs):].astype(np.int64) ## generate output DNA/RNA annotations (type_condition_replicate_barcode) n_bc = df.groupby('label').Barcode.agg(len).max() - def generateAnnotationOutput(data, number_barcodes): - data = data.loc[data.index.repeat(number_barcodes)] - data['barcode'] = data.groupby(['type','condition','replicate']).cumcount() +1 - data['barcode'] = data['barcode'].astype(str) - data['sample'] = data[['type','condition','replicate','barcode']].agg('_'.join,axis=1) - data = data[["sample", "type", "condition", "replicate", "barcode"]] - return(data) - dna_annot = generateAnnotationOutput(dna_annot, n_bc) - rna_annot = generateAnnotationOutput(rna_annot, n_bc) + dna_annot = generate_annotation_output(dna_annot, n_bc) + rna_annot = generate_annotation_output(rna_annot, n_bc) ## generate output DNA/RNA count tables ## rows oligo/seq ids,/assignment then per barcode the counts. padding with zeros - def generateCountOutput(data,columns): - counts = pd.DataFrame(list(data.groupby('label').apply(lambda x: x.values.flatten()))).fillna(0).astype(np.int64) - counts.columns = columns - counts['seq_id'] = data.index.unique() - counts = counts[(['seq_id'] + list(columns))] - return(counts) - - dna_counts = generateCountOutput(dna_df,dna_annot['sample']) - rna_counts = generateCountOutput(rna_df,rna_annot['sample']) - - ## write table function - def write(data,file): - data.to_csv(file, index=False,sep='\t', compression='gzip') + dna_counts = generate_count_output(dna_df, dna_annot['sample'], n_bc) + rna_counts = generate_count_output(rna_df, rna_annot['sample'], n_bc) ## write output DNA/RNA annotations - write(dna_annot,dna_annotation_output_file) - write(rna_annot,rna_annotation_output_file) + write_table(dna_annot, dna_annotation_output_file) + write_table(rna_annot, rna_annotation_output_file) ## write output DNA/RNA annotations - write(dna_counts,dna_counts_output_file) - write(rna_counts,rna_counts_output_file) + write_table(dna_counts, dna_counts_output_file) + write_table(rna_counts, rna_counts_output_file) if __name__ == '__main__': cli()