Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Internal] Add converter functions and tests for plugin framework #3914

Merged
merged 7 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions common/force_send_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"reflect"
"strings"

"github.com/databricks/terraform-provider-databricks/internal/tfreflect"
"golang.org/x/exp/slices"
)

Expand All @@ -29,22 +30,22 @@ func SetForceSendFields(req any, d attributeGetter, fields []string) {
if !ok {
panic(fmt.Errorf("request argument to setForceSendFields must have ForceSendFields field of type []string (got %s)", forceSendFieldsField.Type()))
}
fs := listAllFields(rv)
fs := tfreflect.ListAllFields(rv)
for _, fieldName := range fields {
found := false
var structField reflect.StructField
for _, f := range fs {
fn := chooseFieldName(f.sf)
fn := chooseFieldName(f.StructField)
if fn != "-" && fn == fieldName {
found = true
structField = f.sf
structField = f.StructField
break
}
}
if !found {
allFieldNames := make([]string, 0)
for _, f := range fs {
fn := chooseFieldName(f.sf)
fn := chooseFieldName(f.StructField)
if fn == "-" || fn == "force_send_fields" {
continue
}
Expand Down
33 changes: 6 additions & 27 deletions common/reflect_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strconv"
"strings"

"github.com/databricks/terraform-provider-databricks/internal/tfreflect"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

Expand Down Expand Up @@ -382,28 +383,6 @@ func diffSuppressor(fieldName string, v *schema.Schema) func(k, old, new string,
}
}

type field struct {
sf reflect.StructField
v reflect.Value
}

func listAllFields(v reflect.Value) []field {
t := v.Type()
fields := make([]field, 0, v.NumField())
for i := 0; i < v.NumField(); i++ {
f := t.Field(i)
if f.Anonymous {
fields = append(fields, listAllFields(v.Field(i))...)
} else {
fields = append(fields, field{
sf: f,
v: v.Field(i),
})
}
}
return fields
}

func typeToSchema(v reflect.Value, aliases map[string]map[string]string, tc trackingContext) map[string]*schema.Schema {
if rpStruct, ok := resourceProviderRegistry[getNonPointerType(v.Type())]; ok {
return resourceProviderStructToSchemaInternal(rpStruct, tc.pathCtx).GetSchemaMap()
Expand All @@ -418,9 +397,9 @@ func typeToSchema(v reflect.Value, aliases map[string]map[string]string, tc trac
panic(fmt.Errorf("Schema value of Struct is expected, but got %s: %#v", reflectKind(rk), v))
}
tc = tc.visit(v)
fields := listAllFields(v)
fields := tfreflect.ListAllFields(v)
for _, field := range fields {
typeField := field.sf
typeField := field.StructField
if tc.depthExceeded(typeField) {
// Skip the field if recursion depth is over the limit.
log.Printf("[TRACE] over recursion limit, skipping field: %s, max depth: %d", getNameForType(typeField.Type), tc.getMaxDepthForTypeField(typeField))
Expand Down Expand Up @@ -597,9 +576,9 @@ func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, al
return fmt.Errorf("%s: got invalid reflect value %#v", path, rv)
}
isGoSDK := isGoSdk(rv)
fields := listAllFields(rv)
fields := tfreflect.ListAllFields(rv)
for _, field := range fields {
typeField := field.sf
typeField := field.StructField
fieldName := chooseFieldNameWithAliases(typeField, rv.Type(), aliases)
if fieldName == "-" {
continue
Expand All @@ -617,7 +596,7 @@ func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, al
if !isGoSDK && fieldSchema.Optional && defaultEmpty && !omitEmpty {
return fmt.Errorf("inconsistency: %s is optional, default is empty, but has no omitempty", fieldName)
}
valueField := field.v
valueField := field.Value
err := cb(fieldSchema, append(path, fieldName), &valueField)
if err != nil {
return fmt.Errorf("%s: %s", fieldName, err)
Expand Down
3 changes: 2 additions & 1 deletion common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
)

var (
uuidRegex = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
uuidRegex = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
TerraformBugErrorMessage = "This is a bug. Please report this to the maintainers at github.com/databricks/terraform-provider-databricks."
)

func StringIsUUID(s string) bool {
Expand Down
27 changes: 27 additions & 0 deletions internal/tfreflect/reflect_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package tfreflect

import "reflect"

type Field struct {
StructField reflect.StructField
Value reflect.Value
}

// Given a reflect.Value of a struct, list all of the fields for both struct field and
edwardfeng-db marked this conversation as resolved.
Show resolved Hide resolved
// the value. This function also extracts and flattens the anonymous fields nested inside.
func ListAllFields(v reflect.Value) []Field {
t := v.Type()
fields := make([]Field, 0, v.NumField())
for i := 0; i < v.NumField(); i++ {
f := t.Field(i)
if f.Anonymous {
fields = append(fields, ListAllFields(v.Field(i))...)
} else {
fields = append(fields, Field{
StructField: f,
Value: v.Field(i),
})
}
}
return fields
}
228 changes: 228 additions & 0 deletions pluginframework/converters/converters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
package converters

import (
"context"
"fmt"
"reflect"
"testing"

"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/stretchr/testify/assert"
)

type DummyTfSdk struct {
Enabled types.Bool `tfsdk:"enabled" tf:"optional"`
Workers types.Int64 `tfsdk:"workers" tf:""`
Floats types.Float64 `tfsdk:"floats" tf:""`
Description types.String `tfsdk:"description" tf:""`
Tasks types.String `tfsdk:"task" tf:"optional"`
Nested *DummyNestedTfSdk `tfsdk:"nested" tf:"optional"`
NoPointerNested DummyNestedTfSdk `tfsdk:"no_pointer_nested" tf:"optional"`
NestedList []DummyNestedTfSdk `tfsdk:"nested_list" tf:"optional"`
NestedPointerList []*DummyNestedTfSdk `tfsdk:"nested_pointer_list" tf:"optional"`
Map map[string]types.String `tfsdk:"map" tf:"optional"`
NestedMap map[string]DummyNestedTfSdk `tfsdk:"nested_map" tf:"optional"`
Repeated []types.Int64 `tfsdk:"repeated" tf:"optional"`
Attributes map[string]types.String `tfsdk:"attributes" tf:"optional"`
EnumField types.String `tfsdk:"enum_field" tf:"optional"`
AdditionalField types.String `tfsdk:"additional_field" tf:"optional"`
DistinctField types.String `tfsdk:"distinct_field" tf:"optional"`
Irrelevant types.String `tfsdk:"-"`
}

type TestEnum string

const TestEnumA TestEnum = `TEST_ENUM_A`

const TestEnumB TestEnum = `TEST_ENUM_B`

const TestEnumC TestEnum = `TEST_ENUM_C`

func (f *TestEnum) String() string {
return string(*f)
}

// Set raw string value and validate it against allowed values
func (f *TestEnum) Set(v string) error {
switch v {
case `TEST_ENUM_A`, `TEST_ENUM_B`, `TEST_ENUM_C`:
*f = TestEnum(v)
return nil
default:
return fmt.Errorf(`value "%s" is not one of "TEST_ENUM_A", "TEST_ENUM_B", "TEST_ENUM_C"`, v)
}
}

func (f *TestEnum) Type() string {
return "MonitorInfoStatus"
}

type DummyNestedTfSdk struct {
Name types.String `tfsdk:"name" tf:"optional"`
Enabled types.Bool `tfsdk:"enabled" tf:"optional"`
}

type DummyGoSdk struct {
Enabled bool `json:"enabled"`
Workers int64 `json:"workers"`
Floats float64 `json:"floats"`
Description string `json:"description"`
Tasks string `json:"tasks"`
Nested *DummyNestedGoSdk `json:"nested"`
NoPointerNested DummyNestedGoSdk `json:"no_pointer_nested"`
NestedList []DummyNestedGoSdk `json:"nested_list"`
NestedPointerList []*DummyNestedGoSdk `json:"nested_pointer_list"`
Map map[string]string `json:"map"`
NestedMap map[string]DummyNestedGoSdk `json:"nested_map"`
Repeated []int64 `json:"repeated"`
Attributes map[string]string `json:"attributes"`
EnumField TestEnum `json:"enum_field"`
AdditionalField string `json:"additional_field"`
DistinctField string `json:"distinct_field"` // distinct field that the tfsdk struct doesn't have
ForceSendFields []string `json:"-"`
}

type DummyNestedGoSdk struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
ForceSendFields []string `json:"-"`
}

// Function to construct individual test case with a pair of matching tfSdkStruct and gosdkStruct.
// Verifies that the conversion both ways are working as expected.
func RunConverterTest(t *testing.T, description string, tfSdkStruct DummyTfSdk, goSdkStruct DummyGoSdk) {
convertedGoSdkStruct := DummyGoSdk{}
assert.True(t, !TfSdkToGoSdkStruct(context.Background(), tfSdkStruct, &convertedGoSdkStruct).HasError())
assert.True(t, reflect.DeepEqual(convertedGoSdkStruct, goSdkStruct), fmt.Sprintf("tfsdk to gosdk conversion - %s", description))

convertedTfSdkStruct := DummyTfSdk{}
assert.True(t, !GoSdkToTfSdkStruct(context.Background(), goSdkStruct, &convertedTfSdkStruct).HasError())
assert.True(t, reflect.DeepEqual(convertedTfSdkStruct, tfSdkStruct), fmt.Sprintf("gosdk to tfsdk conversion - %s", description))
}

var tests = []struct {
name string
tfSdkStruct DummyTfSdk
goSdkStruct DummyGoSdk
}{
{
"string conversion",
DummyTfSdk{Description: types.StringValue("abc")},
DummyGoSdk{Description: "abc", ForceSendFields: []string{"Description"}},
},
{
edwardfeng-db marked this conversation as resolved.
Show resolved Hide resolved
"bool conversion",
DummyTfSdk{Enabled: types.BoolValue(true)},
DummyGoSdk{Enabled: true, ForceSendFields: []string{"Enabled"}},
},
{
"int64 conversion",
DummyTfSdk{Workers: types.Int64Value(123)},
DummyGoSdk{Workers: 123, ForceSendFields: []string{"Workers"}},
},
{
"tf null value conversion",
DummyTfSdk{Workers: types.Int64Null()},
DummyGoSdk{},
},
{
"float64 conversion",
DummyTfSdk{Floats: types.Float64Value(1.1)},
DummyGoSdk{Floats: 1.1, ForceSendFields: []string{"Floats"}},
},
{
"enum conversion",
DummyTfSdk{EnumField: types.StringValue("TEST_ENUM_A")},
DummyGoSdk{EnumField: TestEnumA},
},
{
"struct conversion",
DummyTfSdk{NoPointerNested: DummyNestedTfSdk{
Name: types.StringValue("def"),
Enabled: types.BoolValue(true),
}},
DummyGoSdk{NoPointerNested: DummyNestedGoSdk{
Name: "def",
Enabled: true,
ForceSendFields: []string{"Name", "Enabled"},
}},
},
{
"pointer conversion",
DummyTfSdk{Nested: &DummyNestedTfSdk{
Name: types.StringValue("def"),
Enabled: types.BoolValue(true),
}},
DummyGoSdk{Nested: &DummyNestedGoSdk{
Name: "def",
Enabled: true,
ForceSendFields: []string{"Name", "Enabled"},
}},
},
{
"list conversion",
DummyTfSdk{Repeated: []types.Int64{types.Int64Value(12), types.Int64Value(34)}},
DummyGoSdk{Repeated: []int64{12, 34}},
},
{
"map conversion",
DummyTfSdk{Attributes: map[string]types.String{"key": types.StringValue("value")}},
DummyGoSdk{Attributes: map[string]string{"key": "value"}},
},
{
"nested list conversion",
DummyTfSdk{NestedList: []DummyNestedTfSdk{
{
Name: types.StringValue("def"),
Enabled: types.BoolValue(true),
},
{
Name: types.StringValue("def"),
Enabled: types.BoolValue(true),
},
}},
DummyGoSdk{NestedList: []DummyNestedGoSdk{
{
Name: "def",
Enabled: true,
ForceSendFields: []string{"Name", "Enabled"},
},
{
Name: "def",
Enabled: true,
ForceSendFields: []string{"Name", "Enabled"},
},
}},
},
{
"nested map conversion",
DummyTfSdk{NestedMap: map[string]DummyNestedTfSdk{
"key1": {
Name: types.StringValue("abc"),
Enabled: types.BoolValue(true),
},
"key2": {
Name: types.StringValue("def"),
Enabled: types.BoolValue(false),
},
}},
DummyGoSdk{NestedMap: map[string]DummyNestedGoSdk{
"key1": {
Name: "abc",
Enabled: true,
ForceSendFields: []string{"Name", "Enabled"},
},
"key2": {
Name: "def",
Enabled: false,
ForceSendFields: []string{"Name", "Enabled"},
},
}},
},
}

func TestConverter(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) { RunConverterTest(t, test.name, test.tfSdkStruct, test.goSdkStruct) })
}
}
Loading
Loading