Skip to content

Commit

Permalink
[Feature] Introduced Plugin Framework (#3920)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->
This PR introduces plugin framework to the Databricks Terraform
Provider.
- Introduce plugin framework
- Add provider schema
- Add support to configure databricks client
- Add support for provider fixture
- Add support to get raw config to be used in provider tests and further
in unit testing framework.
- Fix bugs in configuring client + other misc bugs

Note: It is a subpart of:
#3751.
Rest would be added once support for schema customisation and SDK
converted are added.

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->
Unit tests, nightly were run over the above PRs. 
- [ ] `make test` run locally
- [ ] relevant change in `docs/` folder
- [ ] covered with integration tests in `internal/acceptance`
- [ ] relevant acceptance tests are passing
- [ ] using Go SDK
  • Loading branch information
tanmay-db authored Aug 22, 2024
1 parent 7950e92 commit 8899beb
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 126 deletions.
165 changes: 165 additions & 0 deletions internal/pluginframework/provider/provider_plugin_framework.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package provider

import (
"context"
"fmt"
"log"
"reflect"
"sort"
"strings"

"github.com/databricks/databricks-sdk-go/client"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/terraform-provider-databricks/commands"
"github.com/databricks/terraform-provider-databricks/common"

"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/path"
"github.com/hashicorp/terraform-plugin-framework/provider"
"github.com/hashicorp/terraform-plugin-framework/provider/schema"
"github.com/hashicorp/terraform-plugin-framework/resource"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-log/tflog"
)

func GetProviderName() string {
return "databricks-tf-provider"
}

func GetDatabricksProviderPluginFramework() provider.Provider {
p := &DatabricksProviderPluginFramework{}
return p
}

type DatabricksProviderPluginFramework struct {
}

var _ provider.Provider = (*DatabricksProviderPluginFramework)(nil)

func (p *DatabricksProviderPluginFramework) Resources(ctx context.Context) []func() resource.Resource {
return []func() resource.Resource{}
}

func (p *DatabricksProviderPluginFramework) DataSources(ctx context.Context) []func() datasource.DataSource {
return []func() datasource.DataSource{}
}

func (p *DatabricksProviderPluginFramework) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) {
resp.Schema = providerSchemaPluginFramework()
}

func (p *DatabricksProviderPluginFramework) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) {
resp.TypeName = GetProviderName()
resp.Version = common.Version()
}

func (p *DatabricksProviderPluginFramework) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
client := configureDatabricksClient_PluginFramework(ctx, req, resp)
resp.DataSourceData = client
resp.ResourceData = client
}

// Function returns a schema.Schema based on config attributes where each attribute is mapped to the appropriate
// schema type (BoolAttribute, StringAttribute, Int64Attribute).
func providerSchemaPluginFramework() schema.Schema {
ps := map[string]schema.Attribute{}
for _, attr := range config.ConfigAttributes {
switch attr.Kind {
case reflect.Bool:
ps[attr.Name] = schema.BoolAttribute{
Optional: true,
Sensitive: attr.Sensitive,
}
case reflect.String:
ps[attr.Name] = schema.StringAttribute{
Optional: true,
Sensitive: attr.Sensitive,
}
case reflect.Int:
ps[attr.Name] = schema.Int64Attribute{
Optional: true,
Sensitive: attr.Sensitive,
}
}
}
return schema.Schema{
Attributes: ps,
}
}

func configureDatabricksClient_PluginFramework(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) any {
cfg := &config.Config{}
attrsUsed := []string{}
for _, attr := range config.ConfigAttributes {
switch attr.Kind {
case reflect.Bool:
var attrValue types.Bool
diags := req.Config.GetAttribute(ctx, path.Root(attr.Name), &attrValue)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return nil
}
err := attr.Set(cfg, attrValue.ValueBool())
if err != nil {
resp.Diagnostics.Append(diag.NewErrorDiagnostic("Failed to set attribute", err.Error()))
return nil
}
case reflect.Int:
var attrValue types.Int64
diags := req.Config.GetAttribute(ctx, path.Root(attr.Name), &attrValue)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return nil
}
err := attr.Set(cfg, int(attrValue.ValueInt64()))
if err != nil {
resp.Diagnostics.Append(diag.NewErrorDiagnostic("Failed to set attribute", err.Error()))
return nil
}
case reflect.String:
var attrValue types.String
diags := req.Config.GetAttribute(ctx, path.Root(attr.Name), &attrValue)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return nil
}
err := attr.Set(cfg, attrValue.ValueString())
if err != nil {
resp.Diagnostics.AddError(fmt.Sprintf("Failed to set attribute: %s", attr.Name), err.Error())
return nil
}
}
if attr.Kind == reflect.String {
attrsUsed = append(attrsUsed, attr.Name)
}
}
sort.Strings(attrsUsed)
tflog.Info(ctx, fmt.Sprintf("Explicit and implicit attributes: %s", strings.Join(attrsUsed, ", ")))
if cfg.AuthType != "" {
// mapping from previous Google authentication types
// and current authentication types from Databricks Go SDK
oldToNewerAuthType := map[string]string{
"google-creds": "google-credentials",
"google-accounts": "google-id",
"google-workspace": "google-id",
}
newer, ok := oldToNewerAuthType[cfg.AuthType]
if ok {
log.Printf("[INFO] Changing required auth_type from %s to %s", cfg.AuthType, newer)
cfg.AuthType = newer
}
}
client, err := client.New(cfg)
if err != nil {
resp.Diagnostics.Append(diag.NewErrorDiagnostic(err.Error(), ""))
return nil
}
pc := &common.DatabricksClient{
DatabricksClient: client,
}
pc.WithCommandExecutor(func(ctx context.Context, client *common.DatabricksClient) common.CommandExecutor {
return commands.NewCommandsAPI(ctx, client)
})
return pc
}
20 changes: 2 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"github.com/databricks/terraform-provider-databricks/provider"
"github.com/hashicorp/terraform-plugin-go/tfprotov6"
"github.com/hashicorp/terraform-plugin-go/tfprotov6/tf6server"
"github.com/hashicorp/terraform-plugin-mux/tf5to6server"
"github.com/hashicorp/terraform-plugin-mux/tf6muxserver"
)

