Skip to content

Commit ac188cb

Browse files
authored
Add valuesort, new control result mode and value normalizer (#237)
* Merged in PR 123 from parent project Signed-off-by: Bruce Ritchie <[email protected]> * Adding resultmode control and custom value normalizer. Bumped version. Signed-off-by: Bruce Ritchie <[email protected]> * Added test to verify resultmode can be updated Signed-off-by: Bruce Ritchie <[email protected]> * Cargo fmt. Signed-off-by: Bruce Ritchie <[email protected]> * Cargo clippy. Signed-off-by: Bruce Ritchie <[email protected]> * elide lifetimes to get ci check to pass. Signed-off-by: Bruce Ritchie <[email protected]> * Updates after merge with upstream Signed-off-by: Bruce Ritchie <[email protected]> * Added valuesort test, updated changelog. Signed-off-by: Bruce Ritchie <[email protected]> --------- Signed-off-by: Bruce Ritchie <[email protected]>
1 parent e08bc06 commit ac188cb

File tree

12 files changed

+432
-229
lines changed

12 files changed

+432
-229
lines changed

CHANGELOG.md

Lines changed: 103 additions & 52 deletions
Large diffs are not rendered by default.

Cargo.lock

Lines changed: 149 additions & 156 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]
44

55
[workspace.package]
6-
version = "0.23.1"
6+
version = "0.24.0"
77
edition = "2021"
88
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
99
keywords = ["sql", "database", "parser", "cli"]

