Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- Show error if ELASTICSEARCH_URL is not set when using credentials provider ([#68](https://github.com/hasura/ndc-elasticsearch/pull/68))

## [1.5.0]

- Add support for a credentials provider service ([#65](https://github.com/hasura/ndc-elasticsearch/pull/65))
Expand Down
64 changes: 44 additions & 20 deletions elasticsearch/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,23 @@ var (
credentailsProviderKeyEnvVar = "ELASTICSEARCH_CREDENTIALS_PROVIDER_KEY"
credentailsProviderMechanismEnvVar = "ELASTICSEARCH_CREDENTIALS_PROVIDER_MECHANISM"
credentialsProviderUri = "HASURA_CREDENTIALS_PROVIDER_URI"
elasticsearchUrl = "ELASTICSEARCH_URL"
)

var (
errCredentialProviderKeyNotSet = fmt.Errorf("%s is not set", credentailsProviderKeyEnvVar)
errCredentialProviderMechanismNotSet = fmt.Errorf("%s is not set", credentailsProviderMechanismEnvVar)
errCredentialProviderMechanismInvalid = fmt.Errorf("invalid value for %s, should be either \"api-key\" or \"service-token\"", credentailsProviderMechanismEnvVar)
errElasticsearchUrlNotSet = fmt.Errorf("%s is not set", elasticsearchUrl)
)

// getConfigFromEnv retrieves elastic search configuration from environment variables.
func getConfigFromEnv() (*elasticsearch.Config, error) {
esConfig := elasticsearch.Config{}

// Read the address
address := os.Getenv("ELASTICSEARCH_URL")
if address == "" {
return nil, errors.New("ELASTICSEARCH_URL is not set")
esConfig, err := getBaseConfig()
if err != nil {
return nil, err
}

// Split the address by comma
addresses := make([]string, 0)
addresses = append(addresses, strings.Split(address, ",")...)
esConfig.Addresses = addresses

// Read the credentials if provided
username := os.Getenv("ELASTICSEARCH_USERNAME")
password := os.Getenv("ELASTICSEARCH_PASSWORD")
Expand All @@ -65,42 +59,52 @@ func getConfigFromEnv() (*elasticsearch.Config, error) {
esConfig.CACert = cert
}

return &esConfig, nil
return esConfig, nil
}

func shouldUseCredentialsProvider() bool {
return os.Getenv(credentialsProviderUri) != ""
}

func getConfigFromCredentialsProvider(ctx context.Context, forceRefresh bool) (*elasticsearch.Config, error) {
esConfig, err := getBaseConfig()
if err != nil {
return nil, err
}

key := os.Getenv(credentailsProviderKeyEnvVar)
mechanism := os.Getenv(credentailsProviderMechanismEnvVar)
return useCredentialsProvider(ctx, key, mechanism, forceRefresh)
err = setupCredentailsUsingCredentialsProvider(ctx, esConfig, key, mechanism, forceRefresh)
if err != nil {
return nil, err
}
return esConfig, nil
}

func useCredentialsProvider(ctx context.Context, key string, mechanism string, forceRefresh bool) (*elasticsearch.Config, error) {
// setupCredentailsUsingCredentialsProvider sets up the credentials in the elasticsearch config.
// It returns the updated config.
func setupCredentailsUsingCredentialsProvider(ctx context.Context, esConfig *elasticsearch.Config, key string, mechanism string, forceRefresh bool) error {
if key == "" {
return nil, errCredentialProviderKeyNotSet
return errCredentialProviderKeyNotSet
}
if mechanism == "" {
return nil, errCredentialProviderMechanismNotSet
return errCredentialProviderMechanismNotSet
}
if mechanism != "api-key" && mechanism != "service-token" {
return nil, errCredentialProviderMechanismInvalid
return errCredentialProviderMechanismInvalid
}

credential, err := credentials.AcquireCredentials(ctx, key, forceRefresh)
if err != nil {
return nil, err
return err
}

esConfig := elasticsearch.Config{}
if mechanism == "api-key" {
esConfig.APIKey = credential
} else {
esConfig.ServiceToken = credential
}
return &esConfig, nil
return nil
}

func GetDefaultResultSize() int {
Expand All @@ -116,3 +120,23 @@ func GetDefaultResultSize() int {

return size
}

// getBaseConfig returns a new elasticsearch client with only the address set.
// This function should be used to setup the config with properties
// that will be common across all configs (credentials provieder based configs or env based configs).
func getBaseConfig() (*elasticsearch.Config, error) {
esConfig := elasticsearch.Config{}

// Read the address
address := os.Getenv(elasticsearchUrl)
if address == "" {
return nil, errElasticsearchUrlNotSet
}

// Split the address by comma
addresses := make([]string, 0)
addresses = append(addresses, strings.Split(address, ",")...)
esConfig.Addresses = addresses

return &esConfig, nil
}
13 changes: 9 additions & 4 deletions elasticsearch/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,31 @@ import (
"context"
"errors"
"testing"

"github.com/elastic/go-elasticsearch/v8"
)

func TestUseCredentailsProvider(t *testing.T) {
func TestSetupCredentailsUsingCredentialsProvider(t *testing.T) {
t.Run("should return error if key is not set", func(t *testing.T) {
_, err := useCredentialsProvider(context.Background(), "", "", false)
esConfig := elasticsearch.Config{}
err := setupCredentailsUsingCredentialsProvider(context.Background(), &esConfig, "", "", false)
if !errors.Is(err, errCredentialProviderKeyNotSet) {
t.Errorf("expected error to be errCredentialProviderKeyNotSet, got %v", err)
}
})

t.Run("key is set", func(t *testing.T) {
t.Run("should return error if mechanism is not set", func(t *testing.T) {
_, err := useCredentialsProvider(context.Background(), "key", "", false)
esConfig := elasticsearch.Config{}
err := setupCredentailsUsingCredentialsProvider(context.Background(), &esConfig, "key", "", false)
if !errors.Is(err, errCredentialProviderMechanismNotSet) {
t.Errorf("expected error to be errCredentialProviderMechanismNotSet, got %v", err)
}
})

t.Run("should return error if mechanism is invalid", func(t *testing.T) {
_, err := useCredentialsProvider(context.Background(), "key", "invalid", false)
esConfig := elasticsearch.Config{}
err := setupCredentailsUsingCredentialsProvider(context.Background(), &esConfig, "key", "invalid", false)
if !errors.Is(err, errCredentialProviderMechanismInvalid) {
t.Errorf("expected error to be errCredentialProviderMechanismInvalid, got %v", err)
}
Expand Down