Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardfeng-db committed Aug 21, 2024
1 parent 0e9f500 commit fd472ac
Show file tree
Hide file tree
Showing 3 changed files with 410 additions and 0 deletions.
113 changes: 113 additions & 0 deletions internal/pluginframework/tfschema/customizable_schema_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package tfschema

import (
"context"
"fmt"
"testing"

"github.com/hashicorp/terraform-plugin-framework/resource/schema"
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/stretchr/testify/assert"
)

type TestTfSdk struct {
Description types.String `tfsdk:"description" tf:""`
Nested *NestedTfSdk `tfsdk:"nested" tf:"optional"`
Map map[string]types.String `tfsdk:"map" tf:"optional"`
}

type NestedTfSdk struct {
Name types.String `tfsdk:"name" tf:"optional"`
Enabled types.Bool `tfsdk:"enabled" tf:"optional"`
}

type stringLengthBetweenValidator struct {
Max int
Min int
}

// Description returns a plain text description of the validator's behavior, suitable for a practitioner to understand its impact.
func (v stringLengthBetweenValidator) Description(ctx context.Context) string {
return fmt.Sprintf("string length must be between %d and %d", v.Min, v.Max)
}

// MarkdownDescription returns a markdown formatted description of the validator's behavior, suitable for a practitioner to understand its impact.
func (v stringLengthBetweenValidator) MarkdownDescription(ctx context.Context) string {
return fmt.Sprintf("string length must be between `%d` and `%d`", v.Min, v.Max)
}

// Validate runs the main validation logic of the validator, reading configuration data out of `req` and updating `resp` with diagnostics.
func (v stringLengthBetweenValidator) ValidateString(ctx context.Context, req validator.StringRequest, resp *validator.StringResponse) {
// If the value is unknown or null, there is nothing to validate.
if req.ConfigValue.IsUnknown() || req.ConfigValue.IsNull() {
return
}

strLen := len(req.ConfigValue.ValueString())

if strLen < v.Min || strLen > v.Max {
resp.Diagnostics.AddAttributeError(
req.Path,
"Invalid String Length",
fmt.Sprintf("String length must be between %d and %d, got: %d.", v.Min, v.Max, strLen),
)

return
}
}

func TestCustomizeSchemaSetRequired(t *testing.T) {
scm := ResourceStructToSchema(TestTfSdk{}, func(c CustomizableSchema) CustomizableSchema {
c.SetRequired("nested", "enabled")
return c
})

assert.True(t, scm.Attributes["nested"].(schema.SingleNestedAttribute).Attributes["enabled"].IsRequired())
}

func TestCustomizeSchemaSetOptional(t *testing.T) {
scm := ResourceStructToSchema(TestTfSdk{}, func(c CustomizableSchema) CustomizableSchema {
c.SetOptional("description")
return c
})

assert.True(t, scm.Attributes["description"].IsOptional())
}

func TestCustomizeSchemaSetSensitive(t *testing.T) {
scm := ResourceStructToSchema(TestTfSdk{}, func(c CustomizableSchema) CustomizableSchema {
c.SetSensitive("nested", "name")
return c
})

assert.True(t, scm.Attributes["nested"].(schema.SingleNestedAttribute).Attributes["name"].IsSensitive())
}

func TestCustomizeSchemaSetDeprecated(t *testing.T) {
scm := ResourceStructToSchema(TestTfSdk{}, func(c CustomizableSchema) CustomizableSchema {
c.SetDeprecated("deprecated", "map")
return c
})

assert.True(t, scm.Attributes["map"].GetDeprecationMessage() == "deprecated")
}

func TestCustomizeSchemaSetReadOnly(t *testing.T) {
scm := ResourceStructToSchema(TestTfSdk{}, func(c CustomizableSchema) CustomizableSchema {
c.SetReadOnly("map")
return c
})
assert.True(t, !scm.Attributes["map"].IsOptional())
assert.True(t, !scm.Attributes["map"].IsRequired())
assert.True(t, scm.Attributes["map"].IsComputed())
}

func TestCustomizeSchemaAddValidator(t *testing.T) {
scm := ResourceStructToSchema(TestTfSdk{}, func(c CustomizableSchema) CustomizableSchema {
c.AddValidator(stringLengthBetweenValidator{}, "description")
return c
})

assert.True(t, len(scm.Attributes["description"].(schema.StringAttribute).Validators) == 1)
}
145 changes: 145 additions & 0 deletions internal/pluginframework/tfschema/struct_to_schema.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package tfschema

import (
"fmt"
"reflect"
"strings"

"github.com/databricks/terraform-provider-databricks/internal/tfreflect"
dataschema "github.com/hashicorp/terraform-plugin-framework/datasource/schema"
"github.com/hashicorp/terraform-plugin-framework/resource/schema"
"github.com/hashicorp/terraform-plugin-framework/types"
)

