Skip to content

Commit cf63f6c

Browse files
Add retry logic to token acquisition for auth credentials
Token acquisition for OIDC, M2M OAuth, and Azure client secret credentials now retries on transient failures (429, 5xx, network errors). Retries happen at the token source level rather than the HTTP transport level. The existing RoundTrip implementation on ApiClient violates the standard http.RoundTripper contract by retrying requests and interpreting HTTP status codes -- an anti-pattern that arguably led to #1398 in the first place, since the transport-level retry cannot rewind request bodies created by the oauth2 library. By retrying at the application level, each attempt creates a fresh HTTP request, sidestepping the body-reset problem entirely. Fixes #1398 Fixes #1072 Co-authored-by: Isaac Signed-off-by: Ubuntu <renaud.hartert@databricks.com>
1 parent 78469e6 commit cf63f6c

File tree

10 files changed

+382
-5
lines changed

10 files changed

+382
-5
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
### Bug Fixes
1515

16+
* Add retry logic to token acquisition for OIDC, M2M, and Azure client secret credentials ([#1398](https://github.com/databricks/databricks-sdk-go/issues/1398), [#1072](https://github.com/databricks/databricks-sdk-go/issues/1072)).
1617
* Fix double-caching of OAuth tokens in Azure client secret credentials ([#1549](https://github.com/databricks/databricks-sdk-go/issues/1549)).
1718
* Disable async token refresh for GCP credential providers to avoid wasted refresh attempts caused by double-caching with Google's internal `oauth2.ReuseTokenSource` ([#1549](https://github.com/databricks/databricks-sdk-go/issues/1549)).
1819
* Fixed double-caching in M2M OAuth that prevented the proactive async token refresh from reaching the HTTP endpoint until ~10s before expiry, causing bursts of 401 errors at token rotation boundaries ([#1549](https://github.com/databricks/databricks-sdk-go/issues/1549)).

config/auth_azure_client_secret.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
6464
aadEndpoint := env.AzureActiveDirectoryEndpoint()
6565
managementEndpoint := env.AzureServiceManagementEndpoint()
6666
opts := cacheOptions(cfg)
67-
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, env.AzureApplicationID), opts...)
68-
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, managementEndpoint), opts...)
67+
inner := azureReuseTokenSource(nil, auth.NewRetryingTokenSource(c.tokenSourceFor(ctx, cfg, aadEndpoint, env.AzureApplicationID)), opts...)
68+
management := azureReuseTokenSource(nil, auth.NewRetryingTokenSource(c.tokenSourceFor(ctx, cfg, aadEndpoint, managementEndpoint)), opts...)
6969
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken, false, opts...))
7070
return newVisitorOAuthCredentials(visitor, inner), nil
7171
}

config/auth_m2m.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
5050
return ccfg.Token(ctx)
5151
})
5252
return credentials.NewOAuthCredentialsProviderFromTokenSource(
53-
auth.NewCachedTokenSource(directTS, cacheOptions(cfg)...),
53+
auth.NewCachedTokenSource(auth.NewRetryingTokenSource(directTS), cacheOptions(cfg)...),
5454
), nil
5555
}

config/auth_oidc.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
"github.com/databricks/databricks-sdk-go/config/credentials"
7+
"github.com/databricks/databricks-sdk-go/config/experimental/auth"
78
"github.com/databricks/databricks-sdk-go/config/experimental/auth/oidc"
89
)
910

@@ -100,7 +101,7 @@ func oidcStrategy(cfg *Config, name string, ts oidc.IDTokenSource) CredentialsSt
100101
}
101102
oidcConfig.SetScopes(cfg.GetScopes())
102103
tokenSource := oidc.NewDatabricksOIDCTokenSource(oidcConfig)
103-
return NewTokenSourceStrategy(name, tokenSource)
104+
return NewTokenSourceStrategy(name, auth.NewRetryingTokenSource(tokenSource))
104105
}
105106

