Skip to content

Commit 7eb2686

Browse files
authored
fix: False-positive MAP_SUBSET performance warnings (#26304)
Summary: The MAP_SUBSET performance warning was triggering incorrectly in many cases: 1. Warning on transform_values (which transforms values, not filters keys) 2. Warning on map_filter when lambda uses values (e.g., v > 2) 3. Warning on map_filter with non-membership key comparisons (e.g., k > 0) 4. Warning on all map columns, even when not related to features This diff tightens the detection logic to only warn when ALL conditions are met: - Function is map_filter (not transform_values or other map functions) - Lambda does NOT reference the value argument - Lambda uses simple key membership tests (k = c, k IN (...), OR combinations, CONTAINS) - Column name contains "features" (the main motivation for this rule) This eliminates false positives while focusing on the intended use case of optimizing feature map filtering operations. Implementation details: - Added isKeyOnlyMembershipFilter() to validate lambda only uses keys - Added expressionReferencesName() to detect value argument usage - Added isSimpleKeyEquality() to validate membership-only comparisons - Added containsFeatures() to limit warnings to feature-related columns - Updated test cases to reflect correct warning behavior Differential Revision: D84627028 ``` == NO RELEASE NOTE == ```
1 parent 3705024 commit 7eb2686

File tree

3 files changed

+115
-33
lines changed

3 files changed

+115
-33
lines changed

presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java

Lines changed: 103 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,24 +1132,15 @@ else if (frame.getType() == GROUPS) {
11321132
List<TypeSignature> arguments = functionMetadata.getArgumentTypes();
11331133
String functionName = functionMetadata.getName().toString();
11341134

1135-
if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase())) {
1136-
if (arguments.size() > 1) {
1137-
arguments.stream()
1138-
.skip(1)
1139-
.filter(arg -> {
1140-
String base = arg.getBase();
1141-
return "function".equals(base) || "lambda".equals(base);
1142-
})
1143-
.findFirst()
1144-
.ifPresent(arg -> {
1145-
String warningMessage = createWarningMessage(node,
1146-
String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName));
1147-
warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage));
1148-
});
1149-
}
1150-
else if (arguments.size() == 1) {
1151-
String base = arguments.get(0).getBase();
1152-
if ("function".equals(base) || "lambda".equals(base)) {
1135+
if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase()) &&
1136+
"map_filter".equalsIgnoreCase(functionMetadata.getName().getObjectName()) &&
1137+
arguments.size() > 1 && node.getArguments().size() >= 2) {
1138+
Expression mapArg = node.getArguments().get(0);
1139+
Expression lambdaArg = node.getArguments().get(1);
1140+
1141+
if (containsFeatures(mapArg) && lambdaArg instanceof LambdaExpression) {
1142+
LambdaExpression lambda = (LambdaExpression) lambdaArg;
1143+
if (lambda.getArguments().size() == 2 && isKeyOnlyMembershipFilter(lambda)) {
11531144
String warningMessage = createWarningMessage(node,
11541145
String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName));
11551146
warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage));
@@ -1216,6 +1207,100 @@ private String createWarningMessage(Node node, String message)
12161207
}
12171208
}
12181209

