diff --git a/parser/src/nullable.rs b/parser/src/nullable.rs index 84edb00..1575da1 100644 --- a/parser/src/nullable.rs +++ b/parser/src/nullable.rs @@ -27,6 +27,10 @@ pub fn is_column_nullable( .. } = select.body.select { + if is_table_name_in_left_or_right_join(table_name, &from) { + return None; + } + // If a column is declared as not null we only need to prove that the table it came from is always present if notnull { return get_used_table_name(table_name, &from).map(|_| NullableResult::NotNull); @@ -61,6 +65,25 @@ pub fn is_column_nullable( None } +fn is_table_name_in_left_or_right_join(table_name: &str, from: &ast::FromClause) -> bool { + if let Some(joins) = &from.joins { + for join in joins { + if let ast::JoinedSelectTable { + operator: ast::JoinOperator::TypedJoin(Some(join_type)), + table: ast::SelectTable::Table(name, _, _), + .. + } = join && (*join_type == (ast::JoinType::LEFT | ast::JoinType::OUTER) + || *join_type == (ast::JoinType::RIGHT | ast::JoinType::OUTER)) + && compare_identifier(&name.name.0, table_name) + { + return true; + } + } + } + + false +} + fn get_used_table_name<'a>(table_name: &str, from: &'a ast::FromClause) -> Option<&'a str> { if let Some(table) = &from.select { match table.as_ref() { @@ -532,5 +555,25 @@ mod tests { is_column_nullable("id", "bar", true, "select * from foo left join bar"), None ); + assert_eq!( + is_column_nullable("id", "foo", true, "select * from foo left join bar"), + Some(NullableResult::NotNull) + ); + assert_eq!( + is_column_nullable("id", "foo", true, "select * from foo left join foo"), + None + ); + assert_eq!( + is_column_nullable("id", "bar", true, "select * from foo right join bar"), + None + ); + assert_eq!( + is_column_nullable("id", "foo", true, "select * from foo right join bar"), + Some(NullableResult::NotNull) + ); + assert_eq!( + is_column_nullable("id", "foo", true, "select * from foo right join foo"), + None + ); } }