106107
// failedStrategy is a CredentialsStrategy that always fails.
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net/url"
7+
"strings"
8+
"time"
9+
10+
"github.com/databricks/databricks-sdk-go/experimental/api"
11+
"golang.org/x/oauth2"
12+
)
13+
14+
// retryingTokenSource wraps a TokenSource with retry logic for transient
15+
// failures during token acquisition. Each retry calls the underlying Token()
16+
// method, which creates a fresh HTTP request -- avoiding the body-rewinding
17+
// problems that occur when retrying at the transport level.
18+
type retryingTokenSource struct {
19+
inner TokenSource
20+
opts []api.Option
21+
}
22+
23+
var defaultOptions = []api.Option{
24+
api.WithTimeout(1 * time.Minute),
25+
api.WithRetrier(func() api.Retrier {
26+
return api.RetryOn(api.BackoffPolicy{}, isRetriableTokenError)
27+
}),
28+
}
29+
30+
// NewRetryingTokenSource wraps inner with retry logic for transient failures.
31+
// By default it uses exponential backoff with a 1-minute timeout and a 30-second maximum delay.
32+
// Additional api.Option values can be provided to override the defaults.
33+
func NewRetryingTokenSource(inner TokenSource, opts ...api.Option) TokenSource {
34+
return &retryingTokenSource{
35+
inner: inner,
36+
opts: append(defaultOptions, opts...),
37+
}
38+
}
39+
40+
// Token returns a token from the underlying source, retrying on transient
41+
// errors.
42+
func (r *retryingTokenSource) Token(ctx context.Context) (*oauth2.Token, error) {
43+
return api.ExecuteWithResult(ctx, r.inner.Token, r.opts...)
44+
}
45+
46+
// httpStatusCoder is implemented by errors that carry an HTTP status code.
47+
// This interface avoids importing httpclient (which would create a cycle)
48+
// while still allowing to classify httpclient.HttpError by status code.
49+
type httpStatusCoder interface {
50+
HTTPStatusCode() int
51+
}
52+
53+
// isRetriableTokenError returns true if the error is a transient failure that
54+
// should be retried. This covers HTTP errors from the SDK's transport layer,
55+
// OAuth2 token endpoint errors, and transient network errors.
56+
func isRetriableTokenError(err error) bool {
57+
if code := httpStatusCode(err); code != 0 {
58+
return code == 429 || code >= 500
59+
}
60+
var urlErr *url.Error
61+
if errors.As(err, &urlErr) {
62+
msg := urlErr.Error()
63+
for _, s := range transientNetworkErrors {
64+
if strings.Contains(msg, s) {
65+
return true
66+
}
67+
}
68+
}
69+
return false
70+
}
71+
72+
// httpStatusCode extracts the HTTP status code from an error, if available.
73+
func httpStatusCode(err error) int {
74+
// Check oauth2.RetrieveError (has Response field with StatusCode).
75+
var retrieveErr *oauth2.RetrieveError
76+
if errors.As(err, &retrieveErr) && retrieveErr.Response != nil {
77+
return retrieveErr.Response.StatusCode
78+
}
79+
// Check for any error that exposes a StatusCode() method.
80+
var sc httpStatusCoder
81+
if errors.As(err, &sc) {
82+
return sc.HTTPStatusCode()
83+
}
84+
return 0
85+
}
86+
87+
// transientNetworkErrors is the list of error substrings that indicate a
88+
// transient network failure. These mirror the checks in
89+
// httpclient/errors.go isRetriableUrlError.
90+
var transientNetworkErrors = []string{
91+
"connection reset by peer",
92+
"TLS handshake timeout",
93+
"connection refused",
94+
"i/o timeout",
95+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"net/url"
9+
"testing"
10+
"time"
11+
12+
"golang.org/x/oauth2"
13+
)
14+
15+
// testHTTPError is a test error that implements the httpStatusCoder interface,
16+
// mirroring httpclient.HttpError without importing it.
17+
type testHTTPError struct {
18+
code int
19+
message string
20+
}
21+
22+
func (e *testHTTPError) Error() string { return fmt.Sprintf("http %d: %s", e.code, e.message) }
23+
func (e *testHTTPError) HTTPStatusCode() int { return e.code }
24+
25+
func TestRetryingTokenSource(t *testing.T) {
26+
validToken := &oauth2.Token{
27+
AccessToken: "test-token",
28+
Expiry: time.Now().Add(time.Hour),
29+
}
30+
31+
testCases := []struct {
32+
name string
33+
callErrors []error
34+
wantToken bool
35+
wantErr bool
36+
wantCalls int
37+
}{
38+
{
39+
name: "success on first call",
40+
callErrors: []error{nil},
41+
wantToken: true,
42+
wantCalls: 1,
43+
},
44+
{
45+
name: "retry on http 429 then succeed",
46+
callErrors: []error{
47+
&testHTTPError{code: 429, message: "rate limited"},
48+
nil,
49+
},
50+
wantToken: true,
51+
wantCalls: 2,
52+
},
53+
{
54+
name: "retry on http 500 then succeed",
55+
callErrors: []error{
56+
&testHTTPError{code: 500, message: "server error"},
57+
nil,
58+
},
59+
wantToken: true,
60+
wantCalls: 2,
61+
},
62+
{
63+
name: "retry on oauth2 retrieve error 429",
64+
callErrors: []error{
65+
&oauth2.RetrieveError{Response: &http.Response{StatusCode: 429}},
66+
nil,
67+
},
68+
wantToken: true,
69+
wantCalls: 2,
70+
},
71+
{
72+
name: "retry on transient network error",
73+
callErrors: []error{
74+
&url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("connection reset by peer")},
75+
nil,
76+
},
77+
wantToken: true,
78+
wantCalls: 2,
79+
},
80+
{
81+
name: "no retry on http 401",
82+
callErrors: []error{
83+
&testHTTPError{code: 401, message: "unauthorized"},
84+
},
85+
wantErr: true,
86+
wantCalls: 1,
87+
},
88+
{
89+
name: "no retry on http 400",
90+
callErrors: []error{
91+
&testHTTPError{code: 400, message: "bad request"},
92+
},
93+
wantErr: true,
94+
wantCalls: 1,
95+
},
96+
}
97+
98+
for _, tc := range testCases {
99+
t.Run(tc.name, func(t *testing.T) {
100+
callCount := 0
101+
inner := TokenSourceFn(func(ctx context.Context) (*oauth2.Token, error) {
102+
err := tc.callErrors[callCount]
103+
callCount++
104+
if err != nil {
105+
return nil, err
106+
}
107+
return validToken, nil
108+
})
109+
110+
ts := NewRetryingTokenSource(inner)
111+
tok, err := ts.Token(context.Background())
112+
113+
if callCount != tc.wantCalls {
114+
t.Errorf("got %d calls, want %d", callCount, tc.wantCalls)
115+
}
116+
if tc.wantErr && err == nil {
117+
t.Error("got nil error, want error")
118+
}
119+
if !tc.wantErr && err != nil {
120+
t.Errorf("got error %v, want nil", err)
121+
}
122+
if tc.wantToken && tok == nil {
123+
t.Error("got nil token, want token")
124+
}
125+
if !tc.wantToken && tok != nil {
126+
t.Errorf("got token %v, want nil", tok)
127+
}
128+
})
129+
}
130+
}
131+
132+
func TestIsRetriableTokenError(t *testing.T) {
133+
testCases := []struct {
134+
name string
135+
err error
136+
want bool
137+
}{
138+
{
139+
name: "http 429",
140+
err: &testHTTPError{code: 429},
141+
want: true,
142+
},
143+
{
144+
name: "http 500",
145+
err: &testHTTPError{code: 500},
146+
want: true,
147+
},
148+
{
149+
name: "http 503",
150+
err: &testHTTPError{code: 503},
151+
want: true,
152+
},
153+
{
154+
name: "http 401",
155+
err: &testHTTPError{code: 401},
156+
want: false,
157+
},
158+
{
159+
name: "http 403",
160+
err: &testHTTPError{code: 403},
161+
want: false,
162+
},
163+
{
164+
name: "oauth2 retrieve error 429",
165+
err: &oauth2.RetrieveError{Response: &http.Response{StatusCode: 429}},
166+
want: true,
167+
},
168+
{
169+
name: "oauth2 retrieve error 500",
170+
err: &oauth2.RetrieveError{Response: &http.Response{StatusCode: 500}},
171+
want: true,
172+
},
173+
{
174+
name: "oauth2 retrieve error 400",
175+
err: &oauth2.RetrieveError{Response: &http.Response{StatusCode: 400}},
176+
want: false,
177+
},
178+
{
179+
name: "connection reset",
180+
err: &url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("connection reset by peer")},
181+
want: true,
182+
},
183+
{
184+
name: "tls handshake timeout",
185+
err: &url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("TLS handshake timeout")},
186+
want: true,
187+
},
188+
{
189+
name: "connection refused",
190+
err: &url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("connection refused")},
191+
want: true,
192+
},
193+
{
194+
name: "i/o timeout",
195+
err: &url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("i/o timeout")},
196+
want: true,
197+
},
198+
{
199+
name: "url error non-transient",
200+
err: &url.Error{Op: "Post", URL: "https://host/token", Err: fmt.Errorf("no such host")},
201+
want: false,
202+
},
203+
{
204+
name: "generic error",
205+
err: errors.New("something went wrong"),
206+
want: false,
207+
},
208+
}
209+
210+
for _, tc := range testCases {
211+
t.Run(tc.name, func(t *testing.T) {
212+
got := isRetriableTokenError(tc.err)
213+
if got != tc.want {
214+
t.Errorf("isRetriableTokenError(%v) = %v, want %v", tc.err, got, tc.want)
215+
}
216+
})
217+
}
218+
}

experimental/api/api.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ func Execute(ctx context.Context, call Call, opts ...Option) error {
2222
return execute(ctx, call, options, sleep)
2323
}
2424

25+
// ExecuteWithResult returns the result of calling op with the given options.
26+
// It is a convenience wrapper around Execute for operations that return a
27+
// value. In case of error, the zero value of T is returned.
28+
func ExecuteWithResult[T any](ctx context.Context, op func(context.Context) (T, error), opts ...Option) (T, error) {
29+
var result T
30+
err := Execute(ctx, func(ctx context.Context) error {
31+
var err error
32+
result, err = op(ctx)
33+
return err
34+
}, opts...)
35+
if err != nil {
36+
var zero T // guarantee zero value on error
37+
return zero, err
38+
}
39+
return result, nil
40+
}
41+
2542
// sleep sleeps for the given duration. It is mostly equivalent to time.Sleep,
2643
// but can be interrupted by ctx.Done() if the context completes before the
2744
// duration elapses.

0 commit comments

Comments
 (0)