Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions parser/src/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ pub fn is_column_nullable(
..
} = select.body.select
{
if is_table_name_in_left_or_right_join(table_name, &from) {
return None;
}

// If a column is declared as not null we only need to prove that the table it came from is always present
if notnull {
return get_used_table_name(table_name, &from).map(|_| NullableResult::NotNull);
Expand Down Expand Up @@ -61,6 +65,25 @@ pub fn is_column_nullable(
None
}

fn is_table_name_in_left_or_right_join(table_name: &str, from: &ast::FromClause) -> bool {
if let Some(joins) = &from.joins {
for join in joins {
if let ast::JoinedSelectTable {
operator: ast::JoinOperator::TypedJoin(Some(join_type)),
table: ast::SelectTable::Table(name, _, _),
..
} = join && (*join_type == (ast::JoinType::LEFT | ast::JoinType::OUTER)
|| *join_type == (ast::JoinType::RIGHT | ast::JoinType::OUTER))
&& compare_identifier(&name.name.0, table_name)
{
return true;
}
}
}

false
}

fn get_used_table_name<'a>(table_name: &str, from: &'a ast::FromClause) -> Option<&'a str> {
if let Some(table) = &from.select {
match table.as_ref() {
Expand Down Expand Up @@ -532,5 +555,25 @@ mod tests {
is_column_nullable("id", "bar", true, "select * from foo left join bar"),
None
);
assert_eq!(
is_column_nullable("id", "foo", true, "select * from foo left join bar"),
Some(NullableResult::NotNull)
);
assert_eq!(
is_column_nullable("id", "foo", true, "select * from foo left join foo"),
None
);
assert_eq!(
is_column_nullable("id", "bar", true, "select * from foo right join bar"),
None
);
assert_eq!(
is_column_nullable("id", "foo", true, "select * from foo right join bar"),
Some(NullableResult::NotNull)
);
assert_eq!(
is_column_nullable("id", "foo", true, "select * from foo right join foo"),
None
);
}
}