From 91c6f0fcaadfb7bd983e070e6ceffc8aeba7d5a2 Mon Sep 17 00:00:00 2001 From: dfinkel Date: Wed, 6 Nov 2024 00:14:19 -0500 Subject: [PATCH 1/2] feat(spanner/spansql): add support for protobuf column types & Proto bundles (#10945) ### feat(spansql): CREATE/ALTER/DROP PROTO BUNDLE Add support for parsing and serializing CREATE, ALTER and DROP PROTO BUNDLE DDL statements. ### feat(spanner/spansql): support for protobuf types Now that Spanner supports protobuf message and enum-typed columns and casts, add support for parsing those those types. Since protobuf columns aren't distinguished by a keyword, adjust the parser to see any unquoted identifier that's not a known type as a possible protobuf type and loop, consuming `.`s and identifiers until it hits a non-ident/`.` token. (to match the proto namespace components up through the message or enum names) To track the fully-qualified message/enum type-name add an additional field to the `Type` struct (tentatively) called `ProtoRef` so we can recover the message/enum name if canonicalizing everything. closes: #10944 --- spanner/spansql/parser.go | 192 ++++++++++++++++++++++++++++++++- spanner/spansql/parser_test.go | 143 ++++++++++++++++++++++++ spanner/spansql/sql.go | 34 ++++++ spanner/spansql/sql_test.go | 85 +++++++++++++++ spanner/spansql/types.go | 41 +++++++ 5 files changed, 490 insertions(+), 5 deletions(-) diff --git a/spanner/spansql/parser.go b/spanner/spansql/parser.go index 8e66c5d54926..5c1d0bb03915 100644 --- a/spanner/spansql/parser.go +++ b/spanner/spansql/parser.go @@ -46,6 +46,7 @@ This file is structured as follows: import ( "fmt" "os" + "regexp" "strconv" "strings" "time" @@ -1007,6 +1008,7 @@ func (p *parser) parseDDLStmt() (DDLStmt, *parseError) { // DROP VIEW view_name // DROP ROLE role_name // DROP CHANGE STREAM change_stream_name + // DROP PROTO BUNDLE tok := p.next() if tok.err != nil { return nil, tok.err @@ -1065,6 +1067,12 @@ func (p *parser) parseDDLStmt() (DDLStmt, *parseError) { return nil, err } return &DropSequence{Name: name, IfExists: ifExists, Position: pos}, nil + case tok.caseEqual("PROTO"): + // the syntax for this is dead simple: DROP PROTO BUNDLE + if bundleErr := p.expect("BUNDLE"); bundleErr != nil { + return nil, bundleErr + } + return &DropProtoBundle{Position: pos}, nil } } else if p.sniff("RENAME", "TABLE") { a, err := p.parseRenameTable() @@ -1096,6 +1104,12 @@ func (p *parser) parseDDLStmt() (DDLStmt, *parseError) { } else if p.sniff("ALTER", "SEQUENCE") { as, err := p.parseAlterSequence() return as, err + } else if p.sniff("CREATE", "PROTO", "BUNDLE") { + cp, err := p.parseCreateProtoBundle() + return cp, err + } else if p.sniff("ALTER", "PROTO", "BUNDLE") { + ap, err := p.parseAlterProtoBundle() + return ap, err } return nil, p.errorf("unknown DDL statement") @@ -2877,6 +2891,93 @@ func (p *parser) parseCreateSequence() (*CreateSequence, *parseError) { return cs, nil } +func (p *parser) parseCreateProtoBundle() (*CreateProtoBundle, *parseError) { + debugf("parseCreateProtoBundle: %v", p) + + /* + CREATE PROTO BUNDLE ( + [proto_type_name"," ...] + ) + */ + + if err := p.expect("CREATE"); err != nil { + return nil, err + } + pos := p.Pos() + if err := p.expect("PROTO", "BUNDLE"); err != nil { + return nil, err + } + + typeNames, listErr := p.parseProtobufTypeNameList() + if listErr != nil { + return nil, listErr + } + + return &CreateProtoBundle{Types: typeNames, Position: pos}, nil +} + +func (p *parser) parseAlterProtoBundle() (*AlterProtoBundle, *parseError) { + debugf("parseAlterProtoBundle: %v", p) + + /* + ALTER PROTO BUNDLE + INSERT ( + [proto_type_name"," ...] + ) + UPDATE ( + [proto_type_name"," ...] + ) + DELETE ( + [proto_type_name"," ...] + ) + */ + + if err := p.expect("ALTER"); err != nil { + return nil, err + } + pos := p.Pos() + if err := p.expect("PROTO", "BUNDLE"); err != nil { + return nil, err + } + alter := AlterProtoBundle{Position: pos} + for { + var typeSlice *[]string + switch { + case p.eat("INSERT"): + if alter.AddTypes != nil { + return nil, p.errorf("multiple INSERTs in ALTER PROTO BUNDLE") + } + typeSlice = &alter.AddTypes + case p.eat("UPDATE"): + if alter.UpdateTypes != nil { + return nil, p.errorf("multiple UPDATEs in ALTER PROTO BUNDLE") + } + typeSlice = &alter.UpdateTypes + case p.eat("DELETE"): + if alter.DeleteTypes != nil { + return nil, p.errorf("multiple DELETEs in ALTER PROTO BUNDLE") + } + typeSlice = &alter.DeleteTypes + default: + tok := p.next() + if tok.err == eof { + return &alter, nil + } else if tok.err != nil { + return nil, tok.err + } else if tok.typ == unknownToken && tok.value == ";" { + p.back() + return &alter, nil + } + return nil, p.errorf("invalid clause in ALTER PROTO BUNDLE %q", tok.value) + } + + typeNames, listErr := p.parseProtobufTypeNameList() + if listErr != nil { + return nil, listErr + } + *typeSlice = typeNames + } +} func (p *parser) parseAlterSequence() (*AlterSequence, *parseError) { debugf("parseAlterSequence: %v", p) @@ -3004,6 +3105,8 @@ var baseTypes = map[string]TypeBase{ "DATE": Date, "TIMESTAMP": Timestamp, "JSON": JSON, + "PROTO": Proto, // for use in CAST + "ENUM": Enum, // for use in CAST } func (p *parser) parseBaseType() (Type, *parseError) { @@ -3035,6 +3138,59 @@ func (p *parser) parseExtractType() (Type, string, *parseError) { return t, strings.ToUpper(tok.value), nil } +// protobuf identifiers allow one letter-class character followed by any number of alphanumeric characters or an +// underscore (which matches the [:word:] class in RE2/go regexp). +// Fully qualified protobuf type names (enums or messages) include a dot-separated namespace, where each component is +// also an identifier. +// https://github.com/protocolbuffers/protobuf/blob/eeb7dc88f286df558d933214fff829205ffa5506/src/google/protobuf/io/tokenizer.cc#L653-L655 +// https://github.com/protocolbuffers/protobuf/blob/eeb7dc88f286df558d933214fff829205ffa5506/src/google/protobuf/io/tokenizer.cc#L115-L120 +var fqProtoMsgName = regexp.MustCompile(`^(?:[[:alpha:]][[:word:]]*)(?:\.(?:[[:alpha:]][[:word:]]*))*$`) + +func (p *parser) parseProtobufTypeName(consumed string) (string, *parseError) { + // Whether it's quoted or not, we might have multiple namespace components (with `.` separation) + possibleProtoTypeName := strings.Builder{} + possibleProtoTypeName.WriteString(consumed) + ntok := p.next() +PROTO_TOK_CONSUME: + for ; ntok.err == nil; ntok = p.next() { + appendVal := ntok.value + switch ntok.typ { + case unquotedID: + case quotedID: + if !fqProtoMsgName.MatchString(ntok.string) { + return "", p.errorf("got %q, want fully qualified protobuf type", ntok.string) + } + appendVal = ntok.string + case unknownToken: + if ntok.value != "." { + p.back() + break PROTO_TOK_CONSUME + } + default: + p.back() + break PROTO_TOK_CONSUME + } + possibleProtoTypeName.WriteString(appendVal) + } + if ntok.err != nil { + return "", ntok.err + } + return possibleProtoTypeName.String(), nil +} + +func (p *parser) parseProtobufTypeNameList() ([]string, *parseError) { + var list []string + err := p.parseCommaList("(", ")", func(p *parser) *parseError { + tn, err := p.parseProtobufTypeName("") + if err != nil { + return err + } + list = append(list, tn) + return nil + }) + return list, err +} + func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) { debugf("parseBaseOrParameterizedType: %v", p) @@ -3064,12 +3220,38 @@ func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError return Type{}, tok.err } } - base, ok := baseTypes[strings.ToUpper(tok.value)] // baseTypes is keyed by upper case strings. - if !ok { - return Type{}, p.errorf("got %q, want scalar type", tok.value) - } - t.Base = base + switch tok.typ { + case unquotedID: + base, ok := baseTypes[strings.ToUpper(tok.value)] // baseTypes is keyed by upper case strings. + if ok { + t.Base = base + break + } + // Likely a protobuf type; make sure its value matches the regexp + // protobuf types can be either quoted or unquoted, so we need to handle both. + // Fortunately, the identifier tokenization rules match between protobuf and SQL, so we only need to + // verify quoted identifiers. + pbTypeName, pbParseErr := p.parseProtobufTypeName(tok.value) + if pbParseErr != nil { + return Type{}, pbParseErr + } + t.ProtoRef = pbTypeName + t.Base = Proto + return t, nil + case quotedID: + if !fqProtoMsgName.MatchString(tok.string) { + return Type{}, p.errorf("got %q, want fully qualified protobuf type", tok.value) + } + pbTypeName, pbParseErr := p.parseProtobufTypeName(tok.string) + if pbParseErr != nil { + return Type{}, pbParseErr + } + t.ProtoRef = pbTypeName + t.Base = Proto + return t, nil + default: + } if withParam && (t.Base == String || t.Base == Bytes) { if err := p.expect("("); err != nil { return Type{}, err diff --git a/spanner/spansql/parser_test.go b/spanner/spansql/parser_test.go index 811cfc8ce580..88066d559d5d 100644 --- a/spanner/spansql/parser_test.go +++ b/spanner/spansql/parser_test.go @@ -26,6 +26,42 @@ import ( "cloud.google.com/go/civil" ) +func TestFQProtoMsgName(t *testing.T) { + for _, tbl := range []struct { + in string + expMatch bool + }{ + { + in: "fizzle", + expMatch: true, + }, + { + in: "fizzle.bit", + expMatch: true, + }, + { + in: "fizzle.boo1.boop333", + expMatch: true, + }, + { + in: "fizz9le.boo1.boop333", + expMatch: true, + }, + { + in: "9fizz9le", + expMatch: false, + }, + { + in: "99.999", + expMatch: false, + }, + } { + if matches := fqProtoMsgName.MatchString(tbl.in); matches != tbl.expMatch { + t.Errorf("expected %q to match %t; got %t", tbl.in, tbl.expMatch, matches) + } + } +} + func TestParseQuery(t *testing.T) { tests := []struct { in string @@ -408,6 +444,8 @@ func TestParseExpr(t *testing.T) { // Functions {`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}}, {`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}}, + {`CAST(Bar AS ENUM)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Enum}}}}}, + {`CAST(Bar AS PROTO)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Proto}}}}}, {`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}}, {`EXTRACT(DATE FROM TIMESTAMP AT TIME ZONE "America/Los_Angeles")`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("TIMESTAMP"), Zone: "America/Los_Angeles", Type: Type{Base: Timestamp}}}}}}, {`EXTRACT(DAY FROM DATE)`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DAY", Expr: ID("DATE"), Type: Type{Base: Int64}}}}}, @@ -1864,6 +1902,46 @@ func TestParseDDL(t *testing.T) { }, }, }, + { + `CREATE TABLE IF NOT EXISTS tname (id INT64, name foo.bar.baz.ProtoName) PRIMARY KEY (id)`, + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &CreateTable{ + Name: "tname", + IfNotExists: true, + Columns: []ColumnDef{ + {Name: "id", Type: Type{Base: Int64}, Position: line(1)}, + {Name: "name", Type: Type{Base: Proto, ProtoRef: "foo.bar.baz.ProtoName"}, Position: line(1)}, + }, + PrimaryKey: []KeyPart{ + {Column: "id"}, + }, + Position: line(1), + }, + }, + }, + }, + { + "CREATE TABLE IF NOT EXISTS tname (id INT64, name `foo.bar.baz.ProtoName`) PRIMARY KEY (id)", + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &CreateTable{ + Name: "tname", + IfNotExists: true, + Columns: []ColumnDef{ + {Name: "id", Type: Type{Base: Int64}, Position: line(1)}, + {Name: "name", Type: Type{Base: Proto, ProtoRef: "foo.bar.baz.ProtoName"}, Position: line(1)}, + }, + PrimaryKey: []KeyPart{ + {Column: "id"}, + }, + Position: line(1), + }, + }, + }, + }, { `CREATE INDEX IF NOT EXISTS iname ON tname (cname)`, &DDL{ @@ -2023,6 +2101,71 @@ func TestParseDDL(t *testing.T) { }, }, }, + { + `CREATE PROTO BUNDLE (foo.bar.baz.Fiddle, ` + "`foo.bar.baz.Foozle`" + `); + ALTER PROTO BUNDLE INSERT (a.b.c, b.d.e, k) UPDATE (foo.bar.baz.Fiddle) DELETE (foo.bar.baz.Foozle); + DROP PROTO BUNDLE;`, + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &CreateProtoBundle{ + Types: []string{"foo.bar.baz.Fiddle", "foo.bar.baz.Foozle"}, + Position: line(1), + }, + &AlterProtoBundle{ + AddTypes: []string{"a.b.c", "b.d.e", "k"}, + UpdateTypes: []string{"foo.bar.baz.Fiddle"}, + DeleteTypes: []string{"foo.bar.baz.Foozle"}, + Position: line(2), + }, + &DropProtoBundle{ + Position: line(3), + }, + }, + }, + }, + { + `ALTER PROTO BUNDLE UPDATE (foo.bar.baz.Fiddle) INSERT (a.b.c, b.d.e, k) DELETE (foo.bar.baz.Foozle);`, + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &AlterProtoBundle{ + AddTypes: []string{"a.b.c", "b.d.e", "k"}, + UpdateTypes: []string{"foo.bar.baz.Fiddle"}, + DeleteTypes: []string{"foo.bar.baz.Foozle"}, + Position: line(1), + }, + }, + }, + }, + { + `ALTER PROTO BUNDLE DELETE (foo.bar.baz.Foozle) UPDATE (foo.bar.baz.Fiddle) INSERT (a.b.c, b.d.e, k)`, + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &AlterProtoBundle{ + AddTypes: []string{"a.b.c", "b.d.e", "k"}, + UpdateTypes: []string{"foo.bar.baz.Fiddle"}, + DeleteTypes: []string{"foo.bar.baz.Foozle"}, + Position: line(1), + }, + }, + }, + }, + { + `ALTER PROTO BUNDLE INSERT (a.b.c, b.d.e, k) DELETE (foo.bar.baz.Foozle);`, + &DDL{ + Filename: "filename", + List: []DDLStmt{ + &AlterProtoBundle{ + AddTypes: []string{"a.b.c", "b.d.e", "k"}, + UpdateTypes: nil, + DeleteTypes: []string{"foo.bar.baz.Foozle"}, + Position: line(1), + }, + }, + }, + }, } for _, test := range tests { got, err := ParseDDL("filename", test.in) diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go index 577a45e2ef7f..f044e312cfd0 100644 --- a/spanner/spansql/sql.go +++ b/spanner/spansql/sql.go @@ -98,6 +98,11 @@ func (ci CreateIndex) SQL() string { return str } +func (cp CreateProtoBundle) SQL() string { + // Backtick-quote all the types so we don't need to check for SQL keywords + return "CREATE PROTO BUNDLE (`" + strings.Join(cp.Types, "`, `") + "`)" +} + func (cv CreateView) SQL() string { str := "CREATE" if cv.OrReplace { @@ -558,6 +563,23 @@ func (d *Delete) SQL() string { return "DELETE FROM " + d.Table.SQL() + " WHERE " + d.Where.SQL() } +func (do DropProtoBundle) SQL() string { + return "DROP PROTO BUNDLE" +} +func (ap AlterProtoBundle) SQL() string { + str := "ALTER PROTO BUNDLE" + if len(ap.AddTypes) > 0 { + str += " INSERT (`" + strings.Join(ap.AddTypes, "`, `") + "`)" + } + if len(ap.UpdateTypes) > 0 { + str += " UPDATE (`" + strings.Join(ap.UpdateTypes, "`, `") + "`)" + } + if len(ap.DeleteTypes) > 0 { + str += " DELETE (`" + strings.Join(ap.DeleteTypes, "`, `") + "`)" + } + return str +} + func (u *Update) SQL() string { str := "UPDATE " + u.Table.SQL() + " SET " for i, item := range u.Items { @@ -651,6 +673,14 @@ func (c Check) SQL() string { func (t Type) SQL() string { str := t.Base.SQL() + + // If ProtoRef is empty, and Base is an Enum or Proto, we're probably + // in an expression where PROTO or ENUM are valid type names, in which + // case we can just fall through. + if t.Base == Proto || t.Base == Enum && t.ProtoRef != "" { + // If ProtoRef is non-empty, backtick-quote that and declare victory. + return "`" + t.ProtoRef + "`" + } if t.Len > 0 && (t.Base == String || t.Base == Bytes) { str += "(" if t.Len == MaxLen { @@ -686,6 +716,10 @@ func (tb TypeBase) SQL() string { return "TIMESTAMP" case JSON: return "JSON" + case Proto: + return "PROTO" + case Enum: + return "ENUM" } panic("unknown TypeBase") } diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go index 805f28b8352c..20561f723257 100644 --- a/spanner/spansql/sql_test.go +++ b/spanner/spansql/sql_test.go @@ -94,6 +94,7 @@ func TestSQL(t *testing.T) { {Name: "Cm", Type: Type{Base: Int64}, Generated: Func{Name: "CHAR_LENGTH", Args: []Expr{ID("Ce")}}, Position: line(14)}, {Name: "Cn", Type: Type{Base: JSON}, Position: line(15)}, {Name: "Co", Type: Type{Base: Int64}, Default: IntegerLiteral(1), Position: line(16)}, + {Name: "Cp", Type: Type{Base: Proto, ProtoRef: "a.b.c"}, Position: line(17)}, }, PrimaryKey: []KeyPart{ {Column: "Ca"}, @@ -117,6 +118,7 @@ func TestSQL(t *testing.T) { Cm INT64 AS (CHAR_LENGTH(Ce)) STORED, Cn JSON, Co INT64 DEFAULT (1), + Cp ` + "`a.b.c`" + `, ) PRIMARY KEY(Ca, Cb DESC)`, reparseDDL, }, @@ -806,6 +808,13 @@ func TestSQL(t *testing.T) { "DROP INDEX IF EXISTS iname", reparseDDL, }, + { + &DropProtoBundle{ + Position: line(1), + }, + "DROP PROTO BUNDLE", + reparseDDL, + }, { &CreateTable{ Name: "tname1", @@ -920,6 +929,70 @@ func TestSQL(t *testing.T) { `DROP SEQUENCE sname`, reparseDDL, }, + { + &AlterProtoBundle{ + Position: line(1), + }, + "ALTER PROTO BUNDLE", + reparseDDL, + }, + { + &CreateProtoBundle{ + Types: []string{"a.b.c", "b.d.e"}, + Position: line(1), + }, + "CREATE PROTO BUNDLE (`a.b.c`, `b.d.e`)", + reparseDDL, + }, + { + &CreateProtoBundle{ + Types: []string{"a"}, + Position: line(1), + }, + "CREATE PROTO BUNDLE (`a`)", + reparseDDL, + }, + { + &CreateProtoBundle{ + Types: []string{"a.b.c"}, + Position: line(1), + }, + "CREATE PROTO BUNDLE (`a.b.c`)", + reparseDDL, + }, + { + &AlterProtoBundle{ + AddTypes: []string{"a.b.c", "b.d.e"}, + Position: line(1), + }, + "ALTER PROTO BUNDLE INSERT (`a.b.c`, `b.d.e`)", + reparseDDL, + }, + { + &AlterProtoBundle{ + UpdateTypes: []string{"a.b.c", "b.d.e"}, + Position: line(1), + }, + "ALTER PROTO BUNDLE UPDATE (`a.b.c`, `b.d.e`)", + reparseDDL, + }, + { + &AlterProtoBundle{ + DeleteTypes: []string{"a.b.c", "b.d.e"}, + Position: line(1), + }, + "ALTER PROTO BUNDLE DELETE (`a.b.c`, `b.d.e`)", + reparseDDL, + }, + { + &AlterProtoBundle{ + AddTypes: []string{"e.f.g"}, + DeleteTypes: []string{"a.b.c", "b.d.e"}, + Position: line(1), + }, + "ALTER PROTO BUNDLE INSERT (`e.f.g`) DELETE (`a.b.c`, `b.d.e`)", + reparseDDL, + }, { &Insert{ Table: "Singers", @@ -1039,6 +1112,18 @@ func TestSQL(t *testing.T) { `SELECT CAST(7 AS STRING)`, reparseQuery, }, + { + Query{ + Select: Select{ + List: []Expr{Func{ + Name: "CAST", + Args: []Expr{TypedExpr{Expr: IntegerLiteral(7), Type: Type{Base: Enum}}}, + }}, + }, + }, + `SELECT CAST(7 AS ENUM)`, + reparseQuery, + }, { Query{ Select: Select{ diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go index 481d83f10f20..6a0bbc5167b7 100644 --- a/spanner/spansql/types.go +++ b/spanner/spansql/types.go @@ -526,6 +526,10 @@ type Type struct { Array bool Base TypeBase // Bool, Int64, Float64, Numeric, String, Bytes, Date, Timestamp Len int64 // if Base is String or Bytes; may be MaxLen + // fully-qualified Protocol Buffer Message or Enum type-name (including + // leading dot-separated namespace) + // non-empty if Base is ProtoMessage or ProtoEnum + ProtoRef string } // MaxLen is a sentinel for Type's Len field, representing the MAX value. @@ -543,6 +547,8 @@ const ( Date Timestamp JSON + Proto + Enum // Enum used in CAST expressions ) type PrivilegeType int @@ -1392,3 +1398,38 @@ func (ds *DropSequence) String() string { return fmt.Sprintf("%#v", ds) } func (*DropSequence) isDDLStmt() {} func (ds *DropSequence) Pos() Position { return ds.Position } func (ds *DropSequence) clearOffset() { ds.Position.Offset = 0 } + +// CreateProtoBundle represents a CREATE PROTO BUNDLE statement. +// https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#create-proto-bundle +type CreateProtoBundle struct { + Types []string + Position Position +} + +func (cp *CreateProtoBundle) String() string { return fmt.Sprintf("%#v", cp) } +func (*CreateProtoBundle) isDDLStmt() {} +func (cp *CreateProtoBundle) Pos() Position { return cp.Position } +func (cp *CreateProtoBundle) clearOffset() { cp.Position.Offset = 0 } + +// AlterProtoBundle represents a ALTER PROTO BUNDLE statement. +// https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#alter-proto-bundle +type AlterProtoBundle struct { + AddTypes, UpdateTypes, DeleteTypes []string + Position Position +} + +func (ap *AlterProtoBundle) String() string { return fmt.Sprintf("%#v", ap) } +func (*AlterProtoBundle) isDDLStmt() {} +func (ap *AlterProtoBundle) Pos() Position { return ap.Position } +func (ap *AlterProtoBundle) clearOffset() { ap.Position.Offset = 0 } + +// DropProtoBundle represents a DROP PROTO BUNDLE statement. +// https://cloud.google.com/spanner/docs/reference/standard-sql/data-definition-language#drop-proto-bundle +type DropProtoBundle struct { + Position Position +} + +func (dp *DropProtoBundle) String() string { return fmt.Sprintf("%#v", dp) } +func (*DropProtoBundle) isDDLStmt() {} +func (dp *DropProtoBundle) Pos() Position { return dp.Position } +func (dp *DropProtoBundle) clearOffset() { dp.Position.Offset = 0 } From 8771f2ea9807ab822083808e0678392edff3b4f2 Mon Sep 17 00:00:00 2001 From: Chris Smith Date: Wed, 6 Nov 2024 09:02:03 -0700 Subject: [PATCH 2/2] fix(auth): restore Application Default Credentials support to idtoken (#11083) --- auth/credentials/idtoken/file.go | 33 +++-------- auth/credentials/idtoken/idtoken.go | 24 ++++++-- auth/credentials/idtoken/idtoken_test.go | 75 ++++++++++++++++-------- 3 files changed, 77 insertions(+), 55 deletions(-) diff --git a/auth/credentials/idtoken/file.go b/auth/credentials/idtoken/file.go index 333521c91940..c160c514339a 100644 --- a/auth/credentials/idtoken/file.go +++ b/auth/credentials/idtoken/file.go @@ -21,7 +21,6 @@ import ( "strings" "cloud.google.com/go/auth" - "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/credentials/impersonate" "cloud.google.com/go/auth/internal" "cloud.google.com/go/auth/internal/credsfile" @@ -32,14 +31,8 @@ const ( iamCredAud = "https://iamcredentials.googleapis.com/" ) -var ( - defaultScopes = []string{ - "https://iamcredentials.googleapis.com/", - "https://www.googleapis.com/auth/cloud-platform", - } -) - -func credsFromBytes(b []byte, opts *Options) (*auth.Credentials, error) { +func credsFromDefault(creds *auth.Credentials, opts *Options) (*auth.Credentials, error) { + b := creds.JSON() t, err := credsfile.ParseFileType(b) if err != nil { return nil, err @@ -93,33 +86,23 @@ func credsFromBytes(b []byte, opts *Options) (*auth.Credentials, error) { account := filepath.Base(accountURL.ServiceAccountImpersonationURL) account = strings.Split(account, ":")[0] - baseCreds, err := credentials.DetectDefault(&credentials.DetectOptions{ - Scopes: defaultScopes, - CredentialsJSON: b, - Client: opts.client(), - UseSelfSignedJWT: true, - }) - if err != nil { - return nil, err - } - config := impersonate.IDTokenOptions{ Audience: opts.Audience, TargetPrincipal: account, IncludeEmail: true, Client: opts.client(), - Credentials: baseCreds, + Credentials: creds, } - creds, err := impersonate.NewIDTokenCredentials(&config) + idTokenCreds, err := impersonate.NewIDTokenCredentials(&config) if err != nil { return nil, err } return auth.NewCredentials(&auth.CredentialsOptions{ - TokenProvider: creds, + TokenProvider: idTokenCreds, JSON: b, - ProjectIDProvider: auth.CredentialsPropertyFunc(baseCreds.ProjectID), - UniverseDomainProvider: auth.CredentialsPropertyFunc(baseCreds.UniverseDomain), - QuotaProjectIDProvider: auth.CredentialsPropertyFunc(baseCreds.QuotaProjectID), + ProjectIDProvider: auth.CredentialsPropertyFunc(creds.ProjectID), + UniverseDomainProvider: auth.CredentialsPropertyFunc(creds.UniverseDomain), + QuotaProjectIDProvider: auth.CredentialsPropertyFunc(creds.QuotaProjectID), }), nil default: return nil, fmt.Errorf("idtoken: unsupported credentials type: %v", t) diff --git a/auth/credentials/idtoken/idtoken.go b/auth/credentials/idtoken/idtoken.go index a79890be9151..37947f90eb8d 100644 --- a/auth/credentials/idtoken/idtoken.go +++ b/auth/credentials/idtoken/idtoken.go @@ -22,11 +22,11 @@ package idtoken import ( "errors" - "fmt" "net/http" "os" "cloud.google.com/go/auth" + "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/internal" "cloud.google.com/go/auth/internal/credsfile" "cloud.google.com/go/compute/metadata" @@ -52,6 +52,11 @@ const ( ) var ( + defaultScopes = []string{ + "https://iamcredentials.googleapis.com/", + "https://www.googleapis.com/auth/cloud-platform", + } + errMissingOpts = errors.New("idtoken: opts must be provided") errMissingAudience = errors.New("idtoken: Audience must be provided") errBothFileAndJSON = errors.New("idtoken: CredentialsFile and CredentialsJSON must not both be provided") @@ -113,13 +118,20 @@ func NewCredentials(opts *Options) (*auth.Credentials, error) { if err := opts.validate(); err != nil { return nil, err } - if b := opts.jsonBytes(); b != nil { - return credsFromBytes(b, opts) - } - if metadata.OnGCE() { + b := opts.jsonBytes() + if b == nil && metadata.OnGCE() { return computeCredentials(opts) } - return nil, fmt.Errorf("idtoken: couldn't find any credentials") + creds, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: defaultScopes, + CredentialsJSON: b, + Client: opts.client(), + UseSelfSignedJWT: true, + }) + if err != nil { + return nil, err + } + return credsFromDefault(creds, opts) } func (o *Options) jsonBytes() []byte { diff --git a/auth/credentials/idtoken/idtoken_test.go b/auth/credentials/idtoken/idtoken_test.go index a59b35c81992..e50ca0a03d1e 100644 --- a/auth/credentials/idtoken/idtoken_test.go +++ b/auth/credentials/idtoken/idtoken_test.go @@ -66,7 +66,7 @@ func TestNewCredentials_Validate(t *testing.T) { } } -func TestNewCredentials_ServiceAccount(t *testing.T) { +func TestNewCredentials_ServiceAccount_NoClient(t *testing.T) { wantTok, _ := createRS256JWT(t) b, err := os.ReadFile("../../internal/testdata/sa.json") if err != nil { @@ -116,30 +116,57 @@ func (m mockTransport) RoundTrip(r *http.Request) (*http.Response, error) { return rw.Result(), nil } -func TestNewCredentials_ImpersonatedServiceAccount(t *testing.T) { - wantTok, _ := createRS256JWT(t) - client := internal.DefaultClient() - client.Transport = mockTransport{ - handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(fmt.Sprintf(`{"token": %q}`, wantTok))) - }), - } - creds, err := NewCredentials(&Options{ - Audience: "aud", - CredentialsFile: "../../internal/testdata/imp.json", - CustomClaims: map[string]interface{}{ - "foo": "bar", +func TestNewCredentials_ImpersonatedAndExternal(t *testing.T) { + tests := []struct { + name string + adc string + file string + }{ + { + name: "ADC external account", + adc: "../../internal/testdata/exaccount_url.json", + }, + { + name: "CredentialsFile impersonated service account", + file: "../../internal/testdata/imp.json", }, - Client: client, - }) - if err != nil { - t.Fatal(err) - } - tok, err := creds.Token(context.Background()) - if err != nil { - t.Fatalf("tp.Token() = %v", err) } - if tok.Value != wantTok { - t.Errorf("got %q, want %q", tok.Value, wantTok) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantTok, _ := createRS256JWT(t) + client := internal.DefaultClient() + client.Transport = mockTransport{ + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(fmt.Sprintf(`{"token": %q}`, wantTok))) + }), + } + + opts := &Options{ + Audience: "aud", + CustomClaims: map[string]interface{}{ + "foo": "bar", + }, + Client: client, + } + if tt.file != "" { + opts.CredentialsFile = tt.file + } else if tt.adc != "" { + t.Setenv(credsfile.GoogleAppCredsEnvVar, tt.adc) + } else { + t.Fatal("test fixture must have adc or file") + } + + creds, err := NewCredentials(opts) + if err != nil { + t.Fatal(err) + } + tok, err := creds.Token(context.Background()) + if err != nil { + t.Fatalf("tp.Token() = %v", err) + } + if tok.Value != wantTok { + t.Errorf("got %q, want %q", tok.Value, wantTok) + } + }) } }