diff --git a/.gitignore b/.gitignore index 329348a7c12..192594e0ba4 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,8 @@ http-client.env.json # Coding agent files (could be symlinks) .claude .clinerules -memory-bank \ No newline at end of file +memory-bank + +# uv environment config (helps with IDE plugins depending on Python) +pyproject.toml +uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index efa9c779423..1471a22efd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,11 +3,6 @@ repos: hooks: - id: spotless-format name: Spotless Format - entry: bash -c './gradlew spotlessApply && git add -u' - language: system - pass_filenames: false - - id: spotless-check - name: Spotless Post-format Check - entry: bash -c './gradlew spotlessCheck' + entry: bash -c './gradlew spotlessApply spotlessCheck && git add -u' language: system pass_filenames: false diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index d5c37d405ec..e73f38ede79 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -97,6 +97,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; +import org.opensearch.sql.ast.tree.args.RareTopNArguments; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprMissingValue; import org.opensearch.sql.data.type.ExprCoreType; @@ -378,10 +379,10 @@ public LogicalPlan visitRareTopN(RareTopN node, AnalysisContext context) { fields.forEach( field -> newEnv.define(new Symbol(Namespace.FIELD_NAME, field.toString()), field.type())); - List options = node.getArguments(); - Integer noOfResults = (Integer) options.get(0).getValue().getValue(); + RareTopNArguments options = node.getArguments(); - return new LogicalRareTopN(child, node.getCommandType(), noOfResults, fields, groupBys); + return new LogicalRareTopN( + child, node.getCommandType(), options.getNoOfResults(), fields, groupBys); } /** diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index c8600bb9809..becbfa7ef3c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -80,6 +80,7 @@ import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; +import org.opensearch.sql.ast.tree.args.RareTopNArguments; /** Class of static methods to create specific node instances. */ @UtilityClass @@ -494,8 +495,8 @@ public static Head head(UnresolvedPlan input, Integer size, Integer from) { return new Head(input, size, from); } - public static List defaultTopArgs() { - return exprList(argument("noOfResults", intLiteral(10))); + public static List defaultTopRareArgs() { + return new RareTopNArguments().asExprList(); } public static RareTopN rareTopN( diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java b/core/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java index 3fd3aa3a2c0..cb87a550cbb 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/RareTopN.java @@ -17,6 +17,7 @@ import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.args.RareTopNArguments; /** AST node represent RareTopN operation. */ @Getter @@ -29,11 +30,16 @@ public class RareTopN extends UnresolvedPlan { private UnresolvedPlan child; private final CommandType commandType; - // arguments: noOfResults: Integer, countField: String, showCount: Boolean + // arguments: noOfResults: Integer, countField: String, showCount: Boolean, percentField: String, + // showPerc: Boolean, useOther: Boolean private final List arguments; private final List fields; private final List groupExprList; + public RareTopNArguments getArguments() { + return new RareTopNArguments(arguments); + } + @Override public RareTopN attach(UnresolvedPlan child) { this.child = child; diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/args/RareTopNArguments.java b/core/src/main/java/org/opensearch/sql/ast/tree/args/RareTopNArguments.java new file mode 100644 index 00000000000..df9d2a96ecb --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/args/RareTopNArguments.java @@ -0,0 +1,116 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree.args; + +import static org.opensearch.sql.ast.dsl.AstDSL.argument; +import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.exprList; +import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; + +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; +import org.jetbrains.annotations.TestOnly; +import org.opensearch.sql.ast.expression.Argument; + +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@AllArgsConstructor +public class RareTopNArguments { + public static final String NUMBER_RESULTS = "noOfResults"; + public static final String COUNT_FIELD = "countField"; + public static final String SHOW_COUNT = "showCount"; + public static final String PERCENT_FIELD = "percentField"; + public static final String SHOW_PERCENT = "showPerc"; + public static final String USE_OTHER = "useOther"; + + private int noOfResults = 10; + private String countField = "count"; + private boolean showCount = true; + private String percentField = "percent"; + private boolean showPerc = false; + private boolean useOther = false; + + public RareTopNArguments(List arguments) { + // handle `percent=whatever showperc=false` (though I'm not sure if it's ever useful to do so) + boolean isShowPercOverridden = false; + + for (Argument arg : arguments) { + switch (arg.getArgName()) { + case NUMBER_RESULTS: + noOfResults = (int) arg.getValue().getValue(); + if (noOfResults < 0) { + throw new IllegalArgumentException( + "Illegal number of results requested for top/rare: must be non-negative"); + } + break; + case COUNT_FIELD: + countField = (String) arg.getValue().getValue(); + if (countField.isBlank()) { + throw new IllegalArgumentException("Illegal count field in top/rare: cannot be blank"); + } + break; + case SHOW_COUNT: + showCount = (boolean) arg.getValue().getValue(); + break; + case PERCENT_FIELD: + percentField = (String) arg.getValue().getValue(); + if (percentField.isBlank()) { + throw new IllegalArgumentException( + "Illegal percent field in top/rare: cannot be blank"); + } + if (!isShowPercOverridden) { + showPerc = true; + } + break; + case SHOW_PERCENT: + showPerc = (boolean) arg.getValue().getValue(); + isShowPercOverridden = true; + break; + case USE_OTHER: + useOther = (boolean) arg.getValue().getValue(); + break; + default: + throw new IllegalArgumentException("unknown argument for rare/top: " + arg.getArgName()); + } + } + } + + public String renderOptions() { + StringBuilder options = new StringBuilder(); + if (showCount) { + options.append("countfield='").append(countField).append("' "); + } else { + options.append("showcount=false "); + } + if (showPerc) { + options.append("percfield='").append(percentField).append("' "); + } else { + options.append("showperc=false "); + } + if (useOther) { + options.append("useother=true "); + } + return options.toString(); + } + + @TestOnly + public List asExprList() { + return exprList( + argument("noOfResults", intLiteral(10)), + argument("countField", stringLiteral("count")), + argument("showCount", booleanLiteral(true)), + argument("percentField", stringLiteral("percent")), + argument("showPerc", booleanLiteral(false)), + argument("useOther", booleanLiteral(false))); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index cb01493fcda..9f98e283043 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -77,7 +77,6 @@ import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Argument.ArgumentMap; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.Let; @@ -131,6 +130,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; +import org.opensearch.sql.ast.tree.args.RareTopNArguments; import org.opensearch.sql.calcite.plan.OpenSearchConstants; import org.opensearch.sql.calcite.utils.BinUtils; import org.opensearch.sql.calcite.utils.JoinAndLookupUtils; @@ -1729,13 +1729,19 @@ public RelNode visitKmeans(Kmeans node, CalcitePlanContext context) { public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) { visitChildren(node, context); - ArgumentMap arguments = ArgumentMap.of(node.getArguments()); - String countFieldName = (String) arguments.get("countField").getValue(); + RareTopNArguments arguments = node.getArguments(); + String countFieldName = arguments.getCountField(); if (context.relBuilder.peek().getRowType().getFieldNames().contains(countFieldName)) { throw new IllegalArgumentException( - "Field `" + "The top/rare output field `" + countFieldName - + "` is existed, change the count field by setting countfield='xyz'"); + + "` already exists. Suggestion: change the count field by adding countfield='xyz'"); + } + if (arguments.isUseOther()) { + throw new CalciteUnsupportedException("`useother` is currently unsupported. (Coming soon)"); + } + if (arguments.isShowPerc()) { + throw new CalciteUnsupportedException("`showperc` is currently unsupported. (Coming soon)"); } // 1. group the group-by list + field list and add a count() aggregation @@ -1768,14 +1774,13 @@ public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) { context.relBuilder.alias(rowNumberWindowOver, ROW_NUMBER_COLUMN_NAME)); // 3. filter row_number() <= k in each partition - Integer N = (Integer) arguments.get("noOfResults").getValue(); context.relBuilder.filter( context.relBuilder.lessThanOrEqual( - context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), context.relBuilder.literal(N))); + context.relBuilder.field(ROW_NUMBER_COLUMN_NAME), + context.relBuilder.literal(arguments.getNoOfResults()))); // 4. project final output. the default output is group by list + field list - Boolean showCount = (Boolean) arguments.get("showCount").getValue(); - if (showCount) { + if (arguments.isShowCount()) { context.relBuilder.projectExcept(context.relBuilder.field(ROW_NUMBER_COLUMN_NAME)); } else { context.relBuilder.projectExcept( diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 48b051e456b..05ee5ff3f9d 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -126,6 +126,8 @@ ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD'; APPEND: 'APPEND'; COUNTFIELD: 'COUNTFIELD'; SHOWCOUNT: 'SHOWCOUNT'; +SHOWPERC: 'SHOWPERC'; +PERCENTFIELD: 'PERCENTFIELD'; LIMIT: 'LIMIT'; USEOTHER: 'USEOTHER'; INPUT: 'INPUT'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 0bc7b784338..9554c7d3f3d 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -302,11 +302,19 @@ logSpanValue ; topCommand - : TOP (number = integerLiteral)? (COUNTFIELD EQUAL countfield = stringLiteral)? (SHOWCOUNT EQUAL showcount = booleanLiteral)? fieldList (byClause)? + : TOP (number = integerLiteral)? topRareParameter* fieldList (byClause)? ; rareCommand - : RARE (number = integerLiteral)? (COUNTFIELD EQUAL countfield = stringLiteral)? (SHOWCOUNT EQUAL showcount = booleanLiteral)? fieldList (byClause)? + : RARE (number = integerLiteral)? topRareParameter* fieldList (byClause)? + ; + +topRareParameter + : (COUNTFIELD EQUAL countfield = stringLiteral) + | (SHOWCOUNT EQUAL showcount = booleanLiteral) + | (PERCENTFIELD EQUAL percentfield = stringLiteral) + | (SHOWPERC EQUAL showperc = booleanLiteral) + | (USEOTHER EQUAL useother = booleanLiteral) ; grokCommand @@ -1462,6 +1470,8 @@ searchableKeyWord | ANOMALY_SCORE_THRESHOLD | COUNTFIELD | SHOWCOUNT + | SHOWPERC + | PERCENTFIELD | PATH | INPUT | OUTPUT diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java index e1d892fdfce..1bf9c1cdf43 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/ArgumentFactory.java @@ -14,6 +14,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.ast.tree.args.RareTopNArguments; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.exception.SemanticCheckException; @@ -127,6 +128,38 @@ public static List getArgumentList(SortFieldContext ctx) { : new Argument("type", new Literal(null, DataType.NULL))); } + /** + * Helper: dedupe list generation logic for top & rare. Helps ensure both commands consistently + * support the same args. + */ + private static List topRareArgList( + ParserRuleContext number, List params) { + List args = new ArrayList<>(6); + if (number != null) { + args.add(new Argument(RareTopNArguments.NUMBER_RESULTS, getArgumentValue(number))); + } + + for (OpenSearchPPLParser.TopRareParameterContext param : params) { + if (param.countfield != null) { + args.add(new Argument(RareTopNArguments.COUNT_FIELD, getArgumentValue(param.countfield))); + } + if (param.showcount != null) { + args.add(new Argument(RareTopNArguments.SHOW_COUNT, getArgumentValue(param.showcount))); + } + if (param.percentfield != null) { + args.add( + new Argument(RareTopNArguments.PERCENT_FIELD, getArgumentValue(param.percentfield))); + } + if (param.showperc != null) { + args.add(new Argument(RareTopNArguments.SHOW_PERCENT, getArgumentValue(param.showperc))); + } + if (param.useother != null) { + args.add(new Argument(RareTopNArguments.USE_OTHER, getArgumentValue(param.useother))); + } + } + return args; + } + /** * Get list of {@link Argument}. * @@ -134,16 +167,7 @@ public static List getArgumentList(SortFieldContext ctx) { * @return the list of arguments fetched from the top command */ public static List getArgumentList(TopCommandContext ctx) { - return Arrays.asList( - ctx.number != null - ? new Argument("noOfResults", getArgumentValue(ctx.number)) - : new Argument("noOfResults", new Literal(10, DataType.INTEGER)), - ctx.countfield != null - ? new Argument("countField", getArgumentValue(ctx.countfield)) - : new Argument("countField", new Literal("count", DataType.STRING)), - ctx.showcount != null - ? new Argument("showCount", getArgumentValue(ctx.showcount)) - : new Argument("showCount", new Literal(true, DataType.BOOLEAN))); + return topRareArgList(ctx.number, ctx.topRareParameter()); } /** @@ -153,16 +177,7 @@ public static List getArgumentList(TopCommandContext ctx) { * @return the list of argument with default number of results for the rare command */ public static List getArgumentList(RareCommandContext ctx) { - return Arrays.asList( - ctx.number != null - ? new Argument("noOfResults", getArgumentValue(ctx.number)) - : new Argument("noOfResults", new Literal(10, DataType.INTEGER)), - ctx.countfield != null - ? new Argument("countField", getArgumentValue(ctx.countfield)) - : new Argument("countField", new Literal("count", DataType.STRING)), - ctx.showcount != null - ? new Argument("showCount", getArgumentValue(ctx.showcount)) - : new Argument("showCount", new Literal(true, DataType.BOOLEAN))); + return topRareArgList(ctx.number, ctx.topRareParameter()); } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index eeca282cb9d..084b8fac8cd 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -26,7 +26,6 @@ import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta; import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Argument.ArgumentMap; import org.opensearch.sql.ast.expression.Between; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; @@ -89,6 +88,7 @@ import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; +import org.opensearch.sql.ast.tree.args.RareTopNArguments; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.planner.logical.LogicalAggregation; @@ -353,17 +353,13 @@ public String visitWindow(Window node, String context) { /** Build {@link LogicalRareTopN}. */ @Override public String visitRareTopN(RareTopN node, String context) { - final String child = node.getChild().get(0).accept(this, context); - ArgumentMap arguments = ArgumentMap.of(node.getArguments()); - Integer noOfResults = (Integer) arguments.get("noOfResults").getValue(); - String countField = (String) arguments.get("countField").getValue(); - Boolean showCount = (Boolean) arguments.get("showCount").getValue(); + final String child = node.getChild().getFirst().accept(this, context); + RareTopNArguments arguments = node.getArguments(); + Integer noOfResults = arguments.getNoOfResults(); String fields = visitFieldList(node.getFields()); String group = visitExpressionList(node.getGroupExprList()); - String options = - isCalciteEnabled(settings) - ? StringUtils.format("countield='%s' showcount=%s ", countField, showCount) - : ""; + String options = isCalciteEnabled(settings) ? arguments.renderOptions() : ""; + return StringUtils.format( "%s | %s %d %s%s", child, diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLRareTopNTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLRareTopNTest.java index 23dab511671..adb5de765c2 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLRareTopNTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLRareTopNTest.java @@ -24,31 +24,35 @@ public void testRare() { String ppl = "source=EMP | rare JOB"; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(JOB=[$0], count=[$1])\n" - + " LogicalFilter(condition=[<=($2, 10)])\n" - + " LogicalProject(JOB=[$0], count=[$1], _row_number_=[ROW_NUMBER() OVER (ORDER BY" - + " $1)])\n" - + " LogicalAggregate(group=[{0}], count=[COUNT()])\n" - + " LogicalProject(JOB=[$2])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + """ + LogicalProject(JOB=[$0], count=[$1]) + LogicalFilter(condition=[<=($2, 10)]) + LogicalProject(JOB=[$0], count=[$1], _row_number_=[ROW_NUMBER() OVER (ORDER BY\ + $1)]) + LogicalAggregate(group=[{0}], count=[COUNT()]) + LogicalProject(JOB=[$2]) + LogicalTableScan(table=[[scott, EMP]]) + """; verifyLogical(root, expectedLogical); String expectedResult = - "" - + "JOB=PRESIDENT; count=1\n" - + "JOB=ANALYST; count=2\n" - + "JOB=MANAGER; count=3\n" - + "JOB=SALESMAN; count=4\n" - + "JOB=CLERK; count=4\n"; + """ + JOB=PRESIDENT; count=1 + JOB=ANALYST; count=2 + JOB=MANAGER; count=3 + JOB=SALESMAN; count=4 + JOB=CLERK; count=4 + """; verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT `JOB`, `count`\n" - + "FROM (SELECT `JOB`, COUNT(*) `count`, ROW_NUMBER() OVER (ORDER BY COUNT(*) NULLS" - + " LAST) `_row_number_`\n" - + "FROM `scott`.`EMP`\n" - + "GROUP BY `JOB`) `t1`\n" - + "WHERE `_row_number_` <= 10"; + """ + SELECT `JOB`, `count` + FROM (SELECT `JOB`, COUNT(*) `count`, ROW_NUMBER() OVER (ORDER BY COUNT(*) NULLS\ + LAST) `_row_number_` + FROM `scott`.`EMP` + GROUP BY `JOB`) `t1` + WHERE `_row_number_` <= 10"""; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -57,35 +61,39 @@ public void testRareBy() { String ppl = "source=EMP | rare JOB by DEPTNO"; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(DEPTNO=[$0], JOB=[$1], count=[$2])\n" - + " LogicalFilter(condition=[<=($3, 10)])\n" - + " LogicalProject(DEPTNO=[$0], JOB=[$1], count=[$2], _row_number_=[ROW_NUMBER()" - + " OVER (PARTITION BY $0 ORDER BY $2)])\n" - + " LogicalAggregate(group=[{0, 1}], count=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7], JOB=[$2])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + """ + LogicalProject(DEPTNO=[$0], JOB=[$1], count=[$2]) + LogicalFilter(condition=[<=($3, 10)]) + LogicalProject(DEPTNO=[$0], JOB=[$1], count=[$2], _row_number_=[ROW_NUMBER()\ + OVER (PARTITION BY $0 ORDER BY $2)]) + LogicalAggregate(group=[{0, 1}], count=[COUNT()]) + LogicalProject(DEPTNO=[$7], JOB=[$2]) + LogicalTableScan(table=[[scott, EMP]]) + """; verifyLogical(root, expectedLogical); String expectedResult = - "" - + "DEPTNO=20; JOB=MANAGER; count=1\n" - + "DEPTNO=20; JOB=CLERK; count=2\n" - + "DEPTNO=20; JOB=ANALYST; count=2\n" - + "DEPTNO=10; JOB=MANAGER; count=1\n" - + "DEPTNO=10; JOB=CLERK; count=1\n" - + "DEPTNO=10; JOB=PRESIDENT; count=1\n" - + "DEPTNO=30; JOB=MANAGER; count=1\n" - + "DEPTNO=30; JOB=CLERK; count=1\n" - + "DEPTNO=30; JOB=SALESMAN; count=4\n"; + """ + DEPTNO=20; JOB=MANAGER; count=1 + DEPTNO=20; JOB=CLERK; count=2 + DEPTNO=20; JOB=ANALYST; count=2 + DEPTNO=10; JOB=MANAGER; count=1 + DEPTNO=10; JOB=CLERK; count=1 + DEPTNO=10; JOB=PRESIDENT; count=1 + DEPTNO=30; JOB=MANAGER; count=1 + DEPTNO=30; JOB=CLERK; count=1 + DEPTNO=30; JOB=SALESMAN; count=4 + """; verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT `DEPTNO`, `JOB`, `count`\n" - + "FROM (SELECT `DEPTNO`, `JOB`, COUNT(*) `count`, ROW_NUMBER() OVER (PARTITION BY" - + " `DEPTNO` ORDER BY COUNT(*) NULLS LAST) `_row_number_`\n" - + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`, `JOB`) `t1`\n" - + "WHERE `_row_number_` <= 10"; + """ + SELECT `DEPTNO`, `JOB`, `count` + FROM (SELECT `DEPTNO`, `JOB`, COUNT(*) `count`, ROW_NUMBER() OVER (PARTITION BY\ + `DEPTNO` ORDER BY COUNT(*) NULLS LAST) `_row_number_` + FROM `scott`.`EMP` + GROUP BY `DEPTNO`, `JOB`) `t1` + WHERE `_row_number_` <= 10"""; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -94,35 +102,39 @@ public void testRareDisableShowCount() { String ppl = "source=EMP | rare showcount=false JOB by DEPTNO"; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(DEPTNO=[$0], JOB=[$1])\n" - + " LogicalFilter(condition=[<=($3, 10)])\n" - + " LogicalProject(DEPTNO=[$0], JOB=[$1], count=[$2], _row_number_=[ROW_NUMBER()" - + " OVER (PARTITION BY $0 ORDER BY $2)])\n" - + " LogicalAggregate(group=[{0, 1}], count=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7], JOB=[$2])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + """ + LogicalProject(DEPTNO=[$0], JOB=[$1]) + LogicalFilter(condition=[<=($3, 10)]) + LogicalProject(DEPTNO=[$0], JOB=[$1], count=[$2], _row_number_=[ROW_NUMBER()\ + OVER (PARTITION BY $0 ORDER BY $2)]) + LogicalAggregate(group=[{0, 1}], count=[COUNT()]) + LogicalProject(DEPTNO=[$7], JOB=[$2]) + LogicalTableScan(table=[[scott, EMP]]) + """; verifyLogical(root, expectedLogical); String expectedResult = - "" - + "DEPTNO=20; JOB=MANAGER\n" - + "DEPTNO=20; JOB=CLERK\n" - + "DEPTNO=20; JOB=ANALYST\n" - + "DEPTNO=10; JOB=MANAGER\n" - + "DEPTNO=10; JOB=CLERK\n" - + "DEPTNO=10; JOB=PRESIDENT\n" - + "DEPTNO=30; JOB=MANAGER\n" - + "DEPTNO=30; JOB=CLERK\n" - + "DEPTNO=30; JOB=SALESMAN\n"; + """ + DEPTNO=20; JOB=MANAGER + DEPTNO=20; JOB=CLERK + DEPTNO=20; JOB=ANALYST + DEPTNO=10; JOB=MANAGER + DEPTNO=10; JOB=CLERK + DEPTNO=10; JOB=PRESIDENT + DEPTNO=30; JOB=MANAGER + DEPTNO=30; JOB=CLERK + DEPTNO=30; JOB=SALESMAN + """; verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT `DEPTNO`, `JOB`\n" - + "FROM (SELECT `DEPTNO`, `JOB`, COUNT(*) `count`, ROW_NUMBER() OVER (PARTITION BY" - + " `DEPTNO` ORDER BY COUNT(*) NULLS LAST) `_row_number_`\n" - + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`, `JOB`) `t1`\n" - + "WHERE `_row_number_` <= 10"; + """ + SELECT `DEPTNO`, `JOB` + FROM (SELECT `DEPTNO`, `JOB`, COUNT(*) `count`, ROW_NUMBER() OVER (PARTITION BY\ + `DEPTNO` ORDER BY COUNT(*) NULLS LAST) `_row_number_` + FROM `scott`.`EMP` + GROUP BY `DEPTNO`, `JOB`) `t1` + WHERE `_row_number_` <= 10"""; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -131,35 +143,39 @@ public void testRareCountField() { String ppl = "source=EMP | rare countfield='my_cnt' JOB by DEPTNO"; RelNode root = getRelNode(ppl); String expectedLogical = - "LogicalProject(DEPTNO=[$0], JOB=[$1], my_cnt=[$2])\n" - + " LogicalFilter(condition=[<=($3, 10)])\n" - + " LogicalProject(DEPTNO=[$0], JOB=[$1], my_cnt=[$2], _row_number_=[ROW_NUMBER()" - + " OVER (PARTITION BY $0 ORDER BY $2)])\n" - + " LogicalAggregate(group=[{0, 1}], my_cnt=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7], JOB=[$2])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + """ + LogicalProject(DEPTNO=[$0], JOB=[$1], my_cnt=[$2]) + LogicalFilter(condition=[<=($3, 10)]) + LogicalProject(DEPTNO=[$0], JOB=[$1], my_cnt=[$2], _row_number_=[ROW_NUMBER()\ + OVER (PARTITION BY $0 ORDER BY $2)]) + LogicalAggregate(group=[{0, 1}], my_cnt=[COUNT()]) + LogicalProject(DEPTNO=[$7], JOB=[$2]) + LogicalTableScan(table=[[scott, EMP]]) + """; verifyLogical(root, expectedLogical); String expectedResult = - "" - + "DEPTNO=20; JOB=MANAGER; my_cnt=1\n" - + "DEPTNO=20; JOB=CLERK; my_cnt=2\n" - + "DEPTNO=20; JOB=ANALYST; my_cnt=2\n" - + "DEPTNO=10; JOB=MANAGER; my_cnt=1\n" - + "DEPTNO=10; JOB=CLERK; my_cnt=1\n" - + "DEPTNO=10; JOB=PRESIDENT; my_cnt=1\n" - + "DEPTNO=30; JOB=MANAGER; my_cnt=1\n" - + "DEPTNO=30; JOB=CLERK; my_cnt=1\n" - + "DEPTNO=30; JOB=SALESMAN; my_cnt=4\n"; + """ + DEPTNO=20; JOB=MANAGER; my_cnt=1 + DEPTNO=20; JOB=CLERK; my_cnt=2 + DEPTNO=20; JOB=ANALYST; my_cnt=2 + DEPTNO=10; JOB=MANAGER; my_cnt=1 + DEPTNO=10; JOB=CLERK; my_cnt=1 + DEPTNO=10; JOB=PRESIDENT; my_cnt=1 + DEPTNO=30; JOB=MANAGER; my_cnt=1 + DEPTNO=30; JOB=CLERK; my_cnt=1 + DEPTNO=30; JOB=SALESMAN; my_cnt=4 + """; verifyResult(root, expectedResult); String expectedSparkSql = - "SELECT `DEPTNO`, `JOB`, `my_cnt`\n" - + "FROM (SELECT `DEPTNO`, `JOB`, COUNT(*) `my_cnt`, ROW_NUMBER() OVER (PARTITION BY" - + " `DEPTNO` ORDER BY COUNT(*) NULLS LAST) `_row_number_`\n" - + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`, `JOB`) `t1`\n" - + "WHERE `_row_number_` <= 10"; + """ + SELECT `DEPTNO`, `JOB`, `my_cnt` + FROM (SELECT `DEPTNO`, `JOB`, COUNT(*) `my_cnt`, ROW_NUMBER() OVER (PARTITION BY\ + `DEPTNO` ORDER BY COUNT(*) NULLS LAST) `_row_number_` + FROM `scott`.`EMP` + GROUP BY `DEPTNO`, `JOB`) `t1` + WHERE `_row_number_` <= 10"""; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -171,15 +187,23 @@ public void failWithDuplicatedName() { } catch (Exception e) { assertThat( e.getMessage(), - is("Field `count` is existed, change the count field by setting countfield='xyz'")); + is( + "The top/rare output field `count` already exists. Suggestion: change the count field" + + " by adding countfield='xyz'")); } + } + + @Test + public void failWithDuplicateOverriddenName() { try { RelNode root = getRelNode("source=EMP | rare countfield='DEPTNO' JOB by DEPTNO"); fail("expected error, got " + root); } catch (Exception e) { assertThat( e.getMessage(), - is("Field `DEPTNO` is existed, change the count field by setting countfield='xyz'")); + is( + "The top/rare output field `DEPTNO` already exists. Suggestion: change the count" + + " field by adding countfield='xyz'")); } } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 4f09a3e02ee..c02af6c9789 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -21,6 +21,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultSortFieldArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultStatsArgs; +import static org.opensearch.sql.ast.dsl.AstDSL.defaultTopRareArgs; import static org.opensearch.sql.ast.dsl.AstDSL.describe; import static org.opensearch.sql.ast.dsl.AstDSL.eval; import static org.opensearch.sql.ast.dsl.AstDSL.exprList; @@ -79,6 +80,7 @@ import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.setting.Settings.Key; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; +import org.opensearch.sql.ppl.utils.ExprLists; import org.opensearch.sql.utils.SystemIndexUtils; public class AstBuilderTest { @@ -630,15 +632,7 @@ public void testIdentifierAsFieldNameStartWithAt() { public void testRareCommand() { assertEqual( "source=t | rare a", - rareTopN( - relation("t"), - CommandType.RARE, - exprList( - argument("noOfResults", intLiteral(10)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), - emptyList(), - field("a"))); + rareTopN(relation("t"), CommandType.RARE, defaultTopRareArgs(), emptyList(), field("a"))); } @Test @@ -648,10 +642,7 @@ public void testRareCommandWithGroupBy() { rareTopN( relation("t"), CommandType.RARE, - exprList( - argument("noOfResults", intLiteral(10)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), + defaultTopRareArgs(), exprList(field("b")), field("a"))); } @@ -663,10 +654,7 @@ public void testRareCommandWithMultipleFields() { rareTopN( relation("t"), CommandType.RARE, - exprList( - argument("noOfResults", intLiteral(10)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), + defaultTopRareArgs(), exprList(field("c")), field("a"), field("b"))); @@ -679,10 +667,7 @@ public void testTopCommandWithN() { rareTopN( relation("t"), CommandType.TOP, - exprList( - argument("noOfResults", intLiteral(1)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), + ExprLists.merge(defaultTopRareArgs(), argument("noOfResults", intLiteral(1))), emptyList(), field("a"))); } @@ -691,43 +676,131 @@ public void testTopCommandWithN() { public void testTopCommandWithoutNAndGroupBy() { assertEqual( "source=t | top a", + rareTopN(relation("t"), CommandType.TOP, defaultTopRareArgs(), emptyList(), field("a"))); + } + + @Test + public void testTopCommandWithNAndGroupBy() { + assertEqual( + "source=t | top 1 a by b", rareTopN( relation("t"), CommandType.TOP, + ExprLists.merge(defaultTopRareArgs(), argument("noOfResults", intLiteral(1))), + exprList(field("b")), + field("a"))); + } + + @Test + public void testTopCommandWithMultipleFields() { + assertEqual( + "source=t | top 1 `a`, `b` by `c`", + rareTopN( + relation("t"), + CommandType.TOP, + ExprLists.merge(defaultTopRareArgs(), argument("noOfResults", intLiteral(1))), + exprList(field("c")), + field("a"), + field("b"))); + } + + @Test + public void testRareCommandWithAllArguments() { + assertEqual( + "source=t | rare 5 countfield='mycounts' showcount=false percentfield='perc' showperc=true" + + " useother=true a", + rareTopN( + relation("t"), + CommandType.RARE, exprList( - argument("noOfResults", intLiteral(10)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), + argument("noOfResults", intLiteral(5)), + argument("countField", stringLiteral("mycounts")), + argument("showCount", booleanLiteral(false)), + argument("percentField", stringLiteral("perc")), + argument("showPerc", booleanLiteral(true)), + argument("useOther", booleanLiteral(true))), emptyList(), field("a"))); } @Test - public void testTopCommandWithNAndGroupBy() { + public void testTopCommandWithCustomFields() { assertEqual( - "source=t | top 1 a by b", + "source=t | top countfield='hits' percentfield='percentage' a", rareTopN( relation("t"), CommandType.TOP, - exprList( - argument("noOfResults", intLiteral(1)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), + ExprLists.merge( + defaultTopRareArgs(), + argument("countField", stringLiteral("hits")), + argument("percentField", stringLiteral("percentage")), + argument("showPerc", booleanLiteral(true))), + emptyList(), + field("a"))); + } + + @Test + public void testTopCommandWithPercentageOnly() { + assertEqual( + "source=t | top showcount=false showperc=true a by b", + rareTopN( + relation("t"), + CommandType.TOP, + ExprLists.merge( + defaultTopRareArgs(), + argument("showCount", booleanLiteral(false)), + argument("showPerc", booleanLiteral(true))), exprList(field("b")), field("a"))); } @Test - public void testTopCommandWithMultipleFields() { + public void testRareCommandWithUseOther() { assertEqual( - "source=t | top 1 `a`, `b` by `c`", + "source=t | rare useother=true a", + rareTopN( + relation("t"), + CommandType.RARE, + ExprLists.merge(defaultTopRareArgs(), argument("useOther", booleanLiteral(true))), + emptyList(), + field("a"))); + } + + @Test + public void testTopCommandWithAllArguments() { + assertEqual( + "source=t | top 20 countfield='cnt' showcount=true percentfield='pct' showperc=true" + + " useother=true a, b by c", + rareTopN( + relation("t"), + CommandType.TOP, + exprList( + argument("noOfResults", intLiteral(20)), + argument("countField", stringLiteral("cnt")), + argument("showCount", booleanLiteral(true)), + argument("percentField", stringLiteral("pct")), + argument("showPerc", booleanLiteral(true)), + argument("useOther", booleanLiteral(true))), + exprList(field("c")), + field("a"), + field("b"))); + } + + @Test + public void testTopCommandWithAllArgumentsShuffled() { + assertEqual( + "source=t | top 20 showcount=true percentfield='pct' countfield='cnt' useother=true" + + " showperc=true a, b by c", rareTopN( relation("t"), CommandType.TOP, exprList( - argument("noOfResults", intLiteral(1)), - argument("countField", stringLiteral("count")), - argument("showCount", booleanLiteral(true))), + argument("noOfResults", intLiteral(20)), + argument("countField", stringLiteral("cnt")), + argument("showCount", booleanLiteral(true)), + argument("percentField", stringLiteral("pct")), + argument("showPerc", booleanLiteral(true)), + argument("useOther", booleanLiteral(true))), exprList(field("c")), field("a"), field("b"))); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/ExprLists.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/ExprLists.java new file mode 100644 index 00000000000..8aaaffa51b6 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/ExprLists.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import java.util.List; +import org.opensearch.sql.ast.expression.Argument; + +public class ExprLists { + /** + * Apply all the given argument updates to the base argument list + * + * @param base The original list of arguments + * @param updates All updates to apply on the base (either adding new results or updating existing + * ones) + * @return The merged result + */ + public static List merge(List base, Argument... updates) { + for (Argument update : updates) { + boolean updated = false; + for (int i = 0; i < base.size(); i++) { + if (base.get(i).getArgName().equals(update.getArgName())) { + base.set(i, update); + updated = true; + } + } + if (!updated) { + base.add(update); + } + } + return base; + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index 44392cd9f57..75b1e83c50e 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -299,7 +299,7 @@ public void testTopCommandWithNAndGroupBy() { public void testRareCommandWithGroupByWithCalcite() { when(settings.getSettingValue(Key.CALCITE_ENGINE_ENABLED)).thenReturn(true); assertEquals( - "source=table | rare 10 countield='count' showcount=true identifier by identifier", + "source=table | rare 10 countfield='count' showperc=false identifier by identifier", anonymize("source=t | rare a by b")); } @@ -307,7 +307,7 @@ public void testRareCommandWithGroupByWithCalcite() { public void testTopCommandWithNAndGroupByWithCalcite() { when(settings.getSettingValue(Key.CALCITE_ENGINE_ENABLED)).thenReturn(true); assertEquals( - "source=table | top 1 countield='count' showcount=true identifier by identifier", + "source=table | top 1 countfield='count' showperc=false identifier by identifier", anonymize("source=t | top 1 a by b")); }