Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoisted userManagementClient and providerConfigClient into baseClient #317

Merged
merged 2 commits into from
Dec 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,25 @@ func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error)
return nil, err
}

hc, _, err := transport.NewHTTPClient(ctx, conf.Opts...)
transport, _, err := transport.NewHTTPClient(ctx, conf.Opts...)
if err != nil {
return nil, err
}

hc := internal.WithDefaultRetryConfig(transport)
hc.CreateErrFn = handleHTTPError
hc.SuccessFn = internal.HasSuccessStatus
hc.Opts = []internal.HTTPOption{
internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)),
}

base := &baseClient{
userManagementClient: newUserManagementClient(hc, conf),
providerConfigClient: newProviderConfigClient(hc, conf),
idTokenVerifier: idTokenVerifier,
cookieVerifier: cookieVerifier,
userManagementEndpoint: idToolkitV1Endpoint,
providerConfigEndpoint: providerConfigEndpoint,
projectID: conf.ProjectID,
httpClient: hc,
idTokenVerifier: idTokenVerifier,
cookieVerifier: cookieVerifier,
}
return &Client{
baseClient: base,
Expand Down Expand Up @@ -183,7 +192,7 @@ func (c *Client) SessionCookie(
idToken string,
expiresIn time.Duration,
) (string, error) {
return c.baseClient.userManagementClient.createSessionCookie(ctx, idToken, expiresIn)
return c.baseClient.createSessionCookie(ctx, idToken, expiresIn)
}

// Token represents a decoded Firebase ID token.
Expand Down Expand Up @@ -213,22 +222,21 @@ type FirebaseInfo struct {
Identities map[string]interface{} `json:"identities"`
}

// baseClient exposes the APIs common to both auth.Client and auth.TenantClient.
type baseClient struct {
*userManagementClient
*providerConfigClient
idTokenVerifier *tokenVerifier
cookieVerifier *tokenVerifier
tenantID string
userManagementEndpoint string
providerConfigEndpoint string
projectID string
tenantID string
httpClient *internal.HTTPClient
idTokenVerifier *tokenVerifier
cookieVerifier *tokenVerifier
}

func (c *baseClient) withTenantID(tenantID string) *baseClient {
return &baseClient{
userManagementClient: c.userManagementClient.withTenantID(tenantID),
providerConfigClient: c.providerConfigClient.withTenantID(tenantID),
idTokenVerifier: c.idTokenVerifier,
cookieVerifier: c.cookieVerifier,
tenantID: tenantID,
}
copy := *c
copy.tenantID = tenantID
return &copy
}

// VerifyIDToken verifies the signature and payload of the provided ID token.
Expand Down
27 changes: 15 additions & 12 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ func TestNewClientWithServiceAccountCredentials(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, creds.ProjectID); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, creds.ProjectID); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, creds.ProjectID); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand All @@ -127,8 +127,8 @@ func TestNewClientWithoutCredentials(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, ""); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, ""); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand All @@ -155,8 +155,8 @@ func TestNewClientWithServiceAccountID(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, ""); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, ""); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand Down Expand Up @@ -194,8 +194,8 @@ func TestNewClientWithUserCredentials(t *testing.T) {
if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil {
t.Errorf("NewClient().cookieVerifier: %v", err)
}
if err := checkUserManagementClient(client, ""); err != nil {
t.Errorf("NewClient().userManagementClient: %v", err)
if err := checkBaseClient(client, ""); err != nil {
t.Errorf("NewClient().baseClient: %v", err)
}
if client.clock != internal.SystemClock {
t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock)
Expand Down Expand Up @@ -1057,10 +1057,13 @@ func checkCookieVerifier(tv *tokenVerifier, projectID string) error {
return nil
}

func checkUserManagementClient(client *Client, wantProjectID string) error {
umc := client.userManagementClient
if umc.baseURL != idToolkitV1Endpoint {
return fmt.Errorf("baseURL = %q; want = %q", umc.baseURL, idToolkitV1Endpoint)
func checkBaseClient(client *Client, wantProjectID string) error {
umc := client.baseClient
if umc.userManagementEndpoint != idToolkitV1Endpoint {
return fmt.Errorf("userManagementEndpoint = %q; want = %q", umc.userManagementEndpoint, idToolkitV1Endpoint)
}
if umc.providerConfigEndpoint != providerConfigEndpoint {
return fmt.Errorf("providerConfigEndpoint = %q; want = %q", umc.providerConfigEndpoint, providerConfigEndpoint)
}
if umc.projectID != wantProjectID {
return fmt.Errorf("projectID = %q; want = %q", umc.projectID, wantProjectID)
Expand Down
12 changes: 6 additions & 6 deletions auth/email_action_links.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,38 @@ const (

// EmailVerificationLink generates the out-of-band email action link for email verification flows for the specified
// email address.
func (c *userManagementClient) EmailVerificationLink(ctx context.Context, email string) (string, error) {
func (c *baseClient) EmailVerificationLink(ctx context.Context, email string) (string, error) {
return c.EmailVerificationLinkWithSettings(ctx, email, nil)
}

// EmailVerificationLinkWithSettings generates the out-of-band email action link for email verification flows for the
// specified email address, using the action code settings provided.
func (c *userManagementClient) EmailVerificationLinkWithSettings(
func (c *baseClient) EmailVerificationLinkWithSettings(
ctx context.Context, email string, settings *ActionCodeSettings) (string, error) {
return c.generateEmailActionLink(ctx, emailVerification, email, settings)
}

// PasswordResetLink generates the out-of-band email action link for password reset flows for the specified email
// address.
func (c *userManagementClient) PasswordResetLink(ctx context.Context, email string) (string, error) {
func (c *baseClient) PasswordResetLink(ctx context.Context, email string) (string, error) {
return c.PasswordResetLinkWithSettings(ctx, email, nil)
}

// PasswordResetLinkWithSettings generates the out-of-band email action link for password reset flows for the
// specified email address, using the action code settings provided.
func (c *userManagementClient) PasswordResetLinkWithSettings(
func (c *baseClient) PasswordResetLinkWithSettings(
ctx context.Context, email string, settings *ActionCodeSettings) (string, error) {
return c.generateEmailActionLink(ctx, passwordReset, email, settings)
}

// EmailSignInLink generates the out-of-band email action link for email link sign-in flows, using the action
// code settings provided.
func (c *userManagementClient) EmailSignInLink(
func (c *baseClient) EmailSignInLink(
ctx context.Context, email string, settings *ActionCodeSettings) (string, error) {
return c.generateEmailActionLink(ctx, emailLinkSignIn, email, settings)
}

func (c *userManagementClient) generateEmailActionLink(
func (c *baseClient) generateEmailActionLink(
ctx context.Context, linkType linkType, email string, settings *ActionCodeSettings) (string, error) {

if email == "" {
Expand Down
2 changes: 1 addition & 1 deletion auth/email_action_links_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func TestEmailVerificationLinkError(t *testing.T) {
}
s := echoServer(testActionLinkResponse, t)
defer s.Close()
s.Client.userManagementClient.httpClient.RetryConfig = nil
s.Client.baseClient.httpClient.RetryConfig = nil
s.Status = http.StatusInternalServerError

for code, check := range cases {
Expand Down
4 changes: 2 additions & 2 deletions auth/export_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const maxReturnedResults = 1000
//
// If nextPageToken is empty, the iterator will start at the beginning.
// If the nextPageToken is not empty, the iterator starts after the token.
func (c *userManagementClient) Users(ctx context.Context, nextPageToken string) *UserIterator {
func (c *baseClient) Users(ctx context.Context, nextPageToken string) *UserIterator {
it := &UserIterator{
ctx: ctx,
client: c,
Expand All @@ -49,7 +49,7 @@ func (c *userManagementClient) Users(ctx context.Context, nextPageToken string)
//
// Also see: https://github.com/GoogleCloudPlatform/google-cloud-go/wiki/Iterator-Guidelines
type UserIterator struct {
client *userManagementClient
client *baseClient
ctx context.Context
nextFunc func() error
pageInfo *iterator.PageInfo
Expand Down
2 changes: 1 addition & 1 deletion auth/import_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type ErrorInfo struct {
//
// No more than 1000 users can be imported in a single call. If at least one user specifies a
// password, a UserImportHash must be specified as an option.
func (c *userManagementClient) ImportUsers(
func (c *baseClient) ImportUsers(
ctx context.Context, users []*UserToImport, opts ...UserImportOption) (*UserImportResult, error) {

if len(users) == 0 {
Expand Down
60 changes: 17 additions & 43 deletions auth/provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (config *OIDCProviderConfigToUpdate) buildRequest() (nestedMap, error) {

// OIDCProviderConfigIterator is an iterator over OIDC provider configurations.
type OIDCProviderConfigIterator struct {
client *providerConfigClient
client *baseClient
ctx context.Context
nextFunc func() error
pageInfo *iterator.PageInfo
Expand Down Expand Up @@ -538,7 +538,7 @@ func (config *SAMLProviderConfigToUpdate) buildRequest() (nestedMap, error) {

// SAMLProviderConfigIterator is an iterator over SAML provider configurations.
type SAMLProviderConfigIterator struct {
client *providerConfigClient
client *baseClient
ctx context.Context
nextFunc func() error
pageInfo *iterator.PageInfo
Expand Down Expand Up @@ -595,36 +595,8 @@ func (it *SAMLProviderConfigIterator) fetch(pageSize int, pageToken string) (str
return result.NextPageToken, nil
}

type providerConfigClient struct {
endpoint string
projectID string
tenantID string
httpClient *internal.HTTPClient
}

func newProviderConfigClient(client *http.Client, conf *internal.AuthConfig) *providerConfigClient {
hc := internal.WithDefaultRetryConfig(client)
hc.CreateErrFn = handleHTTPError
hc.SuccessFn = internal.HasSuccessStatus
hc.Opts = []internal.HTTPOption{
internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)),
}

return &providerConfigClient{
endpoint: providerConfigEndpoint,
projectID: conf.ProjectID,
httpClient: hc,
}
}

func (c *providerConfigClient) withTenantID(tenantID string) *providerConfigClient {
copy := *c
copy.tenantID = tenantID
return &copy
}

// OIDCProviderConfig returns the OIDCProviderConfig with the given ID.
func (c *providerConfigClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) {
func (c *baseClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) {
if err := validateOIDCConfigID(id); err != nil {
return nil, err
}
Expand All @@ -642,7 +614,7 @@ func (c *providerConfigClient) OIDCProviderConfig(ctx context.Context, id string
}

// CreateOIDCProviderConfig creates a new OIDC provider config from the given parameters.
func (c *providerConfigClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) {
func (c *baseClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) {
if config == nil {
return nil, errors.New("config must not be nil")
}
Expand All @@ -669,7 +641,7 @@ func (c *providerConfigClient) CreateOIDCProviderConfig(ctx context.Context, con
}

// UpdateOIDCProviderConfig updates an existing OIDC provider config with the given parameters.
func (c *providerConfigClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) {
func (c *baseClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) {
if err := validateOIDCConfigID(id); err != nil {
return nil, err
}
Expand Down Expand Up @@ -704,7 +676,7 @@ func (c *providerConfigClient) UpdateOIDCProviderConfig(ctx context.Context, id
}

// DeleteOIDCProviderConfig deletes the OIDCProviderConfig with the given ID.
func (c *providerConfigClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error {
func (c *baseClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error {
if err := validateOIDCConfigID(id); err != nil {
return err
}
Expand All @@ -721,7 +693,7 @@ func (c *providerConfigClient) DeleteOIDCProviderConfig(ctx context.Context, id
//
// If nextPageToken is empty, the iterator will start at the beginning. Otherwise,
// iterator starts after the token.
func (c *providerConfigClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator {
func (c *baseClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator {
it := &OIDCProviderConfigIterator{
ctx: ctx,
client: c,
Expand All @@ -736,7 +708,7 @@ func (c *providerConfigClient) OIDCProviderConfigs(ctx context.Context, nextPage
}

// SAMLProviderConfig returns the SAMLProviderConfig with the given ID.
func (c *providerConfigClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) {
func (c *baseClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) {
if err := validateSAMLConfigID(id); err != nil {
return nil, err
}
Expand All @@ -754,7 +726,7 @@ func (c *providerConfigClient) SAMLProviderConfig(ctx context.Context, id string
}

// CreateSAMLProviderConfig creates a new SAML provider config from the given parameters.
func (c *providerConfigClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) {
func (c *baseClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) {
if config == nil {
return nil, errors.New("config must not be nil")
}
Expand All @@ -781,7 +753,7 @@ func (c *providerConfigClient) CreateSAMLProviderConfig(ctx context.Context, con
}

// UpdateSAMLProviderConfig updates an existing SAML provider config with the given parameters.
func (c *providerConfigClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) {
func (c *baseClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) {
if err := validateSAMLConfigID(id); err != nil {
return nil, err
}
Expand Down Expand Up @@ -816,7 +788,7 @@ func (c *providerConfigClient) UpdateSAMLProviderConfig(ctx context.Context, id
}

// DeleteSAMLProviderConfig deletes the SAMLProviderConfig with the given ID.
func (c *providerConfigClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error {
func (c *baseClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error {
if err := validateSAMLConfigID(id); err != nil {
return err
}
Expand All @@ -833,7 +805,7 @@ func (c *providerConfigClient) DeleteSAMLProviderConfig(ctx context.Context, id
//
// If nextPageToken is empty, the iterator will start at the beginning. Otherwise,
// iterator starts after the token.
func (c *providerConfigClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator {
func (c *baseClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator {
it := &SAMLProviderConfigIterator{
ctx: ctx,
client: c,
Expand All @@ -847,15 +819,17 @@ func (c *providerConfigClient) SAMLProviderConfigs(ctx context.Context, nextPage
return it
}

func (c *providerConfigClient) makeRequest(ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) {
func (c *baseClient) makeRequest(
ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) {

if c.projectID == "" {
return nil, errors.New("project id not available")
}

if c.tenantID != "" {
req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.endpoint, c.projectID, c.tenantID, req.URL)
req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.providerConfigEndpoint, c.projectID, c.tenantID, req.URL)
} else {
req.URL = fmt.Sprintf("%s/projects/%s%s", c.endpoint, c.projectID, req.URL)
req.URL = fmt.Sprintf("%s/projects/%s%s", c.providerConfigEndpoint, c.projectID, req.URL)
}

return c.httpClient.DoAndUnmarshal(ctx, req, v)
Expand Down
Loading