|
5 | 5 |
|
6 | 6 | import requests_mock |
7 | 7 | from requests.exceptions import HTTPError |
| 8 | +from oauthlib.oauth2 import TokenExpiredError |
8 | 9 |
|
9 | 10 | from parsons import Newmode, Table |
10 | 11 | from test.test_newmode import test_newmode_data |
@@ -337,3 +338,38 @@ def test_checked_response_http_error(self, m): |
337 | 338 | ) |
338 | 339 | with self.assertRaises(HTTPError): |
339 | 340 | self.nm.checked_response(response, self.nm.default_client) |
| 341 | + |
| 342 | + @requests_mock.Mocker() |
| 343 | + @patch("parsons.newmode.newmode.NewmodeV2.get_default_oauth_client") |
| 344 | + def test_token_refresh_on_expired_token(self, m, mock_get_default_oauth_client): |
| 345 | + m.post(V2_API_AUTH_URL, json={"access_token": "fakeAccessToken"}) |
| 346 | + m.get(f"{V2_API_URL}v2.1/test-endpoint", status_code=401) |
| 347 | + |
| 348 | + mock_new_client = mock.MagicMock() |
| 349 | + mock_get_default_oauth_client.return_value = mock_new_client |
| 350 | + self.nm.default_client.request = mock.MagicMock() |
| 351 | + |
| 352 | + mock_response = mock.MagicMock() |
| 353 | + mock_response.raise_for_status = mock.MagicMock() |
| 354 | + mock_response.status_code = 200 |
| 355 | + mock_response.json.return_value = {"data": "success"} |
| 356 | + |
| 357 | + # Simulate token expiration and successful response |
| 358 | + def oauth_side_effect(*args, **kwargs): |
| 359 | + if not hasattr(self, "call_count"): |
| 360 | + self.call_count = 0 |
| 361 | + if self.call_count == 0: |
| 362 | + self.call_count += 1 |
| 363 | + raise TokenExpiredError() |
| 364 | + return mock_response |
| 365 | + |
| 366 | + self.nm.default_client.request.side_effect = oauth_side_effect |
| 367 | + |
| 368 | + response = self.nm.base_request( |
| 369 | + method="GET", |
| 370 | + url=f"{V2_API_URL}v2.1/test-endpoint", |
| 371 | + client=self.nm.default_client, |
| 372 | + ) |
| 373 | + |
| 374 | + mock_get_default_oauth_client.assert_called_once() |
| 375 | + self.assertEqual(response, {"data": "success"}) |
0 commit comments