diff --git a/internal/http/client.go b/internal/http/client.go index e3daf0664..32705c64d 100644 --- a/internal/http/client.go +++ b/internal/http/client.go @@ -48,6 +48,9 @@ type Client struct { errorValue ErrorResponse logger logging.Logger + + // customHeaders stores headers to be applied to all requests + customHeaders map[string]string } // NewClient is used to create a new instance of Client. @@ -104,11 +107,16 @@ func NewClient(cfg config.Config) Client { } client := Client{ - authStrategy: &ClassicV2Authorizer{}, - client: r, - config: cfg, - errorValue: &DefaultErrorResponse{}, - logger: logger, + authStrategy: &ClassicV2Authorizer{}, + client: r, + config: cfg, + errorValue: &DefaultErrorResponse{}, + logger: logger, + customHeaders: make(map[string]string), + } + + if cfg.CustomHeaders != nil { + client.SetCustomHeaders(cfg.CustomHeaders) } switch cfg.Compression { @@ -139,6 +147,18 @@ func (c *Client) SetErrorValue(v ErrorResponse) *Client { return c } +// SetCustomHeaders is used to set custom headers at the client level +func (c *Client) SetCustomHeaders(headers map[string]string) { + if c.customHeaders == nil { + c.customHeaders = make(map[string]string) + } + + // Merge the new headers with existing ones + for k, v := range headers { + c.customHeaders[k] = v + } +} + // Get represents an HTTP GET request to a New Relic API. // The queryParams argument can be used to add query string parameters to the requested URL. // The respBody argument will be unmarshaled from JSON in the response body to the type provided. @@ -197,14 +217,15 @@ func (c *Client) PostWithContext( reqBody interface{}, respBody interface{}, ) (*http.Response, error) { - req, err := c.NewRequest(http.MethodPost, url, queryParams, reqBody, respBody) - if err != nil { - return nil, err - } - - req.WithContext(ctx) - - return c.Do(req) + //req, err := c.NewRequest(http.MethodPost, url, queryParams, reqBody, respBody) + //if err != nil { + // return nil, err + //} + // + //req.WithContext(ctx) + // + //return c.Do(req) + return c.PostWithContextAndHeaders(ctx, url, queryParams, reqBody, respBody, nil) } // Put represents an HTTP PUT request to a New Relic API. @@ -544,3 +565,61 @@ func (c *Client) NewNerdGraphRequest(query string, vars map[string]interface{}, return req, nil } + +func (c *Client) PostWithHeaders( + url string, + queryParams interface{}, + reqBody interface{}, + respBody interface{}, + customHeaders map[string]string, +) (*http.Response, error) { + return c.PostWithContextAndHeaders(context.Background(), url, queryParams, reqBody, respBody, customHeaders) +} + +// new methods for custom headers + +func (c *Client) NerdGraphQueryWithHeaders(query string, vars map[string]interface{}, respBody interface{}, customHeaders map[string]string) error { + return c.NerdGraphQueryWithContextAndHeaders(context.Background(), query, vars, respBody, customHeaders) +} + +func (c *Client) NerdGraphQueryWithContextAndHeaders(ctx context.Context, query string, vars map[string]interface{}, respBody interface{}, customHeaders map[string]string) error { + req, err := c.NewNerdGraphRequest(query, vars, respBody) + if err != nil { + return err + } + + if customHeaders != nil { + req.SetCustomHeaders(customHeaders) + } + + req.WithContext(ctx) + + _, err = c.Do(req) + if err != nil { + return err + } + + return nil +} + +func (c *Client) PostWithContextAndHeaders( + ctx context.Context, + url string, + queryParams interface{}, + reqBody interface{}, + respBody interface{}, + customHeaders map[string]string, +) (*http.Response, error) { + req, err := c.NewRequest(http.MethodPost, url, queryParams, reqBody, respBody) + if err != nil { + return nil, err + } + + if customHeaders != nil { + req.SetCustomHeaders(customHeaders) + } + + req.WithContext(ctx) + + return c.Do(req) +} diff --git a/internal/http/request.go b/internal/http/request.go index 6e40a9795..7d0b4a662 100644 --- a/internal/http/request.go +++ b/internal/http/request.go @@ -15,15 +15,16 @@ import ( // Request represents a configurable HTTP request. type Request struct { - method string - url string - params interface{} - reqBody interface{} - value interface{} - config config.Config - authStrategy RequestAuthorizer - errorValue ErrorResponse - request *retryablehttp.Request + method string + url string + params interface{} + reqBody interface{} + value interface{} + config config.Config + authStrategy RequestAuthorizer + errorValue ErrorResponse + request *retryablehttp.Request + customHeaders map[string]string } // NewRequest creates a new Request struct. @@ -35,13 +36,14 @@ func (c *Client) NewRequest(method string, url string, params interface{}, reqBo ) req := &Request{ - method: method, - url: url, - params: params, - reqBody: reqBody, - value: value, - authStrategy: c.authStrategy, - errorValue: c.errorValue, + method: method, + url: url, + params: params, + reqBody: reqBody, + value: value, + authStrategy: c.authStrategy, + errorValue: c.errorValue, + customHeaders: make(map[string]string), } // FIXME: We should remove this requirement on the request @@ -119,6 +121,14 @@ func (r *Request) SetErrorValue(e ErrorResponse) { r.errorValue = e } +// SetCustomHeaders is used to the Request struct to set custom headers for a specific request +func (r *Request) SetCustomHeaders(headers map[string]string) *Request { + for k, v := range headers { + r.SetHeader(k, v) + } + return r +} + // SetServiceName sets the service name for the request. func (r *Request) SetServiceName(serviceName string) { serviceName = fmt.Sprintf("%s|%s", serviceName, defaultServiceName) @@ -139,6 +149,13 @@ func (r *Request) makeRequest() (*retryablehttp.Request, error) { return nil, err } + // Apply client-level custom headers if available + if r.config.CustomHeaders != nil { + for key, value := range r.config.CustomHeaders { + r.request.Header.Set(key, value) + } + } + return r.request, nil } diff --git a/newrelic/newrelic.go b/newrelic/newrelic.go index e165f7515..70580ce73 100644 --- a/newrelic/newrelic.go +++ b/newrelic/newrelic.go @@ -241,3 +241,8 @@ func ConfigLogJSON(logJSON bool) ConfigOption { func ConfigLogger(logger logging.Logger) ConfigOption { return config.ConfigLogger(logger) } + +// ConfigCustomHeaders sets custom headers that will be sent with every request. +func ConfigCustomHeaders(headers map[string]string) ConfigOption { + return config.ConfigCustomHeaders(headers) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index c741ece6e..b20429f38 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -58,6 +58,9 @@ type Config struct { // Logger allows customization of the client's underlying logger. Logger logging.Logger + + // CustomHeaders stores headers to be applied to all requests + CustomHeaders map[string]string } // New creates a default configuration and returns it @@ -65,9 +68,10 @@ func New() Config { reg, _ := region.Get(region.Default) return Config{ - region: reg, - LogLevel: "info", - Compression: Compression.None, + region: reg, + LogLevel: "info", + Compression: Compression.None, + CustomHeaders: make(map[string]string), } } @@ -142,3 +146,19 @@ func (c *Config) GetLogger() logging.Logger { return l } + +// ConfigCustomHeaders sets the custom headers to be sent with each request +func ConfigCustomHeaders(headers map[string]string) ConfigOption { + return func(cfg *Config) error { + if cfg.CustomHeaders == nil { + cfg.CustomHeaders = make(map[string]string) + } + + // Merge the provided headers with existing ones + for k, v := range headers { + cfg.CustomHeaders[k] = v + } + + return nil + } +} diff --git a/pkg/dashboards/types.go b/pkg/dashboards/types.go index 7da90f527..692c8a02e 100644 --- a/pkg/dashboards/types.go +++ b/pkg/dashboards/types.go @@ -572,7 +572,9 @@ type DashboardWidgetLayoutInput struct { // DashboardWidgetNRQLQueryInput - NRQL query used by a widget. type DashboardWidgetNRQLQueryInput struct { // New Relic account ID to issue the query against. - AccountID int `json:"accountId"` + AccountID int `json:"accountId,omitempty"` + // New Relic account IDs to issue the query against. + AccountIDS []int `json:"accountIds,omitempty"` // NRQL formatted query. Query nrdb.NRQL `json:"query"` }