From 1562784b232500cd67a49554b36cb773e165334c Mon Sep 17 00:00:00 2001 From: Edward Feng <67326663+edwardfeng-db@users.noreply.github.com> Date: Sat, 27 Jul 2024 19:21:43 +0200 Subject: [PATCH] Make CustomizableSchema work for Plugin Framework (#3819) ## Changes - Added basic supports for CustomizableSchema - Able to add, remove fields and navigate through the schema - Feature support matrix: https://docs.google.com/document/d/1ofoLi9vlu-4L-89kBOYPbpOpPk0GZghZcC5SgnBH4ZM/edit#heading=h.wi3qt4pwk7q4 ## Tests - [x] `make test` run locally - [x] relevant change in `docs/` folder - [x] covered with integration tests in `internal/acceptance` - [x] relevant acceptance tests are passing - [x] using Go SDK --- .../customizable_schema_plugin_framework.go | 339 ++++++++++++++++++ ...stomizable_schema_plugin_framework_test.go | 71 ++++ common/reflect_resource_plugin_framework.go | 11 +- .../reflect_resource_plugin_framework_test.go | 2 +- 4 files changed, 419 insertions(+), 4 deletions(-) create mode 100644 common/customizable_schema_plugin_framework.go create mode 100644 common/customizable_schema_plugin_framework_test.go diff --git a/common/customizable_schema_plugin_framework.go b/common/customizable_schema_plugin_framework.go new file mode 100644 index 0000000000..ae0571b827 --- /dev/null +++ b/common/customizable_schema_plugin_framework.go @@ -0,0 +1,339 @@ +package common + +import ( + "fmt" + "reflect" + + "github.com/hashicorp/terraform-plugin-framework/provider/schema" +) + +type CustomizableSchemaPluginFramework struct { + attr schema.Attribute +} + +func ConstructCustomizableSchema(attributes map[string]schema.Attribute) *CustomizableSchemaPluginFramework { + attr := schema.Attribute(schema.SingleNestedAttribute{Attributes: attributes}) + return &CustomizableSchemaPluginFramework{attr: attr} +} + +// Converts CustomizableSchema into a map from string to Attribute. +func (s *CustomizableSchemaPluginFramework) ToAttributeMap() map[string]schema.Attribute { + return attributeToMap(&s.attr) +} + +func attributeToMap(attr *schema.Attribute) map[string]schema.Attribute { + var m map[string]schema.Attribute + switch attr := (*attr).(type) { + case schema.SingleNestedAttribute: + m = attr.Attributes + case schema.ListNestedAttribute: + m = attr.NestedObject.Attributes + case schema.MapNestedAttribute: + m = attr.NestedObject.Attributes + default: + panic(fmt.Errorf("cannot convert to map, attribute is not nested")) + } + + return m +} + +func (s *CustomizableSchemaPluginFramework) AddNewField(key string, newField schema.Attribute, path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + _, exists := attr.Attributes[key] + if exists { + panic("Cannot add new field, " + key + " already exists in the schema") + } + attr.Attributes[key] = newField + return attr + case schema.ListNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if exists { + panic("Cannot add new field, " + key + " already exists in the schema") + } + attr.NestedObject.Attributes[key] = newField + return attr + case schema.MapNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if exists { + panic("Cannot add new field, " + key + " already exists in the schema") + } + attr.NestedObject.Attributes[key] = newField + return attr + default: + panic("attribute is not nested, cannot add field") + } + } + + if len(path) == 0 { + s.attr = cb(s.attr) + } else { + navigateSchemaWithCallback(&s.attr, cb, path...) + } + return s +} + +func (s *CustomizableSchemaPluginFramework) RemoveField(key string, path ...string) *CustomizableSchemaPluginFramework { + cb := func(a schema.Attribute) schema.Attribute { + switch attr := a.(type) { + case schema.SingleNestedAttribute: + _, exists := attr.Attributes[key] + if !exists { + panic("Cannot remove field, " + key + " does not exist in the schema") + } + delete(attr.Attributes, key) + return attr + case schema.ListNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if !exists { + panic("Cannot remove field, " + key + " does not exist in the schema") + } + delete(attr.NestedObject.Attributes, key) + return attr + case schema.MapNestedAttribute: + _, exists := attr.NestedObject.Attributes[key] + if !exists { + panic("Cannot remove field, " + key + " does not exist in the schema") + } + delete(attr.NestedObject.Attributes, key) + return attr + default: + panic("attribute is not nested, cannot add field") + } + } + + if len(path) == 0 { + s.attr = cb(s.attr) + } else { + navigateSchemaWithCallback(&s.attr, cb, path...) + } + return s +} + +func (s *CustomizableSchemaPluginFramework) AddValidator(v any, path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr schema.Attribute) schema.Attribute { + val := reflect.ValueOf(attr) + + // Make a copy of the existing attr. + newAttr := reflect.New(val.Type()).Elem() + newAttr.Set(val) + val = newAttr + + field := val.FieldByName("Validators") + if !field.IsValid() { + panic(fmt.Sprintf("Validators field not found in %T", attr)) + } + if field.Kind() != reflect.Slice { + panic(fmt.Sprintf("Validators field is not a slice in %T", attr)) + } + if !field.CanSet() { + panic(fmt.Sprintf("Validators field cannot be set in %T", attr)) + } + + elemType := field.Type().Elem() + value := reflect.ValueOf(v) + + if !value.Type().AssignableTo(elemType) { + panic(fmt.Sprintf("Value of type %T is not assignable to slice of %s", v, elemType)) + } + + // Append the value + newSlice := reflect.Append(field, value) + field.Set(newSlice) + + return val.Interface().(schema.Attribute) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetOptional(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr schema.Attribute) schema.Attribute { + // Get the concrete value stored in the interface + v := reflect.ValueOf(attr) + + // Make a new addressable value and copy the original value into it + newAttr := reflect.New(v.Type()).Elem() + newAttr.Set(v) + v = newAttr + + field := v.FieldByName("Required") + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.Bool { + field.SetBool(false) + } else { + panic(fmt.Sprintf("Required is not a bool field in %T", attr)) + } + } else { + panic(fmt.Sprintf("Required field not found or cannot be set in %T", attr)) + } + + field = v.FieldByName("Optional") + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.Bool { + field.SetBool(true) + } else { + panic(fmt.Sprintf("Optional is not a bool field in %T", attr)) + } + } else { + panic(fmt.Sprintf("Optional field not found or cannot be set in %T", attr)) + } + + return v.Interface().(schema.Attribute) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetRequired(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr schema.Attribute) schema.Attribute { + // Get the concrete value stored in the interface + v := reflect.ValueOf(attr) + + // Make a new addressable value and copy the original value into it + newAttr := reflect.New(v.Type()).Elem() + newAttr.Set(v) + v = newAttr + + field := v.FieldByName("Required") + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.Bool { + field.SetBool(true) + } else { + panic(fmt.Sprintf("Required is not a bool field in %T", attr)) + } + } else { + panic(fmt.Sprintf("Required field not found or cannot be set in %T", attr)) + } + + field = v.FieldByName("Optional") + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.Bool { + field.SetBool(false) + } else { + panic(fmt.Sprintf("Optional is not a bool field in %T", attr)) + } + } else { + panic(fmt.Sprintf("Optional field not found or cannot be set in %T", attr)) + } + + return v.Interface().(schema.Attribute) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +func (s *CustomizableSchemaPluginFramework) SetSensitive(path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr schema.Attribute) schema.Attribute { + // Get the concrete value stored in the interface + v := reflect.ValueOf(attr) + + // Make a new addressable value and copy the original value into it + newAttr := reflect.New(v.Type()).Elem() + newAttr.Set(v) + v = newAttr + + field := v.FieldByName("Sensitive") + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.Bool { + field.SetBool(true) + } else { + panic(fmt.Sprintf("Sensitive is not a bool field in %T", attr)) + } + } else { + panic(fmt.Sprintf("Sensitive field not found or cannot be set in %T", attr)) + } + + return v.Interface().(schema.Attribute) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + return s +} + +func (s *CustomizableSchemaPluginFramework) SetDeprecated(msg string, path ...string) *CustomizableSchemaPluginFramework { + cb := func(attr schema.Attribute) schema.Attribute { + // Get the concrete value stored in the interface + v := reflect.ValueOf(attr) + + // Make a new addressable value and copy the original value into it + newAttr := reflect.New(v.Type()).Elem() + newAttr.Set(v) + v = newAttr + + field := v.FieldByName("DeprecationMessage") + if field.IsValid() && field.CanSet() { + if field.Kind() == reflect.String { + field.SetString(msg) + } else { + panic(fmt.Sprintf("DeprecationMessage is not a string field in %T", attr)) + } + } else { + panic(fmt.Sprintf("DeprecationMessage field not found or cannot be set in %T", attr)) + } + + return v.Interface().(schema.Attribute) + } + + navigateSchemaWithCallback(&s.attr, cb, path...) + + return s +} + +// Given a attribute map, navigate through the given path, panics if the path is not valid. +func MustSchemaAttributePath(attrs map[string]schema.Attribute, path ...string) schema.Attribute { + attr := ConstructCustomizableSchema(attrs).attr + + res, err := navigateSchema(&attr, path...) + if err != nil { + panic(err) + } + + return res +} + +// Helper function for navigating through schema attributes, panics if path does not exist or invalid. +func navigateSchema(s *schema.Attribute, path ...string) (schema.Attribute, error) { + current_scm := s + for i, p := range path { + m := attributeToMap(current_scm) + + v, ok := m[p] + if !ok { + return nil, fmt.Errorf("missing key %s", p) + } + + if i == len(path)-1 { + return v, nil + } + current_scm = &v + } + return nil, fmt.Errorf("path %v is incomplete", path) +} + +// Helper function for navigating through schema attributes, panics if path does not exist or invalid. +func navigateSchemaWithCallback(s *schema.Attribute, cb func(schema.Attribute) schema.Attribute, path ...string) (schema.Attribute, error) { + current_scm := s + for i, p := range path { + m := attributeToMap(current_scm) + + v, ok := m[p] + if !ok { + return nil, fmt.Errorf("missing key %s", p) + } + + if i == len(path)-1 { + m[p] = cb(v) + return m[p], nil + } + current_scm = &v + } + return nil, fmt.Errorf("path %v is incomplete", path) +} diff --git a/common/customizable_schema_plugin_framework_test.go b/common/customizable_schema_plugin_framework_test.go new file mode 100644 index 0000000000..3d0b31b0e9 --- /dev/null +++ b/common/customizable_schema_plugin_framework_test.go @@ -0,0 +1,71 @@ +package common + +import ( + "context" + "fmt" + "testing" + + "github.com/hashicorp/terraform-plugin-framework/provider/schema" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/stretchr/testify/assert" +) + +type stringLengthBetweenValidator struct { + Max int + Min int +} + +// Description returns a plain text description of the validator's behavior, suitable for a practitioner to understand its impact. +func (v stringLengthBetweenValidator) Description(ctx context.Context) string { + return fmt.Sprintf("string length must be between %d and %d", v.Min, v.Max) +} + +// MarkdownDescription returns a markdown formatted description of the validator's behavior, suitable for a practitioner to understand its impact. +func (v stringLengthBetweenValidator) MarkdownDescription(ctx context.Context) string { + return fmt.Sprintf("string length must be between `%d` and `%d`", v.Min, v.Max) +} + +// Validate runs the main validation logic of the validator, reading configuration data out of `req` and updating `resp` with diagnostics. +func (v stringLengthBetweenValidator) ValidateString(ctx context.Context, req validator.StringRequest, resp *validator.StringResponse) { + // If the value is unknown or null, there is nothing to validate. + if req.ConfigValue.IsUnknown() || req.ConfigValue.IsNull() { + return + } + + strLen := len(req.ConfigValue.ValueString()) + + if strLen < v.Min || strLen > v.Max { + resp.Diagnostics.AddAttributeError( + req.Path, + "Invalid String Length", + fmt.Sprintf("String length must be between %d and %d, got: %d.", v.Min, v.Max, strLen), + ) + + return + } +} + +func TestCustomizeSchema(t *testing.T) { + scm := pluginFrameworkStructToSchema(DummyTfSdk{}, func(c CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework { + c.AddNewField("new_field", schema.StringAttribute{Required: true}) + c.AddNewField("new_field", schema.StringAttribute{Required: true}, "nested") + c.AddNewField("to_be_removed", schema.StringAttribute{Required: true}, "nested") + c.RemoveField("to_be_removed", "nested") + c.SetRequired("nested", "enabled") + c.SetOptional("description") + c.SetSensitive("nested", "name") + c.SetDeprecated("deprecated", "map") + c.AddValidator(stringLengthBetweenValidator{}, "description") + return c + }) + assert.True(t, scm.Attributes["new_field"].IsRequired()) + assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "new_field").IsRequired()) + assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "enabled").IsRequired()) + assert.True(t, MustSchemaAttributePath(scm.Attributes, "nested", "name").IsSensitive()) + assert.True(t, MustSchemaAttributePath(scm.Attributes, "map").GetDeprecationMessage() == "deprecated") + assert.True(t, scm.Attributes["description"].IsOptional()) + attr := MustSchemaAttributePath(scm.Attributes, "nested").(schema.SingleNestedAttribute).Attributes + _, ok := attr["to_be_removed"] + assert.True(t, len(MustSchemaAttributePath(scm.Attributes, "description").(schema.StringAttribute).Validators) == 1) + assert.True(t, !ok) +} diff --git a/common/reflect_resource_plugin_framework.go b/common/reflect_resource_plugin_framework.go index ca364521cd..71b57e6227 100644 --- a/common/reflect_resource_plugin_framework.go +++ b/common/reflect_resource_plugin_framework.go @@ -402,8 +402,13 @@ func fieldIsOptional(field reflect.StructField) bool { return strings.Contains(tagValue, "optional") } -func pluginFrameworkStructToSchema(v any) schema.Schema { - return schema.Schema{ - Attributes: pluginFrameworkTypeToSchema(reflect.ValueOf(v)), +func pluginFrameworkStructToSchema(v any, customizeSchema func(CustomizableSchemaPluginFramework) CustomizableSchemaPluginFramework) schema.Schema { + attributes := pluginFrameworkTypeToSchema(reflect.ValueOf(v)) + + if customizeSchema != nil { + cs := customizeSchema(*ConstructCustomizableSchema(attributes)) + return schema.Schema{Attributes: cs.ToAttributeMap()} + } else { + return schema.Schema{Attributes: attributes} } } diff --git a/common/reflect_resource_plugin_framework_test.go b/common/reflect_resource_plugin_framework_test.go index 81ad1b33d6..88fc0a30cf 100644 --- a/common/reflect_resource_plugin_framework_test.go +++ b/common/reflect_resource_plugin_framework_test.go @@ -140,7 +140,7 @@ var tfSdkStruct = DummyTfSdk{ func TestGetAndSetPluginFramework(t *testing.T) { // Also test StructToSchema. - scm := pluginFrameworkStructToSchema(DummyTfSdk{}) + scm := pluginFrameworkStructToSchema(DummyTfSdk{}, nil) state := tfsdk.State{ Schema: scm, }