Skip to content

Commit 9efb7f0

Browse files
m-Bilalhgiasac
andauthored
return only the credentials property from the credential provider response (#185)
Co-authored-by: Toan Nguyen <hgiasac@gmail.com>
1 parent 4033419 commit 9efb7f0

3 files changed

Lines changed: 54 additions & 10 deletions

File tree

credentials/credentials_provider.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package credentials
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
7-
"io"
88
"net/http"
99
"net/url"
1010
"os"
@@ -17,7 +17,10 @@ import (
1717
"go.opentelemetry.io/otel/trace"
1818
)
1919

20-
var errAuthWebhookUriRequired = errors.New("the env var HASURA_CREDENTIALS_PROVIDER_URI must be set and non-empty")
20+
var (
21+
errAuthWebhookUriRequired = errors.New("the env var HASURA_CREDENTIALS_PROVIDER_URI must be set and non-empty")
22+
errEmptyCredentials = errors.New("empty credentials")
23+
)
2124

2225
var defaultClient = CredentialClient{
2326
httpClient: http.DefaultClient,
@@ -33,6 +36,11 @@ func AcquireCredentials(ctx context.Context, key string, forceRefresh bool) (str
3336
return defaultClient.AcquireCredentials(ctx, key, forceRefresh)
3437
}
3538

39+
// Payload is the credentials provider webhook response payload.
40+
type Payload struct {
41+
Credentials string `json:"credentials"`
42+
}
43+
3644
// CredentialClient is an HTTP client that can requests the credentials provider webhook to get the credentials.
3745
type CredentialClient struct {
3846
providerUri *url.URL
@@ -128,15 +136,21 @@ func (cc *CredentialClient) AcquireCredentials(ctx context.Context, key string,
128136

129137
span.SetAttributes(attribute.Int("http.response.status_code", resp.StatusCode))
130138

131-
body, err := io.ReadAll(resp.Body)
139+
var payload Payload
140+
err = json.NewDecoder(resp.Body).Decode(&payload)
132141
if err != nil {
133142
span.SetStatus(codes.Error, "failed to read the response")
134143
span.RecordError(err)
135144

136145
return "", fmt.Errorf("error reading response: %w", err)
137146
}
138147

139-
span.SetAttributes(attribute.Int64("http.response.size", int64(len(body))))
148+
if payload.Credentials == "" {
149+
span.SetStatus(codes.Error, errEmptyCredentials.Error())
150+
span.RecordError(err)
151+
152+
return "", errEmptyCredentials
153+
}
140154

141-
return string(body), nil
155+
return payload.Credentials, nil
142156
}

credentials/credentials_provider_test.go

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ import (
55
"fmt"
66
"net/http"
77
"net/http/httptest"
8+
"net/url"
89
"os"
910
"testing"
11+
12+
"go.opentelemetry.io/otel"
1013
)
1114

1215
func TestAcquireCredentials(t *testing.T) {
@@ -36,26 +39,28 @@ func TestAcquireCredentials(t *testing.T) {
3639
t.Errorf("expected Authorization=Bearer %s; got %s", bearerToken, r.Header.Get("Authorization"))
3740
}
3841

39-
fmt.Fprint(w, "credentials")
42+
fmt.Fprint(w, "{ \"credentials\": \"api-key\" }")
4043
}))
4144

4245
defer server.Close()
4346

4447
// Set the environment variable
4548
os.Setenv("HASURA_CREDENTIALS_PROVIDER_URI", server.URL)
4649
os.Setenv("HASURA_CREDENTIALS_PROVIDER_BEARER_TOKEN", bearerToken)
50+
defer os.Unsetenv("HASURA_CREDENTIALS_PROVIDER_URI")
51+
defer os.Unsetenv("HASURA_CREDENTIALS_PROVIDER_BEARER_TOKEN")
4752

4853
credentials, err := AcquireCredentials(context.TODO(), "key", false)
4954
if err != nil {
5055
t.Errorf("unexpected error: %v", err)
5156
}
5257

53-
if credentials != "credentials" {
54-
t.Errorf("expected credentials to be 'credentials', got '%s'", credentials)
58+
if credentials != "api-key" {
59+
t.Errorf("expected credentials to be 'api-key', got '%s'", credentials)
5560
}
5661
})
5762

58-
t.Run("when the request fails", func(t *testing.T) {
63+
t.Run("when the server does not exist", func(t *testing.T) {
5964
os.Setenv("HASURA_CREDENTIALS_PROVIDER_URI", "http://localhost:0000")
6065

6166
_, err := AcquireCredentials(context.TODO(), "key", false)
@@ -64,4 +69,28 @@ func TestAcquireCredentials(t *testing.T) {
6469
}
6570
})
6671
})
72+
73+
t.Run("when the response does not have credentials", func(t *testing.T) {
74+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
75+
fmt.Fprint(w, "{}")
76+
}))
77+
78+
defer server.Close()
79+
80+
serverUri, err := url.Parse(server.URL)
81+
if err != nil {
82+
t.Fatalf("unexpected error: %v", err)
83+
}
84+
85+
client := &CredentialClient{
86+
providerUri: serverUri,
87+
httpClient: server.Client(),
88+
propagator: otel.GetTextMapPropagator(),
89+
}
90+
91+
_, err = client.AcquireCredentials(context.TODO(), "key", false)
92+
if err != errEmptyCredentials {
93+
t.Errorf("expected an empty credentails error, got: %s\n", err)
94+
}
95+
})
6796
}

utils/decode_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package utils
22

33
import (
44
"fmt"
5+
"math"
56
"reflect"
67
"testing"
78
"time"
@@ -217,7 +218,7 @@ func TestDecodeDateTime(t *testing.T) {
217218
iNow := float64(now.UnixNano()) / float64(1000)
218219
value, err := DecodeDateTime(iNow, WithBaseUnix(time.Microsecond))
219220
assert.NilError(t, err)
220-
assert.Equal(t, int64(now.UnixNano()/1000), int64(value.UnixNano()/1000))
221+
assert.Assert(t, math.Abs(float64(int64(now.UnixNano()/1000)-int64(value.UnixNano()/1000))) <= 1)
221222
})
222223

223224
t.Run("from_string", func(t *testing.T) {

0 commit comments

Comments
 (0)