@@ -37,7 +37,7 @@ use thiserror::Error;
37
37
/// The codec only supports floating point data.
38
38
pub struct RoundCodec {
39
39
/// Precision of the rounding operation
40
- pub precision : Positive < f64 > ,
40
+ pub precision : NonNegative < f64 > ,
41
41
/// The codec's encoding format version. Do not provide this parameter explicitly.
42
42
#[ serde( default , rename = "_version" ) ]
43
43
pub version : StaticCodecVersion < 1 , 0 , 0 > ,
@@ -51,7 +51,7 @@ impl Codec for RoundCodec {
51
51
#[ expect( clippy:: cast_possible_truncation) ]
52
52
AnyCowArray :: F32 ( data) => Ok ( AnyArray :: F32 ( round (
53
53
data,
54
- Positive ( self . precision . 0 as f32 ) ,
54
+ NonNegative ( self . precision . 0 as f32 ) ,
55
55
) ) ) ,
56
56
AnyCowArray :: F64 ( data) => Ok ( AnyArray :: F64 ( round ( data, self . precision ) ) ) ,
57
57
encoded => Err ( RoundCodecError :: UnsupportedDtype ( encoded. dtype ( ) ) ) ,
@@ -95,37 +95,37 @@ impl StaticCodec for RoundCodec {
95
95
96
96
#[ expect( clippy:: derive_partial_eq_without_eq) ] // floats are not Eq
97
97
#[ derive( Copy , Clone , PartialEq , PartialOrd , Hash ) ]
98
- /// Positive floating point number
99
- pub struct Positive < T : Float > ( T ) ;
98
+ /// Non-negative floating point number
99
+ pub struct NonNegative < T : Float > ( T ) ;
100
100
101
- impl Serialize for Positive < f64 > {
101
+ impl Serialize for NonNegative < f64 > {
102
102
fn serialize < S : Serializer > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error > {
103
103
serializer. serialize_f64 ( self . 0 )
104
104
}
105
105
}
106
106
107
- impl < ' de > Deserialize < ' de > for Positive < f64 > {
107
+ impl < ' de > Deserialize < ' de > for NonNegative < f64 > {
108
108
fn deserialize < D : Deserializer < ' de > > ( deserializer : D ) -> Result < Self , D :: Error > {
109
109
let x = f64:: deserialize ( deserializer) ?;
110
110
111
- if x > 0.0 {
111
+ if x >= 0.0 {
112
112
Ok ( Self ( x) )
113
113
} else {
114
114
Err ( serde:: de:: Error :: invalid_value (
115
115
serde:: de:: Unexpected :: Float ( x) ,
116
- & "a positive value" ,
116
+ & "a non-negative value" ,
117
117
) )
118
118
}
119
119
}
120
120
}
121
121
122
- impl JsonSchema for Positive < f64 > {
122
+ impl JsonSchema for NonNegative < f64 > {
123
123
fn schema_name ( ) -> Cow < ' static , str > {
124
- Cow :: Borrowed ( "PositiveF64 " )
124
+ Cow :: Borrowed ( "NonNegativeF64 " )
125
125
}
126
126
127
127
fn schema_id ( ) -> Cow < ' static , str > {
128
- Cow :: Borrowed ( concat ! ( module_path!( ) , "::" , "Positive <f64>" ) )
128
+ Cow :: Borrowed ( concat ! ( module_path!( ) , "::" , "NonNegative <f64>" ) )
129
129
}
130
130
131
131
fn json_schema ( _gen : & mut SchemaGenerator ) -> Schema {
@@ -154,11 +154,104 @@ pub enum RoundCodecError {
154
154
#[ must_use]
155
155
/// Rounds the input `data` using
156
156
/// `$c = \text{round}\left( \frac{x}{precision} \right) \cdot precision$`
157
+ ///
158
+ /// If precision is zero, the `data` is returned unchanged.
157
159
pub fn round < T : Float , S : Data < Elem = T > , D : Dimension > (
158
160
data : ArrayBase < S , D > ,
159
- precision : Positive < T > ,
161
+ precision : NonNegative < T > ,
160
162
) -> Array < T , D > {
161
163
let mut encoded = data. into_owned ( ) ;
162
- encoded. mapv_inplace ( |x| ( x / precision. 0 ) . round ( ) * precision. 0 ) ;
164
+
165
+ if precision. 0 . is_zero ( ) {
166
+ return encoded;
167
+ }
168
+
169
+ encoded. mapv_inplace ( |x| {
170
+ let n = x / precision. 0 ;
171
+
172
+ // if x / precision is not finite, don't try to round
173
+ // e.g. when x / eps = inf
174
+ if !n. is_finite ( ) {
175
+ return x;
176
+ }
177
+
178
+ // round x to be a multiple of precision
179
+ n. round ( ) * precision. 0
180
+ } ) ;
181
+
163
182
encoded
164
183
}
184
+
185
+ #[ cfg( test) ]
186
+ mod tests {
187
+ use ndarray:: array;
188
+
189
+ use super :: * ;
190
+
191
+ #[ test]
192
+ fn round_zero_precision ( ) {
193
+ let data = array ! [ 1.1 , 2.1 ] ;
194
+
195
+ let rounded = round ( data. view ( ) , NonNegative ( 0.0 ) ) ;
196
+
197
+ assert_eq ! ( data, rounded) ;
198
+ }
199
+
200
+ #[ test]
201
+ fn round_minimal_precision ( ) {
202
+ let data = array ! [ 0.1 , 1.0 , 11.0 , 21.0 ] ;
203
+
204
+ assert_eq ! ( 11.0 / f64 :: MIN_POSITIVE , f64 :: INFINITY ) ;
205
+ let rounded = round ( data. view ( ) , NonNegative ( f64:: MIN_POSITIVE ) ) ;
206
+
207
+ assert_eq ! ( data, rounded) ;
208
+ }
209
+
210
+ #[ test]
211
+ fn round_roundoff_errors ( ) {
212
+ let data = array ! [ 0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1.0 ] ;
213
+
214
+ let rounded = round ( data. view ( ) , NonNegative ( 0.1 ) ) ;
215
+
216
+ assert_eq ! (
217
+ rounded,
218
+ array![
219
+ 0.0 ,
220
+ 0.1 ,
221
+ 0.2 ,
222
+ 0.30000000000000004 ,
223
+ 0.4 ,
224
+ 0.5 ,
225
+ 0.6000000000000001 ,
226
+ 0.7000000000000001 ,
227
+ 0.8 ,
228
+ 0.9 ,
229
+ 1.0
230
+ ]
231
+ ) ;
232
+
233
+ let rounded_twice = round ( rounded. view ( ) , NonNegative ( 0.1 ) ) ;
234
+
235
+ assert_eq ! ( rounded, rounded_twice) ;
236
+ }
237
+
238
+ #[ test]
239
+ fn round_edge_cases ( ) {
240
+ let data = array ! [
241
+ -f64 :: NAN ,
242
+ -f64 :: INFINITY ,
243
+ -42.0 ,
244
+ -0.0 ,
245
+ 0.0 ,
246
+ 42.0 ,
247
+ f64 :: INFINITY ,
248
+ f64 :: NAN
249
+ ] ;
250
+
251
+ let rounded = round ( data. view ( ) , NonNegative ( 1.0 ) ) ;
252
+
253
+ for ( d, r) in data. into_iter ( ) . zip ( rounded) {
254
+ assert ! ( d == r || d. to_bits( ) == r. to_bits( ) ) ;
255
+ }
256
+ }
257
+ }
0 commit comments