const startMessageFormat = `Databricks Terraform Provider
Expand All @@ -39,21 +37,8 @@ func main() {

log.Printf(startMessageFormat, common.Version())

sdkPluginProvider := provider.DatabricksProvider()

upgradedSdkPluginProvider, err := tf5to6server.UpgradeServer(
context.Background(),
sdkPluginProvider.GRPCProvider,
)
if err != nil {
log.Fatal(err)
}

ctx := context.Background()
muxServer, err := tf6muxserver.NewMuxServer(ctx, func() tfprotov6.ProviderServer {
return upgradedSdkPluginProvider
})

providerServer, err := provider.GetProviderServer(ctx)
if err != nil {
log.Fatal(err)
}
Expand All @@ -65,10 +50,9 @@ func main() {

err = tf6server.Serve(
"registry.terraform.io/databricks/databricks",
muxServer.ProviderServer,
func() tfprotov6.ProviderServer { return providerServer },
serveOpts...,
)

if err != nil {
log.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/databricks/terraform-provider-databricks/commands"
"github.com/databricks/terraform-provider-databricks/common"
"github.com/databricks/terraform-provider-databricks/dashboards"
pluginframeworkprovider "github.com/databricks/terraform-provider-databricks/internal/pluginframework/provider"
"github.com/databricks/terraform-provider-databricks/jobs"
"github.com/databricks/terraform-provider-databricks/logger"
"github.com/databricks/terraform-provider-databricks/mlflow"
Expand All @@ -50,7 +51,7 @@ import (
func init() {
// IMPORTANT: this line cannot be changed, because it's used for
// internal purposes at Databricks.
useragent.WithProduct("databricks-tf-provider", common.Version())
useragent.WithProduct(pluginframeworkprovider.GetProviderName(), common.Version())

userAgentExtraEnv := os.Getenv("DATABRICKS_USER_AGENT_EXTRA")
out, err := parseUserAgentExtra(userAgentExtraEnv)
Expand Down
49 changes: 49 additions & 0 deletions provider/provider_factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package provider

import (
"context"
"log"

pluginframeworkprovider "github.com/databricks/terraform-provider-databricks/internal/pluginframework/provider"
"github.com/hashicorp/terraform-plugin-framework/providerserver"
"github.com/hashicorp/terraform-plugin-go/tfprotov6"
"github.com/hashicorp/terraform-plugin-mux/tf5to6server"
"github.com/hashicorp/terraform-plugin-mux/tf6muxserver"
)

// GetProviderServer initializes and returns a Terraform Protocol v6 ProviderServer.
// The function begins by initializing the Databricks provider using the SDK plugin
// and then upgrades this provider to be compatible with Terraform's Protocol v6 using
// the v5-to-v6 upgrade mechanism. Additionally, it retrieves the Databricks provider
// that is implemented using the Terraform Plugin Framework. These different provider
// implementations are then combined using a multiplexing server, which allows multiple
// Protocol v6 providers to be served together. The function returns the multiplexed
// ProviderServer, or an error if any part of the process fails.
func GetProviderServer(ctx context.Context) (tfprotov6.ProviderServer, error) {
sdkPluginProvider := DatabricksProvider()

upgradedSdkPluginProvider, err := tf5to6server.UpgradeServer(
context.Background(),
sdkPluginProvider.GRPCProvider,
)
if err != nil {
log.Fatal(err)
}

pluginFrameworkProvider := pluginframeworkprovider.GetDatabricksProviderPluginFramework()

providers := []func() tfprotov6.ProviderServer{
func() tfprotov6.ProviderServer {
return upgradedSdkPluginProvider
},
providerserver.NewProtocol6(pluginFrameworkProvider),
}

muxServer, err := tf6muxserver.NewMuxServer(ctx, providers...)

if err != nil {
return nil, err
}

return muxServer.ProviderServer(), nil
}
Loading

0 comments on commit 8899beb

Please sign in to comment.