Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ jobs:
directory: .
snakefile: workflow/Snakefile
args: "--lint --configfile config/example_config.yaml --config skip_version_check=True"

Pytest:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install test dependencies
run: python -m pip install --upgrade pip -r requirements-test.txt
- name: Run pytest suite
run: python -m pytest -q tests
# Testing:
# runs-on: ubuntu-latest
# needs:
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
[tool.snakefmt]
line_length = 127

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
addopts = "-ra"
4 changes: 4 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pytest
click
numpy
pandas
20 changes: 20 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Tests

This directory contains the pytest suite for workflow scripts and supporting fixtures.

## Layout

- `tests/conftest.py` keeps shared pytest fixtures and makes the repo root importable.
- `tests/<area>/test_*.py` holds the actual tests, grouped by script area such as `count`.
- `tests/fixtures/<area>/` stores reusable input data for those tests.

## Adding a new script test

1. Put the new test in the matching area folder, for example `tests/count/test_new_script.py`.
2. Add any reusable inputs under `tests/fixtures/<area>/`.
3. Prefer `click.testing.CliRunner` for Click commands and pytest fixtures like `tmp_path` for temporary outputs.
4. Run a focused check with `conda run -n mpralib python -m pytest -q tests/<area>/test_new_script.py`.

## Current example

The MPRAnalyze compiler test lives in `tests/count/test_mpranalyze_compiler.py` and uses the fixture input at `tests/fixtures/count/minimal_test_input.tsv`.
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sys
from pathlib import Path

import pytest
from click.testing import CliRunner

PROJECT_ROOT = Path(__file__).resolve().parents[1]
PROJECT_ROOT_STR = str(PROJECT_ROOT)
TESTS_ROOT = Path(__file__).resolve().parent

if PROJECT_ROOT_STR not in sys.path:
sys.path.insert(0, PROJECT_ROOT_STR)


@pytest.fixture(scope="session")
def tests_root() -> Path:
return TESTS_ROOT


@pytest.fixture(scope="session")
def fixtures_root(tests_root: Path) -> Path:
return tests_root / "fixtures"


@pytest.fixture(scope="session")
def count_fixtures_root(fixtures_root: Path) -> Path:
return fixtures_root / "count"


@pytest.fixture
def minimal_count_input(count_fixtures_root: Path) -> Path:
return count_fixtures_root / "minimal_test_input.tsv"


@pytest.fixture
def ragged_missing_input(count_fixtures_root: Path) -> Path:
return count_fixtures_root / "ragged_missing_input.tsv"


@pytest.fixture
def cli_runner() -> CliRunner:
return CliRunner()
140 changes: 140 additions & 0 deletions tests/count/test_mpranalyze_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import pandas as pd

from workflow.scripts.count import mpranalyze_compiler as compiler


class TestMpranalyzeCompiler:
def test_get_annot_parses_dna_and_rna_headers(self):
assert compiler.get_annot("DNA(condition X, replicate 1)") == ("DNA", "X", "1")
assert compiler.get_annot("RNA(condition Y, replicate 2)") == ("RNA", "Y", "2")

def test_get_annot_returns_none_for_unmatched_headers(self):
assert compiler.get_annot("Sequence") == (None, None, None)
assert compiler.get_annot("label") == (None, None, None)

def test_generate_annotation_output_repeats_and_numbers_barcodes(self):
input_frame = pd.DataFrame(
[
{"type": "DNA", "condition": "X", "replicate": "1"},
{"type": "RNA", "condition": "X", "replicate": "1"},
]
)

output = compiler.generate_annotation_output(input_frame, number_barcodes=2)

assert list(output["sample"]) == ["DNA_X_1_1", "DNA_X_1_2", "RNA_X_1_1", "RNA_X_1_2"]
assert list(output["barcode"]) == ["1", "2", "1", "2"]

def test_generate_count_output_pads_and_flattens_by_barcode(self):
input_frame = pd.DataFrame(
[
{"label": "oligoA", "bc1": 10, "bc2": 20},
{"label": "oligoA", "bc1": 11, "bc2": 21},
{"label": "oligoB", "bc1": 12, "bc2": 22},
]
).set_index("label")

output = compiler.generate_count_output(input_frame, ["sample_1", "sample_2", "sample_3", "sample_4"], number_barcodes=2)

assert list(output["seq_id"]) == ["oligoA", "oligoB"]
assert list(output.loc[output["seq_id"] == "oligoA", ["sample_1", "sample_2", "sample_3", "sample_4"]].iloc[0]) == [10, 11, 20, 21]
assert list(output.loc[output["seq_id"] == "oligoB", ["sample_1", "sample_2", "sample_3", "sample_4"]].iloc[0]) == [12, 0, 22, 0]

def test_cli_generates_expected_count_tables(self, cli_runner, minimal_count_input, tmp_path):
rna_counts_output = tmp_path / "rna_counts.tsv.gz"
dna_counts_output = tmp_path / "dna_counts.tsv.gz"
rna_annotation_output = tmp_path / "rna_annot.tsv.gz"
dna_annotation_output = tmp_path / "dna_annot.tsv.gz"

