Skip to content

Commit 689cf37

Browse files
authored
Merge pull request #214 from gonzolino/initialtoken
Do not trigger token refresh callback on first Token() call
2 parents 714eee4 + 2fd243e commit 689cf37

3 files changed

Lines changed: 94 additions & 4 deletions

File tree

client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func newClient(ctx context.Context, config *oauth2.Config, token *oauth2.Token)
3232
// The callback will be invoked whenever the OAuth2 token is automatically refreshed.
3333
func newClientWithCallback(ctx context.Context, config *oauth2.Config, token *oauth2.Token, callback TokenRefreshCallback) *client {
3434
tokenSrc := config.TokenSource(ctx, token)
35-
callbackTokenSrc := NewCallbackTokenSource(tokenSrc, callback)
35+
callbackTokenSrc := NewCallbackTokenSource(tokenSrc, callback, token)
3636

3737
return &client{
3838
http: oauth2.NewClient(ctx, callbackTokenSrc),

tokensource.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,15 @@ func (cts *callbackTokenSource) Token() (*oauth2.Token, error) {
9696
// }
9797
//
9898
// tado := gotado.NewWithTokenRefreshCallback(ctx, config, token, callback)
99-
func NewCallbackTokenSource(src oauth2.TokenSource, callback TokenRefreshCallback) oauth2.TokenSource {
99+
func NewCallbackTokenSource(src oauth2.TokenSource, callback TokenRefreshCallback, initialToken ...*oauth2.Token) oauth2.TokenSource {
100+
var token *oauth2.Token
101+
if len(initialToken) > 0 {
102+
token = initialToken[0]
103+
}
100104
return &callbackTokenSource{
101-
src: src,
102-
callback: callback,
105+
src: src,
106+
callback: callback,
107+
lastToken: copyToken(token),
103108
}
104109
}
105110

tokensource_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,97 @@
11
package gotado
22

33
import (
4+
"sync/atomic"
45
"testing"
56
"time"
67

78
"golang.org/x/oauth2"
89
)
910

11+
// staticTokenSource returns the same token every time.
12+
type staticTokenSource struct {
13+
token *oauth2.Token
14+
}
15+
16+
func (s *staticTokenSource) Token() (*oauth2.Token, error) {
17+
return s.token, nil
18+
}
19+
20+
func TestCallbackTokenSource(t *testing.T) {
21+
t.Run("NoCallbackWhenTokenUnchanged", func(t *testing.T) {
22+
token := &oauth2.Token{
23+
AccessToken: "access",
24+
RefreshToken: "refresh",
25+
}
26+
27+
var callCount atomic.Int32
28+
callback := func(newToken *oauth2.Token) {
29+
callCount.Add(1)
30+
}
31+
32+
src := NewCallbackTokenSource(&staticTokenSource{token: token}, callback, token)
33+
34+
// First call should NOT trigger callback since token matches initialToken
35+
_, err := src.Token()
36+
if err != nil {
37+
t.Fatalf("unexpected error: %v", err)
38+
}
39+
if callCount.Load() != 0 {
40+
t.Errorf("callback should not fire when token unchanged, got %d calls", callCount.Load())
41+
}
42+
})
43+
44+
t.Run("CallbackWhenNilInitialToken", func(t *testing.T) {
45+
token := &oauth2.Token{
46+
AccessToken: "access",
47+
RefreshToken: "refresh",
48+
}
49+
50+
var callCount atomic.Int32
51+
callback := func(newToken *oauth2.Token) {
52+
callCount.Add(1)
53+
}
54+
55+
src := NewCallbackTokenSource(&staticTokenSource{token: token}, callback, nil)
56+
57+
// First call should trigger callback since initialToken is nil
58+
_, err := src.Token()
59+
if err != nil {
60+
t.Fatalf("unexpected error: %v", err)
61+
}
62+
if callCount.Load() != 1 {
63+
t.Errorf("callback should fire when initialToken is nil, got %d calls", callCount.Load())
64+
}
65+
})
66+
67+
t.Run("CallbackWhenTokenChanges", func(t *testing.T) {
68+
initialToken := &oauth2.Token{
69+
AccessToken: "access1",
70+
RefreshToken: "refresh1",
71+
}
72+
newToken := &oauth2.Token{
73+
AccessToken: "access2",
74+
RefreshToken: "refresh2",
75+
}
76+
77+
var callCount atomic.Int32
78+
callback := func(token *oauth2.Token) {
79+
callCount.Add(1)
80+
}
81+
82+
src := NewCallbackTokenSource(&staticTokenSource{token: newToken}, callback, initialToken)
83+
84+
// Should trigger callback since source returns a different token
85+
_, err := src.Token()
86+
if err != nil {
87+
t.Fatalf("unexpected error: %v", err)
88+
}
89+
if callCount.Load() != 1 {
90+
t.Errorf("callback should fire when token changes, got %d calls", callCount.Load())
91+
}
92+
})
93+
}
94+
1095
func TestCopyToken(t *testing.T) {
1196
t.Run("NilToken", func(t *testing.T) {
1297
copied := copyToken(nil)

0 commit comments

Comments
 (0)