sqllogictest-bin/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ glob = "0.3"
2323
itertools = "0.13"
2424
quick-junit = { version = "0.5" }
2525
rand = "0.8"
26-
sqllogictest = { path = "../sqllogictest", version = "0.23" }
27-
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.23" }
26+
sqllogictest = { path = "../sqllogictest", version = "0.24" }
27+
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.24" }
2828
tokio = { version = "1", features = [
2929
"rt",
3030
"rt-multi-thread",

sqllogictest-bin/src/main.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
1717
use rand::distributions::DistString;
1818
use rand::seq::SliceRandom;
1919
use sqllogictest::{
20-
default_column_validator, default_validator, update_record_with_output, AsyncDB, Injected,
21-
MakeConnection, Record, Runner,
20+
default_column_validator, default_normalizer, default_validator, update_record_with_output,
21+
AsyncDB, Injected, MakeConnection, Record, Runner,
2222
};
2323

2424
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
@@ -770,6 +770,7 @@ async fn update_record<M: MakeConnection>(
770770
&record_output,
771771
"\t",
772772
default_validator,
773+
default_normalizer,
773774
default_column_validator,
774775
) {
775776
Some(new_record) => {

sqllogictest-engines/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
2020
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
2121
serde = { version = "1", features = ["derive"] }
2222
serde_json = "1"
23-
sqllogictest = { path = "../sqllogictest", version = "0.23" }
23+
sqllogictest = { path = "../sqllogictest", version = "0.24" }
2424
thiserror = "2"
2525
tokio = { version = "1", features = [
2626
"rt",

sqllogictest/src/parser.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ pub enum QueryExpect<T: ColumnType> {
8585
Results {
8686
types: Vec<T>,
8787
sort_mode: Option<SortMode>,
88+
result_mode: Option<ResultMode>,
8889
label: Option<String>,
8990
results: Vec<String>,
9091
},
@@ -98,6 +99,7 @@ impl<T: ColumnType> QueryExpect<T> {
9899
Self::Results {
99100
types: Vec::new(),
100101
sort_mode: None,
102+
result_mode: None,
101103
label: None,
102104
results: Vec::new(),
103105
}
@@ -287,6 +289,7 @@ impl<T: ColumnType> std::fmt::Display for Record<T> {
287289
}
288290
Record::Control(c) => match c {
289291
Control::SortMode(m) => write!(f, "control sortmode {}", m.as_str()),
292+
Control::ResultMode(m) => write!(f, "control resultmode {}", m.as_str()),
290293
Control::Substitution(s) => write!(f, "control substitution {}", s.as_str()),
291294
},
292295
Record::Condition(cond) => match cond {
@@ -435,6 +438,8 @@ impl PartialEq for ExpectedError {
435438
pub enum Control {
436439
/// Control sort mode.
437440
SortMode(SortMode),
441+
/// control result mode.
442+
ResultMode(ResultMode),
438443
/// Control whether or not to substitute variables in the SQL.
439444
Substitution(bool),
440445
}
@@ -545,6 +550,38 @@ impl ControlItem for SortMode {
545550
}
546551
}
547552

553+
/// Whether the results should be parsed as value-wise or row-wise
554+
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
555+
pub enum ResultMode {
556+
/// Results are in a single column
557+
ValueWise,
558+
/// The default option where results are in columns separated by spaces
559+
RowWise,
560+
}
561+
562+
impl ControlItem for ResultMode {
563+
fn try_from_str(s: &str) -> Result<Self, ParseErrorKind> {
564+
match s {
565+
"rowwise" => Ok(Self::RowWise),
566+
"valuewise" => Ok(Self::ValueWise),
567+
_ => Err(ParseErrorKind::InvalidSortMode(s.to_string())),
568+
}
569+
}
570+
571+
fn as_str(&self) -> &'static str {
572+
match self {
573+
Self::RowWise => "rowwise",
574+
Self::ValueWise => "valuewise",
575+
}
576+
}
577+
}
578+
579+
impl fmt::Display for ResultMode {
580+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
581+
write!(f, "{self:?}")
582+
}
583+
}
584+
548585
/// The error type for parsing sqllogictest.
549586
#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
550587
#[error("parse error at {loc}: {kind}")]
@@ -754,6 +791,7 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
754791
QueryExpect::Results {
755792
types,
756793
sort_mode,
794+
result_mode: None,
757795
label,
758796
results: Vec::new(),
759797
}
@@ -812,6 +850,12 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
812850
});
813851
}
814852
["control", res @ ..] => match res {
853+
["resultmode", result_mode] => match ResultMode::try_from_str(result_mode) {
854+
Ok(result_mode) => {
855+
records.push(Record::Control(Control::ResultMode(result_mode)))
856+
}
857+
Err(k) => return Err(k.at(loc)),
858+
},
815859
["sortmode", sort_mode] => match SortMode::try_from_str(sort_mode) {
816860
Ok(sort_mode) => records.push(Record::Control(Control::SortMode(sort_mode))),
817861
Err(k) => return Err(k.at(loc)),
@@ -988,6 +1032,11 @@ mod tests {
9881032
parse_roundtrip::<DefaultColumnType>("../tests/slt/rowsort.slt")
9891033
}
9901034

1035+
#[test]
1036+
fn test_valuesort() {
1037+
parse_roundtrip::<DefaultColumnType>("../tests/slt/valuesort.slt")
1038+
}
1039+
9911040
#[test]
9921041
fn test_substitution() {
9931042
parse_roundtrip::<DefaultColumnType>("../tests/substitution/basic.slt")

sqllogictest/src/runner.rs

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -449,26 +449,39 @@ fn format_column_diff(expected: &str, actual: &str, colorize: bool) -> String {
449449
format!("[Expected] {expected}\n[Actual ] {actual}")
450450
}
451451

452+
/// Normalizer will be used by [`Runner`] to normalize the result values
453+
///
454+
/// # Default
455+
///
456+
/// By default, the ([`default_normalizer`]) will be used to normalize values.
457+
pub type Normalizer = fn(s: &String) -> String;
458+
452459
/// Trim and replace multiple whitespaces with one.
453460
#[allow(clippy::ptr_arg)]
454-
fn normalize_string(s: &String) -> String {
461+
pub fn default_normalizer(s: &String) -> String {
455462
s.trim().split_ascii_whitespace().join(" ")
456463
}
457464

458465
/// Validator will be used by [`Runner`] to validate the output.
459466
///
460467
/// # Default
461468
///
462-
/// By default ([`default_validator`]), we will use compare normalized results.
463-
pub type Validator = fn(actual: &[Vec<String>], expected: &[String]) -> bool;
464-
465-
pub fn default_validator(actual: &[Vec<String>], expected: &[String]) -> bool {
466-
let expected_results = expected.iter().map(normalize_string).collect_vec();
469+
/// By default, the ([`default_validator`]) will be used compare normalized results.
470+
pub type Validator =
471+
fn(normalizer: Normalizer, actual: &[Vec<String>], expected: &[String]) -> bool;
472+
473+
pub fn default_validator(
474+
normalizer: Normalizer,
475+
actual: &[Vec<String>],
476+
expected: &[String],
477+
) -> bool {
478+
let expected_results = expected.iter().map(normalizer).collect_vec();
467479
// Default, we compare normalized results. Whitespace characters are ignored.
468480
let normalized_rows = actual
469481
.iter()
470-
.map(|strs| strs.iter().map(normalize_string).join(" "))
482+
.map(|strs| strs.iter().map(normalizer).join(" "))
471483
.collect_vec();
484+
472485
normalized_rows == expected_results
473486
}
474487

@@ -502,9 +515,12 @@ pub struct Runner<D: AsyncDB, M: MakeConnection> {
502515
conn: Connections<D, M>,
503516
// validator is used for validate if the result of query equals to expected.
504517
validator: Validator,
518+
// normalizer is used to normalize the result text
519+
normalizer: Normalizer,
505520
column_type_validator: ColumnTypeValidator<D::ColumnType>,
506521
substitution: Option<Substitution>,
507522
sort_mode: Option<SortMode>,
523+
result_mode: Option<ResultMode>,
508524
/// 0 means never hashing
509525
hash_threshold: usize,
510526
/// Labels for condition `skipif` and `onlyif`.
@@ -518,9 +534,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
518534
pub fn new(make_conn: M) -> Self {
519535
Runner {
520536
validator: default_validator,
537+
normalizer: default_normalizer,
521538
column_type_validator: default_column_validator,
522539
substitution: None,
523540
sort_mode: None,
541+
result_mode: None,
524542
hash_threshold: 0,
525543
labels: HashSet::new(),
526544
conn: Connections::new(make_conn),
@@ -532,6 +550,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
532550
self.labels.insert(label.to_string());
533551
}
534552

553+
pub fn with_normalizer(&mut self, normalizer: Normalizer) {
554+
self.normalizer = normalizer;
555+
}
535556
pub fn with_validator(&mut self, validator: Validator) {
536557
self.validator = validator;
537558
}
@@ -769,15 +790,31 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
769790
QueryExpect::Error(_) => None,
770791
}
771792
.or(self.sort_mode);
793+
794+
let mut value_sort = false;
772795
match sort_mode {
773796
None | Some(SortMode::NoSort) => {}
774797
Some(SortMode::RowSort) => {
775798
rows.sort_unstable();
776799
}
777-
Some(SortMode::ValueSort) => todo!("value sort"),
800+
Some(SortMode::ValueSort) => {
801+
rows = rows
802+
.iter()
803+
.flat_map(|row| row.iter())
804+
.map(|s| vec![s.to_owned()])
805+
.collect();
806+
rows.sort_unstable();
807+
value_sort = true;
808+
}
778809
};
779810

780-
if self.hash_threshold > 0 && rows.len() * types.len() > self.hash_threshold {
811+
let num_values = if value_sort {
812+
rows.len()
813+
} else {
814+
rows.len() * types.len()
815+
};
816+
817+
if self.hash_threshold > 0 && num_values > self.hash_threshold {
781818
let mut md5 = md5::Md5::new();
782819
for line in &rows {
783820
for value in line {
@@ -808,6 +845,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
808845
Control::SortMode(sort_mode) => {
809846
self.sort_mode = Some(sort_mode);
810847
}
848+
Control::ResultMode(result_mode) => {
849+
self.result_mode = Some(result_mode);
850+
}
811851
Control::Substitution(on_off) => match (&mut self.substitution, on_off) {
812852
(s @ None, true) => *s = Some(Substitution::default()),
813853
(s @ Some(_), false) => *s = None,
@@ -996,7 +1036,17 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
9961036
.at(loc));
9971037
}
9981038

999-
if !(self.validator)(rows, &expected_results) {
1039+
let actual_results = match self.result_mode {
1040+
Some(ResultMode::ValueWise) => rows
1041+
.iter()
1042+
.flat_map(|strs| strs.iter())
1043+
.map(|str| vec![str.to_string()])
1044+
.collect_vec(),
1045+
// default to rowwise
1046+
_ => rows.clone(),
1047+
};
1048+
1049+
if !(self.validator)(self.normalizer, &actual_results, &expected_results) {
10001050
let output_rows =
10011051
rows.iter().map(|strs| strs.iter().join(" ")).collect_vec();
10021052
return Err(TestErrorKind::QueryResultMismatch {
@@ -1167,9 +1217,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
11671217
conn_builder(target.clone(), db_name.clone()).map(Ok)
11681218
}),
11691219
validator: self.validator,
1220+
normalizer: self.normalizer,
11701221
column_type_validator: self.column_type_validator,
11711222
substitution: self.substitution.clone(),
11721223
sort_mode: self.sort_mode,
1224+
result_mode: self.result_mode,
11731225
hash_threshold: self.hash_threshold,
11741226
labels: self.labels.clone(),
11751227
};
@@ -1240,6 +1292,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
12401292
filename: impl AsRef<Path>,
12411293
col_separator: &str,
12421294
validator: Validator,
1295+
normalizer: Normalizer,
12431296
column_type_validator: ColumnTypeValidator<D::ColumnType>,
12441297
) -> Result<(), Box<dyn std::error::Error>> {
12451298
use std::io::{Read, Seek, SeekFrom, Write};
@@ -1355,6 +1408,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
13551408
&record_output,
13561409
col_separator,
13571410
validator,
1411+
normalizer,
13581412
column_type_validator,
13591413
)
13601414
.unwrap_or(record);
@@ -1384,6 +1438,7 @@ pub fn update_record_with_output<T: ColumnType>(
13841438
record_output: &RecordOutput<T>,
13851439
col_separator: &str,
13861440
validator: Validator,
1441+
normalizer: Normalizer,
13871442
column_type_validator: ColumnTypeValidator<T>,
13881443
) -> Option<Record<T>> {
13891444
match (record.clone(), record_output) {
@@ -1523,7 +1578,7 @@ pub fn update_record_with_output<T: ColumnType>(
15231578
QueryExpect::Results {
15241579
results: expected_results,
15251580
..
1526-
} if validator(rows, expected_results) => expected_results.clone(),
1581+
} if validator(normalizer, rows, expected_results) => expected_results.clone(),
15271582
_ => rows.iter().map(|cols| cols.join(col_separator)).collect(),
15281583
};
15291584
let types = match &expected {
@@ -1541,17 +1596,22 @@ pub fn update_record_with_output<T: ColumnType>(
15411596
connection,
15421597
expected: match expected {
15431598
QueryExpect::Results {
1544-
sort_mode, label, ..
1599+
sort_mode,
1600+
label,
1601+
result_mode,
1602+
..
15451603
} => QueryExpect::Results {
15461604
results,
15471605
types,
15481606
sort_mode,
1607+
result_mode,
15491608
label,
15501609
},
15511610
QueryExpect::Error(_) => QueryExpect::Results {
15521611
results,
15531612
types,
15541613
sort_mode: None,
1614+
result_mode: None,
15551615
label: None,
15561616
},
15571617
},
@@ -2009,6 +2069,7 @@ Caused by:
20092069
&record_output,
20102070
" ",
20112071
default_validator,
2072+
default_normalizer,
20122073
strict_column_validator,
20132074
);
20142075

tests/custom_type/custom_type.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,8 @@ fn test() {
6969
let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) });
7070
tester.with_column_validator(strict_column_validator);
7171

72-
tester.run_file("./custom_type/custom_type.slt").unwrap();
72+
let r = tester.run_file("./custom_type/custom_type.slt");
73+
if let Err(err) = r {
74+
eprintln!("{:?}", err);
75+
}
7376
}

0 commit comments

Comments
 (0)