Skip to content

Commit a6013a0

Browse files
Merge branch 'master' into prep-v2.18.0
2 parents c2dde08 + bee258d commit a6013a0

2 files changed

Lines changed: 167 additions & 0 deletions

File tree

okta/config.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ package okta
1818

1919
import (
2020
"errors"
21+
"fmt"
2122
"io/ioutil"
2223
"log"
2324
"net/http"
2425
"os"
2526
"syscall"
27+
"time"
2628

2729
"github.com/okta/okta-sdk-golang/v2/okta/cache"
2830
"gopkg.in/square/go-jose.v2"
@@ -69,6 +71,28 @@ type config struct {
6971

7072
type ConfigSetter func(*config)
7173

74+
type InterceptingRoundTripper struct {
75+
Transport http.RoundTripper
76+
Interceptor func(*http.Request) error
77+
Blocking bool
78+
}
79+
80+
func WithHttpInterceptorAndHttpClientPtr(interceptor func(*http.Request) error, httpClient *http.Client, blocking bool) ConfigSetter {
81+
return func(c *config) {
82+
if httpClient == nil {
83+
httpClient = http.DefaultClient
84+
}
85+
86+
if httpClient.Transport == nil {
87+
httpClient.Transport = &http.Transport{
88+
IdleConnTimeout: 30 * time.Second,
89+
}
90+
}
91+
92+
c.HttpClient.Transport = NewInterceptingRoundTripper(interceptor, httpClient.Transport, blocking)
93+
}
94+
}
95+
7296
func WithCache(cache bool) ConfigSetter {
7397
return func(c *config) {
7498
c.Okta.Client.Cache.Enabled = cache
@@ -240,3 +264,36 @@ func fileExists(filename string) bool {
240264
}
241265
return !info.IsDir()
242266
}
267+
268+
func (c *InterceptingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
269+
interceptError := func() (err error) {
270+
defer func() {
271+
if panicked := recover(); panicked != nil {
272+
if panickedErrString, ok := panicked.(string); ok {
273+
err = fmt.Errorf("recovered panic in Okta HTTP interceptor: %s", panickedErrString)
274+
} else {
275+
err = fmt.Errorf("recovered panic in Okta HTTP interceptor, but failed to parse error string")
276+
}
277+
}
278+
}()
279+
return c.Interceptor(req)
280+
}()
281+
282+
if interceptError != nil && c.Blocking {
283+
return nil, interceptError
284+
}
285+
286+
if c.Transport != nil {
287+
response, roundTripperErr := c.Transport.RoundTrip(req)
288+
return response, roundTripperErr
289+
}
290+
return nil, fmt.Errorf("an error ocurred in Okta SDK, Transport was nil")
291+
}
292+
293+
func NewInterceptingRoundTripper(interceptor func(*http.Request) error, transport http.RoundTripper, blocking bool) *InterceptingRoundTripper {
294+
return &InterceptingRoundTripper{
295+
Interceptor: interceptor,
296+
Blocking: blocking,
297+
Transport: transport,
298+
}
299+
}

tests/unit/client_config_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ package unit
1818

1919
import (
2020
"context"
21+
"fmt"
22+
"net/http"
2123
"testing"
2224

25+
"github.com/jarcoal/httpmock"
2326
"github.com/okta/okta-sdk-golang/v2/okta"
2427

2528
"github.com/okta/okta-sdk-golang/v2/tests"
@@ -90,3 +93,110 @@ func Test_will_error_if_private_key_authorization_type_with_missing_properties(t
9093
_, _, err := tests.NewClient(context.TODO(), okta.WithAuthorizationMode("PrivateKey"), okta.WithClientId(""))
9194
assert.Error(t, err, "Does not error if private key selected with no other required options")
9295
}
96+
97+
type InterceptingRoundTripperTest struct {
98+
Name string
99+
Blocking bool
100+
Interceptor func(*http.Request) error
101+
ExpectedTransportCalls int
102+
ExpectInterceptorCalled bool
103+
ExpectSdkErrorThrown bool
104+
}
105+
106+
func Test_Intercepting_RoundTripper(t *testing.T) {
107+
interceptorCalled := false
108+
testsToRun := []InterceptingRoundTripperTest{
109+
{
110+
Name: "Calls interceptor",
111+
Blocking: false,
112+
Interceptor: func(r *http.Request) error {
113+
interceptorCalled = true
114+
return nil
115+
},
116+
ExpectedTransportCalls: 1,
117+
ExpectInterceptorCalled: true,
118+
ExpectSdkErrorThrown: false,
119+
},
120+
{
121+
Name: "Does not call transport when interceptor panics when blocking",
122+
Blocking: true,
123+
Interceptor: func(r *http.Request) error {
124+
interceptorCalled = true
125+
panic("Some err")
126+
},
127+
ExpectedTransportCalls: 0,
128+
ExpectInterceptorCalled: true,
129+
ExpectSdkErrorThrown: true,
130+
},
131+
{
132+
Name: "Calls transport when interceptor panics when non blocking",
133+
Blocking: false,
134+
Interceptor: func(r *http.Request) error {
135+
interceptorCalled = true
136+
panic("Some err")
137+
},
138+
ExpectedTransportCalls: 1,
139+
ExpectInterceptorCalled: true,
140+
ExpectSdkErrorThrown: false,
141+
},
142+
{
143+
Name: "Does not call transport when interceptor throws err when blocking",
144+
Blocking: true,
145+
Interceptor: func(r *http.Request) error {
146+
interceptorCalled = true
147+
return fmt.Errorf("Some error")
148+
},
149+
ExpectedTransportCalls: 0,
150+
ExpectInterceptorCalled: true,
151+
ExpectSdkErrorThrown: true,
152+
},
153+
{
154+
Name: "Calls transport when interceptor throws err when not blocking",
155+
Blocking: false,
156+
Interceptor: func(r *http.Request) error {
157+
interceptorCalled = true
158+
return fmt.Errorf("Some error")
159+
},
160+
ExpectedTransportCalls: 1,
161+
ExpectInterceptorCalled: true,
162+
ExpectSdkErrorThrown: false,
163+
},
164+
}
165+
166+
for _, test := range testsToRun {
167+
t.Run(
168+
test.Name,
169+
func(t *testing.T) {
170+
mockHttpClient := http.DefaultClient
171+
mockTransport := httpmock.DefaultTransport
172+
mockTransport.RegisterNoResponder(func(r *http.Request) (*http.Response, error) {
173+
return &http.Response{StatusCode: 200}, nil
174+
})
175+
mockHttpClient.Transport = mockTransport
176+
177+
_, oktaClient, err := tests.NewClient(
178+
context.TODO(),
179+
okta.WithHttpInterceptorAndHttpClientPtr(test.Interceptor, mockHttpClient, test.Blocking),
180+
)
181+
assert.NoError(t, err)
182+
183+
_, _, err = oktaClient.IdentityProvider.ActivateIdentityProvider(context.TODO(), "Anything")
184+
185+
if test.ExpectSdkErrorThrown {
186+
assert.Error(t, err)
187+
} else {
188+
assert.NoError(t, err)
189+
}
190+
191+
assert.Equal(t, test.ExpectInterceptorCalled, interceptorCalled)
192+
193+
callCount := mockTransport.GetTotalCallCount()
194+
195+
assert.Equal(t, test.ExpectedTransportCalls, callCount)
196+
197+
interceptorCalled = false
198+
mockTransport.ZeroCallCounters()
199+
},
200+
)
201+
}
202+
}

0 commit comments

Comments
 (0)