Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 44 additions & 49 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,84 +18,80 @@ type Client struct {
secretId string
secretKey string

token *Token
nextRefresh time.Time

m *sync.RWMutex
stopChan chan struct{}
m *sync.RWMutex
token *Token
}

type Transport struct {
rt http.RoundTripper
cli *Client
}

// refreshTokenIfNeeded refreshes the token if refresh time has passed
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
c.m.Lock()
defer c.m.Unlock()

if time.Now().Before(c.nextRefresh) {
return nil
}

newToken, err := c.refreshToken(ctx, c.token.Refresh)
if err != nil {
return err
}
c.updateToken(newToken)
return nil
}

// updateToken updates the client token and sets the refresh time to half the
// access token lifetime.
func (c *Client) updateToken(t *Token) {
c.token = t
c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second)
}

// StartTokenHandler handles token refreshes in the background
func (c *Client) StartTokenHandler(ctx context.Context) error {
// Initialize the first token
token, err := c.newToken(ctx)
if err != nil {
return errors.New("failed to get initial token: " + err.Error())
return errors.New("getting initial token: " + err.Error())
}

c.m.Lock()
c.updateToken(token)
c.token = token
c.m.Unlock()

go c.tokenRefreshLoop(ctx)
go c.tokenHandler(ctx)
return nil
}

func (c *Client) tokenRefreshLoop(ctx context.Context) {
// tokenHandler gets a new token using the refresh token and a new pair when the
// refresh token expires.
func (c *Client) tokenHandler(ctx context.Context) {
newTokenTimer := time.NewTimer(0) // Start immediately
refreshTokenTimer := time.NewTimer(0) // Start immediately
defer func() {
newTokenTimer.Stop()
refreshTokenTimer.Stop()
}()

resetTimer := func(timer *time.Timer, expiryTime time.Time) {
if !timer.Stop() {
<-timer.C
}
timer.Reset(time.Until(expiryTime))
}

for {
c.m.RLock()
refreshTime := c.nextRefresh
newTokenExpiry := c.token.accessExpires(2)
refreshTokenExpiry := c.token.refreshExpires(2)
c.m.RUnlock()

timeToWait := time.Until(refreshTime)
if timeToWait < 0 {
timeToWait = 0
}
resetTimer(newTokenTimer, newTokenExpiry)
resetTimer(refreshTokenTimer, refreshTokenExpiry)

select {
case <-c.stopChan:
return
case <-time.After(timeToWait):
if err := c.refreshTokenIfNeeded(ctx); err != nil {
panic(fmt.Sprintf("failed to refresh token: %s", err))
}
case <-ctx.Done():
return
case <-newTokenTimer.C:
if token, err := c.newToken(ctx); err != nil {
panic(fmt.Sprintf("getting new token: %s", err))
} else {
c.updateToken(token)
}
case <-refreshTokenTimer.C:
if token, err := c.refreshToken(ctx); err != nil {
panic(fmt.Sprintf("refreshing token: %s", err))
} else {
c.updateToken(token)
}
}
}
}

func (c *Client) StopTokenHandler() {
close(c.stopChan)
// updateToken updates the client's token
func (c *Client) updateToken(t *Token) {
c.m.Lock()
defer c.m.Unlock()
c.token = t
}

func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
Expand All @@ -122,8 +118,7 @@ func NewClient(secretId, secretKey string) (*Client, error) {
secretId: secretId,
secretKey: secretKey,

m: &sync.RWMutex{},
stopChan: make(chan struct{}),
m: &sync.RWMutex{},
}

// Add transport to handle headers, host and path for all requests
Expand Down
69 changes: 57 additions & 12 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,75 @@
package nordigen

import (
"context"
"net/http"
"os"
"sync"
"testing"
"time"
)

// TestClientTokenRefresh should do a successful token refresh. We force this by
// setting the expiration to a time in the past and then calling any method.
// This test will only run if you have a valid secretId and secretKey in your
// environment.
func TestClientTokenRefresh(t *testing.T) {
id, id_exists := os.LookupEnv("NORDIGEN_SECRET_ID")
key, key_exists := os.LookupEnv("NORDIGEN_SECRET_KEY")
if !id_exists || !key_exists {
var (
sharedClient *Client
initOnce sync.Once
)

func initTestClient(t *testing.T) *Client {
id, idExists := os.LookupEnv("NORDIGEN_SECRET_ID")
key, keyExists := os.LookupEnv("NORDIGEN_SECRET_KEY")
if !idExists || !keyExists {
t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set")
}

c, err := NewClient(id, key)
initOnce.Do(func() {
c := &Client{
c: &http.Client{Timeout: 60 * time.Second},
secretId: id,
secretKey: key,

m: &sync.RWMutex{},
}
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}

// Initialize the first token
token, err := c.newToken(context.Background())
if err != nil {
t.Fatalf("newToken: %s", err)
}

c.token = token
sharedClient = c
})

return sharedClient
}

func TestAccessRefresh(t *testing.T) {
c := initTestClient(t)

// Expire token immediately
c.token.AccessExpires = 0

ctx, cancel := context.WithCancel(context.Background())
go c.tokenHandler(ctx)
_, err := c.ListRequisitions()
if err != nil {
t.Fatalf("NewClient: %s", err)
t.Fatalf("ListRequisitions: %s", err)
}
cancel() // Stop handler again
}

func TestRefreshRefresh(t *testing.T) {
c := initTestClient(t)

// Expire token immediately
c.token.RefreshExpires = 0

c.nextRefresh = time.Now().Add(-time.Hour)
_, err = c.ListRequisitions()
ctx, cancel := context.WithCancel(context.Background())
go c.tokenHandler(ctx)
_, err := c.ListRequisitions()
if err != nil {
t.Fatalf("ListRequisitions: %s", err)
}
cancel() // Stop handler again
}
17 changes: 14 additions & 3 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"strings"
"time"
)

type Token struct {
Expand All @@ -30,7 +31,17 @@ const tokenPath = "token"
const tokenNewPath = "new/"
const tokenRefreshPath = "refresh/"

func (c Client) newToken(ctx context.Context) (*Token, error) {
// accessExpires returns the time when access token expires divided by divisor
func (t *Token) accessExpires(divisor int) time.Time {
return time.Now().Add(time.Second * time.Duration(t.AccessExpires/divisor))
}

// refreshExpires returns the time when refresh token expires divided by divisor
func (t *Token) refreshExpires(divisor int) time.Time {
return time.Now().Add(time.Second * time.Duration(t.RefreshExpires/divisor))
}

func (c *Client) newToken(ctx context.Context) (*Token, error) {
data, err := json.Marshal(Secret{
SecretId: c.secretId,
AccessId: c.secretKey,
Expand Down Expand Up @@ -69,8 +80,8 @@ func (c Client) newToken(ctx context.Context) (*Token, error) {
return t, nil
}

func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) {
data, err := json.Marshal(TokenRefresh{Refresh: refresh})
func (c *Client) refreshToken(ctx context.Context) (*Token, error) {
data, err := json.Marshal(TokenRefresh{Refresh: c.token.Refresh})
if err != nil {
return nil, err
}
Expand Down
Loading