Skip to content

Initial numba module #3225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft

Initial numba module #3225

wants to merge 18 commits into from

Conversation

benjeffery
Copy link
Member

Part of #3135

@benjeffery
Copy link
Member Author

I've done quite a bit of numba investigation and found a way to use dataclasses in numba code. This seems to come at very little performance cost compared to tuples and is a lot nicer. Using a generator also seems to work fine!

Copy link

codecov bot commented Jun 18, 2025

Codecov Report

❌ Patch coverage is 0.81301% with 244 lines in your changes missing coverage. Please review.
✅ Project coverage is 89.31%. Comparing base (ce09b35) to head (0e9ecb4).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
python/tskit/jit/numba.py 0.00% 244 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3225      +/-   ##
==========================================
- Coverage   89.61%   89.31%   -0.30%     
==========================================
  Files          28       27       -1     
  Lines       31983    30834    -1149     
  Branches     5888     5599     -289     
==========================================
- Hits        28660    27540    -1120     
- Misses       1888     1985      +97     
+ Partials     1435     1309     -126     
Flag Coverage Δ
c-tests 86.59% <ø> (ø)
lwt-tests ?
python-c-tests 88.15% <ø> (ø)
python-tests 98.79% <100.00%> (-0.02%) ⬇️
python-tests-numpy1 50.78% <0.00%> (-1.66%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/tskit/drawing.py 98.39% <100.00%> (ø)
python/tskit/jit/numba.py 0.00% <0.00%> (ø)

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@benjeffery
Copy link
Member Author

Having some CI weirdness that I'm not yet able to recreate.

@benjeffery
Copy link
Member Author

CI Fixed.

Here's some benchmarking with the "coalescent_nodes" method from #2778 on a TS with 12M edges:

Using ts.edge_diffs: 23.3s
Calculating edge diffs and coalescent nodes in a single numba.njit function: 0.085s
Using the classes here, calculating coalescent nodes in separate client numba.njit function: 0.093s

@jeromekelleher
Copy link
Member

Shall we move the first commit into its own PR? It's cluttering up this one and making it hard to see the real changes.

@jeromekelleher
Copy link
Member

I had imagined something lower level that was basically a copy of the TreePosition class from here: https://github.com/jeromekelleher/sc2ts/blob/7758245c3dc537aeec3b7cd6282241b65f8843dd/sc2ts/jit.py#L107

So, we don't try to provide Pythonic APIs, but just provide direct access to the edges out and edges in, which can be numba compiled like the example in the sc2ts code.

@benjeffery
Copy link
Member Author

benjeffery commented Jun 19, 2025

just provide direct access to the edges out and edges in

That's how this code works,
Your sc2ts code here:

    while tree_pos.next():
        for j in range(tree_pos.out_range[0], tree_pos.out_range[1]):
            e = tree_pos.edge_removal_order[j]
            c = edges_child[e]
            p = edges_parent[e]
            parent[c] = -1
            u = p
            while u != -1:
                num_samples[u] -= num_samples[c]
                u = parent[u]

becomes

    for tree_pos in numba_ts.edge_diffs():
        for j in range(*tree_pos.edges_out_index_range):
            e = numba_ts.indexes_edge_removal_order[j]
            c = edges_child[e]
            p = edges_parent[e]
            parent[c] = -1
            u = p
            while u != -1:
                num_samples[u] -= num_samples[c]
                u = parent[u]

It is still compiled, and 30% faster (for the coalesent nodes example)!

@jeromekelleher
Copy link
Member

Ahh, I didn't spot that sorry. How is it faster then?

I do think we should just stick with the TreePosition interface though, because we want to support seeking backwards as well, and ultimately randomly. There's no point in adding a layer for indirection on top of that.

@benjeffery
Copy link
Member Author

How is it faster then?

Mutating numpy arrays to maintain the state involves the following:

  1. Creating a temporary list (build_list).
  2. Performing bounds checks for the slice.
  3. Copying the data from the list into the array's memory.

Whereas yielding lightweight immuatable objects is much more amenable to numba optimisation. We might be able to get the same gains by using native objects for the state rather than numpy arrays if you are set against iteration.

@jeromekelleher
Copy link
Member

Let's talk it through in person - I don't have time to form an educated opinion I'm afraid.

@benjeffery benjeffery force-pushed the numba branch 2 times, most recently from 7aa7151 to 5d22c6c Compare June 27, 2025 15:22
@benjeffery
Copy link
Member Author

I've tried to closely match the exisiting tsutil implementation with next and prev. Need to so some perf checks with this new code under numba.

@benjeffery
Copy link
Member Author

New code looks just as fast, proceeding to add some more tests. Will merge this then before doign docs.

@benjeffery
Copy link
Member Author

Getting some weird failures on Windows here, and coverage not counting for the new module, will fix.

I've added a stab at some docs.

@hyanwong
Copy link
Member

hyanwong commented Jul 3, 2025

Re docs, eventually we probably want a "high performance" tutorial with some of this stuff, but I can have a stab at that after 1.0. There's some comments here: tskit-dev/tutorials#150 (comment) and some code examples at tskit-dev/tutorials#63

docs/numba.md Outdated
print(type(numba_ts))
```

## Tree Traversal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I normally think of "tree traversal" as iterating through the tree structure itself. Do you mean "Iterating through trees" here? I can't see any pre/postorder traversal code here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I was avoiding the word "iteration" not not confuse it with a Python iterator - but I'll change it back as this is more confusing!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just "moving between trees"?

Copy link
Member

@jeromekelleher jeromekelleher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Not obvious to me why we're setting up the tests like this though.


NODE_IS_SAMPLE = tskit.NODE_IS_SAMPLE

@numba.njit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are importing the jit module within the test functions here and defining the algorithm? I think we can assume that developers have numba installed, and there's pytest ways of skipping the module for CI?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further up the test file we test importing the tskit.jit.numba module while it is mocked out, which needs the module to not be imported. I could try and find another way to do that which doesn't require all the local imports?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe do that test in its own module so? The local import stuff will confuse people in to thinking it's necessary and the LLMs might also get this idea from the example code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed by making the imported module look absent in the test that requires it to be, and have removed the function-local imports.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants