Skip to content

Commit 0aa9f5f

Browse files
committed
fix: token is invalid or expired
The refresh goroutine was stopped immediately after starting, fix that logic error to refresh the token in the background after half the access token lifetime. Fixes martinohansen/ynabber#97
1 parent b9f2fcc commit 0aa9f5f

File tree

3 files changed

+110
-82
lines changed

3 files changed

+110
-82
lines changed

client.go

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package nordigen
22

33
import (
4+
"context"
5+
"errors"
46
"fmt"
57
"net/http"
68
"strings"
@@ -12,67 +14,84 @@ const baseUrl = "bankaccountdata.gocardless.com"
1214
const apiPath = "/api/v2"
1315

1416
type Client struct {
15-
c *http.Client
16-
secretId string
17-
secretKey string
18-
expiration time.Time
19-
token *Token
20-
m *sync.Mutex
21-
stopChan chan struct{}
17+
c *http.Client
18+
secretId string
19+
secretKey string
20+
21+
token *Token
22+
nextRefresh time.Time
23+
24+
m *sync.RWMutex
25+
stopChan chan struct{}
2226
}
2327

2428
type Transport struct {
2529
rt http.RoundTripper
2630
cli *Client
2731
}
2832

29-
func (c *Client) refreshTokenIfNeeded() error {
33+
// refreshTokenIfNeeded refreshes the token if refresh time has passed
34+
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
3035
c.m.Lock()
3136
defer c.m.Unlock()
3237

33-
if time.Now().Add(time.Minute).Before(c.expiration) {
34-
return nil
35-
} else {
36-
// Refresh the token if its expiration is less than a minute away
37-
newToken, err := c.refreshToken(c.token.Refresh)
38-
if err != nil {
39-
return err
40-
}
41-
c.token = newToken
42-
c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second)
38+
if time.Now().Before(c.nextRefresh) {
4339
return nil
4440
}
41+
42+
newToken, err := c.refreshToken(ctx, c.token.Refresh)
43+
if err != nil {
44+
return err
45+
}
46+
c.updateToken(newToken)
47+
return nil
4548
}
4649

47-
func (c *Client) StartTokenHandler() {
48-
c.stopChan = make(chan struct{})
50+
// updateToken updates the client token and sets the refresh time to half the
51+
// access token lifetime.
52+
func (c *Client) updateToken(t *Token) {
53+
c.token = t
54+
c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second)
55+
}
4956

50-
// Initialize the first token and start the token handler
51-
token, err := c.newToken()
57+
// StartTokenHandler handles token refreshes in the background
58+
func (c *Client) StartTokenHandler(ctx context.Context) error {
59+
// Initialize the first token
60+
token, err := c.newToken(ctx)
5261
if err != nil {
53-
panic("Failed to get initial token: " + err.Error())
62+
return errors.New("failed to get initial token: " + err.Error())
5463
}
55-
c.token = token
56-
57-
go func() {
58-
for {
59-
timeToWait := time.Until(c.expiration) - time.Minute
60-
if timeToWait < 0 {
61-
// If the token is already expired, try to refresh immediately
62-
timeToWait = 0
63-
}
6464

65-
select {
66-
case <-c.stopChan:
67-
return
68-
case <-time.After(timeToWait):
69-
if err := c.refreshTokenIfNeeded(); err != nil {
70-
// TODO(Martin): add retry logic
71-
panic("Failed to refresh token: " + err.Error())
72-
}
65+
c.m.Lock()
66+
c.updateToken(token)
67+
c.m.Unlock()
68+
69+
go c.tokenRefreshLoop(ctx)
70+
return nil
71+
}
72+
73+
func (c *Client) tokenRefreshLoop(ctx context.Context) {
74+
for {
75+
c.m.RLock()
76+
refreshTime := c.nextRefresh
77+
c.m.RUnlock()
78+
79+
timeToWait := time.Until(refreshTime)
80+
if timeToWait < 0 {
81+
timeToWait = 0
82+
}
83+
84+
select {
85+
case <-c.stopChan:
86+
return
87+
case <-time.After(timeToWait):
88+
if err := c.refreshTokenIfNeeded(ctx); err != nil {
89+
panic(fmt.Sprintf("failed to refresh token: %s", err))
7390
}
91+
case <-ctx.Done():
92+
return
7493
}
75-
}()
94+
}
7695
}
7796

7897
func (c *Client) StopTokenHandler() {
@@ -98,17 +117,22 @@ func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
98117
// NewClient creates a new Nordigen client that handles token refreshes and adds
99118
// the necessary headers, host, and path to all requests.
100119
func NewClient(secretId, secretKey string) (*Client, error) {
101-
c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{},
120+
c := &Client{
121+
c: &http.Client{Timeout: 60 * time.Second},
102122
secretId: secretId,
103123
secretKey: secretKey,
124+
125+
m: &sync.RWMutex{},
126+
stopChan: make(chan struct{}),
104127
}
105128

106129
// Add transport to handle headers, host and path for all requests
107130
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}
108131

109132
// Start token handler
110-
c.StartTokenHandler()
111-
defer c.StopTokenHandler()
133+
if err := c.StartTokenHandler(context.Background()); err != nil {
134+
return nil, err
135+
}
112136

113137
return c, nil
114138
}

client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestClientTokenRefresh(t *testing.T) {
2222
t.Fatalf("NewClient: %s", err)
2323
}
2424

25-
c.expiration = time.Now().Add(-time.Hour)
25+
c.nextRefresh = time.Now().Add(-time.Hour)
2626
_, err = c.ListRequisitions()
2727
if err != nil {
2828
t.Fatalf("ListRequisitions: %s", err)

token.go

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package nordigen
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"io"
78
"net/http"
@@ -16,87 +17,90 @@ type Token struct {
1617
RefreshExpires int `json:"refresh_expires"`
1718
}
1819

20+
type TokenRefresh struct {
21+
Refresh string `json:"refresh"`
22+
}
23+
1924
type Secret struct {
2025
SecretId string `json:"secret_id"`
2126
AccessId string `json:"secret_key"`
2227
}
2328

2429
const tokenPath = "token"
2530
const tokenNewPath = "new/"
26-
const tokenRefreshPath = "refresh"
27-
28-
func (c Client) newToken() (*Token, error) {
29-
req := http.Request{
30-
Method: http.MethodPost,
31-
URL: &url.URL{
32-
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
33-
},
34-
}
31+
const tokenRefreshPath = "refresh/"
3532

33+
func (c Client) newToken(ctx context.Context) (*Token, error) {
3634
data, err := json.Marshal(Secret{
3735
SecretId: c.secretId,
3836
AccessId: c.secretKey,
3937
})
4038
if err != nil {
4139
return nil, err
4240
}
43-
req.Body = io.NopCloser(bytes.NewBuffer(data))
44-
resp, err := c.c.Do(&req)
4541

46-
if err != nil {
47-
return nil, err
42+
req := &http.Request{
43+
Method: http.MethodPost,
44+
Body: io.NopCloser(bytes.NewBuffer(data)),
45+
URL: &url.URL{
46+
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
47+
},
4848
}
49-
body, err := io.ReadAll(resp.Body)
49+
req = req.WithContext(ctx)
5050

51+
resp, err := c.c.Do(req)
5152
if err != nil {
5253
return nil, err
5354
}
55+
defer resp.Body.Close()
56+
57+
body, readErr := io.ReadAll(resp.Body)
58+
if readErr != nil {
59+
return nil, readErr
60+
}
5461
if resp.StatusCode != http.StatusOK {
55-
return nil, &APIError{resp.StatusCode, string(body), err}
62+
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
5663
}
57-
t := &Token{}
58-
err = json.Unmarshal(body, &t)
5964

60-
if err != nil {
65+
t := &Token{}
66+
if err := json.Unmarshal(body, t); err != nil {
6167
return nil, err
6268
}
63-
6469
return t, nil
6570
}
6671

67-
func (c Client) refreshToken(refresh string) (*Token, error) {
68-
req := http.Request{
72+
func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) {
73+
data, err := json.Marshal(TokenRefresh{Refresh: refresh})
74+
if err != nil {
75+
return nil, err
76+
}
77+
78+
req := &http.Request{
6979
Method: http.MethodPost,
80+
Body: io.NopCloser(bytes.NewBuffer(data)),
7081
URL: &url.URL{
7182
Path: strings.Join([]string{tokenPath, tokenRefreshPath}, "/"),
7283
},
7384
}
74-
data, err := json.Marshal(refresh)
75-
76-
if err != nil {
77-
return &Token{}, err
78-
}
79-
req.Body = io.NopCloser(bytes.NewBuffer(data))
80-
81-
resp, err := c.c.Do(&req)
85+
req = req.WithContext(ctx)
8286

87+
resp, err := c.c.Do(req)
8388
if err != nil {
8489
return nil, err
8590
}
86-
body, err := io.ReadAll(resp.Body)
91+
defer resp.Body.Close()
8792

88-
if err != nil {
89-
return nil, err
93+
body, readErr := io.ReadAll(resp.Body)
94+
if readErr != nil {
95+
return nil, readErr
9096
}
9197
if resp.StatusCode != http.StatusOK {
92-
return nil, &APIError{resp.StatusCode, string(body), err}
98+
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
9399
}
94-
t := &Token{}
95-
err = json.Unmarshal(body, &t)
96100

97-
if err != nil {
101+
t := &Token{}
102+
if err := json.Unmarshal(body, t); err != nil {
98103
return nil, err
99104
}
100-
101105
return t, nil
102106
}

0 commit comments

Comments
 (0)