From 2460711f4f47ca855296f4aa7e00134bbf391a97 Mon Sep 17 00:00:00 2001 From: Oittaa Date: Fri, 11 Jul 2025 22:49:52 +0200 Subject: [PATCH 1/8] Add selection algorithms --- CMakeLists.txt | 1 + lib/std/select.zig | 484 +++++++++++++++++++++++++++++++++++++++++++++ lib/std/std.zig | 1 + 3 files changed, 486 insertions(+) create mode 100644 lib/std/select.zig diff --git a/CMakeLists.txt b/CMakeLists.txt index db580b05fa28..83513606e81f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -496,6 +496,7 @@ set(ZIG_STAGE2_SOURCES lib/std/pdb.zig lib/std/process.zig lib/std/process/Child.zig + lib/std/select.zig lib/std/sort.zig lib/std/start.zig lib/std/static_string_map.zig diff --git a/lib/std/select.zig b/lib/std/select.zig new file mode 100644 index 000000000000..f850e2ac0808 --- /dev/null +++ b/lib/std/select.zig @@ -0,0 +1,484 @@ +const std = @import("std.zig"); +const assert = std.debug.assert; +const sort = std.sort; +const mem = std.mem; +const math = std.math; +const testing = std.testing; + +/// A convenient wrapper for nthElementContext that finds the nth smallest +/// element in a slice and places it at items[n]. +/// +/// This function modifies the order of the other elements in the slice. After execution, +/// all elements before items[n] will be less than or equal to items[n], and all +/// elements after items[n] will be greater than or equal to items[n]. +/// +/// This is a high-level wrapper that creates the necessary context for the +/// core nthElementContext implementation. +/// +/// Parameters: +/// - T: The type of the elements in the slice. +/// - items: The slice of items to select from. The order of elements will be modified. +/// - n: The 0-based index of the element to find (e.g., 0 for the smallest, items.len - 1 for the largest). +/// - context: A user-provided context that will be passed to lessThanFn. +/// - lessThanFn: The comparison function that defines the ordering of elements. +pub fn nthElement( + comptime T: type, + items: []T, + n: usize, + context: anytype, + comptime lessThanFn: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) void { + const Context = struct { + items: []T, + sub_ctx: @TypeOf(context), + + pub fn lessThan(ctx: @This(), a: usize, b: usize) bool { + return lessThanFn(ctx.sub_ctx, ctx.items[a], ctx.items[b]); + } + + pub fn swap(ctx: @This(), a: usize, b: usize) void { + return mem.swap(T, &ctx.items[a], &ctx.items[b]); + } + }; + nthElementContext(n, 0, items.len, Context{ .items = items, .sub_ctx = context }); +} + +/// The core implementation of the nth-element search algorithm (Introselect). +/// It finds the element that would be at index a + n if the sub-slice [a..b) were sorted. +/// +/// This function operates on a sub-slice of a collection managed by the context +/// and modifies the slice in-place. After execution, the nth smallest element is +/// guaranteed to be at index a + n, with all preceding elements being less than or +/// equal to it, and all succeeding elements being greater than or equal to it. +/// +/// The algorithm is a hybrid: +/// 1. Quicksort-like partitioning: It uses a recursive partitioning strategy +/// to narrow down the search space. +/// 2. Insertion Sort: For very small sub-slices, it switches to insertion sort, +/// which is more efficient for small inputs. +/// 3. Heapsort Fallback: To guard against worst-case O(n^2) performance on +/// pathological inputs, it tracks recursion depth. +/// If the depth limit is exceeded, it switches to heapSelectContext, which +/// has a guaranteed O(n log n) worst-case time complexity. +/// +/// Parameters: +/// - n: The 0-based index of the element to find, relative to the start of the sub-slice. +/// - a: The starting index of the sub-slice to search within. +/// - b: The exclusive end index of the sub-slice. +/// - context: An object providing lessThan(i, j) and swap(i, j) methods. +pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { + // very short slices get sorted using insertion sort. + const max_insertion = 8; + assert(a < b); + const len = b - a; + assert(n < len); + var left: usize = a; + var right: usize = b; + var depth_limit: usize = math.log2_int(usize, len) * 2; // This is what C++ std::nth_element does. + while (right > left) { + if (right - left <= max_insertion) { + sort.insertionContext(left, right, context); + break; + } + if (depth_limit == 0) { + heapSelectContext(n - (left - a), left, right, context); + break; + } + depth_limit -= 1; + var pivot: usize = 0; + chosePivot(left, right, &pivot, context); + // if the chosen pivot is equal to the predecessor, then it's the smallest element in the + // slice. Partition the slice into elements equal to and elements greater than the pivot. + // This case is usually hit when the slice contains many duplicate elements. + if (left > a and !context.lessThan(left - 1, pivot)) { + left = partitionEqual(left, right, pivot, context); + continue; + } + partition(left, right, &pivot, context); + const target = a + n; + if (pivot == target) { + break; + } else if (pivot > target) { + right = pivot; + } else { + left = pivot + 1; + } + } +} + +/// A convenient wrapper for `heapSelectContext`. It creates the appropriate +/// context for a given slice and less-than function. After execution, the +/// nth smallest element of `items` will be at `items[n]`. +/// +/// Parameters: +/// - T: The type of the elements in the slice. +/// - items: The slice of items to select from. +/// - n: The 0-based index of the element to find (0 for smallest, 1 for 2nd smallest, etc.). +/// - context: A user-provided context to be passed to `lessThanFn`. +/// - lessThanFn: The comparison function. +pub fn heapSelect( + comptime T: type, + items: []T, + n: usize, + context: anytype, + comptime lessThanFn: fn (@TypeOf(context), lhs: T, rhs: T) bool, +) void { + // A local struct to adapt the user's slice and functions to the + // index-based interface required by `heapSelectContext`. + const Context = struct { + items: []T, + sub_ctx: @TypeOf(context), + + pub fn lessThan(ctx: @This(), i: usize, j: usize) bool { + return lessThanFn(ctx.sub_ctx, ctx.items[i], ctx.items[j]); + } + + pub fn swap(ctx: @This(), i: usize, j: usize) void { + return mem.swap(T, &ctx.items[i], &ctx.items[j]); + } + }; + + // Create an instance of the context and call the core selection function. + heapSelectContext(n, 0, items.len, Context{ .items = items, .sub_ctx = context }); +} + +/// heapSelectContext finds the nth smallest element within a slice defined by indices [a, b). +/// The result (the nth smallest element) will be placed at index `a + n` of the underlying +/// collection managed by the context. +/// +/// This function modifies the order of elements in the slice. +/// +/// Parameters: +/// - n: The 0-based index of the element to find in the sorted version of the slice (0 for smallest, 1 for 2nd smallest, etc.). +/// - a: The starting index of the slice. +/// - b: The exclusive end index of the slice. +/// - context: An object with `lessThan(i, j)` and `swap(i, j)` methods. +pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { + assert(a < b); + const len = b - a; + assert(n < len); + const n_largest = len - n; + // build the heap in linear time. + var i = a + (b - a) / 2; + while (i > a) { + i -= 1; + siftDown(a, i, b, context); + } + + var heap_end = b; + i = 0; + while (i < n_largest - 1) : (i += 1) { + heap_end -= 1; + context.swap(a, heap_end); + siftDown(a, a, heap_end, context); + } + + // After the loop, the root of the heap (at index `a`) is the nth smallest element. + // We swap it into the correct position `a + n`. + if (len > 0) { + context.swap(a, a + n); + } +} + +/// Calculates the median of a slice using the nthElement function. +/// For slices with an odd number of elements, it returns the middle element. +/// For slices with an even number of elements, it returns the mean of the two central elements. +/// The result from integer types is rounded towards zero, while for floating-point types it is the exact mean. +/// This function modifies the order of elements in the slice. +pub fn median( + comptime T: type, + items: []T, + context: anytype, + comptime lessThanFn: fn (context: @TypeOf(context), lhs: T, rhs: T) bool, +) T { + const len = items.len; + assert(len > 0); // Ensure the slice is not empty. + const mid = len / 2; + if (len % 2 == 1) { + nthElement(T, items, mid, context, lessThanFn); + return items[mid]; + } + nthElement(T, items, mid - 1, context, lessThanFn); + const lower_median = items[mid - 1]; + var upper_median = items[mid]; + var i = mid + 1; + while (i < len) : (i += 1) { + if (lessThanFn(context, items[i], upper_median)) { + upper_median = items[i]; + } + } + return switch (@typeInfo(T)) { + .int => @divTrunc((lower_median + upper_median), 2), + .float => (lower_median + upper_median) / 2, + else => @compileError("Unsupported type for median: " ++ @typeName(T)), + }; +} + +/// partitions `items[a..b]` into elements smaller than `items[pivot]`, +/// followed by elements greater than or equal to `items[pivot]`. +/// +/// sets the new pivot. +fn partition(a: usize, b: usize, pivot: *usize, context: anytype) void { + // move pivot to the first place + context.swap(a, pivot.*); + var i = a + 1; + var j = b - 1; + while (true) { + while (i <= j and context.lessThan(i, a)) i += 1; + while (i <= j and !context.lessThan(j, a)) j -= 1; + if (i > j) break; + context.swap(i, j); + i += 1; + j -= 1; + } + context.swap(j, a); + pivot.* = j; +} + +/// partitions items into elements equal to `items[pivot]` +/// followed by elements greater than `items[pivot]`. +/// +/// it assumed that `items[a..b]` does not contain elements smaller than the `items[pivot]`. +fn partitionEqual(a: usize, b: usize, pivot: usize, context: anytype) usize { + // move pivot to the first place + context.swap(a, pivot); + + var i = a + 1; + var j = b - 1; + + while (true) { + while (i <= j and !context.lessThan(a, i)) i += 1; + while (i <= j and context.lessThan(a, j)) j -= 1; + if (i > j) break; + + context.swap(i, j); + i += 1; + j -= 1; + } + + return i; +} + +/// chooses a pivot in `items[a..b]`. +/// It's modeled directly after the `chosePivot` function in `std.sort`. +fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) void { + // minimum length for using the Tukey's ninther method + const shortest_ninther = 50; + const len = b - a; + const i = a + len / 4 * 1; + const j = a + len / 4 * 2; + const k = a + len / 4 * 3; + + if (len >= 8) { + if (len >= shortest_ninther) { + // find medians in the neighborhoods of `i`, `j` and `k` + sort3(i - 1, i, i + 1, context); + sort3(j - 1, j, j + 1, context); + sort3(k - 1, k, k + 1, context); + } + + // find the median among `i`, `j` and `k` and stores it in `j` + sort3(i, j, k, context); + } + + pivot.* = j; +} + +fn sort3(a: usize, b: usize, c: usize, context: anytype) void { + if (context.lessThan(b, a)) { + context.swap(b, a); + } + + if (context.lessThan(c, b)) { + context.swap(c, b); + } + + if (context.lessThan(b, a)) { + context.swap(b, a); + } +} + +fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { + var cur = target; + while (true) { + // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 + // The `+ a + 1` is safe because: + // for `a > 0` then `2a >= a + 1`. + // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. + var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; + + // stop if we overshot the boundary + if (!(child < b)) break; + + // `next_child` is at most `b`, therefore no overflow is possible + const next_child = child + 1; + + // store the greater child in `child` + if (next_child < b and context.lessThan(child, next_child)) { + child = next_child; + } + + // stop if the Heap invariant holds at `cur`. + if (context.lessThan(child, cur)) break; + + // swap `cur` with the greater child, + // move one step down, and continue sifting. + context.swap(child, cur); + cur = child; + } +} + +// Tests + +const select_funcs = &[_]fn (comptime type, anytype, anytype, anytype, comptime anytype) void{ + nthElement, + heapSelect, +}; + +const context_select_funcs = &[_]fn (usize, usize, usize, anytype) void{ + nthElementContext, + heapSelectContext, +}; + +test "select" { + const asc_u8 = sort.asc(u8); + const asc_i32 = sort.asc(i32); + + const u8cases = [_][]const []const u8{ + &[_][]const u8{ + "a", + "a", + }, + &[_][]const u8{ + "az", + "az", + }, + &[_][]const u8{ + "za", + "az", + }, + &[_][]const u8{ + "asdf", + "adfs", + }, + &[_][]const u8{ + "one", + "eno", + }, + }; + + const i32cases = [_][]const []const i32{ + &[_][]const i32{ + &[_]i32{1}, + &[_]i32{1}, + }, + &[_][]const i32{ + &[_]i32{ 0, 1 }, + &[_]i32{ 0, 1 }, + }, + &[_][]const i32{ + &[_]i32{ 1, 0 }, + &[_]i32{ 0, 1 }, + }, + &[_][]const i32{ + &[_]i32{ 1, -1, 0 }, + &[_]i32{ -1, 0, 1 }, + }, + &[_][]const i32{ + &[_]i32{ 2, 1, 3 }, + &[_]i32{ 1, 2, 3 }, + }, + &[_][]const i32{ + &[_]i32{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 55, 32, 39, 58, 21, 88, 43, 22, 59 }, + &[_]i32{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 21, 22, 32, 39, 43, 55, 58, 59, 88 }, + }, + }; + + inline for (select_funcs) |selectFn| { + for (u8cases) |case| { + var buf: [20]u8 = undefined; + const slice = buf[0..case[0].len]; + const mid = slice.len / 2; + @memcpy(slice, case[0]); + selectFn(u8, slice, mid, {}, asc_u8); + try testing.expectEqual(slice[mid], case[1][mid]); + } + + for (i32cases) |case| { + var buf: [20]i32 = undefined; + const slice = buf[0..case[0].len]; + const mid = slice.len / 2; + @memcpy(slice, case[0]); + selectFn(i32, slice, mid, {}, asc_i32); + try testing.expectEqual(slice[mid], case[1][mid]); + } + } +} + +test "select descending" { + const desc_i32 = sort.desc(i32); + + const rev_cases = [_][]const []const i32{ + &[_][]const i32{ + &[_]i32{1}, + &[_]i32{1}, + }, + &[_][]const i32{ + &[_]i32{ 0, 1 }, + &[_]i32{ 1, 0 }, + }, + &[_][]const i32{ + &[_]i32{ 1, 0 }, + &[_]i32{ 1, 0 }, + }, + &[_][]const i32{ + &[_]i32{ 1, -1, 0 }, + &[_]i32{ 1, 0, -1 }, + }, + &[_][]const i32{ + &[_]i32{ 2, 1, 3 }, + &[_]i32{ 3, 2, 1 }, + }, + }; + + inline for (select_funcs) |selectFn| { + for (rev_cases) |case| { + var buf: [8]i32 = undefined; + const slice = buf[0..case[0].len]; + const mid = slice.len / 2; + @memcpy(slice, case[0]); + selectFn(i32, slice, mid, {}, desc_i32); + try testing.expectEqual(slice[mid], case[1][mid]); + } + } +} + +test "median odd length" { + var items = [_]i32{ 1, 3, 2, 5, 4 }; // sorted: 1, 2, 3, 4, 5 -> median 3 + const m = median(i32, &items, {}, sort.asc(i32)); + try testing.expectEqual(3, m); +} + +test "median even length" { + var items = [_]u32{ 1, 3, 2, 5, 4, 6 }; // sorted: 1, 2, 3, 4, 5, 6 -> median (3+4)/2 = 3.5 + const m = median(u32, &items, {}, sort.asc(u32)); + try testing.expectEqual(3, m); +} + +test "median even length negative" { + var items = [_]i32{ -1, -3, -2, -5, -4, -6 }; // sorted: 1, 2, 3, 4, 5, 6 -> median (3+4)/2 = 3.5 + const m = median(i32, &items, {}, sort.asc(i32)); + try testing.expectEqual(-3, m); +} + +test "median odd length float" { + var items = [_]f64{ 1.1, 3.3, 2.2, 5.5, 4.4 }; // sorted: 1.1, 2.2, 3.3, 4.4, 5.5 -> median 3.3 + const m = median(f64, &items, {}, sort.asc(f64)); + try testing.expectEqual(3.3, m); +} + +test "median even length float" { + var items = [_]f32{ 1.1, 3.3, 2.2, 5.5, 4.4, 6.6 }; // sorted: 1.1, 2.2, 3.3, 4.4, 5.5, 6.6 -> median (3.3+4.4)/2 = 3.85 + const m = median(f32, &items, {}, sort.asc(f32)); + try testing.expectApproxEqRel(3.85, m, 0.00001); +} diff --git a/lib/std/std.zig b/lib/std/std.zig index 564b04c609f8..ee0c459815e1 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -84,6 +84,7 @@ pub const pdb = @import("pdb.zig"); pub const pie = @import("pie.zig"); pub const posix = @import("posix.zig"); pub const process = @import("process.zig"); +pub const select = @import("select.zig"); pub const sort = @import("sort.zig"); pub const simd = @import("simd.zig"); pub const ascii = @import("ascii.zig"); From acc13e14e604c3f5460032c7a619caa7c7ea59be Mon Sep 17 00:00:00 2001 From: Oittaa Date: Fri, 11 Jul 2025 23:49:55 +0200 Subject: [PATCH 2/8] Add fuzz testing, fix comments --- lib/std/select.zig | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/lib/std/select.zig b/lib/std/select.zig index f850e2ac0808..538fc95b3d69 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -453,6 +453,30 @@ test "select descending" { } } +test "select fuzz testing" { + const asc_i32 = sort.asc(i32); + var prng = std.Random.DefaultPrng.init(std.testing.random_seed); + const random = prng.random(); + const test_case_count = 10; + + inline for (select_funcs) |selectFn| { + for (0..test_case_count) |_| { + const array_size = random.intRangeLessThan(usize, 1, 1000); + const array = try testing.allocator.alloc(i32, array_size); + defer testing.allocator.free(array); + // populate with random data + for (array) |*item| { + item.* = random.intRangeLessThan(i32, 0, 100); + } + const n = random.intRangeLessThan(usize, 0, array_size); + selectFn(i32, array, n, {}, asc_i32); + const n_val = array[n]; + sort.pdq(i32, array, {}, asc_i32); + try testing.expectEqual(n_val, array[n]); + } + } +} + test "median odd length" { var items = [_]i32{ 1, 3, 2, 5, 4 }; // sorted: 1, 2, 3, 4, 5 -> median 3 const m = median(i32, &items, {}, sort.asc(i32)); @@ -466,7 +490,7 @@ test "median even length" { } test "median even length negative" { - var items = [_]i32{ -1, -3, -2, -5, -4, -6 }; // sorted: 1, 2, 3, 4, 5, 6 -> median (3+4)/2 = 3.5 + var items = [_]i32{ -1, -3, -2, -5, -4, -6 }; // sorted: -6, -5, -4, -3, -2, -1 -> median (-4 + -3)/2 = -3.5 const m = median(i32, &items, {}, sort.asc(i32)); try testing.expectEqual(-3, m); } From eff69e9f0916439e152677f4fe69f1b732ad6e88 Mon Sep 17 00:00:00 2001 From: Oittaa Date: Sat, 12 Jul 2025 15:26:09 +0200 Subject: [PATCH 3/8] Move shared functions to sort/utils.zig --- lib/std/select.zig | 166 +++++++++++------------------------------ lib/std/sort.zig | 35 +-------- lib/std/sort/pdq.zig | 131 +------------------------------- lib/std/sort/utils.zig | 153 +++++++++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 284 deletions(-) create mode 100644 lib/std/sort/utils.zig diff --git a/lib/std/select.zig b/lib/std/select.zig index 538fc95b3d69..b066d6065b35 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -86,15 +86,15 @@ pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { } depth_limit -= 1; var pivot: usize = 0; - chosePivot(left, right, &pivot, context); + _ = sort.utils.chosePivot(left, right, &pivot, context); // if the chosen pivot is equal to the predecessor, then it's the smallest element in the // slice. Partition the slice into elements equal to and elements greater than the pivot. // This case is usually hit when the slice contains many duplicate elements. if (left > a and !context.lessThan(left - 1, pivot)) { - left = partitionEqual(left, right, pivot, context); + left = sort.utils.partitionEqual(left, right, pivot, context); continue; } - partition(left, right, &pivot, context); + sort.utils.partition(left, right, &pivot, context); const target = a + n; if (pivot == target) { break; @@ -162,7 +162,7 @@ pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { var i = a + (b - a) / 2; while (i > a) { i -= 1; - siftDown(a, i, b, context); + sort.utils.siftDown(a, i, b, context); } var heap_end = b; @@ -170,7 +170,7 @@ pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { while (i < n_largest - 1) : (i += 1) { heap_end -= 1; context.swap(a, heap_end); - siftDown(a, a, heap_end, context); + sort.utils.siftDown(a, a, heap_end, context); } // After the loop, the root of the heap (at index `a`) is the nth smallest element. @@ -214,120 +214,6 @@ pub fn median( }; } -/// partitions `items[a..b]` into elements smaller than `items[pivot]`, -/// followed by elements greater than or equal to `items[pivot]`. -/// -/// sets the new pivot. -fn partition(a: usize, b: usize, pivot: *usize, context: anytype) void { - // move pivot to the first place - context.swap(a, pivot.*); - var i = a + 1; - var j = b - 1; - while (true) { - while (i <= j and context.lessThan(i, a)) i += 1; - while (i <= j and !context.lessThan(j, a)) j -= 1; - if (i > j) break; - context.swap(i, j); - i += 1; - j -= 1; - } - context.swap(j, a); - pivot.* = j; -} - -/// partitions items into elements equal to `items[pivot]` -/// followed by elements greater than `items[pivot]`. -/// -/// it assumed that `items[a..b]` does not contain elements smaller than the `items[pivot]`. -fn partitionEqual(a: usize, b: usize, pivot: usize, context: anytype) usize { - // move pivot to the first place - context.swap(a, pivot); - - var i = a + 1; - var j = b - 1; - - while (true) { - while (i <= j and !context.lessThan(a, i)) i += 1; - while (i <= j and context.lessThan(a, j)) j -= 1; - if (i > j) break; - - context.swap(i, j); - i += 1; - j -= 1; - } - - return i; -} - -/// chooses a pivot in `items[a..b]`. -/// It's modeled directly after the `chosePivot` function in `std.sort`. -fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) void { - // minimum length for using the Tukey's ninther method - const shortest_ninther = 50; - const len = b - a; - const i = a + len / 4 * 1; - const j = a + len / 4 * 2; - const k = a + len / 4 * 3; - - if (len >= 8) { - if (len >= shortest_ninther) { - // find medians in the neighborhoods of `i`, `j` and `k` - sort3(i - 1, i, i + 1, context); - sort3(j - 1, j, j + 1, context); - sort3(k - 1, k, k + 1, context); - } - - // find the median among `i`, `j` and `k` and stores it in `j` - sort3(i, j, k, context); - } - - pivot.* = j; -} - -fn sort3(a: usize, b: usize, c: usize, context: anytype) void { - if (context.lessThan(b, a)) { - context.swap(b, a); - } - - if (context.lessThan(c, b)) { - context.swap(c, b); - } - - if (context.lessThan(b, a)) { - context.swap(b, a); - } -} - -fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { - var cur = target; - while (true) { - // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 - // The `+ a + 1` is safe because: - // for `a > 0` then `2a >= a + 1`. - // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. - var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; - - // stop if we overshot the boundary - if (!(child < b)) break; - - // `next_child` is at most `b`, therefore no overflow is possible - const next_child = child + 1; - - // store the greater child in `child` - if (next_child < b and context.lessThan(child, next_child)) { - child = next_child; - } - - // stop if the Heap invariant holds at `cur`. - if (context.lessThan(child, cur)) break; - - // swap `cur` with the greater child, - // move one step down, and continue sifting. - context.swap(child, cur); - cur = child; - } -} - // Tests const select_funcs = &[_]fn (comptime type, anytype, anytype, anytype, comptime anytype) void{ @@ -401,7 +287,7 @@ test "select" { const mid = slice.len / 2; @memcpy(slice, case[0]); selectFn(u8, slice, mid, {}, asc_u8); - try testing.expectEqual(slice[mid], case[1][mid]); + try testing.expectEqual(case[1][mid], slice[mid]); } for (i32cases) |case| { @@ -410,7 +296,7 @@ test "select" { const mid = slice.len / 2; @memcpy(slice, case[0]); selectFn(i32, slice, mid, {}, asc_i32); - try testing.expectEqual(slice[mid], case[1][mid]); + try testing.expectEqual(case[1][mid], slice[mid]); } } } @@ -448,14 +334,46 @@ test "select descending" { const mid = slice.len / 2; @memcpy(slice, case[0]); selectFn(i32, slice, mid, {}, desc_i32); - try testing.expectEqual(slice[mid], case[1][mid]); + try testing.expectEqual(case[1][mid], slice[mid]); + } + } +} + +test "select with context in the middle of a slice" { + const Context = struct { + items: []i32, + + pub fn lessThan(ctx: @This(), a: usize, b: usize) bool { + return ctx.items[a] < ctx.items[b]; + } + + pub fn swap(ctx: @This(), a: usize, b: usize) void { + return mem.swap(i32, &ctx.items[a], &ctx.items[b]); + } + }; + + const i32case = &[_]i32{ 0, 1, 8, 3, 6, 5, 4, 2, 9, 7, 10, 55, 32, 39, 58, 21, 88, 43, 22, 59 }; + + const ranges = [_]struct { start: usize, end: usize, n: usize, expected: i32 }{ + .{ .start = 10, .end = 20, .n = 1, .expected = 21 }, + .{ .start = 1, .end = 11, .n = 2, .expected = 3 }, + .{ .start = 3, .end = 7, .n = 3, .expected = 6 }, + }; + + inline for (context_select_funcs) |selectFn| { + for (ranges) |range| { + var buf: [20]i32 = undefined; + const slice = buf[0..i32case.len]; + @memcpy(slice, i32case); + selectFn(range.n, range.start, range.end, Context{ .items = slice }); + try testing.expectEqual(range.expected, slice[range.start + range.n]); } } } test "select fuzz testing" { const asc_i32 = sort.asc(i32); - var prng = std.Random.DefaultPrng.init(std.testing.random_seed); + var prng = std.Random.DefaultPrng.init(testing.random_seed); const random = prng.random(); const test_case_count = 10; @@ -472,7 +390,7 @@ test "select fuzz testing" { selectFn(i32, array, n, {}, asc_i32); const n_val = array[n]; sort.pdq(i32, array, {}, asc_i32); - try testing.expectEqual(n_val, array[n]); + try testing.expectEqual(array[n], n_val); } } } diff --git a/lib/std/sort.zig b/lib/std/sort.zig index 8705d2401730..747dbfff67e0 100644 --- a/lib/std/sort.zig +++ b/lib/std/sort.zig @@ -9,6 +9,7 @@ pub const Mode = enum { stable, unstable }; pub const block = @import("sort/block.zig").block; pub const pdq = @import("sort/pdq.zig").pdq; pub const pdqContext = @import("sort/pdq.zig").pdqContext; +pub const utils = @import("sort/utils.zig"); /// Stable in-place sort. O(n) best case, O(pow(n, 2)) worst case. /// O(1) memory (no allocator required). @@ -86,7 +87,7 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void { var i = a + (b - a) / 2; while (i > a) { i -= 1; - siftDown(a, i, b, context); + utils.siftDown(a, i, b, context); } // pop maximal elements from the heap. @@ -94,37 +95,7 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void { while (i > a) { i -= 1; context.swap(a, i); - siftDown(a, a, i, context); - } -} - -fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { - var cur = target; - while (true) { - // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 - // The `+ a + 1` is safe because: - // for `a > 0` then `2a >= a + 1`. - // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. - var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; - - // stop if we overshot the boundary - if (!(child < b)) break; - - // `next_child` is at most `b`, therefore no overflow is possible - const next_child = child + 1; - - // store the greater child in `child` - if (next_child < b and context.lessThan(child, next_child)) { - child = next_child; - } - - // stop if the Heap invariant holds at `cur`. - if (context.lessThan(child, cur)) break; - - // swap `cur` with the greater child, - // move one step down, and continue sifting. - context.swap(child, cur); - cur = child; + utils.siftDown(a, a, i, context); } } diff --git a/lib/std/sort/pdq.zig b/lib/std/sort/pdq.zig index 55bd17ae93e5..c1d115bddf48 100644 --- a/lib/std/sort/pdq.zig +++ b/lib/std/sort/pdq.zig @@ -29,12 +29,6 @@ pub fn pdq( pdqContext(0, items.len, Context{ .items = items, .sub_ctx = context }); } -const Hint = enum { - increasing, - decreasing, - unknown, -}; - /// Unstable in-place sort. O(n) best case, O(n*log(n)) worst case and average case. /// O(log(n)) memory (no allocator required). /// `context` must have methods `swap` and `lessThan`, @@ -80,7 +74,7 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { // choose a pivot and try guessing whether the slice is already sorted. var pivot: usize = 0; - var hint = chosePivot(range.a, range.b, &pivot, context); + var hint = sort.utils.chosePivot(range.a, range.b, &pivot, context); if (hint == .decreasing) { // The maximum number of swaps was performed, so items are likely @@ -102,13 +96,13 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { // slice. Partition the slice into elements equal to and elements greater than the pivot. // This case is usually hit when the slice contains many duplicate elements. if (range.a > a and !context.lessThan(range.a - 1, pivot)) { - range.a = partitionEqual(range.a, range.b, pivot, context); + range.a = sort.utils.partitionEqual(range.a, range.b, pivot, context); continue; } // partition the slice. var mid = pivot; - was_partitioned = partition(range.a, range.b, &mid, context); + was_partitioned = sort.utils.partition(range.a, range.b, &mid, context); const left_len = mid - range.a; const right_len = range.b - mid; @@ -131,74 +125,6 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { } } -/// partitions `items[a..b]` into elements smaller than `items[pivot]`, -/// followed by elements greater than or equal to `items[pivot]`. -/// -/// sets the new pivot. -/// returns `true` if already partitioned. -fn partition(a: usize, b: usize, pivot: *usize, context: anytype) bool { - // move pivot to the first place - context.swap(a, pivot.*); - - var i = a + 1; - var j = b - 1; - - while (i <= j and context.lessThan(i, a)) i += 1; - while (i <= j and !context.lessThan(j, a)) j -= 1; - - // check if items are already partitioned (no item to swap) - if (i > j) { - // put pivot back to the middle - context.swap(j, a); - pivot.* = j; - return true; - } - - context.swap(i, j); - i += 1; - j -= 1; - - while (true) { - while (i <= j and context.lessThan(i, a)) i += 1; - while (i <= j and !context.lessThan(j, a)) j -= 1; - if (i > j) break; - - context.swap(i, j); - i += 1; - j -= 1; - } - - // TODO: Enable the BlockQuicksort optimization - - context.swap(j, a); - pivot.* = j; - return false; -} - -/// partitions items into elements equal to `items[pivot]` -/// followed by elements greater than `items[pivot]`. -/// -/// it assumed that `items[a..b]` does not contain elements smaller than the `items[pivot]`. -fn partitionEqual(a: usize, b: usize, pivot: usize, context: anytype) usize { - // move pivot to the first place - context.swap(a, pivot); - - var i = a + 1; - var j = b - 1; - - while (true) { - while (i <= j and !context.lessThan(a, i)) i += 1; - while (i <= j and context.lessThan(a, j)) j -= 1; - if (i > j) break; - - context.swap(i, j); - i += 1; - j -= 1; - } - - return i; -} - /// partially sorts a slice by shifting several out-of-order elements around. /// /// returns `true` if the slice is sorted at the end. This function is `O(n)` worst-case. @@ -268,57 +194,6 @@ fn breakPatterns(a: usize, b: usize, context: anytype) void { } } -/// chooses a pivot in `items[a..b]`. -/// swaps likely_sorted when `items[a..b]` seems to be already sorted. -fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint { - // minimum length for using the Tukey's ninther method - const shortest_ninther = 50; - // max_swaps is the maximum number of swaps allowed in this function - const max_swaps = 4 * 3; - - const len = b - a; - const i = a + len / 4 * 1; - const j = a + len / 4 * 2; - const k = a + len / 4 * 3; - var swaps: usize = 0; - - if (len >= 8) { - if (len >= shortest_ninther) { - // find medians in the neighborhoods of `i`, `j` and `k` - sort3(i - 1, i, i + 1, &swaps, context); - sort3(j - 1, j, j + 1, &swaps, context); - sort3(k - 1, k, k + 1, &swaps, context); - } - - // find the median among `i`, `j` and `k` and stores it in `j` - sort3(i, j, k, &swaps, context); - } - - pivot.* = j; - return switch (swaps) { - 0 => .increasing, - max_swaps => .decreasing, - else => .unknown, - }; -} - -fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) void { - if (context.lessThan(b, a)) { - swaps.* += 1; - context.swap(b, a); - } - - if (context.lessThan(c, b)) { - swaps.* += 1; - context.swap(c, b); - } - - if (context.lessThan(b, a)) { - swaps.* += 1; - context.swap(b, a); - } -} - fn reverseRange(a: usize, b: usize, context: anytype) void { var i = a; var j = b - 1; diff --git a/lib/std/sort/utils.zig b/lib/std/sort/utils.zig new file mode 100644 index 000000000000..5d40d86eadf6 --- /dev/null +++ b/lib/std/sort/utils.zig @@ -0,0 +1,153 @@ +const std = @import("../std.zig"); +const math = std.math; + +pub const Hint = enum { increasing, decreasing, unknown }; + +pub fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { + var cur = target; + while (true) { + // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 + // The `+ a + 1` is safe because: + // for `a > 0` then `2a >= a + 1`. + // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. + var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; + + // stop if we overshot the boundary + if (!(child < b)) break; + + // `next_child` is at most `b`, therefore no overflow is possible + const next_child = child + 1; + + // store the greater child in `child` + if (next_child < b and context.lessThan(child, next_child)) { + child = next_child; + } + + // stop if the Heap invariant holds at `cur`. + if (context.lessThan(child, cur)) break; + + // swap `cur` with the greater child, + // move one step down, and continue sifting. + context.swap(child, cur); + cur = child; + } +} + +/// partitions `items[a..b]` into elements smaller than `items[pivot]`, +/// followed by elements greater than or equal to `items[pivot]`. +/// +/// sets the new pivot. +/// returns `true` if already partitioned. +pub fn partition(a: usize, b: usize, pivot: *usize, context: anytype) bool { + // move pivot to the first place + context.swap(a, pivot.*); + + var i = a + 1; + var j = b - 1; + + while (i <= j and context.lessThan(i, a)) i += 1; + while (i <= j and !context.lessThan(j, a)) j -= 1; + + // check if items are already partitioned (no item to swap) + if (i > j) { + // put pivot back to the middle + context.swap(j, a); + pivot.* = j; + return true; + } + + context.swap(i, j); + i += 1; + j -= 1; + + while (true) { + while (i <= j and context.lessThan(i, a)) i += 1; + while (i <= j and !context.lessThan(j, a)) j -= 1; + if (i > j) break; + + context.swap(i, j); + i += 1; + j -= 1; + } + + // TODO: Enable the BlockQuicksort optimization + + context.swap(j, a); + pivot.* = j; + return false; +} + +/// partitions items into elements equal to `items[pivot]` +/// followed by elements greater than `items[pivot]`. +/// +/// it assumed that `items[a..b]` does not contain elements smaller than the `items[pivot]`. +pub fn partitionEqual(a: usize, b: usize, pivot: usize, context: anytype) usize { + // move pivot to the first place + context.swap(a, pivot); + + var i = a + 1; + var j = b - 1; + + while (true) { + while (i <= j and !context.lessThan(a, i)) i += 1; + while (i <= j and context.lessThan(a, j)) j -= 1; + if (i > j) break; + + context.swap(i, j); + i += 1; + j -= 1; + } + + return i; +} + +/// chooses a pivot in `items[a..b]`. +/// swaps likely_sorted when `items[a..b]` seems to be already sorted. +pub fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint { + // minimum length for using the Tukey's ninther method + const shortest_ninther = 50; + // max_swaps is the maximum number of swaps allowed in this function + const max_swaps = 4 * 3; + + const len = b - a; + const i = a + len / 4 * 1; + const j = a + len / 4 * 2; + const k = a + len / 4 * 3; + var swaps: usize = 0; + + if (len >= 8) { + if (len >= shortest_ninther) { + // find medians in the neighborhoods of `i`, `j` and `k` + sort3(i - 1, i, i + 1, &swaps, context); + sort3(j - 1, j, j + 1, &swaps, context); + sort3(k - 1, k, k + 1, &swaps, context); + } + + // find the median among `i`, `j` and `k` and stores it in `j` + sort3(i, j, k, &swaps, context); + } + + pivot.* = j; + return switch (swaps) { + 0 => .increasing, + max_swaps => .decreasing, + else => .unknown, + }; +} + +fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) void { + if (context.lessThan(b, a)) { + swaps.* += 1; + context.swap(b, a); + } + + if (context.lessThan(c, b)) { + swaps.* += 1; + context.swap(c, b); + } + + if (context.lessThan(b, a)) { + swaps.* += 1; + context.swap(b, a); + } +} From 0b30c44d90c8c53b09331ec277822e2ec81124b4 Mon Sep 17 00:00:00 2001 From: Oittaa Date: Sat, 12 Jul 2025 15:46:04 +0200 Subject: [PATCH 4/8] Discard return value from partition() --- lib/std/select.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/select.zig b/lib/std/select.zig index b066d6065b35..4bd3e3dee1bc 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -94,7 +94,7 @@ pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { left = sort.utils.partitionEqual(left, right, pivot, context); continue; } - sort.utils.partition(left, right, &pivot, context); + _ = sort.utils.partition(left, right, &pivot, context); const target = a + n; if (pivot == target) { break; From 5fba3027e386752065b0ddba807626da73237fca Mon Sep 17 00:00:00 2001 From: Oittaa Date: Sat, 12 Jul 2025 16:53:52 +0200 Subject: [PATCH 5/8] Seems like utils like crypto.utils is to be removed, let's do the same --- lib/std/select.zig | 10 +-- lib/std/sort.zig | 155 ++++++++++++++++++++++++++++++++++++++++- lib/std/sort/pdq.zig | 6 +- lib/std/sort/utils.zig | 153 ---------------------------------------- 4 files changed, 160 insertions(+), 164 deletions(-) delete mode 100644 lib/std/sort/utils.zig diff --git a/lib/std/select.zig b/lib/std/select.zig index 4bd3e3dee1bc..40b8fc9a651f 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -86,15 +86,15 @@ pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { } depth_limit -= 1; var pivot: usize = 0; - _ = sort.utils.chosePivot(left, right, &pivot, context); + _ = sort.chosePivot(left, right, &pivot, context); // if the chosen pivot is equal to the predecessor, then it's the smallest element in the // slice. Partition the slice into elements equal to and elements greater than the pivot. // This case is usually hit when the slice contains many duplicate elements. if (left > a and !context.lessThan(left - 1, pivot)) { - left = sort.utils.partitionEqual(left, right, pivot, context); + left = sort.partitionEqual(left, right, pivot, context); continue; } - _ = sort.utils.partition(left, right, &pivot, context); + _ = sort.partition(left, right, &pivot, context); const target = a + n; if (pivot == target) { break; @@ -162,7 +162,7 @@ pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { var i = a + (b - a) / 2; while (i > a) { i -= 1; - sort.utils.siftDown(a, i, b, context); + sort.siftDown(a, i, b, context); } var heap_end = b; @@ -170,7 +170,7 @@ pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { while (i < n_largest - 1) : (i += 1) { heap_end -= 1; context.swap(a, heap_end); - sort.utils.siftDown(a, a, heap_end, context); + sort.siftDown(a, a, heap_end, context); } // After the loop, the root of the heap (at index `a`) is the nth smallest element. diff --git a/lib/std/sort.zig b/lib/std/sort.zig index 747dbfff67e0..49bf021be5de 100644 --- a/lib/std/sort.zig +++ b/lib/std/sort.zig @@ -4,12 +4,12 @@ const testing = std.testing; const mem = std.mem; const math = std.math; +pub const Hint = enum { increasing, decreasing, unknown }; pub const Mode = enum { stable, unstable }; pub const block = @import("sort/block.zig").block; pub const pdq = @import("sort/pdq.zig").pdq; pub const pdqContext = @import("sort/pdq.zig").pdqContext; -pub const utils = @import("sort/utils.zig"); /// Stable in-place sort. O(n) best case, O(pow(n, 2)) worst case. /// O(1) memory (no allocator required). @@ -87,7 +87,7 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void { var i = a + (b - a) / 2; while (i > a) { i -= 1; - utils.siftDown(a, i, b, context); + siftDown(a, i, b, context); } // pop maximal elements from the heap. @@ -95,7 +95,37 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void { while (i > a) { i -= 1; context.swap(a, i); - utils.siftDown(a, a, i, context); + siftDown(a, a, i, context); + } +} + +pub fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { + var cur = target; + while (true) { + // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 + // The `+ a + 1` is safe because: + // for `a > 0` then `2a >= a + 1`. + // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. + var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; + + // stop if we overshot the boundary + if (!(child < b)) break; + + // `next_child` is at most `b`, therefore no overflow is possible + const next_child = child + 1; + + // store the greater child in `child` + if (next_child < b and context.lessThan(child, next_child)) { + child = next_child; + } + + // stop if the Heap invariant holds at `cur`. + if (context.lessThan(child, cur)) break; + + // swap `cur` with the greater child, + // move one step down, and continue sifting. + context.swap(child, cur); + cur = child; } } @@ -955,3 +985,122 @@ test isSorted { try testing.expect(isSorted(u8, "ffff", {}, asc_u8)); try testing.expect(isSorted(u8, "ffff", {}, desc_u8)); } + +/// partitions `items[a..b]` into elements smaller than `items[pivot]`, +/// followed by elements greater than or equal to `items[pivot]`. +/// +/// sets the new pivot. +/// returns `true` if already partitioned. +pub fn partition(a: usize, b: usize, pivot: *usize, context: anytype) bool { + // move pivot to the first place + context.swap(a, pivot.*); + + var i = a + 1; + var j = b - 1; + + while (i <= j and context.lessThan(i, a)) i += 1; + while (i <= j and !context.lessThan(j, a)) j -= 1; + + // check if items are already partitioned (no item to swap) + if (i > j) { + // put pivot back to the middle + context.swap(j, a); + pivot.* = j; + return true; + } + + context.swap(i, j); + i += 1; + j -= 1; + + while (true) { + while (i <= j and context.lessThan(i, a)) i += 1; + while (i <= j and !context.lessThan(j, a)) j -= 1; + if (i > j) break; + + context.swap(i, j); + i += 1; + j -= 1; + } + + // TODO: Enable the BlockQuicksort optimization + + context.swap(j, a); + pivot.* = j; + return false; +} + +/// partitions items into elements equal to `items[pivot]` +/// followed by elements greater than `items[pivot]`. +/// +/// it assumed that `items[a..b]` does not contain elements smaller than the `items[pivot]`. +pub fn partitionEqual(a: usize, b: usize, pivot: usize, context: anytype) usize { + // move pivot to the first place + context.swap(a, pivot); + + var i = a + 1; + var j = b - 1; + + while (true) { + while (i <= j and !context.lessThan(a, i)) i += 1; + while (i <= j and context.lessThan(a, j)) j -= 1; + if (i > j) break; + + context.swap(i, j); + i += 1; + j -= 1; + } + + return i; +} + +/// chooses a pivot in `items[a..b]`. +/// swaps likely_sorted when `items[a..b]` seems to be already sorted. +pub fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint { + // minimum length for using the Tukey's ninther method + const shortest_ninther = 50; + // max_swaps is the maximum number of swaps allowed in this function + const max_swaps = 4 * 3; + + const len = b - a; + const i = a + len / 4 * 1; + const j = a + len / 4 * 2; + const k = a + len / 4 * 3; + var swaps: usize = 0; + + if (len >= 8) { + if (len >= shortest_ninther) { + // find medians in the neighborhoods of `i`, `j` and `k` + sort3(i - 1, i, i + 1, &swaps, context); + sort3(j - 1, j, j + 1, &swaps, context); + sort3(k - 1, k, k + 1, &swaps, context); + } + + // find the median among `i`, `j` and `k` and stores it in `j` + sort3(i, j, k, &swaps, context); + } + + pivot.* = j; + return switch (swaps) { + 0 => .increasing, + max_swaps => .decreasing, + else => .unknown, + }; +} + +pub fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) void { + if (context.lessThan(b, a)) { + swaps.* += 1; + context.swap(b, a); + } + + if (context.lessThan(c, b)) { + swaps.* += 1; + context.swap(c, b); + } + + if (context.lessThan(b, a)) { + swaps.* += 1; + context.swap(b, a); + } +} diff --git a/lib/std/sort/pdq.zig b/lib/std/sort/pdq.zig index c1d115bddf48..59d759d1c2c6 100644 --- a/lib/std/sort/pdq.zig +++ b/lib/std/sort/pdq.zig @@ -74,7 +74,7 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { // choose a pivot and try guessing whether the slice is already sorted. var pivot: usize = 0; - var hint = sort.utils.chosePivot(range.a, range.b, &pivot, context); + var hint = sort.chosePivot(range.a, range.b, &pivot, context); if (hint == .decreasing) { // The maximum number of swaps was performed, so items are likely @@ -96,13 +96,13 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void { // slice. Partition the slice into elements equal to and elements greater than the pivot. // This case is usually hit when the slice contains many duplicate elements. if (range.a > a and !context.lessThan(range.a - 1, pivot)) { - range.a = sort.utils.partitionEqual(range.a, range.b, pivot, context); + range.a = sort.partitionEqual(range.a, range.b, pivot, context); continue; } // partition the slice. var mid = pivot; - was_partitioned = sort.utils.partition(range.a, range.b, &mid, context); + was_partitioned = sort.partition(range.a, range.b, &mid, context); const left_len = mid - range.a; const right_len = range.b - mid; diff --git a/lib/std/sort/utils.zig b/lib/std/sort/utils.zig deleted file mode 100644 index 5d40d86eadf6..000000000000 --- a/lib/std/sort/utils.zig +++ /dev/null @@ -1,153 +0,0 @@ -const std = @import("../std.zig"); -const math = std.math; - -pub const Hint = enum { increasing, decreasing, unknown }; - -pub fn siftDown(a: usize, target: usize, b: usize, context: anytype) void { - var cur = target; - while (true) { - // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1 - // The `+ a + 1` is safe because: - // for `a > 0` then `2a >= a + 1`. - // for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe. - var child = (math.mul(usize, cur - a, 2) catch break) + a + 1; - - // stop if we overshot the boundary - if (!(child < b)) break; - - // `next_child` is at most `b`, therefore no overflow is possible - const next_child = child + 1; - - // store the greater child in `child` - if (next_child < b and context.lessThan(child, next_child)) { - child = next_child; - } - - // stop if the Heap invariant holds at `cur`. - if (context.lessThan(child, cur)) break; - - // swap `cur` with the greater child, - // move one step down, and continue sifting. - context.swap(child, cur); - cur = child; - } -} - -/// partitions `items[a..b]` into elements smaller than `items[pivot]`, -/// followed by elements greater than or equal to `items[pivot]`. -/// -/// sets the new pivot. -/// returns `true` if already partitioned. -pub fn partition(a: usize, b: usize, pivot: *usize, context: anytype) bool { - // move pivot to the first place - context.swap(a, pivot.*); - - var i = a + 1; - var j = b - 1; - - while (i <= j and context.lessThan(i, a)) i += 1; - while (i <= j and !context.lessThan(j, a)) j -= 1; - - // check if items are already partitioned (no item to swap) - if (i > j) { - // put pivot back to the middle - context.swap(j, a); - pivot.* = j; - return true; - } - - context.swap(i, j); - i += 1; - j -= 1; - - while (true) { - while (i <= j and context.lessThan(i, a)) i += 1; - while (i <= j and !context.lessThan(j, a)) j -= 1; - if (i > j) break; - - context.swap(i, j); - i += 1; - j -= 1; - } - - // TODO: Enable the BlockQuicksort optimization - - context.swap(j, a); - pivot.* = j; - return false; -} - -/// partitions items into elements equal to `items[pivot]` -/// followed by elements greater than `items[pivot]`. -/// -/// it assumed that `items[a..b]` does not contain elements smaller than the `items[pivot]`. -pub fn partitionEqual(a: usize, b: usize, pivot: usize, context: anytype) usize { - // move pivot to the first place - context.swap(a, pivot); - - var i = a + 1; - var j = b - 1; - - while (true) { - while (i <= j and !context.lessThan(a, i)) i += 1; - while (i <= j and context.lessThan(a, j)) j -= 1; - if (i > j) break; - - context.swap(i, j); - i += 1; - j -= 1; - } - - return i; -} - -/// chooses a pivot in `items[a..b]`. -/// swaps likely_sorted when `items[a..b]` seems to be already sorted. -pub fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint { - // minimum length for using the Tukey's ninther method - const shortest_ninther = 50; - // max_swaps is the maximum number of swaps allowed in this function - const max_swaps = 4 * 3; - - const len = b - a; - const i = a + len / 4 * 1; - const j = a + len / 4 * 2; - const k = a + len / 4 * 3; - var swaps: usize = 0; - - if (len >= 8) { - if (len >= shortest_ninther) { - // find medians in the neighborhoods of `i`, `j` and `k` - sort3(i - 1, i, i + 1, &swaps, context); - sort3(j - 1, j, j + 1, &swaps, context); - sort3(k - 1, k, k + 1, &swaps, context); - } - - // find the median among `i`, `j` and `k` and stores it in `j` - sort3(i, j, k, &swaps, context); - } - - pivot.* = j; - return switch (swaps) { - 0 => .increasing, - max_swaps => .decreasing, - else => .unknown, - }; -} - -fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) void { - if (context.lessThan(b, a)) { - swaps.* += 1; - context.swap(b, a); - } - - if (context.lessThan(c, b)) { - swaps.* += 1; - context.swap(c, b); - } - - if (context.lessThan(b, a)) { - swaps.* += 1; - context.swap(b, a); - } -} From 6423f97192ec957f2fc8bb5ccea112a9bb53c3d5 Mon Sep 17 00:00:00 2001 From: Oittaa Date: Sun, 13 Jul 2025 19:00:28 +0200 Subject: [PATCH 6/8] Prevent median overflows, cleanup --- lib/std/select.zig | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/lib/std/select.zig b/lib/std/select.zig index 40b8fc9a651f..c1cc94e2cfaa 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -159,7 +159,7 @@ pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { assert(n < len); const n_largest = len - n; // build the heap in linear time. - var i = a + (b - a) / 2; + var i = a + len / 2; while (i > a) { i -= 1; sort.siftDown(a, i, b, context); @@ -173,17 +173,15 @@ pub fn heapSelectContext(n: usize, a: usize, b: usize, context: anytype) void { sort.siftDown(a, a, heap_end, context); } - // After the loop, the root of the heap (at index `a`) is the nth smallest element. + // After the loop, the root (at index `a`) is the n-th smallest element (0-indexed). // We swap it into the correct position `a + n`. - if (len > 0) { - context.swap(a, a + n); - } + context.swap(a, a + n); } /// Calculates the median of a slice using the nthElement function. /// For slices with an odd number of elements, it returns the middle element. /// For slices with an even number of elements, it returns the mean of the two central elements. -/// The result from integer types is rounded towards zero, while for floating-point types it is the exact mean. +/// The result from integer types is floor divided. /// This function modifies the order of elements in the slice. pub fn median( comptime T: type, @@ -208,8 +206,8 @@ pub fn median( } } return switch (@typeInfo(T)) { - .int => @divTrunc((lower_median + upper_median), 2), - .float => (lower_median + upper_median) / 2, + .int => (lower_median & upper_median) + ((lower_median ^ upper_median) >> 1), + .float => lower_median / 2 + upper_median / 2, else => @compileError("Unsupported type for median: " ++ @typeName(T)), }; } @@ -410,7 +408,7 @@ test "median even length" { test "median even length negative" { var items = [_]i32{ -1, -3, -2, -5, -4, -6 }; // sorted: -6, -5, -4, -3, -2, -1 -> median (-4 + -3)/2 = -3.5 const m = median(i32, &items, {}, sort.asc(i32)); - try testing.expectEqual(-3, m); + try testing.expectEqual(-4, m); } test "median odd length float" { @@ -424,3 +422,30 @@ test "median even length float" { const m = median(f32, &items, {}, sort.asc(f32)); try testing.expectApproxEqRel(3.85, m, 0.00001); } + +test "median overflow i8" { + const asc = sort.asc(i8); + const fill_values = [_]i8{ 127, -128 }; + for (fill_values) |fill| { + var items = [_]i8{fill} ** 4; + const m = median(i8, &items, {}, asc); + try testing.expectEqual(fill, m); + } +} + +test "median overflow f32" { + const asc = sort.asc(f32); + const fill_values = [_]f32{ math.floatMax(f32), math.floatMin(f32) }; + for (fill_values) |fill| { + var items = [_]f32{fill} ** 4; + const m = median(f32, &items, {}, asc); + try testing.expectEqual(fill, m); + } +} + +test "median mixed min max i8" { + const asc = sort.asc(i8); + var items = [_]i8{ -128, 127, 127, -128 }; + const m = median(i8, &items, {}, asc); + try testing.expectEqual(@as(i8, -1), m); +} From 4fee68cb3f3ab9f0078acfd4df5529299bce81f2 Mon Sep 17 00:00:00 2001 From: Oittaa Date: Mon, 14 Jul 2025 19:39:14 +0200 Subject: [PATCH 7/8] Use testing.expectEqualStrings with strings, loop over every position in the "select" test --- lib/std/select.zig | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/lib/std/select.zig b/lib/std/select.zig index c1cc94e2cfaa..de8683e21738 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -280,21 +280,25 @@ test "select" { inline for (select_funcs) |selectFn| { for (u8cases) |case| { + const len = case[0].len; var buf: [20]u8 = undefined; - const slice = buf[0..case[0].len]; - const mid = slice.len / 2; - @memcpy(slice, case[0]); - selectFn(u8, slice, mid, {}, asc_u8); - try testing.expectEqual(case[1][mid], slice[mid]); + for (0..len) |n| { + const slice = buf[0..len]; + @memcpy(slice, case[0]); + selectFn(u8, slice, n, {}, asc_u8); + try testing.expectEqualStrings(case[1][n .. n + 1], slice[n .. n + 1]); + } } for (i32cases) |case| { + const len = case[0].len; var buf: [20]i32 = undefined; - const slice = buf[0..case[0].len]; - const mid = slice.len / 2; - @memcpy(slice, case[0]); - selectFn(i32, slice, mid, {}, asc_i32); - try testing.expectEqual(case[1][mid], slice[mid]); + for (0..len) |n| { + const slice = buf[0..len]; + @memcpy(slice, case[0]); + selectFn(i32, slice, n, {}, asc_i32); + try testing.expectEqual(case[1][n], slice[n]); + } } } } @@ -445,7 +449,7 @@ test "median overflow f32" { test "median mixed min max i8" { const asc = sort.asc(i8); - var items = [_]i8{ -128, 127, 127, -128 }; + var items = [_]i8{ -128, 127, 127, -128 }; // sorted: -128, -128, 127, 127 -> median (-128+127)/2 = -0.5 const m = median(i8, &items, {}, asc); try testing.expectEqual(@as(i8, -1), m); } From 45e07bd55b171ff4c46f3416fc8112a9a8c49f9a Mon Sep 17 00:00:00 2001 From: Oittaa Date: Thu, 17 Jul 2025 23:47:42 +0200 Subject: [PATCH 8/8] Break out early if the target is in `partitionEqual` --- lib/std/select.zig | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/std/select.zig b/lib/std/select.zig index de8683e21738..a4077af29255 100644 --- a/lib/std/select.zig +++ b/lib/std/select.zig @@ -72,6 +72,7 @@ pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { assert(a < b); const len = b - a; assert(n < len); + const target = a + n; var left: usize = a; var right: usize = b; var depth_limit: usize = math.log2_int(usize, len) * 2; // This is what C++ std::nth_element does. @@ -81,7 +82,7 @@ pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { break; } if (depth_limit == 0) { - heapSelectContext(n - (left - a), left, right, context); + heapSelectContext(target - left, left, right, context); break; } depth_limit -= 1; @@ -92,10 +93,10 @@ pub fn nthElementContext(n: usize, a: usize, b: usize, context: anytype) void { // This case is usually hit when the slice contains many duplicate elements. if (left > a and !context.lessThan(left - 1, pivot)) { left = sort.partitionEqual(left, right, pivot, context); + if (target < left) break; continue; } _ = sort.partition(left, right, &pivot, context); - const target = a + n; if (pivot == target) { break; } else if (pivot > target) {