diff --git a/partiql-ast/api/partiql-ast.api b/partiql-ast/api/partiql-ast.api index ac6ebdc2ba..6d59f78955 100644 --- a/partiql-ast/api/partiql-ast.api +++ b/partiql-ast/api/partiql-ast.api @@ -5593,7 +5593,7 @@ public final class org/partiql/ast/v1/Ast { public static final fun exprWindowOver (Ljava/util/List;Ljava/util/List;)Lorg/partiql/ast/v1/expr/ExprWindow$Over; public static final fun from (Ljava/util/List;)Lorg/partiql/ast/v1/From; public static final fun fromExpr (Lorg/partiql/ast/v1/expr/Expr;Lorg/partiql/ast/v1/FromType;Lorg/partiql/ast/v1/Identifier;Lorg/partiql/ast/v1/Identifier;)Lorg/partiql/ast/v1/FromExpr; - public static final fun fromJoin (Lorg/partiql/ast/v1/From;Lorg/partiql/ast/v1/From;Lorg/partiql/ast/v1/JoinType;Lorg/partiql/ast/v1/expr/Expr;)Lorg/partiql/ast/v1/FromJoin; + public static final fun fromJoin (Lorg/partiql/ast/v1/FromTableRef;Lorg/partiql/ast/v1/FromTableRef;Lorg/partiql/ast/v1/JoinType;Lorg/partiql/ast/v1/expr/Expr;)Lorg/partiql/ast/v1/FromJoin; public static final fun graphLabelConj (Lorg/partiql/ast/v1/graph/GraphLabel;Lorg/partiql/ast/v1/graph/GraphLabel;)Lorg/partiql/ast/v1/graph/GraphLabel$Conj; public static final fun graphLabelDisj (Lorg/partiql/ast/v1/graph/GraphLabel;Lorg/partiql/ast/v1/graph/GraphLabel;)Lorg/partiql/ast/v1/graph/GraphLabel$Disj; public static final fun graphLabelName (Ljava/lang/String;)Lorg/partiql/ast/v1/graph/GraphLabel$Name; @@ -5737,7 +5737,7 @@ public abstract interface class org/partiql/ast/v1/AstVisitor { public abstract fun visitTableRef (Lorg/partiql/ast/v1/FromTableRef;Ljava/lang/Object;)Ljava/lang/Object; } -public class org/partiql/ast/v1/DataType : org/partiql/ast/v1/Enum { +public class org/partiql/ast/v1/DataType : org/partiql/ast/v1/AstNode, org/partiql/ast/v1/Enum { public static final field BAG I public static final field BIGINT I public static final field BINARY_LARGE_OBJECT I @@ -5859,7 +5859,9 @@ public class org/partiql/ast/v1/DataType : org/partiql/ast/v1/Enum { public static fun USER_DEFINED (Lorg/partiql/ast/v1/IdentifierChain;)Lorg/partiql/ast/v1/DataType; public static fun VARCHAR ()Lorg/partiql/ast/v1/DataType; public static fun VARCHAR (I)Lorg/partiql/ast/v1/DataType; + public fun accept (Lorg/partiql/ast/v1/AstVisitor;Ljava/lang/Object;)Ljava/lang/Object; protected fun canEqual (Ljava/lang/Object;)Z + public fun children ()Ljava/util/Collection; public fun code ()I public fun equals (Ljava/lang/Object;)Z public fun getLength ()Ljava/lang/Integer; @@ -6054,9 +6056,9 @@ public class org/partiql/ast/v1/FromExpr$Builder { public class org/partiql/ast/v1/FromJoin : org/partiql/ast/v1/FromTableRef { public final field condition Lorg/partiql/ast/v1/expr/Expr; public final field joinType Lorg/partiql/ast/v1/JoinType; - public final field lhs Lorg/partiql/ast/v1/From; - public final field rhs Lorg/partiql/ast/v1/From; - public fun (Lorg/partiql/ast/v1/From;Lorg/partiql/ast/v1/From;Lorg/partiql/ast/v1/JoinType;Lorg/partiql/ast/v1/expr/Expr;)V + public final field lhs Lorg/partiql/ast/v1/FromTableRef; + public final field rhs Lorg/partiql/ast/v1/FromTableRef; + public fun (Lorg/partiql/ast/v1/FromTableRef;Lorg/partiql/ast/v1/FromTableRef;Lorg/partiql/ast/v1/JoinType;Lorg/partiql/ast/v1/expr/Expr;)V public fun accept (Lorg/partiql/ast/v1/AstVisitor;Ljava/lang/Object;)Ljava/lang/Object; public static fun builder ()Lorg/partiql/ast/v1/FromJoin$Builder; protected fun canEqual (Ljava/lang/Object;)Z @@ -6069,8 +6071,8 @@ public class org/partiql/ast/v1/FromJoin$Builder { public fun build ()Lorg/partiql/ast/v1/FromJoin; public fun condition (Lorg/partiql/ast/v1/expr/Expr;)Lorg/partiql/ast/v1/FromJoin$Builder; public fun joinType (Lorg/partiql/ast/v1/JoinType;)Lorg/partiql/ast/v1/FromJoin$Builder; - public fun lhs (Lorg/partiql/ast/v1/From;)Lorg/partiql/ast/v1/FromJoin$Builder; - public fun rhs (Lorg/partiql/ast/v1/From;)Lorg/partiql/ast/v1/FromJoin$Builder; + public fun lhs (Lorg/partiql/ast/v1/FromTableRef;)Lorg/partiql/ast/v1/FromJoin$Builder; + public fun rhs (Lorg/partiql/ast/v1/FromTableRef;)Lorg/partiql/ast/v1/FromJoin$Builder; public fun toString ()Ljava/lang/String; } diff --git a/partiql-ast/src/main/java/org/partiql/ast/v1/Ast.kt b/partiql-ast/src/main/java/org/partiql/ast/v1/Ast.kt index 351163c324..fa34286673 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/v1/Ast.kt +++ b/partiql-ast/src/main/java/org/partiql/ast/v1/Ast.kt @@ -402,7 +402,7 @@ public object Ast { } @JvmStatic - public fun fromJoin(lhs: From, rhs: From, joinType: JoinType?, condition: Expr?): FromJoin { + public fun fromJoin(lhs: FromTableRef, rhs: FromTableRef, joinType: JoinType?, condition: Expr?): FromJoin { return FromJoin(lhs, rhs, joinType, condition) } diff --git a/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java b/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java index 70804ebe3e..546ca7a79b 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java +++ b/partiql-ast/src/main/java/org/partiql/ast/v1/DataType.java @@ -3,8 +3,12 @@ import lombok.EqualsAndHashCode; import org.jetbrains.annotations.NotNull; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + @EqualsAndHashCode(callSuper = false) -public class DataType implements Enum { +public class DataType extends AstNode implements Enum { public static final int UNKNOWN = 0; // public static final int NULL = 1; @@ -530,4 +534,19 @@ public Integer getLength() { public IdentifierChain getName() { return name; } + + @NotNull + @Override + public Collection children() { + List kids = new ArrayList<>(); + if (name != null) { + kids.add(name); + } + return kids; + } + + @Override + public R accept(@NotNull AstVisitor visitor, C ctx) { + return null; + } } diff --git a/partiql-ast/src/main/java/org/partiql/ast/v1/FromJoin.java b/partiql-ast/src/main/java/org/partiql/ast/v1/FromJoin.java index 4c573f8dee..f87b0405e0 100644 --- a/partiql-ast/src/main/java/org/partiql/ast/v1/FromJoin.java +++ b/partiql-ast/src/main/java/org/partiql/ast/v1/FromJoin.java @@ -17,10 +17,10 @@ @EqualsAndHashCode(callSuper = false) public class FromJoin extends FromTableRef { @NotNull - public final From lhs; + public final FromTableRef lhs; @NotNull - public final From rhs; + public final FromTableRef rhs; @Nullable public final JoinType joinType; @@ -28,7 +28,7 @@ public class FromJoin extends FromTableRef { @Nullable public final Expr condition; - public FromJoin(@NotNull From lhs, @NotNull From rhs, @Nullable JoinType joinType, @Nullable Expr condition) { + public FromJoin(@NotNull FromTableRef lhs, @NotNull FromTableRef rhs, @Nullable JoinType joinType, @Nullable Expr condition) { this.lhs = lhs; this.rhs = rhs; this.joinType = joinType; diff --git a/partiql-parser/api/partiql-parser.api b/partiql-parser/api/partiql-parser.api index 0ae383747e..8870548876 100644 --- a/partiql-parser/api/partiql-parser.api +++ b/partiql-parser/api/partiql-parser.api @@ -118,3 +118,35 @@ public final class org/partiql/parser/SourceLocations : java/util/Map, kotlin/jv public final fun values ()Ljava/util/Collection; } +public abstract interface class org/partiql/parser/V1PartiQLParser { + public static final field Companion Lorg/partiql/parser/V1PartiQLParser$Companion; + public static fun builder ()Lorg/partiql/parser/V1PartiQLParserBuilder; + public abstract fun parse (Ljava/lang/String;)Lorg/partiql/parser/V1PartiQLParser$Result; + public static fun standard ()Lorg/partiql/parser/V1PartiQLParser; +} + +public final class org/partiql/parser/V1PartiQLParser$Companion { + public final fun builder ()Lorg/partiql/parser/V1PartiQLParserBuilder; + public final fun standard ()Lorg/partiql/parser/V1PartiQLParser; +} + +public final class org/partiql/parser/V1PartiQLParser$Result { + public fun (Ljava/lang/String;Lorg/partiql/ast/v1/Statement;Lorg/partiql/parser/SourceLocations;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Lorg/partiql/ast/v1/Statement; + public final fun component3 ()Lorg/partiql/parser/SourceLocations; + public final fun copy (Ljava/lang/String;Lorg/partiql/ast/v1/Statement;Lorg/partiql/parser/SourceLocations;)Lorg/partiql/parser/V1PartiQLParser$Result; + public static synthetic fun copy$default (Lorg/partiql/parser/V1PartiQLParser$Result;Ljava/lang/String;Lorg/partiql/ast/v1/Statement;Lorg/partiql/parser/SourceLocations;ILjava/lang/Object;)Lorg/partiql/parser/V1PartiQLParser$Result; + public fun equals (Ljava/lang/Object;)Z + public final fun getLocations ()Lorg/partiql/parser/SourceLocations; + public final fun getRoot ()Lorg/partiql/ast/v1/Statement; + public final fun getSource ()Ljava/lang/String; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class org/partiql/parser/V1PartiQLParserBuilder { + public fun ()V + public final fun build ()Lorg/partiql/parser/V1PartiQLParser; +} + diff --git a/partiql-parser/src/main/antlr/PartiQLParser.g4 b/partiql-parser/src/main/antlr/PartiQLParser.g4 index 11c7c6d1e8..b364d6c145 100644 --- a/partiql-parser/src/main/antlr/PartiQLParser.g4 +++ b/partiql-parser/src/main/antlr/PartiQLParser.g4 @@ -358,7 +358,7 @@ excludeExprSteps ; fromClause - : FROM tableReference; + : FROM ( tableReference ( COMMA tableReference)* ); whereClauseSelect : WHERE arg=exprSelect; @@ -468,16 +468,15 @@ edgeAbbrev */ tableReference - : lhs=tableReference joinType? CROSS JOIN rhs=joinRhs # TableCrossJoin - | lhs=tableReference COMMA rhs=joinRhs # TableCrossJoin - | lhs=tableReference joinType? JOIN rhs=joinRhs joinSpec # TableQualifiedJoin - | tableNonJoin # TableRefBase - | PAREN_LEFT tableReference PAREN_RIGHT # TableWrapped + : tablePrimary # TableRefPrimary + | lhs=tableReference CROSS JOIN rhs=tablePrimary # TableCrossJoin + | lhs=tableReference joinType? JOIN rhs=tableReference joinSpec # TableQualifiedJoin ; -tableNonJoin +tablePrimary : tableBaseReference | tableUnpivot + | tableWrapped ; tableBaseReference @@ -487,15 +486,11 @@ tableBaseReference ; tableUnpivot - : UNPIVOT expr asIdent? atIdent? byIdent?; - -joinRhs - : tableNonJoin # JoinRhsBase - | PAREN_LEFT tableReference PAREN_RIGHT # JoinRhsTableJoined + : UNPIVOT expr asIdent? atIdent? byIdent? ; -joinSpec - : ON expr; +tableWrapped + : PAREN_LEFT tableReference PAREN_RIGHT; joinType : mod=INNER @@ -505,6 +500,9 @@ joinType | mod=OUTER ; +joinSpec + : ON expr; + /** * * EXPRESSIONS & PRECEDENCE diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/V1PartiQLParser.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/V1PartiQLParser.kt new file mode 100644 index 0000000000..03780ee615 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/V1PartiQLParser.kt @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +import org.partiql.ast.v1.Statement +import org.partiql.parser.internal.V1PartiQLParserDefault + +public interface V1PartiQLParser { + + @Throws(PartiQLSyntaxException::class, InterruptedException::class) + public fun parse(source: String): Result + + public data class Result( + val source: String, + val root: Statement, + val locations: SourceLocations, + ) + + public companion object { + + @JvmStatic + public fun builder(): V1PartiQLParserBuilder = V1PartiQLParserBuilder() + + @JvmStatic + public fun standard(): V1PartiQLParser = V1PartiQLParserDefault() + } +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/V1PartiQLParserBuilder.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/V1PartiQLParserBuilder.kt new file mode 100644 index 0000000000..ef5b51ecf1 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/V1PartiQLParserBuilder.kt @@ -0,0 +1,27 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser + +import org.partiql.parser.internal.V1PartiQLParserDefault + +/** + * A builder class to instantiate a [V1PartiQLParser]. + */ +public class V1PartiQLParserBuilder { + + public fun build(): V1PartiQLParser { + return V1PartiQLParserDefault() + } +} diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt index dfdcc3591f..9ec61ae525 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/PartiQLParserDefault.kt @@ -1233,7 +1233,12 @@ internal class PartiQLParserDefault : PartiQLParser { * */ - override fun visitFromClause(ctx: GeneratedParser.FromClauseContext) = visitAs(ctx.tableReference()) + override fun visitFromClause(ctx: GeneratedParser.FromClauseContext) = translate(ctx) { + val tableRefs = visitOrEmpty(ctx.tableReference()) + tableRefs.drop(1).fold(tableRefs.first()) { acc, tableRef -> + fromJoin(acc, tableRef, From.Join.Type.CROSS, null) + } + } override fun visitTableBaseRefClauses(ctx: GeneratedParser.TableBaseRefClausesContext) = translate(ctx) { val expr = visitAs(ctx.source) @@ -1276,8 +1281,7 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitTableCrossJoin(ctx: GeneratedParser.TableCrossJoinContext) = translate(ctx) { val lhs = visitAs(ctx.lhs) val rhs = visitAs(ctx.rhs) - val type = convertJoinType(ctx.joinType()) - fromJoin(lhs, rhs, type, null) + fromJoin(lhs, rhs, null, null) } private fun convertJoinType(ctx: GeneratedParser.JoinTypeContext?): From.Join.Type? { @@ -1323,9 +1327,6 @@ internal class PartiQLParserDefault : PartiQLParser { override fun visitJoinSpec(ctx: GeneratedParser.JoinSpecContext) = visitExpr(ctx.expr()) - override fun visitJoinRhsTableJoined(ctx: GeneratedParser.JoinRhsTableJoinedContext) = - visitAs(ctx.tableReference()) - /** * SIMPLE EXPRESSIONS */ diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt new file mode 100644 index 0000000000..0a215a6f85 --- /dev/null +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/internal/V1PartiQLParserDefault.kt @@ -0,0 +1,2196 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at: + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific + * language governing permissions and limitations under the License. + */ + +package org.partiql.parser.internal + +import com.amazon.ionelement.api.IntElement +import com.amazon.ionelement.api.IntElementSize +import com.amazon.ionelement.api.IonElement +import org.antlr.v4.runtime.BailErrorStrategy +import org.antlr.v4.runtime.BaseErrorListener +import org.antlr.v4.runtime.CharStreams +import org.antlr.v4.runtime.CommonTokenStream +import org.antlr.v4.runtime.ParserRuleContext +import org.antlr.v4.runtime.RecognitionException +import org.antlr.v4.runtime.Recognizer +import org.antlr.v4.runtime.Token +import org.antlr.v4.runtime.TokenSource +import org.antlr.v4.runtime.TokenStream +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException +import org.antlr.v4.runtime.tree.TerminalNode +import org.partiql.ast.v1.Ast +import org.partiql.ast.v1.Ast.exclude +import org.partiql.ast.v1.Ast.excludePath +import org.partiql.ast.v1.Ast.excludeStepCollIndex +import org.partiql.ast.v1.Ast.excludeStepCollWildcard +import org.partiql.ast.v1.Ast.excludeStepStructField +import org.partiql.ast.v1.Ast.excludeStepStructWildcard +import org.partiql.ast.v1.Ast.explain +import org.partiql.ast.v1.Ast.exprAnd +import org.partiql.ast.v1.Ast.exprArray +import org.partiql.ast.v1.Ast.exprBag +import org.partiql.ast.v1.Ast.exprBetween +import org.partiql.ast.v1.Ast.exprCall +import org.partiql.ast.v1.Ast.exprCase +import org.partiql.ast.v1.Ast.exprCaseBranch +import org.partiql.ast.v1.Ast.exprCast +import org.partiql.ast.v1.Ast.exprCoalesce +import org.partiql.ast.v1.Ast.exprExtract +import org.partiql.ast.v1.Ast.exprInCollection +import org.partiql.ast.v1.Ast.exprIsType +import org.partiql.ast.v1.Ast.exprLike +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprMatch +import org.partiql.ast.v1.Ast.exprNot +import org.partiql.ast.v1.Ast.exprNullIf +import org.partiql.ast.v1.Ast.exprOperator +import org.partiql.ast.v1.Ast.exprOr +import org.partiql.ast.v1.Ast.exprOverlay +import org.partiql.ast.v1.Ast.exprParameter +import org.partiql.ast.v1.Ast.exprPath +import org.partiql.ast.v1.Ast.exprPathStepAllElements +import org.partiql.ast.v1.Ast.exprPathStepAllFields +import org.partiql.ast.v1.Ast.exprPathStepElement +import org.partiql.ast.v1.Ast.exprPathStepField +import org.partiql.ast.v1.Ast.exprPosition +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.exprSessionAttribute +import org.partiql.ast.v1.Ast.exprStruct +import org.partiql.ast.v1.Ast.exprStructField +import org.partiql.ast.v1.Ast.exprSubstring +import org.partiql.ast.v1.Ast.exprTrim +import org.partiql.ast.v1.Ast.exprVarRef +import org.partiql.ast.v1.Ast.exprVariant +import org.partiql.ast.v1.Ast.exprWindow +import org.partiql.ast.v1.Ast.exprWindowOver +import org.partiql.ast.v1.Ast.from +import org.partiql.ast.v1.Ast.fromExpr +import org.partiql.ast.v1.Ast.fromJoin +import org.partiql.ast.v1.Ast.graphLabelConj +import org.partiql.ast.v1.Ast.graphLabelDisj +import org.partiql.ast.v1.Ast.graphLabelName +import org.partiql.ast.v1.Ast.graphLabelNegation +import org.partiql.ast.v1.Ast.graphLabelWildcard +import org.partiql.ast.v1.Ast.graphMatch +import org.partiql.ast.v1.Ast.graphMatchEdge +import org.partiql.ast.v1.Ast.graphMatchNode +import org.partiql.ast.v1.Ast.graphMatchPattern +import org.partiql.ast.v1.Ast.graphPattern +import org.partiql.ast.v1.Ast.graphQuantifier +import org.partiql.ast.v1.Ast.graphSelectorAllShortest +import org.partiql.ast.v1.Ast.graphSelectorAny +import org.partiql.ast.v1.Ast.graphSelectorAnyK +import org.partiql.ast.v1.Ast.graphSelectorAnyShortest +import org.partiql.ast.v1.Ast.graphSelectorShortestK +import org.partiql.ast.v1.Ast.graphSelectorShortestKGroup +import org.partiql.ast.v1.Ast.groupBy +import org.partiql.ast.v1.Ast.groupByKey +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.Ast.letBinding +import org.partiql.ast.v1.Ast.orderBy +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.Ast.queryBodySFW +import org.partiql.ast.v1.Ast.queryBodySetOp +import org.partiql.ast.v1.Ast.selectItemExpr +import org.partiql.ast.v1.Ast.selectItemStar +import org.partiql.ast.v1.Ast.selectList +import org.partiql.ast.v1.Ast.selectPivot +import org.partiql.ast.v1.Ast.selectStar +import org.partiql.ast.v1.Ast.selectValue +import org.partiql.ast.v1.Ast.setOp +import org.partiql.ast.v1.Ast.sort +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.DataType +import org.partiql.ast.v1.DatetimeField +import org.partiql.ast.v1.Exclude +import org.partiql.ast.v1.ExcludeStep +import org.partiql.ast.v1.From +import org.partiql.ast.v1.FromTableRef +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.GroupBy +import org.partiql.ast.v1.GroupByStrategy +import org.partiql.ast.v1.Identifier +import org.partiql.ast.v1.IdentifierChain +import org.partiql.ast.v1.JoinType +import org.partiql.ast.v1.Let +import org.partiql.ast.v1.Nulls +import org.partiql.ast.v1.Order +import org.partiql.ast.v1.Select +import org.partiql.ast.v1.SelectItem +import org.partiql.ast.v1.SetOpType +import org.partiql.ast.v1.SetQuantifier +import org.partiql.ast.v1.Sort +import org.partiql.ast.v1.Statement +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprArray +import org.partiql.ast.v1.expr.ExprBag +import org.partiql.ast.v1.expr.ExprCall +import org.partiql.ast.v1.expr.ExprPath +import org.partiql.ast.v1.expr.ExprQuerySet +import org.partiql.ast.v1.expr.PathStep +import org.partiql.ast.v1.expr.Scope +import org.partiql.ast.v1.expr.SessionAttribute +import org.partiql.ast.v1.expr.TrimSpec +import org.partiql.ast.v1.expr.WindowFunction +import org.partiql.ast.v1.graph.GraphDirection +import org.partiql.ast.v1.graph.GraphLabel +import org.partiql.ast.v1.graph.GraphPart +import org.partiql.ast.v1.graph.GraphPattern +import org.partiql.ast.v1.graph.GraphQuantifier +import org.partiql.ast.v1.graph.GraphRestrictor +import org.partiql.ast.v1.graph.GraphSelector +import org.partiql.parser.PartiQLLexerException +import org.partiql.parser.PartiQLParserException +import org.partiql.parser.PartiQLSyntaxException +import org.partiql.parser.SourceLocation +import org.partiql.parser.SourceLocations +import org.partiql.parser.V1PartiQLParser +import org.partiql.parser.internal.antlr.PartiQLParser +import org.partiql.parser.internal.antlr.PartiQLParserBaseVisitor +import org.partiql.parser.internal.util.DateTimeUtils +import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.boolValue +import org.partiql.value.dateValue +import org.partiql.value.datetime.DateTimeException +import org.partiql.value.datetime.DateTimeValue +import org.partiql.value.decimalValue +import org.partiql.value.int32Value +import org.partiql.value.int64Value +import org.partiql.value.intValue +import org.partiql.value.missingValue +import org.partiql.value.nullValue +import org.partiql.value.stringValue +import org.partiql.value.timeValue +import org.partiql.value.timestampValue +import java.math.BigDecimal +import java.math.BigInteger +import java.math.MathContext +import java.math.RoundingMode +import java.nio.channels.ClosedByInterruptException +import java.nio.charset.StandardCharsets +import java.time.LocalDate +import java.time.format.DateTimeFormatter +import java.time.format.DateTimeParseException +import org.partiql.parser.internal.antlr.PartiQLParser as GeneratedParser +import org.partiql.parser.internal.antlr.PartiQLTokens as GeneratedLexer + +/** + * ANTLR Based Implementation of a PartiQLParser + * + * SLL Prediction Mode + * ------------------- + * The [PredictionMode.SLL] mode uses the [BailErrorStrategy]. The [GeneratedParser], upon seeing a syntax error, + * will throw a [ParseCancellationException] due to the [GeneratedParser.getErrorHandler] + * being a [BailErrorStrategy]. The purpose of this is to throw syntax errors as quickly as possible once encountered. + * As noted by the [PredictionMode.SLL] documentation, to guarantee results, it is useful to follow up a failed parse + * by parsing with [PredictionMode.LL]. See the JavaDocs for [PredictionMode.SLL] and [BailErrorStrategy] for more. + * + * LL Prediction Mode + * ------------------ + * The [PredictionMode.LL] mode is capable of parsing all valid inputs for a grammar, + * but is slower than [PredictionMode.SLL]. Upon seeing a syntax error, this parser throws a [PartiQLParserException]. + */ +internal class V1PartiQLParserDefault : V1PartiQLParser { + + @Throws(PartiQLSyntaxException::class, InterruptedException::class) + override fun parse(source: String): V1PartiQLParser.Result { + try { + return V1PartiQLParserDefault.parse(source) + } catch (throwable: Throwable) { + throw PartiQLSyntaxException.wrap(throwable) + } + } + + companion object { + + /** + * To reduce latency costs, the [V1PartiQLParserDefault] attempts to use [PredictionMode.SLL] and falls back to + * [PredictionMode.LL] if a [ParseCancellationException] is thrown by the [BailErrorStrategy]. + */ + private fun parse(source: String): V1PartiQLParser.Result = try { + parse(source, PredictionMode.SLL) + } catch (ex: ParseCancellationException) { + parse(source, PredictionMode.LL) + } + + /** + * Parses an input string [source] using the given prediction mode. + */ + private fun parse(source: String, mode: PredictionMode): V1PartiQLParser.Result { + val tokens = createTokenStream(source) + val parser = InterruptibleParser(tokens) + parser.reset() + parser.removeErrorListeners() + parser.interpreter.predictionMode = mode + when (mode) { + PredictionMode.SLL -> parser.errorHandler = BailErrorStrategy() + PredictionMode.LL -> parser.addErrorListener(ParseErrorListener()) + else -> throw IllegalArgumentException("Unsupported parser mode: $mode") + } + val tree = parser.root() + return Visitor.translate(source, tokens, tree) + } + + private fun createTokenStream(source: String): CountingTokenStream { + val queryStream = source.byteInputStream(StandardCharsets.UTF_8) + val inputStream = try { + CharStreams.fromStream(queryStream) + } catch (ex: ClosedByInterruptException) { + throw InterruptedException() + } + val handler = TokenizeErrorListener() + val lexer = GeneratedLexer(inputStream) + lexer.removeErrorListeners() + lexer.addErrorListener(handler) + return CountingTokenStream(lexer) + } + } + + /** + * Catches Lexical errors (unidentified tokens) and throws a [PartiQLParserException] + */ + private class TokenizeErrorListener : BaseErrorListener() { + @Throws(PartiQLParserException::class) + override fun syntaxError( + recognizer: Recognizer<*, *>?, + offendingSymbol: Any?, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException?, + ) { + if (offendingSymbol is Token) { + val token = offendingSymbol.text + val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) + throw PartiQLLexerException( + token = token, + tokenType = tokenType, + message = msg, + cause = e, + location = SourceLocation( + line = line, + offset = charPositionInLine + 1, + length = token.length, + lengthLegacy = token.length, + ), + ) + } else { + throw IllegalArgumentException("Offending symbol is not a Token.") + } + } + } + + /** + * Catches Parser errors (malformed syntax) and throws a [PartiQLParserException] + */ + private class ParseErrorListener : BaseErrorListener() { + + private val rules = GeneratedParser.ruleNames.asList() + + @Throws(PartiQLParserException::class) + override fun syntaxError( + recognizer: Recognizer<*, *>?, + offendingSymbol: Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException?, + ) { + if (offendingSymbol is Token) { + val rule = e?.ctx?.toString(rules) ?: "UNKNOWN" + val token = offendingSymbol.text + val tokenType = GeneratedParser.VOCABULARY.getSymbolicName(offendingSymbol.type) + throw PartiQLParserException( + rule = rule, + token = token, + tokenType = tokenType, + message = msg, + cause = e, + location = SourceLocation( + line = line, + offset = charPositionInLine + 1, + length = msg.length, + lengthLegacy = offendingSymbol.text.length, + ), + ) + } else { + throw IllegalArgumentException("Offending symbol is not a Token.") + } + } + } + + /** + * A wrapped [GeneratedParser] to allow thread interruption during parse. + */ + internal class InterruptibleParser(input: TokenStream) : GeneratedParser(input) { + override fun enterRule(localctx: ParserRuleContext?, state: Int, ruleIndex: Int) { + if (Thread.interrupted()) { + throw InterruptedException() + } + super.enterRule(localctx, state, ruleIndex) + } + } + + /** + * This token stream creates [parameterIndexes], which is a map, where the keys represent the + * indexes of all [GeneratedLexer.QUESTION_MARK]'s and the values represent their relative index amongst all other + * [GeneratedLexer.QUESTION_MARK]'s. + */ + internal open class CountingTokenStream(tokenSource: TokenSource) : CommonTokenStream(tokenSource) { + // TODO: Research use-case of parameters and implementation -- see https://github.com/partiql/partiql-docs/issues/23 + val parameterIndexes = mutableMapOf() + private var parametersFound = 0 + override fun LT(k: Int): Token? { + val token = super.LT(k) + token?.let { + if (it.type == GeneratedLexer.QUESTION_MARK && parameterIndexes.containsKey(token.tokenIndex).not()) { + parameterIndexes[token.tokenIndex] = ++parametersFound + } + } + return token + } + } + + /** + * Translate an ANTLR ParseTree to a PartiQL + */ + @OptIn(PartiQLValueExperimental::class) + private class Visitor( + private val tokens: CommonTokenStream, + private val locations: SourceLocations.Mutable, + private val parameters: Map = mapOf(), + ) : PartiQLParserBaseVisitor() { + + companion object { + + private val rules = GeneratedParser.ruleNames.asList() + + /** + * Expose an (internal) friendly entry point into the traversal; mostly for keeping mutable state contained. + */ + fun translate( + source: String, + tokens: CountingTokenStream, + tree: GeneratedParser.RootContext, + ): V1PartiQLParser.Result { + val locations = SourceLocations.Mutable() + val visitor = Visitor(tokens, locations, tokens.parameterIndexes) + val root = visitor.visitAs(tree) as Statement + return V1PartiQLParser.Result( + source = source, + root = root, + locations = locations.toMap(), + ) + } + + fun error( + ctx: ParserRuleContext, + message: String, + cause: Throwable? = null, + ) = PartiQLParserException( + rule = ctx.toStringTree(rules), + token = ctx.start.text, + tokenType = GeneratedParser.VOCABULARY.getSymbolicName(ctx.start.type), + message = message, + cause = cause, + location = SourceLocation( + line = ctx.start.line, + offset = ctx.start.charPositionInLine + 1, + length = ctx.stop.stopIndex - ctx.start.startIndex, + lengthLegacy = ctx.start.text.length, + ), + ) + + fun error( + token: Token, + message: String, + cause: Throwable? = null, + ) = PartiQLLexerException( + token = token.text, + tokenType = GeneratedParser.VOCABULARY.getSymbolicName(token.type), + message = message, + cause = cause, + location = SourceLocation( + line = token.line, + offset = token.charPositionInLine + 1, + length = token.stopIndex - token.startIndex, + lengthLegacy = token.text.length, + ), + ) + + internal val DATE_PATTERN_REGEX = Regex("\\d\\d\\d\\d-\\d\\d-\\d\\d") + + internal val GENERIC_TIME_REGEX = Regex("\\d\\d:\\d\\d:\\d\\d(\\.\\d*)?([+|-]\\d\\d:\\d\\d)?") + } + + /** + * Each visit attaches source locations from the given parse tree node; constructs nodes via the factory. + */ + private inline fun translate(ctx: ParserRuleContext, block: () -> T): T { + val node = block() + if (ctx.start != null) { + locations[node.tag] = SourceLocation( + line = ctx.start.line, + offset = ctx.start.charPositionInLine + 1, + length = (ctx.stop?.stopIndex ?: ctx.start.stopIndex) - ctx.start.startIndex + 1, + lengthLegacy = ctx.start.text.length, // LEGACY LENGTH + ) + } + return node + } + + /** + * + * TOP LEVEL + * + */ + + override fun visitQueryDql(ctx: GeneratedParser.QueryDqlContext): AstNode = visitDql(ctx.dql()) + + override fun visitQueryDml(ctx: GeneratedParser.QueryDmlContext): AstNode = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitRoot(ctx: GeneratedParser.RootContext) = translate(ctx) { + when (ctx.EXPLAIN()) { + null -> visit(ctx.statement()) as Statement + else -> { + var type: String? = null + var format: String? = null + ctx.explainOption().forEach { option -> + val parameter = try { + ExplainParameters.valueOf(option.param.text.uppercase()) + } catch (ex: java.lang.IllegalArgumentException) { + throw error(option.param, "Unknown EXPLAIN parameter.", ex) + } + when (parameter) { + ExplainParameters.TYPE -> { + type = parameter.getCompliantString(type, option.value) + } + ExplainParameters.FORMAT -> { + format = parameter.getCompliantString(format, option.value) + } + } + } + explain( + // TODO get rid of usage of PartiQLValue https://github.com/partiql/partiql-lang-kotlin/issues/1589 + options = mapOf( + "type" to stringValue(type), + "format" to stringValue(format) + ), + statement = visit(ctx.statement()) as Statement, + ) + } + } + } + + /** + * + * COMMON USAGES + * + */ + + override fun visitAsIdent(ctx: GeneratedParser.AsIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitAtIdent(ctx: GeneratedParser.AtIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitByIdent(ctx: GeneratedParser.ByIdentContext) = visitSymbolPrimitive(ctx.symbolPrimitive()) + + private fun visitSymbolPrimitive(ctx: GeneratedParser.SymbolPrimitiveContext): Identifier = + when (ctx) { + is GeneratedParser.IdentifierQuotedContext -> visitIdentifierQuoted(ctx) + is GeneratedParser.IdentifierUnquotedContext -> visitIdentifierUnquoted(ctx) + else -> throw error(ctx, "Invalid symbol reference.") + } + + override fun visitIdentifierQuoted(ctx: GeneratedParser.IdentifierQuotedContext): Identifier = translate(ctx) { + identifier( + ctx.IDENTIFIER_QUOTED().getStringValue(), + true + ) + } + + override fun visitIdentifierUnquoted(ctx: GeneratedParser.IdentifierUnquotedContext): Identifier = translate(ctx) { + identifier( + ctx.text, + false + ) + } + + override fun visitQualifiedName(ctx: GeneratedParser.QualifiedNameContext) = translate(ctx) { + val qualifier = ctx.qualifier.map { visitSymbolPrimitive(it) } + val name = identifierChain(visitSymbolPrimitive(ctx.name), null) + if (qualifier.isEmpty()) { + name + } else { + qualifier.reversed().fold(name) { acc, id -> + identifierChain(root = id, next = acc) + } + } + } + + /** + * + * DATA DEFINITION LANGUAGE (DDL) -- deleted in v1; will be added before final v1 release + * + */ + +// override fun visitQueryDdl(ctx: GeneratedParser.QueryDdlContext): AstNode = visitDdl(ctx.ddl()) +// +// override fun visitDropTable(ctx: GeneratedParser.DropTableContext) = translate(ctx) { +// val table = visitQualifiedName(ctx.qualifiedName()) +// statementDDLDropTable(table) +// } +// +// override fun visitDropIndex(ctx: GeneratedParser.DropIndexContext) = translate(ctx) { +// val table = visitSymbolPrimitive(ctx.on) +// val index = visitSymbolPrimitive(ctx.target) +// statementDDLDropIndex(index, table) +// } +// +// override fun visitCreateTable(ctx: GeneratedParser.CreateTableContext) = translate(ctx) { +// val table = visitQualifiedName(ctx.qualifiedName()) +// val definition = ctx.tableDef()?.let { visitTableDef(it) } +// statementDDLCreateTable(table, definition) +// } +// +// override fun visitCreateIndex(ctx: GeneratedParser.CreateIndexContext) = translate(ctx) { +// // TODO add index name to ANTLR grammar +// val name: Identifier? = null +// val table = visitSymbolPrimitive(ctx.symbolPrimitive()) +// val fields = ctx.pathSimple().map { path -> visitPathSimple(path) } +// statementDDLCreateIndex(name, table, fields) +// } +// +// override fun visitTableDef(ctx: GeneratedParser.TableDefContext) = translate(ctx) { +// // Column Definitions are the only thing we currently allow as table definition parts +// val columns = ctx.tableDefPart().filterIsInstance().map { +// visitColumnDeclaration(it) +// } +// tableDefinition(columns) +// } +// +// override fun visitColumnDeclaration(ctx: GeneratedParser.ColumnDeclarationContext) = translate(ctx) { +// val name = visitSymbolPrimitive(ctx.columnName().symbolPrimitive()).symbol +// val type = visit(ctx.type()) as Type +// val constraints = ctx.columnConstraint().map { +// visitColumnConstraint(it) +// } +// tableDefinitionColumn(name, type, constraints) +// } +// +// override fun visitColumnConstraint(ctx: GeneratedParser.ColumnConstraintContext) = translate(ctx) { +// val identifier = ctx.columnConstraintName()?.let { symbolToString(it.symbolPrimitive()) } +// val body = visit(ctx.columnConstraintDef()) as TableDefinition.Column.Constraint.Body +// tableDefinitionColumnConstraint(identifier, body) +// } +// +// override fun visitColConstrNotNull(ctx: GeneratedParser.ColConstrNotNullContext) = translate(ctx) { +// tableDefinitionColumnConstraintBodyNotNull() +// } +// +// override fun visitColConstrNull(ctx: GeneratedParser.ColConstrNullContext) = translate(ctx) { +// tableDefinitionColumnConstraintBodyNullable() +// } + + /** + * + * EXECUTE + * + */ + + override fun visitQueryExec(ctx: GeneratedParser.QueryExecContext) = translate(ctx) { + throw error(ctx, "EXEC no longer supported in the default PartiQLParser.") + } + + /** + * TODO EXEC accepts an `expr` as the procedure name so we have to unpack the string. + * - https://github.com/partiql/partiql-lang-kotlin/issues/707 + */ + override fun visitExecCommand(ctx: GeneratedParser.ExecCommandContext) = translate(ctx) { + throw error(ctx, "EXEC no longer supported in the default PartiQLParser.") + } + + /** + * + * DATA MANIPULATION LANGUAGE (DML) + * + */ + + /** + * The PartiQL grammars allows for multiple DML commands in one UPDATE statement. + * This function unwraps DML commands to the more limited DML.BatchLegacy.Op commands. + */ + override fun visitDmlBaseWrapper(ctx: GeneratedParser.DmlBaseWrapperContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDmlDelete(ctx: GeneratedParser.DmlDeleteContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDmlInsertReturning(ctx: GeneratedParser.DmlInsertReturningContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDmlBase(ctx: GeneratedParser.DmlBaseContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDmlBaseCommand(ctx: GeneratedParser.DmlBaseCommandContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitRemoveCommand(ctx: GeneratedParser.RemoveCommandContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDeleteCommand(ctx: GeneratedParser.DeleteCommandContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + /** + * Legacy INSERT with RETURNING clause is not represented in the AST as this grammar .. + * .. only exists for backwards compatibility. The RETURNING clause is ignored. + * + * TODO remove insertCommandReturning grammar rule + * - https://github.com/partiql/partiql-lang-kotlin/issues/698 + * - https://github.com/partiql/partiql-lang-kotlin/issues/708 + */ + override fun visitInsertCommandReturning(ctx: GeneratedParser.InsertCommandReturningContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitInsertStatementLegacy(ctx: GeneratedParser.InsertStatementLegacyContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitInsertStatement(ctx: GeneratedParser.InsertStatementContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitReplaceCommand(ctx: GeneratedParser.ReplaceCommandContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitUpsertCommand(ctx: GeneratedParser.UpsertCommandContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitReturningClause(ctx: GeneratedParser.ReturningClauseContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitReturningColumn(ctx: GeneratedParser.ReturningColumnContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitOnConflict(ctx: GeneratedParser.OnConflictContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + /** + * TODO Remove this when we remove INSERT LEGACY as no other conflict actions are allowed in PartiQL.g4. + */ + override fun visitOnConflictLegacy(ctx: GeneratedParser.OnConflictLegacyContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitConflictTarget(ctx: GeneratedParser.ConflictTargetContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitConflictAction(ctx: GeneratedParser.ConflictActionContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDoReplace(ctx: GeneratedParser.DoReplaceContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitDoUpdate(ctx: GeneratedParser.DoUpdateContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + // "simple paths" used by previous DDL's CREATE INDEX + override fun visitPathSimple(ctx: GeneratedParser.PathSimpleContext) = translate(ctx) { + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") + } + + // "simple paths" used by previous DDL's CREATE INDEX + override fun visitPathSimpleLiteral(ctx: GeneratedParser.PathSimpleLiteralContext) = translate(ctx) { + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") + } + + // "simple paths" used by previous DDL's CREATE INDEX + override fun visitPathSimpleSymbol(ctx: GeneratedParser.PathSimpleSymbolContext) = translate(ctx) { + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") + } + + // "simple paths" used by previous DDL's CREATE INDEX + override fun visitPathSimpleDotSymbol(ctx: GeneratedParser.PathSimpleDotSymbolContext) = translate(ctx) { + throw error(ctx, "DDL no longer supported in the default PartiQLParser.") + } + + /** + * TODO current PartiQL.g4 grammar models a SET with no UPDATE target as valid DML command. + */ + override fun visitSetCommand(ctx: GeneratedParser.SetCommandContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + override fun visitSetAssignment(ctx: GeneratedParser.SetAssignmentContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + /** + * + * DATA QUERY LANGUAGE (DQL) + * + */ + + override fun visitDql(ctx: GeneratedParser.DqlContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + query(expr) + } + + override fun visitQueryBase(ctx: GeneratedParser.QueryBaseContext): AstNode = visit(ctx.exprSelect()) + + override fun visitSfwQuery(ctx: GeneratedParser.SfwQueryContext) = translate(ctx) { + val select = visit(ctx.select) as Select + val from = visitFromClause(ctx.from) + val exclude = visitOrNull(ctx.exclude) + val let = visitOrNull(ctx.let) + val where = visitOrNull(ctx.where) + val groupBy = ctx.group?.let { visitGroupClause(it) } + val having = visitOrNull(ctx.having?.arg) + val orderBy = ctx.order?.let { visitOrderByClause(it) } + val limit = visitOrNull(ctx.limit?.arg) + val offset = visitOrNull(ctx.offset?.arg) + exprQuerySet( + body = queryBodySFW( + select, exclude, from, let, where, groupBy, having + ), + orderBy = orderBy, + limit = limit, + offset = offset + ) + } + + /** + * + * SELECT & PROJECTIONS + * + */ + + override fun visitSelectAll(ctx: GeneratedParser.SelectAllContext) = translate(ctx) { + val quantifier = convertSetQuantifier(ctx.setQuantifierStrategy()) + selectStar(quantifier) + } + + override fun visitSelectItems(ctx: GeneratedParser.SelectItemsContext) = translate(ctx) { + val items = visitOrEmpty(ctx.projectionItems().projectionItem()) + val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) + selectList(items, setq) + } + + override fun visitSelectPivot(ctx: GeneratedParser.SelectPivotContext) = translate(ctx) { + val key = visitExpr(ctx.at) + val value = visitExpr(ctx.pivot) + selectPivot(key, value) + } + + override fun visitSelectValue(ctx: GeneratedParser.SelectValueContext) = translate(ctx) { + val constructor = visitExpr(ctx.expr()) + val setq = convertSetQuantifier(ctx.setQuantifierStrategy()) + selectValue(constructor, setq) + } + + override fun visitProjectionItem(ctx: GeneratedParser.ProjectionItemContext) = translate(ctx) { + val expr = visitExpr(ctx.expr()) + val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it) } + if (expr is ExprPath) { + convertPathToProjectionItem(ctx, expr, alias) + } else { + selectItemExpr(expr, alias) + } + } + + /** + * + * SIMPLE CLAUSES + * + */ + + override fun visitLimitClause(ctx: GeneratedParser.LimitClauseContext): Expr = visitAs(ctx.arg) + + override fun visitExpr(ctx: GeneratedParser.ExprContext): Expr { + if (Thread.interrupted()) { + throw InterruptedException() + } + return visitAs(ctx.exprBagOp()) + } + + override fun visitOffsetByClause(ctx: GeneratedParser.OffsetByClauseContext) = visitAs(ctx.arg) + + override fun visitWhereClause(ctx: GeneratedParser.WhereClauseContext) = visitExpr(ctx.arg) + + override fun visitWhereClauseSelect(ctx: GeneratedParser.WhereClauseSelectContext) = visitAs(ctx.arg) + + override fun visitHavingClause(ctx: GeneratedParser.HavingClauseContext) = visitAs(ctx.arg) + + /** + * + * LET CLAUSE + * + */ + + override fun visitLetClause(ctx: GeneratedParser.LetClauseContext) = translate(ctx) { + val bindings = visitOrEmpty(ctx.letBinding()) + Ast.let(bindings) + } + + override fun visitLetBinding(ctx: GeneratedParser.LetBindingContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + val alias = visitSymbolPrimitive(ctx.symbolPrimitive()) + letBinding(expr, alias) + } + + /** + * + * ORDER BY CLAUSE + * + */ + + override fun visitOrderByClause(ctx: GeneratedParser.OrderByClauseContext) = translate(ctx) { + val sorts = visitOrEmpty(ctx.orderSortSpec()) + orderBy(sorts) + } + + override fun visitOrderSortSpec(ctx: GeneratedParser.OrderSortSpecContext) = translate(ctx) { + val expr = visitAs(ctx.expr()) + val dir = when { + ctx.dir == null -> null + ctx.dir.type == GeneratedParser.ASC -> Order.ASC() + ctx.dir.type == GeneratedParser.DESC -> Order.DESC() + else -> throw error(ctx.dir, "Invalid ORDER BY direction; expected ASC or DESC") + } + val nulls = when { + ctx.nulls == null -> null + ctx.nulls.type == GeneratedParser.FIRST -> Nulls.FIRST() + ctx.nulls.type == GeneratedParser.LAST -> Nulls.LAST() + else -> throw error(ctx.nulls, "Invalid ORDER BY null ordering; expected FIRST or LAST") + } + sort(expr, dir, nulls) + } + + /** + * + * GROUP BY CLAUSE + * + */ + + override fun visitGroupClause(ctx: GeneratedParser.GroupClauseContext) = translate(ctx) { + val strategy = if (ctx.PARTIAL() != null) GroupByStrategy.PARTIAL() else GroupByStrategy.FULL() + val keys = visitOrEmpty(ctx.groupKey()) + val alias = ctx.groupAlias()?.symbolPrimitive()?.let { visitSymbolPrimitive(it) } + groupBy(strategy, keys, alias) + } + + override fun visitGroupKey(ctx: GeneratedParser.GroupKeyContext) = translate(ctx) { + val expr = visitAs(ctx.key) + val alias = ctx.symbolPrimitive()?.let { visitSymbolPrimitive(it) } + groupByKey(expr, alias) + } + + /** + * EXCLUDE CLAUSE + */ + override fun visitExcludeClause(ctx: GeneratedParser.ExcludeClauseContext) = translate(ctx) { + val excludeExprs = ctx.excludeExpr().map { expr -> + visitExcludeExpr(expr) + } + exclude(excludeExprs) + } + + override fun visitExcludeExpr(ctx: GeneratedParser.ExcludeExprContext) = translate(ctx) { + val rootId = visitSymbolPrimitive(ctx.symbolPrimitive()) + val root = exprVarRef(identifierChain(rootId, null), Scope.DEFAULT()) + val steps = visitOrEmpty(ctx.excludeExprSteps()) + excludePath(root, steps) + } + + override fun visitExcludeExprTupleAttr(ctx: GeneratedParser.ExcludeExprTupleAttrContext) = translate(ctx) { + val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) + excludeStepStructField(identifier) + } + + override fun visitExcludeExprCollectionIndex(ctx: GeneratedParser.ExcludeExprCollectionIndexContext) = + translate(ctx) { + val index = ctx.index.text.toInt() + excludeStepCollIndex(index) + } + + override fun visitExcludeExprCollectionAttr(ctx: GeneratedParser.ExcludeExprCollectionAttrContext) = + translate(ctx) { + val attr = ctx.attr.getStringValue() + val identifier = identifier(attr, true) + excludeStepStructField(identifier) + } + + override fun visitExcludeExprCollectionWildcard(ctx: GeneratedParser.ExcludeExprCollectionWildcardContext) = + translate(ctx) { + excludeStepCollWildcard() + } + + override fun visitExcludeExprTupleWildcard(ctx: GeneratedParser.ExcludeExprTupleWildcardContext) = + translate(ctx) { + excludeStepStructWildcard() + } + + /** + * + * BAG OPERATIONS + * + */ + override fun visitBagOp(ctx: GeneratedParser.BagOpContext) = translate(ctx) { + val setq = when { + ctx.ALL() != null -> SetQuantifier.ALL() + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT() + else -> null + } + val op = when (ctx.op.type) { + GeneratedParser.UNION -> setOp(SetOpType.UNION(), setq) + GeneratedParser.INTERSECT -> setOp(SetOpType.INTERSECT(), setq) + GeneratedParser.EXCEPT -> setOp(SetOpType.EXCEPT(), setq) + else -> error("Unsupported bag op token ${ctx.op}") + } + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val outer = ctx.OUTER() != null + val orderBy = ctx.order?.let { visitOrderByClause(it) } + val limit = ctx.limit?.let { visitAs(it) } + val offset = ctx.offset?.let { visitAs(it) } + exprQuerySet( + queryBodySetOp( + op, + outer, + lhs, + rhs + ), + orderBy, + limit, + offset, + ) + } + + /** + * + * GRAPH PATTERN MANIPULATION LANGUAGE (GPML) + * + */ + + override fun visitGpmlPattern(ctx: GeneratedParser.GpmlPatternContext) = translate(ctx) { + val pattern = visitMatchPattern(ctx.matchPattern()) + val selector = visitOrNull(ctx.matchSelector()) + graphMatch(listOf(pattern), selector) + } + + override fun visitGpmlPatternList(ctx: GeneratedParser.GpmlPatternListContext) = translate(ctx) { + val patterns = ctx.matchPattern().map { pattern -> visitMatchPattern(pattern) } + val selector = visitOrNull(ctx.matchSelector()) + graphMatch(patterns, selector) + } + + override fun visitMatchPattern(ctx: GeneratedParser.MatchPatternContext) = translate(ctx) { + val parts = visitOrEmpty(ctx.graphPart()) + val restrictor = ctx.restrictor?.let { + when (ctx.restrictor.text.lowercase()) { + "trail" -> GraphRestrictor.TRAIL() + "acyclic" -> GraphRestrictor.ACYCLIC() + "simple" -> GraphRestrictor.SIMPLE() + else -> throw error(ctx.restrictor, "Unrecognized pattern restrictor") + } + } + val variable = visitOrNull(ctx.variable)?.symbol + graphPattern(restrictor, null, variable, null, parts) + } + + override fun visitPatternPathVariable(ctx: GeneratedParser.PatternPathVariableContext) = + visitSymbolPrimitive(ctx.symbolPrimitive()) + + override fun visitSelectorBasic(ctx: GeneratedParser.SelectorBasicContext) = translate(ctx) { + when (ctx.mod.type) { + GeneratedParser.ANY -> graphSelectorAnyShortest() + GeneratedParser.ALL -> graphSelectorAllShortest() + else -> throw error(ctx, "Unsupported match selector.") + } + } + + override fun visitSelectorAny(ctx: GeneratedParser.SelectorAnyContext) = translate(ctx) { + when (ctx.k) { + null -> graphSelectorAny() + else -> graphSelectorAnyK(ctx.k.text.toLong()) + } + } + + override fun visitSelectorShortest(ctx: GeneratedParser.SelectorShortestContext) = translate(ctx) { + val k = ctx.k.text.toLong() + when (ctx.GROUP()) { + null -> graphSelectorShortestK(k) + else -> graphSelectorShortestKGroup(k) + } + } + + override fun visitLabelSpecOr(ctx: GeneratedParser.LabelSpecOrContext) = translate(ctx) { + val lhs = visit(ctx.labelSpec()) as GraphLabel + val rhs = visit(ctx.labelTerm()) as GraphLabel + graphLabelDisj(lhs, rhs) + } + + override fun visitLabelTermAnd(ctx: GeneratedParser.LabelTermAndContext) = translate(ctx) { + val lhs = visit(ctx.labelTerm()) as GraphLabel + val rhs = visit(ctx.labelFactor()) as GraphLabel + graphLabelConj(lhs, rhs) + } + + override fun visitLabelFactorNot(ctx: GeneratedParser.LabelFactorNotContext) = translate(ctx) { + val arg = visit(ctx.labelPrimary()) as GraphLabel + graphLabelNegation(arg) + } + + override fun visitLabelPrimaryName(ctx: GeneratedParser.LabelPrimaryNameContext) = translate(ctx) { + val x = visitSymbolPrimitive(ctx.symbolPrimitive()) + graphLabelName(x.symbol) + } + + override fun visitLabelPrimaryWild(ctx: GeneratedParser.LabelPrimaryWildContext) = translate(ctx) { + graphLabelWildcard() + } + + override fun visitLabelPrimaryParen(ctx: GeneratedParser.LabelPrimaryParenContext) = + visit(ctx.labelSpec()) as GraphLabel + + override fun visitPattern(ctx: GeneratedParser.PatternContext) = translate(ctx) { + val restrictor = visitRestrictor(ctx.restrictor) + val variable = visitOrNull(ctx.variable)?.symbol + val prefilter = ctx.where?.let { visitExpr(it.expr()) } + val quantifier = ctx.quantifier?.let { visitPatternQuantifier(it) } + val parts = visitOrEmpty(ctx.graphPart()) + graphPattern(restrictor, prefilter, variable, quantifier, parts) + } + + override fun visitEdgeAbbreviated(ctx: GeneratedParser.EdgeAbbreviatedContext) = translate(ctx) { + val direction = visitEdge(ctx.edgeAbbrev()) + val quantifier = visitOrNull(ctx.quantifier) + graphMatchEdge(direction, quantifier, null, null, null) + } + + private fun GraphPart.Edge.copy( + direction: GraphDirection? = null, + quantifier: GraphQuantifier? = null, + prefilter: Expr? = null, + variable: String? = null, + label: GraphLabel? = null, + ) = graphMatchEdge( + direction = direction ?: this.direction, + quantifier = quantifier ?: this.quantifier, + prefilter = prefilter ?: this.prefilter, + variable = variable ?: this.variable, + label = label ?: this.label, + ) + + override fun visitEdgeWithSpec(ctx: GeneratedParser.EdgeWithSpecContext) = translate(ctx) { + val quantifier = visitOrNull(ctx.quantifier) + val edge = visitOrNull(ctx.edgeWSpec()) + edge!!.copy(quantifier = quantifier) + } + + override fun visitEdgeSpec(ctx: GeneratedParser.EdgeSpecContext) = translate(ctx) { + val placeholderDirection = GraphDirection.RIGHT() + val variable = visitOrNull(ctx.symbolPrimitive())?.symbol + val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } + val label = visitOrNull(ctx.labelSpec()) + graphMatchEdge(placeholderDirection, null, prefilter, variable, label) + } + + override fun visitEdgeSpecLeft(ctx: GeneratedParser.EdgeSpecLeftContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.LEFT()) + } + + override fun visitEdgeSpecRight(ctx: GeneratedParser.EdgeSpecRightContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.RIGHT()) + } + + override fun visitEdgeSpecBidirectional(ctx: GeneratedParser.EdgeSpecBidirectionalContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.LEFT_OR_RIGHT()) + } + + override fun visitEdgeSpecUndirectedBidirectional(ctx: GeneratedParser.EdgeSpecUndirectedBidirectionalContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.LEFT_UNDIRECTED_OR_RIGHT()) + } + + override fun visitEdgeSpecUndirected(ctx: GeneratedParser.EdgeSpecUndirectedContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.UNDIRECTED()) + } + + override fun visitEdgeSpecUndirectedLeft(ctx: GeneratedParser.EdgeSpecUndirectedLeftContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.LEFT_OR_UNDIRECTED()) + } + + override fun visitEdgeSpecUndirectedRight(ctx: GeneratedParser.EdgeSpecUndirectedRightContext): AstNode { + val edge = visitEdgeSpec(ctx.edgeSpec()) + return edge.copy(direction = GraphDirection.UNDIRECTED_OR_RIGHT()) + } + + private fun visitEdge(ctx: GeneratedParser.EdgeAbbrevContext): GraphDirection = when { + ctx.TILDE() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.UNDIRECTED_OR_RIGHT() + ctx.TILDE() != null && ctx.ANGLE_LEFT() != null -> GraphDirection.LEFT_OR_UNDIRECTED() + ctx.TILDE() != null -> GraphDirection.UNDIRECTED() + ctx.MINUS() != null && ctx.ANGLE_LEFT() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.LEFT_OR_RIGHT() + ctx.MINUS() != null && ctx.ANGLE_LEFT() != null -> GraphDirection.LEFT() + ctx.MINUS() != null && ctx.ANGLE_RIGHT() != null -> GraphDirection.RIGHT() + ctx.MINUS() != null -> GraphDirection.LEFT_UNDIRECTED_OR_RIGHT() + else -> throw error(ctx, "Unsupported edge type") + } + + override fun visitGraphPart(ctx: GeneratedParser.GraphPartContext): GraphPart { + val part = super.visitGraphPart(ctx) + if (part is GraphPattern) { + return translate(ctx) { graphMatchPattern(part) } + } + return part as GraphPart + } + + override fun visitPatternQuantifier(ctx: GeneratedParser.PatternQuantifierContext) = translate(ctx) { + when { + ctx.quant == null -> graphQuantifier(ctx.lower.text.toLong(), ctx.upper?.text?.toLong()) + ctx.quant.type == GeneratedParser.PLUS -> graphQuantifier(1L, null) + ctx.quant.type == GeneratedParser.ASTERISK -> graphQuantifier(0L, null) + else -> throw error(ctx, "Unsupported quantifier") + } + } + + override fun visitNode(ctx: GeneratedParser.NodeContext) = translate(ctx) { + val variable = visitOrNull(ctx.symbolPrimitive())?.symbol + val prefilter = ctx.whereClause()?.let { visitExpr(it.expr()) } + val label = visitOrNull(ctx.labelSpec()) + graphMatchNode(prefilter, variable, label) + } + + private fun visitRestrictor(ctx: GeneratedParser.PatternRestrictorContext?): GraphRestrictor? { + if (ctx == null) return null + return when (ctx.restrictor.text.lowercase()) { + "trail" -> GraphRestrictor.TRAIL() + "acyclic" -> GraphRestrictor.ACYCLIC() + "simple" -> GraphRestrictor.SIMPLE() + else -> throw error(ctx, "Unrecognized pattern restrictor") + } + } + + /** + * + * TABLE REFERENCES & JOINS & FROM CLAUSE + * + */ + override fun visitFromClause(ctx: GeneratedParser.FromClauseContext): From = translate(ctx) { + val tableRefs = visitOrEmpty(ctx.tableReference()) + from(tableRefs) + } + + override fun visitTableBaseRefSymbol(ctx: PartiQLParser.TableBaseRefSymbolContext): FromTableRef = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = visitSymbolPrimitive(ctx.symbolPrimitive()) + fromExpr(expr, FromType.SCAN(), asAlias, null) + } + + override fun visitTableBaseRefClauses(ctx: PartiQLParser.TableBaseRefClausesContext): FromTableRef = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } + val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } + fromExpr(expr, FromType.SCAN(), asAlias, atAlias) + } + + override fun visitTableBaseRefMatch(ctx: PartiQLParser.TableBaseRefMatchContext): FromTableRef = translate(ctx) { + val expr = visitAs(ctx.source) + val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } + val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } + fromExpr(expr, FromType.SCAN(), asAlias, atAlias) + } + + override fun visitTableUnpivot(ctx: PartiQLParser.TableUnpivotContext): FromTableRef = translate(ctx) { + val expr = visitAs(ctx.expr()) + val asAlias = ctx.asIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } + val atAlias = ctx.atIdent()?.let { visitSymbolPrimitive(it.symbolPrimitive()) } + fromExpr(expr, FromType.UNPIVOT(), asAlias, atAlias) + } + + override fun visitTableWrapped(ctx: PartiQLParser.TableWrappedContext): FromTableRef = translate(ctx) { + visitAs(ctx.tableReference()) + } + + override fun visitTableCrossJoin(ctx: PartiQLParser.TableCrossJoinContext): FromTableRef = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + fromJoin(lhs, rhs, JoinType.CROSS(), null) + } + + override fun visitTableQualifiedJoin(ctx: PartiQLParser.TableQualifiedJoinContext): FromTableRef = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs) + val type = convertJoinType(ctx.joinType()) + val condition = ctx.joinSpec()?.let { visitExpr(it.expr()) } + fromJoin(lhs, rhs, type, condition) + } + + private fun convertJoinType(ctx: GeneratedParser.JoinTypeContext?): JoinType? { + if (ctx == null) return null + return when (ctx.mod.type) { + GeneratedParser.INNER -> JoinType.INNER() + GeneratedParser.LEFT -> when (ctx.OUTER()) { + null -> JoinType.LEFT() + else -> JoinType.LEFT_OUTER() + } + GeneratedParser.RIGHT -> when (ctx.OUTER()) { + null -> JoinType.RIGHT() + else -> JoinType.RIGHT_OUTER() + } + GeneratedParser.FULL -> when (ctx.OUTER()) { + null -> JoinType.FULL() + else -> JoinType.FULL_OUTER() + } + GeneratedParser.OUTER -> { + // TODO https://github.com/partiql/partiql-spec/issues/41 + // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1013 + JoinType.FULL_OUTER() + } + else -> null + } + } + + /** + * TODO Remove as/at/by aliases from DELETE command grammar in PartiQL.g4 + */ + override fun visitFromClauseSimpleExplicit(ctx: GeneratedParser.FromClauseSimpleExplicitContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + /** + * TODO Remove fromClauseSimple rule from DELETE command grammar in PartiQL.g4 + */ + override fun visitFromClauseSimpleImplicit(ctx: GeneratedParser.FromClauseSimpleImplicitContext) = translate(ctx) { + throw error(ctx, "DML no longer supported in the default PartiQLParser.") + } + + /** + * SIMPLE EXPRESSIONS + */ + + override fun visitOr(ctx: GeneratedParser.OrContext) = translate(ctx) { + val l = visit(ctx.lhs) as Expr + val r = visit(ctx.rhs) as Expr + exprOr(l, r) + } + + override fun visitAnd(ctx: GeneratedParser.AndContext) = translate(ctx) { + val l = visit(ctx.lhs) as Expr + val r = visit(ctx.rhs) as Expr + exprAnd(l, r) + } + + override fun visitNot(ctx: GeneratedParser.NotContext) = translate(ctx) { + val expr = visit(ctx.exprNot()) as Expr + exprNot(expr) + } + + private fun checkForInvalidTokens(op: ParserRuleContext) { + val start = op.start.tokenIndex + val stop = op.stop.tokenIndex + val tokensInRange = tokens.get(start, stop) + if (tokensInRange.any { it.channel == GeneratedLexer.HIDDEN }) { + throw error(op, "Invalid whitespace or comment in operator") + } + } + + private fun convertToOperator(value: ParserRuleContext, op: ParserRuleContext): Expr { + checkForInvalidTokens(op) + return convertToOperator(value, op.text) + } + + private fun convertToOperator(value: ParserRuleContext, op: String): Expr { + val v = visit(value) as Expr + return exprOperator(op, null, v) + } + + private fun convertToOperator(lhs: ParserRuleContext, rhs: ParserRuleContext, op: ParserRuleContext): Expr { + checkForInvalidTokens(op) + return convertToOperator(lhs, rhs, op.text) + } + + private fun convertToOperator(lhs: ParserRuleContext, rhs: ParserRuleContext, op: String): Expr { + val l = visit(lhs) as Expr + val r = visit(rhs) as Expr + return exprOperator(op, l, r) + } + + override fun visitMathOp00(ctx: GeneratedParser.MathOp00Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertToOperator(ctx.lhs, ctx.rhs, ctx.op) + } + + override fun visitMathOp01(ctx: GeneratedParser.MathOp01Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertToOperator(ctx.rhs, ctx.op) + } + + override fun visitMathOp02(ctx: GeneratedParser.MathOp02Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertToOperator(ctx.lhs, ctx.rhs, ctx.op.text) + } + + override fun visitMathOp03(ctx: GeneratedParser.MathOp03Context) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertToOperator(ctx.lhs, ctx.rhs, ctx.op.text) + } + + override fun visitValueExpr(ctx: GeneratedParser.ValueExprContext) = translate(ctx) { + if (ctx.parent != null) return@translate visit(ctx.parent) + convertToOperator(ctx.rhs, ctx.sign.text) + } + + /** + * + * PREDICATES + * + */ + + override fun visitPredicateComparison(ctx: GeneratedParser.PredicateComparisonContext) = translate(ctx) { + convertToOperator(ctx.lhs, ctx.rhs, ctx.op) + } + + /** + * TODO Fix the IN collection grammar, also label alternative forms + * - https://github.com/partiql/partiql-lang-kotlin/issues/1115 + * - https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitPredicateIn(ctx: GeneratedParser.PredicateInContext) = translate(ctx) { + val lhs = visitAs(ctx.lhs) + val rhs = visitAs(ctx.rhs ?: ctx.expr()).let { + // Wrap rhs in an array unless it's a query or already a collection + if (it is ExprQuerySet || it is ExprArray || it is ExprBag || ctx.PAREN_LEFT() == null) { + it + } else { + // IN ( expr ) + exprArray(listOf(it)) + } + } + val not = ctx.NOT() != null + exprInCollection(lhs, rhs, not) + } + + override fun visitPredicateIs(ctx: GeneratedParser.PredicateIsContext) = translate(ctx) { + val value = visitAs(ctx.lhs) + val type = visitAs(ctx.type()) + val not = ctx.NOT() != null + exprIsType(value, type, not) + } + + override fun visitPredicateBetween(ctx: GeneratedParser.PredicateBetweenContext) = translate(ctx) { + val value = visitAs(ctx.lhs) + val lower = visitAs(ctx.lower) + val upper = visitAs(ctx.upper) + val not = ctx.NOT() != null + exprBetween(value, lower, upper, not) + } + + override fun visitPredicateLike(ctx: GeneratedParser.PredicateLikeContext) = translate(ctx) { + val value = visitAs(ctx.lhs) + val pattern = visitAs(ctx.rhs) + val escape = visitOrNull(ctx.escape) + val not = ctx.NOT() != null + exprLike(value, pattern, escape, not) + } + + /** + * + * PRIMARY EXPRESSIONS + * + */ + + override fun visitExprTermWrappedQuery(ctx: GeneratedParser.ExprTermWrappedQueryContext): AstNode = + visit(ctx.expr()) + + override fun visitVariableIdentifier(ctx: GeneratedParser.VariableIdentifierContext) = translate(ctx) { + val symbol = ctx.ident.getStringValue() + val isDelimited = when (ctx.ident.type) { + GeneratedParser.IDENTIFIER -> false + else -> true + } + val scope = when (ctx.qualifier) { + null -> Scope.DEFAULT() + else -> Scope.LOCAL() + } + exprVarRef( + identifierChain( + root = identifier(symbol, isDelimited), + next = null + ), + scope + ) + } + + override fun visitVariableKeyword(ctx: GeneratedParser.VariableKeywordContext) = translate(ctx) { + val symbol = ctx.key.text + val isDelimited = false + val scope = when (ctx.qualifier) { + null -> Scope.DEFAULT() + else -> Scope.LOCAL() + } + exprVarRef( + identifierChain( + root = identifier(symbol, isDelimited), + next = null + ), + scope + ) + } + + override fun visitParameter(ctx: GeneratedParser.ParameterContext) = translate(ctx) { + val index = parameters[ctx.QUESTION_MARK().symbol.tokenIndex] ?: throw error( + ctx, "Unable to find index of parameter." + ) + exprParameter(index) + } + + override fun visitSequenceConstructor(ctx: GeneratedParser.SequenceConstructorContext) = translate(ctx) { + error("Sequence constructor not supported") + } + + private fun PathStep.copy(next: PathStep?) = when (this) { + is PathStep.Element -> exprPathStepElement(this.element, next) + is PathStep.Field -> exprPathStepField(this.field, next) + is PathStep.AllElements -> exprPathStepAllElements(next) + is PathStep.AllFields -> exprPathStepAllFields(next) + else -> error("Unsupported PathStep: $this") + } + + override fun visitExprPrimaryPath(ctx: GeneratedParser.ExprPrimaryPathContext) = translate(ctx) { + val base = visitAs(ctx.exprPrimary()) + val init: PathStep? = null + val steps = ctx.pathStep().reversed().fold(init) { acc, step -> + val stepExpr = visit(step) as PathStep + stepExpr.copy(acc) + } + exprPath(base, steps) + } + + override fun visitPathStepIndexExpr(ctx: GeneratedParser.PathStepIndexExprContext) = translate(ctx) { + val key = visitAs(ctx.key) + exprPathStepElement(key, null) + } + + override fun visitPathStepDotExpr(ctx: GeneratedParser.PathStepDotExprContext) = translate(ctx) { + val symbol = visitSymbolPrimitive(ctx.symbolPrimitive()) + exprPathStepField(symbol, null) + } + + override fun visitPathStepIndexAll(ctx: GeneratedParser.PathStepIndexAllContext) = translate(ctx) { + exprPathStepAllElements(null) + } + + override fun visitPathStepDotAll(ctx: GeneratedParser.PathStepDotAllContext) = translate(ctx) { + exprPathStepAllFields(null) + } + + override fun visitExprGraphMatchMany(ctx: GeneratedParser.ExprGraphMatchManyContext) = translate(ctx) { + val graph = visit(ctx.exprPrimary()) as Expr + val pattern = visitGpmlPatternList(ctx.gpmlPatternList()) + exprMatch(graph, pattern) + } + + override fun visitExprGraphMatchOne(ctx: GeneratedParser.ExprGraphMatchOneContext) = translate(ctx) { + val graph = visit(ctx.exprPrimary()) as Expr + val pattern = visitGpmlPattern(ctx.gpmlPattern()) + exprMatch(graph, pattern) + } + + override fun visitExprTermCurrentUser(ctx: GeneratedParser.ExprTermCurrentUserContext) = translate(ctx) { + exprSessionAttribute(SessionAttribute.CURRENT_USER()) + } + + override fun visitExprTermCurrentDate(ctx: GeneratedParser.ExprTermCurrentDateContext) = + translate(ctx) { + exprSessionAttribute(SessionAttribute.CURRENT_DATE()) + } + + /** + * + * FUNCTIONS + * + */ + + override fun visitNullIf(ctx: GeneratedParser.NullIfContext) = translate(ctx) { + val value = visitExpr(ctx.expr(0)) + val nullifier = visitExpr(ctx.expr(1)) + exprNullIf(value, nullifier) + } + + override fun visitCoalesce(ctx: GeneratedParser.CoalesceContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprCoalesce(expressions) + } + + override fun visitCaseExpr(ctx: GeneratedParser.CaseExprContext) = translate(ctx) { + val expr = ctx.case_?.let { visitExpr(it) } + val branches = ctx.whens.indices.map { i -> + // consider adding locations + val w = visitExpr(ctx.whens[i]) + val t = visitExpr(ctx.thens[i]) + exprCaseBranch(w, t) + } + val default = ctx.else_?.let { visitExpr(it) } + exprCase(expr, branches, default) + } + + override fun visitCast(ctx: GeneratedParser.CastContext) = translate(ctx) { + val expr = visitExpr(ctx.expr()) + val type = visitAs(ctx.type()) + exprCast(expr, type) + } + + override fun visitCanCast(ctx: GeneratedParser.CanCastContext) = translate(ctx) { + throw error(ctx, "CAN_CAST is no longer supported in the default PartiQLParser") + } + + override fun visitCanLosslessCast(ctx: GeneratedParser.CanLosslessCastContext) = translate(ctx) { + throw error(ctx, "CAN_LOSSLESS_CAST is no longer supported in the default PartiQLParser") + } + + override fun visitFunctionCall(ctx: GeneratedParser.FunctionCallContext) = translate(ctx) { + val args = visitOrEmpty(ctx.expr()) + when (val funcName = ctx.qualifiedName()) { + is GeneratedParser.QualifiedNameContext -> { + when (funcName.name.start.type) { + GeneratedParser.MOD -> exprOperator("%", args[0], args[1]) + GeneratedParser.CHARACTER_LENGTH, GeneratedParser.CHAR_LENGTH -> { + val path = ctx.qualifiedName().qualifier.map { visitSymbolPrimitive(it) } + val name = identifierChain(identifier("char_length", false), null) + if (path.isEmpty()) { + exprCall(name, args, null) // setq = null for scalar fn + } else { + val function = path.reversed().fold(name) { acc, id -> + identifierChain(root = id, next = acc) + } + exprCall(function, args, setq = null) + } + } + else -> visitNonReservedFunctionCall(ctx, args) + } + } + else -> visitNonReservedFunctionCall(ctx, args) + } + } + private fun visitNonReservedFunctionCall(ctx: GeneratedParser.FunctionCallContext, args: List): ExprCall { + val function = visitQualifiedName(ctx.qualifiedName()) + return exprCall(function, args, convertSetQuantifier(ctx.setQuantifierStrategy())) + } + + /** + * + * FUNCTIONS WITH SPECIAL FORMS + * + */ + + override fun visitDateFunction(ctx: GeneratedParser.DateFunctionContext) = translate(ctx) { + try { + DatetimeField.valueOf(ctx.dt.text) + } catch (ex: IllegalArgumentException) { + throw error(ctx.dt, "Expected one of: ${DatetimeField.values().joinToString()}", ex) + } + val lhs = visitExpr(ctx.expr(0)) + val rhs = visitExpr(ctx.expr(1)) + // TODO change to not use PartiQLValue -- https://github.com/partiql/partiql-lang-kotlin/issues/1589 + val fieldLit = exprLit(stringValue(ctx.dt.text.uppercase())) + when { + ctx.DATE_ADD() != null -> exprCall(identifierChain(identifier("DATE_ADD", false), null), listOf(fieldLit, lhs, rhs), null) + ctx.DATE_DIFF() != null -> exprCall(identifierChain(identifier("DATE_DIFF", false), null), listOf(fieldLit, lhs, rhs), null) + else -> throw error(ctx, "Expected DATE_ADD or DATE_DIFF") + } + } + + /** + * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitSubstring(ctx: GeneratedParser.SubstringContext) = translate(ctx) { + if (ctx.FROM() == null) { + // normal form + val function = "SUBSTRING".toIdentifierChain() + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args, setq = null) // setq = null for scalar fn + } else { + // special form + val value = visitExpr(ctx.expr(0)) + val start = visitOrNull(ctx.expr(1)) + val length = visitOrNull(ctx.expr(2)) + exprSubstring(value, start, length) + } + } + + /** + * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitPosition(ctx: GeneratedParser.PositionContext) = translate(ctx) { + if (ctx.IN() == null) { + // normal form + val function = "POSITION".toIdentifierChain() + val args = visitOrEmpty(ctx.expr()) + exprCall(function, args, setq = null) // setq = null for scalar fn + } else { + // special form + val lhs = visitExpr(ctx.expr(0)) + val rhs = visitExpr(ctx.expr(1)) + exprPosition(lhs, rhs) + } + } + + /** + * TODO Add labels to each alternative, https://github.com/partiql/partiql-lang-kotlin/issues/1113 + */ + override fun visitOverlay(ctx: GeneratedParser.OverlayContext) = translate(ctx) { + // TODO: figure out why do we have a normalized form for overlay? + if (ctx.PLACING() == null) { + // normal form + val function = "OVERLAY".toIdentifierChain() + val args = arrayOfNulls(4).also { + visitOrEmpty(ctx.expr()).forEachIndexed { index, expr -> + it[index] = expr + } + } + val e = error(ctx, "overlay function requires at least three args") + + exprOverlay(args[0] ?: throw e, args[1] ?: throw e, args[2] ?: throw e, args[3]) + } else { + // special form + val value = visitExpr(ctx.expr(0)) + val overlay = visitExpr(ctx.expr(1)) + val start = visitExpr(ctx.expr(2)) + val length = visitOrNull(ctx.expr(3)) + exprOverlay(value, overlay, start, length) + } + } + + override fun visitExtract(ctx: GeneratedParser.ExtractContext) = translate(ctx) { + val field = try { + DatetimeField.valueOf(ctx.IDENTIFIER().text.uppercase()) + } catch (ex: IllegalArgumentException) { + throw error(ctx.IDENTIFIER().symbol, "Expected one of: ${DatetimeField.values().joinToString()}", ex) + } + val source = visitExpr(ctx.expr()) + exprExtract(field, source) + } + + override fun visitTrimFunction(ctx: GeneratedParser.TrimFunctionContext) = translate(ctx) { + val spec = ctx.mod?.let { + try { + TrimSpec.valueOf(it.text.uppercase()) + } catch (ex: IllegalArgumentException) { + throw error(it, "Expected on of: ${TrimSpec.values().joinToString()}", ex) + } + } + val (chars, value) = when (ctx.expr().size) { + 1 -> null to visitExpr(ctx.expr(0)) + 2 -> visitExpr(ctx.expr(0)) to visitExpr(ctx.expr(1)) + else -> throw error(ctx, "Expected one or two TRIM expression arguments") + } + exprTrim(value, chars, spec) + } + + /** + * Window Functions + */ + + override fun visitLagLeadFunction(ctx: GeneratedParser.LagLeadFunctionContext) = translate(ctx) { + val function = when { + ctx.LAG() != null -> WindowFunction.LAG() + ctx.LEAD() != null -> WindowFunction.LEAD() + else -> throw error(ctx, "Expected LAG or LEAD") + } + val expression = visitExpr(ctx.expr(0)) + val offset = visitOrNull(ctx.expr(1)) + val default = visitOrNull(ctx.expr(2)) + val over = visitOver(ctx.over()) + if (over.sorts == null) { + throw error(ctx.over(), "$function requires Window ORDER BY") + } + exprWindow(function, expression, offset, default, over) + } + + override fun visitOver(ctx: GeneratedParser.OverContext) = translate(ctx) { + val partitions = ctx.windowPartitionList()?.let { visitOrEmpty(it.expr()) } + val sorts = ctx.windowSortSpecList()?.let { visitOrEmpty(it.orderSortSpec()) } + exprWindowOver(partitions, sorts) + } + + /** + * + * LITERALS + * + */ + + override fun visitBag(ctx: GeneratedParser.BagContext) = translate(ctx) { + // Prohibit hidden characters between angle brackets + val startTokenIndex = ctx.start.tokenIndex + val endTokenIndex = ctx.stop.tokenIndex + if (tokens.getHiddenTokensToRight(startTokenIndex, GeneratedLexer.HIDDEN) != null || tokens.getHiddenTokensToLeft(endTokenIndex, GeneratedLexer.HIDDEN) != null) { + throw error(ctx, "Invalid bag expression") + } + val expressions = visitOrEmpty(ctx.expr()) + exprBag(expressions) + } + + override fun visitLiteralDecimal(ctx: GeneratedParser.LiteralDecimalContext) = translate(ctx) { + val decimal = try { + val v = ctx.LITERAL_DECIMAL().text.trim() + BigDecimal(v, MathContext(38, RoundingMode.HALF_EVEN)) + } catch (e: NumberFormatException) { + throw error(ctx, "Invalid decimal literal", e) + } + exprLit(decimalValue(decimal)) + } + + override fun visitArray(ctx: GeneratedParser.ArrayContext) = translate(ctx) { + val expressions = visitOrEmpty(ctx.expr()) + exprArray(expressions) + } + + override fun visitLiteralNull(ctx: GeneratedParser.LiteralNullContext) = translate(ctx) { + exprLit(nullValue()) + } + + override fun visitLiteralMissing(ctx: GeneratedParser.LiteralMissingContext) = translate(ctx) { + exprLit(missingValue()) + } + + override fun visitLiteralTrue(ctx: GeneratedParser.LiteralTrueContext) = translate(ctx) { + exprLit(boolValue(true)) + } + + override fun visitLiteralFalse(ctx: GeneratedParser.LiteralFalseContext) = translate(ctx) { + exprLit(boolValue(false)) + } + + override fun visitLiteralIon(ctx: GeneratedParser.LiteralIonContext) = translate(ctx) { + val value = ctx.ION_CLOSURE().getStringValue() + val encoding = "ion" + exprVariant(value, encoding) + } + + override fun visitLiteralString(ctx: GeneratedParser.LiteralStringContext) = translate(ctx) { + val value = ctx.LITERAL_STRING().getStringValue() + exprLit(stringValue(value)) + } + + override fun visitLiteralInteger(ctx: GeneratedParser.LiteralIntegerContext) = translate(ctx) { + val n = ctx.LITERAL_INTEGER().text + + // 1st, try parse as int + try { + val v = n.toInt(10) + return@translate exprLit(int32Value(v)) + } catch (ex: NumberFormatException) { + // ignore + } + + // 2nd, try parse as long + try { + val v = n.toLong(10) + return@translate exprLit(int64Value(v)) + } catch (ex: NumberFormatException) { + // ignore + } + + // 3rd, try parse as BigInteger + try { + val v = BigInteger(n) + return@translate exprLit(intValue(v)) + } catch (ex: NumberFormatException) { + throw ex + } + } + + override fun visitLiteralDate(ctx: GeneratedParser.LiteralDateContext) = translate(ctx) { + val pattern = ctx.LITERAL_STRING().symbol + val dateString = ctx.LITERAL_STRING().getStringValue() + if (DATE_PATTERN_REGEX.matches(dateString).not()) { + throw error(pattern, "Expected DATE string to be of the format yyyy-MM-dd") + } + val value = try { + LocalDate.parse(dateString, DateTimeFormatter.ISO_LOCAL_DATE) + } catch (e: DateTimeParseException) { + throw error(pattern, e.localizedMessage, e) + } catch (e: IndexOutOfBoundsException) { + throw error(pattern, e.localizedMessage, e) + } + val date = DateTimeValue.date(value.year, value.monthValue, value.dayOfMonth) + exprLit(dateValue(date)) + } + + override fun visitLiteralTime(ctx: GeneratedParser.LiteralTimeContext) = translate(ctx) { + val (timeString, precision) = getTimeStringAndPrecision(ctx.LITERAL_STRING(), ctx.LITERAL_INTEGER()) + val time = try { + DateTimeUtils.parseTimeLiteral(timeString) + } catch (e: DateTimeException) { + throw error(ctx, "Invalid Date Time Literal", e) + } + val value = time.toPrecision(precision) + exprLit(timeValue(value)) + } + + override fun visitLiteralTimestamp(ctx: GeneratedParser.LiteralTimestampContext) = translate(ctx) { + val (timeString, precision) = getTimeStringAndPrecision(ctx.LITERAL_STRING(), ctx.LITERAL_INTEGER()) + val timestamp = try { + DateTimeUtils.parseTimestamp(timeString) + } catch (e: DateTimeException) { + throw error(ctx, "Invalid Date Time Literal", e) + } + val value = timestamp.toPrecision(precision) + exprLit(timestampValue(value)) + } + + override fun visitTuple(ctx: GeneratedParser.TupleContext) = translate(ctx) { + val fields = ctx.pair().map { + val k = visitExpr(it.lhs) + val v = visitExpr(it.rhs) + exprStructField(k, v) + } + exprStruct(fields) + } + + /** + * + * TYPES + * + */ + + override fun visitTypeAtomic(ctx: GeneratedParser.TypeAtomicContext) = translate(ctx) { + when (ctx.datatype.type) { + GeneratedParser.NULL -> DataType.NULL() + GeneratedParser.BOOL -> DataType.BOOLEAN() + GeneratedParser.BOOLEAN -> DataType.BOOL() + GeneratedParser.SMALLINT -> DataType.SMALLINT() + GeneratedParser.INT2 -> DataType.INT2() + GeneratedParser.INTEGER2 -> DataType.INTEGER2() + // TODO, we have INT aliased to INT4 when it should be visa-versa. + GeneratedParser.INT4 -> DataType.INT4() + GeneratedParser.INTEGER4 -> DataType.INTEGER4() + GeneratedParser.INT -> DataType.INT() + GeneratedParser.INTEGER -> DataType.INTEGER() + GeneratedParser.BIGINT -> DataType.BIGINT() + GeneratedParser.INT8 -> DataType.INT8() + GeneratedParser.INTEGER8 -> DataType.INTEGER8() + GeneratedParser.FLOAT -> DataType.FLOAT() + GeneratedParser.DOUBLE -> TODO() // not sure if DOUBLE is to be supported + GeneratedParser.REAL -> DataType.REAL() + GeneratedParser.TIMESTAMP -> DataType.TIMESTAMP() + GeneratedParser.CHAR -> DataType.CHAR() + GeneratedParser.CHARACTER -> DataType.CHARACTER() + GeneratedParser.MISSING -> DataType.MISSING() + GeneratedParser.STRING -> DataType.STRING() + GeneratedParser.SYMBOL -> DataType.SYMBOL() + // TODO https://github.com/partiql/partiql-lang-kotlin/issues/1125 + GeneratedParser.BLOB -> DataType.BLOB() + GeneratedParser.CLOB -> DataType.CLOB() + GeneratedParser.DATE -> DataType.DATE() + GeneratedParser.STRUCT -> DataType.STRUCT() + GeneratedParser.TUPLE -> DataType.TUPLE() + GeneratedParser.LIST -> DataType.LIST() + GeneratedParser.SEXP -> DataType.SEXP() + GeneratedParser.BAG -> DataType.BAG() + GeneratedParser.ANY -> TODO() // not sure if ANY is to be supported + else -> throw error(ctx, "Unknown atomic type.") // TODO other types included in parser + } + } + + override fun visitTypeVarChar(ctx: GeneratedParser.TypeVarCharContext): DataType = translate(ctx) { + when (val n = ctx.arg0?.text?.toInt()) { + null -> DataType.VARCHAR() + else -> DataType.VARCHAR(n) + } + } + + override fun visitTypeArgSingle(ctx: GeneratedParser.TypeArgSingleContext) = translate(ctx) { + val n = ctx.arg0?.text?.toInt() + when (ctx.datatype.type) { + GeneratedParser.FLOAT -> when (n) { + null -> DataType.FLOAT(64) + 32 -> DataType.FLOAT(32) + 64 -> DataType.FLOAT(64) + else -> throw error(ctx.datatype, "Invalid FLOAT precision. Expected 32 or 64") + } + GeneratedParser.CHAR, GeneratedParser.CHARACTER -> when (n) { + null -> DataType.CHAR() + else -> DataType.CHAR(n) + } + GeneratedParser.VARCHAR -> when (n) { + null -> DataType.VARCHAR() + else -> DataType.VARCHAR(n) + } + else -> throw error(ctx.datatype, "Invalid datatype") + } + } + + override fun visitTypeArgDouble(ctx: GeneratedParser.TypeArgDoubleContext) = translate(ctx) { + val arg0 = ctx.arg0?.text?.toInt() + val arg1 = ctx.arg1?.text?.toInt() + when (ctx.datatype.type) { + GeneratedParser.DECIMAL -> when { + arg0 == null && arg1 == null -> DataType.DECIMAL() + arg0 != null && arg1 == null -> DataType.DECIMAL(arg0) + arg0 != null && arg1 != null -> DataType.DECIMAL(arg0, arg1) + else -> error("Invalid parameters for decimal") + } + GeneratedParser.DEC -> when { + arg0 == null && arg1 == null -> DataType.DEC() + arg0 != null && arg1 == null -> DataType.DEC(arg0) + arg0 != null && arg1 != null -> DataType.DEC(arg0, arg1) + else -> error("Invalid parameters for dec") + } + GeneratedParser.NUMERIC -> when { + arg0 == null && arg1 == null -> DataType.NUMERIC() + arg0 != null && arg1 == null -> DataType.NUMERIC(arg0) + arg0 != null && arg1 != null -> DataType.NUMERIC(arg0, arg1) + else -> error("Invalid parameters for decimal") + } + else -> throw error(ctx.datatype, "Invalid datatype") + } + } + + override fun visitTypeTimeZone(ctx: GeneratedParser.TypeTimeZoneContext) = translate(ctx) { + val precision = ctx.precision?.let { + val p = ctx.precision.text.toInt() + if (p < 0 || 9 < p) throw error(ctx.precision, "Unsupported time precision") + p + } + + when (ctx.datatype.type) { + GeneratedParser.TIME -> when (ctx.ZONE()) { + null -> when (precision) { + null -> DataType.TIME() + else -> DataType.TIME(precision) + } + else -> when (precision) { + null -> DataType.TIME_WITH_TIME_ZONE() + else -> DataType.TIME_WITH_TIME_ZONE(precision) + } + } + GeneratedParser.TIMESTAMP -> when (ctx.ZONE()) { + null -> when (precision) { + null -> DataType.TIMESTAMP() + else -> DataType.TIMESTAMP(precision) + } + else -> when (precision) { + null -> DataType.TIMESTAMP_WITH_TIME_ZONE() + else -> DataType.TIMESTAMP_WITH_TIME_ZONE(precision) + } + } + else -> throw error(ctx.datatype, "Invalid datatype") + } + } + + override fun visitTypeCustom(ctx: GeneratedParser.TypeCustomContext) = translate(ctx) { + DataType.USER_DEFINED(ctx.text.uppercase().toIdentifierChain()) + } + + private inline fun visitOrEmpty(ctx: List?): List = when { + ctx.isNullOrEmpty() -> emptyList() + else -> ctx.map { visit(it) as T } + } + + private inline fun visitOrNull(ctx: ParserRuleContext?): T? = + ctx?.let { it.accept(this) as T } + + private inline fun visitAs(ctx: ParserRuleContext): T = visit(ctx) as T + + /** + * Visiting a symbol to get a string, skip the wrapping, unwrapping, and location tracking. + */ + private fun symbolToString(ctx: GeneratedParser.SymbolPrimitiveContext) = when (ctx) { + is GeneratedParser.IdentifierQuotedContext -> ctx.IDENTIFIER_QUOTED().getStringValue() + is GeneratedParser.IdentifierUnquotedContext -> ctx.text + else -> throw error(ctx, "Invalid symbol reference.") + } + + /** + * Convert [ALL|DISTINCT] to SetQuantifier Enum + */ + private fun convertSetQuantifier(ctx: GeneratedParser.SetQuantifierStrategyContext?): SetQuantifier? = when { + ctx == null -> null + ctx.ALL() != null -> SetQuantifier.ALL() + ctx.DISTINCT() != null -> SetQuantifier.DISTINCT() + else -> throw error(ctx, "Expected set quantifier ALL or DISTINCT") + } + + /** + * With the and nodes of a literal time expression, returns the parsed string and precision. + * TIME ()? (WITH TIME ZONE)? + */ + private fun getTimeStringAndPrecision( + stringNode: TerminalNode, + integerNode: TerminalNode?, + ): Pair { + val timeString = stringNode.getStringValue() + val precision = when (integerNode) { + null -> { + try { + getPrecisionFromTimeString(timeString) + } catch (e: Exception) { + throw error(stringNode.symbol, "Unable to parse precision.", e) + } + } + else -> { + val p = integerNode.text.toBigInteger().toInt() + if (p < 0 || 9 < p) throw error(integerNode.symbol, "Precision out of bounds") + p + } + } + return timeString to precision + } + + private fun getPrecisionFromTimeString(timeString: String): Int { + val matcher = GENERIC_TIME_REGEX.toPattern().matcher(timeString) + if (!matcher.find()) { + throw IllegalArgumentException("Time string does not match the format 'HH:MM:SS[.ddd....][+|-HH:MM]'") + } + val fraction = matcher.group(1)?.removePrefix(".") + return fraction?.length ?: 0 + } + + /** + * Converts a Path expression into a Projection Item (either ALL or EXPR). Note: A Projection Item only allows a + * subset of a typical Path expressions. See the following examples. + * + * Examples of valid projections are: + * + * ```partiql + * SELECT * FROM foo + * SELECT foo.* FROM foo + * SELECT f.* FROM foo as f + * SELECT foo.bar.* FROM foo + * SELECT f.bar.* FROM foo as f + * ``` + * Also validates that the expression is valid for select list context. It does this by making + * sure that expressions looking like the following do not appear: + * + * ```partiql + * SELECT foo[*] FROM foo + * SELECT f.*.bar FROM foo as f + * SELECT foo[1].* FROM foo + * SELECT foo.*.bar FROM foo + * ``` + */ + protected fun convertPathToProjectionItem(ctx: ParserRuleContext, path: ExprPath, alias: Identifier?) = + translate(ctx) { + val steps = mutableListOf() + var containsIndex = false + var curStep = path.next + var last = curStep + while (curStep != null) { + val isLastStep = curStep.next == null + // Only last step can have a '.*' + if (curStep is PathStep.AllFields && !isLastStep) { + throw error(ctx, "Projection item cannot unpivot unless at end.") + } + // No step can have an indexed wildcard: '[*]' + if (curStep is PathStep.AllElements) { + throw error(ctx, "Projection item cannot index using wildcard.") + } + // TODO If the last step is '.*', no indexing is allowed + // if (step.metas.containsKey(IsPathIndexMeta.TAG)) { + // containsIndex = true + // } + if (curStep !is PathStep.AllFields) { + steps.add(curStep) + } + + if (isLastStep && curStep is PathStep.AllFields && containsIndex) { + throw error(ctx, "Projection item use wildcard with any indexing.") + } + last = curStep + curStep = curStep.next + } + when { + last is PathStep.AllFields && steps.isEmpty() -> { + selectItemStar(path.root) + } + last is PathStep.AllFields -> { + val init: PathStep? = null + val newSteps = steps.reversed().fold(init) { acc, step -> + when (step) { + is PathStep.Element -> PathStep.Element(step.element, acc) + is PathStep.Field -> PathStep.Field(step.field, acc) + is PathStep.AllElements -> PathStep.AllElements(acc) + is PathStep.AllFields -> PathStep.AllFields(acc) + else -> error("Unexpected path step") + } + } + selectItemStar(exprPath(path.root, newSteps)) + } + else -> { + selectItemExpr(path, alias) + } + } + } + + private fun TerminalNode.getStringValue(): String = this.symbol.getStringValue() + + private fun Token.getStringValue(): String = when (this.type) { + GeneratedParser.IDENTIFIER -> this.text + GeneratedParser.IDENTIFIER_QUOTED -> this.text.removePrefix("\"").removeSuffix("\"").replace("\"\"", "\"") + GeneratedParser.LITERAL_STRING -> this.text.removePrefix("'").removeSuffix("'").replace("''", "'") + GeneratedParser.ION_CLOSURE -> this.text.removePrefix("`").removeSuffix("`") + else -> throw error(this, "Unsupported token for grabbing string value.") + } + + private fun String.toIdentifier(): Identifier = identifier(this, false) + + private fun String.toIdentifierChain(): IdentifierChain = identifierChain(root = this.toIdentifier(), next = null) + + private fun String.toBigInteger() = BigInteger(this, 10) + + private fun assertIntegerElement(token: Token?, value: IonElement?) { + if (value == null || token == null) return + if (value !is IntElement) throw error(token, "Expected an integer value.") + if (value.integerSize == IntElementSize.BIG_INTEGER || value.longValue > Int.MAX_VALUE || value.longValue < Int.MIN_VALUE) throw error( + token, "Type parameter exceeded maximum value" + ) + } + + private enum class ExplainParameters { + TYPE, FORMAT; + + fun getCompliantString(target: String?, input: Token): String = when (target) { + null -> input.text!! + else -> throw error(input, "Cannot set EXPLAIN parameter ${this.name} multiple times.") + } + } + } +} diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt index 51a63feb92..f886401019 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserBagOpTests.kt @@ -1,13 +1,24 @@ package org.partiql.parser.internal import org.junit.jupiter.api.Test -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.SetOp -import org.partiql.ast.SetQuantifier -import org.partiql.ast.builder.AstBuilder -import org.partiql.ast.builder.ast +import org.partiql.ast.v1.Ast.exprBag +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprQuerySet +import org.partiql.ast.v1.Ast.exprStruct +import org.partiql.ast.v1.Ast.exprStructField +import org.partiql.ast.v1.Ast.from +import org.partiql.ast.v1.Ast.fromExpr +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.Ast.queryBodySFW +import org.partiql.ast.v1.Ast.queryBodySetOp +import org.partiql.ast.v1.Ast.selectStar +import org.partiql.ast.v1.Ast.setOp +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.FromType +import org.partiql.ast.v1.SetOpType +import org.partiql.ast.v1.SetQuantifier +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.ExprQuerySet import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value import org.partiql.value.stringValue @@ -15,110 +26,137 @@ import kotlin.test.assertEquals class PartiQLParserBagOpTests { - private val parser = PartiQLParserDefault() + private val parser = V1PartiQLParserDefault() - private fun query(block: AstBuilder.() -> Expr) = ast { statementQuery { expr = block() } } + private fun queryBody(body: () -> Expr) = query(body()) @OptIn(PartiQLValueExperimental::class) - private fun createSFW(i: Int): Expr.QuerySet = - ast { - exprQuerySet { - body = queryBodySFW { - select = selectStar() - from = fromValue { - expr = exprCollection { - type = Expr.Collection.Type.BAG - values = mutableListOf( - exprStruct { - fields = mutableListOf( - exprStructField { - name = exprLit { value = stringValue("a") } - value = exprLit { value = int32Value(i) } - } + private fun createSFW(i: Int): ExprQuerySet = + exprQuerySet( + body = queryBodySFW( + select = selectStar(setq = null), + from = from( + tableRefs = listOf( + fromExpr( + expr = exprBag( + values = mutableListOf( + exprStruct( + fields = mutableListOf( + exprStructField( + name = exprLit(value = stringValue("a")), + value = exprLit(value = int32Value(i)) + ) + ) ) - } - ) - } - type = From.Value.Type.SCAN - } - } - } - } + ) + ), + fromType = FromType.SCAN(), + asAlias = null, + atAlias = null + ) + ) + ), + exclude = null, + let = null, + where = null, + groupBy = null, + having = null, + ), + orderBy = null, + limit = null, + offset = null + ) @OptIn(PartiQLValueExperimental::class) - private fun createLit(i: Int) = ast { exprLit { value = int32Value(i) } } + private fun createLit(i: Int) = exprLit(int32Value(i)) // SQL Union @Test fun sqlUnion() = assertExpression( "SELECT * FROM <<{'a': 1}>> UNION SELECT * FROM <<{'a': 2}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = null + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlUnionMultiple() = assertExpression( "SELECT * FROM <<{'a': 1}>> UNION ALL SELECT * FROM <<{'a': 2}>> UNION DISTINCT SELECT * FROM <<{'a': 3}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createSFW(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlUnionMultipleRight() = assertExpression( "SELECT * FROM <<{'a': 1}>> UNION ALL (SELECT * FROM <<{'a': 2}>> UNION DISTINCT SELECT * FROM <<{'a': 3}>>)", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.DISTINCT - } - lhs = createSFW(2) - rhs = createSFW(3) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.DISTINCT() + ), + lhs = createSFW(2), + rhs = createSFW(3), isOuter = false - } - isOuter = false - } - } - } + ), + orderBy = null, + limit = null, + offset = null + ), + isOuter = false + ), + orderBy = null, + limit = null, + offset = null + ) } ) @@ -126,90 +164,110 @@ class PartiQLParserBagOpTests { @Test fun outerUnion() = assertExpression( "SELECT * FROM <<{'a': 1}>> OUTER UNION SELECT * FROM <<{'a': 2}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = null + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = true - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun outerUnionNonSpecified() = assertExpression( "SELECT * FROM <<{'a': 1}>> UNION 2", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - } - lhs = createSFW(1) - rhs = createLit(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = null + ), + lhs = createSFW(1), + rhs = createLit(2), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlUnionAndOuterUnion() = assertExpression( "SELECT * FROM <<{'a': 1}>> UNION ALL SELECT * FROM <<{'a': 2}>> UNION DISTINCT 3", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createLit(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createLit(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun outerUnionAndSQLUnion() = assertExpression( "1 UNION ALL SELECT * FROM <<{'a': 2}>> UNION DISTINCT SELECT * FROM <<{'a': 3}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.UNION - setq = SetQuantifier.ALL - } - lhs = createLit(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.UNION(), + setq = SetQuantifier.ALL() + ), + lhs = createLit(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createSFW(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @@ -217,73 +275,89 @@ class PartiQLParserBagOpTests { @Test fun sqlExcept() = assertExpression( "SELECT * FROM <<{'a': 1}>> EXCEPT SELECT * FROM <<{'a': 2}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = null + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlExceptMultiple() = assertExpression( "SELECT * FROM <<{'a': 1}>> EXCEPT ALL SELECT * FROM <<{'a': 2}>> EXCEPT DISTINCT SELECT * FROM <<{'a': 3}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createSFW(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlExceptMultipleRight() = assertExpression( "SELECT * FROM <<{'a': 1}>> EXCEPT ALL (SELECT * FROM <<{'a': 2}>> EXCEPT DISTINCT SELECT * FROM <<{'a': 3}>>)", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.DISTINCT - } - lhs = createSFW(2) - rhs = createSFW(3) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = createSFW(2), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @@ -291,90 +365,112 @@ class PartiQLParserBagOpTests { @Test fun outerExcept() = assertExpression( "SELECT * FROM <<{'a': 1}>> OUTER EXCEPT SELECT * FROM <<{'a': 2}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = null + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = true - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun outerExceptNonSpecified() = assertExpression( "SELECT * FROM <<{'a': 1}>> EXCEPT 2", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - } - lhs = createSFW(1) - rhs = createLit(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = null + ), + lhs = createSFW(1), + rhs = createLit(2), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlExceptAndOuterExcept() = assertExpression( "SELECT * FROM <<{'a': 1}>> EXCEPT ALL SELECT * FROM <<{'a': 2}>> EXCEPT DISTINCT 3", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createLit(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createLit(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun outerExceptAndSQLExcept() = assertExpression( "1 EXCEPT ALL SELECT * FROM <<{'a': 2}>> EXCEPT DISTINCT SELECT * FROM <<{'a': 3}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.EXCEPT - setq = SetQuantifier.ALL - } - lhs = createLit(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.EXCEPT(), + setq = SetQuantifier.ALL() + ), + lhs = createLit(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createSFW(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @@ -382,73 +478,89 @@ class PartiQLParserBagOpTests { @Test fun sqlIntersect() = assertExpression( "SELECT * FROM <<{'a': 1}>> INTERSECT SELECT * FROM <<{'a': 2}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = null + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlIntersectMultiple() = assertExpression( "SELECT * FROM <<{'a': 1}>> INTERSECT ALL SELECT * FROM <<{'a': 2}>> INTERSECT DISTINCT SELECT * FROM <<{'a': 3}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createSFW(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlIntersectMultipleRight() = assertExpression( "SELECT * FROM <<{'a': 1}>> INTERSECT ALL (SELECT * FROM <<{'a': 2}>> INTERSECT DISTINCT SELECT * FROM <<{'a': 3}>>)", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.DISTINCT - } - lhs = createSFW(2) - rhs = createSFW(3) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = createSFW(2), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @@ -456,90 +568,110 @@ class PartiQLParserBagOpTests { @Test fun outerIntersect() = assertExpression( "SELECT * FROM <<{'a': 1}>> OUTER INTERSECT SELECT * FROM <<{'a': 2}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = null + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = true - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun outerIntersectNonSpecified() = assertExpression( "SELECT * FROM <<{'a': 1}>> INTERSECT 2", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - } - lhs = createSFW(1) - rhs = createLit(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = null + ), + lhs = createSFW(1), + rhs = createLit(2), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun sqlIntersectAndOuterIntersect() = assertExpression( "SELECT * FROM <<{'a': 1}>> INTERSECT ALL SELECT * FROM <<{'a': 2}>> INTERSECT DISTINCT 3", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.ALL - } - lhs = createSFW(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.ALL() + ), + lhs = createSFW(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createLit(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createLit(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) @Test fun outerIntersectAndSQLIntersect() = assertExpression( "1 INTERSECT ALL SELECT * FROM <<{'a': 2}>> INTERSECT DISTINCT SELECT * FROM <<{'a': 3}>>", - query { - exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.DISTINCT - } - lhs = exprQuerySet { - body = queryBodySetOp { - type = setOp { - type = SetOp.Type.INTERSECT - setq = SetQuantifier.ALL - } - lhs = createLit(1) - rhs = createSFW(2) + queryBody { + exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.DISTINCT() + ), + lhs = exprQuerySet( + body = queryBodySetOp( + type = setOp( + setOpType = SetOpType.INTERSECT(), + setq = SetQuantifier.ALL() + ), + lhs = createLit(1), + rhs = createSFW(2), isOuter = false - } - } - rhs = createSFW(3) + ), + orderBy = null, + limit = null, + offset = null + ), + rhs = createSFW(3), isOuter = false - } - } + ), + orderBy = null, + limit = null, + offset = null + ) } ) diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt index 3fbb0321a4..bd4a3d7ddb 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserDDLTests.kt @@ -1,22 +1,11 @@ package org.partiql.parser.internal -import org.junit.jupiter.api.extension.ExtensionContext -import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.Arguments -import org.junit.jupiter.params.provider.ArgumentsProvider -import org.junit.jupiter.params.provider.ArgumentsSource -import org.partiql.ast.AstNode -import org.partiql.ast.Identifier -import org.partiql.ast.identifierQualified -import org.partiql.ast.identifierSymbol -import org.partiql.ast.statementDDLCreateTable -import org.partiql.ast.statementDDLDropTable -import java.util.stream.Stream +import org.partiql.ast.v1.AstNode import kotlin.test.assertEquals class PartiQLParserDDLTests { - private val parser = PartiQLParserDefault() + private val parser = V1PartiQLParserDefault() data class SuccessTestCase( val description: String? = null, @@ -24,107 +13,108 @@ class PartiQLParserDDLTests { val node: AstNode ) - @ArgumentsSource(TestProvider::class) - @ParameterizedTest - fun errorTests(tc: SuccessTestCase) = assertExpression(tc.query, tc.node) - - class TestProvider : ArgumentsProvider { - val createTableTests = listOf( - SuccessTestCase( - "CREATE TABLE with unqualified case insensitive name", - "CREATE TABLE foo", - statementDDLCreateTable( - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - null - ) - ), - // Support Case Sensitive identifier as table name - // Subsequent process may need to change - // See: https://www.db-fiddle.com/f/9A8mknSNYuRGLfkqkLeiHD/0 for reference. - SuccessTestCase( - "CREATE TABLE with unqualified case sensitive name", - "CREATE TABLE \"foo\"", - statementDDLCreateTable( - identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), - null - ) - ), - SuccessTestCase( - "CREATE TABLE with qualified case insensitive name", - "CREATE TABLE myCatalog.mySchema.foo", - statementDDLCreateTable( - identifierQualified( - identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), - listOf( - identifierSymbol("mySchema", Identifier.CaseSensitivity.INSENSITIVE), - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - ) - ), - null - ) - ), - SuccessTestCase( - "CREATE TABLE with qualified name with mixed case sensitivity", - "CREATE TABLE myCatalog.\"mySchema\".foo", - statementDDLCreateTable( - identifierQualified( - identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), - listOf( - identifierSymbol("mySchema", Identifier.CaseSensitivity.SENSITIVE), - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - ) - ), - null - ) - ), - ) - - val dropTableTests = listOf( - SuccessTestCase( - "DROP TABLE with unqualified case insensitive name", - "DROP TABLE foo", - statementDDLDropTable( - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - ) - ), - SuccessTestCase( - "DROP TABLE with unqualified case sensitive name", - "DROP TABLE \"foo\"", - statementDDLDropTable( - identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), - ) - ), - SuccessTestCase( - "DROP TABLE with qualified case insensitive name", - "DROP TABLE myCatalog.mySchema.foo", - statementDDLDropTable( - identifierQualified( - identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), - listOf( - identifierSymbol("mySchema", Identifier.CaseSensitivity.INSENSITIVE), - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - ) - ), - ) - ), - SuccessTestCase( - "DROP TABLE with qualified name with mixed case sensitivity", - "DROP TABLE myCatalog.\"mySchema\".foo", - statementDDLDropTable( - identifierQualified( - identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), - listOf( - identifierSymbol("mySchema", Identifier.CaseSensitivity.SENSITIVE), - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), - ) - ), - ) - ), - ) - - override fun provideArguments(context: ExtensionContext?): Stream = - (createTableTests + dropTableTests).map { Arguments.of(it) }.stream() - } + // DDL not yet supported in v1 AST +// @ArgumentsSource(TestProvider::class) +// @ParameterizedTest +// fun errorTests(tc: SuccessTestCase) = assertExpression(tc.query, tc.node) +// +// class TestProvider : ArgumentsProvider { +// val createTableTests = listOf( +// SuccessTestCase( +// "CREATE TABLE with unqualified case insensitive name", +// "CREATE TABLE foo", +// statementDDLCreateTable( +// identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), +// null +// ) +// ), +// // Support Case Sensitive identifier as table name +// // Subsequent process may need to change +// // See: https://www.db-fiddle.com/f/9A8mknSNYuRGLfkqkLeiHD/0 for reference. +// SuccessTestCase( +// "CREATE TABLE with unqualified case sensitive name", +// "CREATE TABLE \"foo\"", +// statementDDLCreateTable( +// identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), +// null +// ) +// ), +// SuccessTestCase( +// "CREATE TABLE with qualified case insensitive name", +// "CREATE TABLE myCatalog.mySchema.foo", +// statementDDLCreateTable( +// identifierQualified( +// identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), +// listOf( +// identifierSymbol("mySchema", Identifier.CaseSensitivity.INSENSITIVE), +// identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), +// ) +// ), +// null +// ) +// ), +// SuccessTestCase( +// "CREATE TABLE with qualified name with mixed case sensitivity", +// "CREATE TABLE myCatalog.\"mySchema\".foo", +// statementDDLCreateTable( +// identifierQualified( +// identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), +// listOf( +// identifierSymbol("mySchema", Identifier.CaseSensitivity.SENSITIVE), +// identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), +// ) +// ), +// null +// ) +// ), +// ) +// +// val dropTableTests = listOf( +// SuccessTestCase( +// "DROP TABLE with unqualified case insensitive name", +// "DROP TABLE foo", +// statementDDLDropTable( +// identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), +// ) +// ), +// SuccessTestCase( +// "DROP TABLE with unqualified case sensitive name", +// "DROP TABLE \"foo\"", +// statementDDLDropTable( +// identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), +// ) +// ), +// SuccessTestCase( +// "DROP TABLE with qualified case insensitive name", +// "DROP TABLE myCatalog.mySchema.foo", +// statementDDLDropTable( +// identifierQualified( +// identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), +// listOf( +// identifierSymbol("mySchema", Identifier.CaseSensitivity.INSENSITIVE), +// identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), +// ) +// ), +// ) +// ), +// SuccessTestCase( +// "DROP TABLE with qualified name with mixed case sensitivity", +// "DROP TABLE myCatalog.\"mySchema\".foo", +// statementDDLDropTable( +// identifierQualified( +// identifierSymbol("myCatalog", Identifier.CaseSensitivity.INSENSITIVE), +// listOf( +// identifierSymbol("mySchema", Identifier.CaseSensitivity.SENSITIVE), +// identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), +// ) +// ), +// ) +// ), +// ) +// +// override fun provideArguments(context: ExtensionContext?): Stream = +// (createTableTests + dropTableTests).map { Arguments.of(it) }.stream() +// } private fun assertExpression(input: String, expected: AstNode) { val result = parser.parse(input) diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt index a35372845c..6cc3947584 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserFunctionCallTests.kt @@ -1,27 +1,26 @@ package org.partiql.parser.internal import org.junit.jupiter.api.Test -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.Identifier -import org.partiql.ast.exprCall -import org.partiql.ast.identifierQualified -import org.partiql.ast.identifierSymbol -import org.partiql.ast.statementQuery +import org.partiql.ast.v1.Ast.exprCall +import org.partiql.ast.v1.Ast.identifier +import org.partiql.ast.v1.Ast.identifierChain +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.expr.Expr import kotlin.test.assertEquals class PartiQLParserFunctionCallTests { - private val parser = PartiQLParserDefault() + private val parser = V1PartiQLParserDefault() - private inline fun query(body: () -> Expr) = statementQuery(body()) + private inline fun queryBody(body: () -> Expr) = query(body()) @Test fun callUnqualifiedNonReservedInsensitive() = assertExpression( "foo()", - query { + queryBody { exprCall( - function = identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + function = identifierChain(identifier("foo", false), null), args = emptyList(), setq = null ) @@ -31,9 +30,9 @@ class PartiQLParserFunctionCallTests { @Test fun callUnqualifiedNonReservedSensitive() = assertExpression( "\"foo\"()", - query { + queryBody { exprCall( - function = identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), + function = identifierChain(identifier("foo", true), null), args = emptyList(), setq = null ) @@ -43,9 +42,9 @@ class PartiQLParserFunctionCallTests { @Test fun callUnqualifiedReservedInsensitive() = assertExpression( "upper()", - query { + queryBody { exprCall( - function = identifierSymbol("upper", Identifier.CaseSensitivity.INSENSITIVE), + function = identifierChain(identifier("upper", false), null), args = emptyList(), setq = null ) @@ -55,9 +54,9 @@ class PartiQLParserFunctionCallTests { @Test fun callUnqualifiedReservedSensitive() = assertExpression( "\"upper\"()", - query { + queryBody { exprCall( - function = identifierSymbol("upper", Identifier.CaseSensitivity.SENSITIVE), + function = identifierChain(identifier("upper", true), null), args = emptyList(), setq = null ) @@ -67,13 +66,13 @@ class PartiQLParserFunctionCallTests { @Test fun callQualifiedNonReservedInsensitive() = assertExpression( "my_catalog.my_schema.foo()", - query { + queryBody { exprCall( - function = identifierQualified( - root = identifierSymbol("my_catalog", Identifier.CaseSensitivity.INSENSITIVE), - steps = listOf( - identifierSymbol("my_schema", Identifier.CaseSensitivity.INSENSITIVE), - identifierSymbol("foo", Identifier.CaseSensitivity.INSENSITIVE), + function = identifierChain( + root = identifier("my_catalog", false), + next = identifierChain( + root = identifier("my_schema", false), + next = identifierChain(identifier("foo", false), null), ) ), args = emptyList(), @@ -85,13 +84,13 @@ class PartiQLParserFunctionCallTests { @Test fun callQualifiedNonReservedSensitive() = assertExpression( "my_catalog.my_schema.\"foo\"()", - query { + queryBody { exprCall( - function = identifierQualified( - root = identifierSymbol("my_catalog", Identifier.CaseSensitivity.INSENSITIVE), - steps = listOf( - identifierSymbol("my_schema", Identifier.CaseSensitivity.INSENSITIVE), - identifierSymbol("foo", Identifier.CaseSensitivity.SENSITIVE), + function = identifierChain( + root = identifier("my_catalog", false), + next = identifierChain( + identifier("my_schema", false), + identifierChain(identifier("foo", true), null), ) ), args = emptyList(), @@ -103,13 +102,13 @@ class PartiQLParserFunctionCallTests { @Test fun callQualifiedReservedInsensitive() = assertExpression( "my_catalog.my_schema.upper()", - query { + queryBody { exprCall( - function = identifierQualified( - root = identifierSymbol("my_catalog", Identifier.CaseSensitivity.INSENSITIVE), - steps = listOf( - identifierSymbol("my_schema", Identifier.CaseSensitivity.INSENSITIVE), - identifierSymbol("upper", Identifier.CaseSensitivity.INSENSITIVE), + function = identifierChain( + root = identifier("my_catalog", false), + next = identifierChain( + identifier("my_schema", false), + identifierChain(identifier("upper", false), null), ) ), args = emptyList(), @@ -121,13 +120,13 @@ class PartiQLParserFunctionCallTests { @Test fun callQualifiedReservedSensitive() = assertExpression( "my_catalog.my_schema.\"upper\"()", - query { + queryBody { exprCall( - function = identifierQualified( - root = identifierSymbol("my_catalog", Identifier.CaseSensitivity.INSENSITIVE), - steps = listOf( - identifierSymbol("my_schema", Identifier.CaseSensitivity.INSENSITIVE), - identifierSymbol("upper", Identifier.CaseSensitivity.SENSITIVE), + function = identifierChain( + root = identifier("my_catalog", false), + next = identifierChain( + identifier("my_schema", false), + identifierChain(identifier("upper", true), null), ) ), args = emptyList(), diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt index e05f33551c..7b25f97e43 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserOperatorTests.kt @@ -1,11 +1,11 @@ package org.partiql.parser.internal import org.junit.jupiter.api.Test -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.exprLit -import org.partiql.ast.exprOperator -import org.partiql.ast.statementQuery +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprOperator +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.expr.Expr import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value import kotlin.test.assertEquals @@ -13,14 +13,14 @@ import kotlin.test.assertEquals @OptIn(PartiQLValueExperimental::class) class PartiQLParserOperatorTests { - private val parser = PartiQLParserDefault() + private val parser = V1PartiQLParserDefault() - private inline fun query(body: () -> Expr) = statementQuery(body()) + private inline fun queryBody(body: () -> Expr) = query(body()) @Test fun builtinUnaryOperator() = assertExpression( "-2", - query { + queryBody { exprOperator( symbol = "-", lhs = null, @@ -32,7 +32,7 @@ class PartiQLParserOperatorTests { @Test fun builtinBinaryOperator() = assertExpression( "1 <= 2", - query { + queryBody { exprOperator( symbol = "<=", lhs = exprLit(int32Value(1)), @@ -44,7 +44,7 @@ class PartiQLParserOperatorTests { @Test fun customUnaryOperator() = assertExpression( "==!2", - query { + queryBody { exprOperator( symbol = "==!", lhs = null, @@ -56,7 +56,7 @@ class PartiQLParserOperatorTests { @Test fun customBinaryOperator() = assertExpression( "1 ==! 2", - query { + queryBody { exprOperator( symbol = "==!", lhs = exprLit(int32Value(1)), diff --git a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt index 2ec95b2d7d..7f0604759c 100644 --- a/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt +++ b/partiql-parser/src/test/kotlin/org/partiql/parser/internal/PartiQLParserSessionAttributeTests.kt @@ -1,12 +1,13 @@ package org.partiql.parser.internal import org.junit.jupiter.api.Test -import org.partiql.ast.AstNode -import org.partiql.ast.Expr -import org.partiql.ast.exprLit -import org.partiql.ast.exprOperator -import org.partiql.ast.exprSessionAttribute -import org.partiql.ast.statementQuery +import org.partiql.ast.v1.Ast.exprLit +import org.partiql.ast.v1.Ast.exprOperator +import org.partiql.ast.v1.Ast.exprSessionAttribute +import org.partiql.ast.v1.Ast.query +import org.partiql.ast.v1.AstNode +import org.partiql.ast.v1.expr.Expr +import org.partiql.ast.v1.expr.SessionAttribute import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value import kotlin.test.assertEquals @@ -14,42 +15,42 @@ import kotlin.test.assertEquals @OptIn(PartiQLValueExperimental::class) class PartiQLParserSessionAttributeTests { - private val parser = PartiQLParserDefault() + private val parser = V1PartiQLParserDefault() - private inline fun query(body: () -> Expr) = statementQuery(body()) + private inline fun queryBody(body: () -> Expr) = query(body()) @Test fun currentUserUpperCase() = assertExpression( "CURRENT_USER", - query { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) + queryBody { + exprSessionAttribute(SessionAttribute.CURRENT_USER()) } ) @Test fun currentUserMixedCase() = assertExpression( "CURRENT_user", - query { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) + queryBody { + exprSessionAttribute(SessionAttribute.CURRENT_USER()) } ) @Test fun currentUserLowerCase() = assertExpression( "current_user", - query { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) + queryBody { + exprSessionAttribute(SessionAttribute.CURRENT_USER()) } ) @Test fun currentUserEquals() = assertExpression( "1 = current_user", - query { + queryBody { exprOperator( symbol = "=", lhs = exprLit(int32Value(1)), - rhs = exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_USER) + rhs = exprSessionAttribute(SessionAttribute.CURRENT_USER()) ) } ) @@ -57,24 +58,24 @@ class PartiQLParserSessionAttributeTests { @Test fun currentDateUpperCase() = assertExpression( "CURRENT_DATE", - query { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE) + queryBody { + exprSessionAttribute(SessionAttribute.CURRENT_DATE()) } ) @Test fun currentDateMixedCase() = assertExpression( "CURRENT_date", - query { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE) + queryBody { + exprSessionAttribute(SessionAttribute.CURRENT_DATE()) } ) @Test fun currentDateLowerCase() = assertExpression( "current_date", - query { - exprSessionAttribute(Expr.SessionAttribute.Attribute.CURRENT_DATE) + queryBody { + exprSessionAttribute(SessionAttribute.CURRENT_DATE()) } )