Skip to content

Commit 3683fd6

Browse files
authored
Merge pull request #26 from martinohansen/main
fix: token is invalid or expired
2 parents b9f2fcc + 0aa9f5f commit 3683fd6

File tree

3 files changed

+110
-82
lines changed

3 files changed

+110
-82
lines changed

client.go

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package nordigen
22

33
import (
4+
"context"
5+
"errors"
46
"fmt"
57
"net/http"
68
"strings"
@@ -12,67 +14,84 @@ const baseUrl = "bankaccountdata.gocardless.com"
1214
const apiPath = "/api/v2"
1315

1416
type Client struct {
15-
c *http.Client
16-
secretId string
17-
secretKey string
18-
expiration time.Time
19-
token *Token
20-
m *sync.Mutex
21-
stopChan chan struct{}
17+
c *http.Client
18+
secretId string
19+
secretKey string
20+
21+
token *Token
22+
nextRefresh time.Time
23+
24+
m *sync.RWMutex
25+
stopChan chan struct{}
2226
}
2327

2428
type Transport struct {
2529
rt http.RoundTripper
2630
cli *Client
2731
}
2832

29-
func (c *Client) refreshTokenIfNeeded() error {
33+
// refreshTokenIfNeeded refreshes the token if refresh time has passed
34+
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
3035
c.m.Lock()
3136
defer c.m.Unlock()
3237

33-
if time.Now().Add(time.Minute).Before(c.expiration) {
34-
return nil
35-
} else {
36-
// Refresh the token if its expiration is less than a minute away
37-
newToken, err := c.refreshToken(c.token.Refresh)
38-
if err != nil {
39-
return err
40-
}
41-
c.token = newToken
42-
c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second)
38+
if time.Now().Before(c.nextRefresh) {
4339
return nil
4440
}
41+
42+
newToken, err := c.refreshToken(ctx, c.token.Refresh)
43+
if err != nil {
44+
return err
45+
}
46+
c.updateToken(newToken)
47+
return nil
4548
}
4649

47-
func (c *Client) StartTokenHandler() {
48-
c.stopChan = make(chan struct{})
50+
// updateToken updates the client token and sets the refresh time to half the
51+
// access token lifetime.
52+
func (c *Client) updateToken(t *Token) {
53+
c.token = t
54+
c.nextRefresh = time.Now().Add(time.Duration(t.AccessExpires/2) * time.Second)
55+
}
4956

50-
// Initialize the first token and start the token handler
51-
token, err := c.newToken()
57+
// StartTokenHandler handles token refreshes in the background
58+
func (c *Client) StartTokenHandler(ctx context.Context) error {
59+
// Initialize the first token
60+
token, err := c.newToken(ctx)
5261
if err != nil {
53-
panic("Failed to get initial token: " + err.Error())
62+
return errors.New("failed to get initial token: " + err.Error())
5463
}
55-
c.token = token
56-
57-
go func() {
58-
for {
59-
timeToWait := time.Until(c.expiration) - time.Minute
60-
if timeToWait < 0 {
61-
// If the token is already expired, try to refresh immediately
62-
timeToWait = 0
63-
}
6464

65-
select {
66-
case <-c.stopChan:
67-
return
68-
case <-time.After(timeToWait):
69-
if err := c.refreshTokenIfNeeded(); err != nil {
70-
// TODO(Martin): add retry logic
71-
panic("Failed to refresh token: " + err.Error())
72-
}
65+
c.m.Lock()
66+
c.updateToken(token)
67+
c.m.Unlock()
68+
69+
go c.tokenRefreshLoop(ctx)
70+
return nil
71+
}
72+
73+
func (c *Client) tokenRefreshLoop(ctx context.Context) {
74+
for {
75+
c.m.RLock()
76+
refreshTime := c.nextRefresh
77+
c.m.RUnlock()
78+
79+
timeToWait := time.Until(refreshTime)
80+
if timeToWait < 0 {
81+
timeToWait = 0
82+
}
83+
84+
select {
85+
case <-c.stopChan:
86+
return
87+
case <-time.After(timeToWait):
88+
if err := c.refreshTokenIfNeeded(ctx); err != nil {
89+
panic(fmt.Sprintf("failed to refresh token: %s", err))
7390
}
91+
case <-ctx.Done():
92+
return
7493
}
75-
}()
94+
}
7695
}
7796

7897
func (c *Client) StopTokenHandler() {
@@ -98,17 +117,22 @@ func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
98117
// NewClient creates a new Nordigen client that handles token refreshes and adds
99118
// the necessary headers, host, and path to all requests.
100119
func NewClient(secretId, secretKey string) (*Client, error) {
101-
c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{},
120+
c := &Client{
121+
c: &http.Client{Timeout: 60 * time.Second},
102122
secretId: secretId,
103123
secretKey: secretKey,
124+
125+
m: &sync.RWMutex{},
126+
stopChan: make(chan struct{}),
104127
}
105128

106129
// Add transport to handle headers, host and path for all requests
107130
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}
108131

109132
// Start token handler
110-
c.StartTokenHandler()
111-
defer c.StopTokenHandler()
133+
if err := c.StartTokenHandler(context.Background()); err != nil {
134+
return nil, err
135+
}
112136

113137
return c, nil
114138
}

client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestClientTokenRefresh(t *testing.T) {
2222
t.Fatalf("NewClient: %s", err)
2323
}
2424

25-
c.expiration = time.Now().Add(-time.Hour)
25+
c.nextRefresh = time.Now().Add(-time.Hour)
2626
_, err = c.ListRequisitions()
2727
if err != nil {
2828
t.Fatalf("ListRequisitions: %s", err)

token.go

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package nordigen
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"io"
78
"net/http"
@@ -16,87 +17,90 @@ type Token struct {
1617
RefreshExpires int `json:"refresh_expires"`
1718
}
1819

20+
type TokenRefresh struct {
21+
Refresh string `json:"refresh"`
22+
}
23+
1924
type Secret struct {
2025
SecretId string `json:"secret_id"`
2126
AccessId string `json:"secret_key"`
2227
}
2328

2429
const tokenPath = "token"
2530
const tokenNewPath = "new/"
26-
const tokenRefreshPath = "refresh"
27-
28-
func (c Client) newToken() (*Token, error) {
29-
req := http.Request{
30-
Method: http.MethodPost,
31-
URL: &url.URL{
32-
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
33-
},
34-
}
31+
const tokenRefreshPath = "refresh/"
3532

33+
func (c Client) newToken(ctx context.Context) (*Token, error) {
3634
data, err := json.Marshal(Secret{
3735
SecretId: c.secretId,
3836
AccessId: c.secretKey,
3937
})
4038
if err != nil {
4139
return nil, err
4240
}
43-
req.Body = io.NopCloser(bytes.NewBuffer(data))
44-
resp, err := c.c.Do(&req)
4541

46-
if err != nil {
47-
return nil, err
42+
req := &http.Request{
43+
Method: http.MethodPost,
44+
Body: io.NopCloser(bytes.NewBuffer(data)),
45+
URL: &url.URL{
46+
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
47+
},
4848
}
49-
body, err := io.ReadAll(resp.Body)
49+
req = req.WithContext(ctx)
5050

51+
resp, err := c.c.Do(req)
5152
if err != nil {
5253
return nil, err
5354
}
55+
defer resp.Body.Close()
56+
57+
body, readErr := io.ReadAll(resp.Body)
58+
if readErr != nil {
59+
return nil, readErr
60+
}
5461
if resp.StatusCode != http.StatusOK {
55-
return nil, &APIError{resp.StatusCode, string(body), err}
62+
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
5663
}
57-
t := &Token{}
58-
err = json.Unmarshal(body, &t)
5964

60-
if err != nil {
65+
t := &Token{}
66+
if err := json.Unmarshal(body, t); err != nil {
6167
return nil, err
6268
}
63-
6469
return t, nil
6570
}
6671

67-
func (c Client) refreshToken(refresh string) (*Token, error) {
68-
req := http.Request{
72+
func (c Client) refreshToken(ctx context.Context, refresh string) (*Token, error) {
73+
data, err := json.Marshal(TokenRefresh{Refresh: refresh})
74+
if err != nil {
75+
return nil, err
76+
}
77+
78+
req := &http.Request{
6979
Method: http.MethodPost,
80+
Body: io.NopCloser(bytes.NewBuffer(data)),
7081
URL: &url.URL{
7182
Path: strings.Join([]string{tokenPath, tokenRefreshPath}, "/"),
7283
},
7384
}
74-
data, err := json.Marshal(refresh)
75-
76-
if err != nil {
77-
return &Token{}, err
78-
}
79-
req.Body = io.NopCloser(bytes.NewBuffer(data))
80-
81-
resp, err := c.c.Do(&req)
85+
req = req.WithContext(ctx)
8286

87+
resp, err := c.c.Do(req)
8388
if err != nil {
8489
return nil, err
8590
}
86-
body, err := io.ReadAll(resp.Body)
91+
defer resp.Body.Close()
8792

88-
if err != nil {
89-
return nil, err
93+
body, readErr := io.ReadAll(resp.Body)
94+
if readErr != nil {
95+
return nil, readErr
9096
}
9197
if resp.StatusCode != http.StatusOK {
92-
return nil, &APIError{resp.StatusCode, string(body), err}
98+
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
9399
}
94-
t := &Token{}
95-
err = json.Unmarshal(body, &t)
96100

97-
if err != nil {
101+
t := &Token{}
102+
if err := json.Unmarshal(body, t); err != nil {
98103
return nil, err
99104
}
100-
101105
return t, nil
102106
}

0 commit comments

Comments
 (0)