diff --git a/common/force_send_fields.go b/common/force_send_fields.go index f8cc2f2a78..bddcb82278 100644 --- a/common/force_send_fields.go +++ b/common/force_send_fields.go @@ -29,22 +29,22 @@ func SetForceSendFields(req any, d attributeGetter, fields []string) { if !ok { panic(fmt.Errorf("request argument to setForceSendFields must have ForceSendFields field of type []string (got %s)", forceSendFieldsField.Type())) } - fs := listAllFields(rv) + fs := ListAllFields(rv) for _, fieldName := range fields { found := false var structField reflect.StructField for _, f := range fs { - fn := chooseFieldName(f.sf) + fn := chooseFieldName(f.Sf) if fn != "-" && fn == fieldName { found = true - structField = f.sf + structField = f.Sf break } } if !found { allFieldNames := make([]string, 0) for _, f := range fs { - fn := chooseFieldName(f.sf) + fn := chooseFieldName(f.Sf) if fn == "-" || fn == "force_send_fields" { continue } diff --git a/common/reflect_resource.go b/common/reflect_resource.go index 7c28e9ec5d..f5b1a4f9cd 100644 --- a/common/reflect_resource.go +++ b/common/reflect_resource.go @@ -382,22 +382,22 @@ func diffSuppressor(fieldName string, v *schema.Schema) func(k, old, new string, } } -type field struct { - sf reflect.StructField - v reflect.Value +type Field struct { + Sf reflect.StructField + V reflect.Value } -func listAllFields(v reflect.Value) []field { +func ListAllFields(v reflect.Value) []Field { t := v.Type() - fields := make([]field, 0, v.NumField()) + fields := make([]Field, 0, v.NumField()) for i := 0; i < v.NumField(); i++ { f := t.Field(i) if f.Anonymous { - fields = append(fields, listAllFields(v.Field(i))...) + fields = append(fields, ListAllFields(v.Field(i))...) } else { - fields = append(fields, field{ - sf: f, - v: v.Field(i), + fields = append(fields, Field{ + Sf: f, + V: v.Field(i), }) } } @@ -418,9 +418,9 @@ func typeToSchema(v reflect.Value, aliases map[string]map[string]string, tc trac panic(fmt.Errorf("Schema value of Struct is expected, but got %s: %#v", reflectKind(rk), v)) } tc = tc.visit(v) - fields := listAllFields(v) + fields := ListAllFields(v) for _, field := range fields { - typeField := field.sf + typeField := field.Sf if tc.depthExceeded(typeField) { // Skip the field if recursion depth is over the limit. log.Printf("[TRACE] over recursion limit, skipping field: %s, max depth: %d", getNameForType(typeField.Type), tc.getMaxDepthForTypeField(typeField)) @@ -597,9 +597,9 @@ func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, al return fmt.Errorf("%s: got invalid reflect value %#v", path, rv) } isGoSDK := isGoSdk(rv) - fields := listAllFields(rv) + fields := ListAllFields(rv) for _, field := range fields { - typeField := field.sf + typeField := field.Sf fieldName := chooseFieldNameWithAliases(typeField, rv.Type(), aliases) if fieldName == "-" { continue @@ -617,7 +617,7 @@ func iterFields(rv reflect.Value, path []string, s map[string]*schema.Schema, al if !isGoSDK && fieldSchema.Optional && defaultEmpty && !omitEmpty { return fmt.Errorf("inconsistency: %s is optional, default is empty, but has no omitempty", fieldName) } - valueField := field.v + valueField := field.V err := cb(fieldSchema, append(path, fieldName), &valueField) if err != nil { return fmt.Errorf("%s: %s", fieldName, err) diff --git a/pluginframework/common/converters_test.go b/pluginframework/common/converters_test.go new file mode 100644 index 0000000000..4b8732924f --- /dev/null +++ b/pluginframework/common/converters_test.go @@ -0,0 +1,370 @@ +package pluginframework + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/hashicorp/terraform-plugin-framework/tfsdk" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/stretchr/testify/assert" +) + +type DummyTfSdk struct { + Enabled types.Bool `tfsdk:"enabled" tf:"optional"` + Workers types.Int64 `tfsdk:"workers" tf:""` // Test required field + Floats types.Float64 `tfsdk:"floats" tf:""` // Test required field + Description types.String `tfsdk:"description" tf:""` + Tasks types.String `tfsdk:"task" tf:"optional"` + Nested *DummyNestedTfSdk `tfsdk:"nested" tf:"optional"` + NoPointerNested DummyNestedTfSdk `tfsdk:"no_pointer_nested" tf:"optional"` + NestedList []DummyNestedTfSdk `tfsdk:"nested_list" tf:"optional"` + NestedPointerList []*DummyNestedTfSdk `tfsdk:"nested_pointer_list" tf:"optional"` + Map map[string]types.String `tfsdk:"map" tf:"optional"` + NestedMap map[string]DummyNestedTfSdk `tfsdk:"nested_map" tf:"optional"` + Repeated []types.Int64 `tfsdk:"repeated" tf:"optional"` + Attributes map[string]types.String `tfsdk:"attributes" tf:"optional"` + EnumField types.String `tfsdk:"enum_field" tf:"optional"` + AdditionalField types.String `tfsdk:"additional_field" tf:"optional"` + DistinctField types.String `tfsdk:"distinct_field" tf:"optional"` // distinct field that the gosdk struct doesn't have + Irrelevant types.String `tfsdk:"-"` +} + +type TestEnum string + +const TestEnumA TestEnum = `TEST_ENUM_A` + +const TestEnumB TestEnum = `TEST_ENUM_B` + +const TestEnumC TestEnum = `TEST_ENUM_C` + +func (f *TestEnum) String() string { + return string(*f) +} + +// Set raw string value and validate it against allowed values +func (f *TestEnum) Set(v string) error { + switch v { + case `TEST_ENUM_A`, `TEST_ENUM_B`, `TEST_ENUM_C`: + *f = TestEnum(v) + return nil + default: + return fmt.Errorf(`value "%s" is not one of "TEST_ENUM_A", "TEST_ENUM_B", "TEST_ENUM_C"`, v) + } +} + +func (f *TestEnum) Type() string { + return "MonitorInfoStatus" +} + +type DummyNestedTfSdk struct { + Name types.String `tfsdk:"name" tf:"optional"` + Enabled types.Bool `tfsdk:"enabled" tf:"optional"` +} + +type DummyGoSdk struct { + Enabled bool `json:"enabled"` + Workers int64 `json:"workers"` + Floats float64 `json:"floats"` + Description string `json:"description"` + Tasks string `json:"tasks"` + Nested *DummyNestedGoSdk `json:"nested"` + NoPointerNested DummyNestedGoSdk `json:"no_pointer_nested"` + NestedList []DummyNestedGoSdk `json:"nested_list"` + NestedPointerList []*DummyNestedGoSdk `json:"nested_pointer_list"` + Map map[string]string `json:"map"` + NestedMap map[string]DummyNestedGoSdk `json:"nested_map"` + Repeated []int64 `json:"repeated"` + Attributes map[string]string `json:"attributes"` + EnumField TestEnum `json:"enum_field"` + AdditionalField string `json:"additional_field"` + DistinctField string `json:"distinct_field"` // distinct field that the tfsdk struct doesn't have + ForceSendFields []string `json:"-"` +} + +type DummyNestedGoSdk struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + ForceSendFields []string `json:"-"` +} + +type emptyCtx struct{} + +func (emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (emptyCtx) Done() <-chan struct{} { + return nil +} + +func (emptyCtx) Err() error { + return nil +} + +func (emptyCtx) Value(key any) any { + return nil +} + +var ctx = emptyCtx{} + +// Constructing a dummy tfsdk struct. +var tfSdkStructComplex = DummyTfSdk{ + Enabled: types.BoolValue(false), + Workers: types.Int64Value(12), + Description: types.StringValue("abc"), + Tasks: types.StringNull(), + Nested: &DummyNestedTfSdk{ + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + NoPointerNested: DummyNestedTfSdk{ + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + NestedList: []DummyNestedTfSdk{ + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + }, + Map: map[string]types.String{ + "key1": types.StringValue("value1"), + "key2": types.StringValue("value2"), + }, + NestedMap: map[string]DummyNestedTfSdk{ + "key1": { + Name: types.StringValue("abc"), + Enabled: types.BoolValue(false), + }, + "key2": { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + }, + NestedPointerList: []*DummyNestedTfSdk{ + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + }, + Attributes: map[string]types.String{"key": types.StringValue("value")}, + EnumField: types.StringValue("TEST_ENUM_A"), + Repeated: []types.Int64{types.Int64Value(12), types.Int64Value(34)}, +} + +func TestGetAndSetPluginFramework(t *testing.T) { + // Also test StructToSchema. + scm := PluginFrameworkResourceStructToSchema(DummyTfSdk{}, nil) + state := tfsdk.State{ + Schema: scm, + } + + // Test optional is being handled properly. + assert.True(t, scm.Attributes["enabled"].IsOptional()) + assert.True(t, scm.Attributes["workers"].IsRequired()) + assert.True(t, scm.Attributes["floats"].IsRequired()) + + // Assert that we can set state from the tfsdk struct. + diags := state.Set(ctx, tfSdkStructComplex) + assert.Len(t, diags, 0) + + // Assert that we can get a struct from the state. + getterStruct := DummyTfSdk{} + diags = state.Get(ctx, &getterStruct) + assert.Len(t, diags, 0) + + // Assert the struct populated from .Get is exactly the same as the original tfsdk struct. + assert.True(t, reflect.DeepEqual(getterStruct, tfSdkStructComplex)) +} + +func TestStructConversion(t *testing.T) { + // Convert from tfsdk to gosdk struct using the converter function + convertedGoSdkStruct := DummyGoSdk{} + diags := TfSdkToGoSdkStruct(tfSdkStructComplex, &convertedGoSdkStruct, ctx) + assert.True(t, !diags.HasError()) + + // Convert the gosdk struct back to tfsdk struct + convertedTfSdkStruct := DummyTfSdk{} + diags = GoSdkToTfSdkStruct(convertedGoSdkStruct, &convertedTfSdkStruct, ctx) + assert.True(t, !diags.HasError()) + + // Assert that the struct is exactly the same after tfsdk --> gosdk --> tfsdk + assert.True(t, reflect.DeepEqual(tfSdkStructComplex, convertedTfSdkStruct)) +} + +// Function to construct individual test case with a pair of matching tfSdkStruct and gosdkStruct. +// Verifies that the conversion both ways are working as expected. +func ConverterTestCase(t *testing.T, description string, tfSdkStruct DummyTfSdk, goSdkStruct DummyGoSdk) { + convertedGoSdkStruct := DummyGoSdk{} + assert.True(t, !TfSdkToGoSdkStruct(tfSdkStruct, &convertedGoSdkStruct, ctx).HasError()) + assert.True(t, reflect.DeepEqual(convertedGoSdkStruct, goSdkStruct), fmt.Sprintf("tfsdk to gosdk conversion - %s", description)) + + convertedTfSdkStruct := DummyTfSdk{} + assert.True(t, !GoSdkToTfSdkStruct(goSdkStruct, &convertedTfSdkStruct, ctx).HasError()) + assert.True(t, reflect.DeepEqual(convertedTfSdkStruct, tfSdkStruct), fmt.Sprintf("gosdk to tfsdk conversion - %s", description)) +} + +func TestConverter(t *testing.T) { + ConverterTestCase( + t, + "string conversion", + DummyTfSdk{Description: types.StringValue("abc")}, + DummyGoSdk{Description: "abc", ForceSendFields: []string{"Description"}}, + ) + ConverterTestCase( + t, + "bool conversion", + DummyTfSdk{Enabled: types.BoolValue(true)}, + DummyGoSdk{Enabled: true, ForceSendFields: []string{"Enabled"}}, + ) + ConverterTestCase( + t, + "int64 conversion", + DummyTfSdk{Workers: types.Int64Value(123)}, + DummyGoSdk{Workers: 123, ForceSendFields: []string{"Workers"}}, + ) + ConverterTestCase( + t, + "tf null value conversion", + DummyTfSdk{Workers: types.Int64Null()}, + DummyGoSdk{}, + ) + ConverterTestCase( + t, + "float64 conversion", + DummyTfSdk{Floats: types.Float64Value(1.1)}, + DummyGoSdk{Floats: 1.1, ForceSendFields: []string{"Floats"}}, + ) + ConverterTestCase( + t, + "enum conversion", + DummyTfSdk{EnumField: types.StringValue("TEST_ENUM_A")}, + DummyGoSdk{EnumField: TestEnumA}, + ) + ConverterTestCase( + t, + "struct conversion", + DummyTfSdk{NoPointerNested: DummyNestedTfSdk{ + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }}, + DummyGoSdk{NoPointerNested: DummyNestedGoSdk{ + Name: "def", + Enabled: true, + ForceSendFields: []string{"Name", "Enabled"}, + }}, + ) + ConverterTestCase( + t, + "pointer conversion", + DummyTfSdk{Nested: &DummyNestedTfSdk{ + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }}, + DummyGoSdk{Nested: &DummyNestedGoSdk{ + Name: "def", + Enabled: true, + ForceSendFields: []string{"Name", "Enabled"}, + }}, + ) + ConverterTestCase( + t, + "list conversion", + DummyTfSdk{Repeated: []types.Int64{types.Int64Value(12), types.Int64Value(34)}}, + DummyGoSdk{Repeated: []int64{12, 34}}, + ) + ConverterTestCase( + t, + "map conversion", + DummyTfSdk{Attributes: map[string]types.String{"key": types.StringValue("value")}}, + DummyGoSdk{Attributes: map[string]string{"key": "value"}}, + ) + ConverterTestCase( + t, + "nested list conversion", + DummyTfSdk{NestedList: []DummyNestedTfSdk{ + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(true), + }, + }}, + DummyGoSdk{NestedList: []DummyNestedGoSdk{ + { + Name: "def", + Enabled: true, + ForceSendFields: []string{"Name", "Enabled"}, + }, + { + Name: "def", + Enabled: true, + ForceSendFields: []string{"Name", "Enabled"}, + }, + }}, + ) + ConverterTestCase( + t, + "nested list conversion", + DummyTfSdk{NestedList: []DummyNestedTfSdk{ + { + Name: types.StringValue("abc"), + Enabled: types.BoolValue(true), + }, + { + Name: types.StringValue("def"), + Enabled: types.BoolValue(false), + }, + }}, + DummyGoSdk{NestedList: []DummyNestedGoSdk{ + { + Name: "abc", + Enabled: true, + ForceSendFields: []string{"Name", "Enabled"}, + }, + { + Name: "def", + Enabled: false, + ForceSendFields: []string{"Name", "Enabled"}, + }, + }}, + ) + ConverterTestCase( + t, + "nested map conversion", + DummyTfSdk{NestedMap: map[string]DummyNestedTfSdk{ + "key1": { + Name: types.StringValue("abc"), + Enabled: types.BoolValue(true), + }, + "key2": { + Name: types.StringValue("def"), + Enabled: types.BoolValue(false), + }, + }}, + DummyGoSdk{NestedMap: map[string]DummyNestedGoSdk{ + "key1": { + Name: "abc", + Enabled: true, + ForceSendFields: []string{"Name", "Enabled"}, + }, + "key2": { + Name: "def", + Enabled: false, + ForceSendFields: []string{"Name", "Enabled"}, + }, + }}, + ) +} diff --git a/pluginframework/common/customizable_schema.go b/pluginframework/common/customizable_schema.go new file mode 100644 index 0000000000..dbaf2f0755 --- /dev/null +++ b/pluginframework/common/customizable_schema.go @@ -0,0 +1,126 @@ +package pluginframework + +import ( + "fmt" +) + +type CustomizableSchemaPluginFramework struct { + attr Attribute +} + +func ConstructCustomizableSchema(attributes map[string]Attribute) *CustomizableSchemaPluginFramework { + attr := Attribute(SingleNestedAttribute{Attributes: attributes}) + return &CustomizableSchemaPluginFramework{attr: attr} +} + +// Converts CustomizableSchema into a map from string to Attribute. +func (s *CustomizableSchemaPluginFramework) ToAttributeMap() map[string]Attribute { + return attributeToMap(&s.attr) +} + +func attributeToMap(attr *Attribute) map[string]Attribute { + var m map[string]Attribute + switch attr := (*attr).(type) { + case SingleNestedAttribute: + m = attr.Attributes + case ListNestedAttribute: + m = attr.NestedObject.Attributes + case MapNestedAttribute: + m = attr.NestedObject.Attributes + default: + panic(fmt.Errorf("cannot convert to map, attribute is not nested")) + } + + return m +} + +func (s *CustomizableSchemaPluginFramework) AddValidator(v any, path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.AddValidators(v) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetOptional(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.SetOptional() + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetRequired(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.SetRequired() + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetSensitive(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.SetSensitive() + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + return s +} + +func (s *CustomizableSchemaPluginFramework) SetDeprecated(msg string, path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.SetDeprecated(msg) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetComputed(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.SetComputed() + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + return s +} + +// SetReadOnly sets the schema to be read-only (i.e. computed, non-optional). +// This should be used for fields that are not user-configurable but are returned +// by the platform. +func (s *CustomizableSchemaPluginFramework) SetReadOnly(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr Attribute) Attribute { + return attr.SetReadOnly() + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +// Helper function for navigating through schema attributes, panics if path does not exist or invalid. +func navigateSchemaWithCallback(s *Attribute, cb func(Attribute) Attribute, path ...string) (Attribute, error) { + current_scm := s + for i, p := range path { + m := attributeToMap(current_scm) + + v, ok := m[p] + if !ok { + return nil, fmt.Errorf("missing key %s", p) + } + + if i == len(path)-1 { + m[p] = cb(v) + return m[p], nil + } + current_scm = &v + } + return nil, fmt.Errorf("path %v is incomplete", path) +} diff --git a/pluginframework/common/databricks_attribute.go b/pluginframework/common/databricks_attribute.go new file mode 100644 index 0000000000..dde24a1c33 --- /dev/null +++ b/pluginframework/common/databricks_attribute.go @@ -0,0 +1,571 @@ +package pluginframework + +import ( + "github.com/hashicorp/terraform-plugin-framework/attr" + dataschema "github.com/hashicorp/terraform-plugin-framework/datasource/schema" + "github.com/hashicorp/terraform-plugin-framework/resource/schema" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" +) + +// Common interface for all attributes, we need this because in terraform plugin framework, the datasource schema and resource +// schema are in two separate packages. This common interface prevents us from keeping two copies of StructToSchema and CustomizableSchema. +type Attribute interface { + ToDataSourceAttribute() dataschema.Attribute + ToResourceAttribute() schema.Attribute + SetOptional() Attribute + SetRequired() Attribute + SetSensitive() Attribute + SetComputed() Attribute + SetReadOnly() Attribute + SetDeprecated(string) Attribute + AddValidators(any) Attribute +} + +type StringAttribute struct { + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.String +} + +func (a StringAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.StringAttribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a StringAttribute) ToResourceAttribute() schema.Attribute { + return schema.StringAttribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a StringAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a StringAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a StringAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a StringAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a StringAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a StringAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a StringAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.String)) + return a +} + +type Float64Attribute struct { + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.Float64 +} + +func (a Float64Attribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.Float64Attribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a Float64Attribute) ToResourceAttribute() schema.Attribute { + return schema.Float64Attribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a Float64Attribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a Float64Attribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a Float64Attribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a Float64Attribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a Float64Attribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a Float64Attribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a Float64Attribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.Float64)) + return a +} + +type Int64Attribute struct { + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.Int64 +} + +func (a Int64Attribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.Int64Attribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a Int64Attribute) ToResourceAttribute() schema.Attribute { + return schema.Int64Attribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a Int64Attribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a Int64Attribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a Int64Attribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a Int64Attribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a Int64Attribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a Int64Attribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a Int64Attribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.Int64)) + return a +} + +type BoolAttribute struct { + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.Bool +} + +func (a BoolAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.BoolAttribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a BoolAttribute) ToResourceAttribute() schema.Attribute { + return schema.BoolAttribute{Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a BoolAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a BoolAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a BoolAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a BoolAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a BoolAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a BoolAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a BoolAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.Bool)) + return a +} + +type MapAttribute struct { + ElementType attr.Type + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.Map +} + +func (a MapAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.MapAttribute{ElementType: a.ElementType, Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a MapAttribute) ToResourceAttribute() schema.Attribute { + return schema.MapAttribute{ElementType: a.ElementType, Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a MapAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a MapAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a MapAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a MapAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a MapAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a MapAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a MapAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.Map)) + return a +} + +type ListAttribute struct { + ElementType attr.Type + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.List +} + +func (a ListAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.ListAttribute{ElementType: a.ElementType, Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a ListAttribute) ToResourceAttribute() schema.Attribute { + return schema.ListAttribute{ElementType: a.ElementType, Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a ListAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a ListAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a ListAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a ListAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a ListAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a ListAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a ListAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.List)) + return a +} + +type SingleNestedAttribute struct { + Attributes map[string]Attribute + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.Object +} + +func (a SingleNestedAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.SingleNestedAttribute{Attributes: ToDataSourceAttributeMap(a.Attributes), Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a SingleNestedAttribute) ToResourceAttribute() schema.Attribute { + return schema.SingleNestedAttribute{Attributes: ToResourceAttributeMap(a.Attributes), Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a SingleNestedAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a SingleNestedAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a SingleNestedAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a SingleNestedAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a SingleNestedAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a SingleNestedAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a SingleNestedAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.Object)) + return a +} + +type ListNestedAttribute struct { + NestedObject NestedAttributeObject + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.List +} + +func (a ListNestedAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.ListNestedAttribute{NestedObject: a.NestedObject.ToDataSourceAttribute(), Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a ListNestedAttribute) ToResourceAttribute() schema.Attribute { + return schema.ListNestedAttribute{NestedObject: a.NestedObject.ToResourceAttribute(), Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a ListNestedAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a ListNestedAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a ListNestedAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a ListNestedAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a ListNestedAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a ListNestedAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a ListNestedAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.List)) + return a +} + +type NestedAttributeObject struct { + Attributes map[string]Attribute +} + +func (a NestedAttributeObject) ToDataSourceAttribute() dataschema.NestedAttributeObject { + dataSourceAttributes := ToDataSourceAttributeMap(a.Attributes) + + return dataschema.NestedAttributeObject{ + Attributes: dataSourceAttributes, + } +} + +func (a NestedAttributeObject) ToResourceAttribute() schema.NestedAttributeObject { + resourceAttributes := ToResourceAttributeMap(a.Attributes) + + return schema.NestedAttributeObject{ + Attributes: resourceAttributes, + } +} + +func ToDataSourceAttributeMap(attributes map[string]Attribute) map[string]dataschema.Attribute { + dataSourceAttributes := make(map[string]dataschema.Attribute) + + for key, attribute := range attributes { + dataSourceAttributes[key] = attribute.ToDataSourceAttribute() + } + + return dataSourceAttributes +} + +func ToResourceAttributeMap(attributes map[string]Attribute) map[string]schema.Attribute { + resourceAttributes := make(map[string]schema.Attribute) + + for key, attribute := range attributes { + resourceAttributes[key] = attribute.ToResourceAttribute() + } + + return resourceAttributes +} + +type MapNestedAttribute struct { + NestedObject NestedAttributeObject + Optional bool + Required bool + Sensitive bool + Computed bool + DeprecationMessage string + Validators []validator.Map +} + +func (a MapNestedAttribute) ToDataSourceAttribute() dataschema.Attribute { + return dataschema.MapNestedAttribute{NestedObject: a.NestedObject.ToDataSourceAttribute(), Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a MapNestedAttribute) ToResourceAttribute() schema.Attribute { + return schema.MapNestedAttribute{NestedObject: a.NestedObject.ToResourceAttribute(), Optional: a.Optional, Required: a.Required, Sensitive: a.Sensitive, DeprecationMessage: a.DeprecationMessage, Computed: a.Computed, Validators: a.Validators} +} + +func (a MapNestedAttribute) SetOptional() Attribute { + a.Optional = true + a.Required = false + return a +} + +func (a MapNestedAttribute) SetRequired() Attribute { + a.Optional = false + a.Required = true + return a +} + +func (a MapNestedAttribute) SetSensitive() Attribute { + a.Sensitive = true + return a +} + +func (a MapNestedAttribute) SetComputed() Attribute { + a.Computed = true + return a +} + +func (a MapNestedAttribute) SetReadOnly() Attribute { + a.Computed = true + a.Optional = false + a.Required = false + return a +} + +func (a MapNestedAttribute) SetDeprecated(msg string) Attribute { + a.DeprecationMessage = msg + return a +} + +func (a MapNestedAttribute) AddValidators(v any) Attribute { + a.Validators = append(a.Validators, v.(validator.Map)) + return a +} diff --git a/pluginframework/common/go_to_tf.go b/pluginframework/common/go_to_tf.go new file mode 100644 index 0000000000..9f6ce66932 --- /dev/null +++ b/pluginframework/common/go_to_tf.go @@ -0,0 +1,207 @@ +package pluginframework + +import ( + "context" + "fmt" + "reflect" + + "github.com/databricks/databricks-sdk-go/logger" + "github.com/databricks/terraform-provider-databricks/common" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +// Converts a go-sdk struct into a tfsdk struct. +func GoSdkToTfSdkStruct(gosdk interface{}, tfsdk interface{}, ctx context.Context) diag.Diagnostics { + + srcVal := reflect.ValueOf(gosdk) + destVal := reflect.ValueOf(tfsdk) + + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() + } + + if destVal.Kind() != reflect.Ptr { + return diag.Diagnostics{diag.NewErrorDiagnostic("please provide a pointer for the tfsdk struct", "tfsdk to gosdk struct conversion failure")} + } + destVal = destVal.Elem() + + if srcVal.Kind() != reflect.Struct || destVal.Kind() != reflect.Struct { + return diag.Diagnostics{diag.NewErrorDiagnostic(fmt.Sprintf("input should be structs %s, %s", srcVal.Type().Name(), destVal.Type().Name()), "tfsdk to gosdk struct conversion failure")} + } + + var forceSendFieldVal []string + + forceSendField := srcVal.FieldByName("ForceSendFields") + if !forceSendField.IsValid() { + // If no forceSendField, just use an empty list. + forceSendFieldVal = []string{} + } else { + switch forceSendField.Interface().(type) { + case []string: + default: + panic(fmt.Errorf("ForceSendField is not of type []string")) + } + forceSendFieldVal = forceSendField.Interface().([]string) + } + + for _, field := range common.ListAllFields(srcVal) { + srcField := field.V + srcFieldName := field.Sf.Name + + destField := destVal.FieldByName(srcFieldName) + srcFieldTag := field.Sf.Tag.Get("json") + if srcFieldTag == "-" { + continue + } + err := goSdkToTfSdkSingleField(srcField, destField, srcFieldName, forceSendFieldVal, false, ctx) + if err != nil { + return diag.Diagnostics{diag.NewErrorDiagnostic(err.Error(), "gosdk to tfsdk field conversion failure")} + } + } + return nil +} + +func goSdkToTfSdkSingleField(srcField reflect.Value, destField reflect.Value, srcFieldName string, forceSendField []string, alwaysAdd bool, ctx context.Context) error { + + if !destField.IsValid() { + // Skip field that destination struct does not have. + logger.Tracef(ctx, fmt.Sprintf("field skipped in gosdk to tfsdk conversion: destination struct does not have field %s", srcFieldName)) + return nil + } + + if !destField.CanSet() { + panic(fmt.Errorf("destination field can not be set: %s", destField.Type().Name())) + } + + srcFieldValue := srcField.Interface() + + if srcFieldValue == nil { + return nil + } + + switch srcField.Kind() { + case reflect.Ptr: + if srcField.IsNil() { + // Skip nils + return nil + } + destField.Set(reflect.New(destField.Type().Elem())) + + // Recursively populate the nested struct. + if diags := GoSdkToTfSdkStruct(srcFieldValue, destField.Interface(), ctx); diags.ErrorsCount() > 0 { + panic("Error converting gosdk to tfsdk struct") + } + case reflect.Bool: + boolVal := srcFieldValue.(bool) + // check if alwaysAdd is false or the value is zero or if the field is in the forceSendFields list + if alwaysAdd || !(!boolVal && !checkTheStringInForceSendFields(srcFieldName, forceSendField)) { + destField.Set(reflect.ValueOf(types.BoolValue(boolVal))) + } else { + destField.Set(reflect.ValueOf(types.BoolNull())) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // convert any kind of integer to int64 + intVal := srcField.Convert(reflect.TypeOf(int64(0))).Interface().(int64) + // check if alwaysAdd is true or the value is zero or if the field is in the forceSendFields list + if alwaysAdd || !(intVal == 0 && !checkTheStringInForceSendFields(srcFieldName, forceSendField)) { + destField.Set(reflect.ValueOf(types.Int64Value(int64(intVal)))) + } else { + destField.Set(reflect.ValueOf(types.Int64Null())) + } + case reflect.Float32, reflect.Float64: + // convert any kind of float to float64 + float64Val := srcField.Convert(reflect.TypeOf(float64(0))).Interface().(float64) + // check if alwaysAdd is true or the value is zero or if the field is in the forceSendFields list + if alwaysAdd || !(float64Val == 0 && !checkTheStringInForceSendFields(srcFieldName, forceSendField)) { + destField.Set(reflect.ValueOf(types.Float64Value(float64Val))) + } else { + destField.Set(reflect.ValueOf(types.Float64Null())) + } + case reflect.String: + var strVal string + if srcField.Type().Name() != "string" { + // This case is for Enum Types. + var stringMethod reflect.Value + if srcField.CanAddr() { + stringMethod = srcField.Addr().MethodByName("String") + } else { + // Create a new addressable variable to call the String method + addr := reflect.New(srcField.Type()).Elem() + addr.Set(srcField) + stringMethod = addr.Addr().MethodByName("String") + } + if stringMethod.IsValid() { + stringResult := stringMethod.Call(nil) + if len(stringResult) == 1 { + strVal = stringResult[0].Interface().(string) + } else { + panic("num get string has more than one result") + } + } else { + panic("enum does not have valid .String() method") + } + } else { + strVal = srcFieldValue.(string) + } + // check if alwaysAdd is false or the value is zero or if the field is in the forceSendFields list + if alwaysAdd || !(strVal == "" && !checkTheStringInForceSendFields(srcFieldName, forceSendField)) { + destField.Set(reflect.ValueOf(types.StringValue(strVal))) + } else { + destField.Set(reflect.ValueOf(types.StringNull())) + } + case reflect.Struct: + if srcField.IsZero() { + // Skip zeros + return nil + } + // resolve the nested struct by recursively calling the function + if GoSdkToTfSdkStruct(srcFieldValue, destField.Addr().Interface(), ctx).ErrorsCount() > 0 { + panic("Error converting gosdk to tfsdk struct") + } + case reflect.Slice: + if srcField.IsNil() { + // Skip nils + return nil + } + destSlice := reflect.MakeSlice(destField.Type(), srcField.Len(), srcField.Cap()) + for j := 0; j < srcField.Len(); j++ { + + srcElem := srcField.Index(j) + + destElem := destSlice.Index(j) + if err := goSdkToTfSdkSingleField(srcElem, destElem, "", nil, true, ctx); err != nil { + return err + } + } + destField.Set(destSlice) + case reflect.Map: + if srcField.IsNil() { + // Skip nils + return nil + } + destMap := reflect.MakeMap(destField.Type()) + for _, key := range srcField.MapKeys() { + srcMapValue := srcField.MapIndex(key) + destMapValue := reflect.New(destField.Type().Elem()).Elem() + destMapKey := reflect.ValueOf(key.Interface()) + if err := goSdkToTfSdkSingleField(srcMapValue, destMapValue, "", nil, true, ctx); err != nil { + return err + } + destMap.SetMapIndex(destMapKey, destMapValue) + } + destField.Set(destMap) + default: + panic(fmt.Errorf("unknown type for field: %s", srcField.Type().Name())) + } + return nil +} + +func checkTheStringInForceSendFields(fieldName string, forceSendFields []string) bool { + for _, field := range forceSendFields { + if field == fieldName { + return true + } + } + return false +} diff --git a/pluginframework/common/reflect_resources.go b/pluginframework/common/reflect_resources.go new file mode 100644 index 0000000000..2db3fbd0ef --- /dev/null +++ b/pluginframework/common/reflect_resources.go @@ -0,0 +1,149 @@ +package pluginframework + +import ( + "fmt" + "reflect" + "strings" + + "github.com/databricks/terraform-provider-databricks/common" + dataschema "github.com/hashicorp/terraform-plugin-framework/datasource/schema" + "github.com/hashicorp/terraform-plugin-framework/resource/schema" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +func pluginFrameworkTypeToSchema(v reflect.Value) map[string]Attribute { + scm := map[string]Attribute{} + rk := v.Kind() + if rk == reflect.Ptr { + v = v.Elem() + rk = v.Kind() + } + if rk != reflect.Struct { + panic(fmt.Errorf("Schema value of Struct is expected, but got %s: %#v", rk.String(), v)) + } + fields := common.ListAllFields(v) + for _, field := range fields { + typeField := field.Sf + fieldName := typeField.Tag.Get("tfsdk") + if fieldName == "-" { + continue + } + isOptional := fieldIsOptional(typeField) + // For now everything is marked as optional. TODO: add tf annotations for computed, optional, etc. + kind := typeField.Type.Kind() + if kind == reflect.Ptr { + elem := typeField.Type.Elem() + sv := reflect.New(elem).Elem() + nestedScm := pluginFrameworkTypeToSchema(sv) + scm[fieldName] = SingleNestedAttribute{Attributes: nestedScm, Optional: isOptional, Required: !isOptional} + } else if kind == reflect.Slice { + elem := typeField.Type.Elem() + if elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + if elem.Kind() != reflect.Struct { + panic(fmt.Errorf("unsupported slice value for %s: %s", fieldName, elem.Kind().String())) + } + switch elem { + case reflect.TypeOf(types.Bool{}): + scm[fieldName] = ListAttribute{ElementType: types.BoolType, Optional: isOptional, Required: !isOptional} + case reflect.TypeOf(types.Int64{}): + scm[fieldName] = ListAttribute{ElementType: types.Int64Type, Optional: isOptional, Required: !isOptional} + case reflect.TypeOf(types.Float64{}): + scm[fieldName] = ListAttribute{ElementType: types.Float64Type, Optional: isOptional, Required: !isOptional} + case reflect.TypeOf(types.String{}): + scm[fieldName] = ListAttribute{ElementType: types.StringType, Optional: isOptional, Required: !isOptional} + default: + // Nested struct + nestedScm := pluginFrameworkTypeToSchema(reflect.New(elem).Elem()) + scm[fieldName] = ListNestedAttribute{NestedObject: NestedAttributeObject{Attributes: nestedScm}, Optional: isOptional, Required: !isOptional} + } + } else if kind == reflect.Map { + elem := typeField.Type.Elem() + if elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + if elem.Kind() != reflect.Struct { + panic(fmt.Errorf("unsupported map value for %s: %s", fieldName, elem.Kind().String())) + } + switch elem { + case reflect.TypeOf(types.Bool{}): + scm[fieldName] = MapAttribute{ElementType: types.BoolType, Optional: isOptional, Required: !isOptional} + case reflect.TypeOf(types.Int64{}): + scm[fieldName] = MapAttribute{ElementType: types.Int64Type, Optional: isOptional, Required: !isOptional} + case reflect.TypeOf(types.Float64{}): + scm[fieldName] = MapAttribute{ElementType: types.Float64Type, Optional: isOptional, Required: !isOptional} + case reflect.TypeOf(types.String{}): + scm[fieldName] = MapAttribute{ElementType: types.StringType, Optional: isOptional, Required: !isOptional} + default: + // Nested struct + nestedScm := pluginFrameworkTypeToSchema(reflect.New(elem).Elem()) + scm[fieldName] = MapNestedAttribute{NestedObject: NestedAttributeObject{Attributes: nestedScm}, Optional: isOptional, Required: !isOptional} + } + } else if kind == reflect.Struct { + switch field.V.Interface().(type) { + case types.Bool: + scm[fieldName] = BoolAttribute{Optional: isOptional, Required: !isOptional} + case types.Int64: + scm[fieldName] = Int64Attribute{Optional: isOptional, Required: !isOptional} + case types.Float64: + scm[fieldName] = Float64Attribute{Optional: isOptional, Required: !isOptional} + case types.String: + scm[fieldName] = StringAttribute{Optional: isOptional, Required: !isOptional} + case types.List: + panic("types.List should never be used in tfsdk structs") + case types.Map: + panic("types.Map should never be used in tfsdk structs") + default: + // If it is a real stuct instead of a tfsdk type, recursively resolve it. + elem := typeField.Type + sv := reflect.New(elem) + nestedScm := pluginFrameworkTypeToSchema(sv) + scm[fieldName] = SingleNestedAttribute{Attributes: nestedScm, Optional: isOptional, Required: !isOptional} + } + } else if kind == reflect.String { + // This case is for Enum types. + scm[fieldName] = StringAttribute{Optional: isOptional, Required: !isOptional} + } else { + panic(fmt.Errorf("unknown type for field: %s", typeField.Name)) + } + } + return scm +} + +func fieldIsOptional(field reflect.StructField) bool { + tagValue := field.Tag.Get("tf") + return strings.Contains(tagValue, "optional") +} + +func PluginFrameworkResourceStructToSchema(v any, customizeSchema func(CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework) schema.Schema { + attributes := PluginFrameworkResourceStructToSchemaMap(v, customizeSchema) + return schema.Schema{Attributes: attributes} +} + +func PluginFrameworkDataSourceStructToSchema(v any, customizeSchema func(CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework) dataschema.Schema { + attributes := PluginFrameworkDataSourceStructToSchemaMap(v, customizeSchema) + return dataschema.Schema{Attributes: attributes} +} + +func PluginFrameworkResourceStructToSchemaMap(v any, customizeSchema func(CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework) map[string]schema.Attribute { + attributes := pluginFrameworkTypeToSchema(reflect.ValueOf(v)) + + if customizeSchema != nil { + cs := customizeSchema(*ConstructCustomizableSchema(attributes)) + return ToResourceAttributeMap(cs.ToAttributeMap()) + } else { + return ToResourceAttributeMap(attributes) + } +} + +func PluginFrameworkDataSourceStructToSchemaMap(v any, customizeSchema func(CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework) map[string]dataschema.Attribute { + attributes := pluginFrameworkTypeToSchema(reflect.ValueOf(v)) + + if customizeSchema != nil { + cs := customizeSchema(*ConstructCustomizableSchema(attributes)) + return ToDataSourceAttributeMap(cs.ToAttributeMap()) + } else { + return ToDataSourceAttributeMap(attributes) + } +} diff --git a/pluginframework/common/tf_to_go.go b/pluginframework/common/tf_to_go.go new file mode 100644 index 0000000000..420d956808 --- /dev/null +++ b/pluginframework/common/tf_to_go.go @@ -0,0 +1,214 @@ +package pluginframework + +import ( + "context" + "fmt" + "log" + "reflect" + + "github.com/databricks/databricks-sdk-go/logger" + "github.com/databricks/terraform-provider-databricks/common" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/types" +) + +// Converts a tfsdk struct into a go-sdk struct, with the folowing rules. +// +// types.String -> string +// types.Bool -> bool +// types.Int64 -> int64 +// types.Float64 -> float64 +// types.String -> string +// +// NOTE: +// +// ForceSendFields are populated for string, bool, int64, float64 on non null values +// types.list and types.map are not supported +// map keys should always be a string +// tfsdk structs use types.String for all enum values +func TfSdkToGoSdkStruct(tfsdk interface{}, gosdk interface{}, ctx context.Context) diag.Diagnostics { + srcVal := reflect.ValueOf(tfsdk) + destVal := reflect.ValueOf(gosdk) + + if srcVal.Kind() == reflect.Ptr { + srcVal = srcVal.Elem() + } + + if destVal.Kind() != reflect.Ptr { + return diag.Diagnostics{diag.NewErrorDiagnostic(fmt.Sprintf("please provide a pointer for the gosdk struct, got %s", destVal.Type().Name()), "tfsdk to gosdk struct conversion failure")} + } + destVal = destVal.Elem() + + if srcVal.Kind() != reflect.Struct { + return diag.Diagnostics{diag.NewErrorDiagnostic(fmt.Sprintf("input should be structs, got %s,", srcVal.Type().Name()), "tfsdk to gosdk struct conversion failure")} + } + + forceSendFieldsField := destVal.FieldByName("ForceSendFields") + + allFields := common.ListAllFields(srcVal) + for _, field := range allFields { + srcField := field.V + srcFieldName := field.Sf.Name + + srcFieldTag := field.Sf.Tag.Get("tfsdk") + if srcFieldTag == "-" { + continue + } + + destField := destVal.FieldByName(srcFieldName) + + err := tfSdkToGoSdkSingleField(srcField, destField, srcFieldName, &forceSendFieldsField, ctx) + if err != nil { + return diag.Diagnostics{diag.NewErrorDiagnostic(err.Error(), "tfsdk to gosdk field conversion failure")} + } + } + + return diag.Diagnostics{} +} + +func tfSdkToGoSdkSingleField(srcField reflect.Value, destField reflect.Value, srcFieldName string, forceSendFieldsField *reflect.Value, ctx context.Context) error { + + if !destField.IsValid() { + // Skip field that destination struct does not have. + logger.Tracef(ctx, fmt.Sprintf("field skipped in tfsdk to gosdk conversion: destination struct does not have field %s", srcFieldName)) + return nil + } + + if !destField.CanSet() { + panic(fmt.Errorf("destination field can not be set: %s", destField.Type().Name())) + } + srcFieldValue := srcField.Interface() + + if srcFieldValue == nil { + return nil + } else if srcField.Kind() == reflect.Ptr { + if srcField.IsNil() { + // Skip nils + return nil + } + // Allocate new memory for the destination field + destField.Set(reflect.New(destField.Type().Elem())) + + // Recursively populate the nested struct. + if TfSdkToGoSdkStruct(srcFieldValue, destField.Interface(), ctx).ErrorsCount() > 0 { + panic("Error converting tfsdk to gosdk struct") + } + } else if srcField.Kind() == reflect.Struct { + switch v := srcFieldValue.(type) { + case types.Bool: + destField.SetBool(v.ValueBool()) + if !v.IsNull() { + addToForceSendFields(srcFieldName, forceSendFieldsField) + } + case types.Int64: + destField.SetInt(v.ValueInt64()) + if !v.IsNull() { + addToForceSendFields(srcFieldName, forceSendFieldsField) + } + case types.Float64: + destField.SetFloat(v.ValueFloat64()) + if !v.IsNull() { + addToForceSendFields(srcFieldName, forceSendFieldsField) + } + case types.String: + if destField.Type().Name() != "string" { + // This is the case for enum. + + // Skip unset value. + if srcField.IsZero() { + return nil + } + + // Find the Set method + destVal := reflect.New(destField.Type()) + setMethod := destVal.MethodByName("Set") + if !setMethod.IsValid() { + panic(fmt.Sprintf("set method not found on enum type: %s", destField.Type().Name())) + } + + // Prepare the argument for the Set method + arg := reflect.ValueOf(v.ValueString()) + + // Call the Set method + result := setMethod.Call([]reflect.Value{arg}) + if len(result) != 0 { + if err, ok := result[0].Interface().(error); ok && err != nil { + panic(err) + } + } + // We don't need to set ForceSendFields for enums because the value is never going to be a zero value (empty string). + destField.Set(destVal.Elem()) + } else { + destField.SetString(v.ValueString()) + if !v.IsNull() { + addToForceSendFields(srcFieldName, forceSendFieldsField) + } + } + case types.List: + panic("types.List should never be used, use go native slices instead") + case types.Map: + panic("types.Map should never be used, use go native maps instead") + default: + if srcField.IsZero() { + // Skip zeros + return nil + } + // If it is a real stuct instead of a tfsdk type, recursively resolve it. + if TfSdkToGoSdkStruct(srcFieldValue, destField.Addr().Interface(), ctx).ErrorsCount() > 0 { + panic("Error converting tfsdk to gosdk struct") + } + } + } else if srcField.Kind() == reflect.Slice { + if srcField.IsNil() { + // Skip nils + return nil + } + destSlice := reflect.MakeSlice(destField.Type(), srcField.Len(), srcField.Cap()) + for j := 0; j < srcField.Len(); j++ { + + srcElem := srcField.Index(j) + + destElem := destSlice.Index(j) + if err := tfSdkToGoSdkSingleField(srcElem, destElem, "", nil, ctx); err != nil { + return err + } + } + destField.Set(destSlice) + } else if srcField.Kind() == reflect.Map { + if srcField.IsNil() { + // Skip nils + return nil + } + destMap := reflect.MakeMap(destField.Type()) + for _, key := range srcField.MapKeys() { + srcMapValue := srcField.MapIndex(key) + destMapValue := reflect.New(destField.Type().Elem()).Elem() + destMapKey := reflect.ValueOf(key.Interface()) + if err := tfSdkToGoSdkSingleField(srcMapValue, destMapValue, "", nil, ctx); err != nil { + return err + } + destMap.SetMapIndex(destMapKey, destMapValue) + } + destField.Set(destMap) + } else { + panic(fmt.Errorf("unknown type for field: %s", srcField.Type().Name())) + } + return nil + +} + +func addToForceSendFields(fieldName string, forceSendFieldsField *reflect.Value) { + if forceSendFieldsField == nil || !forceSendFieldsField.IsValid() || !forceSendFieldsField.CanSet() { + log.Printf("[Debug] forceSendFieldsField is nil, invalid or not settable. %s", fieldName) + return + } + // Initialize forceSendFields if it is a zero Value + if forceSendFieldsField.IsZero() { + // Create a new slice of strings + newSlice := []string{} + forceSendFieldsField.Set(reflect.ValueOf(&newSlice).Elem()) + } + forceSendFields := forceSendFieldsField.Interface().([]string) + forceSendFields = append(forceSendFields, fieldName) + forceSendFieldsField.Set(reflect.ValueOf(forceSendFields)) +}