1210+
private boolean isKeyOnlyMembershipFilter(LambdaExpression lambda)
1211+
{
1212+
String valueArgName = lambda.getArguments().get(1).getName().getValue();
1213+
Expression body = lambda.getBody();
1214+
1215+
if (expressionReferencesName(body, valueArgName)) {
1216+
return false;
1217+
}
1218+
1219+
return isSimpleKeyEquality(body);
1220+
}
1221+
1222+
private boolean expressionReferencesName(Expression expression, String name)
1223+
{
1224+
if (expression == null) {
1225+
return false;
1226+
}
1227+
if (expression instanceof Identifier) {
1228+
return ((Identifier) expression).getValue().equalsIgnoreCase(name);
1229+
}
1230+
if (expression instanceof ComparisonExpression) {
1231+
ComparisonExpression comp = (ComparisonExpression) expression;
1232+
return expressionReferencesName(comp.getLeft(), name) || expressionReferencesName(comp.getRight(), name);
1233+
}
1234+
if (expression instanceof LogicalBinaryExpression) {
1235+
LogicalBinaryExpression logical = (LogicalBinaryExpression) expression;
1236+
return expressionReferencesName(logical.getLeft(), name) || expressionReferencesName(logical.getRight(), name);
1237+
}
1238+
if (expression instanceof InPredicate) {
1239+
InPredicate inPred = (InPredicate) expression;
1240+
return expressionReferencesName(inPred.getValue(), name) || expressionReferencesName(inPred.getValueList(), name);
1241+
}
1242+
if (expression instanceof InListExpression) {
1243+
InListExpression inList = (InListExpression) expression;
1244+
for (Expression value : inList.getValues()) {
1245+
if (expressionReferencesName(value, name)) {
1246+
return true;
1247+
}
1248+
}
1249+
}
1250+
if (expression instanceof ArithmeticBinaryExpression) {
1251+
ArithmeticBinaryExpression arith = (ArithmeticBinaryExpression) expression;
1252+
return expressionReferencesName(arith.getLeft(), name) || expressionReferencesName(arith.getRight(), name);
1253+
}
1254+
if (expression instanceof FunctionCall) {
1255+
FunctionCall func = (FunctionCall) expression;
1256+
for (Expression arg : func.getArguments()) {
1257+
if (expressionReferencesName(arg, name)) {
1258+
return true;
1259+
}
1260+
}
1261+
}
1262+
// Literals don't reference any names
1263+
return false;
1264+
}
1265+
1266+
private boolean containsFeatures(Expression expression)
1267+
{
1268+
if (expression instanceof Identifier) {
1269+
return ((Identifier) expression).getValue().toLowerCase().contains("features");
1270+
}
1271+
if (expression instanceof SymbolReference) {
1272+
return ((SymbolReference) expression).getName().toLowerCase().contains("features");
1273+
}
1274+
if (expression instanceof DereferenceExpression) {
1275+
DereferenceExpression deref = (DereferenceExpression) expression;
1276+
return containsFeatures(deref.getBase()) || deref.getField().getValue().toLowerCase().contains("features");
1277+
}
1278+
return false;
1279+
}
1280+
1281+
private boolean isSimpleKeyEquality(Expression expression)
1282+
{
1283+
if (expression instanceof ComparisonExpression) {
1284+
ComparisonExpression comparison = (ComparisonExpression) expression;
1285+
return comparison.getOperator() == ComparisonExpression.Operator.EQUAL;
1286+
}
1287+
if (expression instanceof InPredicate) {
1288+
return true;
1289+
}
1290+
if (expression instanceof LogicalBinaryExpression) {
1291+
LogicalBinaryExpression logical = (LogicalBinaryExpression) expression;
1292+
if (logical.getOperator() == LogicalBinaryExpression.Operator.OR) {
1293+
return isSimpleKeyEquality(logical.getLeft()) && isSimpleKeyEquality(logical.getRight());
1294+
}
1295+
}
1296+
if (expression instanceof FunctionCall) {
1297+
FunctionCall func = (FunctionCall) expression;
1298+
String funcName = func.getName().toString();
1299+
return funcName.equalsIgnoreCase("contains") || funcName.equalsIgnoreCase("presto.default.contains");
1300+
}
1301+
return false;
1302+
}
1303+
12191304
private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type boundType, StackableAstVisitorContext<Context> context, Window window)
12201305
{
12211306
if (!window.getOrderBy().isPresent()) {

presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -168,30 +168,27 @@ void testNoORWarning()
168168
@Test
169169
public void testMapFilterWarnings()
170170
{
171-
assertHasWarning(
172-
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(x)"),
173-
PERFORMANCE_WARNING,
174-
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
171+
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)"));
175172

176173
assertHasWarning(
177-
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
174+
analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"),
178175
PERFORMANCE_WARNING,
179176
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
180177

181178
assertHasWarning(
182-
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
179+
analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"),
183180
PERFORMANCE_WARNING,
184181
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
185182

186-
assertHasWarning(
187-
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
188-
PERFORMANCE_WARNING,
189-
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
183+
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"));
190184

191-
assertHasWarning(
192-
analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"),
193-
PERFORMANCE_WARNING,
194-
"Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset");
185+
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"));
186+
187+
assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k > 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"));
188+
189+
assertNoWarning(analyzeWithWarnings("SELECT transform_values(user_features, (k, v) -> v * 2) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)"));
190+
191+
assertNoWarning(analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)"));
195192
}
196193

197194
@Test

presto-tests/src/test/java/com/facebook/presto/execution/TestWarnings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ public void testMapWithDoubleKeysProducesWarnings()
180180
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
181181

182182
query = "select transform_keys(map(ARRAY [25.5E0, 26.5E0, 27.5E0], ARRAY [25.5E0, 26.5E0, 27.5E0]), (k, v) -> k + v)";
183-
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode(), PERFORMANCE_WARNING.toWarningCode()));
183+
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
184184

185185
query = "SELECT histogram(RETAILPRICE) FROM tpch.tiny.part";
186186
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));

0 commit comments

Comments
 (0)