Skip to content

Commit e36862c

Browse files
committed
Fix on top of Copilot's work.
1 parent 3de19b9 commit e36862c

File tree

4 files changed

+97
-68
lines changed

4 files changed

+97
-68
lines changed

src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs

Lines changed: 30 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
2424
{
2525
var visitedExpression = base.VisitBinary(binaryExpression);
2626

27-
return TryOptimizeQueryableNullCheck(visitedExpression)
28-
?? TryOptimizeConditionalEquality(visitedExpression)
27+
return TryOptimizeConditionalEquality(visitedExpression)
28+
?? TryOptimizeQueryableNullCheck(visitedExpression)
2929
?? visitedExpression;
3030
}
3131

@@ -77,34 +77,6 @@ protected override Expression VisitConditional(ConditionalExpression conditional
7777
return base.VisitConditional(conditionalExpression);
7878
}
7979

80-
private static Expression? TryOptimizeQueryableNullCheck(Expression expression)
81-
{
82-
// Optimize IQueryable/DbSet null checks
83-
// IQueryable != null => true
84-
// IQueryable == null => false
85-
if (expression is BinaryExpression
86-
{
87-
NodeType: ExpressionType.Equal or ExpressionType.NotEqual
88-
} binaryExpression)
89-
{
90-
var isLeftNull = IsNullConstant(binaryExpression.Left);
91-
var isRightNull = IsNullConstant(binaryExpression.Right);
92-
93-
if (isLeftNull != isRightNull)
94-
{
95-
var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left;
96-
97-
if (IsQueryableType(nonNullExpression.Type))
98-
{
99-
var result = binaryExpression.NodeType == ExpressionType.NotEqual;
100-
return Expression.Constant(result, typeof(bool));
101-
}
102-
}
103-
}
104-
105-
return null;
106-
}
107-
10880
private static Expression? TryOptimizeConditionalEquality(Expression expression)
10981
{
11082
// Simplify (a ? b : null) == null => !a || b == null
@@ -145,6 +117,34 @@ protected override Expression VisitConditional(ConditionalExpression conditional
145117
return null;
146118
}
147119

120+
private static Expression? TryOptimizeQueryableNullCheck(Expression expression)
121+
{
122+
// Optimize IQueryable/DbSet null checks:
123+
// * IQueryable != null => true
124+
// * IQueryable == null => false
125+
if (expression is BinaryExpression
126+
{
127+
NodeType: ExpressionType.Equal or ExpressionType.NotEqual
128+
} binaryExpression)
129+
{
130+
var isLeftNull = IsNullConstant(binaryExpression.Left);
131+
var isRightNull = IsNullConstant(binaryExpression.Right);
132+
133+
if (isLeftNull != isRightNull)
134+
{
135+
var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left;
136+
137+
if (nonNullExpression.Type.IsAssignableTo(typeof(IQueryable)))
138+
{
139+
var result = binaryExpression.NodeType == ExpressionType.NotEqual;
140+
return Expression.Constant(result);
141+
}
142+
}
143+
}
144+
145+
return null;
146+
}
147+
148148
private sealed class NullSafeAccessVerifyingExpressionVisitor : ExpressionVisitor
149149
{
150150
private readonly ISet<Expression> _nullSafeAccesses = new HashSet<Expression>(ExpressionEqualityComparer.Instance);
@@ -191,17 +191,4 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
191191

192192
private static bool IsNullConstant(Expression expression)
193193
=> expression is ConstantExpression { Value: null };
194-
195-
private static bool IsQueryableType(Type type)
196-
{
197-
if (type.IsGenericType)
198-
{
199-
var genericTypeDefinition = type.GetGenericTypeDefinition();
200-
return genericTypeDefinition == typeof(IQueryable<>)
201-
|| genericTypeDefinition == typeof(IOrderedQueryable<>)
202-
|| genericTypeDefinition == typeof(DbSet<>);
203-
}
204-
205-
return type == typeof(IQueryable);
206-
}
207194
}

test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,22 @@ public override async Task Where_Queryable_AsEnumerable_Contains_negated(bool as
16991699
AssertSql();
17001700
}
17011701

1702+
public override async Task Where_Queryable_not_null_check_with_Contains(bool async)
1703+
{
1704+
// Cosmos client evaluation. Issue #17246.
1705+
await AssertTranslationFailed(() => base.Where_Queryable_not_null_check_with_Contains(async));
1706+
1707+
AssertSql();
1708+
}
1709+
1710+
public override async Task Where_Queryable_null_check_with_Contains(bool async)
1711+
{
1712+
// Cosmos client evaluation. Issue #17246.
1713+
await AssertTranslationFailed(() => base.Where_Queryable_null_check_with_Contains(async));
1714+
1715+
AssertSql();
1716+
}
1717+
17021718
public override Task Where_list_object_contains_over_value_type(bool async)
17031719
=> Fixture.NoSyncTest(
17041720
async, async a =>

test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,31 +1418,6 @@ public virtual Task Where_Queryable_AsEnumerable_Contains_negated(bool async)
14181418
elementSorter: e => e.CustomerID,
14191419
elementAsserter: (e, a) => AssertCollection(e.Subquery, a.Subquery));
14201420

1421-
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1422-
public virtual Task Where_Queryable_null_check_with_Contains(bool async)
1423-
{
1424-
return AssertQuery(
1425-
async,
1426-
ss =>
1427-
{
1428-
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1429-
return ss.Set<Customer>().Where(c => ids != null && ids.Contains(c.CustomerID));
1430-
});
1431-
}
1432-
1433-
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1434-
public virtual Task Where_Queryable_null_check_equal(bool async)
1435-
{
1436-
return AssertQuery(
1437-
async,
1438-
ss =>
1439-
{
1440-
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1441-
return ss.Set<Customer>().Where(c => ids == null || !ids.Contains(c.CustomerID));
1442-
},
1443-
assertEmpty: true);
1444-
}
1445-
14461421
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
14471422
public virtual Task Where_Queryable_ToList_Count_member(bool async)
14481423
=> AssertQuery(
@@ -1463,6 +1438,27 @@ public virtual Task Where_Queryable_ToArray_Length_member(bool async)
14631438
assertOrder: true,
14641439
elementAsserter: (e, a) => AssertCollection(e, a));
14651440

1441+
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1442+
public virtual Task Where_Queryable_not_null_check_with_Contains(bool async)
1443+
=> AssertQuery(
1444+
async,
1445+
ss =>
1446+
{
1447+
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1448+
return ss.Set<Customer>().Where(c => ids != null && ids.Contains(c.CustomerID));
1449+
});
1450+
1451+
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
1452+
public virtual Task Where_Queryable_null_check_with_Contains(bool async)
1453+
=> AssertQuery(
1454+
async,
1455+
ss =>
1456+
{
1457+
var ids = ss.Set<Customer>().Select(c => c.CustomerID);
1458+
return ss.Set<Customer>().Where(c => ids == null || !ids.Contains(c.CustomerID));
1459+
},
1460+
assertEmpty: true);
1461+
14661462
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
14671463
public virtual Task Where_collection_navigation_ToList_Count(bool async)
14681464
=> AssertQuery(

test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,36 @@ ORDER BY [c].[CustomerID]
18901890
""");
18911891
}
18921892

1893+
public override async Task Where_Queryable_not_null_check_with_Contains(bool async)
1894+
{
1895+
await base.Where_Queryable_not_null_check_with_Contains(async);
1896+
1897+
AssertSql(
1898+
"""
1899+
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
1900+
FROM [Customers] AS [c]
1901+
WHERE [c].[CustomerID] IN (
1902+
SELECT [c0].[CustomerID]
1903+
FROM [Customers] AS [c0]
1904+
)
1905+
""");
1906+
}
1907+
1908+
public override async Task Where_Queryable_null_check_with_Contains(bool async)
1909+
{
1910+
await base.Where_Queryable_null_check_with_Contains(async);
1911+
1912+
AssertSql(
1913+
"""
1914+
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
1915+
FROM [Customers] AS [c]
1916+
WHERE [c].[CustomerID] NOT IN (
1917+
SELECT [c0].[CustomerID]
1918+
FROM [Customers] AS [c0]
1919+
)
1920+
""");
1921+
}
1922+
18931923
public override async Task Where_collection_navigation_ToList_Count(bool async)
18941924
{
18951925
await base.Where_collection_navigation_ToList_Count(async);

0 commit comments

Comments
 (0)