diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 39c6e37..10679ab 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -54,6 +54,23 @@ jobs:
directory: .
snakefile: workflow/Snakefile
args: "--lint --configfile config/example_config.yaml --config skip_version_check=True"
+
+ Pytest:
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
+ steps:
+ - uses: actions/checkout@v6
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install test dependencies
+ run: python -m pip install --upgrade pip -r requirements-test.txt
+ - name: Run pytest suite
+ run: python -m pytest -q tests
# Testing:
# runs-on: ubuntu-latest
# needs:
diff --git a/pyproject.toml b/pyproject.toml
index 52c96bc..1050e93 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,2 +1,7 @@
[tool.snakefmt]
line_length = 127
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+python_files = ["test_*.py", "*_test.py"]
+addopts = "-ra"
diff --git a/requirements-test.txt b/requirements-test.txt
new file mode 100644
index 0000000..42b555e
--- /dev/null
+++ b/requirements-test.txt
@@ -0,0 +1,4 @@
+pytest
+click
+numpy
+pandas
diff --git a/tests/README.md b/tests/README.md
new file mode 100644
index 0000000..161fdb3
--- /dev/null
+++ b/tests/README.md
@@ -0,0 +1,20 @@
+# Tests
+
+This directory contains the pytest suite for workflow scripts and supporting fixtures.
+
+## Layout
+
+- `tests/conftest.py` keeps shared pytest fixtures and makes the repo root importable.
+- `tests//test_*.py` holds the actual tests, grouped by script area such as `count`.
+- `tests/fixtures//` stores reusable input data for those tests.
+
+## Adding a new script test
+
+1. Put the new test in the matching area folder, for example `tests/count/test_new_script.py`.
+2. Add any reusable inputs under `tests/fixtures//`.
+3. Prefer `click.testing.CliRunner` for Click commands and pytest fixtures like `tmp_path` for temporary outputs.
+4. Run a focused check with `conda run -n mpralib python -m pytest -q tests//test_new_script.py`.
+
+## Current example
+
+The MPRAnalyze compiler test lives in `tests/count/test_mpranalyze_compiler.py` and uses the fixture input at `tests/fixtures/count/minimal_test_input.tsv`.
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..0496f1d
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,42 @@
+import sys
+from pathlib import Path
+
+import pytest
+from click.testing import CliRunner
+
+PROJECT_ROOT = Path(__file__).resolve().parents[1]
+PROJECT_ROOT_STR = str(PROJECT_ROOT)
+TESTS_ROOT = Path(__file__).resolve().parent
+
+if PROJECT_ROOT_STR not in sys.path:
+ sys.path.insert(0, PROJECT_ROOT_STR)
+
+
+@pytest.fixture(scope="session")
+def tests_root() -> Path:
+ return TESTS_ROOT
+
+
+@pytest.fixture(scope="session")
+def fixtures_root(tests_root: Path) -> Path:
+ return tests_root / "fixtures"
+
+
+@pytest.fixture(scope="session")
+def count_fixtures_root(fixtures_root: Path) -> Path:
+ return fixtures_root / "count"
+
+
+@pytest.fixture
+def minimal_count_input(count_fixtures_root: Path) -> Path:
+ return count_fixtures_root / "minimal_test_input.tsv"
+
+
+@pytest.fixture
+def ragged_missing_input(count_fixtures_root: Path) -> Path:
+ return count_fixtures_root / "ragged_missing_input.tsv"
+
+
+@pytest.fixture
+def cli_runner() -> CliRunner:
+ return CliRunner()
diff --git a/tests/count/test_mpranalyze_compiler.py b/tests/count/test_mpranalyze_compiler.py
new file mode 100644
index 0000000..c8f0f8d
--- /dev/null
+++ b/tests/count/test_mpranalyze_compiler.py
@@ -0,0 +1,140 @@
+import pandas as pd
+
+from workflow.scripts.count import mpranalyze_compiler as compiler
+
+
+class TestMpranalyzeCompiler:
+ def test_get_annot_parses_dna_and_rna_headers(self):
+ assert compiler.get_annot("DNA(condition X, replicate 1)") == ("DNA", "X", "1")
+ assert compiler.get_annot("RNA(condition Y, replicate 2)") == ("RNA", "Y", "2")
+
+ def test_get_annot_returns_none_for_unmatched_headers(self):
+ assert compiler.get_annot("Sequence") == (None, None, None)
+ assert compiler.get_annot("label") == (None, None, None)
+
+ def test_generate_annotation_output_repeats_and_numbers_barcodes(self):
+ input_frame = pd.DataFrame(
+ [
+ {"type": "DNA", "condition": "X", "replicate": "1"},
+ {"type": "RNA", "condition": "X", "replicate": "1"},
+ ]
+ )
+
+ output = compiler.generate_annotation_output(input_frame, number_barcodes=2)
+
+ assert list(output["sample"]) == ["DNA_X_1_1", "DNA_X_1_2", "RNA_X_1_1", "RNA_X_1_2"]
+ assert list(output["barcode"]) == ["1", "2", "1", "2"]
+
+ def test_generate_count_output_pads_and_flattens_by_barcode(self):
+ input_frame = pd.DataFrame(
+ [
+ {"label": "oligoA", "bc1": 10, "bc2": 20},
+ {"label": "oligoA", "bc1": 11, "bc2": 21},
+ {"label": "oligoB", "bc1": 12, "bc2": 22},
+ ]
+ ).set_index("label")
+
+ output = compiler.generate_count_output(input_frame, ["sample_1", "sample_2", "sample_3", "sample_4"], number_barcodes=2)
+
+ assert list(output["seq_id"]) == ["oligoA", "oligoB"]
+ assert list(output.loc[output["seq_id"] == "oligoA", ["sample_1", "sample_2", "sample_3", "sample_4"]].iloc[0]) == [10, 11, 20, 21]
+ assert list(output.loc[output["seq_id"] == "oligoB", ["sample_1", "sample_2", "sample_3", "sample_4"]].iloc[0]) == [12, 0, 22, 0]
+
+ def test_cli_generates_expected_count_tables(self, cli_runner, minimal_count_input, tmp_path):
+ rna_counts_output = tmp_path / "rna_counts.tsv.gz"
+ dna_counts_output = tmp_path / "dna_counts.tsv.gz"
+ rna_annotation_output = tmp_path / "rna_annot.tsv.gz"
+ dna_annotation_output = tmp_path / "dna_annot.tsv.gz"
+
+ result = cli_runner.invoke(
+ compiler.cli,
+ [
+ "--input",
+ str(minimal_count_input),
+ "--rna-counts-output",
+ str(rna_counts_output),
+ "--dna-counts-output",
+ str(dna_counts_output),
+ "--rna-annotation-output",
+ str(rna_annotation_output),
+ "--dna-annotation-output",
+ str(dna_annotation_output),
+ ],
+ )
+
+ assert result.exit_code == 0, result.output
+
+ rna: pd.DataFrame = pd.read_csv(rna_counts_output, sep="\t").set_index("seq_id")
+ dna: pd.DataFrame = pd.read_csv(dna_counts_output, sep="\t").set_index("seq_id")
+ cols= ["RNA_X_1_1", "RNA_X_1_2", "RNA_X_2_1", "RNA_X_2_2", "RNA_X_3_1", "RNA_X_3_2"]
+ row = rna.loc["oligoA"]
+ assert list(row.loc[cols]) == [
+ 100,
+ 101,
+ 200,
+ 201,
+ 300,
+ 301,
+ ]
+ cols= ["DNA_X_1_1", "DNA_X_1_2", "DNA_X_2_1", "DNA_X_2_2", "DNA_X_3_1", "DNA_X_3_2"]
+ row = dna.loc["oligoA"]
+ assert list(row.loc[cols]) == [
+ 10,
+ 11,
+ 20,
+ 21,
+ 30,
+ 31,
+ ]
+
+ def test_cli_missing_count_does_not_shift_replicates(self, cli_runner, ragged_missing_input, tmp_path):
+ # Regression guard for the MPRAflow-style "ragged within an oligo" bug
+ # (shendurelab/MPRAflow#87): a missing trailing count (empty cell / trailing
+ # tab) must stay a zero in its own replicate/barcode slot and must NOT shift
+ # later counts into the wrong replicate. The fixture's oligoA is missing its
+ # RNA replicate-3 / barcode-2 value, and oligoB has a single barcode (so it
+ # also exercises the across-oligo padding case).
+ rna_counts_output = tmp_path / "rna_counts.tsv.gz"
+ dna_counts_output = tmp_path / "dna_counts.tsv.gz"
+ rna_annotation_output = tmp_path / "rna_annot.tsv.gz"
+ dna_annotation_output = tmp_path / "dna_annot.tsv.gz"
+
+ result = cli_runner.invoke(
+ compiler.cli,
+ [
+ "--input",
+ str(ragged_missing_input),
+ "--rna-counts-output",
+ str(rna_counts_output),
+ "--dna-counts-output",
+ str(dna_counts_output),
+ "--rna-annotation-output",
+ str(rna_annotation_output),
+ "--dna-annotation-output",
+ str(dna_annotation_output),
+ ],
+ )
+
+ assert result.exit_code == 0, result.output
+
+ rna: pd.DataFrame = pd.read_csv(rna_counts_output, sep="\t").set_index("seq_id")
+ dna: pd.DataFrame = pd.read_csv(dna_counts_output, sep="\t").set_index("seq_id")
+
+ rna_cols = [
+ "RNA_X_1_1", "RNA_X_1_2", "RNA_X_1_3",
+ "RNA_X_2_1", "RNA_X_2_2", "RNA_X_2_3",
+ "RNA_X_3_1", "RNA_X_3_2", "RNA_X_3_3",
+ ]
+ # oligoA: the hole is at replicate 3 / barcode 2 -> must be 0, and the real
+ # barcode-3 value (302) must stay in barcode 3, not shift up to barcode 2.
+ assert list(rna.loc["oligoA", rna_cols]) == [100, 101, 102, 200, 201, 202, 300, 0, 302]
+ # oligoB has one barcode; remaining barcode slots pad with zeros per replicate.
+ assert list(rna.loc["oligoB", rna_cols]) == [103, 0, 0, 203, 0, 0, 303, 0, 0]
+
+ # DNA has no missing values; confirm nothing shifted there either.
+ dna_cols = [
+ "DNA_X_1_1", "DNA_X_1_2", "DNA_X_1_3",
+ "DNA_X_2_1", "DNA_X_2_2", "DNA_X_2_3",
+ "DNA_X_3_1", "DNA_X_3_2", "DNA_X_3_3",
+ ]
+ assert list(dna.loc["oligoA", dna_cols]) == [10, 11, 12, 20, 21, 22, 30, 31, 32]
diff --git a/tests/fixtures/count/minimal_test_input.tsv b/tests/fixtures/count/minimal_test_input.tsv
new file mode 100644
index 0000000..1a330ce
--- /dev/null
+++ b/tests/fixtures/count/minimal_test_input.tsv
@@ -0,0 +1,5 @@
+label Sequence Barcode DNA(condition X, replicate 1) DNA(condition X, replicate 2) DNA(condition X, replicate 3) RNA(condition X, replicate 1) RNA(condition X, replicate 2) RNA(condition X, replicate 3)
+oligoA A BC0 10 20 30 100 200 300
+oligoA A BC1 11 21 31 101 201 301
+oligoB B BC2 12 22 32 102 202 302
+oligoB B BC3 13 23 33 103 203 303
diff --git a/tests/fixtures/count/ragged_missing_input.tsv b/tests/fixtures/count/ragged_missing_input.tsv
new file mode 100644
index 0000000..cd53b62
--- /dev/null
+++ b/tests/fixtures/count/ragged_missing_input.tsv
@@ -0,0 +1,5 @@
+label Sequence Barcode DNA(condition X, replicate 1) DNA(condition X, replicate 2) DNA(condition X, replicate 3) RNA(condition X, replicate 1) RNA(condition X, replicate 2) RNA(condition X, replicate 3)
+oligoA A BC0 10 20 30 100 200 300
+oligoA A BC1 11 21 31 101 201
+oligoA A BC2 12 22 32 102 202 302
+oligoB B BC3 13 23 33 103 203 303
diff --git a/workflow/scripts/count/mpranalyze_compiler.py b/workflow/scripts/count/mpranalyze_compiler.py
index b559c1f..e758108 100644
--- a/workflow/scripts/count/mpranalyze_compiler.py
+++ b/workflow/scripts/count/mpranalyze_compiler.py
@@ -6,6 +6,44 @@
import numpy as np
import pandas as pd
+ANNOT_PATTERN = re.compile(r"^([DR]NA).*\(condition (.*), replicate (.*)\)$")
+
+
+def get_annot(head: str) -> tuple[str|None, str|None, str|None]:
+ match = ANNOT_PATTERN.match(head)
+ if match is not None:
+ group1 = match.group(1)
+ group2 = match.group(2)
+ group3 = match.group(3)
+ return (group1, group2, group3)
+ return (None, None, None)
+
+
+def generate_annotation_output(data, number_barcodes):
+ data = data.loc[data.index.repeat(number_barcodes)].copy()
+ data['barcode'] = data.groupby(['type', 'condition', 'replicate']).cumcount() + 1
+ data['barcode'] = data['barcode'].astype(str)
+ data['sample'] = data[['type', 'condition', 'replicate', 'barcode']].agg('_'.join, axis=1)
+ return data[["sample", "type", "condition", "replicate", "barcode"]]
+
+
+def generate_count_output(data, columns, number_barcodes):
+ rows = []
+ seq_ids = []
+ for label, group in data.groupby('label', sort=False):
+ padded = np.zeros((number_barcodes, data.shape[1]), dtype=np.int64)
+ vals = group.values[:number_barcodes].astype(np.int64)
+ padded[:len(vals)] = vals
+ rows.append(padded.flatten(order='F'))
+ seq_ids.append(label)
+ counts = pd.DataFrame(rows, columns=columns)
+ counts.insert(0, 'seq_id', seq_ids)
+ return counts
+
+
+def write_table(data, file):
+ data.to_csv(file, index=False, sep='\t', compression='gzip')
+
# options
@click.command()
@@ -39,12 +77,6 @@
def cli(input_file, rna_counts_output_file, dna_counts_output_file, rna_annotation_output_file, dna_annotation_output_file):
- annot_pattern = re.compile(r"^([DR]NA).*\(condition (.*), replicate (.*)\)$")
- def get_annot(head):
- m = annot_pattern.match(head)
- if m is not None:
- return m.group(1,2,3)
-
# read input
df = pd.read_csv(input_file,sep="\t", header='infer')
@@ -61,44 +93,26 @@ def get_annot(head):
# counts for observation
- dna_df = df.iloc[:,2:(2+n_dna_obs)].applymap(np.int64)
- rna_df = df.iloc[:,(2+n_dna_obs):].applymap(np.int64)
+ dna_df = df.iloc[:,2:(2+n_dna_obs)].astype(np.int64)
+ rna_df = df.iloc[:,(2+n_dna_obs):].astype(np.int64)
## generate output DNA/RNA annotations (type_condition_replicate_barcode)
n_bc = df.groupby('label').Barcode.agg(len).max()
- def generateAnnotationOutput(data, number_barcodes):
- data = data.loc[data.index.repeat(number_barcodes)]
- data['barcode'] = data.groupby(['type','condition','replicate']).cumcount() +1
- data['barcode'] = data['barcode'].astype(str)
- data['sample'] = data[['type','condition','replicate','barcode']].agg('_'.join,axis=1)
- data = data[["sample", "type", "condition", "replicate", "barcode"]]
- return(data)
- dna_annot = generateAnnotationOutput(dna_annot, n_bc)
- rna_annot = generateAnnotationOutput(rna_annot, n_bc)
+ dna_annot = generate_annotation_output(dna_annot, n_bc)
+ rna_annot = generate_annotation_output(rna_annot, n_bc)
## generate output DNA/RNA count tables
## rows oligo/seq ids,/assignment then per barcode the counts. padding with zeros
- def generateCountOutput(data,columns):
- counts = pd.DataFrame(list(data.groupby('label').apply(lambda x: x.values.flatten()))).fillna(0).astype(np.int64)
- counts.columns = columns
- counts['seq_id'] = data.index.unique()
- counts = counts[(['seq_id'] + list(columns))]
- return(counts)
-
- dna_counts = generateCountOutput(dna_df,dna_annot['sample'])
- rna_counts = generateCountOutput(rna_df,rna_annot['sample'])
-
- ## write table function
- def write(data,file):
- data.to_csv(file, index=False,sep='\t', compression='gzip')
+ dna_counts = generate_count_output(dna_df, dna_annot['sample'], n_bc)
+ rna_counts = generate_count_output(rna_df, rna_annot['sample'], n_bc)
## write output DNA/RNA annotations
- write(dna_annot,dna_annotation_output_file)
- write(rna_annot,rna_annotation_output_file)
+ write_table(dna_annot, dna_annotation_output_file)
+ write_table(rna_annot, rna_annotation_output_file)
## write output DNA/RNA annotations
- write(dna_counts,dna_counts_output_file)
- write(rna_counts,rna_counts_output_file)
+ write_table(dna_counts, dna_counts_output_file)
+ write_table(rna_counts, rna_counts_output_file)
if __name__ == '__main__':
cli()