Skip to content

Commit e16bef0

Browse files
authored
Merge branch 'main' into flake8-pytest-style
2 parents d9c6a6c + 00807a9 commit e16bef0

File tree

2 files changed

+48
-32
lines changed

2 files changed

+48
-32
lines changed

parsons/newmode/newmode.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,23 @@ def __init__(
431431
self.client_secret: str = check_env.check("NEWMODE_API_CLIENT_SECRET", client_secret)
432432
self.headers: Dict[str, str] = {"content-type": "application/json"}
433433
self.default_client: OAuth2APIConnector = self.get_default_oauth_client()
434+
self.campaigns_client: OAuth2APIConnector = self.get_campaigns_oauth_client()
435+
436+
def get_campaigns_oauth_client(self) -> OAuth2APIConnector:
437+
return OAuth2APIConnector(
438+
uri=V2_API_CAMPAIGNS_URL,
439+
auto_refresh_url=None,
440+
client_id=self.client_id,
441+
client_secret=self.client_secret,
442+
headers=V2_API_CAMPAIGNS_HEADERS,
443+
token_url=V2_API_AUTH_URL,
444+
grant_type="client_credentials",
445+
)
434446

435447
def get_default_oauth_client(self) -> OAuth2APIConnector:
436448
return OAuth2APIConnector(
437449
uri=self.base_url,
438-
auto_refresh_url=V2_API_AUTH_URL,
450+
auto_refresh_url=None,
439451
client_id=self.client_id,
440452
client_secret=self.client_secret,
441453
headers=self.headers,
@@ -461,7 +473,7 @@ def base_request(
461473
self,
462474
method: str,
463475
url: str,
464-
client: OAuth2APIConnector,
476+
use_campaigns_client: bool = False,
465477
data: Optional[Dict[str, Any]] = None,
466478
json: Optional[Dict[str, Any]] = None,
467479
params: Optional[Dict[str, Any]] = None,
@@ -474,6 +486,8 @@ def base_request(
474486
if params is None:
475487
params = {}
476488

489+
client = self.default_client if not use_campaigns_client else self.campaigns_client
490+
477491
for attempt in range(retries + 1):
478492
try:
479493
try:
@@ -484,6 +498,10 @@ def base_request(
484498
except TokenExpiredError as e:
485499
logger.warning(f"Token expired: {e}. Refreshing it...")
486500
self.default_client = self.get_default_oauth_client()
501+
self.campaigns_client = self.get_campaigns_oauth_client()
502+
client = (
503+
self.default_client if not use_campaigns_client else self.campaigns_client
504+
)
487505
except Exception as e:
488506
if attempt < retries:
489507
logger.warning(f"Request failed (attempt {attempt + 1}/{retries}). Retrying...")
@@ -496,7 +514,7 @@ def paginate_request(
496514
self,
497515
method: str,
498516
endpoint: str,
499-
client: OAuth2APIConnector,
517+
use_campaigns_client: bool = False,
500518
data_key: str = RESPONSE_DATA_KEY,
501519
data: Optional[Dict[str, Any]] = None,
502520
json: Optional[Dict[str, Any]] = None,
@@ -517,7 +535,7 @@ def paginate_request(
517535
response = self.base_request(
518536
method=method,
519537
url=url,
520-
client=client,
538+
use_campaigns_client=use_campaigns_client,
521539
data=data,
522540
json=json,
523541
params=params,
@@ -531,7 +549,12 @@ def paginate_request(
531549
else:
532550
results.append(response)
533551
# Check for pagination
534-
url = response.get(RESPONSE_LINKS_KEY, {}).get(PAGINATION_NEXT) if response else None
552+
url = None
553+
if response:
554+
url = response.get(RESPONSE_LINKS_KEY, {}).get(PAGINATION_NEXT, {})
555+
if isinstance(url, dict):
556+
url = url.get("href")
557+
535558
return results
536559

537560
def converted_request(
@@ -544,14 +567,13 @@ def converted_request(
544567
params: Optional[Dict[str, Any]] = None,
545568
convert_to_table: bool = True,
546569
data_key: Optional[str] = None,
547-
client: Optional[OAuth2APIConnector] = None,
570+
use_campaigns_client: bool = False,
548571
override_api_version: Optional[str] = None,
549572
) -> Union[Table, Dict[str, Any]]:
550573
"""Internal method to make a call to the Newmode API and convert the result to a Parsons table."""
551574

552575
if params is None:
553576
params = {}
554-
client = client if client else self.default_client
555577
response = self.paginate_request(
556578
method=method,
557579
json=json,
@@ -560,12 +582,12 @@ def converted_request(
560582
data_key=data_key,
561583
supports_version=supports_version,
562584
endpoint=endpoint,
563-
client=client,
585+
use_campaigns_client=use_campaigns_client,
564586
override_api_version=override_api_version,
565587
)
566588
if response:
567589
if convert_to_table:
568-
return client.convert_to_table(data=response)
590+
return self.default_client.convert_to_table(data=response)
569591
else:
570592
return response
571593

@@ -607,22 +629,13 @@ def get_campaign_ids(self, params: Optional[Dict[str, Any]] = None) -> List[str]
607629
if params is None:
608630
params = {}
609631
endpoint = "node/action"
610-
campaigns_client = OAuth2APIConnector(
611-
uri=V2_API_CAMPAIGNS_URL,
612-
auto_refresh_url=V2_API_AUTH_URL,
613-
client_id=self.client_id,
614-
client_secret=self.client_secret,
615-
headers=V2_API_CAMPAIGNS_HEADERS,
616-
token_url=V2_API_AUTH_URL,
617-
grant_type="client_credentials",
618-
)
619632

620633
data = self.converted_request(
621634
endpoint=endpoint,
622635
method="GET",
623636
params=params,
624637
data_key=RESPONSE_DATA_KEY,
625-
client=campaigns_client,
638+
use_campaigns_client=True,
626639
override_api_version=V2_API_CAMPAIGNS_VERSION,
627640
)
628641
return data["id"]
@@ -726,7 +739,10 @@ def get_submissions(self, campaign_id: str, params: Optional[Dict[str, Any]] = N
726739
params = {}
727740
params = {"action": campaign_id}
728741
response = self.converted_request(
729-
endpoint="submission", method="GET", params=params, data_key=RESPONSE_DATA_KEY
742+
endpoint="submission",
743+
method="GET",
744+
params=params,
745+
data_key=RESPONSE_DATA_KEY,
730746
)
731747
return response
732748

@@ -765,7 +781,9 @@ def __new__(
765781
api_version = check_env.check("NEWMODE_API_VERSION", api_version)
766782
if api_version.startswith("v2"):
767783
return NewmodeV2(
768-
client_id=client_id, client_secret=client_secret, api_version=api_version
784+
client_id=client_id,
785+
client_secret=client_secret,
786+
api_version=api_version,
769787
)
770788
if api_version.startswith("v1"):
771789
return NewmodeV1(api_user=api_user, api_password=api_password, api_version=api_version)

test/test_newmode/test_newmode.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def test_base_request_retries(self, m, mock_logger):
285285
self.nm.base_request(
286286
method="GET",
287287
url=f"{V2_API_URL}v2.1/test-endpoint",
288-
client=self.nm.default_client,
289288
retries=2,
290289
)
291290

@@ -345,18 +344,18 @@ def test_checked_response_http_error(self, m):
345344
@patch("parsons.newmode.newmode.NewmodeV2.get_default_oauth_client")
346345
def test_token_refresh_on_expired_token(self, m, mock_get_default_oauth_client):
347346
m.post(V2_API_AUTH_URL, json={"access_token": "fakeAccessToken"})
348-
m.get(f"{V2_API_URL}v2.1/test-endpoint", status_code=401)
349347

350348
mock_new_client = mock.MagicMock()
351349
mock_get_default_oauth_client.return_value = mock_new_client
352-
self.nm.default_client.request = mock.MagicMock()
353350

354351
mock_response = mock.MagicMock()
355352
mock_response.raise_for_status = mock.MagicMock()
356353
mock_response.status_code = 200
357354
mock_response.json.return_value = {"data": "success"}
355+
mock_new_client.request.return_value = mock_response
356+
357+
mock_new_client.json_check.return_value = True
358358

359-
# Simulate token expiration and successful response
360359
def oauth_side_effect(*args, **kwargs):
361360
if not hasattr(self, "call_count"):
362361
self.call_count = 0
@@ -365,13 +364,12 @@ def oauth_side_effect(*args, **kwargs):
365364
raise TokenExpiredError()
366365
return mock_response
367366

368-
self.nm.default_client.request.side_effect = oauth_side_effect
369-
370-
response = self.nm.base_request(
371-
method="GET",
372-
url=f"{V2_API_URL}v2.1/test-endpoint",
373-
client=self.nm.default_client,
374-
)
367+
with patch.object(self.nm.default_client, "request", side_effect=oauth_side_effect):
368+
m.get(f"{V2_API_URL}v2.1/test-endpoint", json={"data": "success"}, status_code=200)
369+
response = self.nm.base_request(method="GET", url=f"{V2_API_URL}v2.1/test-endpoint")
375370

376371
mock_get_default_oauth_client.assert_called_once()
377372
assert response == {"data": "success"}
373+
mock_new_client.request.assert_called_with(
374+
url=f"{V2_API_URL}v2.1/test-endpoint", req_type="GET", json=None, data=None, params={}
375+
)

0 commit comments

Comments
 (0)