result = cli_runner.invoke(
compiler.cli,
[
"--input",
str(minimal_count_input),
"--rna-counts-output",
str(rna_counts_output),
"--dna-counts-output",
str(dna_counts_output),
"--rna-annotation-output",
str(rna_annotation_output),
"--dna-annotation-output",
str(dna_annotation_output),
],
)

assert result.exit_code == 0, result.output

rna: pd.DataFrame = pd.read_csv(rna_counts_output, sep="\t").set_index("seq_id")
dna: pd.DataFrame = pd.read_csv(dna_counts_output, sep="\t").set_index("seq_id")
cols= ["RNA_X_1_1", "RNA_X_1_2", "RNA_X_2_1", "RNA_X_2_2", "RNA_X_3_1", "RNA_X_3_2"]
row = rna.loc["oligoA"]
assert list(row.loc[cols]) == [
100,
101,
200,
201,
300,
301,
]
cols= ["DNA_X_1_1", "DNA_X_1_2", "DNA_X_2_1", "DNA_X_2_2", "DNA_X_3_1", "DNA_X_3_2"]
row = dna.loc["oligoA"]
assert list(row.loc[cols]) == [
10,
11,
20,
21,
30,
31,
]

def test_cli_missing_count_does_not_shift_replicates(self, cli_runner, ragged_missing_input, tmp_path):
# Regression guard for the MPRAflow-style "ragged within an oligo" bug
# (shendurelab/MPRAflow#87): a missing trailing count (empty cell / trailing
# tab) must stay a zero in its own replicate/barcode slot and must NOT shift
# later counts into the wrong replicate. The fixture's oligoA is missing its
# RNA replicate-3 / barcode-2 value, and oligoB has a single barcode (so it
# also exercises the across-oligo padding case).
rna_counts_output = tmp_path / "rna_counts.tsv.gz"
dna_counts_output = tmp_path / "dna_counts.tsv.gz"
rna_annotation_output = tmp_path / "rna_annot.tsv.gz"
dna_annotation_output = tmp_path / "dna_annot.tsv.gz"

result = cli_runner.invoke(
compiler.cli,
[
"--input",
str(ragged_missing_input),
"--rna-counts-output",
str(rna_counts_output),
"--dna-counts-output",
str(dna_counts_output),
"--rna-annotation-output",
str(rna_annotation_output),
"--dna-annotation-output",
str(dna_annotation_output),
],
)

assert result.exit_code == 0, result.output

rna: pd.DataFrame = pd.read_csv(rna_counts_output, sep="\t").set_index("seq_id")
dna: pd.DataFrame = pd.read_csv(dna_counts_output, sep="\t").set_index("seq_id")

rna_cols = [
"RNA_X_1_1", "RNA_X_1_2", "RNA_X_1_3",
"RNA_X_2_1", "RNA_X_2_2", "RNA_X_2_3",
"RNA_X_3_1", "RNA_X_3_2", "RNA_X_3_3",
]
# oligoA: the hole is at replicate 3 / barcode 2 -> must be 0, and the real
# barcode-3 value (302) must stay in barcode 3, not shift up to barcode 2.
assert list(rna.loc["oligoA", rna_cols]) == [100, 101, 102, 200, 201, 202, 300, 0, 302]
# oligoB has one barcode; remaining barcode slots pad with zeros per replicate.
assert list(rna.loc["oligoB", rna_cols]) == [103, 0, 0, 203, 0, 0, 303, 0, 0]

# DNA has no missing values; confirm nothing shifted there either.
dna_cols = [
"DNA_X_1_1", "DNA_X_1_2", "DNA_X_1_3",
"DNA_X_2_1", "DNA_X_2_2", "DNA_X_2_3",
"DNA_X_3_1", "DNA_X_3_2", "DNA_X_3_3",
]
assert list(dna.loc["oligoA", dna_cols]) == [10, 11, 12, 20, 21, 22, 30, 31, 32]
5 changes: 5 additions & 0 deletions tests/fixtures/count/minimal_test_input.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
label Sequence Barcode DNA(condition X, replicate 1) DNA(condition X, replicate 2) DNA(condition X, replicate 3) RNA(condition X, replicate 1) RNA(condition X, replicate 2) RNA(condition X, replicate 3)
oligoA A BC0 10 20 30 100 200 300
oligoA A BC1 11 21 31 101 201 301
oligoB B BC2 12 22 32 102 202 302
oligoB B BC3 13 23 33 103 203 303
5 changes: 5 additions & 0 deletions tests/fixtures/count/ragged_missing_input.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
label Sequence Barcode DNA(condition X, replicate 1) DNA(condition X, replicate 2) DNA(condition X, replicate 3) RNA(condition X, replicate 1) RNA(condition X, replicate 2) RNA(condition X, replicate 3)
oligoA A BC0 10 20 30 100 200 300
oligoA A BC1 11 21 31 101 201
oligoA A BC2 12 22 32 102 202 302
oligoB B BC3 13 23 33 103 203 303
82 changes: 48 additions & 34 deletions workflow/scripts/count/mpranalyze_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,44 @@
import numpy as np
import pandas as pd

