Skip to content

Commit dae8232

Browse files
committed
Fix on top of Copilot's work.
1 parent 3de19b9 commit dae8232

File tree

3 files changed

+81
-43
lines changed

3 files changed

+81
-43
lines changed

src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
2424
{
2525
var visitedExpression = base.VisitBinary(binaryExpression);
2626

27-
return TryOptimizeQueryableNullCheck(visitedExpression)
28-
?? TryOptimizeConditionalEquality(visitedExpression)
27+
return TryOptimizeConditionalEquality(visitedExpression)
28+
?? TryOptimizeQueryableNullCheck(visitedExpression)
2929
?? visitedExpression;
3030
}
3131

@@ -77,34 +77,6 @@ protected override Expression VisitConditional(ConditionalExpression conditional
7777
return base.VisitConditional(conditionalExpression);
7878
}
7979

80-
private static Expression? TryOptimizeQueryableNullCheck(Expression expression)
81-
{
82-
// Optimize IQueryable/DbSet null checks
83-
// IQueryable != null => true
84-
// IQueryable == null => false
85-
if (expression is BinaryExpression
86-
{
87-
NodeType: ExpressionType.Equal or ExpressionType.NotEqual
88-
} binaryExpression)
89-
{
90-
var isLeftNull = IsNullConstant(binaryExpression.Left);
91-
var isRightNull = IsNullConstant(binaryExpression.Right);
92-
93-
if (isLeftNull != isRightNull)
94-
{
95-
var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left;
96-
97-
if (IsQueryableType(nonNullExpression.Type))
98-
{
99-
var result = binaryExpression.NodeType == ExpressionType.NotEqual;
100-
return Expression.Constant(result, typeof(bool));
101-
}
102-
}
103-
}
104-
105-
return null;
106-
}
107-
10880
private static Expression? TryOptimizeConditionalEquality(Expression expression)
10981
{
11082
// Simplify (a ? b : null) == null => !a || b == null
@@ -145,6 +117,34 @@ protected override Expression VisitConditional(ConditionalExpression conditional
145117
return null;
146118
}
147119

120+
private static Expression? TryOptimizeQueryableNullCheck(Expression expression)
121+
{
122+
// Optimize IQueryable/DbSet null checks:
123+
// * IQueryable != null => true
124+
// * IQueryable == null => false
125+
if (expression is BinaryExpression
126+
{
127+
NodeType: ExpressionType.Equal or ExpressionType.NotEqual
128+
} binaryExpression)
129+
{
130+
var isLeftNull = IsNullConstant(binaryExpression.Left);
131+
var isRightNull = IsNullConstant(binaryExpression.Right);
132+
133+
if (isLeftNull != isRightNull)
134+
{
135+
var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left;
136+
137+
if (nonNullExpression.Type.IsAssignableTo(typeof(IQueryable)))
138+
{
139+
var result = binaryExpression.NodeType == ExpressionType.NotEqual;
140+
return Expression.Constant(result);
141+
}
142+
}
143+
}
144+
145+
return null;
146+
}
147+
148148
private sealed class NullSafeAccessVerifyingExpressionVisitor : ExpressionVisitor
149149
{
150150
private readonly ISet<Expression> _nullSafeAccesses = new HashSet<Expression>(ExpressionEqualityComparer.Instance);
@@ -191,17 +191,4 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
191191

192192
private static bool IsNullConstant(Expression expression)
193193
=> expression is ConstantExpression { Value: null };
194-
195-
private static bool IsQueryableType(Type type)
196-
{
197-
if (type.IsGenericType)
198-
{
199-
var genericTypeDefinition = type.GetGenericTypeDefinition();
200-
return genericTypeDefinition == typeof(IQueryable<>)
201-
|| genericTypeDefinition == typeof(IOrderedQueryable<>)
202-
|| genericTypeDefinition == typeof(DbSet<>);
203-
}
204-
205-
return type == typeof(IQueryable);
206-
}
207194
}

test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,27 @@ public virtual Task Where_Queryable_ToArray_Length_member(bool async)
14631463
assertOrder: true,
14641464
elementAsserter: (e, a) => AssertCollection(e, a));
14651465

1466+
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1467+
public virtual Task Where_Queryable_not_null_check_with_Contains(bool async)
1468+
=> AssertQuery(
1469+
async,
1470+
ss =>
1471+
{
1472+
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1473+
return ss.Set<Customer>().Where(c => ids != null && ids.Contains(c.CustomerID));
1474+
});
1475+
1476+
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1477+
public virtual Task Where_Queryable_null_check_with_Contains(bool async)
1478+
=> AssertQuery(
1479+
async,
1480+
ss =>
1481+
{
1482+
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1483+
return ss.Set<Customer>().Where(c => ids == null || !ids.Contains(c.CustomerID));
1484+
},
1485+
assertEmpty: true);
1486+
14661487
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
14671488
public virtual Task Where_collection_navigation_ToList_Count(bool async)
14681489
=> AssertQuery(

test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,36 @@ ORDER BY [c].[CustomerID]
18901890
""");
18911891
}
18921892

1893+
public override async Task Where_Queryable_not_null_check_with_Contains(bool async)
1894+
{
1895+
await base.Where_Queryable_not_null_check_with_Contains(async);
1896+
1897+
AssertSql(
1898+
"""
1899+
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
1900+
FROM [Customers] AS [c]
1901+
WHERE [c].[CustomerID] IN (
1902+
SELECT [c0].[CustomerID]
1903+
FROM [Customers] AS [c0]
1904+
)
1905+
""");
1906+
}
1907+
1908+
public override async Task Where_Queryable_null_check_with_Contains(bool async)
1909+
{
1910+
await base.Where_Queryable_null_check_with_Contains(async);
1911+
1912+
AssertSql(
1913+
"""
1914+
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
1915+
FROM [Customers] AS [c]
1916+
WHERE [c].[CustomerID] NOT IN (
1917+
SELECT [c0].[CustomerID]
1918+
FROM [Customers] AS [c0]
1919+
)
1920+
""");
1921+
}
1922+
18931923
public override async Task Where_collection_navigation_ToList_Count(bool async)
18941924
{
18951925
await base.Where_collection_navigation_ToList_Count(async);

0 commit comments

Comments
 (0)