func typeToSchema(v reflect.Value) map[string]AttributeBuilder {
scm := map[string]AttributeBuilder{}
rk := v.Kind()
if rk == reflect.Ptr {
v = v.Elem()
rk = v.Kind()
}
if rk != reflect.Struct {
panic(fmt.Errorf("schema value of Struct is expected, but got %s: %#v", rk.String(), v))
}
fields := tfreflect.ListAllFields(v)
for _, field := range fields {
typeField := field.StructField
fieldName := typeField.Tag.Get("tfsdk")
if fieldName == "-" {
continue
}
isOptional := fieldIsOptional(typeField)
kind := typeField.Type.Kind()
if kind == reflect.Ptr {
elem := typeField.Type.Elem()
sv := reflect.New(elem).Elem()
nestedScm := typeToSchema(sv)
scm[fieldName] = SingleNestedAttributeBuilder{Attributes: nestedScm, Optional: isOptional, Required: !isOptional}
} else if kind == reflect.Slice {
elemType := typeField.Type.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() != reflect.Struct {
panic(fmt.Errorf("unsupported slice value for %s: %s", fieldName, elemType.Kind().String()))
}
switch elemType {
case reflect.TypeOf(types.Bool{}):
scm[fieldName] = ListAttributeBuilder{ElementType: types.BoolType, Optional: isOptional, Required: !isOptional}
case reflect.TypeOf(types.Int64{}):
scm[fieldName] = ListAttributeBuilder{ElementType: types.Int64Type, Optional: isOptional, Required: !isOptional}
case reflect.TypeOf(types.Float64{}):
scm[fieldName] = ListAttributeBuilder{ElementType: types.Float64Type, Optional: isOptional, Required: !isOptional}
case reflect.TypeOf(types.String{}):
scm[fieldName] = ListAttributeBuilder{ElementType: types.StringType, Optional: isOptional, Required: !isOptional}
default:
// Nested struct
nestedScm := typeToSchema(reflect.New(elemType).Elem())
scm[fieldName] = ListNestedAttributeBuilder{NestedObject: NestedAttributeObject{Attributes: nestedScm}, Optional: isOptional, Required: !isOptional}
}
} else if kind == reflect.Map {
elemType := typeField.Type.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() != reflect.Struct {
panic(fmt.Errorf("unsupported map value for %s: %s", fieldName, elemType.Kind().String()))
}
switch elemType {
case reflect.TypeOf(types.Bool{}):
scm[fieldName] = MapAttributeBuilder{ElementType: types.BoolType, Optional: isOptional, Required: !isOptional}
case reflect.TypeOf(types.Int64{}):
scm[fieldName] = MapAttributeBuilder{ElementType: types.Int64Type, Optional: isOptional, Required: !isOptional}
case reflect.TypeOf(types.Float64{}):
scm[fieldName] = MapAttributeBuilder{ElementType: types.Float64Type, Optional: isOptional, Required: !isOptional}
case reflect.TypeOf(types.String{}):
scm[fieldName] = MapAttributeBuilder{ElementType: types.StringType, Optional: isOptional, Required: !isOptional}
default:
// Nested struct
nestedScm := typeToSchema(reflect.New(elemType).Elem())
scm[fieldName] = MapNestedAttributeBuilder{NestedObject: NestedAttributeObject{Attributes: nestedScm}, Optional: isOptional, Required: !isOptional}
}
} else if kind == reflect.Struct {
switch field.Value.Interface().(type) {
case types.Bool:
scm[fieldName] = BoolAttributeBuilder{Optional: isOptional, Required: !isOptional}
case types.Int64:
scm[fieldName] = Int64AttributeBuilder{Optional: isOptional, Required: !isOptional}
case types.Float64:
scm[fieldName] = Float64AttributeBuilder{Optional: isOptional, Required: !isOptional}
case types.String:
scm[fieldName] = StringAttributeBuilder{Optional: isOptional, Required: !isOptional}
case types.List:
panic("types.List should never be used in tfsdk structs")
case types.Map:
panic("types.Map should never be used in tfsdk structs")
default:
// If it is a real stuct instead of a tfsdk type, recursively resolve it.
elem := typeField.Type
sv := reflect.New(elem)
nestedScm := typeToSchema(sv)
scm[fieldName] = SingleNestedAttributeBuilder{Attributes: nestedScm, Optional: isOptional, Required: !isOptional}
}
} else {
panic(fmt.Errorf("unknown type for field: %s", typeField.Name))
}
}
return scm
}

func fieldIsOptional(field reflect.StructField) bool {
tagValue := field.Tag.Get("tf")
return strings.Contains(tagValue, "optional")
}

func ResourceStructToSchema(v any, customizeSchema func(CustomizableSchema) CustomizableSchema) schema.Schema {
attributes := ResourceStructToSchemaMap(v, customizeSchema)
return schema.Schema{Attributes: attributes}
}

func DataSourceStructToSchema(v any, customizeSchema func(CustomizableSchema) CustomizableSchema) dataschema.Schema {
attributes := DataSourceStructToSchemaMap(v, customizeSchema)
return dataschema.Schema{Attributes: attributes}
}

func ResourceStructToSchemaMap(v any, customizeSchema func(CustomizableSchema) CustomizableSchema) map[string]schema.Attribute {
attributes := typeToSchema(reflect.ValueOf(v))

if customizeSchema != nil {
cs := customizeSchema(*ConstructCustomizableSchema(attributes))
return BuildResourceAttributeMap(cs.ToAttributeMap())
} else {
return BuildResourceAttributeMap(attributes)
}
}

func DataSourceStructToSchemaMap(v any, customizeSchema func(CustomizableSchema) CustomizableSchema) map[string]dataschema.Attribute {
attributes := typeToSchema(reflect.ValueOf(v))

if customizeSchema != nil {
cs := customizeSchema(*ConstructCustomizableSchema(attributes))
return BuildDataSourceAttributeMap(cs.ToAttributeMap())
} else {
return BuildDataSourceAttributeMap(attributes)
}
}
Loading

0 comments on commit fd472ac

Please sign in to comment.