diff --git a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 index e03d57ba7e2a..e699bf8778ad 100644 --- a/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 +++ b/core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4 @@ -591,10 +591,8 @@ primaryExpression (ON OVERFLOW listAggOverflowBehavior)? ')' (WITHIN GROUP '(' orderBy ')') filter? over? #listagg - | processingMode? qualifiedName '(' (label=identifier '.')? ASTERISK ')' - filter? over? #functionCall - | processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)? - orderBy? ')' filter? (nullTreatment? over)? #functionCall + | function #functions + | left=primaryExpression '.' right=function #chainedFunctionCalls | identifier over #measure | identifier '->' expression #lambda | '(' (identifier (',' identifier)*)? ')' '->' expression #lambda @@ -659,6 +657,13 @@ primaryExpression ')' #jsonArray ; +function + : processingMode? qualifiedName '(' (label=identifier '.')? ASTERISK ')' + filter? over? #functionCall + | processingMode? qualifiedName '(' (setQuantifier? expression (',' expression)*)? + orderBy? ')' filter? (nullTreatment? over)? #functionCall + ; + literal : interval #intervalLiteral | identifier string #typeConstructor diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestChainedFunctionCalls.java b/core/trino-main/src/test/java/io/trino/sql/query/TestChainedFunctionCalls.java new file mode 100644 index 000000000000..0455294fea82 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestChainedFunctionCalls.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.query; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.parallel.Execution; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; +import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; + +@TestInstance(PER_CLASS) +@Execution(CONCURRENT) +final class TestChainedFunctionCalls +{ + private final QueryAssertions assertions = new QueryAssertions(); + + @AfterAll + void teardown() + { + assertions.close(); + } + + @Test + void testChainedFunctionCalls() + { + assertThat(assertions.query("SELECT ('hello').upper().concat(' world!')")) + .matches("SELECT VARCHAR 'HELLO world!'"); + + assertThat(assertions.query("SELECT (-123).abs()")) + .matches("SELECT 123"); + } + + @Test + void testQualifiedFunctionName() + { + assertThat(assertions.query("SELECT (-123).system.builtin.abs()")) + .matches("SELECT 123"); + } + + @Test + void testInvalidType() + { + assertThat(assertions.query("SELECT ('hello').abs()")).failure() + .hasMessage("line 1:8: Unexpected parameters (varchar(5)) for function abs. " + + "Expected: abs(bigint), abs(decimal(p,s)), abs(double), abs(integer), abs(real), abs(smallint), abs(tinyint)"); + } + + @Test + void testInvalidParameter() + { + assertThat(assertions.query("SELECT (-123).e()")).failure() + .hasMessage("line 1:8: Unexpected parameters (integer) for function e. Expected: e()"); + } +} diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 6cdc3086b816..ba54ff0b84f4 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -2403,6 +2403,19 @@ public Node visitConcatenation(SqlBaseParser.ConcatenationContext context) (Expression) visit(context.right))); } + @Override + public Node visitChainedFunctionCalls(SqlBaseParser.ChainedFunctionCallsContext context) + { + FunctionCall functionCall = (FunctionCall) visit(context.right); + return new FunctionCall( + getLocation(context), + functionCall.getName(), + ImmutableList.builder() + .add((Expression) visit(context.left)) + .addAll(functionCall.getArguments()) + .build()); + } + @Override public Node visitAtTimeZone(SqlBaseParser.AtTimeZoneContext context) { diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java index 3a9118cd3d0d..781f0b42411a 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java @@ -1673,6 +1673,21 @@ public void testSubstringRegisteredFunction() new FunctionCall(QualifiedName.of("substring"), Lists.newArrayList(new StringLiteral(givenString), new LongLiteral("2"), new LongLiteral("3")))))); } + @Test + void testChainedFunctionCalls() + { + assertStatement("SELECT ('hello').concat(' ').concat('world')", + simpleQuery(selectList( + new FunctionCall(QualifiedName.of("concat"), Lists.newArrayList( + new FunctionCall(QualifiedName.of("concat"), Lists.newArrayList(new StringLiteral("hello"), new StringLiteral(" "))), + new StringLiteral("world")))))); + + assertStatement("SELECT (-123).abs()", + simpleQuery(selectList( + new FunctionCall(QualifiedName.of("abs"), Lists.newArrayList(new LongLiteral("-123")))))); + + } + @Test void testCreateBranch() { diff --git a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java index 92e9ec4db3a9..10e5140eca9f 100644 --- a/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java +++ b/core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java @@ -75,7 +75,7 @@ private static Stream statements() Arguments.of("select CAST(-12223222232535343423232435343 AS BIGINT)", "line 1:13: Invalid numeric literal: -12223222232535343423232435343"), Arguments.of("select foo.!", - "line 1:12: mismatched input '!'. Expecting: '*', "), + "line 1:12: mismatched input '!'. Expecting: '*', 'FINAL', 'RUNNING', "), Arguments.of("select foo(,1)", "line 1:12: mismatched input ','. Expecting: ')', '*', 'ALL', 'DISTINCT', 'ORDER', "), Arguments.of("select foo ( ,1)",