Skip to content

Commit 34afc3b

Browse files
authored
Merge pull request #27 from martinohansen/main
fix: refresh this field may not be blank
2 parents 3683fd6 + e7584e1 commit 34afc3b

File tree

3 files changed

+115
-64
lines changed

3 files changed

+115
-64
lines changed

client.go

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,84 +18,80 @@ type Client struct {
1818
secretId string
1919
secretKey string
2020

21-
token *Token
22-
nextRefresh time.Time
23-
24-
m *sync.RWMutex
25-
stopChan chan struct{}
21+
m *sync.RWMutex
22+
token *Token
2623
}
2724

2825
type Transport struct {
2926
rt http.RoundTripper
3027
cli *Client
3128
}
3229

33-
// refreshTokenIfNeeded refreshes the token if refresh time has passed
34-
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
35-
c.m.Lock()
36-
defer c.m.Unlock()
37-
38-
if time.Now().Before(c.nextRefresh) {
39-
return nil
40-
}
41-
42-
newToken, err := c.refreshToken(ctx, c.token.Refresh)
43-
if err != nil {
44-
return err
45-
}
46-
c.updateToken(newToken)
47-
return nil
48-
}
49-
50-
// updateToken updates the client token and sets the refresh time to half the
51-
// access token lifetime.
52-
func (c *Client) updateToken(t *Token) {
53-
c.token = t
54-
c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second)
55-
}
56-
5730
// StartTokenHandler handles token refreshes in the background
5831
func (c *Client) StartTokenHandler(ctx context.Context) error {
5932
// Initialize the first token
6033
token, err := c.newToken(ctx)
6134
if err != nil {
62-
return errors.New("failed to get initial token: " + err.Error())
35+
return errors.New("getting initial token: " + err.Error())
6336
}
64-
6537
c.m.Lock()
66-
c.updateToken(token)
38+
c.token = token
6739
c.m.Unlock()
6840

69-
go c.tokenRefreshLoop(ctx)
41+
go c.tokenHandler(ctx)
7042
return nil
7143
}
7244

