Skip to content

Commit

Permalink
More work
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed Apr 15, 2024
1 parent 3ecb332 commit 7a1cf28
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 33 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -342,4 +342,7 @@ scripts/tt

.metals

provider/completeness.md
provider/completeness.md

# VS Code debugging binaries
__debug_bin*
102 changes: 92 additions & 10 deletions internal/acceptance/entitlements_base_resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"fmt"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/service/iam"
)

type entitlementResource interface {
resourceType() string
setDisplayName(string)
setWorkspaceClient(*databricks.WorkspaceClient)
create(context.Context) error
getEntitlements(context.Context) ([]iam.ComplexValue, error)
cleanUp(context.Context) error
dataSourceTemplate() string
tfReference() string
Expand All @@ -23,6 +26,10 @@ type userResource struct {
w *databricks.WorkspaceClient
}

func (u *userResource) resourceType() string {
return "user"
}

func (u *userResource) setDisplayName(displayName string) {
u.email = displayName + "@example.com"
}
Expand All @@ -42,6 +49,20 @@ func (u *userResource) create(ctx context.Context) error {
return nil
}

func (u *userResource) getEntitlements(ctx context.Context) ([]iam.ComplexValue, error) {
c, err := u.w.Config.NewApiClient()
if err != nil {
return nil, err
}
res := iam.User{}
err = c.Do(ctx, "GET", fmt.Sprintf("/api/2.0/preview/scim/v2/Users/%s?attributes=entitlements", u.id),
httpclient.WithResponseUnmarshal(&res))
if err != nil {
return nil, err
}
return res.Entitlements, nil
}

func (u *userResource) cleanUp(ctx context.Context) error {
return u.w.Users.DeleteById(ctx, u.id)
}
Expand All @@ -54,7 +75,7 @@ func (u *userResource) dataSourceTemplate() string {
}

func (u *userResource) tfReference() string {
return fmt.Sprintf("data.databricks_user.example.id")
return "user_id = data.databricks_user.example.id"
}

type groupResource struct {
Expand All @@ -63,6 +84,10 @@ type groupResource struct {
w *databricks.WorkspaceClient
}

func (g *groupResource) resourceType() string {
return "group"
}

func (g *groupResource) setDisplayName(displayName string) {
g.displayName = displayName
}
Expand All @@ -82,6 +107,20 @@ func (g *groupResource) create(ctx context.Context) error {
return nil
}

func (g *groupResource) getEntitlements(ctx context.Context) ([]iam.ComplexValue, error) {
c, err := g.w.Config.NewApiClient()
if err != nil {
return nil, err
}
res := iam.Group{}
err = c.Do(ctx, "GET", fmt.Sprintf("/api/2.0/preview/scim/v2/Groups/%s?attributes=entitlements", g.id),
httpclient.WithResponseUnmarshal(&res))
if err != nil {
return nil, err
}
return res.Entitlements, nil
}

func (g *groupResource) cleanUp(ctx context.Context) error {
return g.w.Groups.DeleteById(ctx, g.id)
}
Expand All @@ -94,15 +133,21 @@ func (g *groupResource) dataSourceTemplate() string {
}

func (g *groupResource) tfReference() string {
return fmt.Sprintf("data.databricks_group.example.id")
return "group_id = data.databricks_group.example.id"
}

type servicePrincipalResource struct {
applicationId string
cleanup bool
displayName string
id string
w *databricks.WorkspaceClient
}

func (s *servicePrincipalResource) resourceType() string {
return "service_principal"
}

func (s *servicePrincipalResource) setDisplayName(displayName string) {
s.displayName = displayName
}
Expand All @@ -112,28 +157,65 @@ func (s *servicePrincipalResource) setWorkspaceClient(w *databricks.WorkspaceCli
}

func (s *servicePrincipalResource) create(ctx context.Context) error {
sp, err := s.w.ServicePrincipals.Create(ctx, iam.ServicePrincipal{
ApplicationId: s.applicationId,
DisplayName: s.displayName,
})
sp, err := s.create0(ctx)
if err != nil {
return err
}
s.applicationId = sp.ApplicationId
if s.applicationId == "" {
s.applicationId = sp.ApplicationId
}
s.id = sp.Id
return nil
}

func (s *servicePrincipalResource) create0(ctx context.Context) (iam.ServicePrincipal, error) {
if s.applicationId != "" {
sps := s.w.ServicePrincipals.List(ctx, iam.ListServicePrincipalsRequest{
Filter: fmt.Sprintf(`applicationId eq "%s"`, s.applicationId),
})
if !sps.HasNext(ctx) {
return iam.ServicePrincipal{}, fmt.Errorf("service principal with applicationId %s not found", s.applicationId)
}
return sps.Next(ctx)
}
sp, err := s.w.ServicePrincipals.Create(ctx, iam.ServicePrincipal{
DisplayName: s.displayName,
})
return *sp, err
}

func (s *servicePrincipalResource) getEntitlements(ctx context.Context) ([]iam.ComplexValue, error) {
c, err := s.w.Config.NewApiClient()
if err != nil {
return nil, err
}
res := iam.ServicePrincipal{}
err = c.Do(ctx, "GET", fmt.Sprintf("/api/2.0/preview/scim/v2/ServicePrincipals/%s?attributes=entitlements", s.id),
httpclient.WithResponseUnmarshal(&res))
if err != nil {
return nil, err
}
return res.Entitlements, nil
}

func (s *servicePrincipalResource) cleanUp(ctx context.Context) error {
if !s.cleanup {
return nil
}
return s.w.ServicePrincipals.DeleteById(ctx, s.applicationId)
}

func (s *servicePrincipalResource) dataSourceTemplate() string {
fragment := fmt.Sprintf(`display_name = "%s"`, s.displayName)
if s.applicationId != "" {
fragment = fmt.Sprintf(`application_id = "%s"`, s.applicationId)
}
return fmt.Sprintf(`
data "databricks_service_principal" "example" {
display_name = "%s"
}`, s.displayName)
%s
}`, fragment)
}

func (s *servicePrincipalResource) tfReference() string {
return fmt.Sprintf("data.databricks_service_principal.example.id")
return "service_principal_id = data.databricks_service_principal.example.id"
}
36 changes: 14 additions & 22 deletions internal/acceptance/entitlements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"testing"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"github.com/databricks/databricks-sdk-go/service/iam"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
"github.com/stretchr/testify/assert"
)
Expand All @@ -23,7 +21,7 @@ func (e entitlement) String() string {
return fmt.Sprintf("%s = %t", e.name, e.value)
}

func entitlementsStepBuilder(t *testing.T, c **httpclient.ApiClient, r entitlementResource) func(entitlements []entitlement) step {
func entitlementsStepBuilder(t *testing.T, r entitlementResource) func(entitlements []entitlement) step {
return func(entitlements []entitlement) step {
entitlementsBuf := strings.Builder{}
for _, entitlement := range entitlements {
Expand All @@ -33,19 +31,15 @@ func entitlementsStepBuilder(t *testing.T, c **httpclient.ApiClient, r entitleme
Template: fmt.Sprintf(`
%s
resource "databricks_entitlements" "entitlements_users" {
group_id = %s
%s
%s
}
`, r.dataSourceTemplate(), r.tfReference(), entitlementsBuf.String()),
Check: func(s *terraform.State) error {
groupId := s.RootModule().Resources["data.databricks_group.example"].Primary.ID
var res iam.Group
ctx := context.Background()
err := (*c).Do(ctx, "GET", fmt.Sprintf("/api/2.0/preview/scim/v2/Groups/%s?attributes=entitlements", groupId),
httpclient.WithResponseUnmarshal(&res))
remoteEntitlements, err := r.getEntitlements(context.Background())
assert.NoError(t, err)
receivedEntitlements := make([]string, 0, len(res.Entitlements))
for _, entitlement := range res.Entitlements {
receivedEntitlements := make([]string, 0, len(remoteEntitlements))
for _, entitlement := range remoteEntitlements {
receivedEntitlements = append(receivedEntitlements, entitlement.Value)
}
expectedEntitlements := make([]string, 0, len(entitlements))
Expand All @@ -64,28 +58,24 @@ func entitlementsStepBuilder(t *testing.T, c **httpclient.ApiClient, r entitleme

func makeEntitlementsSteps(t *testing.T, r entitlementResource, entitlementsSteps [][]entitlement) []step {
r.setDisplayName(RandomName("entitlements-"))
var c *httpclient.ApiClient
makeEntitlementsStep := entitlementsStepBuilder(t, &c, r)
makeEntitlementsStep := entitlementsStepBuilder(t, r)
steps := make([]step, len(entitlementsSteps))
for i, entitlements := range entitlementsSteps {
steps[i] = makeEntitlementsStep(entitlements)
}
steps[0].PreConfig = makePreconfig(t, &c, r)
steps[0].PreConfig = makePreconfig(t, r)
return steps
}

func makePreconfig(t *testing.T, c **httpclient.ApiClient, r entitlementResource) func() {
func makePreconfig(t *testing.T, r entitlementResource) func() {
logger.DefaultLogger = &logger.SimpleLogger{
Level: logger.LevelDebug,
}
return func() {
w := databricks.Must(databricks.NewWorkspaceClient())
r.setWorkspaceClient(w)
var err error
*c, err = w.Config.NewApiClient()
assert.NoError(t, err)
ctx := context.Background()
err = r.create(ctx)
err := r.create(ctx)
assert.NoError(t, err)
t.Cleanup(func() {
r.cleanUp(ctx)
Expand All @@ -97,15 +87,17 @@ func entitlementsTest(t *testing.T, f func(*testing.T, entitlementResource)) {
loadWorkspaceEnv(t)
sp := &servicePrincipalResource{}
if isAzure(t) {
// A long-lived application is used in Azure.
sp.applicationId = GetEnvOrSkipTest(t, "ACCOUNT_LEVEL_SERVICE_PRINCIPAL_ID")
sp.cleanup = false
}
resources := []entitlementResource{
// &groupResource{},
&groupResource{},
&userResource{},
// sp,
sp,
}
for _, r := range resources {
t.Run(fmt.Sprintf("%T", r), func(t *testing.T) {
t.Run(r.resourceType(), func(t *testing.T) {
f(t, r)
})
}
Expand Down

0 comments on commit 7a1cf28

Please sign in to comment.