From c0b6e0122acc86948819e596182728b12cdea625 Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Mon, 12 Aug 2024 14:26:46 +0200 Subject: [PATCH] chore(embedded/sql): improvements on SQL layer. - Allow selecting generic expressions; - Refactor FnCall implementation; - Add support for INSERT INTO SELECT statements; - Add more core builtin functions; Signed-off-by: Stefano Scafiti --- embedded/document/engine.go | 12 +- embedded/sql/aggregated_values.go | 20 + embedded/sql/engine.go | 1 - embedded/sql/engine_test.go | 346 +++++++++- embedded/sql/functions.go | 334 +++++++++ embedded/sql/grouped_row_reader.go | 28 +- embedded/sql/grouped_row_reader_test.go | 6 +- embedded/sql/json_type.go | 31 +- embedded/sql/num_operator.go | 21 +- embedded/sql/num_operator_test.go | 6 + embedded/sql/parser.go | 1 + embedded/sql/parser_test.go | 422 ++++++----- embedded/sql/proj_row_reader.go | 121 ++-- embedded/sql/sql_grammar.y | 55 +- embedded/sql/sql_parser.go | 884 ++++++++++++------------ embedded/sql/stmt.go | 403 +++++++---- embedded/sql/stmt_test.go | 242 ++++++- embedded/sql/type_conversion.go | 38 +- embedded/sql/values_row_reader.go | 24 +- embedded/sql/values_row_reader_test.go | 14 +- embedded/store/options.go | 6 +- 21 files changed, 2129 insertions(+), 886 deletions(-) create mode 100644 embedded/sql/functions.go diff --git a/embedded/document/engine.go b/embedded/document/engine.go index 4e7731e18e..553a5b9b4b 100644 --- a/embedded/document/engine.go +++ b/embedded/document/engine.go @@ -748,10 +748,10 @@ func (e *Engine) upsertDocuments(ctx context.Context, sqlTx *sql.SQLTx, collecti ctx, sqlTx, []sql.SQLStmt{ - sql.NewUpserIntoStmt( + sql.NewUpsertIntoStmt( collectionName, colNames, - rows, + sql.NewValuesDataSource(rows), isInsert, nil, ), @@ -879,7 +879,7 @@ func (e *Engine) ReplaceDocuments(ctx context.Context, username string, query *p } queryStmt := sql.NewSelectStmt( - []sql.Selector{sql.NewColSelector(query.CollectionName, documentIdFieldName)}, + []sql.TargetEntry{{Exp: sql.NewColSelector(query.CollectionName, documentIdFieldName)}}, sql.NewTableRef(query.CollectionName, ""), queryCondition, generateSQLOrderByClauses(table, query.OrderBy), @@ -982,7 +982,7 @@ func (e *Engine) GetDocuments(ctx context.Context, query *protomodel.Query, offs } op := sql.NewSelectStmt( - []sql.Selector{sql.NewColSelector(query.CollectionName, DocumentBLOBField)}, + []sql.TargetEntry{{Exp: sql.NewColSelector(query.CollectionName, DocumentBLOBField)}}, sql.NewTableRef(query.CollectionName, ""), queryCondition, generateSQLOrderByClauses(table, query.OrderBy), @@ -1023,7 +1023,7 @@ func (e *Engine) CountDocuments(ctx context.Context, query *protomodel.Query, of } ds := sql.NewSelectStmt( - []sql.Selector{sql.NewColSelector(query.CollectionName, table.Cols()[0].Name())}, + []sql.TargetEntry{{Exp: sql.NewColSelector(query.CollectionName, table.Cols()[0].Name())}}, sql.NewTableRef(query.CollectionName, ""), queryCondition, generateSQLOrderByClauses(table, query.OrderBy), @@ -1032,7 +1032,7 @@ func (e *Engine) CountDocuments(ctx context.Context, query *protomodel.Query, of ) op := sql.NewSelectStmt( - []sql.Selector{sql.NewAggColSelector(sql.COUNT, query.CollectionName, "*")}, + []sql.TargetEntry{{Exp: sql.NewAggColSelector(sql.COUNT, query.CollectionName, "*")}}, ds, nil, nil, diff --git a/embedded/sql/aggregated_values.go b/embedded/sql/aggregated_values.go index 6935ab54bb..3182ac9074 100644 --- a/embedded/sql/aggregated_values.go +++ b/embedded/sql/aggregated_values.go @@ -102,6 +102,10 @@ func (v *CountValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedVal return nil, ErrUnexpected } +func (v *CountValue) selectors() []Selector { + return nil +} + func (v *CountValue) reduceSelectors(row *Row, implicitTable string) ValueExp { return nil } @@ -198,6 +202,10 @@ func (v *SumValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue return nil, ErrUnexpected } +func (v *SumValue) selectors() []Selector { + return nil +} + func (v *SumValue) reduceSelectors(row *Row, implicitTable string) ValueExp { return v } @@ -301,6 +309,10 @@ func (v *MinValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue return nil, ErrUnexpected } +func (v *MinValue) selectors() []Selector { + return nil +} + func (v *MinValue) reduceSelectors(row *Row, implicitTable string) ValueExp { return nil } @@ -404,6 +416,10 @@ func (v *MaxValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue return nil, ErrUnexpected } +func (v *MaxValue) selectors() []Selector { + return nil +} + func (v *MaxValue) reduceSelectors(row *Row, implicitTable string) ValueExp { return nil } @@ -516,6 +532,10 @@ func (v *AVGValue) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue return nil, ErrUnexpected } +func (v *AVGValue) selectors() []Selector { + return nil +} + func (v *AVGValue) reduceSelectors(row *Row, implicitTable string) ValueExp { return nil } diff --git a/embedded/sql/engine.go b/embedded/sql/engine.go index cb1da140c1..0adad90fcd 100644 --- a/embedded/sql/engine.go +++ b/embedded/sql/engine.go @@ -658,7 +658,6 @@ func (e *Engine) InferParameters(ctx context.Context, tx *SQLTx, sql string) (pa if err != nil { return nil, fmt.Errorf("%w: %v", ErrParsingError, err) } - return e.InferParametersPreparedStmts(ctx, tx, stmts) } diff --git a/embedded/sql/engine_test.go b/embedded/sql/engine_test.go index ee07507f6b..2ddd48f6c2 100644 --- a/embedded/sql/engine_test.go +++ b/embedded/sql/engine_test.go @@ -1743,6 +1743,89 @@ func TestUpsertInto(t *testing.T) { require.NoError(t, err) } +func TestUpsertIntoSelect(t *testing.T) { + st, err := store.Open(t.TempDir(), store.DefaultOptions().WithMultiIndexing(true)) + require.NoError(t, err) + defer closeStore(t, st) + + engine, err := NewEngine(st, DefaultOptions().WithPrefix(sqlPrefix)) + require.NoError(t, err) + + _, _, err = engine.Exec( + context.Background(), + nil, `CREATE TABLE table1 ( + id INTEGER AUTO_INCREMENT, + meta JSON, + + PRIMARY KEY id + )`, nil) + require.NoError(t, err) + + _, _, err = engine.Exec( + context.Background(), + nil, `CREATE TABLE table2 ( + id INTEGER AUTO_INCREMENT, + name VARCHAR, + age INTEGER, + active BOOLEAN, + created_at TIMESTAMP, + + PRIMARY KEY id + )`, nil) + require.NoError(t, err) + + n := 100 + for i := 0; i < n; i++ { + name := fmt.Sprintf("name%d", i) + age := 10 + rand.Intn(50) + active := rand.Intn(2) == 1 + + upsert := fmt.Sprintf( + `INSERT INTO table1 (meta) VALUES ('{"name": "%s", "age": %d, "active": %t, "createdAt": "%s"}')`, + name, + age, + active, + time.Now().Format("2006-01-02 15:04:05.999999"), + ) + _, _, err = engine.Exec( + context.Background(), + nil, + upsert, + nil, + ) + require.NoError(t, err) + } + + _, _, err = engine.Exec( + context.Background(), + nil, + `INSERT INTO table2(name, age, active, created_at) + SELECT meta->'name', meta->'age', meta->'active', meta->'createdAt'::TIMESTAMP + FROM table1 + `, + nil, + ) + require.NoError(t, err) + + rows, err := engine.queryAll( + context.Background(), + nil, + `SELECT t1.meta->'name' = t2.name, t1.meta->'age' = t2.age, t1.meta->'active' = t2.active, t1.meta->'createdAt'::TIMESTAMP = t2.created_at + FROM table1 AS t1 JOIN table2 AS t2 on t1.id = t2.id`, + nil, + ) + require.NoError(t, err) + require.Len(t, rows, 100) + + for _, row := range rows { + require.Len(t, row.ValuesByPosition, 4) + + for _, v := range row.ValuesByPosition { + require.True(t, v.RawValue().(bool)) + } + } +} + func TestInsertIntoEdgeCases(t *testing.T) { engine := setupCommonTest(t) @@ -2705,6 +2788,15 @@ func TestQuery(t *testing.T) { err = r.Close() require.NoError(t, err) + + r, err = engine.Query(context.Background(), nil, "SELECT id, title, active FROM table1 WHERE id % 0", nil) + require.NoError(t, err) + + _, err = r.Read(context.Background()) + require.ErrorIs(t, err, ErrDivisionByZero) + + err = r.Close() + require.NoError(t, err) }) t.Run("Query with floating-point division by zero", func(t *testing.T) { @@ -2716,6 +2808,15 @@ func TestQuery(t *testing.T) { err = r.Close() require.NoError(t, err) + + r, err = engine.Query(context.Background(), nil, "SELECT id, title, active FROM table1 WHERE id % (1.0-1.0)", nil) + require.NoError(t, err) + + _, err = r.Read(context.Background()) + require.ErrorIs(t, err, ErrDivisionByZero) + + err = r.Close() + require.NoError(t, err) }) r, err = engine.Query(context.Background(), nil, "SELECT id, title, active FROM table1 WHERE id = 0 AND NOT active OR active", nil) @@ -2760,34 +2861,60 @@ func TestQuery(t *testing.T) { require.NoError(t, err) }) - r, err = engine.Query(context.Background(), nil, "INVALID QUERY", nil) - require.ErrorIs(t, err, ErrParsingError) - require.EqualError(t, err, "parsing error: syntax error: unexpected IDENTIFIER at position 7") - require.Nil(t, r) + t.Run("query expressions", func(t *testing.T) { + reader, err := engine.Query(context.Background(), nil, "SELECT 1, (id + 1) * 2.0, id % 2 = 0 FROM table1", nil) + require.NoError(t, err) - r, err = engine.Query(context.Background(), nil, "UPSERT INTO table1 (id) VALUES(1)", nil) - require.ErrorIs(t, err, ErrExpectingDQLStmt) - require.Nil(t, r) + cols, err := reader.Columns(context.Background()) + require.NoError(t, err) + require.Len(t, cols, 3) - r, err = engine.Query(context.Background(), nil, "UPSERT INTO table1 (id) VALUES(1); UPSERT INTO table1 (id) VALUES(1)", nil) - require.ErrorIs(t, err, ErrExpectingDQLStmt) - require.Nil(t, r) + require.Equal(t, ColDescriptor{Table: "table1", Column: "col0", Type: IntegerType}, cols[0]) + require.Equal(t, ColDescriptor{Table: "table1", Column: "col1", Type: Float64Type}, cols[1]) + require.Equal(t, ColDescriptor{Table: "table1", Column: "col2", Type: BooleanType}, cols[2]) - r, err = engine.QueryPreparedStmt(context.Background(), nil, nil, nil) - require.ErrorIs(t, err, ErrIllegalArguments) - require.Nil(t, r) + rows, err := ReadAllRows(context.Background(), reader) + require.NoError(t, err) + require.Len(t, rows, 10) + require.NoError(t, reader.Close()) - params = make(map[string]interface{}) - params["null_param"] = nil + for i, row := range rows { + require.Equal(t, int64(1), row.ValuesBySelector[EncodeSelector("", "table1", "col0")].RawValue()) + require.Equal(t, float64((i+1)*2), row.ValuesBySelector[EncodeSelector("", "table1", "col1")].RawValue()) + require.Equal(t, i%2 == 0, row.ValuesBySelector[EncodeSelector("", "table1", "col2")].RawValue()) + } + }) - r, err = engine.Query(context.Background(), nil, "SELECT id FROM table1 WHERE active = @null_param", params) - require.NoError(t, err) + t.Run("invalid queries", func(t *testing.T) { + r, err = engine.Query(context.Background(), nil, "INVALID QUERY", nil) + require.ErrorIs(t, err, ErrParsingError) + require.EqualError(t, err, "parsing error: syntax error: unexpected IDENTIFIER at position 7") + require.Nil(t, r) - _, err = r.Read(context.Background()) - require.ErrorIs(t, err, ErrNoMoreRows) + r, err = engine.Query(context.Background(), nil, "UPSERT INTO table1 (id) VALUES(1)", nil) + require.ErrorIs(t, err, ErrExpectingDQLStmt) + require.Nil(t, r) - err = r.Close() - require.NoError(t, err) + r, err = engine.Query(context.Background(), nil, "UPSERT INTO table1 (id) VALUES(1); UPSERT INTO table1 (id) VALUES(1)", nil) + require.ErrorIs(t, err, ErrExpectingDQLStmt) + require.Nil(t, r) + + r, err = engine.QueryPreparedStmt(context.Background(), nil, nil, nil) + require.ErrorIs(t, err, ErrIllegalArguments) + require.Nil(t, r) + + params = make(map[string]interface{}) + params["null_param"] = nil + + r, err = engine.Query(context.Background(), nil, "SELECT id FROM table1 WHERE active = @null_param", params) + require.NoError(t, err) + + _, err = r.Read(context.Background()) + require.ErrorIs(t, err, ErrNoMoreRows) + + err = r.Close() + require.NoError(t, err) + }) } func TestJSON(t *testing.T) { @@ -8848,3 +8975,180 @@ func TestGrantSQLPrivileges(t *testing.T) { checkGrants("SHOW GRANTS") checkGrants("SHOW GRANTS FOR myuser") } + +func TestFunctions(t *testing.T) { + st, err := store.Open(t.TempDir(), store.DefaultOptions().WithMultiIndexing(true)) + require.NoError(t, err) + defer closeStore(t, st) + + engine, err := NewEngine(st, DefaultOptions().WithPrefix(sqlPrefix)) + require.NoError(t, err) + + _, _, err = engine.Exec( + context.Background(), + nil, + "CREATE TABLE mytable(id INTEGER, PRIMARY KEY id)", + nil, + ) + require.NoError(t, err) + + _, _, err = engine.Exec( + context.Background(), + nil, + "INSERT INTO mytable(id) VALUES (1)", + nil, + ) + require.NoError(t, err) + + t.Run("timestamp functions", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT NOW(1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + rows, err := engine.queryAll(context.Background(), nil, "SELECT NOW() FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.IsType(t, time.Time{}, rows[0].ValuesByPosition[0].RawValue()) + }) + + t.Run("uuid functions", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT RANDOM_UUID(1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + rows, err := engine.queryAll(context.Background(), nil, "SELECT RANDOM_UUID() FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.IsType(t, uuid.UUID{}, rows[0].ValuesByPosition[0].RawValue()) + }) + + t.Run("string functions", func(t *testing.T) { + t.Run("length", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT LENGTH(NULL, 1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + _, err = engine.queryAll(context.Background(), nil, "SELECT LENGTH(10) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + rows, err := engine.queryAll(context.Background(), nil, "SELECT LENGTH(NULL) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, rows[0].ValuesByPosition[0].IsNull()) + require.Equal(t, IntegerType, rows[0].ValuesByPosition[0].Type()) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT LENGTH('immudb'), LENGTH('') FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.Equal(t, int64(6), rows[0].ValuesByPosition[0].RawValue().(int64)) + require.Equal(t, int64(0), rows[0].ValuesByPosition[1].RawValue().(int64)) + }) + + t.Run("substring", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT SUBSTRING('Hello, immudb!', 0, 6, true) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + _, err = engine.queryAll(context.Background(), nil, "SELECT SUBSTRING('Hello, immudb!', 0, 6) FROM mytable", nil) + require.ErrorContains(t, err, "parameter 'position' must be greater than zero") + + _, err = engine.queryAll(context.Background(), nil, "SELECT SUBSTRING('Hello, immudb!', 1, -1) FROM mytable", nil) + require.ErrorContains(t, err, "parameter 'length' cannot be negative") + + rows, err := engine.queryAll(context.Background(), nil, "SELECT SUBSTRING(NULL, 8, 0) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, rows[0].ValuesByPosition[0].IsNull()) + require.Equal(t, VarcharType, rows[0].ValuesByPosition[0].Type()) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT SUBSTRING('Hello, immudb!', 8, 0) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, "", rows[0].ValuesByPosition[0].RawValue().(string)) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT SUBSTRING('Hello, immudb!', 8, 6) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.Equal(t, "immudb", rows[0].ValuesByPosition[0].RawValue().(string)) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT SUBSTRING('Hello, immudb!', 8, 100) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.Equal(t, "immudb!", rows[0].ValuesByPosition[0].RawValue().(string)) + }) + + t.Run("trim", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT TRIM(1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + _, err = engine.queryAll(context.Background(), nil, "SELECT TRIM(NULL, 1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + rows, err := engine.queryAll(context.Background(), nil, "SELECT TRIM(NULL) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, rows[0].ValuesByPosition[0].IsNull()) + require.Equal(t, VarcharType, rows[0].ValuesByPosition[0].Type()) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT TRIM(' \t\n\r Hello, immudb! ') FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.Equal(t, "Hello, immudb!", rows[0].ValuesByPosition[0].RawValue().(string)) + }) + + t.Run("concat", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT CONCAT() FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + _, err = engine.queryAll(context.Background(), nil, "SELECT CONCAT('ciao', NULL, true) FROM mytable", nil) + require.ErrorContains(t, err, "'CONCAT' function doesn't accept arguments of type BOOL") + + rows, err := engine.queryAll(context.Background(), nil, "SELECT CONCAT('Hello', ', ', NULL, 'immudb', NULL, '!') FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.Equal(t, "Hello, immudb!", rows[0].ValuesByPosition[0].RawValue().(string)) + }) + + t.Run("upper/lower", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT UPPER(1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + _, err = engine.queryAll(context.Background(), nil, "SELECT LOWER(NULL, 1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + rows, err := engine.queryAll(context.Background(), nil, "SELECT UPPER(NULL), LOWER(NULL) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.True(t, rows[0].ValuesByPosition[0].IsNull()) + require.True(t, rows[0].ValuesByPosition[1].IsNull()) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT UPPER('immudb'), LOWER('IMMUDB') FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + + require.Equal(t, "IMMUDB", rows[0].ValuesByPosition[0].RawValue().(string)) + require.Equal(t, "immudb", rows[0].ValuesByPosition[1].RawValue().(string)) + }) + }) + + t.Run("json functions", func(t *testing.T) { + _, err := engine.queryAll(context.Background(), nil, "SELECT JSON_TYPEOF(true) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + _, err = engine.queryAll(context.Background(), nil, "SELECT JSON_TYPEOF('{}'::JSON, 1) FROM mytable", nil) + require.ErrorIs(t, err, ErrIllegalArguments) + + rows, err := engine.queryAll(context.Background(), nil, "SELECT JSON_TYPEOF(NULL) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Nil(t, rows[0].ValuesByPosition[0].RawValue()) + + rows, err = engine.queryAll(context.Background(), nil, "SELECT JSON_TYPEOF('{}'::JSON) FROM mytable", nil) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, "OBJECT", rows[0].ValuesByPosition[0].RawValue().(string)) + }) +} diff --git a/embedded/sql/functions.go b/embedded/sql/functions.go new file mode 100644 index 0000000000..d7eebddcd9 --- /dev/null +++ b/embedded/sql/functions.go @@ -0,0 +1,334 @@ +/* +Copyright 2024 Codenotary Inc. All rights reserved. + +SPDX-License-Identifier: BUSL-1.1 +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://mariadb.com/bsl11/ + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sql + +import ( + "fmt" + "strings" + "time" + + "github.com/google/uuid" +) + +const ( + LengthFnCall string = "LENGTH" + SubstringFnCall string = "SUBSTRING" + ConcatFnCall string = "CONCAT" + LowerFnCall string = "LOWER" + UpperFnCall string = "UPPER" + TrimFnCall string = "TRIM" + NowFnCall string = "NOW" + UUIDFnCall string = "RANDOM_UUID" + DatabasesFnCall string = "DATABASES" + TablesFnCall string = "TABLES" + TableFnCall string = "TABLE" + UsersFnCall string = "USERS" + ColumnsFnCall string = "COLUMNS" + IndexesFnCall string = "INDEXES" + GrantsFnCall string = "GRANTS" + JSONTypeOfFnCall string = "JSON_TYPEOF" +) + +var builtinFunctions = map[string]Function{ + LengthFnCall: &LengthFn{}, + SubstringFnCall: &SubstringFn{}, + ConcatFnCall: &ConcatFn{}, + LowerFnCall: &LowerUpperFnc{}, + UpperFnCall: &LowerUpperFnc{isUpper: true}, + TrimFnCall: &TrimFnc{}, + NowFnCall: &NowFn{}, + UUIDFnCall: &UUIDFn{}, + JSONTypeOfFnCall: &JsonTypeOfFn{}, +} + +type Function interface { + requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error + inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) + Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) +} + +// ------------------------------------- +// String Functions +// ------------------------------------- + +type LengthFn struct{} + +func (f *LengthFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return IntegerType, nil +} + +func (f *LengthFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != IntegerType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, t) + } + return nil +} + +func (f *LengthFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) != 1 { + return nil, fmt.Errorf("%w: '%s' function does expects one argument but %d were provided", ErrIllegalArguments, LengthFnCall, len(params)) + } + + v := params[0] + if v.IsNull() { + return &NullValue{t: IntegerType}, nil + } + + if v.Type() != VarcharType { + return nil, fmt.Errorf("%w: '%s' function expects an argument of type %s", ErrIllegalArguments, LengthFnCall, VarcharType) + } + + s, _ := v.RawValue().(string) + return &Integer{val: int64(len(s))}, nil +} + +type ConcatFn struct{} + +func (f *ConcatFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return VarcharType, nil +} + +func (f *ConcatFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != VarcharType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t) + } + return nil +} + +func (f *ConcatFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) == 0 { + return nil, fmt.Errorf("%w: '%s' function does expects at least one argument", ErrIllegalArguments, ConcatFnCall) + } + + for _, v := range params { + if v.Type() != AnyType && v.Type() != VarcharType { + return nil, fmt.Errorf("%w: '%s' function doesn't accept arguments of type %s", ErrIllegalArguments, ConcatFnCall, v.Type()) + } + } + + var builder strings.Builder + for _, v := range params { + s, _ := v.RawValue().(string) + builder.WriteString(s) + } + return &Varchar{val: builder.String()}, nil +} + +type SubstringFn struct { +} + +func (f *SubstringFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return VarcharType, nil +} + +func (f *SubstringFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != VarcharType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t) + } + return nil +} + +func (f *SubstringFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) != 3 { + return nil, fmt.Errorf("%w: '%s' function does expects one argument but %d were provided", ErrIllegalArguments, SubstringFnCall, len(params)) + } + + v1, v2, v3 := params[0], params[1], params[2] + + if v1.IsNull() || v2.IsNull() || v3.IsNull() { + return &NullValue{t: VarcharType}, nil + } + + s, _ := v1.RawValue().(string) + pos, _ := v2.RawValue().(int64) + length, _ := v3.RawValue().(int64) + + if pos <= 0 { + return nil, fmt.Errorf("%w: parameter 'position' must be greater than zero", ErrIllegalArguments) + } + + if length < 0 { + return nil, fmt.Errorf("%w: parameter 'length' cannot be negative", ErrIllegalArguments) + } + + end := pos - 1 + length + if end > int64(len(s)) { + end = int64(len(s)) + } + return &Varchar{val: s[pos-1 : end]}, nil +} + +type LowerUpperFnc struct { + isUpper bool +} + +func (f *LowerUpperFnc) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return VarcharType, nil +} + +func (f *LowerUpperFnc) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != VarcharType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t) + } + return nil +} + +func (f *LowerUpperFnc) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) != 1 { + return nil, fmt.Errorf("%w: '%s' function does expects one argument but %d were provided", ErrIllegalArguments, f.name(), len(params)) + } + + v := params[0] + if v.IsNull() { + return &NullValue{t: VarcharType}, nil + } + + if v.Type() != VarcharType { + return nil, fmt.Errorf("%w: '%s' function expects an argument of type %s", ErrIllegalArguments, f.name(), VarcharType) + } + + s, _ := v.RawValue().(string) + + var res string + if f.isUpper { + res = strings.ToUpper(s) + } else { + res = strings.ToLower(s) + } + return &Varchar{val: res}, nil +} + +func (f *LowerUpperFnc) name() string { + if f.isUpper { + return UpperFnCall + } + return LowerFnCall +} + +type TrimFnc struct { +} + +func (f *TrimFnc) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return VarcharType, nil +} + +func (f *TrimFnc) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != VarcharType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t) + } + return nil +} + +func (f *TrimFnc) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) != 1 { + return nil, fmt.Errorf("%w: '%s' function does expects one argument but %d were provided", ErrIllegalArguments, TrimFnCall, len(params)) + } + + v := params[0] + if v.IsNull() { + return &NullValue{t: VarcharType}, nil + } + + if v.Type() != VarcharType { + return nil, fmt.Errorf("%w: '%s' function expects an argument of type %s", ErrIllegalArguments, TrimFnCall, VarcharType) + } + + s, _ := v.RawValue().(string) + return &Varchar{val: strings.Trim(s, " \t\n\r\v\f")}, nil +} + +// ------------------------------------- +// Time Functions +// ------------------------------------- + +type NowFn struct{} + +func (f *NowFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return TimestampType, nil +} + +func (f *NowFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != TimestampType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t) + } + return nil +} + +func (f *NowFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) > 0 { + return nil, fmt.Errorf("%w: '%s' function does not expect any argument but %d were provided", ErrIllegalArguments, NowFnCall, len(params)) + } + return &Timestamp{val: tx.Timestamp().Truncate(time.Microsecond).UTC()}, nil +} + +// ------------------------------------- +// JSON Functions +// ------------------------------------- + +type JsonTypeOfFn struct{} + +func (f *JsonTypeOfFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return VarcharType, nil +} + +func (f *JsonTypeOfFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != VarcharType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t) + } + return nil +} + +func (f *JsonTypeOfFn) Apply(tx *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) != 1 { + return nil, fmt.Errorf("%w: '%s' function expects %d arguments but %d were provided", ErrIllegalArguments, JSONTypeOfFnCall, 1, len(params)) + } + + v := params[0] + if v.IsNull() { + return NewNull(AnyType), nil + } + + jsonVal, ok := v.(*JSON) + if !ok { + return nil, fmt.Errorf("%w: '%s' function expects an argument of type JSON", ErrIllegalArguments, JSONTypeOfFnCall) + } + return NewVarchar(jsonVal.primitiveType()), nil +} + +// ------------------------------------- +// UUID Functions +// ------------------------------------- + +type UUIDFn struct{} + +func (f *UUIDFn) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { + return UUIDType, nil +} + +func (f *UUIDFn) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { + if t != UUIDType { + return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t) + } + return nil +} + +func (f *UUIDFn) Apply(_ *SQLTx, params []TypedValue) (TypedValue, error) { + if len(params) > 0 { + return nil, fmt.Errorf("%w: '%s' function does not expect any argument but %d were provided", ErrIllegalArguments, UUIDFnCall, len(params)) + } + return &UUID{val: uuid.New()}, nil +} diff --git a/embedded/sql/grouped_row_reader.go b/embedded/sql/grouped_row_reader.go index 8cf59420a1..25fddd2b8b 100644 --- a/embedded/sql/grouped_row_reader.go +++ b/embedded/sql/grouped_row_reader.go @@ -27,24 +27,26 @@ import ( type groupedRowReader struct { rowReader RowReader - selectors []Selector - groupByCols []*ColSelector - cols []ColDescriptor + selectors []*AggColSelector + groupByCols []*ColSelector + cols []ColDescriptor + allAggregations bool currRow *Row empty bool } -func newGroupedRowReader(rowReader RowReader, selectors []Selector, groupBy []*ColSelector) (*groupedRowReader, error) { - if rowReader == nil || len(selectors) == 0 { +func newGroupedRowReader(rowReader RowReader, allAggregations bool, selectors []*AggColSelector, groupBy []*ColSelector) (*groupedRowReader, error) { + if rowReader == nil { return nil, ErrIllegalArguments } gr := &groupedRowReader{ - rowReader: rowReader, - selectors: selectors, - groupByCols: groupBy, - empty: true, + rowReader: rowReader, + selectors: selectors, + groupByCols: groupBy, + empty: true, + allAggregations: allAggregations, } cols, err := gr.columns() @@ -139,9 +141,9 @@ func (gr *groupedRowReader) colsBySelector(ctx context.Context) (map[string]ColD return colDescriptors, nil } -func allAggregations(selectors []Selector) bool { - for _, sel := range selectors { - _, isAggregation := sel.(*AggColSelector) +func allAggregations(targets []TargetEntry) bool { + for _, t := range targets { + _, isAggregation := t.Exp.(*AggColSelector) if !isAggregation { return false } @@ -254,7 +256,7 @@ func updateRow(currRow, newRow *Row) error { } func (gr *groupedRowReader) emitCurrentRow(ctx context.Context) (*Row, error) { - if gr.empty && allAggregations(gr.selectors) && len(gr.groupByCols) == 0 { + if gr.empty && gr.allAggregations && len(gr.groupByCols) == 0 { zr, err := gr.zeroRow(ctx) if err != nil { return nil, err diff --git a/embedded/sql/grouped_row_reader_test.go b/embedded/sql/grouped_row_reader_test.go index e84fe26e22..b036aa2a8c 100644 --- a/embedded/sql/grouped_row_reader_test.go +++ b/embedded/sql/grouped_row_reader_test.go @@ -31,7 +31,7 @@ func TestGroupedRowReader(t *testing.T) { engine, err := NewEngine(st, DefaultOptions().WithPrefix(sqlPrefix)) require.NoError(t, err) - _, err = newGroupedRowReader(nil, nil, nil) + _, err = newGroupedRowReader(nil, false, nil, nil) require.ErrorIs(t, err, ErrIllegalArguments) tx, err := engine.NewTx(context.Background(), DefaultTxOptions()) @@ -50,7 +50,7 @@ func TestGroupedRowReader(t *testing.T) { r, err := newRawRowReader(tx, nil, table, period{}, "", &ScanSpecs{Index: table.primaryIndex}) require.NoError(t, err) - gr, err := newGroupedRowReader(r, []Selector{&ColSelector{col: "id"}}, []*ColSelector{{col: "id"}}) + gr, err := newGroupedRowReader(r, false, []*AggColSelector{{aggFn: "COUNT", col: "id"}}, []*ColSelector{{col: "id"}}) require.NoError(t, err) orderBy := gr.OrderBy() @@ -61,7 +61,7 @@ func TestGroupedRowReader(t *testing.T) { cols, err := gr.Columns(context.Background()) require.NoError(t, err) - require.Len(t, cols, 1) + require.Len(t, cols, 2) scanSpecs := gr.ScanSpecs() require.NotNil(t, scanSpecs) diff --git a/embedded/sql/json_type.go b/embedded/sql/json_type.go index 31e7264b96..e3665912ea 100644 --- a/embedded/sql/json_type.go +++ b/embedded/sql/json_type.go @@ -1,3 +1,19 @@ +/* +Copyright 2024 Codenotary Inc. All rights reserved. + +SPDX-License-Identifier: BUSL-1.1 +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://mariadb.com/bsl11/ + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package sql import ( @@ -73,6 +89,10 @@ func (v *JSON) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, er return v, nil } +func (sel *JSON) selectors() []Selector { + return nil +} + func (v *JSON) reduceSelectors(row *Row, implicitTable string) ValueExp { return v } @@ -183,13 +203,6 @@ func (sel *JSONSelector) substitute(params map[string]interface{}) (ValueExp, er return sel, nil } -func (v *JSONSelector) alias() string { - if v.ColSelector.as != "" { - return v.ColSelector.as - } - return v.String() -} - func (v *JSONSelector) resolve(implicitTable string) (string, string, string) { aggFn, table, _ := v.ColSelector.resolve(implicitTable) return aggFn, table, v.String() @@ -212,6 +225,10 @@ func (sel *JSONSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (Type return jsonVal.lookup(sel.fields), nil } +func (sel *JSONSelector) selectors() []Selector { + return []Selector{sel} +} + func (sel *JSONSelector) reduceSelectors(row *Row, implicitTable string) ValueExp { val := sel.ColSelector.reduceSelectors(row, implicitTable) diff --git a/embedded/sql/num_operator.go b/embedded/sql/num_operator.go index 8d4731e7f3..30ac10009d 100644 --- a/embedded/sql/num_operator.go +++ b/embedded/sql/num_operator.go @@ -16,7 +16,10 @@ limitations under the License. package sql -import "fmt" +import ( + "fmt" + "math" +) func applyNumOperator(op NumOperator, vl, vr TypedValue) (TypedValue, error) { if vl.Type() == Float64Type || vr.Type() == Float64Type { @@ -63,6 +66,14 @@ func applyNumOperatorInteger(op NumOperator, vl, vr TypedValue) (TypedValue, err return &Integer{val: nl / nr}, nil } + case MODOP: + { + if nr == 0 { + return nil, ErrDivisionByZero + } + + return &Integer{val: nl % nr}, nil + } case MULTOP: { return &Integer{val: nl * nr}, nil @@ -110,6 +121,14 @@ func applyNumOperatorFloat64(op NumOperator, vl, vr TypedValue) (TypedValue, err return &Float64{val: nl / nr}, nil } + case MODOP: + { + if nr == 0 { + return nil, ErrDivisionByZero + } + + return &Float64{val: math.Mod(nl, nr)}, nil + } case MULTOP: { return &Float64{val: nl * nr}, nil diff --git a/embedded/sql/num_operator_test.go b/embedded/sql/num_operator_test.go index 1635b53e02..784b6d2e89 100644 --- a/embedded/sql/num_operator_test.go +++ b/embedded/sql/num_operator_test.go @@ -18,6 +18,7 @@ package sql import ( "fmt" + "math" "testing" "github.com/stretchr/testify/require" @@ -47,6 +48,11 @@ func TestNumOperator(t *testing.T) { {DIVOP, &Float64{val: 10}, &Integer{val: 3}, float64(10.0 / 3.0)}, {DIVOP, &Float64{val: 10}, &Float64{val: 3}, float64(10.0 / 3.0)}, + {MODOP, &Integer{val: 10}, &Integer{val: 3}, int64(1)}, + {MODOP, &Integer{val: 10}, &Float64{val: 3}, float64(1)}, + {MODOP, &Float64{val: 10}, &Integer{val: 3}, float64(1)}, + {MODOP, &Float64{val: 10.5}, &Float64{val: 3.2}, math.Mod(10.5, 3.2)}, + {MULTOP, &Integer{val: 10}, &Integer{val: 3}, int64(30)}, {MULTOP, &Float64{val: 10}, &Integer{val: 3}, float64(30)}, {MULTOP, &Integer{val: 10}, &Float64{val: 3}, float64(30)}, diff --git a/embedded/sql/parser.go b/embedded/sql/parser.go index 14c0923e9b..9d26227be2 100644 --- a/embedded/sql/parser.go +++ b/embedded/sql/parser.go @@ -55,6 +55,7 @@ var reservedWords = map[string]int{ "CONFLICT": CONFLICT, "DO": DO, "NOTHING": NOTHING, + "RETURNING": RETURNING, "UPSERT": UPSERT, "INTO": INTO, "VALUES": VALUES, diff --git a/embedded/sql/parser_test.go b/embedded/sql/parser_test.go index 55b8b3da03..36b2455a01 100644 --- a/embedded/sql/parser_test.go +++ b/embedded/sql/parser_test.go @@ -464,16 +464,17 @@ func TestInsertIntoStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "time", "title", "active", "compressed", "payload", "note"}, - rows: []*RowSpec{ - {Values: []ValueExp{ - &Integer{val: 2}, - &FnCall{fn: "now"}, - &Varchar{val: "un'titled row"}, - &Bool{val: true}, - &Bool{val: false}, - &Blob{val: decodedBLOB}, - &Param{id: "param1"}, - }, + ds: &valuesDataSource{ + rows: []*RowSpec{{ + Values: []ValueExp{ + &Integer{val: 2}, + &FnCall{fn: "now"}, + &Varchar{val: "un'titled row"}, + &Bool{val: true}, + &Bool{val: false}, + &Blob{val: decodedBLOB}, + &Param{id: "param1"}, + }}, }, }, }, @@ -486,16 +487,17 @@ func TestInsertIntoStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "time", "title", "active", "compressed", "payload", "note"}, - rows: []*RowSpec{ - {Values: []ValueExp{ - &Integer{val: 2}, - &FnCall{fn: "now"}, - &Varchar{val: ""}, - &Bool{val: true}, - &Bool{val: false}, - &Blob{val: decodedBLOB}, - &Param{id: "param1"}, - }, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{ + &Integer{val: 2}, + &FnCall{fn: "now"}, + &Varchar{val: ""}, + &Bool{val: true}, + &Bool{val: false}, + &Blob{val: decodedBLOB}, + &Param{id: "param1"}, + }}, }, }, }, @@ -508,17 +510,17 @@ func TestInsertIntoStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "time", "title", "active", "compressed", "payload", "note"}, - rows: []*RowSpec{ - {Values: []ValueExp{ - &Integer{val: 2}, - &FnCall{fn: "now"}, - &Varchar{val: "'"}, - &Bool{val: true}, - &Bool{val: false}, - &Blob{val: decodedBLOB}, - &Param{id: "param1"}, - }, - }, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{ + &Integer{val: 2}, + &FnCall{fn: "now"}, + &Varchar{val: "'"}, + &Bool{val: true}, + &Bool{val: false}, + &Blob{val: decodedBLOB}, + &Param{id: "param1"}, + }}}, }, }, }, @@ -530,16 +532,17 @@ func TestInsertIntoStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "time", "title", "active", "compressed", "payload", "note"}, - rows: []*RowSpec{ - {Values: []ValueExp{ - &Integer{val: 2}, - &FnCall{fn: "now"}, - &Varchar{val: "untitled row"}, - &Bool{val: true}, - &Param{id: "param1", pos: 1}, - &Blob{val: decodedBLOB}, - &Param{id: "param2", pos: 2}, - }, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{ + &Integer{val: 2}, + &FnCall{fn: "now"}, + &Varchar{val: "untitled row"}, + &Bool{val: true}, + &Param{id: "param1", pos: 1}, + &Blob{val: decodedBLOB}, + &Param{id: "param2", pos: 2}, + }}, }, }, }, @@ -552,16 +555,18 @@ func TestInsertIntoStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "time", "title", "active", "compressed", "payload", "note"}, - rows: []*RowSpec{ - {Values: []ValueExp{ - &Integer{val: 2}, - &FnCall{fn: "now"}, - &Param{id: "param1", pos: 1}, - &Bool{val: true}, - &Param{id: "param2", pos: 2}, - &Blob{val: decodedBLOB}, - &Param{id: "param1", pos: 1}, - }, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{ + &Integer{val: 2}, + &FnCall{fn: "now"}, + &Param{id: "param1", pos: 1}, + &Bool{val: true}, + &Param{id: "param2", pos: 2}, + &Blob{val: decodedBLOB}, + &Param{id: "param1", pos: 1}, + }, + }, }, }, }, @@ -614,10 +619,50 @@ func TestInsertIntoStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "active"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 1}, &Bool{val: false}}}, - {Values: []ValueExp{&Integer{val: 2}, &Bool{val: true}}}, - {Values: []ValueExp{&Integer{val: 3}, &Bool{val: true}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 1}, &Bool{val: false}}}, + {Values: []ValueExp{&Integer{val: 2}, &Bool{val: true}}}, + {Values: []ValueExp{&Integer{val: 3}, &Bool{val: true}}}, + }, + }, + }, + }, + expectedError: nil, + }, + { + input: "INSERT INTO table1(id, active) SELECT * FROM my_table", + expectedOutput: []SQLStmt{ + &UpsertIntoStmt{ + isInsert: true, + tableRef: &tableRef{table: "table1"}, + cols: []string{"id", "active"}, + ds: &SelectStmt{ + ds: &tableRef{ + table: "my_table", + }, + targets: nil, + }, + }, + }, + expectedError: nil, + }, + { + input: "UPSERT INTO table1(id, active) SELECT * FROM my_table WHERE balance >= 0 AND deleted_at IS NULL", + expectedOutput: []SQLStmt{ + &UpsertIntoStmt{ + tableRef: &tableRef{table: "table1"}, + cols: []string{"id", "active"}, + ds: &SelectStmt{ + ds: &tableRef{ + table: "my_table", + }, + targets: nil, + where: &BinBoolExp{ + op: AND, + left: &CmpBoolExp{op: GE, left: &ColSelector{col: "balance"}, right: &Integer{val: 0}}, + right: &CmpBoolExp{op: EQ, left: &ColSelector{col: "deleted_at"}, right: &NullValue{t: AnyType}}, + }, }, }, }, @@ -745,15 +790,19 @@ func TestTxStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "label"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + }, }, }, &UpsertIntoStmt{ tableRef: &tableRef{table: "table2"}, cols: []string{"id"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 10}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 10}}}, + }, }, }, &RollbackStmt{}, @@ -775,8 +824,10 @@ func TestTxStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "label"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + }, }, }, &CommitStmt{}, @@ -817,8 +868,10 @@ func TestTxStmt(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "label"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + }, }, }, &CommitStmt{}, @@ -850,9 +903,9 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1"}, }}, @@ -863,9 +916,9 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1", as: "t1"}, }}, @@ -876,9 +929,9 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{table: "t1", col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{table: "t1", col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1", as: "t1"}, }}, @@ -889,9 +942,9 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{table: "table1", col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{table: "table1", col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1", as: "t1"}, where: &CmpBoolExp{ @@ -909,9 +962,9 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{table: "table1", col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{table: "table1", col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1", as: "t1"}, where: &CmpBoolExp{ @@ -929,9 +982,9 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{table: "table1", col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{table: "table1", col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1", as: "t1"}, where: &CmpBoolExp{ @@ -949,10 +1002,10 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: true, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "time"}, - &ColSelector{col: "name"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "time"}}, + {Exp: &ColSelector{col: "name"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -990,10 +1043,10 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "title"}, - &ColSelector{col: "year"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "title"}}, + {Exp: &ColSelector{col: "year"}}, }, ds: &tableRef{table: "table1"}, orderBy: []*OrdCol{ @@ -1008,10 +1061,10 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "name"}, - &ColSelector{table: "table2", col: "status"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "name"}}, + {Exp: &ColSelector{table: "table2", col: "status"}}, }, ds: &tableRef{table: "table1"}, joins: []*JoinSpec{ @@ -1047,10 +1100,10 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "name"}, - &ColSelector{table: "table2", col: "status"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "name"}}, + {Exp: &ColSelector{table: "table2", col: "status"}}, }, ds: &tableRef{table: "table1"}, joins: []*JoinSpec{ @@ -1086,10 +1139,10 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "name"}, - &ColSelector{table: "table2", col: "status"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "name"}}, + {Exp: &ColSelector{table: "table2", col: "status"}}, }, ds: &tableRef{table: "table1"}, joins: []*JoinSpec{ @@ -1125,15 +1178,15 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "col1", as: "id"}, - &ColSelector{col: "col2", as: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "col1"}, As: "id"}, + {Exp: &ColSelector{col: "col2"}, As: "title"}, }, ds: &tableRef{table: "table2"}, limit: &Integer{val: 100}, @@ -1148,10 +1201,10 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "name"}, - &ColSelector{col: "time"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "name"}}, + {Exp: &ColSelector{col: "time"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -1179,10 +1232,49 @@ func TestSelectStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &JSONSelector{ - ColSelector: &ColSelector{col: "json_data"}, - fields: []string{"info", "address", "street"}, + targets: []TargetEntry{ + { + Exp: &JSONSelector{ + ColSelector: &ColSelector{col: "json_data"}, + fields: []string{"info", "address", "street"}, + }, + }, + }, + ds: &tableRef{table: "table1"}, + }}, + expectedError: nil, + }, + { + input: "SELECT 1, (balance * balance) + 1, amount % 2, data::JSON FROM table1", + expectedOutput: []SQLStmt{ + &SelectStmt{ + distinct: false, + targets: []TargetEntry{ + { + Exp: &Integer{ + val: int64(1), + }, + }, + { + Exp: &NumExp{ + op: ADDOP, + left: &NumExp{ + op: MULTOP, + left: &ColSelector{col: "balance"}, + right: &ColSelector{col: "balance"}, + }, + right: &Integer{val: int64(1)}, + }, + }, + { + Exp: &NumExp{ + op: MODOP, + left: &ColSelector{col: "amount"}, + right: &Integer{val: int64(2)}, + }, + }, + { + Exp: &Cast{val: &ColSelector{col: "data"}, t: JSONType}, }, }, ds: &tableRef{table: "table1"}, @@ -1214,17 +1306,17 @@ func TestSelectUnionStmt(t *testing.T) { distinct: true, left: &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1"}, }, right: &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "title"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "title"}}, }, ds: &tableRef{table: "table1"}, }}, @@ -1254,8 +1346,8 @@ func TestAggFnStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &AggColSelector{aggFn: COUNT, col: "*"}, + targets: []TargetEntry{ + {Exp: &AggColSelector{aggFn: COUNT, col: "*"}}, }, ds: &tableRef{table: "table1"}, }}, @@ -1266,9 +1358,9 @@ func TestAggFnStmt(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "country"}, - &AggColSelector{aggFn: SUM, col: "amount"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "country"}}, + {Exp: &AggColSelector{aggFn: SUM, col: "amount"}}, }, ds: &tableRef{table: "table1"}, groupBy: []*ColSelector{ @@ -1305,8 +1397,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &CmpBoolExp{ @@ -1324,8 +1416,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -1355,8 +1447,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &NotBoolExp{ @@ -1386,8 +1478,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -1405,8 +1497,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -1437,8 +1529,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &LikeBoolExp{ @@ -1456,8 +1548,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &LikeBoolExp{ @@ -1475,8 +1567,8 @@ func TestExpressions(t *testing.T) { expectedOutput: []SQLStmt{ &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -1516,14 +1608,14 @@ func TestExpressions(t *testing.T) { input: "SELECT id FROM clients WHERE EXISTS (SELECT id FROM orders WHERE clients.id = orders.id_client)", expectedOutput: []SQLStmt{ &SelectStmt{ - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "clients"}, where: &ExistsBoolExp{ q: &SelectStmt{ - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "orders"}, where: &CmpBoolExp{ @@ -1546,8 +1638,8 @@ func TestExpressions(t *testing.T) { input: "SELECT id FROM clients WHERE deleted_at IS NULL", expectedOutput: []SQLStmt{ &SelectStmt{ - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "clients"}, where: &CmpBoolExp{ @@ -1566,8 +1658,8 @@ func TestExpressions(t *testing.T) { input: "SELECT id FROM clients WHERE deleted_at IS NOT NULL", expectedOutput: []SQLStmt{ &SelectStmt{ - selectors: []Selector{ - &ColSelector{col: "id"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, }, ds: &tableRef{table: "clients"}, where: &CmpBoolExp{ @@ -1638,24 +1730,28 @@ func TestMultiLineStmts(t *testing.T) { &UpsertIntoStmt{ tableRef: &tableRef{table: "table1"}, cols: []string{"id", "label"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 100}, &Varchar{val: "label1"}}}, + }, }, }, &UpsertIntoStmt{ tableRef: &tableRef{table: "table2"}, cols: []string{"id"}, - rows: []*RowSpec{ - {Values: []ValueExp{&Integer{val: 10}}}, + ds: &valuesDataSource{ + rows: []*RowSpec{ + {Values: []ValueExp{&Integer{val: 10}}}, + }, }, }, &CommitStmt{}, &SelectStmt{ distinct: false, - selectors: []Selector{ - &ColSelector{col: "id"}, - &ColSelector{col: "name"}, - &ColSelector{col: "time"}, + targets: []TargetEntry{ + {Exp: &ColSelector{col: "id"}}, + {Exp: &ColSelector{col: "name"}}, + {Exp: &ColSelector{col: "time"}}, }, ds: &tableRef{table: "table1"}, where: &BinBoolExp{ @@ -1719,9 +1815,11 @@ func TestFloatCornerCases(t *testing.T) { table: "t1", }, cols: []string{"v"}, - rows: []*RowSpec{{ - Values: []ValueExp{d.v}, - }}, + ds: &valuesDataSource{ + rows: []*RowSpec{{ + Values: []ValueExp{d.v}, + }}, + }, }, }, stmt) } @@ -1794,7 +1892,7 @@ func TestGrantRevokeStmt(t *testing.T) { func TestExprString(t *testing.T) { exps := []string{ - "(1 + 1) / (2 * 5 - 10)", + "(1 + 1) / (2 * 5 - 10) % 2", "@param LIKE 'pattern'", "((col1 AND (col2 < 10)) OR (@param = 3 AND (col4 = TRUE))) AND NOT (col5 = 'value' OR (2 + 2 != 4))", "CAST (func_call(1, 'two', 2.5) AS TIMESTAMP)", diff --git a/embedded/sql/proj_row_reader.go b/embedded/sql/proj_row_reader.go index c56fa24ffe..dda316fa36 100644 --- a/embedded/sql/proj_row_reader.go +++ b/embedded/sql/proj_row_reader.go @@ -25,30 +25,31 @@ type projectedRowReader struct { rowReader RowReader tableAlias string - selectors []Selector + targets []TargetEntry } -func newProjectedRowReader(ctx context.Context, rowReader RowReader, tableAlias string, selectors []Selector) (*projectedRowReader, error) { +func newProjectedRowReader(ctx context.Context, rowReader RowReader, tableAlias string, targets []TargetEntry) (*projectedRowReader, error) { // case: SELECT * - if len(selectors) == 0 { + if len(targets) == 0 { cols, err := rowReader.Columns(ctx) if err != nil { return nil, err } for _, col := range cols { - sel := &ColSelector{ - table: col.Table, - col: col.Column, - } - selectors = append(selectors, sel) + targets = append(targets, TargetEntry{ + Exp: &ColSelector{ + table: col.Table, + col: col.Column, + }, + }) } } return &projectedRowReader{ rowReader: rowReader, tableAlias: tableAlias, - selectors: selectors, + targets: targets, }, nil } @@ -82,38 +83,33 @@ func (pr *projectedRowReader) Columns(ctx context.Context) ([]ColDescriptor, err return nil, err } - colsByPos := make([]ColDescriptor, len(pr.selectors)) + colsByPos := make([]ColDescriptor, len(pr.targets)) - for i, sel := range pr.selectors { - aggFn, table, col := sel.resolve(pr.rowReader.TableAlias()) + for i, t := range pr.targets { + var aggFn, table, col string = "", pr.rowReader.TableAlias(), "" + if s, ok := t.Exp.(Selector); ok { + aggFn, table, col = s.resolve(pr.rowReader.TableAlias()) + } if pr.tableAlias != "" { table = pr.tableAlias } - if aggFn == "" && sel.alias() != "" { - col = sel.alias() - } - - if aggFn != "" { - aggFn = "" - col = sel.alias() - if col == "" { - col = fmt.Sprintf("col%d", i) - } + if t.As != "" { + col = t.As + } else if aggFn != "" || col == "" { + col = fmt.Sprintf("col%d", i) } + aggFn = "" colsByPos[i] = ColDescriptor{ AggFn: aggFn, Table: table, Column: col, } - encSel := colsByPos[i].Selector() - colsByPos[i].Type = colsBySel[encSel].Type } - return colsByPos, nil } @@ -123,13 +119,16 @@ func (pr *projectedRowReader) colsBySelector(ctx context.Context) (map[string]Co return nil, err } - colDescriptors := make(map[string]ColDescriptor, len(pr.selectors)) + colDescriptors := make(map[string]ColDescriptor, len(pr.targets)) emptyParams := make(map[string]string) - for i, sel := range pr.selectors { - aggFn, table, col := sel.resolve(pr.rowReader.TableAlias()) + for i, t := range pr.targets { + var aggFn, table, col string = "", pr.rowReader.TableAlias(), "" + if s, ok := t.Exp.(Selector); ok { + aggFn, table, col = s.resolve(pr.rowReader.TableAlias()) + } - sqlType, err := sel.inferType(dsColDescriptors, emptyParams, pr.rowReader.TableAlias()) + sqlType, err := t.Exp.inferType(dsColDescriptors, emptyParams, pr.rowReader.TableAlias()) if err != nil { return nil, err } @@ -138,17 +137,12 @@ func (pr *projectedRowReader) colsBySelector(ctx context.Context) (map[string]Co table = pr.tableAlias } - if aggFn == "" && sel.alias() != "" { - col = sel.alias() - } - - if aggFn != "" { - aggFn = "" - col = sel.alias() - if col == "" { - col = fmt.Sprintf("col%d", i) - } + if t.As != "" { + col = t.As + } else if aggFn != "" || col == "" { + col = fmt.Sprintf("col%d", i) } + aggFn = "" des := ColDescriptor{ AggFn: aggFn, @@ -156,15 +150,28 @@ func (pr *projectedRowReader) colsBySelector(ctx context.Context) (map[string]Co Column: col, Type: sqlType, } - colDescriptors[des.Selector()] = des } - return colDescriptors, nil } func (pr *projectedRowReader) InferParameters(ctx context.Context, params map[string]SQLValueType) error { - return pr.rowReader.InferParameters(ctx, params) + if err := pr.rowReader.InferParameters(ctx, params); err != nil { + return err + } + + cols, err := pr.rowReader.colsBySelector(ctx) + if err != nil { + return err + } + + for _, ex := range pr.targets { + _, err = ex.Exp.inferType(cols, params, pr.TableAlias()) + if err != nil { + return err + } + } + return nil } func (pr *projectedRowReader) Parameters() map[string]interface{} { @@ -178,35 +185,33 @@ func (pr *projectedRowReader) Read(ctx context.Context) (*Row, error) { } prow := &Row{ - ValuesByPosition: make([]TypedValue, len(pr.selectors)), - ValuesBySelector: make(map[string]TypedValue, len(pr.selectors)), + ValuesByPosition: make([]TypedValue, len(pr.targets)), + ValuesBySelector: make(map[string]TypedValue, len(pr.targets)), } - for i, sel := range pr.selectors { - v, err := sel.reduce(pr.Tx(), row, pr.rowReader.TableAlias()) + for i, t := range pr.targets { + v, err := t.Exp.reduce(pr.Tx(), row, pr.rowReader.TableAlias()) if err != nil { return nil, err } - aggFn, table, col := sel.resolve(pr.rowReader.TableAlias()) - if pr.tableAlias != "" { - table = pr.tableAlias + var aggFn, table, col string = "", pr.rowReader.TableAlias(), "" + if s, ok := t.Exp.(Selector); ok { + aggFn, table, col = s.resolve(pr.rowReader.TableAlias()) } - if aggFn == "" && sel.alias() != "" { - col = sel.alias() + if pr.tableAlias != "" { + table = pr.tableAlias } - if aggFn != "" { - aggFn = "" - col = sel.alias() - if col == "" { - col = fmt.Sprintf("col%d", i) - } + if t.As != "" { + col = t.As + } else if aggFn != "" || col == "" { + col = fmt.Sprintf("col%d", i) } prow.ValuesByPosition[i] = v - prow.ValuesBySelector[EncodeSelector(aggFn, table, col)] = v + prow.ValuesBySelector[EncodeSelector("", table, col)] = v } return prow, nil } diff --git a/embedded/sql/sql_grammar.y b/embedded/sql/sql_grammar.y index 60e7a14a80..f31bb1e026 100644 --- a/embedded/sql/sql_grammar.y +++ b/embedded/sql/sql_grammar.y @@ -46,7 +46,7 @@ func setResult(l yyLexer, stmts []SQLStmt) { ids []string col *ColSelector sel Selector - sels []Selector + targets []TargetEntry jsonFields []string distinct bool ds DataSource @@ -77,7 +77,7 @@ func setResult(l yyLexer, stmts []SQLStmt) { %token CREATE DROP USE DATABASE USER WITH PASSWORD READ READWRITE ADMIN SNAPSHOT HISTORY SINCE AFTER BEFORE UNTIL TX OF TIMESTAMP %token TABLE UNIQUE INDEX ON ALTER ADD RENAME TO COLUMN CONSTRAINT PRIMARY KEY CHECK GRANT REVOKE GRANTS FOR PRIVILEGES %token BEGIN TRANSACTION COMMIT ROLLBACK -%token INSERT UPSERT INTO VALUES DELETE UPDATE SET CONFLICT DO NOTHING +%token INSERT UPSERT INTO VALUES DELETE UPDATE SET CONFLICT DO NOTHING RETURNING %token SELECT DISTINCT FROM JOIN HAVING WHERE GROUP BY LIMIT OFFSET ORDER ASC DESC AS UNION ALL %token NOT LIKE IF EXISTS IN IS %token AUTO_INCREMENT NULL CAST SCAST @@ -106,7 +106,7 @@ func setResult(l yyLexer, stmts []SQLStmt) { %right NOT %left CMPOP %left '+' '-' -%left '*' '/' +%left '*' '/' '%' %left '.' %right STMT_SEPARATOR %left IS @@ -122,11 +122,10 @@ func setResult(l yyLexer, stmts []SQLStmt) { %type values opt_values %type val fnCall %type selector -%type opt_selectors selectors %type jsonFields %type col %type opt_distinct opt_all -%type ds +%type ds values_or_query %type tableRef %type opt_period %type opt_period_start @@ -140,6 +139,7 @@ func setResult(l yyLexer, stmts []SQLStmt) { %type binExp %type opt_groupby %type opt_limit opt_offset +%type opt_targets targets %type opt_max_len %type opt_as %type ordcols opt_orderby @@ -393,14 +393,14 @@ one_or_more_ids: } dmlstmt: - INSERT INTO tableRef '(' opt_ids ')' VALUES rows opt_on_conflict + INSERT INTO tableRef '(' opt_ids ')' values_or_query opt_on_conflict { - $$ = &UpsertIntoStmt{isInsert: true, tableRef: $3, cols: $5, rows: $8, onConflict: $9} + $$ = &UpsertIntoStmt{isInsert: true, tableRef: $3, cols: $5, ds: $7, onConflict: $8} } | - UPSERT INTO tableRef '(' ids ')' VALUES rows + UPSERT INTO tableRef '(' ids ')' values_or_query { - $$ = &UpsertIntoStmt{tableRef: $3, cols: $5, rows: $8} + $$ = &UpsertIntoStmt{tableRef: $3, cols: $5, ds: $7} } | DELETE FROM tableRef opt_where opt_indexon opt_limit opt_offset @@ -413,6 +413,18 @@ dmlstmt: $$ = &UpdateStmt{tableRef: $2, updates: $4, where: $5, indexOn: $6, limit: $7, offset: $8} } +values_or_query: + VALUES rows + { + $$ = &valuesDataSource{rows: $2} + } +| + dqlstmt + { + $$ = $1.(DataSource) + } + + opt_on_conflict: { $$ = nil @@ -681,11 +693,11 @@ dqlstmt: } } -select_stmt: SELECT opt_distinct opt_selectors FROM ds opt_indexon opt_joins opt_where opt_groupby opt_having opt_orderby opt_limit opt_offset +select_stmt: SELECT opt_distinct opt_targets FROM ds opt_indexon opt_joins opt_where opt_groupby opt_having opt_orderby opt_limit opt_offset { $$ = &SelectStmt{ distinct: $2, - selectors: $3, + targets: $3, ds: $5, indexOn: $6, joins: $7, @@ -718,28 +730,26 @@ opt_distinct: $$ = true } -opt_selectors: +opt_targets: '*' { $$ = nil } | - selectors + targets { $$ = $1 } -selectors: - selector opt_as +targets: + exp opt_as { - $1.setAlias($2) - $$ = []Selector{$1} + $$ = []TargetEntry{{Exp: $1, As: $2}} } | - selectors ',' selector opt_as + targets ',' exp opt_as { - $3.setAlias($4) - $$ = append($1, $3) + $$ = append($1, TargetEntry{Exp: $3, As: $4}) } selector: @@ -1137,6 +1147,11 @@ binExp: { $$ = &NumExp{left: $1, op: MULTOP, right: $3} } +| + exp '%' exp + { + $$ = &NumExp{left: $1, op: MODOP, right: $3} + } | exp LOP exp { diff --git a/embedded/sql/sql_parser.go b/embedded/sql/sql_parser.go index ec706c2cf8..ac6c63176e 100644 --- a/embedded/sql/sql_parser.go +++ b/embedded/sql/sql_parser.go @@ -32,7 +32,7 @@ type yySymType struct { ids []string col *ColSelector sel Selector - sels []Selector + targets []TargetEntry jsonFields []string distinct bool ds DataSource @@ -111,53 +111,54 @@ const SET = 57393 const CONFLICT = 57394 const DO = 57395 const NOTHING = 57396 -const SELECT = 57397 -const DISTINCT = 57398 -const FROM = 57399 -const JOIN = 57400 -const HAVING = 57401 -const WHERE = 57402 -const GROUP = 57403 -const BY = 57404 -const LIMIT = 57405 -const OFFSET = 57406 -const ORDER = 57407 -const ASC = 57408 -const DESC = 57409 -const AS = 57410 -const UNION = 57411 -const ALL = 57412 -const NOT = 57413 -const LIKE = 57414 -const IF = 57415 -const EXISTS = 57416 -const IN = 57417 -const IS = 57418 -const AUTO_INCREMENT = 57419 -const NULL = 57420 -const CAST = 57421 -const SCAST = 57422 -const SHOW = 57423 -const DATABASES = 57424 -const TABLES = 57425 -const USERS = 57426 -const NPARAM = 57427 -const PPARAM = 57428 -const JOINTYPE = 57429 -const LOP = 57430 -const CMPOP = 57431 -const IDENTIFIER = 57432 -const TYPE = 57433 -const INTEGER = 57434 -const FLOAT = 57435 -const VARCHAR = 57436 -const BOOLEAN = 57437 -const BLOB = 57438 -const AGGREGATE_FUNC = 57439 -const ERROR = 57440 -const DOT = 57441 -const ARROW = 57442 -const STMT_SEPARATOR = 57443 +const RETURNING = 57397 +const SELECT = 57398 +const DISTINCT = 57399 +const FROM = 57400 +const JOIN = 57401 +const HAVING = 57402 +const WHERE = 57403 +const GROUP = 57404 +const BY = 57405 +const LIMIT = 57406 +const OFFSET = 57407 +const ORDER = 57408 +const ASC = 57409 +const DESC = 57410 +const AS = 57411 +const UNION = 57412 +const ALL = 57413 +const NOT = 57414 +const LIKE = 57415 +const IF = 57416 +const EXISTS = 57417 +const IN = 57418 +const IS = 57419 +const AUTO_INCREMENT = 57420 +const NULL = 57421 +const CAST = 57422 +const SCAST = 57423 +const SHOW = 57424 +const DATABASES = 57425 +const TABLES = 57426 +const USERS = 57427 +const NPARAM = 57428 +const PPARAM = 57429 +const JOINTYPE = 57430 +const LOP = 57431 +const CMPOP = 57432 +const IDENTIFIER = 57433 +const TYPE = 57434 +const INTEGER = 57435 +const FLOAT = 57436 +const VARCHAR = 57437 +const BOOLEAN = 57438 +const BLOB = 57439 +const AGGREGATE_FUNC = 57440 +const ERROR = 57441 +const DOT = 57442 +const ARROW = 57443 +const STMT_SEPARATOR = 57444 var yyToknames = [...]string{ "$end", @@ -214,6 +215,7 @@ var yyToknames = [...]string{ "CONFLICT", "DO", "NOTHING", + "RETURNING", "SELECT", "DISTINCT", "FROM", @@ -265,6 +267,7 @@ var yyToknames = [...]string{ "'-'", "'*'", "'/'", + "'%'", "'.'", "STMT_SEPARATOR", "'('", @@ -283,153 +286,156 @@ var yyExca = [...]int16{ -1, 1, 1, -1, -2, 0, - -1, 112, - 72, 182, - 75, 182, - -2, 170, - -1, 260, - 58, 143, - -2, 138, - -1, 303, - 58, 143, + -1, 97, + 73, 184, + 76, 184, + -2, 172, + -1, 262, + 59, 145, -2, 140, + -1, 307, + 59, 145, + -2, 142, } const yyPrivate = 57344 -const yyLast = 512 +const yyLast = 547 var yyAct = [...]int16{ - 111, 408, 117, 97, 296, 197, 254, 152, 313, 330, - 334, 126, 203, 194, 239, 302, 329, 240, 6, 143, - 71, 220, 277, 374, 146, 319, 110, 318, 22, 252, - 287, 252, 109, 252, 252, 252, 395, 379, 357, 355, - 378, 320, 288, 253, 114, 208, 375, 116, 367, 358, - 356, 129, 125, 335, 21, 345, 312, 310, 127, 128, - 173, 309, 307, 130, 96, 120, 121, 122, 123, 124, - 98, 336, 171, 172, 286, 284, 115, 114, 272, 271, - 116, 119, 24, 251, 129, 125, 167, 168, 170, 169, - 331, 127, 128, 237, 283, 276, 130, 182, 120, 121, - 122, 123, 124, 98, 206, 207, 209, 131, 148, 115, - 114, 166, 211, 116, 119, 177, 178, 129, 125, 158, - 180, 267, 266, 265, 127, 128, 264, 222, 182, 130, - 205, 120, 121, 122, 123, 124, 98, 183, 181, 179, - 164, 165, 115, 160, 157, 199, 142, 119, 141, 144, - 407, 173, 361, 99, 212, 400, 196, 173, 360, 287, - 98, 217, 210, 171, 172, 200, 173, 94, 225, 226, - 227, 228, 229, 230, 99, 273, 396, 167, 168, 170, - 169, 252, 238, 241, 151, 170, 169, 173, 216, 83, - 201, 214, 167, 168, 170, 169, 156, 242, 236, 171, - 172, 158, 133, 259, 270, 354, 257, 243, 248, 260, - 99, 285, 369, 167, 168, 170, 169, 98, 268, 173, - 269, 224, 215, 353, 262, 261, 258, 360, 154, 324, - 32, 171, 172, 274, 282, 316, 275, 33, 315, 173, - 235, 99, 76, 195, 349, 167, 168, 170, 169, 173, - 153, 171, 172, 305, 342, 328, 298, 327, 326, 311, - 294, 147, 172, 202, 300, 167, 168, 170, 169, 306, - 247, 290, 246, 245, 244, 167, 168, 170, 169, 221, - 223, 218, 213, 192, 241, 191, 184, 161, 325, 149, - 321, 132, 102, 221, 100, 314, 91, 54, 333, 80, - 79, 323, 322, 78, 75, 337, 77, 70, 69, 39, - 22, 263, 332, 31, 373, 341, 348, 343, 344, 339, - 346, 338, 10, 12, 11, 49, 281, 352, 58, 176, - 232, 372, 241, 173, 351, 159, 21, 231, 175, 233, - 101, 65, 234, 60, 90, 13, 362, 55, 409, 410, - 363, 22, 210, 366, 14, 15, 390, 297, 255, 7, - 399, 8, 9, 16, 17, 382, 144, 18, 19, 64, - 376, 387, 383, 22, 365, 384, 381, 21, 386, 385, - 340, 150, 391, 43, 47, 52, 393, 56, 57, 59, - 62, 204, 397, 139, 388, 401, 398, 66, 67, 21, - 377, 88, 405, 403, 406, 402, 48, 295, 293, 51, - 411, 53, 50, 412, 25, 82, 92, 316, 370, 347, - 315, 368, 188, 189, 44, 186, 187, 185, 46, 45, - 136, 289, 250, 249, 42, 36, 394, 299, 104, 26, - 30, 162, 85, 86, 87, 103, 84, 81, 256, 40, - 34, 68, 35, 134, 135, 27, 29, 28, 2, 38, - 308, 108, 107, 73, 74, 278, 279, 280, 190, 163, - 137, 105, 292, 291, 37, 140, 138, 198, 23, 41, - 359, 145, 174, 63, 350, 371, 389, 404, 317, 364, - 113, 112, 380, 304, 303, 301, 106, 72, 89, 61, - 155, 95, 93, 118, 392, 193, 219, 20, 5, 4, - 3, 1, + 129, 411, 102, 105, 300, 151, 256, 321, 197, 339, + 336, 113, 203, 234, 239, 306, 6, 235, 194, 142, + 295, 71, 286, 145, 378, 327, 22, 326, 254, 254, + 281, 254, 254, 399, 340, 254, 127, 383, 363, 353, + 328, 290, 99, 128, 255, 101, 382, 379, 371, 116, + 112, 354, 21, 341, 24, 352, 114, 115, 159, 350, + 316, 117, 96, 107, 108, 109, 110, 111, 106, 314, + 157, 158, 313, 311, 100, 280, 278, 277, 275, 253, + 104, 337, 285, 174, 152, 153, 155, 154, 156, 274, + 173, 143, 227, 173, 269, 99, 268, 267, 101, 266, + 165, 166, 116, 112, 161, 168, 147, 241, 183, 114, + 115, 130, 159, 176, 117, 172, 107, 108, 109, 110, + 111, 106, 171, 167, 157, 158, 160, 100, 141, 182, + 140, 232, 201, 104, 410, 403, 208, 228, 152, 153, + 155, 154, 156, 365, 199, 230, 282, 281, 254, 150, + 83, 212, 159, 213, 214, 215, 216, 217, 218, 219, + 196, 210, 180, 181, 200, 159, 170, 174, 132, 232, + 76, 276, 250, 233, 236, 231, 106, 157, 158, 243, + 155, 154, 156, 319, 226, 229, 362, 361, 161, 283, + 400, 152, 153, 155, 154, 156, 206, 207, 209, 32, + 245, 225, 244, 261, 211, 324, 33, 232, 323, 259, + 160, 195, 262, 357, 347, 332, 270, 331, 271, 330, + 315, 298, 264, 205, 273, 260, 263, 146, 249, 99, + 248, 247, 101, 246, 240, 77, 116, 112, 242, 237, + 222, 192, 191, 114, 115, 184, 177, 148, 117, 131, + 107, 108, 109, 110, 111, 106, 120, 118, 302, 91, + 284, 100, 94, 54, 240, 80, 304, 104, 79, 310, + 297, 78, 297, 292, 75, 236, 299, 70, 69, 202, + 22, 309, 320, 31, 164, 39, 159, 296, 376, 360, + 377, 317, 318, 163, 272, 22, 359, 322, 157, 158, + 159, 49, 338, 265, 329, 159, 21, 159, 175, 342, + 65, 373, 152, 153, 155, 154, 156, 346, 158, 348, + 349, 21, 351, 344, 356, 343, 152, 153, 155, 154, + 156, 152, 153, 155, 154, 156, 119, 90, 236, 221, + 99, 55, 223, 101, 22, 224, 220, 116, 112, 412, + 413, 366, 394, 257, 114, 115, 301, 210, 370, 117, + 367, 107, 108, 109, 110, 111, 106, 64, 402, 138, + 21, 386, 100, 58, 369, 391, 381, 388, 104, 387, + 143, 390, 389, 279, 385, 345, 395, 149, 60, 52, + 397, 159, 62, 392, 204, 66, 67, 159, 380, 404, + 401, 364, 88, 157, 158, 408, 406, 409, 405, 157, + 158, 51, 50, 414, 53, 25, 415, 152, 153, 155, + 154, 156, 82, 152, 153, 155, 154, 156, 10, 12, + 11, 43, 47, 56, 57, 59, 122, 92, 324, 374, + 355, 323, 372, 188, 189, 85, 86, 87, 186, 187, + 185, 13, 291, 252, 48, 135, 251, 398, 334, 303, + 14, 15, 178, 121, 84, 7, 81, 8, 9, 16, + 17, 36, 44, 18, 19, 38, 46, 45, 133, 134, + 22, 258, 68, 42, 26, 30, 34, 2, 35, 312, + 37, 126, 125, 190, 73, 74, 179, 136, 40, 123, + 27, 29, 28, 287, 288, 289, 21, 294, 293, 139, + 137, 198, 63, 23, 41, 333, 144, 162, 358, 375, + 393, 407, 325, 95, 93, 368, 98, 97, 384, 308, + 307, 305, 124, 72, 89, 61, 169, 103, 335, 396, + 193, 238, 20, 5, 4, 3, 1, } var yyPact = [...]int16{ - 318, -1000, -1000, -25, -1000, -1000, -1000, 372, -1000, -1000, - 432, 223, 427, 451, 379, 379, 365, 362, 328, 207, - 278, 305, 334, -1000, 318, -1000, 268, 268, 268, 426, - 218, -1000, 217, 447, 214, 216, 213, 210, 209, 421, - 375, 88, -1000, -1000, -1000, -1000, -1000, -1000, -1000, 420, - 207, 207, 207, 350, -1000, 274, -1000, -1000, 206, -1000, - 377, 63, -1000, -1000, 204, 269, 202, 419, 268, 462, - -1000, -1000, 443, 6, 6, -1000, 201, 103, -1000, 425, - 461, 469, -1000, 379, 468, 40, 38, 306, 171, 255, - -1000, -1000, 199, 324, -1000, 83, 160, 96, 36, 102, - -1000, 261, 35, 197, 415, 459, -1000, 6, 6, -1000, - 39, 163, 258, -1000, 39, 39, 31, -1000, -1000, 39, - -1000, -1000, -1000, -1000, -1000, 30, -1000, -1000, -1000, -1000, - 20, -1000, 29, 196, 396, 395, 391, 458, 195, -1000, - 193, 153, 153, 471, 39, 89, -1000, 174, -1000, -1000, - 22, 120, -1000, -1000, 192, 91, 128, 84, 191, -1000, - 189, 19, 190, 127, -1000, -1000, 163, 39, 39, 39, - 39, 39, 39, 259, 267, 149, -1000, 173, 81, 255, - -16, 39, 39, 153, -1000, 189, 184, 183, 182, 180, - 114, 403, 402, -26, 80, -1000, -66, 295, 423, 163, - 471, 171, 39, 471, 447, 296, 18, 15, 14, 13, - 160, -11, 160, -1000, 110, -1000, -30, -31, -1000, 74, - -1000, 142, 153, -13, 454, 81, 81, 257, 257, 173, - 90, -1000, 248, 39, -14, -1000, -34, -1000, 143, -35, - 58, 163, -67, -1000, -1000, 401, -1000, -1000, 454, 465, - 464, 360, 170, 359, 293, 39, 411, 295, -1000, 163, - 166, 160, -47, 439, -48, -52, 169, -53, -1000, -1000, - -1000, -1000, -1000, 203, -83, -68, 153, -1000, -1000, -1000, - -1000, -1000, 173, -27, -1000, 138, -1000, 39, -1000, 168, - -1000, 167, 165, -18, -1000, -18, -1000, 39, 163, -37, - 293, 306, -1000, 166, 322, -1000, -1000, 160, 164, 160, - 160, -54, 160, 386, -1000, 39, 154, 256, 131, 113, - -1000, -70, -59, -71, -60, 163, -1000, -1000, -1000, 126, - -1000, 39, 57, 163, -1000, -1000, 153, -1000, 313, -1000, - 22, -1000, -61, -1000, -1000, -1000, -1000, 387, 111, 383, - 254, -1000, 236, -88, -63, -1000, -1000, -1000, -1000, -1000, - -18, 348, -69, -72, 317, 303, 471, 160, -37, 385, - 39, -1000, -1000, -1000, -1000, -1000, -1000, 341, -1000, -1000, - 291, 39, 151, 410, -1000, -73, -1000, 75, 338, 295, - 298, 163, 54, -1000, 39, -1000, 385, -1000, 293, 120, - 151, 163, -1000, -1000, 49, 282, -1000, 120, -1000, -1000, - -1000, 282, -1000, + 424, -1000, -1000, -55, -1000, -1000, -1000, 373, -1000, -1000, + 477, 192, 463, 467, 427, 427, 365, 364, 331, 172, + 271, 350, 335, -1000, 424, -1000, 236, 236, 236, 457, + 187, -1000, 186, 478, 183, 144, 180, 177, 174, 440, + 382, 48, -1000, -1000, -1000, -1000, -1000, -1000, -1000, 438, + 172, 172, 172, 351, -1000, 266, -1000, -1000, 168, -1000, + 398, 157, -1000, -1000, 166, 264, 165, 437, 236, 490, + -1000, -1000, 473, 23, 23, -1000, 158, 68, -1000, 450, + 488, 503, -1000, 427, 502, 20, 18, 319, 136, 224, + -1000, -1000, 156, 329, -1000, 47, 35, 212, -1000, 268, + 268, 13, -1000, -1000, 268, 65, 12, -1000, -1000, -1000, + -1000, -1000, 5, -1000, -1000, -1000, -1000, -17, -1000, 233, + 3, 155, 436, 486, -1000, 23, 23, -1000, 268, 320, + -1000, -2, 154, 419, 418, 412, 483, 151, -1000, 150, + 120, 120, 505, 268, 30, -1000, 189, -1000, -1000, 113, + 268, -1000, 268, 268, 268, 268, 268, 268, 268, 267, + -1000, 149, 269, 109, -1000, 228, 75, 224, -19, 36, + 90, 40, 268, 268, 148, -1000, 143, -3, 147, 84, + -1000, -1000, 320, 120, -1000, 143, 142, 140, 139, 137, + 77, 426, 423, -32, 46, -1000, -67, 289, 456, 320, + 505, 136, 268, 505, 478, 288, -11, -13, -14, -16, + 119, -20, 35, 75, 75, 230, 230, 230, 228, 223, + -1000, 215, -1000, 268, -21, -1000, -33, -1000, 76, -1000, + -34, -35, 67, 314, -36, 45, 320, -1000, 44, -1000, + 97, 120, -28, 492, -70, -1000, -1000, 422, -1000, -1000, + 492, 500, 499, 239, 130, 239, 291, 268, 433, 289, + -1000, 320, 193, 119, -38, 468, -39, -42, 129, -51, + -1000, -1000, -1000, 228, -30, -1000, -1000, -1000, -1000, 91, + -1000, 268, 173, -85, -71, 120, -1000, -1000, -1000, -1000, + -1000, 128, -1000, 126, 124, 432, -29, -1000, -1000, -1000, + -1000, 268, 320, -57, 291, 319, -1000, 193, 326, -1000, + -1000, 119, 123, 119, 119, -52, 119, -56, -72, -60, + 320, 407, -1000, 268, 122, 217, 94, 93, -1000, -73, + -1000, -1000, -1000, -1000, 349, 41, -1000, 268, 320, -1000, + -1000, 120, -1000, 312, -1000, 113, -1000, -63, -1000, -1000, + -1000, -1000, -1000, -1000, -1000, 408, 209, 404, 210, -1000, + 211, -89, -64, -1000, 345, -29, -65, -74, 324, 308, + 505, 119, -57, 406, 268, -1000, -1000, -1000, -1000, -1000, + 339, -1000, -1000, -1000, 286, 268, 116, 431, -1000, -78, + -1000, 88, -1000, 289, 305, 320, 33, -1000, 268, -1000, + 406, 291, 78, 116, 320, -1000, -1000, 32, 282, -1000, + 78, -1000, -1000, -1000, 282, -1000, } var yyPgo = [...]int16{ - 0, 511, 458, 510, 509, 508, 18, 507, 506, 21, - 13, 10, 505, 504, 16, 9, 17, 14, 503, 11, - 2, 502, 501, 500, 3, 499, 498, 12, 391, 20, - 497, 496, 32, 495, 15, 494, 493, 8, 0, 19, - 492, 491, 490, 489, 6, 4, 488, 7, 487, 486, - 1, 5, 369, 485, 484, 482, 24, 481, 480, 22, - 479, 309, 478, + 0, 546, 487, 545, 544, 543, 16, 542, 541, 14, + 18, 9, 540, 539, 538, 10, 17, 13, 537, 11, + 2, 536, 3, 535, 534, 12, 20, 394, 21, 533, + 532, 36, 531, 15, 530, 529, 7, 0, 19, 528, + 527, 526, 525, 6, 4, 524, 523, 522, 5, 521, + 520, 1, 8, 367, 519, 518, 517, 23, 516, 515, + 22, 514, 285, 513, } var yyR1 = [...]int8{ - 0, 1, 2, 2, 62, 62, 3, 3, 3, 4, + 0, 1, 2, 2, 63, 63, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 61, 61, 61, 60, 60, 60, 60, - 60, 60, 60, 59, 59, 59, 59, 52, 52, 11, - 11, 5, 5, 5, 5, 58, 58, 57, 57, 56, - 12, 12, 14, 14, 15, 10, 10, 13, 13, 17, - 17, 16, 16, 18, 18, 18, 18, 18, 18, 18, - 18, 18, 18, 19, 8, 8, 9, 46, 46, 46, - 53, 53, 54, 54, 54, 6, 6, 6, 6, 6, - 6, 6, 6, 7, 26, 26, 25, 25, 21, 21, - 22, 22, 20, 20, 20, 20, 23, 23, 24, 24, - 27, 27, 27, 27, 27, 27, 27, 27, 28, 29, - 30, 30, 30, 31, 31, 31, 32, 32, 33, 33, - 34, 34, 35, 36, 36, 39, 39, 43, 43, 40, - 40, 44, 44, 45, 45, 49, 49, 51, 51, 48, - 48, 50, 50, 50, 47, 47, 47, 37, 37, 37, - 38, 38, 38, 38, 38, 38, 38, 38, 41, 41, - 41, 41, 55, 55, 42, 42, 42, 42, 42, 42, - 42, 42, + 4, 4, 4, 62, 62, 62, 61, 61, 61, 61, + 61, 61, 61, 60, 60, 60, 60, 53, 53, 11, + 11, 5, 5, 5, 5, 26, 26, 59, 59, 58, + 58, 57, 12, 12, 14, 14, 15, 10, 10, 13, + 13, 17, 17, 16, 16, 18, 18, 18, 18, 18, + 18, 18, 18, 18, 18, 19, 8, 8, 9, 47, + 47, 47, 54, 54, 55, 55, 55, 6, 6, 6, + 6, 6, 6, 6, 6, 7, 24, 24, 23, 23, + 45, 45, 46, 46, 20, 20, 20, 20, 21, 21, + 22, 22, 25, 25, 25, 25, 25, 25, 25, 25, + 27, 28, 29, 29, 29, 30, 30, 30, 31, 31, + 32, 32, 33, 33, 34, 35, 35, 38, 38, 42, + 42, 39, 39, 43, 43, 44, 44, 50, 50, 52, + 52, 49, 49, 51, 51, 51, 48, 48, 48, 36, + 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, + 40, 40, 40, 40, 56, 56, 41, 41, 41, 41, + 41, 41, 41, 41, 41, } var yyR2 = [...]int8{ @@ -438,124 +444,124 @@ var yyR2 = [...]int8{ 9, 7, 5, 6, 6, 8, 6, 6, 7, 7, 3, 8, 8, 2, 1, 3, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 3, 1, - 3, 9, 8, 7, 8, 0, 4, 1, 3, 3, - 0, 1, 1, 3, 3, 1, 3, 1, 3, 0, - 1, 1, 3, 1, 1, 1, 1, 1, 6, 1, - 1, 1, 1, 4, 1, 3, 5, 0, 3, 3, - 0, 1, 0, 1, 2, 1, 4, 2, 2, 3, - 2, 2, 4, 13, 0, 1, 0, 1, 1, 1, - 2, 4, 1, 2, 4, 4, 2, 3, 1, 3, - 3, 4, 4, 4, 4, 4, 2, 6, 1, 2, - 0, 2, 2, 0, 2, 2, 2, 1, 0, 1, - 1, 2, 6, 0, 1, 0, 2, 0, 3, 0, - 2, 0, 2, 0, 2, 0, 3, 0, 4, 2, - 4, 0, 1, 1, 0, 1, 2, 0, 4, 6, - 1, 1, 2, 2, 4, 4, 6, 6, 1, 1, - 3, 3, 0, 1, 3, 3, 3, 3, 3, 3, - 3, 4, + 3, 8, 7, 7, 8, 2, 1, 0, 4, 1, + 3, 3, 0, 1, 1, 3, 3, 1, 3, 1, + 3, 0, 1, 1, 3, 1, 1, 1, 1, 1, + 6, 1, 1, 1, 1, 4, 1, 3, 5, 0, + 3, 3, 0, 1, 0, 1, 2, 1, 4, 2, + 2, 3, 2, 2, 4, 13, 0, 1, 0, 1, + 1, 1, 2, 4, 1, 2, 4, 4, 2, 3, + 1, 3, 3, 4, 4, 4, 4, 4, 2, 6, + 1, 2, 0, 2, 2, 0, 2, 2, 2, 1, + 0, 1, 1, 2, 6, 0, 1, 0, 2, 0, + 3, 0, 2, 0, 2, 0, 2, 0, 3, 0, + 4, 2, 4, 0, 1, 1, 0, 1, 2, 0, + 4, 6, 1, 1, 2, 2, 4, 4, 6, 6, + 1, 1, 3, 3, 0, 1, 3, 3, 3, 3, + 3, 3, 3, 3, 4, } var yyChk = [...]int16{ -1000, -1, -2, -3, -4, -5, -6, 41, 43, 44, 4, 6, 5, 27, 36, 37, 45, 46, 49, 50, - -7, 81, 55, -62, 107, 42, 7, 23, 25, 24, - 8, 90, 7, 14, 23, 25, 8, 23, 8, -61, - 70, -60, 55, 4, 45, 50, 49, 5, 27, -61, - 47, 47, 57, -28, 90, 69, 82, 83, 23, 84, - 38, -25, 56, -2, -52, 73, -52, -52, 25, 90, - 90, -29, -30, 16, 17, 90, 26, 90, 90, 90, - 90, 26, 40, 101, 26, -28, -28, -28, 51, -26, - 70, 90, 39, -21, 104, -22, -20, -24, 97, 90, - 90, 71, 90, 26, -52, 9, -31, 19, 18, -32, - 20, -38, -41, -42, 71, 103, 74, -20, -18, 108, - 92, 93, 94, 95, 96, 79, -19, 85, 86, 78, - 90, -32, 90, 99, 28, 29, 5, 9, 7, -61, - 7, 108, 108, -39, 60, -57, -56, 90, -6, 90, - 57, 101, -47, 90, 68, -23, 100, 108, 99, 74, - 108, 90, 26, 10, -32, -32, -38, 102, 103, 105, - 104, 88, 89, 76, -55, 80, 71, -38, -38, 108, - -38, 108, 108, 108, 90, 31, 30, 31, 31, 32, - 10, 90, 90, -12, -10, 90, -10, -51, 6, -38, - -39, 101, 89, -27, -28, 108, 82, 83, 23, 84, - -19, 90, -20, 90, 100, 94, 104, -24, 90, -8, - -9, 90, 108, 90, 94, -38, -38, -38, -38, -38, - -38, 78, 71, 72, 75, 91, -6, 109, -38, -17, - -16, -38, -10, -9, 90, 90, 90, 90, 94, 30, - 30, 109, 101, 109, -44, 63, 25, -51, -56, -38, - -51, -29, -6, 15, 108, 108, 108, 108, -47, -47, - 94, 109, 109, 101, 91, -10, 108, -59, 11, 12, - 13, 78, -38, 108, 109, 68, 109, 101, 109, 30, - -59, 8, 8, 48, 90, 48, -45, 64, -38, 26, - -44, -33, -34, -35, -36, 87, -47, 109, 21, 109, - 109, 90, 109, -37, -9, 35, 32, -46, 110, 108, - 109, -10, -6, -16, 91, -38, 90, 90, 90, -14, - -15, 108, -14, -38, -11, 90, 108, -45, -39, -34, - 58, -47, 90, -47, -47, 109, -47, 33, -38, 90, - -54, 78, 71, 92, 92, 109, 109, 109, 109, -58, - 101, 26, -17, -10, -43, 61, -27, 109, 34, 101, - 35, -53, 77, 78, 111, 109, -15, 52, 109, 109, - -40, 59, 62, -51, -47, -11, -37, -38, 53, -49, - 65, -38, -13, -24, 26, 109, 101, 54, -44, 62, - 101, -38, -37, -45, -48, -20, -24, 101, -50, 66, - 67, -20, -50, + -7, 82, 56, -63, 109, 42, 7, 23, 25, 24, + 8, 91, 7, 14, 23, 25, 8, 23, 8, -62, + 71, -61, 56, 4, 45, 50, 49, 5, 27, -62, + 47, 47, 58, -27, 91, 70, 83, 84, 23, 85, + 38, -23, 57, -2, -53, 74, -53, -53, 25, 91, + 91, -28, -29, 16, 17, 91, 26, 91, 91, 91, + 91, 26, 40, 102, 26, -27, -27, -27, 51, -24, + 71, 91, 39, -45, 105, -46, -37, -40, -41, 72, + 104, 75, -20, -18, 110, -22, 98, 93, 94, 95, + 96, 97, 80, -19, 86, 87, 79, 91, 91, 72, + 91, 26, -53, 9, -30, 19, 18, -31, 20, -37, + -31, 91, 100, 28, 29, 5, 9, 7, -62, 7, + 110, 110, -38, 61, -58, -57, 91, -6, 91, 58, + 102, -48, 103, 104, 106, 105, 107, 89, 90, 77, + 91, 69, -56, 81, 72, -37, -37, 110, -37, -21, + 101, 110, 110, 110, 100, 75, 110, 91, 26, 10, + -31, -31, -37, 110, 91, 31, 30, 31, 31, 32, + 10, 91, 91, -12, -10, 91, -10, -52, 6, -37, + -38, 102, 90, -25, -27, 110, 83, 84, 23, 85, + -19, 91, -37, -37, -37, -37, -37, -37, -37, -37, + 79, 72, 91, 73, 76, 92, -6, 111, 101, 95, + 105, -22, 91, -37, -17, -16, -37, 91, -8, -9, + 91, 110, 91, 95, -10, -9, 91, 91, 91, 91, + 95, 30, 30, 111, 102, 111, -43, 64, 25, -52, + -57, -37, -52, -28, -6, 15, 110, 110, 110, 110, + -48, -48, 79, -37, 110, 111, 95, 111, 111, 69, + 111, 102, 102, 92, -10, 110, -60, 11, 12, 13, + 111, 30, -60, 8, 8, -26, 48, -6, 91, -26, + -44, 65, -37, 26, -43, -32, -33, -34, -35, 88, + -48, 111, 21, 111, 111, 91, 111, -6, -16, 92, + -37, -36, -9, 35, 32, -47, 112, 110, 111, -10, + 91, 91, 91, -59, 26, -14, -15, 110, -37, -11, + 91, 110, -44, -38, -33, 59, -48, 91, -48, -48, + 111, -48, 111, 111, 111, 33, -37, 91, -55, 79, + 72, 93, 93, 111, 52, 102, -17, -10, -42, 62, + -25, 111, 34, 102, 35, -54, 78, 79, 113, 111, + 53, -15, 111, 111, -39, 60, 63, -52, -48, -11, + -36, -37, 54, -50, 66, -37, -13, -22, 26, 111, + 102, -43, 63, 102, -37, -36, -44, -49, -20, -22, + 102, -51, 67, 68, -20, -51, } var yyDef = [...]int16{ 0, -2, 1, 4, 6, 7, 8, 10, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 95, 0, 106, 2, 5, 9, 47, 47, 47, 0, - 0, 14, 0, 130, 0, 0, 0, 0, 0, 0, + 97, 0, 108, 2, 5, 9, 47, 47, 47, 0, + 0, 14, 0, 132, 0, 0, 0, 0, 0, 0, 0, 34, 36, 37, 38, 39, 40, 41, 42, 0, - 0, 0, 0, 0, 128, 104, 97, 98, 0, 100, - 101, 0, 107, 3, 0, 0, 0, 0, 47, 0, - 15, 16, 133, 0, 0, 18, 0, 0, 30, 0, - 0, 0, 33, 0, 0, 0, 0, 145, 0, 0, - 105, 99, 0, 0, 108, 109, 164, 112, 0, 118, - 13, 0, 0, 0, 0, 0, 129, 0, 0, 131, - 0, 137, -2, 171, 0, 0, 0, 178, 179, 0, - 73, 74, 75, 76, 77, 0, 79, 80, 81, 82, - 118, 132, 0, 0, 0, 0, 0, 0, 0, 35, - 0, 60, 0, 157, 0, 145, 57, 0, 96, 102, - 0, 0, 110, 165, 0, 113, 0, 0, 0, 48, - 0, 0, 0, 0, 134, 135, 136, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 183, 172, 173, 0, - 0, 0, 69, 0, 22, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 61, 65, 0, 151, 0, 146, - 157, 0, 0, 157, 130, 0, 0, 0, 0, 0, - 164, 128, 164, 166, 0, 116, 0, 0, 119, 0, - 84, 0, 0, 0, 43, 184, 185, 186, 187, 188, - 189, 190, 0, 0, 0, 181, 0, 180, 0, 0, - 70, 71, 0, 23, 24, 0, 26, 27, 43, 0, - 0, 0, 0, 0, 153, 0, 0, 151, 58, 59, - -2, 164, 0, 0, 0, 0, 0, 0, 126, 111, - 117, 114, 115, 167, 87, 0, 0, 28, 44, 45, - 46, 191, 174, 0, 175, 0, 83, 0, 21, 0, - 29, 0, 0, 0, 66, 0, 53, 0, 152, 0, - 153, 145, 139, -2, 0, 144, 120, 164, 0, 164, - 164, 0, 164, 0, 85, 0, 0, 92, 0, 0, - 19, 0, 0, 0, 0, 72, 25, 31, 32, 55, - 62, 69, 52, 154, 158, 49, 0, 54, 147, 141, - 0, 121, 0, 122, 123, 124, 125, 0, 0, 0, - 90, 93, 0, 0, 0, 20, 176, 177, 78, 51, - 0, 0, 0, 0, 149, 0, 157, 164, 0, 167, - 0, 86, 91, 94, 88, 89, 63, 0, 64, 50, - 155, 0, 0, 0, 127, 0, 168, 0, 0, 151, - 0, 150, 148, 67, 0, 17, 167, 56, 153, 0, - 0, 142, 169, 103, 156, 161, 68, 0, 159, 162, - 163, 161, 160, + 0, 0, 0, 0, 130, 106, 99, 100, 0, 102, + 103, 0, 109, 3, 0, 0, 0, 0, 47, 0, + 15, 16, 135, 0, 0, 18, 0, 0, 30, 0, + 0, 0, 33, 0, 0, 0, 0, 147, 0, 0, + 107, 101, 0, 0, 110, 111, 166, -2, 173, 0, + 0, 0, 180, 181, 0, 114, 0, 75, 76, 77, + 78, 79, 0, 81, 82, 83, 84, 120, 13, 0, + 0, 0, 0, 0, 131, 0, 0, 133, 0, 139, + 134, 0, 0, 0, 0, 0, 0, 0, 35, 0, + 62, 0, 159, 0, 147, 59, 0, 98, 104, 0, + 0, 112, 0, 0, 0, 0, 0, 0, 0, 0, + 167, 0, 0, 0, 185, 174, 175, 0, 0, 115, + 0, 0, 0, 71, 0, 48, 0, 0, 0, 0, + 136, 137, 138, 0, 22, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 63, 67, 0, 153, 0, 148, + 159, 0, 0, 159, 132, 0, 0, 0, 0, 0, + 166, 130, 166, 186, 187, 188, 189, 190, 191, 192, + 193, 0, 168, 0, 0, 183, 0, 182, 0, 118, + 0, 0, 120, 0, 0, 72, 73, 121, 0, 86, + 0, 0, 0, 43, 0, 23, 24, 0, 26, 27, + 43, 0, 0, 0, 0, 0, 155, 0, 0, 153, + 60, 61, -2, 166, 0, 0, 0, 0, 0, 0, + 128, 113, 194, 176, 0, 177, 119, 116, 117, 0, + 85, 0, 169, 89, 0, 0, 28, 44, 45, 46, + 21, 0, 29, 0, 0, 57, 0, 56, 68, 52, + 53, 0, 154, 0, 155, 147, 141, -2, 0, 146, + 122, 166, 0, 166, 166, 0, 166, 0, 0, 0, + 74, 0, 87, 0, 0, 94, 0, 0, 19, 0, + 25, 31, 32, 51, 0, 55, 64, 71, 156, 160, + 49, 0, 54, 149, 143, 0, 123, 0, 124, 125, + 126, 127, 178, 179, 80, 0, 0, 0, 92, 95, + 0, 0, 0, 20, 0, 0, 0, 0, 151, 0, + 159, 166, 0, 169, 0, 88, 93, 96, 90, 91, + 0, 65, 66, 50, 157, 0, 0, 0, 129, 0, + 170, 0, 58, 153, 0, 152, 150, 69, 0, 17, + 169, 155, 0, 0, 144, 171, 105, 158, 163, 70, + 0, 161, 164, 165, 163, 162, } var yyTok1 = [...]int8{ 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 107, 3, 3, + 110, 111, 105, 103, 102, 104, 108, 106, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 108, 109, 104, 102, 101, 103, 106, 105, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 110, 3, 111, + 3, 112, 3, 113, } var yyTok2 = [...]int8{ @@ -568,7 +574,8 @@ var yyTok2 = [...]int8{ 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, - 92, 93, 94, 95, 96, 97, 98, 99, 100, 107, + 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, + 109, } var yyTok3 = [...]int8{ @@ -1141,14 +1148,14 @@ yydefault: yyVAL.ids = yyDollar[2].ids } case 51: - yyDollar = yyS[yypt-9 : yypt+1] + yyDollar = yyS[yypt-8 : yypt+1] { - yyVAL.stmt = &UpsertIntoStmt{isInsert: true, tableRef: yyDollar[3].tableRef, cols: yyDollar[5].ids, rows: yyDollar[8].rows, onConflict: yyDollar[9].onConflict} + yyVAL.stmt = &UpsertIntoStmt{isInsert: true, tableRef: yyDollar[3].tableRef, cols: yyDollar[5].ids, ds: yyDollar[7].ds, onConflict: yyDollar[8].onConflict} } case 52: - yyDollar = yyS[yypt-8 : yypt+1] + yyDollar = yyS[yypt-7 : yypt+1] { - yyVAL.stmt = &UpsertIntoStmt{tableRef: yyDollar[3].tableRef, cols: yyDollar[5].ids, rows: yyDollar[8].rows} + yyVAL.stmt = &UpsertIntoStmt{tableRef: yyDollar[3].tableRef, cols: yyDollar[5].ids, ds: yyDollar[7].ds} } case 53: yyDollar = yyS[yypt-7 : yypt+1] @@ -1161,211 +1168,221 @@ yydefault: yyVAL.stmt = &UpdateStmt{tableRef: yyDollar[2].tableRef, updates: yyDollar[4].updates, where: yyDollar[5].exp, indexOn: yyDollar[6].ids, limit: yyDollar[7].exp, offset: yyDollar[8].exp} } case 55: + yyDollar = yyS[yypt-2 : yypt+1] + { + yyVAL.ds = &valuesDataSource{rows: yyDollar[2].rows} + } + case 56: + yyDollar = yyS[yypt-1 : yypt+1] + { + yyVAL.ds = yyDollar[1].stmt.(DataSource) + } + case 57: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.onConflict = nil } - case 56: + case 58: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.onConflict = &OnConflictDo{} } - case 57: + case 59: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.updates = []*colUpdate{yyDollar[1].update} } - case 58: + case 60: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.updates = append(yyDollar[1].updates, yyDollar[3].update) } - case 59: + case 61: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.update = &colUpdate{col: yyDollar[1].id, op: yyDollar[2].cmpOp, val: yyDollar[3].exp} } - case 60: + case 62: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.ids = nil } - case 61: + case 63: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.ids = yyDollar[1].ids } - case 62: + case 64: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.rows = []*RowSpec{yyDollar[1].row} } - case 63: + case 65: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.rows = append(yyDollar[1].rows, yyDollar[3].row) } - case 64: + case 66: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.row = &RowSpec{Values: yyDollar[2].values} } - case 65: + case 67: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.ids = []string{yyDollar[1].id} } - case 66: + case 68: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.ids = append(yyDollar[1].ids, yyDollar[3].id) } - case 67: + case 69: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.cols = []*ColSelector{yyDollar[1].col} } - case 68: + case 70: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.cols = append(yyDollar[1].cols, yyDollar[3].col) } - case 69: + case 71: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.values = nil } - case 70: + case 72: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.values = yyDollar[1].values } - case 71: + case 73: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.values = []ValueExp{yyDollar[1].exp} } - case 72: + case 74: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.values = append(yyDollar[1].values, yyDollar[3].exp) } - case 73: + case 75: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Integer{val: int64(yyDollar[1].integer)} } - case 74: + case 76: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Float64{val: float64(yyDollar[1].float)} } - case 75: + case 77: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Varchar{val: yyDollar[1].str} } - case 76: + case 78: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Bool{val: yyDollar[1].boolean} } - case 77: + case 79: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Blob{val: yyDollar[1].blob} } - case 78: + case 80: yyDollar = yyS[yypt-6 : yypt+1] { yyVAL.value = &Cast{val: yyDollar[3].exp, t: yyDollar[5].sqlType} } - case 79: + case 81: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = yyDollar[1].value } - case 80: + case 82: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Param{id: yyDollar[1].id} } - case 81: + case 83: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &Param{id: fmt.Sprintf("param%d", yyDollar[1].pparam), pos: yyDollar[1].pparam} } - case 82: + case 84: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.value = &NullValue{t: AnyType} } - case 83: + case 85: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.value = &FnCall{fn: yyDollar[1].id, params: yyDollar[3].values} } - case 84: + case 86: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.colsSpec = []*ColSpec{yyDollar[1].colSpec} } - case 85: + case 87: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.colsSpec = append(yyDollar[1].colsSpec, yyDollar[3].colSpec) } - case 86: + case 88: yyDollar = yyS[yypt-5 : yypt+1] { yyVAL.colSpec = &ColSpec{colName: yyDollar[1].id, colType: yyDollar[2].sqlType, maxLen: int(yyDollar[3].integer), notNull: yyDollar[4].boolean, autoIncrement: yyDollar[5].boolean} } - case 87: + case 89: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.integer = 0 } - case 88: + case 90: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.integer = yyDollar[2].integer } - case 89: + case 91: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.integer = yyDollar[2].integer } - case 90: + case 92: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.boolean = false } - case 91: + case 93: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.boolean = true } - case 92: + case 94: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.boolean = false } - case 93: + case 95: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.boolean = false } - case 94: + case 96: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.boolean = true } - case 95: + case 97: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.stmt = yyDollar[1].stmt } - case 96: + case 98: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.stmt = &UnionStmt{ @@ -1374,506 +1391,509 @@ yydefault: right: yyDollar[4].stmt.(DataSource), } } - case 97: + case 99: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.stmt = &SelectStmt{ ds: &FnDataSourceStmt{fnCall: &FnCall{fn: "databases"}}, } } - case 98: + case 100: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.stmt = &SelectStmt{ ds: &FnDataSourceStmt{fnCall: &FnCall{fn: "tables"}}, } } - case 99: + case 101: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.stmt = &SelectStmt{ ds: &FnDataSourceStmt{fnCall: &FnCall{fn: "table", params: []ValueExp{&Varchar{val: yyDollar[3].id}}}}, } } - case 100: + case 102: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.stmt = &SelectStmt{ ds: &FnDataSourceStmt{fnCall: &FnCall{fn: "users"}}, } } - case 101: + case 103: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.stmt = &SelectStmt{ ds: &FnDataSourceStmt{fnCall: &FnCall{fn: "grants"}}, } } - case 102: + case 104: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.stmt = &SelectStmt{ ds: &FnDataSourceStmt{fnCall: &FnCall{fn: "grants", params: []ValueExp{&Varchar{val: yyDollar[4].id}}}}, } } - case 103: + case 105: yyDollar = yyS[yypt-13 : yypt+1] { yyVAL.stmt = &SelectStmt{ - distinct: yyDollar[2].distinct, - selectors: yyDollar[3].sels, - ds: yyDollar[5].ds, - indexOn: yyDollar[6].ids, - joins: yyDollar[7].joins, - where: yyDollar[8].exp, - groupBy: yyDollar[9].cols, - having: yyDollar[10].exp, - orderBy: yyDollar[11].ordcols, - limit: yyDollar[12].exp, - offset: yyDollar[13].exp, + distinct: yyDollar[2].distinct, + targets: yyDollar[3].targets, + ds: yyDollar[5].ds, + indexOn: yyDollar[6].ids, + joins: yyDollar[7].joins, + where: yyDollar[8].exp, + groupBy: yyDollar[9].cols, + having: yyDollar[10].exp, + orderBy: yyDollar[11].ordcols, + limit: yyDollar[12].exp, + offset: yyDollar[13].exp, } } - case 104: + case 106: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.distinct = true } - case 105: + case 107: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.distinct = false } - case 106: + case 108: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.distinct = false } - case 107: + case 109: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.distinct = true } - case 108: + case 110: yyDollar = yyS[yypt-1 : yypt+1] { - yyVAL.sels = nil + yyVAL.targets = nil } - case 109: + case 111: yyDollar = yyS[yypt-1 : yypt+1] { - yyVAL.sels = yyDollar[1].sels + yyVAL.targets = yyDollar[1].targets } - case 110: + case 112: yyDollar = yyS[yypt-2 : yypt+1] { - yyDollar[1].sel.setAlias(yyDollar[2].id) - yyVAL.sels = []Selector{yyDollar[1].sel} + yyVAL.targets = []TargetEntry{{Exp: yyDollar[1].exp, As: yyDollar[2].id}} } - case 111: + case 113: yyDollar = yyS[yypt-4 : yypt+1] { - yyDollar[3].sel.setAlias(yyDollar[4].id) - yyVAL.sels = append(yyDollar[1].sels, yyDollar[3].sel) + yyVAL.targets = append(yyDollar[1].targets, TargetEntry{Exp: yyDollar[3].exp, As: yyDollar[4].id}) } - case 112: + case 114: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.sel = yyDollar[1].col } - case 113: + case 115: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.sel = &JSONSelector{ColSelector: yyDollar[1].col, fields: yyDollar[2].jsonFields} } - case 114: + case 116: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.sel = &AggColSelector{aggFn: yyDollar[1].aggFn, col: "*"} } - case 115: + case 117: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.sel = &AggColSelector{aggFn: yyDollar[1].aggFn, table: yyDollar[3].col.table, col: yyDollar[3].col.col} } - case 116: + case 118: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.jsonFields = []string{yyDollar[2].str} } - case 117: + case 119: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.jsonFields = append(yyVAL.jsonFields, yyDollar[3].str) } - case 118: + case 120: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.col = &ColSelector{col: yyDollar[1].id} } - case 119: + case 121: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.col = &ColSelector{table: yyDollar[1].id, col: yyDollar[3].id} } - case 120: + case 122: yyDollar = yyS[yypt-3 : yypt+1] { yyDollar[1].tableRef.period = yyDollar[2].period yyDollar[1].tableRef.as = yyDollar[3].id yyVAL.ds = yyDollar[1].tableRef } - case 121: + case 123: yyDollar = yyS[yypt-4 : yypt+1] { yyDollar[2].stmt.(*SelectStmt).as = yyDollar[4].id yyVAL.ds = yyDollar[2].stmt.(DataSource) } - case 122: + case 124: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.ds = &FnDataSourceStmt{fnCall: &FnCall{fn: "databases"}, as: yyDollar[4].id} } - case 123: + case 125: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.ds = &FnDataSourceStmt{fnCall: &FnCall{fn: "tables"}, as: yyDollar[4].id} } - case 124: + case 126: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.ds = &FnDataSourceStmt{fnCall: &FnCall{fn: "table", params: []ValueExp{&Varchar{val: yyDollar[3].id}}}} } - case 125: + case 127: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.ds = &FnDataSourceStmt{fnCall: &FnCall{fn: "users"}, as: yyDollar[4].id} } - case 126: + case 128: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.ds = &FnDataSourceStmt{fnCall: yyDollar[1].value.(*FnCall), as: yyDollar[2].id} } - case 127: + case 129: yyDollar = yyS[yypt-6 : yypt+1] { yyVAL.ds = &tableRef{table: yyDollar[4].id, history: true, as: yyDollar[6].id} } - case 128: + case 130: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.tableRef = &tableRef{table: yyDollar[1].id} } - case 129: + case 131: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.period = period{start: yyDollar[1].openPeriod, end: yyDollar[2].openPeriod} } - case 130: + case 132: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.openPeriod = nil } - case 131: + case 133: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.openPeriod = &openPeriod{inclusive: true, instant: yyDollar[2].periodInstant} } - case 132: + case 134: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.openPeriod = &openPeriod{instant: yyDollar[2].periodInstant} } - case 133: + case 135: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.openPeriod = nil } - case 134: + case 136: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.openPeriod = &openPeriod{inclusive: true, instant: yyDollar[2].periodInstant} } - case 135: + case 137: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.openPeriod = &openPeriod{instant: yyDollar[2].periodInstant} } - case 136: + case 138: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.periodInstant = periodInstant{instantType: txInstant, exp: yyDollar[2].exp} } - case 137: + case 139: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.periodInstant = periodInstant{instantType: timeInstant, exp: yyDollar[1].exp} } - case 138: + case 140: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.joins = nil } - case 139: + case 141: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.joins = yyDollar[1].joins } - case 140: + case 142: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.joins = []*JoinSpec{yyDollar[1].join} } - case 141: + case 143: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.joins = append([]*JoinSpec{yyDollar[1].join}, yyDollar[2].joins...) } - case 142: + case 144: yyDollar = yyS[yypt-6 : yypt+1] { yyVAL.join = &JoinSpec{joinType: yyDollar[1].joinType, ds: yyDollar[3].ds, indexOn: yyDollar[4].ids, cond: yyDollar[6].exp} } - case 143: + case 145: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.joinType = InnerJoin } - case 144: + case 146: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.joinType = yyDollar[1].joinType } - case 145: + case 147: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.exp = nil } - case 146: + case 148: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.exp = yyDollar[2].exp } - case 147: + case 149: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.cols = nil } - case 148: + case 150: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.cols = yyDollar[3].cols } - case 149: + case 151: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.exp = nil } - case 150: + case 152: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.exp = yyDollar[2].exp } - case 151: + case 153: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.exp = nil } - case 152: + case 154: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.exp = yyDollar[2].exp } - case 153: + case 155: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.exp = nil } - case 154: + case 156: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.exp = yyDollar[2].exp } - case 155: + case 157: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.ordcols = nil } - case 156: + case 158: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.ordcols = yyDollar[3].ordcols } - case 157: + case 159: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.ids = nil } - case 158: + case 160: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.ids = yyDollar[4].ids } - case 159: + case 161: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.ordcols = []*OrdCol{{sel: yyDollar[1].sel, descOrder: yyDollar[2].opt_ord}} } - case 160: + case 162: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.ordcols = append(yyDollar[1].ordcols, &OrdCol{sel: yyDollar[3].sel, descOrder: yyDollar[4].opt_ord}) } - case 161: + case 163: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.opt_ord = false } - case 162: + case 164: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.opt_ord = false } - case 163: + case 165: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.opt_ord = true } - case 164: + case 166: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.id = "" } - case 165: + case 167: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.id = yyDollar[1].id } - case 166: + case 168: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.id = yyDollar[2].id } - case 167: + case 169: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.checks = nil } - case 168: + case 170: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.checks = append([]CheckConstraint{{exp: yyDollar[2].exp}}, yyDollar[4].checks...) } - case 169: + case 171: yyDollar = yyS[yypt-6 : yypt+1] { yyVAL.checks = append([]CheckConstraint{{name: yyDollar[2].id, exp: yyDollar[4].exp}}, yyDollar[6].checks...) } - case 170: + case 172: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.exp = yyDollar[1].exp } - case 171: + case 173: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.exp = yyDollar[1].binExp } - case 172: + case 174: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.exp = &NotBoolExp{exp: yyDollar[2].exp} } - case 173: + case 175: yyDollar = yyS[yypt-2 : yypt+1] { yyVAL.exp = &NumExp{left: &Integer{val: 0}, op: SUBSOP, right: yyDollar[2].exp} } - case 174: + case 176: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.exp = &LikeBoolExp{val: yyDollar[1].exp, notLike: yyDollar[2].boolean, pattern: yyDollar[4].exp} } - case 175: + case 177: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.exp = &ExistsBoolExp{q: (yyDollar[3].stmt).(DataSource)} } - case 176: + case 178: yyDollar = yyS[yypt-6 : yypt+1] { yyVAL.exp = &InSubQueryExp{val: yyDollar[1].exp, notIn: yyDollar[2].boolean, q: yyDollar[5].stmt.(*SelectStmt)} } - case 177: + case 179: yyDollar = yyS[yypt-6 : yypt+1] { yyVAL.exp = &InListExp{val: yyDollar[1].exp, notIn: yyDollar[2].boolean, values: yyDollar[5].values} } - case 178: + case 180: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.exp = yyDollar[1].sel } - case 179: + case 181: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.exp = yyDollar[1].value } - case 180: + case 182: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.exp = yyDollar[2].exp } - case 181: + case 183: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.exp = &Cast{val: yyDollar[1].exp, t: yyDollar[3].sqlType} } - case 182: + case 184: yyDollar = yyS[yypt-0 : yypt+1] { yyVAL.boolean = false } - case 183: + case 185: yyDollar = yyS[yypt-1 : yypt+1] { yyVAL.boolean = true } - case 184: + case 186: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &NumExp{left: yyDollar[1].exp, op: ADDOP, right: yyDollar[3].exp} } - case 185: + case 187: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &NumExp{left: yyDollar[1].exp, op: SUBSOP, right: yyDollar[3].exp} } - case 186: + case 188: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &NumExp{left: yyDollar[1].exp, op: DIVOP, right: yyDollar[3].exp} } - case 187: + case 189: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &NumExp{left: yyDollar[1].exp, op: MULTOP, right: yyDollar[3].exp} } - case 188: + case 190: + yyDollar = yyS[yypt-3 : yypt+1] + { + yyVAL.binExp = &NumExp{left: yyDollar[1].exp, op: MODOP, right: yyDollar[3].exp} + } + case 191: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &BinBoolExp{left: yyDollar[1].exp, op: yyDollar[2].logicOp, right: yyDollar[3].exp} } - case 189: + case 192: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &CmpBoolExp{left: yyDollar[1].exp, op: yyDollar[2].cmpOp, right: yyDollar[3].exp} } - case 190: + case 193: yyDollar = yyS[yypt-3 : yypt+1] { yyVAL.binExp = &CmpBoolExp{left: yyDollar[1].exp, op: EQ, right: &NullValue{t: AnyType}} } - case 191: + case 194: yyDollar = yyS[yypt-4 : yypt+1] { yyVAL.binExp = &CmpBoolExp{left: yyDollar[1].exp, op: NE, right: &NullValue{t: AnyType}} diff --git a/embedded/sql/stmt.go b/embedded/sql/stmt.go index e7bdd69410..6fa519228a 100644 --- a/embedded/sql/stmt.go +++ b/embedded/sql/stmt.go @@ -175,6 +175,7 @@ const ( SUBSOP DIVOP MULTOP + MODOP ) func NumOperatorString(op NumOperator) string { @@ -187,6 +188,8 @@ func NumOperatorString(op NumOperator) string { return "/" case MULTOP: return "*" + case MODOP: + return "%" } return "" } @@ -199,19 +202,6 @@ const ( RightJoin ) -const ( - NowFnCall string = "NOW" - UUIDFnCall string = "RANDOM_UUID" - DatabasesFnCall string = "DATABASES" - TablesFnCall string = "TABLES" - TableFnCall string = "TABLE" - UsersFnCall string = "USERS" - ColumnsFnCall string = "COLUMNS" - IndexesFnCall string = "INDEXES" - GrantsFnCall string = "GRANTS" - JSONTypeOfFnCall string = "JSON_TYPEOF" -) - type SQLStmt interface { readOnly() bool requiredPrivileges() []SQLPrivilege @@ -1029,7 +1019,7 @@ type UpsertIntoStmt struct { isInsert bool tableRef *tableRef cols []string - rows []*RowSpec + ds DataSource onConflict *OnConflictDo } @@ -1038,18 +1028,26 @@ func (stmt *UpsertIntoStmt) readOnly() bool { } func (stmt *UpsertIntoStmt) requiredPrivileges() []SQLPrivilege { + privileges := stmt.privileges() + if stmt.ds != nil { + privileges = append(privileges, stmt.ds.requiredPrivileges()...) + } + return privileges +} + +func (stmt *UpsertIntoStmt) privileges() []SQLPrivilege { if stmt.isInsert { return []SQLPrivilege{SQLPrivilegeInsert} } return []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate} } -func NewUpserIntoStmt(table string, cols []string, rows []*RowSpec, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt { +func NewUpsertIntoStmt(table string, cols []string, ds DataSource, isInsert bool, onConflict *OnConflictDo) *UpsertIntoStmt { return &UpsertIntoStmt{ isInsert: isInsert, tableRef: NewTableRef(table, ""), cols: cols, - rows: rows, + ds: ds, onConflict: onConflict, } } @@ -1067,9 +1065,13 @@ func NewRowSpec(values []ValueExp) *RowSpec { type OnConflictDo struct{} func (stmt *UpsertIntoStmt) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error { - emptyDescriptors := make(map[string]ColDescriptor) + ds, ok := stmt.ds.(*valuesDataSource) + if !ok { + return stmt.ds.inferParameters(ctx, tx, params) + } - for _, row := range stmt.rows { + emptyDescriptors := make(map[string]ColDescriptor) + for _, row := range ds.rows { if len(stmt.cols) != len(row.Values) { return ErrInvalidNumberOfValues } @@ -1130,8 +1132,22 @@ func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[st ValuesBySelector: make(map[string]TypedValue), } - for _, row := range stmt.rows { - if len(row.Values) != len(stmt.cols) { + reader, err := stmt.ds.Resolve(ctx, tx, params, nil) + if err != nil { + return nil, err + } + defer reader.Close() + + for { + row, err := reader.Read(ctx) + if errors.Is(err, ErrNoMoreRows) { + break + } + if err != nil { + return nil, err + } + + if len(row.ValuesByPosition) != len(stmt.cols) { return nil, ErrInvalidNumberOfValues } @@ -1165,7 +1181,7 @@ func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[st } // value was specified - cVal := row.Values[colPos] + cVal := row.ValuesByPosition[colPos] val, err := cVal.substitute(params) if err != nil { @@ -1261,7 +1277,6 @@ func (stmt *UpsertIntoStmt) execAt(ctx context.Context, tx *SQLTx, params map[st return nil, err } } - return tx, nil } @@ -1872,6 +1887,7 @@ type ValueExp interface { inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error substitute(params map[string]interface{}) (ValueExp, error) + selectors() []Selector reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) reduceSelectors(row *Row, implicitTable string) ValueExp isConstant() bool @@ -2058,6 +2074,10 @@ func (v *NullValue) requiresType(t SQLValueType, cols map[string]ColDescriptor, return nil } +func (v *NullValue) selectors() []Selector { + return nil +} + func (v *NullValue) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2110,6 +2130,10 @@ func (v *Integer) requiresType(t SQLValueType, cols map[string]ColDescriptor, pa return nil } +func (v *Integer) selectors() []Selector { + return nil +} + func (v *Integer) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2194,6 +2218,10 @@ func (v *Timestamp) requiresType(t SQLValueType, cols map[string]ColDescriptor, return nil } +func (v *Timestamp) selectors() []Selector { + return nil +} + func (v *Timestamp) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2271,6 +2299,10 @@ func (v *Varchar) requiresType(t SQLValueType, cols map[string]ColDescriptor, pa return nil } +func (v *Varchar) selectors() []Selector { + return nil +} + func (v *Varchar) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2346,6 +2378,10 @@ func (v *UUID) requiresType(t SQLValueType, cols map[string]ColDescriptor, param return nil } +func (v *UUID) selectors() []Selector { + return nil +} + func (v *UUID) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2415,6 +2451,10 @@ func (v *Bool) requiresType(t SQLValueType, cols map[string]ColDescriptor, param return nil } +func (v *Bool) selectors() []Selector { + return nil +} + func (v *Bool) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2498,6 +2538,10 @@ func (v *Blob) requiresType(t SQLValueType, cols map[string]ColDescriptor, param return nil } +func (v *Blob) selectors() []Selector { + return nil +} + func (v *Blob) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2567,6 +2611,10 @@ func (v *Float64) requiresType(t SQLValueType, cols map[string]ColDescriptor, pa return nil } +func (v *Float64) selectors() []Selector { + return nil +} + func (v *Float64) substitute(params map[string]interface{}) (ValueExp, error) { return v, nil } @@ -2628,51 +2676,31 @@ type FnCall struct { } func (v *FnCall) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { - if strings.ToUpper(v.fn) == NowFnCall { - return TimestampType, nil - } - - if strings.ToUpper(v.fn) == UUIDFnCall { - return UUIDType, nil - } - - if strings.ToUpper(v.fn) == JSONTypeOfFnCall { - return VarcharType, nil + fn, err := v.resolveFunc() + if err != nil { + return AnyType, err } - - return AnyType, fmt.Errorf("%w: unknown function %s", ErrIllegalArguments, v.fn) + return fn.inferType(cols, params, implicitTable) } func (v *FnCall) requiresType(t SQLValueType, cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) error { - if strings.ToUpper(v.fn) == NowFnCall { - if t != TimestampType { - return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, TimestampType, t) - } - - return nil + fn, err := v.resolveFunc() + if err != nil { + return err } + return fn.requiresType(t, cols, params, implicitTable) +} - if strings.ToUpper(v.fn) == UUIDFnCall { - if t != UUIDType { - return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, UUIDType, t) - } - - return nil +func (v *FnCall) selectors() []Selector { + selectors := make([]Selector, 0) + for _, param := range v.params { + selectors = append(selectors, param.selectors()...) } - - if strings.ToUpper(v.fn) == JSONTypeOfFnCall { - if t != VarcharType { - return fmt.Errorf("%w: %v can not be interpreted as type %v", ErrInvalidTypes, VarcharType, t) - } - return nil - } - - return fmt.Errorf("%w: unkown function %s", ErrIllegalArguments, v.fn) + return selectors } func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err error) { ps := make([]ValueExp, len(v.params)) - for i, p := range v.params { ps[i], err = p.substitute(params) if err != nil { @@ -2687,41 +2715,39 @@ func (v *FnCall) substitute(params map[string]interface{}) (val ValueExp, err er } func (v *FnCall) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, error) { - if strings.ToUpper(v.fn) == NowFnCall { - if len(v.params) > 0 { - return nil, fmt.Errorf("%w: '%s' function does not expect any argument but %d were provided", ErrIllegalArguments, NowFnCall, len(v.params)) - } - return &Timestamp{val: tx.Timestamp().Truncate(time.Microsecond).UTC()}, nil + fn, err := v.resolveFunc() + if err != nil { + return nil, err } - if strings.ToUpper(v.fn) == UUIDFnCall { - if len(v.params) > 0 { - return nil, fmt.Errorf("%w: '%s' function does not expect any argument but %d were provided", ErrIllegalArguments, UUIDFnCall, len(v.params)) - } - return &UUID{val: uuid.New()}, nil + fnInputs, err := v.reduceParams(tx, row, implicitTable) + if err != nil { + return nil, err } + return fn.Apply(tx, fnInputs) +} - if strings.ToUpper(v.fn) == JSONTypeOfFnCall { - if len(v.params) != 1 { - return nil, fmt.Errorf("%w: '%s' function expects %d arguments but %d were provided", ErrIllegalArguments, JSONTypeOfFnCall, 1, len(v.params)) - } - - v, err := v.params[0].reduce(tx, row, implicitTable) - if err != nil { - return nil, err - } - - if v.IsNull() { - return NewNull(AnyType), nil +func (v *FnCall) reduceParams(tx *SQLTx, row *Row, implicitTable string) ([]TypedValue, error) { + var values []TypedValue + if len(v.params) > 0 { + values = make([]TypedValue, len(v.params)) + for i, p := range v.params { + v, err := p.reduce(tx, row, implicitTable) + if err != nil { + return nil, err + } + values[i] = v } + } + return values, nil +} - jsonVal, ok := v.(*JSON) - if !ok { - return nil, fmt.Errorf("%w: '%s' function expects an argument of type JSON", ErrIllegalArguments, JSONTypeOfFnCall) - } - return NewVarchar(jsonVal.primitiveType()), nil +func (v *FnCall) resolveFunc() (Function, error) { + fn, exists := builtinFunctions[strings.ToUpper(v.fn)] + if !exists { + return nil, fmt.Errorf("%w: unkown function %s", ErrIllegalArguments, v.fn) } - return nil, fmt.Errorf("%w: unkown function %s", ErrIllegalArguments, v.fn) + return fn, nil } func (v *FnCall) reduceSelectors(row *Row, implicitTable string) ValueExp { @@ -2791,6 +2817,10 @@ func (c *Cast) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, er return conv(val) } +func (v *Cast) selectors() []Selector { + return v.val.selectors() +} + func (c *Cast) reduceSelectors(row *Row, implicitTable string) ValueExp { return &Cast{ val: c.val.reduceSelectors(row, implicitTable), @@ -2891,6 +2921,10 @@ func (p *Param) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedValue, e return nil, ErrUnexpected } +func (p *Param) selectors() []Selector { + return nil +} + func (p *Param) reduceSelectors(row *Row, implicitTable string) ValueExp { return p } @@ -2923,8 +2957,14 @@ type DataSource interface { Alias() string } +type TargetEntry struct { + Exp ValueExp + As string +} + type SelectStmt struct { distinct bool + targets []TargetEntry selectors []Selector ds DataSource indexOn []string @@ -2939,7 +2979,7 @@ type SelectStmt struct { } func NewSelectStmt( - selectors []Selector, + targets []TargetEntry, ds DataSource, where ValueExp, orderBy []*OrdCol, @@ -2947,12 +2987,12 @@ func NewSelectStmt( offset ValueExp, ) *SelectStmt { return &SelectStmt{ - selectors: selectors, - ds: ds, - where: where, - orderBy: orderBy, - limit: limit, - offset: offset, + targets: targets, + ds: ds, + where: where, + orderBy: orderBy, + limit: limit, + offset: offset, } } @@ -2970,7 +3010,7 @@ func (stmt *SelectStmt) inferParameters(ctx context.Context, tx *SQLTx, params m return err } - // TODO (jeroiraz) may be optimized so to resolve the query statement just once + // TODO: (jeroiraz) may be optimized so to resolve the query statement just once rowReader, err := stmt.Resolve(ctx, tx, nil, nil) if err != nil { return err @@ -2986,7 +3026,7 @@ func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string } if stmt.containsAggregations() || len(stmt.groupBy) > 0 { - for _, sel := range stmt.selectors { + for _, sel := range stmt.getSelectors() { _, isAgg := sel.(*AggColSelector) if !isAgg && !stmt.groupByContains(sel) { return nil, fmt.Errorf("%s: %w", EncodeSelector(sel.resolve(stmt.Alias())), ErrColumnMustAppearInGroupByOrAggregation) @@ -3006,10 +3046,17 @@ func (stmt *SelectStmt) execAt(ctx context.Context, tx *SQLTx, params map[string return tx, nil } +func (stmt *SelectStmt) getSelectors() []Selector { + if stmt.selectors == nil { + stmt.selectors = stmt.extractSelectors() + } + return stmt.selectors +} + func (stmt *SelectStmt) containsSelector(s Selector) bool { encSel := EncodeSelector(s.resolve(stmt.Alias())) - for _, sel := range stmt.selectors { + for _, sel := range stmt.getSelectors() { if EncodeSelector(sel.resolve(stmt.Alias())) == encSel { return true } @@ -3028,6 +3075,29 @@ func (stmt *SelectStmt) groupByContains(sel Selector) bool { return false } +func (stmt *SelectStmt) extractGroupByCols() []*AggColSelector { + cols := make([]*AggColSelector, 0, len(stmt.targets)) + + for _, t := range stmt.targets { + selectors := t.Exp.selectors() + for _, sel := range selectors { + aggSel, isAgg := sel.(*AggColSelector) + if isAgg { + cols = append(cols, aggSel) + } + } + } + return cols +} + +func (stmt *SelectStmt) extractSelectors() []Selector { + selectors := make([]Selector, 0, len(stmt.targets)) + for _, t := range stmt.targets { + selectors = append(selectors, t.Exp.selectors()...) + } + return selectors +} + func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (ret RowReader, err error) { scanSpecs, err := stmt.genScanSpecs(tx, params) if err != nil { @@ -3068,7 +3138,7 @@ func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[strin } var groupedRowReader *groupedRowReader - groupedRowReader, err = newGroupedRowReader(rowReader, stmt.selectors, stmt.groupBy) + groupedRowReader, err = newGroupedRowReader(rowReader, allAggregations(stmt.targets), stmt.extractGroupByCols(), stmt.groupBy) if err != nil { return nil, err } @@ -3088,7 +3158,7 @@ func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[strin rowReader = sortRowReader } - projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.selectors) + projectedRowReader, err := newProjectedRowReader(ctx, rowReader, stmt.as, stmt.targets) if err != nil { return nil, err } @@ -3181,7 +3251,7 @@ func ordColumnsHaveAggregations(cols []*OrdCol) bool { } func (stmt *SelectStmt) containsAggregations() bool { - for _, sel := range stmt.selectors { + for _, sel := range stmt.getSelectors() { _, isAgg := sel.(*AggColSelector) if isAgg { return true @@ -3227,7 +3297,7 @@ func (stmt *SelectStmt) Alias() string { } func (stmt *SelectStmt) hasTxMetadata() bool { - for _, sel := range stmt.selectors { + for _, sel := range stmt.getSelectors() { switch s := sel.(type) { case *ColSelector: if s.col == txMetadataCol { @@ -3594,6 +3664,56 @@ func (stmt *tableRef) Alias() string { return stmt.as } +type valuesDataSource struct { + rows []*RowSpec +} + +func NewValuesDataSource(rows []*RowSpec) *valuesDataSource { + return &valuesDataSource{ + rows: rows, + } +} + +func (ds *valuesDataSource) readOnly() bool { + return true +} + +func (ds *valuesDataSource) requiredPrivileges() []SQLPrivilege { + return nil +} + +func (ds *valuesDataSource) execAt(ctx context.Context, tx *SQLTx, params map[string]interface{}) (*SQLTx, error) { + return tx, nil +} + +func (ds *valuesDataSource) inferParameters(ctx context.Context, tx *SQLTx, params map[string]SQLValueType) error { + return nil +} + +func (ds *valuesDataSource) Alias() string { + return "" +} + +func (ds *valuesDataSource) Resolve(ctx context.Context, tx *SQLTx, params map[string]interface{}, scanSpecs *ScanSpecs) (RowReader, error) { + if tx == nil { + return nil, ErrIllegalArguments + } + + cols := make([]ColDescriptor, len(ds.rows[0].Values)) + for i := range cols { + cols[i] = ColDescriptor{ + Type: AnyType, + Column: fmt.Sprintf("col%d", i), + } + } + + values := make([][]ValueExp, len(ds.rows)) + for i, rowSpec := range ds.rows { + values[i] = rowSpec.Values + } + return newValuesRowReader(tx, params, cols, false, "values", values) +} + type JoinSpec struct { joinType JoinType ds DataSource @@ -3616,14 +3736,11 @@ func NewOrdCol(table string, col string, descOrder bool) *OrdCol { type Selector interface { ValueExp resolve(implicitTable string) (aggFn, table, col string) - alias() string - setAlias(alias string) } type ColSelector struct { table string col string - as string } func NewColSelector(table, col string) *ColSelector { @@ -3641,18 +3758,6 @@ func (sel *ColSelector) resolve(implicitTable string) (aggFn, table, col string) return "", table, sel.col } -func (sel *ColSelector) alias() string { - if sel.as == "" { - return sel.col - } - - return sel.as -} - -func (sel *ColSelector) setAlias(alias string) { - sel.as = alias -} - func (sel *ColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { _, table, col := sel.resolve(implicitTable) encSel := EncodeSelector("", table, col) @@ -3700,6 +3805,10 @@ func (sel *ColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (Typed return v, nil } +func (sel *ColSelector) selectors() []Selector { + return []Selector{sel} +} + func (sel *ColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp { aggFn, table, col := sel.resolve(implicitTable) @@ -3727,7 +3836,6 @@ type AggColSelector struct { aggFn AggregateFn table string col string - as string } func NewAggColSelector(aggFn AggregateFn, table, col string) *AggColSelector { @@ -3747,18 +3855,9 @@ func (sel *AggColSelector) resolve(implicitTable string) (aggFn, table, col stri if sel.table != "" { table = sel.table } - return sel.aggFn, table, sel.col } -func (sel *AggColSelector) alias() string { - return sel.as -} - -func (sel *AggColSelector) setAlias(alias string) { - sel.as = alias -} - func (sel *AggColSelector) inferType(cols map[string]ColDescriptor, params map[string]SQLValueType, implicitTable string) (SQLValueType, error) { if sel.aggFn == COUNT { return IntegerType, nil @@ -3818,6 +3917,10 @@ func (sel *AggColSelector) reduce(tx *SQLTx, row *Row, implicitTable string) (Ty return v, nil } +func (sel *AggColSelector) selectors() []Selector { + return []Selector{sel} +} + func (sel *AggColSelector) reduceSelectors(row *Row, implicitTable string) ValueExp { return sel } @@ -3845,7 +3948,7 @@ func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]S if err != nil { return AnyType, err } - if tleft != AnyType && tleft != IntegerType && tleft != Float64Type { + if tleft != AnyType && tleft != IntegerType && tleft != Float64Type && tleft != JSONType { return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tleft) } @@ -3853,7 +3956,7 @@ func (bexp *NumExp) inferType(cols map[string]ColDescriptor, params map[string]S if err != nil { return AnyType, err } - if tright != AnyType && tright != IntegerType && tright != Float64Type { + if tright != AnyType && tright != IntegerType && tright != Float64Type && tright != JSONType { return AnyType, fmt.Errorf("%w: %v or %v can not be interpreted as type %v", ErrInvalidTypes, IntegerType, Float64Type, tright) } @@ -3967,6 +4070,10 @@ func unwrapJSON(v TypedValue) TypedValue { return v } +func (bexp *NumExp) selectors() []Selector { + return append(bexp.left.selectors(), bexp.right.selectors()...) +} + func (bexp *NumExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return &NumExp{ op: bexp.op, @@ -4033,6 +4140,10 @@ func (bexp *NotBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (Typed return &Bool{val: !r}, nil } +func (bexp *NotBoolExp) selectors() []Selector { + return bexp.exp.selectors() +} + func (bexp *NotBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return &NotBoolExp{ exp: bexp.exp.reduceSelectors(row, implicitTable), @@ -4153,6 +4264,10 @@ func (bexp *LikeBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (Type return &Bool{val: matched != bexp.notLike}, nil } +func (bexp *LikeBoolExp) selectors() []Selector { + return bexp.val.selectors() +} + func (bexp *LikeBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return bexp } @@ -4278,6 +4393,10 @@ func (bexp *CmpBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (Typed return &Bool{val: cmpSatisfiesOp(r, bexp.op)}, nil } +func (bexp *CmpBoolExp) selectors() []Selector { + return append(bexp.left.selectors(), bexp.right.selectors()...) +} + func (bexp *CmpBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return &CmpBoolExp{ op: bexp.op, @@ -4529,6 +4648,10 @@ func (bexp *BinBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (Typed return nil, ErrUnexpected } +func (bexp *BinBoolExp) selectors() []Selector { + return append(bexp.left.selectors(), bexp.right.selectors()...) +} + func (bexp *BinBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return &BinBoolExp{ op: bexp.op, @@ -4605,6 +4728,10 @@ func (bexp *ExistsBoolExp) reduce(tx *SQLTx, row *Row, implicitTable string) (Ty return nil, fmt.Errorf("'EXISTS' clause: %w", ErrNoSupported) } +func (bexp *ExistsBoolExp) selectors() []Selector { + return nil +} + func (bexp *ExistsBoolExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return bexp } @@ -4643,6 +4770,10 @@ func (bexp *InSubQueryExp) reduce(tx *SQLTx, row *Row, implicitTable string) (Ty return nil, fmt.Errorf("error inferring type in 'IN' clause: %w", ErrNoSupported) } +func (bexp *InSubQueryExp) selectors() []Selector { + return bexp.val.selectors() +} + func (bexp *InSubQueryExp) reduceSelectors(row *Row, implicitTable string) ValueExp { return bexp } @@ -4745,6 +4876,14 @@ func (bexp *InListExp) reduce(tx *SQLTx, row *Row, implicitTable string) (TypedV return &Bool{val: found != bexp.notIn}, nil } +func (bexp *InListExp) selectors() []Selector { + selectors := make([]Selector, 0, len(bexp.values)) + for _, v := range bexp.values { + selectors = append(selectors, v.selectors()...) + } + return append(bexp.val.selectors(), selectors...) +} + func (bexp *InListExp) reduceSelectors(row *Row, implicitTable string) ValueExp { values := make([]ValueExp, len(bexp.values)) @@ -4901,7 +5040,7 @@ func (stmt *FnDataSourceStmt) resolveListDatabases(ctx context.Context, tx *SQLT values[i] = []ValueExp{&Varchar{val: db}} } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) { @@ -4923,7 +5062,7 @@ func (stmt *FnDataSourceStmt) resolveListTables(ctx context.Context, tx *SQLTx, values[i] = []ValueExp{&Varchar{val: t.name}} } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) { @@ -5001,7 +5140,7 @@ func (stmt *FnDataSourceStmt) resolveShowTable(ctx context.Context, tx *SQLTx, p } } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (rowReader RowReader, err error) { @@ -5042,7 +5181,7 @@ func (stmt *FnDataSourceStmt) resolveListUsers(ctx context.Context, tx *SQLTx, p } } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) { @@ -5137,7 +5276,7 @@ func (stmt *FnDataSourceStmt) resolveListColumns(ctx context.Context, tx *SQLTx, } } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) { @@ -5194,7 +5333,7 @@ func (stmt *FnDataSourceStmt) resolveListIndexes(ctx context.Context, tx *SQLTx, } } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, params map[string]interface{}, _ *ScanSpecs) (RowReader, error) { @@ -5256,7 +5395,7 @@ func (stmt *FnDataSourceStmt) resolveListGrants(ctx context.Context, tx *SQLTx, } } - return newValuesRowReader(tx, params, cols, stmt.Alias(), values) + return newValuesRowReader(tx, params, cols, true, stmt.Alias(), values) } // DropTableStmt represents a statement to delete a table. diff --git a/embedded/sql/stmt_test.go b/embedded/sql/stmt_test.go index bcd81a9d20..4d4e1f791b 100644 --- a/embedded/sql/stmt_test.go +++ b/embedded/sql/stmt_test.go @@ -548,7 +548,7 @@ func TestRequiresTypeSysFnValueExp(t *testing.T) { expectedError error }{ { - exp: &FnCall{fn: "NOW"}, + exp: &FnCall{fn: NowFnCall}, cols: cols, params: params, implicitTable: "mytable", @@ -556,7 +556,7 @@ func TestRequiresTypeSysFnValueExp(t *testing.T) { expectedError: nil, }, { - exp: &FnCall{fn: "NOW"}, + exp: &FnCall{fn: NowFnCall}, cols: cols, params: params, implicitTable: "mytable", @@ -564,20 +564,124 @@ func TestRequiresTypeSysFnValueExp(t *testing.T) { expectedError: ErrInvalidTypes, }, { - exp: &FnCall{fn: "LOWER"}, + exp: &FnCall{fn: LengthFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: IntegerType, + expectedError: nil, + }, + { + exp: &FnCall{fn: LengthFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: ErrInvalidTypes, + }, + { + exp: &FnCall{fn: SubstringFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: nil, + }, + { + exp: &FnCall{fn: SubstringFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: IntegerType, + expectedError: ErrInvalidTypes, + }, + { + exp: &FnCall{fn: ConcatFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: nil, + }, + { + exp: &FnCall{fn: ConcatFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: IntegerType, + expectedError: ErrInvalidTypes, + }, + { + exp: &FnCall{fn: TrimFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: nil, + }, + { + exp: &FnCall{fn: TrimFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: IntegerType, + expectedError: ErrInvalidTypes, + }, + { + exp: &FnCall{fn: UpperFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: nil, + }, + { + exp: &FnCall{fn: LowerFnCall}, cols: cols, params: params, implicitTable: "mytable", requiredType: VarcharType, - expectedError: ErrIllegalArguments, + expectedError: nil, }, { - exp: &FnCall{fn: "LOWER"}, + exp: &FnCall{fn: LowerFnCall}, cols: cols, params: params, implicitTable: "mytable", requiredType: Float64Type, - expectedError: ErrIllegalArguments, + expectedError: ErrInvalidTypes, + }, + { + exp: &FnCall{fn: JSONTypeOfFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: nil, + }, + { + exp: &FnCall{fn: JSONTypeOfFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: IntegerType, + expectedError: ErrInvalidTypes, + }, + { + exp: &FnCall{fn: UUIDFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: UUIDType, + expectedError: nil, + }, + { + exp: &FnCall{fn: UUIDFnCall}, + cols: cols, + params: params, + implicitTable: "mytable", + requiredType: VarcharType, + expectedError: ErrInvalidTypes, }, } @@ -1394,6 +1498,16 @@ func TestRequiredPrivileges(t *testing.T) { readOnly: false, privileges: []SQLPrivilege{SQLPrivilegeInsert}, }, + { + stmt: &UpsertIntoStmt{ds: &SelectStmt{}}, + readOnly: false, + privileges: []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeUpdate, SQLPrivilegeSelect}, + }, + { + stmt: &UpsertIntoStmt{ds: &SelectStmt{}, isInsert: true}, + readOnly: false, + privileges: []SQLPrivilege{SQLPrivilegeInsert, SQLPrivilegeSelect}, + }, { stmt: &DeleteFromStmt{}, readOnly: false, @@ -1474,6 +1588,11 @@ func TestRequiredPrivileges(t *testing.T) { readOnly: false, privileges: []SQLPrivilege{SQLPrivilegeAlter}, }, + { + stmt: &RenameTableStmt{}, + readOnly: false, + privileges: []SQLPrivilege{SQLPrivilegeAlter}, + }, { stmt: &RenameColumnStmt{}, readOnly: false, @@ -1493,3 +1612,114 @@ func TestRequiredPrivileges(t *testing.T) { }) } } + +func TestExprSelectors(t *testing.T) { + type testCase struct { + Expr ValueExp + selectors []Selector + } + + tests := []testCase{ + { + Expr: &Integer{}, + }, + { + Expr: &Bool{}, + }, + { + Expr: &Float64{}, + }, + { + Expr: &NullValue{}, + }, + { + Expr: &Blob{}, + }, + { + Expr: &UUID{}, + }, + { + Expr: &JSON{}, + }, + { + Expr: &Timestamp{}, + }, + { + Expr: &Varchar{}, + }, + { + Expr: &Param{}, + }, + { + Expr: &ColSelector{col: "col"}, + selectors: []Selector{ + &ColSelector{col: "col"}, + }, + }, + { + Expr: &JSONSelector{ColSelector: &ColSelector{col: "col"}}, + selectors: []Selector{ + &JSONSelector{ColSelector: &ColSelector{col: "col"}}, + }, + }, + { + Expr: &BinBoolExp{ + left: &ColSelector{col: "col"}, + right: &ColSelector{col: "col1"}, + }, + selectors: []Selector{ + &ColSelector{col: "col"}, + &ColSelector{col: "col1"}, + }, + }, + { + Expr: &NumExp{ + left: &ColSelector{col: "col"}, + right: &ColSelector{col: "col1"}, + }, + selectors: []Selector{ + &ColSelector{col: "col"}, + &ColSelector{col: "col1"}, + }, + }, + { + Expr: &LikeBoolExp{ + val: &ColSelector{col: "col"}, + }, + selectors: []Selector{ + &ColSelector{col: "col"}, + }, + }, + { + Expr: &ExistsBoolExp{}, + }, + { + Expr: &InSubQueryExp{val: &ColSelector{col: "col"}}, + selectors: []Selector{ + &ColSelector{col: "col"}, + }, + }, + { + Expr: &InListExp{ + val: &ColSelector{col: "col"}, + values: []ValueExp{ + &ColSelector{col: "col1"}, + &ColSelector{col: "col2"}, + &ColSelector{col: "col3"}, + }, + }, + selectors: []Selector{ + &ColSelector{col: "col"}, + &ColSelector{col: "col1"}, + &ColSelector{col: "col2"}, + &ColSelector{col: "col3"}, + }, + }, + } + + for _, tc := range tests { + t.Run(reflect.TypeOf(tc.Expr).Elem().Name(), func(t *testing.T) { + require.Equal(t, tc.selectors, tc.Expr.selectors()) + }) + } +} diff --git a/embedded/sql/type_conversion.go b/embedded/sql/type_conversion.go index cdf4669d94..930ec4ec82 100644 --- a/embedded/sql/type_conversion.go +++ b/embedded/sql/type_conversion.go @@ -98,6 +98,27 @@ func getConverter(src, dst SQLValueType) (converterFunc, error) { }, nil } + if src == JSONType { + jsonToStr, err := getConverter(src, VarcharType) + if err != nil { + return nil, err + } + + strToTimestamp, err := getConverter(VarcharType, TimestampType) + if err != nil { + return nil, err + } + + return func(tv TypedValue) (TypedValue, error) { + v, err := jsonToStr(tv) + if err != nil { + return nil, err + } + s, _ := v.RawValue().(string) + return strToTimestamp(NewVarchar(strings.Trim(s, `"`))) + }, nil + } + return nil, fmt.Errorf( "%w: only INTEGER and VARCHAR types can be cast as TIMESTAMP", ErrUnsupportedCast, @@ -267,11 +288,18 @@ func getConverter(src, dst SQLValueType) (converterFunc, error) { }, nil } + if src == JSONType { + return func(val TypedValue) (TypedValue, error) { + jsonStr := val.String() + return &Blob{val: []byte(jsonStr)}, nil + }, nil + } + return nil, fmt.Errorf( - "%w: only UUID and VARCHAR types can be cast as BLOB", + "%w: cannot cast type %s to BLOB", ErrUnsupportedCast, + src, ) - } if dst == VarcharType { @@ -312,6 +340,12 @@ func getConverter(src, dst SQLValueType) (converterFunc, error) { err := json.Unmarshal([]byte(s), &x) return &JSON{val: x}, err + case BLOBType: + rawJson, ok := tv.RawValue().([]byte) + if !ok { + return nil, fmt.Errorf("invalid %s value", JSONType) + } + return NewJsonFromString(string(rawJson)) } return nil, fmt.Errorf( diff --git a/embedded/sql/values_row_reader.go b/embedded/sql/values_row_reader.go index 3e23df8ccd..398e56843d 100644 --- a/embedded/sql/values_row_reader.go +++ b/embedded/sql/values_row_reader.go @@ -31,17 +31,13 @@ type valuesRowReader struct { values [][]ValueExp read int - params map[string]interface{} - + params map[string]interface{} + checkTypes bool onCloseCallback func() closed bool } -func newValuesRowReader(tx *SQLTx, params map[string]interface{}, cols []ColDescriptor, tableAlias string, values [][]ValueExp) (*valuesRowReader, error) { - if len(cols) == 0 { - return nil, fmt.Errorf("%w: empty column list", ErrIllegalArguments) - } - +func newValuesRowReader(tx *SQLTx, params map[string]interface{}, cols []ColDescriptor, checkTypes bool, tableAlias string, values [][]ValueExp) (*valuesRowReader, error) { if tableAlias == "" { return nil, fmt.Errorf("%w: table alias is mandatory", ErrIllegalArguments) } @@ -77,6 +73,7 @@ func newValuesRowReader(tx *SQLTx, params map[string]interface{}, cols []ColDesc colsBySel: colsBySel, tableAlias: tableAlias, values: values, + checkTypes: checkTypes, }, nil } @@ -115,7 +112,10 @@ func (vr *valuesRowReader) colsBySelector(ctx context.Context) (map[string]ColDe func (vr *valuesRowReader) InferParameters(ctx context.Context, params map[string]SQLValueType) error { for _, vs := range vr.values { for _, v := range vs { - v.inferType(vr.colsBySel, params, vr.tableAlias) + _, err := v.inferType(vr.colsBySel, params, vr.tableAlias) + if err != nil { + return err + } } } return nil @@ -142,9 +142,11 @@ func (vr *valuesRowReader) Read(ctx context.Context) (*Row, error) { return nil, err } - err = rv.requiresType(vr.colsByPos[i].Type, vr.colsBySel, nil, vr.tableAlias) - if err != nil { - return nil, err + if vr.checkTypes { + err = rv.requiresType(vr.colsByPos[i].Type, vr.colsBySel, nil, vr.tableAlias) + if err != nil { + return nil, err + } } valuesByPosition[i] = rv diff --git a/embedded/sql/values_row_reader_test.go b/embedded/sql/values_row_reader_test.go index 176b8cfb01..00b9d94f8d 100644 --- a/embedded/sql/values_row_reader_test.go +++ b/embedded/sql/values_row_reader_test.go @@ -24,23 +24,23 @@ import ( ) func TestValuesRowReader(t *testing.T) { - _, err := newValuesRowReader(nil, nil, nil, "", nil) + _, err := newValuesRowReader(nil, nil, nil, true, "", nil) require.ErrorIs(t, err, ErrIllegalArguments) cols := []ColDescriptor{ {Column: "col1"}, } - _, err = newValuesRowReader(nil, nil, cols, "", nil) + _, err = newValuesRowReader(nil, nil, cols, true, "", nil) require.ErrorIs(t, err, ErrIllegalArguments) - _, err = newValuesRowReader(nil, nil, cols, "", nil) + _, err = newValuesRowReader(nil, nil, cols, true, "", nil) require.ErrorIs(t, err, ErrIllegalArguments) - _, err = newValuesRowReader(nil, nil, cols, "table1", nil) + _, err = newValuesRowReader(nil, nil, cols, true, "table1", nil) require.NoError(t, err) - _, err = newValuesRowReader(nil, nil, cols, "table1", + _, err = newValuesRowReader(nil, nil, cols, true, "table1", [][]ValueExp{ { &Bool{val: true}, @@ -52,7 +52,7 @@ func TestValuesRowReader(t *testing.T) { _, err = newValuesRowReader(nil, nil, []ColDescriptor{ {Table: "table1", Column: "col1"}, - }, "", nil) + }, true, "", nil) require.ErrorIs(t, err, ErrIllegalArguments) values := [][]ValueExp{ @@ -65,7 +65,7 @@ func TestValuesRowReader(t *testing.T) { "param1": 1, } - rowReader, err := newValuesRowReader(nil, params, cols, "table1", values) + rowReader, err := newValuesRowReader(nil, params, cols, true, "table1", values) require.NoError(t, err) require.Nil(t, rowReader.OrderBy()) require.Nil(t, rowReader.ScanSpecs()) diff --git a/embedded/store/options.go b/embedded/store/options.go index 72c8eff2e5..4e9e7a8107 100644 --- a/embedded/store/options.go +++ b/embedded/store/options.go @@ -240,10 +240,8 @@ func DefaultOptions() *Options { CompressionLevel: DefaultCompressionLevel, EmbeddedValues: DefaultEmbeddedValues, PreallocFiles: DefaultPreallocFiles, - - IndexOpts: DefaultIndexOptions(), - - AHTOpts: DefaultAHTOptions(), + IndexOpts: DefaultIndexOptions(), + AHTOpts: DefaultAHTOptions(), } }