Skip to content

Commit 7fe360f

Browse files
committed
fix: refresh this field may not be blank
A race condition caused two refresh calls to be executed at the same time, refactor the handler to use timers instead of time.After to avoid this. Also update the tests which catches the error should we make it again. commit-id:ed164022
1 parent 3683fd6 commit 7fe360f

File tree

3 files changed

+117
-66
lines changed

3 files changed

+117
-66
lines changed

client.go

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,84 +18,80 @@ type Client struct {
1818
secretId string
1919
secretKey string
2020

21-
token *Token
22-
nextRefresh time.Time
23-
24-
m *sync.RWMutex
25-
stopChan chan struct{}
21+
m *sync.RWMutex
22+
token *Token
2623
}
2724

2825
type Transport struct {
2926
rt http.RoundTripper
3027
cli *Client
3128
}
3229

33-
// refreshTokenIfNeeded refreshes the token if refresh time has passed
34-
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
35-
c.m.Lock()
36-
defer c.m.Unlock()
37-
38-
if time.Now().Before(c.nextRefresh) {
39-
return nil
40-
}
41-
42-
newToken, err := c.refreshToken(ctx, c.token.Refresh)
43-
if err != nil {
44-
return err
45-
}
46-
c.updateToken(newToken)
47-
return nil
48-
}
49-
50-
// updateToken updates the client token and sets the refresh time to half the
51-
// access token lifetime.
52-
func (c *Client) updateToken(t *Token) {
53-
c.token = t
54-
c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second)
55-
}
56-
5730
// StartTokenHandler handles token refreshes in the background
5831
func (c *Client) StartTokenHandler(ctx context.Context) error {
5932
// Initialize the first token
6033
token, err := c.newToken(ctx)
6134
if err != nil {
62-
return errors.New("failed to get initial token: " + err.Error())
35+
return errors.New("getting initial token: " + err.Error())
6336
}
64-
6537
c.m.Lock()
66-
c.updateToken(token)
38+
c.token = token
6739
c.m.Unlock()
6840

69-
go c.tokenRefreshLoop(ctx)
41+
go c.tokenHandler(ctx)
7042
return nil
7143
}
7244

