diff --git a/pkg/config/config.go b/pkg/config/config.go index 758b209..f0699b6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -160,7 +160,7 @@ func SetUser(user string) error { func MtlsEnabled() bool { authMethod := viper.GetString("authentication_method") - return authMethod == "mtls" + return authMethod == "mtls" && viper.GetBool("mtls_settings.enabled") } // BaseWebURL allows the ConsoleMe URL to be overridden for cases where the API diff --git a/pkg/creds/consoleme.go b/pkg/creds/consoleme.go index 48ec051..9354574 100644 --- a/pkg/creds/consoleme.go +++ b/pkg/creds/consoleme.go @@ -48,9 +48,8 @@ import ( var clientVersion = fmt.Sprintf("%s", metadata.Version) var userAgent = "weep/" + clientVersion + " Go-http-client/1.1" - -type Account struct { -} +var clientFactoryOverride ClientFactory +var preflightFunctions = make([]RequestPreflight, 0) // HTTPClient is the interface we expect HTTP clients to implement. type HTTPClient interface { @@ -67,13 +66,39 @@ type Client struct { Region string } +type ClientFactory func() (*http.Client, error) + +// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client +// creation with a ClientFactory. This function will be called during the creation +// of all ConsoleMe clients. +func RegisterClientFactory(factory ClientFactory) { + clientFactoryOverride = factory +} + +type RequestPreflight func(req *http.Request) error + +// RegisterRequestPreflight adds a RequestPreflight function which will be called in the +// order of registration during the creation of a ConsoleMe request. +func RegisterRequestPreflight(preflight RequestPreflight) { + preflightFunctions = append(preflightFunctions, preflight) +} + // GetClient creates an authenticated ConsoleMe client func GetClient(region string) (*Client, error) { var client *Client consoleMeUrl := viper.GetString("consoleme_url") authenticationMethod := viper.GetString("authentication_method") - if authenticationMethod == "mtls" { + if clientFactoryOverride != nil { + customClient, err := clientFactoryOverride() + if err != nil { + return client, err + } + client, err = NewClient(consoleMeUrl, "", customClient) + if err != nil { + return client, err + } + } else if authenticationMethod == "mtls" { mtlsClient, err := mtls.NewHTTPClient() if err != nil { return client, err @@ -122,6 +147,18 @@ func NewClient(hostname string, region string, httpc *http.Client) (*Client, err return c, nil } +func runPreflightFunctions(req *http.Request) error { + var err error + if preflightFunctions != nil { + for _, preflight := range preflightFunctions { + if err = preflight(req); err != nil { + return err + } + } + } + return nil +} + func (c *Client) buildRequest(method string, resource string, body io.Reader, apiPrefix string) (*http.Request, error) { urlStr := c.Host + apiPrefix + resource req, err := http.NewRequest(method, urlStr, body) @@ -130,6 +167,10 @@ func (c *Client) buildRequest(method string, resource string, body io.Reader, ap } req.Header.Set("User-Agent", userAgent) req.Header.Add("Content-Type", "application/json") + err = runPreflightFunctions(req) + if err != nil { + return nil, err + } return req, nil }