1
1
//! I/O utilities.
2
2
3
- use csv:: ReaderBuilder ;
3
+ use csv:: { ReaderBuilder , StringRecord } ;
4
4
use smartcore:: linalg:: basic:: matrix:: DenseMatrix ;
5
5
use std:: error:: Error ;
6
6
use std:: fmt;
@@ -49,6 +49,8 @@ impl Error for CsvError {
49
49
///
50
50
/// Returns an error if the file cannot be read, a value fails to parse into
51
51
/// `f64`, or the rows have inconsistent lengths.
52
+ /// Row numbers mentioned in error messages are one-based and refer to data rows,
53
+ /// excluding the header.
52
54
///
53
55
/// # Examples
54
56
///
@@ -66,29 +68,19 @@ impl Error for CsvError {
66
68
/// # }
67
69
/// ```
68
70
pub fn load_csv_features < P : AsRef < Path > > ( path : P ) -> Result < DenseMatrix < f64 > , CsvError > {
69
- let file = File :: open ( path. as_ref ( ) ) . map_err ( CsvError :: Io ) ?;
70
- let mut reader = ReaderBuilder :: new ( )
71
- . has_headers ( true )
72
- . flexible ( true )
73
- . from_reader ( file) ;
74
-
71
+ let mut reader = build_csv_reader ( path. as_ref ( ) ) ?;
75
72
let mut features: Vec < Vec < f64 > > = Vec :: new ( ) ;
73
+ let mut expected_width: Option < usize > = None ;
76
74
77
- for result in reader. records ( ) {
75
+ for ( row_idx , result) in reader. records ( ) . enumerate ( ) {
78
76
let record = result. map_err ( |e| CsvError :: Parse ( Box :: new ( e) ) ) ?;
79
- let row = record
80
- . iter ( )
81
- . map ( |v| {
82
- v. parse :: < f64 > ( )
83
- . map_err ( |e : ParseFloatError | CsvError :: Parse ( Box :: new ( e) ) )
84
- } )
85
- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
77
+ let row = parse_feature_row ( & record, row_idx) ?;
78
+ ensure_consistent_width ( & row, row_idx, & mut expected_width) ?;
86
79
features. push ( row) ;
87
80
}
88
81
89
- let expected = features. first ( ) . map_or ( 0 , Vec :: len) ;
90
- if features. iter ( ) . any ( |r| r. len ( ) != expected) {
91
- return Err ( CsvError :: Shape ( "inconsistent row lengths" . to_string ( ) ) ) ;
82
+ if features. is_empty ( ) {
83
+ return Err ( CsvError :: Shape ( "no rows found" . to_string ( ) ) ) ;
92
84
}
93
85
94
86
let matrix = DenseMatrix :: from_2d_vec ( & features) . map_err ( |e| CsvError :: Shape ( e. to_string ( ) ) ) ?;
@@ -106,6 +98,8 @@ pub fn load_csv_features<P: AsRef<Path>>(path: P) -> Result<DenseMatrix<f64>, Cs
106
98
///
107
99
/// Returns an error if the file cannot be read, a value fails to parse into
108
100
/// `f64`, or the rows have inconsistent lengths.
101
+ /// Row numbers mentioned in error messages are one-based and refer to data rows,
102
+ /// excluding the header.
109
103
///
110
104
/// # Examples
111
105
///
@@ -123,40 +117,159 @@ pub fn load_labeled_csv<P: AsRef<Path>>(
123
117
path : P ,
124
118
target_col : usize ,
125
119
) -> Result < ( DenseMatrix < f64 > , Vec < f64 > ) , CsvError > {
126
- let file = File :: open ( path. as_ref ( ) ) . map_err ( CsvError :: Io ) ?;
127
- let mut reader = ReaderBuilder :: new ( )
128
- . has_headers ( true )
129
- . flexible ( true )
130
- . from_reader ( file) ;
131
-
120
+ let mut reader = build_csv_reader ( path. as_ref ( ) ) ?;
132
121
let mut features: Vec < Vec < f64 > > = Vec :: new ( ) ;
133
122
let mut targets: Vec < f64 > = Vec :: new ( ) ;
123
+ let mut expected_width: Option < usize > = None ;
134
124
135
- for result in reader. records ( ) {
125
+ for ( row_idx , result) in reader. records ( ) . enumerate ( ) {
136
126
let record = result. map_err ( |e| CsvError :: Parse ( Box :: new ( e) ) ) ?;
137
- let mut row: Vec < f64 > = Vec :: new ( ) ;
138
- for ( idx, field) in record. iter ( ) . enumerate ( ) {
139
- let value: f64 = field
140
- . parse :: < f64 > ( )
141
- . map_err ( |e : ParseFloatError | CsvError :: Parse ( Box :: new ( e) ) ) ?;
142
- if idx == target_col {
143
- targets. push ( value) ;
144
- } else {
145
- row. push ( value) ;
146
- }
147
- }
127
+ let ( row, target) = parse_labeled_row ( & record, row_idx, target_col) ?;
128
+ ensure_consistent_width ( & row, row_idx, & mut expected_width) ?;
129
+ targets. push ( target) ;
148
130
features. push ( row) ;
149
131
}
150
132
151
- if targets. len ( ) != features. len ( ) {
152
- return Err ( CsvError :: Shape ( "inconsistent row lengths" . to_string ( ) ) ) ;
153
- }
154
-
155
- let expected = features. first ( ) . map_or ( 0 , Vec :: len) ;
156
- if features. iter ( ) . any ( |r| r. len ( ) != expected) {
157
- return Err ( CsvError :: Shape ( "inconsistent row lengths" . to_string ( ) ) ) ;
133
+ if features. is_empty ( ) {
134
+ return Err ( CsvError :: Shape ( "no rows found" . to_string ( ) ) ) ;
158
135
}
159
136
160
137
let matrix = DenseMatrix :: from_2d_vec ( & features) . map_err ( |e| CsvError :: Shape ( e. to_string ( ) ) ) ?;
161
138
Ok ( ( matrix, targets) )
162
139
}
140
+
141
+ fn build_csv_reader ( path : & Path ) -> Result < csv:: Reader < File > , CsvError > {
142
+ let file = File :: open ( path) . map_err ( CsvError :: Io ) ?;
143
+ Ok ( ReaderBuilder :: new ( )
144
+ . has_headers ( true )
145
+ . flexible ( true )
146
+ . from_reader ( file) )
147
+ }
148
+
149
+ fn parse_feature_row ( record : & StringRecord , row_idx : usize ) -> Result < Vec < f64 > , CsvError > {
150
+ if record. is_empty ( ) {
151
+ return Err ( CsvError :: Shape ( format ! (
152
+ "row {}: expected at least one column" ,
153
+ row_idx + 1
154
+ ) ) ) ;
155
+ }
156
+
157
+ record
158
+ . iter ( )
159
+ . enumerate ( )
160
+ . map ( |( col_idx, value) | parse_numeric_field ( value, row_idx, col_idx) )
161
+ . collect ( )
162
+ }
163
+
164
+ fn parse_labeled_row (
165
+ record : & StringRecord ,
166
+ row_idx : usize ,
167
+ target_col : usize ,
168
+ ) -> Result < ( Vec < f64 > , f64 ) , CsvError > {
169
+ if record. len ( ) <= target_col {
170
+ return Err ( CsvError :: Shape ( format ! (
171
+ "row {}: target column index {} out of bounds (row has {} columns)" ,
172
+ row_idx + 1 ,
173
+ target_col,
174
+ record. len( )
175
+ ) ) ) ;
176
+ }
177
+
178
+ if record. len ( ) <= 1 {
179
+ return Err ( CsvError :: Shape ( format ! (
180
+ "row {}: expected at least one feature column in addition to the target" ,
181
+ row_idx + 1
182
+ ) ) ) ;
183
+ }
184
+
185
+ let mut target = None ;
186
+ let mut row = Vec :: with_capacity ( record. len ( ) - 1 ) ;
187
+
188
+ for ( col_idx, value) in record. iter ( ) . enumerate ( ) {
189
+ let parsed = parse_numeric_field ( value, row_idx, col_idx) ?;
190
+ if col_idx == target_col {
191
+ target = Some ( parsed) ;
192
+ } else {
193
+ row. push ( parsed) ;
194
+ }
195
+ }
196
+
197
+ match target {
198
+ Some ( target_value) => Ok ( ( row, target_value) ) ,
199
+ None => Err ( CsvError :: Shape ( format ! (
200
+ "row {}: missing target column {}" ,
201
+ row_idx + 1 ,
202
+ target_col
203
+ ) ) ) ,
204
+ }
205
+ }
206
+
207
+ fn parse_numeric_field ( value : & str , row_idx : usize , col_idx : usize ) -> Result < f64 , CsvError > {
208
+ value. parse :: < f64 > ( ) . map_err ( |err : ParseFloatError | {
209
+ CsvError :: Parse ( Box :: new ( FloatParseError :: new (
210
+ row_idx + 1 ,
211
+ col_idx + 1 ,
212
+ err,
213
+ ) ) )
214
+ } )
215
+ }
216
+
217
+ fn ensure_consistent_width (
218
+ row : & [ f64 ] ,
219
+ row_idx : usize ,
220
+ expected_width : & mut Option < usize > ,
221
+ ) -> Result < ( ) , CsvError > {
222
+ if row. is_empty ( ) {
223
+ return Err ( CsvError :: Shape ( format ! (
224
+ "row {}: expected at least one column" ,
225
+ row_idx + 1
226
+ ) ) ) ;
227
+ }
228
+
229
+ match expected_width {
230
+ Some ( width) if row. len ( ) != * width => Err ( CsvError :: Shape ( format ! (
231
+ "row {}: expected {} columns but found {}" ,
232
+ row_idx + 1 ,
233
+ width,
234
+ row. len( )
235
+ ) ) ) ,
236
+ Some ( _) => Ok ( ( ) ) ,
237
+ None => {
238
+ * expected_width = Some ( row. len ( ) ) ;
239
+ Ok ( ( ) )
240
+ }
241
+ }
242
+ }
243
+
244
+ #[ derive( Debug ) ]
245
+ struct FloatParseError {
246
+ row : usize ,
247
+ column : usize ,
248
+ source : ParseFloatError ,
249
+ }
250
+
251
+ impl FloatParseError {
252
+ fn new ( row : usize , column : usize , source : ParseFloatError ) -> Self {
253
+ Self {
254
+ row,
255
+ column,
256
+ source,
257
+ }
258
+ }
259
+ }
260
+
261
+ impl fmt:: Display for FloatParseError {
262
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
263
+ write ! (
264
+ f,
265
+ "failed to parse float at row {}, column {}: {}" ,
266
+ self . row, self . column, self . source
267
+ )
268
+ }
269
+ }
270
+
271
+ impl Error for FloatParseError {
272
+ fn source ( & self ) -> Option < & ( dyn Error + ' static ) > {
273
+ Some ( & self . source )
274
+ }
275
+ }
0 commit comments