@@ -18,8 +18,11 @@ package unit
1818
1919import (
2020 "context"
21+ "fmt"
22+ "net/http"
2123 "testing"
2224
25+ "github.com/jarcoal/httpmock"
2326 "github.com/okta/okta-sdk-golang/v2/okta"
2427
2528 "github.com/okta/okta-sdk-golang/v2/tests"
@@ -90,3 +93,110 @@ func Test_will_error_if_private_key_authorization_type_with_missing_properties(t
9093 _ , _ , err := tests .NewClient (context .TODO (), okta .WithAuthorizationMode ("PrivateKey" ), okta .WithClientId ("" ))
9194 assert .Error (t , err , "Does not error if private key selected with no other required options" )
9295}
96+
97+ type InterceptingRoundTripperTest struct {
98+ Name string
99+ Blocking bool
100+ Interceptor func (* http.Request ) error
101+ ExpectedTransportCalls int
102+ ExpectInterceptorCalled bool
103+ ExpectSdkErrorThrown bool
104+ }
105+
106+ func Test_Intercepting_RoundTripper (t * testing.T ) {
107+ interceptorCalled := false
108+ testsToRun := []InterceptingRoundTripperTest {
109+ {
110+ Name : "Calls interceptor" ,
111+ Blocking : false ,
112+ Interceptor : func (r * http.Request ) error {
113+ interceptorCalled = true
114+ return nil
115+ },
116+ ExpectedTransportCalls : 1 ,
117+ ExpectInterceptorCalled : true ,
118+ ExpectSdkErrorThrown : false ,
119+ },
120+ {
121+ Name : "Does not call transport when interceptor panics when blocking" ,
122+ Blocking : true ,
123+ Interceptor : func (r * http.Request ) error {
124+ interceptorCalled = true
125+ panic ("Some err" )
126+ },
127+ ExpectedTransportCalls : 0 ,
128+ ExpectInterceptorCalled : true ,
129+ ExpectSdkErrorThrown : true ,
130+ },
131+ {
132+ Name : "Calls transport when interceptor panics when non blocking" ,
133+ Blocking : false ,
134+ Interceptor : func (r * http.Request ) error {
135+ interceptorCalled = true
136+ panic ("Some err" )
137+ },
138+ ExpectedTransportCalls : 1 ,
139+ ExpectInterceptorCalled : true ,
140+ ExpectSdkErrorThrown : false ,
141+ },
142+ {
143+ Name : "Does not call transport when interceptor throws err when blocking" ,
144+ Blocking : true ,
145+ Interceptor : func (r * http.Request ) error {
146+ interceptorCalled = true
147+ return fmt .Errorf ("Some error" )
148+ },
149+ ExpectedTransportCalls : 0 ,
150+ ExpectInterceptorCalled : true ,
151+ ExpectSdkErrorThrown : true ,
152+ },
153+ {
154+ Name : "Calls transport when interceptor throws err when not blocking" ,
155+ Blocking : false ,
156+ Interceptor : func (r * http.Request ) error {
157+ interceptorCalled = true
158+ return fmt .Errorf ("Some error" )
159+ },
160+ ExpectedTransportCalls : 1 ,
161+ ExpectInterceptorCalled : true ,
162+ ExpectSdkErrorThrown : false ,
163+ },
164+ }
165+
166+ for _ , test := range testsToRun {
167+ t .Run (
168+ test .Name ,
169+ func (t * testing.T ) {
170+ mockHttpClient := http .DefaultClient
171+ mockTransport := httpmock .DefaultTransport
172+ mockTransport .RegisterNoResponder (func (r * http.Request ) (* http.Response , error ) {
173+ return & http.Response {StatusCode : 200 }, nil
174+ })
175+ mockHttpClient .Transport = mockTransport
176+
177+ _ , oktaClient , err := tests .NewClient (
178+ context .TODO (),
179+ okta .WithHttpInterceptorAndHttpClientPtr (test .Interceptor , mockHttpClient , test .Blocking ),
180+ )
181+ assert .NoError (t , err )
182+
183+ _ , _ , err = oktaClient .IdentityProvider .ActivateIdentityProvider (context .TODO (), "Anything" )
184+
185+ if test .ExpectSdkErrorThrown {
186+ assert .Error (t , err )
187+ } else {
188+ assert .NoError (t , err )
189+ }
190+
191+ assert .Equal (t , test .ExpectInterceptorCalled , interceptorCalled )
192+
193+ callCount := mockTransport .GetTotalCallCount ()
194+
195+ assert .Equal (t , test .ExpectedTransportCalls , callCount )
196+
197+ interceptorCalled = false
198+ mockTransport .ZeroCallCounters ()
199+ },
200+ )
201+ }
202+ }
0 commit comments