Skip to content

Commit 163183e

Browse files
committed
Improve CSV loading error reporting
1 parent f01a29f commit 163183e

File tree

3 files changed

+209
-47
lines changed

3 files changed

+209
-47
lines changed

src/utils/io.rs

Lines changed: 156 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! I/O utilities.
22
3-
use csv::ReaderBuilder;
3+
use csv::{ReaderBuilder, StringRecord};
44
use smartcore::linalg::basic::matrix::DenseMatrix;
55
use std::error::Error;
66
use std::fmt;
@@ -49,6 +49,8 @@ impl Error for CsvError {
4949
///
5050
/// Returns an error if the file cannot be read, a value fails to parse into
5151
/// `f64`, or the rows have inconsistent lengths.
52+
/// Row numbers mentioned in error messages are one-based and refer to data rows,
53+
/// excluding the header.
5254
///
5355
/// # Examples
5456
///
@@ -66,29 +68,19 @@ impl Error for CsvError {
6668
/// # }
6769
/// ```
6870
pub fn load_csv_features<P: AsRef<Path>>(path: P) -> Result<DenseMatrix<f64>, CsvError> {
69-
let file = File::open(path.as_ref()).map_err(CsvError::Io)?;
70-
let mut reader = ReaderBuilder::new()
71-
.has_headers(true)
72-
.flexible(true)
73-
.from_reader(file);
74-
71+
let mut reader = build_csv_reader(path.as_ref())?;
7572
let mut features: Vec<Vec<f64>> = Vec::new();
73+
let mut expected_width: Option<usize> = None;
7674

77-
for result in reader.records() {
75+
for (row_idx, result) in reader.records().enumerate() {
7876
let record = result.map_err(|e| CsvError::Parse(Box::new(e)))?;
79-
let row = record
80-
.iter()
81-
.map(|v| {
82-
v.parse::<f64>()
83-
.map_err(|e: ParseFloatError| CsvError::Parse(Box::new(e)))
84-
})
85-
.collect::<Result<Vec<_>, _>>()?;
77+
let row = parse_feature_row(&record, row_idx)?;
78+
ensure_consistent_width(&row, row_idx, &mut expected_width)?;
8679
features.push(row);
8780
}
8881

89-
let expected = features.first().map_or(0, Vec::len);
90-
if features.iter().any(|r| r.len() != expected) {
91-
return Err(CsvError::Shape("inconsistent row lengths".to_string()));
82+
if features.is_empty() {
83+
return Err(CsvError::Shape("no rows found".to_string()));
9284
}
9385

9486
let matrix = DenseMatrix::from_2d_vec(&features).map_err(|e| CsvError::Shape(e.to_string()))?;
@@ -106,6 +98,8 @@ pub fn load_csv_features<P: AsRef<Path>>(path: P) -> Result<DenseMatrix<f64>, Cs
10698
///
10799
/// Returns an error if the file cannot be read, a value fails to parse into
108100
/// `f64`, or the rows have inconsistent lengths.
101+
/// Row numbers mentioned in error messages are one-based and refer to data rows,
102+
/// excluding the header.
109103
///
110104
/// # Examples
111105
///
@@ -123,40 +117,159 @@ pub fn load_labeled_csv<P: AsRef<Path>>(
123117
path: P,
124118
target_col: usize,
125119
) -> Result<(DenseMatrix<f64>, Vec<f64>), CsvError> {
126-
let file = File::open(path.as_ref()).map_err(CsvError::Io)?;
127-
let mut reader = ReaderBuilder::new()
128-
.has_headers(true)
129-
.flexible(true)
130-
.from_reader(file);
131-
120+
let mut reader = build_csv_reader(path.as_ref())?;
132121
let mut features: Vec<Vec<f64>> = Vec::new();
133122
let mut targets: Vec<f64> = Vec::new();
123+
let mut expected_width: Option<usize> = None;
134124

135-
for result in reader.records() {
125+
for (row_idx, result) in reader.records().enumerate() {
136126
let record = result.map_err(|e| CsvError::Parse(Box::new(e)))?;
137-
let mut row: Vec<f64> = Vec::new();
138-
for (idx, field) in record.iter().enumerate() {
139-
let value: f64 = field
140-
.parse::<f64>()
141-
.map_err(|e: ParseFloatError| CsvError::Parse(Box::new(e)))?;
142-
if idx == target_col {
143-
targets.push(value);
144-
} else {
145-
row.push(value);
146-
}
147-
}
127+
let (row, target) = parse_labeled_row(&record, row_idx, target_col)?;
128+
ensure_consistent_width(&row, row_idx, &mut expected_width)?;
129+
targets.push(target);
148130
features.push(row);
149131
}
150132

151-
if targets.len() != features.len() {
152-
return Err(CsvError::Shape("inconsistent row lengths".to_string()));
153-
}
154-
155-
let expected = features.first().map_or(0, Vec::len);
156-
if features.iter().any(|r| r.len() != expected) {
157-
return Err(CsvError::Shape("inconsistent row lengths".to_string()));
133+
if features.is_empty() {
134+
return Err(CsvError::Shape("no rows found".to_string()));
158135
}
159136

160137
let matrix = DenseMatrix::from_2d_vec(&features).map_err(|e| CsvError::Shape(e.to_string()))?;
161138
Ok((matrix, targets))
162139
}
140+
141+
fn build_csv_reader(path: &Path) -> Result<csv::Reader<File>, CsvError> {
142+
let file = File::open(path).map_err(CsvError::Io)?;
143+
Ok(ReaderBuilder::new()
144+
.has_headers(true)
145+
.flexible(true)
146+
.from_reader(file))
147+
}
148+
149+
fn parse_feature_row(record: &StringRecord, row_idx: usize) -> Result<Vec<f64>, CsvError> {
150+
if record.is_empty() {
151+
return Err(CsvError::Shape(format!(
152+
"row {}: expected at least one column",
153+
row_idx + 1
154+
)));
155+
}
156+
157+
record
158+
.iter()
159+
.enumerate()
160+
.map(|(col_idx, value)| parse_numeric_field(value, row_idx, col_idx))
161+
.collect()
162+
}
163+
164+
fn parse_labeled_row(
165+
record: &StringRecord,
166+
row_idx: usize,
167+
target_col: usize,
168+
) -> Result<(Vec<f64>, f64), CsvError> {
169+
if record.len() <= target_col {
170+
return Err(CsvError::Shape(format!(
171+
"row {}: target column index {} out of bounds (row has {} columns)",
172+
row_idx + 1,
173+
target_col,
174+
record.len()
175+
)));
176+
}
177+
178+
if record.len() <= 1 {
179+
return Err(CsvError::Shape(format!(
180+
"row {}: expected at least one feature column in addition to the target",
181+
row_idx + 1
182+
)));
183+
}
184+
185+
let mut target = None;
186+
let mut row = Vec::with_capacity(record.len() - 1);
187+
188+
for (col_idx, value) in record.iter().enumerate() {
189+
let parsed = parse_numeric_field(value, row_idx, col_idx)?;
190+
if col_idx == target_col {
191+
target = Some(parsed);
192+
} else {
193+
row.push(parsed);
194+
}
195+
}
196+
197+
match target {
198+
Some(target_value) => Ok((row, target_value)),
199+
None => Err(CsvError::Shape(format!(
200+
"row {}: missing target column {}",
201+
row_idx + 1,
202+
target_col
203+
))),
204+
}
205+
}
206+
207+
fn parse_numeric_field(value: &str, row_idx: usize, col_idx: usize) -> Result<f64, CsvError> {
208+
value.parse::<f64>().map_err(|err: ParseFloatError| {
209+
CsvError::Parse(Box::new(FloatParseError::new(
210+
row_idx + 1,
211+
col_idx + 1,
212+
err,
213+
)))
214+
})
215+
}
216+
217+
fn ensure_consistent_width(
218+
row: &[f64],
219+
row_idx: usize,
220+
expected_width: &mut Option<usize>,
221+
) -> Result<(), CsvError> {
222+
if row.is_empty() {
223+
return Err(CsvError::Shape(format!(
224+
"row {}: expected at least one column",
225+
row_idx + 1
226+
)));
227+
}
228+
229+
match expected_width {
230+
Some(width) if row.len() != *width => Err(CsvError::Shape(format!(
231+
"row {}: expected {} columns but found {}",
232+
row_idx + 1,
233+
width,
234+
row.len()
235+
))),
236+
Some(_) => Ok(()),
237+
None => {
238+
*expected_width = Some(row.len());
239+
Ok(())
240+
}
241+
}
242+
}
243+
244+
#[derive(Debug)]
245+
struct FloatParseError {
246+
row: usize,
247+
column: usize,
248+
source: ParseFloatError,
249+
}
250+
251+
impl FloatParseError {
252+
fn new(row: usize, column: usize, source: ParseFloatError) -> Self {
253+
Self {
254+
row,
255+
column,
256+
source,
257+
}
258+
}
259+
}
260+
261+
impl fmt::Display for FloatParseError {
262+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263+
write!(
264+
f,
265+
"failed to parse float at row {}, column {}: {}",
266+
self.row, self.column, self.source
267+
)
268+
}
269+
}
270+
271+
impl Error for FloatParseError {
272+
fn source(&self) -> Option<&(dyn Error + 'static)> {
273+
Some(&self.source)
274+
}
275+
}

tests/load_csv_features.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,26 @@ fn errors_on_bad_path() {
1818
#[test]
1919
fn errors_on_non_numeric() {
2020
let err = load_csv_features("tests/fixtures/non_numeric_features.csv").unwrap_err();
21-
assert!(matches!(err, CsvError::Parse(_)));
21+
match err {
22+
CsvError::Parse(parse_err) => {
23+
let message = parse_err.to_string();
24+
assert!(message.contains("row 2"), "unexpected message: {message}");
25+
assert!(
26+
message.contains("column 1"),
27+
"unexpected message: {message}"
28+
);
29+
}
30+
other => panic!("expected parse error, got {other}"),
31+
}
2232
}
2333

2434
#[test]
2535
fn errors_on_inconsistent_rows() {
2636
let err = load_csv_features("tests/fixtures/inconsistent_features.csv").unwrap_err();
27-
assert!(matches!(err, CsvError::Shape(_)));
37+
match err {
38+
CsvError::Shape(message) => {
39+
assert!(message.contains("row 2"), "unexpected message: {message}");
40+
}
41+
other => panic!("expected shape error, got {other}"),
42+
}
2843
}

tests/load_labeled_csv.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,45 @@ fn errors_on_bad_path() {
2222
#[test]
2323
fn errors_on_non_numeric() {
2424
let err = load_labeled_csv("tests/fixtures/non_numeric_labeled.csv", 2).unwrap_err();
25-
assert!(matches!(err, CsvError::Parse(_)));
25+
match err {
26+
CsvError::Parse(parse_err) => {
27+
let message = parse_err.to_string();
28+
assert!(message.contains("row 2"), "unexpected message: {message}");
29+
assert!(
30+
message.contains("column 2"),
31+
"unexpected message: {message}"
32+
);
33+
}
34+
other => panic!("expected parse error, got {other}"),
35+
}
2636
}
2737

2838
#[test]
2939
fn errors_on_inconsistent_rows() {
3040
let err = load_labeled_csv("tests/fixtures/inconsistent_labeled.csv", 2).unwrap_err();
31-
assert!(matches!(err, CsvError::Shape(_)));
41+
match err {
42+
CsvError::Shape(message) => {
43+
assert!(message.contains("row 2"), "unexpected message: {message}");
44+
assert!(
45+
message.contains("out of bounds"),
46+
"unexpected message: {message}"
47+
);
48+
}
49+
other => panic!("expected shape error, got {other}"),
50+
}
51+
}
52+
53+
#[test]
54+
fn errors_when_target_column_out_of_range() {
55+
let err = load_labeled_csv("tests/fixtures/supervised_sample.csv", 5).unwrap_err();
56+
match err {
57+
CsvError::Shape(message) => {
58+
assert!(message.contains("row 1"), "unexpected message: {message}");
59+
assert!(
60+
message.contains("out of bounds"),
61+
"unexpected message: {message}"
62+
);
63+
}
64+
other => panic!("expected shape error, got {other}"),
65+
}
3266
}

0 commit comments

Comments
 (0)