73-
func (c *Client) tokenRefreshLoop(ctx context.Context) {
45+
// tokenHandler gets a new token using the refresh token and a new pair when the
46+
// refresh token expires.
47+
func (c *Client) tokenHandler(ctx context.Context) {
48+
newTokenTimer := time.NewTimer(0) // Start immediately
49+
refreshTokenTimer := time.NewTimer(0) // Start immediately
50+
defer func() {
51+
newTokenTimer.Stop()
52+
refreshTokenTimer.Stop()
53+
}()
54+
55+
resetTimer := func(timer *time.Timer, expiryTime time.Time) {
56+
if !timer.Stop() {
57+
<-timer.C
58+
}
59+
timer.Reset(time.Until(expiryTime))
60+
}
61+
7462
for {
7563
c.m.RLock()
76-
refreshTime := c.nextRefresh
64+
newTokenExpiry := c.token.accessExpires(2)
65+
refreshTokenExpiry := c.token.refreshExpires(2)
7766
c.m.RUnlock()
7867

79-
timeToWait := time.Until(refreshTime)
80-
if timeToWait < 0 {
81-
timeToWait = 0
82-
}
68+
resetTimer(newTokenTimer, newTokenExpiry)
69+
resetTimer(refreshTokenTimer, refreshTokenExpiry)
8370

8471
select {
85-
case <-c.stopChan:
86-
return
87-
case <-time.After(timeToWait):
88-
if err := c.refreshTokenIfNeeded(ctx); err != nil {
89-
panic(fmt.Sprintf("failed to refresh token: %s", err))
90-
}
9172
case <-ctx.Done():
9273
return
74+
case <-newTokenTimer.C:
75+
if token, err := c.newToken(ctx); err != nil {
76+
panic(fmt.Sprintf("getting new token: %s", err))
77+
} else {
78+
c.updateToken(token)
79+
}
80+
case <-refreshTokenTimer.C:
81+
if token, err := c.refreshToken(ctx); err != nil {
82+
panic(fmt.Sprintf("refreshing token: %s", err))
83+
} else {
84+
c.updateToken(token)
85+
}
9386
}
9487
}
9588
}
9689

97-
func (c *Client) StopTokenHandler() {
98-
close(c.stopChan)
90+
// updateToken updates the client's token
91+
func (c *Client) updateToken(t *Token) {
92+
c.m.Lock()
93+
defer c.m.Unlock()
94+
c.token = t
9995
}
10096

10197
func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -122,8 +118,7 @@ func NewClient(secretId, secretKey string) (*Client, error) {
122118
secretId: secretId,
123119
secretKey: secretKey,
124120

125-
m: &sync.RWMutex{},
126-
stopChan: make(chan struct{}),
121+
m: &sync.RWMutex{},
127122
}
128123

129124
// Add transport to handle headers, host and path for all requests

client_test.go

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,75 @@
11
package nordigen
22

33
import (
4+
"context"
5+
"net/http"
46
"os"
7+
"sync"
58
"testing"
69
"time"
710
)
811

9-
// TestClientTokenRefresh should do a successful token refresh. We force this by
10-
// setting the expiration to a time in the past and then calling any method.
11-
// This test will only run if you have a valid secretId and secretKey in your
12-
// environment.
13-
func TestClientTokenRefresh(t *testing.T) {
14-
id, id_exists := os.LookupEnv("NORDIGEN_SECRET_ID")
15-
key, key_exists := os.LookupEnv("NORDIGEN_SECRET_KEY")
16-
if !id_exists || !key_exists {
17-
t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set")
18-
}
12+
var (
13+
sharedClient *Client
14+
initOnce sync.Once
15+
)
16+
17+
func initTestClient(t *testing.T) *Client {
18+
initOnce.Do(func() {
19+
id, idExists := os.LookupEnv("NORDIGEN_SECRET_ID")
20+
key, keyExists := os.LookupEnv("NORDIGEN_SECRET_KEY")
21+
if !idExists || !keyExists {
22+
t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set")
23+
}
24+
25+
c := &Client{
26+
c: &http.Client{Timeout: 60 * time.Second},
27+
secretId: id,
28+
secretKey: key,
1929

20-
c, err := NewClient(id, key)
30+
m: &sync.RWMutex{},
31+
}
32+
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}
33+
34+
// Initialize the first token
35+
token, err := c.newToken(context.Background())
36+
if err != nil {
37+
t.Fatalf("newToken: %s", err)
38+
}
39+
40+
c.token = token
41+
sharedClient = c
42+
})
43+
44+
return sharedClient
45+
}
46+
47+
func TestAccessRefresh(t *testing.T) {
48+
c := initTestClient(t)
49+
50+
// Expire token immediately
51+
c.token.AccessExpires = 0
52+
53+
ctx, cancel := context.WithCancel(context.Background())
54+
go c.tokenHandler(ctx)
55+
_, err := c.ListRequisitions()
2156
if err != nil {
22-
t.Fatalf("NewClient: %s", err)
57+
t.Fatalf("ListRequisitions: %s", err)
2358
}
59+
cancel() // Stop handler again
60+
}
61+
62+
func TestRefreshRefresh(t *testing.T) {
63+
c := initTestClient(t)
64+
65+
// Expire token immediately
66+
c.token.RefreshExpires = 0
2467

25-
c.nextRefresh = time.Now().Add(-time.Hour)
26-
_, err = c.ListRequisitions()
68+
ctx, cancel := context.WithCancel(context.Background())
69+
go c.tokenHandler(ctx)
70+
_, err := c.ListRequisitions()
2771
if err != nil {
2872
t.Fatalf("ListRequisitions: %s", err)
2973
}
74+
cancel() // Stop handler again
3075
}

token.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/url"
1010
"strings"
11+
"time"
1112
)
1213

1314
type Token struct {
@@ -30,7 +31,17 @@ const tokenPath = "token"
3031
const tokenNewPath = "new/"
3132
const tokenRefreshPath = "refresh/"
3233

33-
func (c Client) newToken(ctx context.Context) (*Token, error) {
34+
// accessExpires returns the time when access token expires divided by divisor
35+
func (t *Token) accessExpires(divisor int) time.Time {
36+
return time.Now().Add(time.Second * time.Duration(t.AccessExpires/divisor))
37+
}
38+
39+
// refreshExpires returns the time when refresh token expires divided by divisor
40+
func (t *Token) refreshExpires(divisor int) time.Time {
41+
return time.Now().Add(time.Second * time.Duration(t.RefreshExpires/divisor))
42+
}
43+
44+
func (c *Client) newToken(ctx context.Context) (*Token, error) {
3445
data, err := json.Marshal(Secret{
3546
SecretId: c.secretId,
3647
AccessId: c.secretKey,
@@ -69,8 +80,8 @@ func (c Client) newToken(ctx context.Context) (*Token, error) {
6980
return t, nil
7081
}
7182

72-
func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) {
73-
data, err := json.Marshal(TokenRefresh{Refresh: refresh})
83+
func (c *Client) refreshToken(ctx context.Context) (*Token, error) {
84+
data, err := json.Marshal(TokenRefresh{Refresh: c.token.Refresh})
7485
if err != nil {
7586
return nil, err
7687
}

0 commit comments

Comments
 (0)