Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make reflect_resource and customizable_schema work for data source #3880

Open
wants to merge 5 commits into
base: terraform-plugin-framework
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 24 additions & 245 deletions common/customizable_schema_plugin_framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,30 @@ package common

import (
"fmt"
"reflect"

"github.com/hashicorp/terraform-plugin-framework/resource/schema"
)

type CustomizableSchemaPluginFramework struct {
attr schema.Attribute
attr Attribute
}

func ConstructCustomizableSchema(attributes map[string]schema.Attribute) *CustomizableSchemaPluginFramework {
attr := schema.Attribute(schema.SingleNestedAttribute{Attributes: attributes})
func ConstructCustomizableSchema(attributes map[string]Attribute) *CustomizableSchemaPluginFramework {
attr := Attribute(SingleNestedAttribute{Attributes: attributes})
return &CustomizableSchemaPluginFramework{attr: attr}
}

// Converts CustomizableSchema into a map from string to Attribute.
func (s *CustomizableSchemaPluginFramework) ToAttributeMap() map[string]schema.Attribute {
func (s *CustomizableSchemaPluginFramework) ToAttributeMap() map[string]Attribute {
return attributeToMap(&s.attr)
}

func attributeToMap(attr *schema.Attribute) map[string]schema.Attribute {
var m map[string]schema.Attribute
func attributeToMap(attr *Attribute) map[string]Attribute {
var m map[string]Attribute
switch attr := (*attr).(type) {
case schema.SingleNestedAttribute:
case SingleNestedAttribute:
m = attr.Attributes
case schema.ListNestedAttribute:
case ListNestedAttribute:
m = attr.NestedObject.Attributes
case schema.MapNestedAttribute:
case MapNestedAttribute:
m = attr.NestedObject.Attributes
default:
panic(fmt.Errorf("cannot convert to map, attribute is not nested"))
Expand All @@ -38,37 +35,8 @@ func attributeToMap(attr *schema.Attribute) map[string]schema.Attribute {
}

func (s *CustomizableSchemaPluginFramework) AddValidator(v any, path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
val := reflect.ValueOf(attr)

// Make a copy of the existing attr.
newAttr := reflect.New(val.Type()).Elem()
newAttr.Set(val)
val = newAttr

field := val.FieldByName("Validators")
if !field.IsValid() {
panic(fmt.Sprintf("Validators field not found in %T", attr))
}
if field.Kind() != reflect.Slice {
panic(fmt.Sprintf("Validators field is not a slice in %T", attr))
}
if !field.CanSet() {
panic(fmt.Sprintf("Validators field cannot be set in %T", attr))
}

elemType := field.Type().Elem()
value := reflect.ValueOf(v)

if !value.Type().AssignableTo(elemType) {
panic(fmt.Sprintf("Value of type %T is not assignable to slice of %s", v, elemType))
}

// Append the value
newSlice := reflect.Append(field, value)
field.Set(newSlice)

return val.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea: expose type-specific methods here instead of any?

return attr.AddValidators(v)
}

navigateSchemaWithCallback(&s.attr, cb, path...)
Expand All @@ -77,38 +45,8 @@ func (s *CustomizableSchemaPluginFramework) AddValidator(v any, path ...string)
}

func (s *CustomizableSchemaPluginFramework) SetOptional(path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
// Get the concrete value stored in the interface
v := reflect.ValueOf(attr)

// Make a new addressable value and copy the original value into it
newAttr := reflect.New(v.Type()).Elem()
newAttr.Set(v)
v = newAttr

field := v.FieldByName("Required")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(false)
} else {
panic(fmt.Sprintf("Required is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Required field not found or cannot be set in %T", attr))
}

field = v.FieldByName("Optional")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(true)
} else {
panic(fmt.Sprintf("Optional is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Optional field not found or cannot be set in %T", attr))
}

return v.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
return attr.SetOptional()
}

navigateSchemaWithCallback(&s.attr, cb, path...)
Expand All @@ -117,38 +55,8 @@ func (s *CustomizableSchemaPluginFramework) SetOptional(path ...string) *Customi
}

func (s *CustomizableSchemaPluginFramework) SetRequired(path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
// Get the concrete value stored in the interface
v := reflect.ValueOf(attr)

// Make a new addressable value and copy the original value into it
newAttr := reflect.New(v.Type()).Elem()
newAttr.Set(v)
v = newAttr

field := v.FieldByName("Required")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(true)
} else {
panic(fmt.Sprintf("Required is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Required field not found or cannot be set in %T", attr))
}

field = v.FieldByName("Optional")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(false)
} else {
panic(fmt.Sprintf("Optional is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Optional field not found or cannot be set in %T", attr))
}

return v.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
return attr.SetRequired()
}

navigateSchemaWithCallback(&s.attr, cb, path...)
Expand All @@ -157,55 +65,17 @@ func (s *CustomizableSchemaPluginFramework) SetRequired(path ...string) *Customi
}

func (s *CustomizableSchemaPluginFramework) SetSensitive(path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
// Get the concrete value stored in the interface
v := reflect.ValueOf(attr)

// Make a new addressable value and copy the original value into it
newAttr := reflect.New(v.Type()).Elem()
newAttr.Set(v)
v = newAttr

field := v.FieldByName("Sensitive")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(true)
} else {
panic(fmt.Sprintf("Sensitive is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Sensitive field not found or cannot be set in %T", attr))
}

return v.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
return attr.SetSensitive()
}

