Skip to content

Commit 869967d

Browse files
authored
refactor: Write edge differences iterator in safe rust. (#791)
BREAKING CHANGE: Several API breakages: * lifetimes are now involved * iterator is used now instead of lending iterator
1 parent c623a31 commit 869967d

File tree

5 files changed

+314
-190
lines changed

5 files changed

+314
-190
lines changed

book/src/tree_sequence_edge_diffs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Iterating over edge differences
22

3-
As with [trees](tree_sequence_iterate_trees.md), the API provides a *lending* iterator over edge differences.
3+
The API provides an iterator over edge differences.
44
Each step of the iterator advances to the next tree in the tree sequence.
55
For each tree, a standard `Iterator` over removals and insertions is available:
66

src/edge_differences.rs

Lines changed: 141 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,8 @@
1+
use crate::EdgeId;
12
use crate::NodeId;
23
use crate::Position;
34
use crate::TreeSequence;
45

5-
use crate::sys::bindings;
6-
7-
#[repr(transparent)]
8-
struct LLEdgeDifferenceIterator(bindings::tsk_diff_iter_t);
9-
10-
impl Drop for LLEdgeDifferenceIterator {
11-
fn drop(&mut self) {
12-
unsafe { bindings::tsk_diff_iter_free(&mut self.0) };
13-
}
14-
}
15-
16-
impl LLEdgeDifferenceIterator {
17-
pub fn new_from_treeseq(
18-
treeseq: &TreeSequence,
19-
flags: bindings::tsk_flags_t,
20-
) -> Result<Self, crate::TskitError> {
21-
let mut inner = std::mem::MaybeUninit::<bindings::tsk_diff_iter_t>::uninit();
22-
let treeseq_ptr = treeseq.as_ptr();
23-
assert!(!treeseq_ptr.is_null());
24-
// SAFETY: treeseq_ptr is not null
25-
let tables_ptr =
26-
unsafe { (*treeseq_ptr).tables } as *const bindings::tsk_table_collection_t;
27-
assert!(!tables_ptr.is_null());
28-
// SAFETY: tables_ptr is not null,
29-
// init of inner will be handled by tsk_diff_iter_init
30-
let num_trees: i32 = treeseq.num_trees().try_into()?;
31-
let code = unsafe {
32-
bindings::tsk_diff_iter_init(inner.as_mut_ptr(), tables_ptr, num_trees, flags)
33-
};
34-
// SAFETY: tsk_diff_iter_init has initialized our object
35-
handle_tsk_return_value!(code, Self(unsafe { inner.assume_init() }))
36-
}
37-
}
38-
396
/// Marker type for edge insertion.
407
pub struct Insertion {}
418

@@ -49,49 +16,6 @@ mod private {
4916
impl EdgeDifferenceIteration for super::Removal {}
5017
}
5118

52-
struct LLEdgeList<T: private::EdgeDifferenceIteration> {
53-
inner: bindings::tsk_edge_list_t,
54-
marker: std::marker::PhantomData<T>,
55-
}
56-
57-
macro_rules! build_lledgelist {
58-
($name: ident, $generic: ty) => {
59-
type $name = LLEdgeList<$generic>;
60-
61-
impl Default for $name {
62-
fn default() -> Self {
63-
Self {
64-
inner: bindings::tsk_edge_list_t {
65-
head: std::ptr::null_mut(),
66-
tail: std::ptr::null_mut(),
67-
},
68-
marker: std::marker::PhantomData::<$generic> {},
69-
}
70-
}
71-
}
72-
};
73-
}
74-
75-
build_lledgelist!(LLEdgeInsertionList, Insertion);
76-
build_lledgelist!(LLEdgeRemovalList, Removal);
77-
78-
/// Concrete type implementing [`Iterator`] over [`EdgeInsertion`] or [`EdgeRemoval`].
79-
/// Created by [`EdgeDifferencesIterator::edge_insertions`] or
80-
/// [`EdgeDifferencesIterator::edge_removals`], respectively.
81-
pub struct EdgeDifferences<'a, T: private::EdgeDifferenceIteration> {
82-
inner: &'a LLEdgeList<T>,
83-
current: *mut bindings::tsk_edge_list_node_t,
84-
}
85-
86-
impl<'a, T: private::EdgeDifferenceIteration> EdgeDifferences<'a, T> {
87-
fn new(inner: &'a LLEdgeList<T>) -> Self {
88-
Self {
89-
inner,
90-
current: std::ptr::null_mut(),
91-
}
92-
}
93-
}
94-
9519
/// An edge difference. Edge insertions and removals are differentiated by
9620
/// marker types [`Insertion`] and [`Removal`], respectively.
9721
#[derive(Debug, Copy, Clone)]
@@ -149,103 +73,170 @@ pub type EdgeInsertion = EdgeDifference<Insertion>;
14973
/// Type alias for [`EdgeDifference<Removal>`]
15074
pub type EdgeRemoval = EdgeDifference<Removal>;
15175

