Skip to content

Commit b9f2fcc

Browse files
authored
Merge pull request #25 from martinohansen/martin/token-handler
fix: deadlock in RoundTrip
2 parents 9afe5bb + 4dafb8e commit b9f2fcc

File tree

3 files changed

+105
-34
lines changed

3 files changed

+105
-34
lines changed

client.go

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,50 +18,97 @@ type Client struct {
1818
expiration time.Time
1919
token *Token
2020
m *sync.Mutex
21+
stopChan chan struct{}
2122
}
2223

23-
type refreshTokenTransport struct {
24+
type Transport struct {
2425
rt http.RoundTripper
2526
cli *Client
2627
}
2728

28-
func (t refreshTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) {
29-
var err error
29+
func (c *Client) refreshTokenIfNeeded() error {
30+
c.m.Lock()
31+
defer c.m.Unlock()
3032

33+
if time.Now().Add(time.Minute).Before(c.expiration) {
34+
return nil
35+
} else {
36+
// Refresh the token if its expiration is less than a minute away
37+
newToken, err := c.refreshToken(c.token.Refresh)
38+
if err != nil {
39+
return err
40+
}
41+
c.token = newToken
42+
c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second)
43+
return nil
44+
}
45+
}
46+
47+
func (c *Client) StartTokenHandler() {
48+
c.stopChan = make(chan struct{})
49+
50+
// Initialize the first token and start the token handler
51+
token, err := c.newToken()
52+
if err != nil {
53+
panic("Failed to get initial token: " + err.Error())
54+
}
55+
c.token = token
56+
57+
go func() {
58+
for {
59+
timeToWait := time.Until(c.expiration) - time.Minute
60+
if timeToWait < 0 {
61+
// If the token is already expired, try to refresh immediately
62+
timeToWait = 0
63+
}
64+
65+
select {
66+
case <-c.stopChan:
67+
return
68+
case <-time.After(timeToWait):
69+
if err := c.refreshTokenIfNeeded(); err != nil {
70+
// TODO(Martin): add retry logic
71+
panic("Failed to refresh token: " + err.Error())
72+
}
73+
}
74+
}
75+
}()
76+
}
77+
78+
func (c *Client) StopTokenHandler() {
79+
close(c.stopChan)
80+
}
81+
82+
func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
3183
req.URL.Scheme = "https"
3284
req.URL.Host = baseUrl
3385
req.URL.Path = strings.Join([]string{apiPath, req.URL.Path}, "/")
3486

3587
req.Header.Add("Content-Type", "application/json")
3688
req.Header.Add("Accept", "application/json")
3789

38-
t.cli.m.Lock()
39-
40-
if t.cli.expiration.Before(time.Now()) {
41-
t.cli.token, err = t.cli.refreshToken(t.cli.token.Refresh)
42-
43-
if err != nil {
44-
return nil, err
45-
}
46-
t.cli.expiration = t.cli.expiration.Add(time.Duration(t.cli.token.RefreshExpires-60) * time.Second)
90+
// Add the access token to the request if it exists
91+
if t.cli.token != nil {
92+
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.cli.token.Access))
4793
}
48-
t.cli.m.Unlock()
49-
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.cli.token.Access))
5094

5195
return t.rt.RoundTrip(req)
5296
}
5397

98+
// NewClient creates a new Nordigen client that handles token refreshes and adds
99+
// the necessary headers, host, and path to all requests.
54100
func NewClient(secretId, secretKey string) (*Client, error) {
55-
var err error
101+
c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{},
102+
secretId: secretId,
103+
secretKey: secretKey,
104+
}
56105

57-
c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{}}
58-
c.token, err = c.newToken(secretId, secretKey)
106+
// Add transport to handle headers, host and path for all requests
107+
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}
59108

60-
if err != nil {
61-
return nil, err
62-
}
63-
c.c.Transport = refreshTokenTransport{rt: http.DefaultTransport, cli: c}
64-
c.expiration = time.Now().Add(time.Duration(c.token.AccessExpires-60) * time.Second)
109+
// Start token handler
110+
c.StartTokenHandler()
111+
defer c.StopTokenHandler()
65112

66113
return c, nil
67114
}

client_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package nordigen
2+
3+
import (
4+
"os"
5+
"testing"
6+
"time"
7+
)
8+
9+
// TestClientTokenRefresh should do a successful token refresh. We force this by
10+
// setting the expiration to a time in the past and then calling any method.
11+
// This test will only run if you have a valid secretId and secretKey in your
12+
// environment.
13+
func TestClientTokenRefresh(t *testing.T) {
14+
id, id_exists := os.LookupEnv("NORDIGEN_SECRET_ID")
15+
key, key_exists := os.LookupEnv("NORDIGEN_SECRET_KEY")
16+
if !id_exists || !key_exists {
17+
t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set")
18+
}
19+
20+
c, err := NewClient(id, key)
21+
if err != nil {
22+
t.Fatalf("NewClient: %s", err)
23+
}
24+
25+
c.expiration = time.Now().Add(-time.Hour)
26+
_, err = c.ListRequisitions()
27+
if err != nil {
28+
t.Fatalf("ListRequisitions: %s", err)
29+
}
30+
}

token.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"encoding/json"
66
"io"
7-
"io/ioutil"
87
"net/http"
98
"net/url"
109
"strings"
@@ -26,22 +25,17 @@ const tokenPath = "token"
2625
const tokenNewPath = "new/"
2726
const tokenRefreshPath = "refresh"
2827

29-
func (c Client) newToken(secretId, secretKey string) (*Token, error) {
28+
func (c Client) newToken() (*Token, error) {
3029
req := http.Request{
3130
Method: http.MethodPost,
3231
URL: &url.URL{
33-
Scheme: "https",
34-
Host: baseUrl,
35-
Path: strings.Join([]string{apiPath, tokenPath, tokenNewPath}, "/"),
32+
Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"),
3633
},
3734
}
38-
req.Header = http.Header{}
39-
req.Header.Add("Content-Type", "application/json")
40-
req.Header.Add("Accept", "application/json")
4135

4236
data, err := json.Marshal(Secret{
43-
SecretId: secretId,
44-
AccessId: secretKey,
37+
SecretId: c.secretId,
38+
AccessId: c.secretKey,
4539
})
4640
if err != nil {
4741
return nil, err
@@ -52,7 +46,7 @@ func (c Client) newToken(secretId, secretKey string) (*Token, error) {
5246
if err != nil {
5347
return nil, err
5448
}
55-
body, err := ioutil.ReadAll(resp.Body)
49+
body, err := io.ReadAll(resp.Body)
5650

5751
if err != nil {
5852
return nil, err
@@ -89,7 +83,7 @@ func (c Client) refreshToken(refresh string) (*Token, error) {
8983
if err != nil {
9084
return nil, err
9185
}
92-
body, err := ioutil.ReadAll(resp.Body)
86+
body, err := io.ReadAll(resp.Body)
9387

9488
if err != nil {
9589
return nil, err

0 commit comments

Comments
 (0)