Skip to content

Commit

Permalink
Refactor HTTP client creation to be more generic (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders authored Jan 14, 2022
1 parent cd9c996 commit e2aec9a
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 80 deletions.
2 changes: 1 addition & 1 deletion cmd/credential_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func generateCredentialProcessConfig(destination string) error {
if destination == "" {
return fmt.Errorf("no destination provided")
}
client, err := creds.GetClient(region)
client, err := creds.GetClient()
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func preInteractiveCheck(region string, client *creds.Client) (*creds.Client, er
// If a client was not provided, create one using the provided region
if client == nil {
var err error
client, err = creds.GetClient(region)
client, err = creds.GetClient()
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ var listCmd = &cobra.Command{
}

func roleList() (string, error) {
client, err := creds.GetClient(region)
client, err := creds.GetClient()
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/open.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func runOpen(cmd *cobra.Command, args []string) error {
return errors.New("Resource type sns and sqs require region in the arn")
}
var resourceURL string
client, err := creds.GetClient(region)
client, err := creds.GetClient()
if err != nil {
logging.LogError(err, "Error getting client")
return err
Expand Down
82 changes: 10 additions & 72 deletions pkg/creds/consoleme.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ import (
"strings"
"time"

"github.com/netflix/weep/pkg/httpAuth"
"github.com/netflix/weep/pkg/httpAuth/custom"

"github.com/netflix/weep/pkg/util"

"github.com/netflix/weep/pkg/aws"
"github.com/netflix/weep/pkg/config"
werrors "github.com/netflix/weep/pkg/errors"
"github.com/netflix/weep/pkg/httpAuth/challenge"
"github.com/netflix/weep/pkg/httpAuth/mtls"
"github.com/netflix/weep/pkg/logging"
"github.com/netflix/weep/pkg/metadata"

Expand All @@ -48,8 +50,6 @@ import (
var clientVersion = fmt.Sprintf("%s", metadata.Version)

var userAgent = "weep/" + clientVersion + " Go-http-client/1.1"
var clientFactoryOverride ClientFactory
var preflightFunctions = make([]RequestPreflight, 0)

// HTTPClient is the interface we expect HTTP clients to implement.
type HTTPClient interface {
Expand All @@ -66,65 +66,15 @@ type Client struct {
Region string
}

type ClientFactory func() (*http.Client, error)

// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client
// creation with a ClientFactory. This function will be called during the creation
// of all ConsoleMe clients.
func RegisterClientFactory(factory ClientFactory) {
clientFactoryOverride = factory
}

type RequestPreflight func(req *http.Request) error

// RegisterRequestPreflight adds a RequestPreflight function which will be called in the
// order of registration during the creation of a ConsoleMe request.
func RegisterRequestPreflight(preflight RequestPreflight) {
preflightFunctions = append(preflightFunctions, preflight)
}

// GetClient creates an authenticated ConsoleMe client
func GetClient(region string) (*Client, error) {
func GetClient() (*Client, error) {
var client *Client
consoleMeUrl := viper.GetString("consoleme_url")
authenticationMethod := viper.GetString("authentication_method")

if clientFactoryOverride != nil {
customClient, err := clientFactoryOverride()
if err != nil {
return client, err
}
client, err = NewClient(consoleMeUrl, "", customClient)
if err != nil {
return client, err
}
} else if authenticationMethod == "mtls" {
mtlsClient, err := mtls.NewHTTPClient()
if err != nil {
return client, err
}
client, err = NewClient(consoleMeUrl, "", mtlsClient)
if err != nil {
return client, err
}
} else if authenticationMethod == "challenge" {
err := challenge.RefreshChallenge()
if err != nil {
return client, err
}
httpClient, err := challenge.NewHTTPClient(consoleMeUrl)
if err != nil {
return client, err
}
client, err = NewClient(consoleMeUrl, "", httpClient)
if err != nil {
return client, err
}
} else {
return nil, fmt.Errorf("Authentication method unsupported or not provided.")
httpClient, err := httpAuth.GetAuthenticatedClient()
if err != nil {
return client, err
}

return client, nil
return NewClient(consoleMeUrl, "", httpClient)
}

// NewClient takes a ConsoleMe hostname and *http.Client, and returns a
Expand All @@ -147,18 +97,6 @@ func NewClient(hostname string, region string, httpc *http.Client) (*Client, err
return c, nil
}

func runPreflightFunctions(req *http.Request) error {
var err error
if preflightFunctions != nil {
for _, preflight := range preflightFunctions {
if err = preflight(req); err != nil {
return err
}
}
}
return nil
}

func (c *Client) buildRequest(method string, resource string, body io.Reader, apiPrefix string) (*http.Request, error) {
urlStr := c.Host + apiPrefix + resource
req, err := http.NewRequest(method, urlStr, body)
Expand All @@ -167,7 +105,7 @@ func (c *Client) buildRequest(method string, resource string, body io.Reader, ap
}
req.Header.Set("User-Agent", userAgent)
req.Header.Add("Content-Type", "application/json")
err = runPreflightFunctions(req)
err = custom.RunPreflightFunctions(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -579,7 +517,7 @@ func GetCredentialsC(client HTTPClient, role string, ipRestrict bool, assumeRole
// GetCredentials requests credentials from ConsoleMe then follows the provided chain of roles to
// assume. Roles are assumed in the order in which they appear in the assumeRole slice.
func GetCredentials(role string, ipRestrict bool, assumeRole []string, region string) (*aws.Credentials, error) {
client, err := GetClient(region)
client, err := GetClient()
if err != nil {
return nil, err
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/httpAuth/custom/custom.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package custom

import "net/http"

var isOverridden bool
var clientFactoryOverride ClientFactory
var preflightFunctions = make([]RequestPreflight, 0)

type ClientFactory func() (*http.Client, error)

func UseCustom() bool {
return isOverridden
}

func NewHTTPClient() (*http.Client, error) {
return clientFactoryOverride()
}

// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client
// creation with a ClientFactory. This function will be called during the creation
// of all ConsoleMe clients.
func RegisterClientFactory(factory ClientFactory) {
clientFactoryOverride = factory
isOverridden = true
}

type RequestPreflight func(req *http.Request) error

// RegisterRequestPreflight adds a RequestPreflight function which will be called in the
// order of registration during the creation of a ConsoleMe request.
func RegisterRequestPreflight(preflight RequestPreflight) {
preflightFunctions = append(preflightFunctions, preflight)
}

func RunPreflightFunctions(req *http.Request) error {
var err error
if preflightFunctions != nil {
for _, preflight := range preflightFunctions {
if err = preflight(req); err != nil {
return err
}
}
}
return nil
}
28 changes: 28 additions & 0 deletions pkg/httpAuth/httpAuth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package httpAuth

import (
"fmt"
"net/http"

"github.com/netflix/weep/pkg/httpAuth/challenge"
"github.com/netflix/weep/pkg/httpAuth/custom"
"github.com/netflix/weep/pkg/httpAuth/mtls"
"github.com/spf13/viper"
)

func GetAuthenticatedClient() (*http.Client, error) {
authenticationMethod := viper.GetString("authentication_method")
consoleMeUrl := viper.GetString("consoleme_url")
if custom.UseCustom() {
return custom.NewHTTPClient()
} else if authenticationMethod == "mtls" {
return mtls.NewHTTPClient()
} else if authenticationMethod == "challenge" {
err := challenge.RefreshChallenge()
if err != nil {
return nil, err
}
return challenge.NewHTTPClient(consoleMeUrl)
}
return nil, fmt.Errorf("Authentication method unsupported or not provided.")
}
2 changes: 1 addition & 1 deletion pkg/server/ecsCredentialsHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func parseAssumeRoleQuery(r *http.Request) ([]string, error) {

func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
var client, err = creds.GetClient(region)
var client, err = creds.GetClient()
if err != nil {
logging.Log.Error(err)
util.WriteError(w, err.Error(), http.StatusBadRequest)
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Run(host string, port int, role, region string, shutdown chan os.Signal) er

if isServingIMDS {
logging.Log.Infof("Configuring weep IMDS service for role %s", role)
client, err := creds.GetClient(region)
client, err := creds.GetClient()
if err != nil {
return err
}
Expand Down
9 changes: 7 additions & 2 deletions pkg/swag/swag.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"fmt"
"net/http"

"github.com/netflix/weep/pkg/httpAuth/mtls"
"github.com/netflix/weep/pkg/creds"

"github.com/spf13/viper"
)

Expand All @@ -15,7 +16,11 @@ type SwagResponse struct {

func getClient() (*http.Client, error) {
if viper.GetBool("swag.use_mtls") {
return mtls.NewHTTPClient()
consoleMeClient, err := creds.GetClient()
if err != nil {
return nil, err
}
return &consoleMeClient.Client, nil
}
return http.DefaultClient, nil
}
Expand Down

0 comments on commit e2aec9a

Please sign in to comment.