Skip to content

Commit 4bd54a9

Browse files
authored
Merge pull request #28 from martinohansen/main
fix: 401 invalid token
2 parents 34afc3b + 6dddb28 commit 4bd54a9

File tree

3 files changed

+40
-67
lines changed

3 files changed

+40
-67
lines changed

client.go

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,70 +30,42 @@ type Transport struct {
3030
// StartTokenHandler handles token refreshes in the background
3131
func (c *Client) StartTokenHandler(ctx context.Context) error {
3232
// Initialize the first token
33-
token, err := c.newToken(ctx)
33+
err := c.newToken(ctx)
3434
if err != nil {
3535
return errors.New("getting initial token: " + err.Error())
3636
}
37-
c.m.Lock()
38-
c.token = token
39-
c.m.Unlock()
4037

4138
go c.tokenHandler(ctx)
4239
return nil
4340
}
4441

4542
// tokenHandler gets a new token using the refresh token and a new pair when the
46-
// refresh token expires.
43+
// refresh token expires
4744
func (c *Client) tokenHandler(ctx context.Context) {
48-
newTokenTimer := time.NewTimer(0) // Start immediately
49-
refreshTokenTimer := time.NewTimer(0) // Start immediately
50-
defer func() {
51-
newTokenTimer.Stop()
52-
refreshTokenTimer.Stop()
53-
}()
54-
55-
resetTimer := func(timer *time.Timer, expiryTime time.Time) {
56-
if !timer.Stop() {
57-
<-timer.C
58-
}
59-
timer.Reset(time.Until(expiryTime))
60-
}
45+
refresh := time.NewTicker(time.Hour * 12) // 12 hours
46+
new := time.NewTicker(time.Hour * 24 * 14) // 14 days
47+
defer refresh.Stop()
48+
defer new.Stop()
6149

6250
for {
63-
c.m.RLock()
64-
newTokenExpiry := c.token.accessExpires(2)
65-
refreshTokenExpiry := c.token.refreshExpires(2)
66-
c.m.RUnlock()
67-
68-
resetTimer(newTokenTimer, newTokenExpiry)
69-
resetTimer(refreshTokenTimer, refreshTokenExpiry)
70-
7151
select {
7252
case <-ctx.Done():
7353
return
74-
case <-newTokenTimer.C:
75-
if token, err := c.newToken(ctx); err != nil {
54+
55+
case <-new.C:
56+
if err := c.newToken(ctx); err != nil {
57+
// TODO(Martin): Improve error handling
7658
panic(fmt.Sprintf("getting new token: %s", err))
77-
} else {
78-
c.updateToken(token)
7959
}
80-
case <-refreshTokenTimer.C:
81-
if token, err := c.refreshToken(ctx); err != nil {
60+
61+
case <-refresh.C:
62+
if err := c.refreshToken(ctx); err != nil {
8263
panic(fmt.Sprintf("refreshing token: %s", err))
83-
} else {
84-
c.updateToken(token)
8564
}
8665
}
8766
}
8867
}
8968

90-
// updateToken updates the client's token
91-
func (c *Client) updateToken(t *Token) {
92-
c.m.Lock()
93-
defer c.m.Unlock()
94-
c.token = t
95-
}
96-
9769
func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
9870
req.URL.Scheme = "https"
9971
req.URL.Host = baseUrl

client_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,11 @@ func initTestClient(t *testing.T) *Client {
3232
c.c.Transport = Transport{rt: http.DefaultTransport, cli: c}
3333

3434
// Initialize the first token
35-
token, err := c.newToken(context.Background())
35+
err := c.newToken(context.Background())
3636
if err != nil {
3737
t.Fatalf("newToken: %s", err)
3838
}
3939

40-
c.token = token
4140
sharedClient = c
4241
})
4342

token.go

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"net/http"
99
"net/url"
1010
"strings"
11-
"time"
1211
)
1312

1413
type Token struct {
@@ -31,23 +30,17 @@ const tokenPath = "token"
3130
const tokenNewPath = "new/"
3231
const tokenRefreshPath = "refresh/"
3332

34-
// accessExpires returns the time when access token expires divided by divisor
35-
func (t *Token) accessExpires(divisor int) time.Time {
36-
return time.Now().Add(time.Second * time.Duration(t.AccessExpires/divisor))
37-
}
38-
39-
// refreshExpires returns the time when refresh token expires divided by divisor
40-
func (t *Token) refreshExpires(divisor int) time.Time {
41-
return time.Now().Add(time.Second * time.Duration(t.RefreshExpires/divisor))
42-
}
33+
// newToken gets a new access token
34+
func (c *Client) newToken(ctx context.Context) error {
35+
c.m.Lock()
36+
defer c.m.Unlock()
4337

44-
func (c *Client) newToken(ctx context.Context) (*Token, error) {
4538
data, err := json.Marshal(Secret{
4639
SecretId: c.secretId,
4740
AccessId: c.secretKey,
4841
})
4942
if err != nil {
50-
return nil, err
43+
return err
5144
}
5245

5346
req := &http.Request{
@@ -61,29 +54,35 @@ func (c *Client) newToken(ctx context.Context) (*Token, error) {
6154

6255
resp, err := c.c.Do(req)
6356
if err != nil {
64-
return nil, err
57+
return err
6558
}
6659
defer resp.Body.Close()
6760

6861
body, readErr := io.ReadAll(resp.Body)
6962
if readErr != nil {
70-
return nil, readErr
63+
return readErr
7164
}
7265
if resp.StatusCode != http.StatusOK {
73-
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
66+
return &APIError{StatusCode: resp.StatusCode, Body: string(body)}
7467
}
7568

7669
t := &Token{}
7770
if err := json.Unmarshal(body, t); err != nil {
78-
return nil, err
71+
return err
7972
}
80-
return t, nil
73+
74+
c.token = t
75+
return nil
8176
}
8277

83-
func (c *Client) refreshToken(ctx context.Context) (*Token, error) {
78+
// refreshToken gets a new access token using the refresh token
79+
func (c *Client) refreshToken(ctx context.Context) error {
80+
c.m.Lock()
81+
defer c.m.Unlock()
82+
8483
data, err := json.Marshal(TokenRefresh{Refresh: c.token.Refresh})
8584
if err != nil {
86-
return nil, err
85+
return err
8786
}
8887

8988
req := &http.Request{
@@ -97,21 +96,24 @@ func (c *Client) refreshToken(ctx context.Context) (*Token, error) {
9796

9897
resp, err := c.c.Do(req)
9998
if err != nil {
100-
return nil, err
99+
return err
101100
}
102101
defer resp.Body.Close()
103102

104103
body, readErr := io.ReadAll(resp.Body)
105104
if readErr != nil {
106-
return nil, readErr
105+
return readErr
107106
}
108107
if resp.StatusCode != http.StatusOK {
109-
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
108+
return &APIError{StatusCode: resp.StatusCode, Body: string(body)}
110109
}
111110

112111
t := &Token{}
113112
if err := json.Unmarshal(body, t); err != nil {
114-
return nil, err
113+
return err
115114
}
116-
return t, nil
115+
116+
c.token.Access = t.Access
117+
c.token.AccessExpires = t.AccessExpires
118+
return nil
117119
}

0 commit comments

Comments
 (0)