Skip to content

Commit fcc242d

Browse files
committed
Fix rounding with tiny precision
1 parent 57036da commit fcc242d

File tree

3 files changed

+109
-16
lines changed

3 files changed

+109
-16
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ numcodecs-zstd = { version = "0.3", path = "codecs/zstd", default-features = fal
7777

7878
# crates.io third-party dependencies
7979
anyhow = { version = "1.0.93", default-features = false }
80-
burn = { version = "0.17", default-features = false }
80+
burn = { version = "0.18", default-features = false }
8181
clap = { version = "4.5", default-features = false }
8282
convert_case = { version = "0.8", default-features = false }
8383
format_serde_error = { version = "0.3", default-features = false }
8484
indexmap = { version = "2.7.1", default-features = false }
8585
itertools = { version = "0.14", default-features = false }
8686
log = { version = "0.4.27", default-features = false }
8787
simple_logger = { version = "5.0", default-features = false }
88-
miniz_oxide = { version = "0.8.4", default-features = false }
88+
miniz_oxide = { version = "0.8.5", default-features = false }
8989
ndarray = { version = "0.16.1", default-features = false } # keep in sync with numpy
9090
ndarray-rand = { version = "0.15", default-features = false }
9191
numpy = { version = "0.25", default-features = false }

codecs/round/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "numcodecs-round"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
edition = { workspace = true }
55
authors = { workspace = true }
66
repository = { workspace = true }

codecs/round/src/lib.rs

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use thiserror::Error;
3737
/// The codec only supports floating point data.
3838
pub struct RoundCodec {
3939
/// Precision of the rounding operation
40-
pub precision: Positive<f64>,
40+
pub precision: NonNegative<f64>,
4141
/// The codec's encoding format version. Do not provide this parameter explicitly.
4242
#[serde(default, rename = "_version")]
4343
pub version: StaticCodecVersion<1, 0, 0>,
@@ -51,7 +51,7 @@ impl Codec for RoundCodec {
5151
#[expect(clippy::cast_possible_truncation)]
5252
AnyCowArray::F32(data) => Ok(AnyArray::F32(round(
5353
data,
54-
Positive(self.precision.0 as f32),
54+
NonNegative(self.precision.0 as f32),
5555
))),
5656
AnyCowArray::F64(data) => Ok(AnyArray::F64(round(data, self.precision))),
5757
encoded => Err(RoundCodecError::UnsupportedDtype(encoded.dtype())),
@@ -95,37 +95,37 @@ impl StaticCodec for RoundCodec {
9595

9696
#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
9797
#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
98-
/// Positive floating point number
99-
pub struct Positive<T: Float>(T);
98+
/// Non-negative floating point number
99+
pub struct NonNegative<T: Float>(T);
100100

101-
impl Serialize for Positive<f64> {
101+
impl Serialize for NonNegative<f64> {
102102
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
103103
serializer.serialize_f64(self.0)
104104
}
105105
}
106106

107-
impl<'de> Deserialize<'de> for Positive<f64> {
107+
impl<'de> Deserialize<'de> for NonNegative<f64> {
108108
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
109109
let x = f64::deserialize(deserializer)?;
110110

111-
if x > 0.0 {
111+
if x >= 0.0 {
112112
Ok(Self(x))
113113
} else {
114114
Err(serde::de::Error::invalid_value(
115115
serde::de::Unexpected::Float(x),
116-
&"a positive value",
116+
&"a non-negative value",
117117
))
118118
}
119119
}
120120
}
121121

122-
impl JsonSchema for Positive<f64> {
122+
impl JsonSchema for NonNegative<f64> {
123123
fn schema_name() -> Cow<'static, str> {
124-
Cow::Borrowed("PositiveF64")
124+
Cow::Borrowed("NonNegativeF64")
125125
}
126126

127127
fn schema_id() -> Cow<'static, str> {
128-
Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
128+
Cow::Borrowed(concat!(module_path!(), "::", "NonNegative<f64>"))
129129
}
130130

131131
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
@@ -154,11 +154,104 @@ pub enum RoundCodecError {
154154
#[must_use]
155155
/// Rounds the input `data` using
156156
/// `$c = \text{round}\left( \frac{x}{precision} \right) \cdot precision$`
157+
///
158+
/// If precision is zero, the `data` is returned unchanged.
157159
pub fn round<T: Float, S: Data<Elem = T>, D: Dimension>(
158160
data: ArrayBase<S, D>,
159-
precision: Positive<T>,
161+
precision: NonNegative<T>,
160162
) -> Array<T, D> {
161163
let mut encoded = data.into_owned();
162-
encoded.mapv_inplace(|x| (x / precision.0).round() * precision.0);
164+
165+
if precision.0.is_zero() {
166+
return encoded;
167+
}
168+
169+
encoded.mapv_inplace(|x| {
170+
let n = x / precision.0;
171+
172+
// if x / precision is not finite, don't try to round
173+
// e.g. when x / eps = inf
174+
if !n.is_finite() {
175+
return x;
176+
}
177+
178+
// round x to be a multiple of precision
179+
n.round() * precision.0
180+
});
181+
163182
encoded
164183
}
184+
185+
#[cfg(test)]
186+
mod tests {
187+
use ndarray::array;
188+
189+
use super::*;
190+
191+
#[test]
192+
fn round_zero_precision() {
193+
let data = array![1.1, 2.1];
194+
195+
let rounded = round(data.view(), NonNegative(0.0));
196+
197+
assert_eq!(data, rounded);
198+
}
199+
200+
#[test]
201+
fn round_minimal_precision() {
202+
let data = array![0.1, 1.0, 11.0, 21.0];
203+
204+
assert_eq!(11.0 / f64::MIN_POSITIVE, f64::INFINITY);
205+
let rounded = round(data.view(), NonNegative(f64::MIN_POSITIVE));
206+
207+
assert_eq!(data, rounded);
208+
}
209+
210+
#[test]
211+
fn round_roundoff_errors() {
212+
let data = array![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
213+
214+
let rounded = round(data.view(), NonNegative(0.1));
215+
216+
assert_eq!(
217+
rounded,
218+
array![
219+
0.0,
220+
0.1,
221+
0.2,
222+
0.30000000000000004,
223+
0.4,
224+
0.5,
225+
0.6000000000000001,
226+
0.7000000000000001,
227+
0.8,
228+
0.9,
229+
1.0
230+
]
231+
);
232+
233+
let rounded_twice = round(rounded.view(), NonNegative(0.1));
234+
235+
assert_eq!(rounded, rounded_twice);
236+
}
237+
238+
#[test]
239+
fn round_edge_cases() {
240+
let data = array![
241+
-f64::NAN,
242+
-f64::INFINITY,
243+
-42.0,
244+
-0.0,
245+
0.0,
246+
42.0,
247+
f64::INFINITY,
248+
f64::NAN
249+
];
250+
251+
let rounded = round(data.view(), NonNegative(1.0));
252+
253+
for (d, r) in data.into_iter().zip(rounded) {
254+
assert!(d == r || d.to_bits() == r.to_bits());
255+
}
256+
}
257+
}

0 commit comments

Comments
 (0)