diff --git a/Cargo.toml b/Cargo.toml index 9d0b974..72bca34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,4 @@ num-traits = "0.2" [dev-dependencies] maplit = "1.0" +rand = "0.8.5" diff --git a/README.md b/README.md index 153407b..b461c31 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,20 @@ let expected = vec![('c', 3), ('b', 2), ('d', 2), ('a', 1), ('e', 1)]; assert!(by_common == expected); ``` +[`k_most_common_ordered()`] takes an argument `k` of type `usize` and returns the top `k` most +common items. This is functionally equivalent to calling `most_common_ordered()` and then +truncating the result to length `k`. However, if `k` is smaller than the length of the counter +then `k_most_common_ordered()` can be more efficient, often much more so. + +```rust +let by_common = "eaddbbccc".chars().collect::>().k_most_common_ordered(2); +let expected = vec![('c', 3), ('b', 2)]; +assert!(by_common == expected); +``` + +[`k_most_common_ordered()`]: Counter::k_most_common_ordered +[`most_common_ordered()`]: Counter::most_common_ordered + ### Get the most common items using your own ordering For example, here we break ties reverse alphabetically. diff --git a/src/lib.rs b/src/lib.rs index 1cbe307..1de45de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,6 +68,21 @@ //! assert!(by_common == expected); //! ``` //! +//! [`k_most_common_ordered()`] takes an argument `k` of type `usize` and returns the top `k` most +//! common items. This is functionally equivalent to calling `most_common_ordered()` and then +//! truncating the result to length `k`. However, if `k` is smaller than the length of the counter +//! then `k_most_common_ordered()` can be more efficient, often much more so. +//! +//! ```rust +//! # use counter::Counter; +//! let by_common = "eaddbbccc".chars().collect::>().k_most_common_ordered(2); +//! let expected = vec![('c', 3), ('b', 2)]; +//! assert!(by_common == expected); +//! ``` +//! +//! [`k_most_common_ordered()`]: Counter::k_most_common_ordered +//! [`most_common_ordered()`]: Counter::most_common_ordered +//! //! ## Get the most common items using your own ordering //! //! For example, here we break ties reverse alphabetically. @@ -176,7 +191,7 @@ use num_traits::{One, Zero}; use std::borrow::Borrow; -use std::collections::HashMap; +use std::collections::{BinaryHeap, HashMap}; use std::hash::Hash; use std::iter; use std::ops::{Add, AddAssign, BitAnd, BitOr, Deref, DerefMut, Index, IndexMut, Sub, SubAssign}; @@ -344,7 +359,7 @@ where .iter() .map(|(key, count)| (key.clone(), count.clone())) .collect::>(); - items.sort_by(|(a_item, a_count), (b_item, b_count)| { + items.sort_unstable_by(|(a_item, a_count), (b_item, b_count)| { b_count .cmp(a_count) .then_with(|| tiebreaker(a_item, b_item)) @@ -363,15 +378,103 @@ where /// In the event that two keys have an equal frequency, use the natural ordering of the keys /// to further sort the results. /// + /// # Examples + /// /// ```rust /// # use counter::Counter; /// let mc = "abracadabra".chars().collect::>().most_common_ordered(); /// let expect = vec![('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]; /// assert_eq!(mc, expect); /// ``` + /// + /// # Time complexity + /// + /// *O*(*n* \* log *n*), where *n* is the number of items in the counter. If all you want is + /// the top *k* items and *k* < *n* then it can be more efficient to use + /// [`k_most_common_ordered`]. + /// + /// [`k_most_common_ordered`]: Counter::k_most_common_ordered pub fn most_common_ordered(&self) -> Vec<(T, N)> { self.most_common_tiebreaker(Ord::cmp) } + + /// Returns the `k` most common items in decreasing order of their counts. + /// + /// The returned vector is the same as would be obtained by calling `most_common_ordered` and + /// then truncating the result to length `k`. In particular, items with the same count are + /// sorted in *increasing* order of their keys. Further, if `k` is greater than the length of + /// the counter then the returned vector will have length equal to that of the counter, not + /// `k`. + /// + /// # Examples + /// + /// ```rust + /// # use counter::Counter; + /// let counter: Counter<_> = "abracadabra".chars().collect(); + /// let top3 = counter.k_most_common_ordered(3); + /// assert_eq!(top3, vec![('a', 5), ('b', 2), ('r', 2)]); + /// ``` + /// + /// # Time complexity + /// + /// This method can be much more efficient than [`most_common_ordered`] when *k* is much + /// smaller than the length of the counter *n*. When *k* = 1 the algorithm is equivalent + /// to finding the minimum (or maximum) of *n* items, which requires *n* \- 1 comparisons. For + /// a fixed value of *k* > 1, the number of comparisons scales with *n* as *n* \+ *O*(log *n*) + /// and the number of swaps scales as *O*(log *n*). As *k* approaches *n*, this algorithm + /// approaches a heapsort of the *n* items, which has complexity *O*(*n* \* log *n*). + /// + /// For values of *k* close to *n* the sorting algorithm used by [`most_common_ordered`] will + /// generally be faster than the heapsort used by this method by a small constant factor. + /// Exactly where the crossover point occurs will depend on several factors. For small *k* + /// choose this method. If *k* is a substantial fraction of *n*, it may be that + /// [`most_common_ordered`] is faster. If performance matters in your application then it may + /// be worth experimenting to see which of the two methods is faster. + /// + /// [`most_common_ordered`]: Counter::most_common_ordered + pub fn k_most_common_ordered(&self, k: usize) -> Vec<(T, N)> { + use std::cmp::Reverse; + + if k == 0 { + return vec![]; + } + + // The quicksort implementation used by `most_common_ordered()` is generally faster than + // the heapsort used below when sorting the entire counter. + if k >= self.map.len() { + return self.most_common_ordered(); + } + + // Clone the counts as we iterate over the map to eliminate an extra indirection when + // comparing counts. This will be an improvement in the typical case where `N: Copy`. + // Defer cloning the keys until we have selected the top `k` items so that we clone only + // `k` keys instead of all of them. + let mut items = self.map.iter().map(|(t, n)| (Reverse(n.clone()), t)); + + // Step 1. Make a heap out of the first `k` items; this makes O(k) comparisons. + let mut heap: BinaryHeap<_> = items.by_ref().take(k).collect(); + + // Step 2. Successively compare each of the remaining `n - k` items to the top of the heap, + // replacing the root (and subsequently sifting down) whenever the item is less than the + // root. This takes at most n - k + k * (1 + log2(k)) * (H(n) - H(k)) comparisons, where + // H(i) is the ith [harmonic number](https://en.wikipedia.org/wiki/Harmonic_number). For + // fixed `k`, this scales as *n* + *O*(log(*n*)). + items.for_each(|item| { + // If `items` is nonempty at this point then we know the heap contains `k > 0` + // elements. + let mut root = heap.peek_mut().expect("the heap is empty"); + if *root > item { + *root = item; + } + }); + + // Step 3. Sort the items in the heap with the second phases of heapsort. The number of + // comparisons is 2 * k * log2(k) + O(k). + heap.into_sorted_vec() + .into_iter() + .map(|(Reverse(n), t)| (t.clone(), n)) + .collect() + } } impl Default for Counter @@ -1007,6 +1110,7 @@ where mod tests { use super::*; use maplit::hashmap; + use rand::Rng; use std::collections::HashMap; #[test] @@ -1179,6 +1283,41 @@ mod tests { assert!(by_common == expected); } + #[test] + fn test_k_most_common_ordered() { + let counter: Counter<_> = "abracadabra".chars().collect(); + let all = counter.most_common_ordered(); + for k in 0..=counter.len() { + let topk = counter.k_most_common_ordered(k); + assert_eq!(&topk, &all[..k]); + } + } + + /// This test is fundamentally the same as `test_k_most_common_ordered`, but it operates on + /// a wider variety of data. In particular, it tests both longer, narrower, and wider + /// distributions of data than the other test does. + #[test] + fn test_k_most_common_ordered_heavy() { + let mut rng = rand::thread_rng(); + + for container_size in [5, 10, 25, 100, 256] { + for max_value_factor in [0.25, 0.5, 1.0, 1.25, 2.0, 10.0, 100.0] { + let max_value = ((container_size as f64) * max_value_factor) as u32; + let mut values = vec![0; container_size]; + for value in values.iter_mut() { + *value = rng.gen_range(0..=max_value); + } + + let counter: Counter<_> = values.into_iter().collect(); + let all = counter.most_common_ordered(); + for k in 0..=counter.len() { + let topk = counter.k_most_common_ordered(k); + assert_eq!(&topk, &all[..k]); + } + } + } + } + #[test] fn test_total() { let counter = Counter::init("".chars());