Skip to content

Commit

Permalink
feat(generic): support an option to remove go.tag annotation (#1526)
Browse files Browse the repository at this point in the history
  • Loading branch information
Marina-Sakai authored Oct 9, 2024
1 parent 4943f34 commit 58c5848
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 63 deletions.
19 changes: 11 additions & 8 deletions pkg/generic/descriptor/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ import (

var isGoTagAliasDisabled = os.Getenv("KITEX_GENERIC_GOTAG_ALIAS_DISABLED") == "True"

func init() {
if isGoTagAliasDisabled {
// disable go.tag for dynamicgo
dthrift.RemoveAnnotationMapper(dthrift.AnnoScopeField, "go.tag")
}
}

// FieldDescriptor idl field descriptor
type FieldDescriptor struct {
Name string // field name
Expand All @@ -47,11 +40,21 @@ type FieldDescriptor struct {
Type *TypeDescriptor
HTTPMapping HTTPMapping
ValueMapping ValueMapping
GoTagOpt *GoTagOption
}

type GoTagOption struct {
IsGoAliasDisabled bool
}

// FieldName return field name maybe with an alias
func (d *FieldDescriptor) FieldName() string {
if d.Alias != "" && !isGoTagAliasDisabled {
aliasDisabled := isGoTagAliasDisabled
if d.GoTagOpt != nil {
aliasDisabled = d.GoTagOpt.IsGoAliasDisabled
}

if d.Alias != "" && !aliasDisabled {
return d.Alias
}
return d.Name
Expand Down
45 changes: 27 additions & 18 deletions pkg/generic/thrift/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func SetDefaultParseMode(m ParseMode) {
}

// Parse descriptor from parser.Thrift
func Parse(tree *parser.Thrift, mode ParseMode) (*descriptor.ServiceDescriptor, error) {
func Parse(tree *parser.Thrift, mode ParseMode, opts ...ParseOption) (*descriptor.ServiceDescriptor, error) {
if len(tree.Services) == 0 {
return nil, errors.New("empty serverce from idls")
}
Expand All @@ -90,13 +90,16 @@ func Parse(tree *parser.Thrift, mode ParseMode) (*descriptor.ServiceDescriptor,
sDsc.Name = "CombinedServices"
}

pOpts := &parseOptions{}
pOpts.apply(opts)

visitedSvcs := make(map[*parser.Service]bool, len(tree.Services))
for _, svc := range svcs {
for p := range getAllSvcs(svc, tree, visitedSvcs) {
svc := p.data.(*parser.Service)
structsCache := map[string]*descriptor.TypeDescriptor{}
for _, fn := range svc.Functions {
if err := addFunction(fn, p.tree, sDsc, structsCache); err != nil {
if err := addFunction(fn, p.tree, sDsc, structsCache, pOpts); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -135,7 +138,7 @@ func getAllSvcs(svc *parser.Service, tree *parser.Thrift, visitedSvcs map[*parse
return svcs
}

func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.ServiceDescriptor, structsCache map[string]*descriptor.TypeDescriptor) (err error) {
func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.ServiceDescriptor, structsCache map[string]*descriptor.TypeDescriptor, opts *parseOptions) (err error) {
if sDsc.Functions[fn.Name] != nil {
return fmt.Errorf("duplicate method name: %s", fn.Name)
}
Expand All @@ -156,8 +159,9 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv
FieldsByName: map[string]*descriptor.FieldDescriptor{},
},
}

var reqType *descriptor.TypeDescriptor
reqType, err = parseType(field.Type, tree, structsCache, initRecursionDepth)
reqType, err = parseType(field.Type, tree, structsCache, initRecursionDepth, opts)
if err != nil {
return err
}
Expand All @@ -171,9 +175,10 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv
}
}
reqField := &descriptor.FieldDescriptor{
Name: field.Name,
ID: field.ID,
Type: reqType,
Name: field.Name,
ID: field.ID,
Type: reqType,
GoTagOpt: opts.goTag,
}
req.Struct.FieldsByID[field.ID] = reqField
req.Struct.FieldsByName[field.Name] = reqField
Expand All @@ -186,12 +191,13 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv
},
}
var respType *descriptor.TypeDescriptor
respType, err = parseType(fn.FunctionType, tree, structsCache, initRecursionDepth)
respType, err = parseType(fn.FunctionType, tree, structsCache, initRecursionDepth, opts)
if err != nil {
return err
}
respField := &descriptor.FieldDescriptor{
Type: respType,
Type: respType,
GoTagOpt: opts.goTag,
}
// response has no name or id
resp.Struct.FieldsByID[0] = respField
Expand All @@ -201,7 +207,7 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv
// only support single exception
field := fn.Throws[0]
var exceptionType *descriptor.TypeDescriptor
exceptionType, err = parseType(field.Type, tree, structsCache, initRecursionDepth)
exceptionType, err = parseType(field.Type, tree, structsCache, initRecursionDepth, opts)
if err != nil {
return err
}
Expand All @@ -210,6 +216,7 @@ func addFunction(fn *parser.Function, tree *parser.Thrift, sDsc *descriptor.Serv
ID: field.ID,
IsException: true,
Type: exceptionType,
GoTagOpt: opts.goTag,
}
resp.Struct.FieldsByID[field.ID] = exceptionField
resp.Struct.FieldsByName[field.Name] = exceptionField
Expand Down Expand Up @@ -275,7 +282,7 @@ var builtinTypes = map[string]*descriptor.TypeDescriptor{
// arg cache:
// only support self reference on the same file
// cross file self reference complicate matters
func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor.TypeDescriptor, recursionDepth int) (*descriptor.TypeDescriptor, error) {
func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor.TypeDescriptor, recursionDepth int, opt *parseOptions) (*descriptor.TypeDescriptor, error) {
if ty, ok := builtinTypes[t.Name]; ok {
return ty, nil
}
Expand All @@ -287,20 +294,20 @@ func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor
case "list":
ty := &descriptor.TypeDescriptor{Name: t.Name}
ty.Type = descriptor.LIST
ty.Elem, err = parseType(t.ValueType, tree, cache, nextRecursionDepth)
ty.Elem, err = parseType(t.ValueType, tree, cache, nextRecursionDepth, opt)
return ty, err
case "set":
ty := &descriptor.TypeDescriptor{Name: t.Name}
ty.Type = descriptor.SET
ty.Elem, err = parseType(t.ValueType, tree, cache, nextRecursionDepth)
ty.Elem, err = parseType(t.ValueType, tree, cache, nextRecursionDepth, opt)
return ty, err
case "map":
ty := &descriptor.TypeDescriptor{Name: t.Name}
ty.Type = descriptor.MAP
if ty.Key, err = parseType(t.KeyType, tree, cache, nextRecursionDepth); err != nil {
if ty.Key, err = parseType(t.KeyType, tree, cache, nextRecursionDepth, opt); err != nil {
return nil, err
}
ty.Elem, err = parseType(t.ValueType, tree, cache, nextRecursionDepth)
ty.Elem, err = parseType(t.ValueType, tree, cache, nextRecursionDepth, opt)
return ty, err
default:
// check the cache
Expand All @@ -318,7 +325,7 @@ func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor
cache = map[string]*descriptor.TypeDescriptor{}
}
if typDef, ok := tree.GetTypedef(typeName); ok {
return parseType(typDef.Type, tree, cache, nextRecursionDepth)
return parseType(typDef.Type, tree, cache, nextRecursionDepth, opt)
}
if _, ok := tree.GetEnum(typeName); ok {
return builtinTypes["i32"], nil
Expand Down Expand Up @@ -359,7 +366,9 @@ func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor
Required: field.Requiredness == parser.FieldType_Required,
Optional: field.Requiredness == parser.FieldType_Optional,
}

if opt != nil {
_f.GoTagOpt = opt.goTag
}
for _, ann := range field.Annotations {
for _, v := range ann.GetValues() {
if handle, ok := descriptor.FindAnnotation(ann.GetKey(), v); ok {
Expand All @@ -383,7 +392,7 @@ func parseType(t *parser.Type, tree *parser.Thrift, cache map[string]*descriptor
if _f.HTTPMapping == nil {
_f.HTTPMapping = descriptor.DefaultNewMapping(_f.FieldName())
}
if _f.Type, err = parseType(field.Type, tree, cache, nextRecursionDepth); err != nil {
if _f.Type, err = parseType(field.Type, tree, cache, nextRecursionDepth, opt); err != nil {
return nil, err
}
// set default value
Expand Down
41 changes: 41 additions & 0 deletions pkg/generic/thrift/parse_option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import "github.com/cloudwego/kitex/pkg/generic/descriptor"

type parseOptions struct {
goTag *descriptor.GoTagOption
}

type ParseOption struct {
F func(opt *parseOptions)
}

func (o *parseOptions) apply(opts []ParseOption) {
for _, op := range opts {
op.F(o)
}
}

func WithGoTagDisabled(disable bool) ParseOption {
return ParseOption{F: func(opt *parseOptions) {
opt.goTag = &descriptor.GoTagOption{
IsGoAliasDisabled: disable,
}
}}
}
Loading

0 comments on commit 58c5848

Please sign in to comment.