ANNOT_PATTERN = re.compile(r"^([DR]NA).*\(condition (.*), replicate (.*)\)$")


def get_annot(head: str) -> tuple[str|None, str|None, str|None]:
match = ANNOT_PATTERN.match(head)
if match is not None:
group1 = match.group(1)
group2 = match.group(2)
group3 = match.group(3)
return (group1, group2, group3)
return (None, None, None)


def generate_annotation_output(data, number_barcodes):
data = data.loc[data.index.repeat(number_barcodes)].copy()
data['barcode'] = data.groupby(['type', 'condition', 'replicate']).cumcount() + 1
data['barcode'] = data['barcode'].astype(str)
data['sample'] = data[['type', 'condition', 'replicate', 'barcode']].agg('_'.join, axis=1)
return data[["sample", "type", "condition", "replicate", "barcode"]]


def generate_count_output(data, columns, number_barcodes):
rows = []
seq_ids = []
for label, group in data.groupby('label', sort=False):
padded = np.zeros((number_barcodes, data.shape[1]), dtype=np.int64)
vals = group.values[:number_barcodes].astype(np.int64)
padded[:len(vals)] = vals
rows.append(padded.flatten(order='F'))
seq_ids.append(label)
counts = pd.DataFrame(rows, columns=columns)
counts.insert(0, 'seq_id', seq_ids)
return counts


def write_table(data, file):
data.to_csv(file, index=False, sep='\t', compression='gzip')


# options
@click.command()
Expand Down Expand Up @@ -39,12 +77,6 @@

def cli(input_file, rna_counts_output_file, dna_counts_output_file, rna_annotation_output_file, dna_annotation_output_file):

annot_pattern = re.compile(r"^([DR]NA).*\(condition (.*), replicate (.*)\)$")
def get_annot(head):
m = annot_pattern.match(head)
if m is not None:
return m.group(1,2,3)

# read input
df = pd.read_csv(input_file,sep="\t", header='infer')

Expand All @@ -61,44 +93,26 @@ def get_annot(head):

# counts for observation

dna_df = df.iloc[:,2:(2+n_dna_obs)].applymap(np.int64)
rna_df = df.iloc[:,(2+n_dna_obs):].applymap(np.int64)
dna_df = df.iloc[:,2:(2+n_dna_obs)].astype(np.int64)
rna_df = df.iloc[:,(2+n_dna_obs):].astype(np.int64)

## generate output DNA/RNA annotations (type_condition_replicate_barcode)
n_bc = df.groupby('label').Barcode.agg(len).max()
def generateAnnotationOutput(data, number_barcodes):
data = data.loc[data.index.repeat(number_barcodes)]
data['barcode'] = data.groupby(['type','condition','replicate']).cumcount() +1
data['barcode'] = data['barcode'].astype(str)
data['sample'] = data[['type','condition','replicate','barcode']].agg('_'.join,axis=1)
data = data[["sample", "type", "condition", "replicate", "barcode"]]
return(data)
dna_annot = generateAnnotationOutput(dna_annot, n_bc)
rna_annot = generateAnnotationOutput(rna_annot, n_bc)
dna_annot = generate_annotation_output(dna_annot, n_bc)
rna_annot = generate_annotation_output(rna_annot, n_bc)

## generate output DNA/RNA count tables
## rows oligo/seq ids,/assignment then per barcode the counts. padding with zeros
def generateCountOutput(data,columns):
counts = pd.DataFrame(list(data.groupby('label').apply(lambda x: x.values.flatten()))).fillna(0).astype(np.int64)
counts.columns = columns
counts['seq_id'] = data.index.unique()
counts = counts[(['seq_id'] + list(columns))]
return(counts)

dna_counts = generateCountOutput(dna_df,dna_annot['sample'])
rna_counts = generateCountOutput(rna_df,rna_annot['sample'])

## write table function
def write(data,file):
data.to_csv(file, index=False,sep='\t', compression='gzip')
dna_counts = generate_count_output(dna_df, dna_annot['sample'], n_bc)
rna_counts = generate_count_output(rna_df, rna_annot['sample'], n_bc)

## write output DNA/RNA annotations
write(dna_annot,dna_annotation_output_file)
write(rna_annot,rna_annotation_output_file)
write_table(dna_annot, dna_annotation_output_file)
write_table(rna_annot, rna_annotation_output_file)

## write output DNA/RNA annotations
write(dna_counts,dna_counts_output_file)
write(rna_counts,rna_counts_output_file)
write_table(dna_counts, dna_counts_output_file)
write_table(rna_counts, rna_counts_output_file)

if __name__ == '__main__':
cli()
Loading