152-
impl<T> Iterator for EdgeDifferences<'_, T>
153-
where
154-
T: private::EdgeDifferenceIteration,
155-
{
156-
type Item = EdgeDifference<T>;
76+
/// Manages iteration over trees to obtain
77+
/// edge differences.
78+
pub struct EdgeDifferencesIterator<'ts> {
79+
edges_left: &'ts [Position],
80+
edges_right: &'ts [Position],
81+
edges_parent: &'ts [NodeId],
82+
edges_child: &'ts [NodeId],
83+
insertion_order: &'ts [EdgeId],
84+
removal_order: &'ts [EdgeId],
85+
left: f64,
86+
sequence_length: f64,
87+
insertion_index: usize,
88+
removal_index: usize,
89+
}
15790

158-
fn next(&mut self) -> Option<Self::Item> {
159-
if self.current.is_null() {
160-
self.current = self.inner.inner.head;
161-
} else {
162-
self.current = unsafe { *self.current }.next;
163-
}
164-
if self.current.is_null() {
165-
None
166-
} else {
167-
let left = unsafe { (*self.current).edge.left };
168-
let right = unsafe { (*self.current).edge.right };
169-
let parent = unsafe { (*self.current).edge.parent };
170-
let child = unsafe { (*self.current).edge.child };
171-
Some(Self::Item::new(left, right, parent, child))
91+
impl<'ts> EdgeDifferencesIterator<'ts> {
92+
pub(crate) fn new(treeseq: &'ts TreeSequence) -> Self {
93+
Self {
94+
edges_left: treeseq.tables().edges().left_slice(),
95+
edges_right: treeseq.tables().edges().right_slice(),
96+
edges_parent: treeseq.tables().edges().parent_slice(),
97+
edges_child: treeseq.tables().edges().child_slice(),
98+
insertion_order: treeseq.edge_insertion_order(),
99+
removal_order: treeseq.edge_removal_order(),
100+
left: 0.,
101+
sequence_length: treeseq.tables().sequence_length().into(),
102+
insertion_index: 0,
103+
removal_index: 0,
172104
}
173105
}
174106
}
175107

176-
/// Manages iteration over trees to obtain
177-
/// edge differences.
178-
pub struct EdgeDifferencesIterator {
179-
inner: LLEdgeDifferenceIterator,
180-
insertion: LLEdgeInsertionList,
181-
removal: LLEdgeRemovalList,
108+
#[derive(Clone)]
109+
pub struct CurrentTreeEdgeDifferences<'ts> {
110+
edges_left: &'ts [Position],
111+
edges_right: &'ts [Position],
112+
edges_parent: &'ts [NodeId],
113+
edges_child: &'ts [NodeId],
114+
insertion_order: &'ts [EdgeId],
115+
removal_order: &'ts [EdgeId],
116+
removals: (usize, usize),
117+
insertions: (usize, usize),
182118
left: f64,
183119
right: f64,
184-
advanced: i32,
185120
}
186121

187-
impl EdgeDifferencesIterator {
188-
// NOTE: will return None if tskit-c cannot
189-
// allocate memory for internal structures.
190-
pub(crate) fn new_from_treeseq(
191-
treeseq: &TreeSequence,
192-
flags: bindings::tsk_flags_t,
193-
) -> Result<Self, crate::TskitError> {
194-
LLEdgeDifferenceIterator::new_from_treeseq(treeseq, flags).map(|inner| Self {
195-
inner,
196-
insertion: LLEdgeInsertionList::default(),
197-
removal: LLEdgeRemovalList::default(),
198-
left: f64::default(),
199-
right: f64::default(),
200-
advanced: 0,
201-
})
202-
}
122+
#[repr(transparent)]
123+
pub struct EdgeRemovalsIterator<'ts>(CurrentTreeEdgeDifferences<'ts>);
203124

204-
fn advance_tree(&mut self) {
205-
// SAFETY: our tree sequence is guaranteed
206-
// to be valid and own its tables.
207-
self.advanced = unsafe {
208-
bindings::tsk_diff_iter_next(
209-
&mut self.inner.0,
210-
&mut self.left,
211-
&mut self.right,
212-
&mut self.removal.inner,
213-
&mut self.insertion.inner,
214-
)
215-
};
216-
}
125+
#[repr(transparent)]
126+
pub struct EdgeInsertionsIterator<'ts>(CurrentTreeEdgeDifferences<'ts>);
217127

218-
pub fn left(&self) -> Position {
219-
self.left.into()
128+
impl<'ts> Iterator for EdgeRemovalsIterator<'ts> {
129+
type Item = EdgeDifference<Removal>;
130+
fn next(&mut self) -> Option<Self::Item> {
131+
if self.0.removals.0 < self.0.removals.1 {
132+
let index = self.0.removals.0;
133+
self.0.removals.0 += 1;
134+
Some(Self::Item::new(
135+
self.0.edges_left[self.0.removal_order[index].as_usize()],
136+
self.0.edges_right[self.0.removal_order[index].as_usize()],
137+
self.0.edges_parent[self.0.removal_order[index].as_usize()],
138+
self.0.edges_child[self.0.removal_order[index].as_usize()],
139+
))
140+
} else {
141+
None
142+
}
220143
}
144+
}
221145

222-
pub fn right(&self) -> Position {
223-
self.right.into()
146+
impl<'ts> Iterator for EdgeInsertionsIterator<'ts> {
147+
type Item = EdgeDifference<Insertion>;
148+
fn next(&mut self) -> Option<Self::Item> {
149+
if self.0.insertions.0 < self.0.insertions.1 {
150+
let index = self.0.insertions.0;
151+
self.0.insertions.0 += 1;
152+
Some(Self::Item::new(
153+
self.0.edges_left[self.0.insertion_order[index].as_usize()],
154+
self.0.edges_right[self.0.insertion_order[index].as_usize()],
155+
self.0.edges_parent[self.0.insertion_order[index].as_usize()],
156+
self.0.edges_child[self.0.insertion_order[index].as_usize()],
157+
))
158+
} else {
159+
None
160+
}
224161
}
162+
}
225163

226-
pub fn interval(&self) -> (Position, Position) {
227-
(self.left(), self.right())
164+
impl<'ts> CurrentTreeEdgeDifferences<'ts> {
165+
pub fn removals(&self) -> impl Iterator<Item = EdgeRemoval> + '_ {
166+
EdgeRemovalsIterator(self.clone())
228167
}
229168

230-
pub fn edge_removals(&self) -> impl Iterator<Item = EdgeRemoval> + '_ {
231-
EdgeDifferences::<Removal>::new(&self.removal)
169+
pub fn insertions(&self) -> impl Iterator<Item = EdgeInsertion> + '_ {
170+
EdgeInsertionsIterator(self.clone())
232171
}
233172

234-
pub fn edge_insertions(&self) -> impl Iterator<Item = EdgeInsertion> + '_ {
235-
EdgeDifferences::<Insertion>::new(&self.insertion)
173+
pub fn interval(&self) -> (Position, Position) {
174+
(self.left.into(), self.right.into())
236175
}
237176
}
238177

239-
impl crate::StreamingIterator for EdgeDifferencesIterator {
240-
type Item = EdgeDifferencesIterator;
241-
242-
fn advance(&mut self) {
243-
self.advance_tree()
178+
fn update_right(
179+
right: f64,
180+
index: usize,
181+
position_slice: &[Position],
182+
diff_slice: &[EdgeId],
183+
) -> f64 {
184+
if index < diff_slice.len() {
185+
let temp = position_slice[diff_slice[index].as_usize()];
186+
if temp < right {
187+
temp.into()
188+
} else {
189+
right
190+
}
191+
} else {
192+
right
244193
}
194+
}
245195

246-
fn get(&self) -> Option<&Self::Item> {
247-
if self.advanced > 0 {
248-
Some(self)
196+
impl<'ts> Iterator for EdgeDifferencesIterator<'ts> {
197+
type Item = CurrentTreeEdgeDifferences<'ts>;
198+
199+
fn next(&mut self) -> Option<Self::Item> {
200+
if self.insertion_index < self.insertion_order.len() && self.left < self.sequence_length {
201+
let removals_start = self.removal_index;
202+
while self.removal_index < self.removal_order.len()
203+
&& self.edges_right[self.removal_order[self.removal_index].as_usize()] == self.left
204+
{
205+
self.removal_index += 1;
206+
}
207+
let insertions_start = self.insertion_index;
208+
while self.insertion_index < self.insertion_order.len()
209+
&& self.edges_left[self.insertion_order[self.insertion_index].as_usize()]
210+
== self.left
211+
{
212+
self.insertion_index += 1;
213+
}
214+
let right = update_right(
215+
self.sequence_length,
216+
self.insertion_index,
217+
self.edges_left,
218+
self.insertion_order,
219+
);
220+
let right = update_right(
221+
right,
222+
self.removal_index,
223+
self.edges_right,
224+
self.removal_order,
225+
);
226+
let diffs = CurrentTreeEdgeDifferences {
227+
edges_left: self.edges_left,
228+
edges_right: self.edges_right,
229+
edges_parent: self.edges_parent,
230+
edges_child: self.edges_child,
231+
insertion_order: self.insertion_order,
232+
removal_order: self.removal_order,
233+
removals: (removals_start, self.removal_index),
234+
insertions: (insertions_start, self.insertion_index),
235+
left: self.left,
236+
right,
237+
};
238+
self.left = right;
239+
Some(diffs)
249240
} else {
250241
None
251242
}

src/trees/treeseq.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,9 @@ impl TreeSequence {
477477
handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv))
478478
}
479479

480-
/// Build a lending iterator over edge differences.
481-
///
482-
/// # Errors
483-
///
484-
/// * [`TskitError`] if the `C` back end is unable to allocate
485-
/// needed memory
486-
pub fn edge_differences_iter(
487-
&self,
488-
) -> Result<crate::edge_differences::EdgeDifferencesIterator, TskitError> {
489-
crate::edge_differences::EdgeDifferencesIterator::new_from_treeseq(self, 0)
480+
/// Build an iterator over edge differences.
481+
pub fn edge_differences_iter(&self) -> crate::edge_differences::EdgeDifferencesIterator {
482+
crate::edge_differences::EdgeDifferencesIterator::new(self)
490483
}
491484

492485
/// Reference to the underlying table collection.

0 commit comments

Comments
 (0)