73-
func (c *Client) tokenRefreshLoop(ctx context.Context) {
45+
// tokenHandler gets a new token using the refresh token and a new pair when the
46+
// refresh token expires.
47+
func (c *Client) tokenHandler(ctx context.Context) {
48+
newTokenTimer := time.NewTimer(0) // Start immediately
49+
refreshTokenTimer := time.NewTimer(0) // Start immediately
50+
defer func() {
51+
newTokenTimer.Stop()
52+
refreshTokenTimer.Stop()
53+
}()
54+
55+
resetTimer := func(timer *time.Timer, expiryTime time.Time) {
56+
if !timer.Stop() {
57+
<-timer.C
58+
}
59+
timer.Reset(time.Until(expiryTime))
60+
}
61+
7462
for {
7563
c.m.RLock()
76-
refreshTime := c.nextRefresh
64+
newTokenExpiry := c.token.accessExpires(2)
65+
refreshTokenExpiry := c.token.refreshExpires(2)
7766
c.m.RUnlock()
7867

79-
timeToWait := time.Until(refreshTime)
80-
if timeToWait < 0 {
81-
timeToWait = 0
82-
}
68+
resetTimer(newTokenTimer, newTokenExpiry)
69+
resetTimer(refreshTokenTimer, refreshTokenExpiry)
8370

8471
select {
85-
case <-c.stopChan:
86-
return
87-
case <-time.After(timeToWait):
88-
if err := c.refreshTokenIfNeeded(ctx); err != nil {
89-
panic(fmt.Sprintf("failed to refresh token: %s", err))
90-
}
9172
case <-ctx.Done():
9273
return
74+
case <-newTokenTimer.C:
75+
if token, err := c.newToken(ctx); err != nil {
76+
panic(fmt.Sprintf("getting new token: %s", err))
77+
} else {
78+
c.updateToken(token)
79+
}
80+
case <-refreshTokenTimer.C:
81+
if token, err := c.refreshToken(ctx); err != nil {
82+
panic(fmt.Sprintf("refreshing token: %s", err))
83+
} else {
84+
c.updateToken(token)
85+
}
9386
}
9487
}
9588
}
9689

97-
func (c *Client) StopTokenHandler() {
98-
close(c.stopChan)
90+
// updateToken updates the client's token
91+
func (c *Client) updateToken(t *Token) {
92+
c.m.Lock()
93+
defer c.m.Unlock()
94+
c.token = t
9995
}
10096

10197
func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -122,8 +118,7 @@ func NewClient(secretId, secretKey string) (*Client, error) {
122118
secretId: secretId,
123119
secretKey: secretKey,
124120

125-
m: &sync.RWMutex{},
126-
stopChan: make(chan struct{}),
121+
m: &sync.RWMutex{},
127122
}
128123

129124
// Add transport to handle headers, host and path for all requests

client_test.go

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,75 @@
11
package nordigen
22

33
import (
4+
"context"
5+
"net/http"
46
"os"
7+
"sync"
58
"testing"
69
"time"
710
)
811

9-
// TestClientTokenRefresh should do a successful token refresh. We force this by
10-
// setting the expiration to a time in the past and then calling any method.
11-
// This test will only run if you have a valid secretId and secretKey in your
12-
// environment.
13-
func TestClientTokenRefresh(t *testing.T) {
14-
id, id_exists := os.LookupEnv("NORDIGEN_SECRET_ID")
15-
key, key_exists := os.LookupEnv("NORDIGEN_SECRET_KEY")
16-
if !id_exists || !key_exists {
12+
var (
13+
sharedClient *Client
14+
initOnce sync.Once
15+
)
16+
17+
func initTestClient(t *testing.T) *Client {
18+
id, idExists := os.LookupEnv("NORDIGEN_SECRET_ID")
19+
key, keyExists := os.LookupEnv("NORDIGEN_SECRET_KEY")
20+
if !idExists || !keyExists {
1721
t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set")
1822
}
1923

20-
c, err := NewClient(id, key)
24+
initOnce.Do(func() {
25+
c := &Client{
26+
c: &http.Client{Timeout: 60 * time.Second},
27+
secretId: id,
28+
secretKey: key,
29+
30+
m: &sync.RWMutex{},
31+
}
32+
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}
33+
34+
// Initialize the first token
35+
token, err := c.newToken(context.Background())
36+
if err != nil {
37+
t.Fatalf("newToken: %s", err)
38+
}
39+
40+
c.token = token
41+
sharedClient = c
42+
})
43+
44+
return sharedClient
45+
}
46+
47+
func TestAccessRefresh(t *testing.T) {
48+
c := initTestClient(t)
49+
50+
// Expire token immediately
51+
c.token.AccessExpires = 0
52+
53+
ctx, cancel := context.WithCancel(context.Background())
54+
go c.tokenHandler(ctx)
55+
_, err := c.ListRequisitions()
2156
if err != nil {
22-
t.Fatalf("NewClient: %s", err)
57+
t.Fatalf("ListRequisitions: %s", err)
2358
}
59+
cancel() // Stop handler again
60+
}
61+
62+
func TestRefreshRefresh(t *testing.T) {
63+
c := initTestClient(t)
64+
65+
// Expire token immediately
66+
c.token.RefreshExpires = 0
2467

25-
c.nextRefresh = time.Now().Add(-time.Hour)
26-
_, err = c.ListRequisitions()
68+
ctx, cancel := context.WithCancel(context.Background())
69+
go c.tokenHandler(ctx)
70+
_, err := c.ListRequisitions()
2771
if err != nil {
2872
t.Fatalf("ListRequisitions: %s", err)
2973
}
74+
cancel() // Stop handler again
3075
}

token.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/url"
1010
"strings"
11+
"time"
1112
)
1213

1314
type Token struct {
@@ -30,7 +31,17 @@ const tokenPath = "token"
3031
const tokenNewPath = "new/"
3132
const tokenRefreshPath = "refresh/"
3233

33-
func (c Client) newToken(ctx context.Context) (*Token, error) {
34+
// accessExpires returns the time when access token expires divided by divisor
35+
func (t *Token) accessExpires(divisor int) time.Time {
36+
return time.Now().Add(time.Second * time.Duration(t.AccessExpires/divisor))
37+
}
38+
39+
// refreshExpires returns the time when refresh token expires divided by divisor
40+
func (t *Token) refreshExpires(divisor int) time.Time {
41+
return time.Now().Add(time.Second * time.Duration(t.RefreshExpires/divisor))
42+
}
43+
44+
func (c *Client) newToken(ctx context.Context) (*Token, error) {
3445
data, err := json.Marshal(Secret{
3546
SecretId: c.secretId,
3647
AccessId: c.secretKey,
@@ -69,8 +80,8 @@ func (c Client) newToken(ctx context.Context) (*Token, error) {
6980
return t, nil
7081
}
7182

72-
func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) {
73-
data, err := json.Marshal(TokenRefresh{Refresh: refresh})
83+
func (c *Client) refreshToken(ctx context.Context) (*Token, error) {
84+
data, err := json.Marshal(TokenRefresh{Refresh: c.token.Refresh})
7485
if err != nil {
7586
return nil, err
7687
}

0 commit comments

Comments
 (0)