navigateSchemaWithCallback(&s.attr, cb, path...)
return s
}

func (s *CustomizableSchemaPluginFramework) SetDeprecated(msg string, path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
// Get the concrete value stored in the interface
v := reflect.ValueOf(attr)

// Make a new addressable value and copy the original value into it
newAttr := reflect.New(v.Type()).Elem()
newAttr.Set(v)
v = newAttr

field := v.FieldByName("DeprecationMessage")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.String {
field.SetString(msg)
} else {
panic(fmt.Sprintf("DeprecationMessage is not a string field in %T", attr))
}
} else {
panic(fmt.Sprintf("DeprecationMessage field not found or cannot be set in %T", attr))
}

return v.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
return attr.SetDeprecated(msg)
}

navigateSchemaWithCallback(&s.attr, cb, path...)
Expand All @@ -214,27 +84,8 @@ func (s *CustomizableSchemaPluginFramework) SetDeprecated(msg string, path ...st
}

func (s *CustomizableSchemaPluginFramework) SetComputed(path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
// Get the concrete value stored in the interface
v := reflect.ValueOf(attr)

// Make a new addressable value and copy the original value into it
newAttr := reflect.New(v.Type()).Elem()
newAttr.Set(v)
v = newAttr

field := v.FieldByName("Computed")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(true)
} else {
panic(fmt.Sprintf("Computed is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Computed field not found or cannot be set in %T", attr))
}

return v.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
return attr.SetComputed()
}

navigateSchemaWithCallback(&s.attr, cb, path...)
Expand All @@ -245,89 +96,17 @@ func (s *CustomizableSchemaPluginFramework) SetComputed(path ...string) *Customi
// This should be used for fields that are not user-configurable but are returned
// by the platform.
func (s *CustomizableSchemaPluginFramework) SetReadOnly(path ...string) *CustomizableSchemaPluginFramework {
cb := func(attr schema.Attribute) schema.Attribute {
// Get the concrete value stored in the interface
v := reflect.ValueOf(attr)

// Make a new addressable value and copy the original value into it
newAttr := reflect.New(v.Type()).Elem()
newAttr.Set(v)
v = newAttr

field := v.FieldByName("Computed")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(true)
} else {
panic(fmt.Sprintf("Computed is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Computed field not found or cannot be set in %T", attr))
}

field = v.FieldByName("Optional")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(false)
} else {
panic(fmt.Sprintf("Optional is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Optional field not found or cannot be set in %T", attr))
}

field = v.FieldByName("Required")
if field.IsValid() && field.CanSet() {
if field.Kind() == reflect.Bool {
field.SetBool(false)
} else {
panic(fmt.Sprintf("Required is not a bool field in %T", attr))
}
} else {
panic(fmt.Sprintf("Required field not found or cannot be set in %T", attr))
}

return v.Interface().(schema.Attribute)
cb := func(attr Attribute) Attribute {
return attr.SetReadOnly()
}

navigateSchemaWithCallback(&s.attr, cb, path...)

return s
}

// Given a attribute map, navigate through the given path, panics if the path is not valid.
func MustSchemaAttributePath(attrs map[string]schema.Attribute, path ...string) schema.Attribute {
attr := ConstructCustomizableSchema(attrs).attr

res, err := navigateSchema(&attr, path...)
if err != nil {
panic(err)
}

return res
}

// Helper function for navigating through schema attributes, panics if path does not exist or invalid.
func navigateSchema(s *schema.Attribute, path ...string) (schema.Attribute, error) {
current_scm := s
for i, p := range path {
m := attributeToMap(current_scm)

v, ok := m[p]
if !ok {
return nil, fmt.Errorf("missing key %s", p)
}

if i == len(path)-1 {
return v, nil
}
current_scm = &v
}
return nil, fmt.Errorf("path %v is incomplete", path)
}

// Helper function for navigating through schema attributes, panics if path does not exist or invalid.
func navigateSchemaWithCallback(s *schema.Attribute, cb func(schema.Attribute) schema.Attribute, path ...string) (schema.Attribute, error) {
func navigateSchemaWithCallback(s *Attribute, cb func(Attribute) Attribute, path ...string) (Attribute, error) {
current_scm := s
for i, p := range path {
m := attributeToMap(current_scm)
Expand Down
9 changes: 5 additions & 4 deletions common/customizable_schema_plugin_framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ func TestCustomizeSchema(t *testing.T) {
c.AddValidator(stringLengthBetweenValidator{}, "description")
return c
})
assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "enabled").IsRequired())
assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "name").IsSensitive())
assert.True(t, MustSchemaAttributePath(scm.Attributes, "map").GetDeprecationMessage() == "deprecated")

assert.True(t, scm.Attributes["nested"].(schema.SingleNestedAttribute).Attributes["enabled"].IsRequired())
assert.True(t, scm.Attributes["nested"].(schema.SingleNestedAttribute).Attributes["name"].IsSensitive())
assert.True(t, scm.Attributes["map"].GetDeprecationMessage() == "deprecated")
assert.True(t, scm.Attributes["description"].IsOptional())
assert.True(t, !scm.Attributes["map"].IsOptional())
assert.True(t, !scm.Attributes["map"].IsRequired())
assert.True(t, scm.Attributes["map"].IsComputed())
assert.True(t, len(MustSchemaAttributePath(scm.Attributes, "description").(schema.StringAttribute).Validators) == 1)
assert.True(t, len(scm.Attributes["description"].(schema.StringAttribute).Validators) == 1)
}
Loading
Loading