diff --git a/internal/pluginframework/provider/provider_plugin_framework.go b/internal/pluginframework/provider/provider_plugin_framework.go new file mode 100644 index 0000000000..7f04921189 --- /dev/null +++ b/internal/pluginframework/provider/provider_plugin_framework.go @@ -0,0 +1,165 @@ +package provider + +import ( + "context" + "fmt" + "log" + "reflect" + "sort" + "strings" + + "github.com/databricks/databricks-sdk-go/client" + "github.com/databricks/databricks-sdk-go/config" + "github.com/databricks/terraform-provider-databricks/commands" + "github.com/databricks/terraform-provider-databricks/common" + + "github.com/hashicorp/terraform-plugin-framework/datasource" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/provider" + "github.com/hashicorp/terraform-plugin-framework/provider/schema" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/hashicorp/terraform-plugin-log/tflog" +) + +func GetProviderName() string { + return "databricks-tf-provider" +} + +func GetDatabricksProviderPluginFramework() provider.Provider { + p := &DatabricksProviderPluginFramework{} + return p +} + +type DatabricksProviderPluginFramework struct { +} + +var _ provider.Provider = (*DatabricksProviderPluginFramework)(nil) + +func (p *DatabricksProviderPluginFramework) Resources(ctx context.Context) []func() resource.Resource { + return []func() resource.Resource{} +} + +func (p *DatabricksProviderPluginFramework) DataSources(ctx context.Context) []func() datasource.DataSource { + return []func() datasource.DataSource{} +} + +func (p *DatabricksProviderPluginFramework) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) { + resp.Schema = providerSchemaPluginFramework() +} + +func (p *DatabricksProviderPluginFramework) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) { + resp.TypeName = GetProviderName() + resp.Version = common.Version() +} + +func (p *DatabricksProviderPluginFramework) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) { + client := configureDatabricksClient_PluginFramework(ctx, req, resp) + resp.DataSourceData = client + resp.ResourceData = client +} + +// Function returns a schema.Schema based on config attributes where each attribute is mapped to the appropriate +// schema type (BoolAttribute, StringAttribute, Int64Attribute). +func providerSchemaPluginFramework() schema.Schema { + ps := map[string]schema.Attribute{} + for _, attr := range config.ConfigAttributes { + switch attr.Kind { + case reflect.Bool: + ps[attr.Name] = schema.BoolAttribute{ + Optional: true, + Sensitive: attr.Sensitive, + } + case reflect.String: + ps[attr.Name] = schema.StringAttribute{ + Optional: true, + Sensitive: attr.Sensitive, + } + case reflect.Int: + ps[attr.Name] = schema.Int64Attribute{ + Optional: true, + Sensitive: attr.Sensitive, + } + } + } + return schema.Schema{ + Attributes: ps, + } +} + +func configureDatabricksClient_PluginFramework(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) any { + cfg := &config.Config{} + attrsUsed := []string{} + for _, attr := range config.ConfigAttributes { + switch attr.Kind { + case reflect.Bool: + var attrValue types.Bool + diags := req.Config.GetAttribute(ctx, path.Root(attr.Name), &attrValue) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return nil + } + err := attr.Set(cfg, attrValue.ValueBool()) + if err != nil { + resp.Diagnostics.Append(diag.NewErrorDiagnostic("Failed to set attribute", err.Error())) + return nil + } + case reflect.Int: + var attrValue types.Int64 + diags := req.Config.GetAttribute(ctx, path.Root(attr.Name), &attrValue) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return nil + } + err := attr.Set(cfg, int(attrValue.ValueInt64())) + if err != nil { + resp.Diagnostics.Append(diag.NewErrorDiagnostic("Failed to set attribute", err.Error())) + return nil + } + case reflect.String: + var attrValue types.String + diags := req.Config.GetAttribute(ctx, path.Root(attr.Name), &attrValue) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return nil + } + err := attr.Set(cfg, attrValue.ValueString()) + if err != nil { + resp.Diagnostics.AddError(fmt.Sprintf("Failed to set attribute: %s", attr.Name), err.Error()) + return nil + } + } + if attr.Kind == reflect.String { + attrsUsed = append(attrsUsed, attr.Name) + } + } + sort.Strings(attrsUsed) + tflog.Info(ctx, fmt.Sprintf("Explicit and implicit attributes: %s", strings.Join(attrsUsed, ", "))) + if cfg.AuthType != "" { + // mapping from previous Google authentication types + // and current authentication types from Databricks Go SDK + oldToNewerAuthType := map[string]string{ + "google-creds": "google-credentials", + "google-accounts": "google-id", + "google-workspace": "google-id", + } + newer, ok := oldToNewerAuthType[cfg.AuthType] + if ok { + log.Printf("[INFO] Changing required auth_type from %s to %s", cfg.AuthType, newer) + cfg.AuthType = newer + } + } + client, err := client.New(cfg) + if err != nil { + resp.Diagnostics.Append(diag.NewErrorDiagnostic(err.Error(), "")) + return nil + } + pc := &common.DatabricksClient{ + DatabricksClient: client, + } + pc.WithCommandExecutor(func(ctx context.Context, client *common.DatabricksClient) common.CommandExecutor { + return commands.NewCommandsAPI(ctx, client) + }) + return pc +} diff --git a/main.go b/main.go index b64453db94..47825aad22 100644 --- a/main.go +++ b/main.go @@ -11,8 +11,6 @@ import ( "github.com/databricks/terraform-provider-databricks/provider" "github.com/hashicorp/terraform-plugin-go/tfprotov6" "github.com/hashicorp/terraform-plugin-go/tfprotov6/tf6server" - "github.com/hashicorp/terraform-plugin-mux/tf5to6server" - "github.com/hashicorp/terraform-plugin-mux/tf6muxserver" ) const startMessageFormat = `Databricks Terraform Provider @@ -39,21 +37,8 @@ func main() { log.Printf(startMessageFormat, common.Version()) - sdkPluginProvider := provider.DatabricksProvider() - - upgradedSdkPluginProvider, err := tf5to6server.UpgradeServer( - context.Background(), - sdkPluginProvider.GRPCProvider, - ) - if err != nil { - log.Fatal(err) - } - ctx := context.Background() - muxServer, err := tf6muxserver.NewMuxServer(ctx, func() tfprotov6.ProviderServer { - return upgradedSdkPluginProvider - }) - + providerServer, err := provider.GetProviderServer(ctx) if err != nil { log.Fatal(err) } @@ -65,10 +50,9 @@ func main() { err = tf6server.Serve( "registry.terraform.io/databricks/databricks", - muxServer.ProviderServer, + func() tfprotov6.ProviderServer { return providerServer }, serveOpts..., ) - if err != nil { log.Fatal(err) } diff --git a/provider/provider.go b/provider/provider.go index 7c4b1b49b8..0b954fc620 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -26,6 +26,7 @@ import ( "github.com/databricks/terraform-provider-databricks/commands" "github.com/databricks/terraform-provider-databricks/common" "github.com/databricks/terraform-provider-databricks/dashboards" + pluginframeworkprovider "github.com/databricks/terraform-provider-databricks/internal/pluginframework/provider" "github.com/databricks/terraform-provider-databricks/jobs" "github.com/databricks/terraform-provider-databricks/logger" "github.com/databricks/terraform-provider-databricks/mlflow" @@ -50,7 +51,7 @@ import ( func init() { // IMPORTANT: this line cannot be changed, because it's used for // internal purposes at Databricks. - useragent.WithProduct("databricks-tf-provider", common.Version()) + useragent.WithProduct(pluginframeworkprovider.GetProviderName(), common.Version()) userAgentExtraEnv := os.Getenv("DATABRICKS_USER_AGENT_EXTRA") out, err := parseUserAgentExtra(userAgentExtraEnv) diff --git a/provider/provider_factory.go b/provider/provider_factory.go new file mode 100644 index 0000000000..2e9f5ef32b --- /dev/null +++ b/provider/provider_factory.go @@ -0,0 +1,49 @@ +package provider + +import ( + "context" + "log" + + pluginframeworkprovider "github.com/databricks/terraform-provider-databricks/internal/pluginframework/provider" + "github.com/hashicorp/terraform-plugin-framework/providerserver" + "github.com/hashicorp/terraform-plugin-go/tfprotov6" + "github.com/hashicorp/terraform-plugin-mux/tf5to6server" + "github.com/hashicorp/terraform-plugin-mux/tf6muxserver" +) + +// GetProviderServer initializes and returns a Terraform Protocol v6 ProviderServer. +// The function begins by initializing the Databricks provider using the SDK plugin +// and then upgrades this provider to be compatible with Terraform's Protocol v6 using +// the v5-to-v6 upgrade mechanism. Additionally, it retrieves the Databricks provider +// that is implemented using the Terraform Plugin Framework. These different provider +// implementations are then combined using a multiplexing server, which allows multiple +// Protocol v6 providers to be served together. The function returns the multiplexed +// ProviderServer, or an error if any part of the process fails. +func GetProviderServer(ctx context.Context) (tfprotov6.ProviderServer, error) { + sdkPluginProvider := DatabricksProvider() + + upgradedSdkPluginProvider, err := tf5to6server.UpgradeServer( + context.Background(), + sdkPluginProvider.GRPCProvider, + ) + if err != nil { + log.Fatal(err) + } + + pluginFrameworkProvider := pluginframeworkprovider.GetDatabricksProviderPluginFramework() + + providers := []func() tfprotov6.ProviderServer{ + func() tfprotov6.ProviderServer { + return upgradedSdkPluginProvider + }, + providerserver.NewProtocol6(pluginFrameworkProvider), + } + + muxServer, err := tf6muxserver.NewMuxServer(ctx, providers...) + + if err != nil { + return nil, err + } + + return muxServer.ProviderServer(), nil +} diff --git a/provider/provider_test.go b/provider/provider_test.go index 7d2b45fe46..6c1c0bbf34 100644 --- a/provider/provider_test.go +++ b/provider/provider_test.go @@ -2,98 +2,20 @@ package provider import ( "context" - "errors" "fmt" "log" "net/http" "net/http/httptest" "os" "path/filepath" - "strings" "testing" "time" "github.com/databricks/terraform-provider-databricks/common" - "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type providerFixture struct { - host string - token string - username string - password string - configFile string - profile string - azureClientID string - azureClientSecret string - azureTenantID string - azureResourceID string - authType string - env map[string]string - assertError string - assertAuth string - assertHost string - assertAzure bool -} - -func (tt providerFixture) rawConfig() map[string]any { - rawConfig := map[string]any{} - if tt.host != "" { - rawConfig["host"] = tt.host - } - if tt.token != "" { - rawConfig["token"] = tt.token - } - if tt.username != "" { - rawConfig["username"] = tt.username - } - if tt.password != "" { - rawConfig["password"] = tt.password - } - if tt.configFile != "" { - rawConfig["config_file"] = tt.configFile - } - if tt.profile != "" { - rawConfig["profile"] = tt.profile - } - if tt.azureClientID != "" { - rawConfig["azure_client_id"] = tt.azureClientID - } - if tt.azureClientSecret != "" { - rawConfig["azure_client_secret"] = tt.azureClientSecret - } - if tt.azureTenantID != "" { - rawConfig["azure_tenant_id"] = tt.azureTenantID - } - if tt.azureResourceID != "" { - rawConfig["azure_workspace_resource_id"] = tt.azureResourceID - } - if tt.authType != "" { - rawConfig["auth_type"] = tt.authType - } - return rawConfig -} - -func (tc providerFixture) apply(t *testing.T) *common.DatabricksClient { - c, err := configureProviderAndReturnClient(t, tc) - if tc.assertError != "" { - require.NotNilf(t, err, "Expected to have %s error", tc.assertError) - require.True(t, strings.HasPrefix(err.Error(), tc.assertError), - "Expected to have '%s' error, but got '%s'", tc.assertError, err) - return nil - } - if err != nil { - require.NoError(t, err) - return nil - } - assert.Equal(t, tc.assertAzure, c.IsAzure()) - assert.Equal(t, tc.assertAuth, c.Config.AuthType) - assert.Equal(t, tc.assertHost, c.Config.Host) - return c -} - func TestConfig_NoParams(t *testing.T) { if f, err := os.Stat("~/.databrickscfg"); err == nil && !f.IsDir() { // the provider should fail to configure if no configuration options are available, @@ -428,7 +350,7 @@ func TestConfig_OAuthFetchesToken(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(shortLivedOAuthHandler)) defer ts.Close() - client := providerFixture{ + testFixture := providerFixture{ env: map[string]string{ "DATABRICKS_HOST": ts.URL, "DATABRICKS_CLIENT_ID": "x", @@ -436,9 +358,18 @@ func TestConfig_OAuthFetchesToken(t *testing.T) { }, assertAuth: "oauth-m2m", assertHost: ts.URL, - }.apply(t) + } + + client := testFixture.applyWithSDKv2(t) + testOAuthFetchesToken(t, client) + + client = testFixture.applyWithPluginFramework(t) + testOAuthFetchesToken(t, client) - ws, err := client.WorkspaceClient() +} + +func testOAuthFetchesToken(t *testing.T, c *common.DatabricksClient) { + ws, err := c.WorkspaceClient() require.NoError(t, err) bgCtx := context.Background() { @@ -455,32 +386,6 @@ func TestConfig_OAuthFetchesToken(t *testing.T) { } } -func configureProviderAndReturnClient(t *testing.T, tt providerFixture) (*common.DatabricksClient, error) { - for k, v := range tt.env { - t.Setenv(k, v) - } - p := DatabricksProvider() - ctx := context.Background() - diags := p.Configure(ctx, terraform.NewResourceConfigRaw(tt.rawConfig())) - if len(diags) > 0 { - issues := []string{} - for _, d := range diags { - issues = append(issues, d.Summary) - } - return nil, errors.New(strings.Join(issues, ", ")) - } - client := p.Meta().(*common.DatabricksClient) - r, err := http.NewRequest("GET", "", nil) - if err != nil { - return nil, err - } - err = client.Config.Authenticate(r) - if err != nil { - return nil, err - } - return client, nil -} - type parseUserAgentTestCase struct { name string env string diff --git a/provider/test_utils.go b/provider/test_utils.go new file mode 100644 index 0000000000..5193835f33 --- /dev/null +++ b/provider/test_utils.go @@ -0,0 +1,198 @@ +package provider + +import ( + "context" + "errors" + "net/http" + "strings" + "testing" + + "github.com/databricks/terraform-provider-databricks/common" + pluginframeworkprovider "github.com/databricks/terraform-provider-databricks/internal/pluginframework/provider" + "github.com/hashicorp/terraform-plugin-framework/provider" + "github.com/hashicorp/terraform-plugin-framework/tfsdk" + "github.com/hashicorp/terraform-plugin-go/tftypes" + "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type providerFixture struct { + host string + token string + username string + password string + configFile string + profile string + azureClientID string + azureClientSecret string + azureTenantID string + azureResourceID string + authType string + env map[string]string + assertError string + assertAuth string + assertHost string + assertAzure bool +} + +func (pf providerFixture) rawConfig() map[string]string { + rawConfig := map[string]string{} + if pf.host != "" { + rawConfig["host"] = pf.host + } + if pf.token != "" { + rawConfig["token"] = pf.token + } + if pf.username != "" { + rawConfig["username"] = pf.username + } + if pf.password != "" { + rawConfig["password"] = pf.password + } + if pf.configFile != "" { + rawConfig["config_file"] = pf.configFile + } + if pf.profile != "" { + rawConfig["profile"] = pf.profile + } + if pf.azureClientID != "" { + rawConfig["azure_client_id"] = pf.azureClientID + } + if pf.azureClientSecret != "" { + rawConfig["azure_client_secret"] = pf.azureClientSecret + } + if pf.azureTenantID != "" { + rawConfig["azure_tenant_id"] = pf.azureTenantID + } + if pf.azureResourceID != "" { + rawConfig["azure_workspace_resource_id"] = pf.azureResourceID + } + if pf.authType != "" { + rawConfig["auth_type"] = pf.authType + } + return rawConfig +} + +func (pf providerFixture) rawConfigSDKv2() map[string]any { + rawConfig := pf.rawConfig() + rawConfigSDKv2 := map[string]any{} + for k, v := range rawConfig { + rawConfigSDKv2[k] = v + } + return rawConfigSDKv2 +} + +func (pf providerFixture) rawConfigPluginFramework() tftypes.Value { + rawConfig := pf.rawConfig() + + rawConfigTypeMap := map[string]tftypes.Type{} + for k := range rawConfig { + rawConfigTypeMap[k] = tftypes.String + } + rawConfigType := tftypes.Object{ + AttributeTypes: rawConfigTypeMap, + } + + rawConfigValueMap := map[string]tftypes.Value{} + for k, v := range rawConfig { + rawConfigValueMap[k] = tftypes.NewValue(tftypes.String, v) + } + rawConfigValue := tftypes.NewValue(rawConfigType, rawConfigValueMap) + return rawConfigValue +} + +func (pf providerFixture) applyWithSDKv2(t *testing.T) *common.DatabricksClient { + c, err := pf.configureProviderAndReturnClient_SDKv2(t) + return pf.applyAssertions(c, t, err) +} + +func (pf providerFixture) applyWithPluginFramework(t *testing.T) *common.DatabricksClient { + c, err := pf.configureProviderAndReturnClient_PluginFramework(t) + return pf.applyAssertions(c, t, err) +} + +func (pf providerFixture) applyAssertions(c *common.DatabricksClient, t *testing.T, err error) *common.DatabricksClient { + if pf.assertError != "" { + require.NotNilf(t, err, "Expected to have %s error", pf.assertError) + require.True(t, strings.HasPrefix(err.Error(), pf.assertError), + "Expected to have '%s' error, but got '%s'", pf.assertError, err) + return nil + } + if err != nil { + require.NoError(t, err) + return nil + } + assert.Equal(t, pf.assertAzure, c.IsAzure()) + assert.Equal(t, pf.assertAuth, c.Config.AuthType) + assert.Equal(t, pf.assertHost, c.Config.Host) + return c +} + +func (pf providerFixture) apply(t *testing.T) { + t.Run("sdkv2", func(t *testing.T) { _ = pf.applyWithSDKv2(t) }) + t.Run("plugin framework", func(t *testing.T) { _ = pf.applyWithPluginFramework(t) }) +} + +func (pf providerFixture) configureProviderAndReturnClient_SDKv2(t *testing.T) (*common.DatabricksClient, error) { + for k, v := range pf.env { + t.Setenv(k, v) + } + p := DatabricksProvider() + ctx := context.Background() + diags := p.Configure(ctx, terraform.NewResourceConfigRaw(pf.rawConfigSDKv2())) + if len(diags) > 0 { + issues := []string{} + for _, d := range diags { + issues = append(issues, d.Summary) + } + return nil, errors.New(strings.Join(issues, ", ")) + } + client := p.Meta().(*common.DatabricksClient) + r, err := http.NewRequest("GET", "", nil) + if err != nil { + return nil, err + } + err = client.Config.Authenticate(r) + if err != nil { + return nil, err + } + return client, nil +} + +func (pf providerFixture) configureProviderAndReturnClient_PluginFramework(t *testing.T) (*common.DatabricksClient, error) { + for k, v := range pf.env { + t.Setenv(k, v) + } + p := pluginframeworkprovider.GetDatabricksProviderPluginFramework() + ctx := context.Background() + rawConfig := pf.rawConfigPluginFramework() + var providerSchemaResponse provider.SchemaResponse + p.Schema(ctx, provider.SchemaRequest{}, &providerSchemaResponse) + configRequest := provider.ConfigureRequest{ + Config: tfsdk.Config{ + Raw: rawConfig, + Schema: providerSchemaResponse.Schema, + }, + } + configResponse := &provider.ConfigureResponse{} + p.Configure(ctx, configRequest, configResponse) + diags := configResponse.Diagnostics + if len(diags) > 0 { + issues := []string{} + for _, d := range diags { + issues = append(issues, d.Summary()) + } + return nil, errors.New(strings.Join(issues, ", ")) + } + client := configResponse.ResourceData.(*common.DatabricksClient) + r, err := http.NewRequest("GET", "", nil) + if err != nil { + return nil, err + } + err = client.Config.Authenticate(r) + if err != nil { + return nil, err + } + return client, nil +}