Skip to content

Commit b3b2195

Browse files
committed
Merge branch 'token-source' into dev
2 parents fde6129 + 0ca81d8 commit b3b2195

8 files changed

+56
-11
lines changed

msal/application.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ class ClientApplication(object):
176176
REMOVE_ACCOUNT_ID = "903"
177177

178178
ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect"
179+
_TOKEN_SOURCE = "token_source"
180+
_TOKEN_SOURCE_IDP = "identity_provider"
181+
_TOKEN_SOURCE_CACHE = "cache"
182+
_TOKEN_SOURCE_BROKER = "broker"
179183

180184
def __init__(
181185
self, client_id,
@@ -998,6 +1002,8 @@ def authorize(): # A controller in a web app
9981002
self._client_capabilities,
9991003
auth_code_flow.pop("claims_challenge", None))),
10001004
**kwargs))
1005+
if "access_token" in response:
1006+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
10011007
telemetry_context.update_telemetry(response)
10021008
return response
10031009

@@ -1070,6 +1076,8 @@ def acquire_token_by_authorization_code(
10701076
self._client_capabilities, claims_challenge)),
10711077
nonce=nonce,
10721078
**kwargs))
1079+
if "access_token" in response:
1080+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
10731081
telemetry_context.update_telemetry(response)
10741082
return response
10751083

@@ -1218,6 +1226,8 @@ def _acquire_token_by_cloud_shell(self, scopes, data=None):
12181226
data=data or {},
12191227
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
12201228
))
1229+
if "access_token" in response:
1230+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
12211231
return response
12221232

12231233
def acquire_token_silent(
@@ -1395,6 +1405,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
13951405
"access_token": entry["secret"],
13961406
"token_type": entry.get("token_type", "Bearer"),
13971407
"expires_in": int(expires_in), # OAuth2 specs defines it as int
1408+
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
13981409
}
13991410
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
14001411
refresh_reason = msal.telemetry.AT_AGING
@@ -1437,6 +1448,8 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
14371448
result = self._acquire_token_for_client(
14381449
scopes, refresh_reason, claims_challenge=claims_challenge,
14391450
**kwargs)
1451+
if result and "access_token" in result:
1452+
result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
14401453
if (result and "error" not in result) or (not access_token_from_cache):
14411454
return result
14421455
except http_exceptions:
@@ -1455,6 +1468,7 @@ def _process_broker_response(self, response, scopes, data):
14551468
data=data,
14561469
_account_id=response["_account_id"],
14571470
))
1471+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
14581472
return _clean_up(response)
14591473

14601474
def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
@@ -1611,6 +1625,8 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
16111625
on_updating_rt=False,
16121626
on_removing_rt=lambda rt_item: None, # No OP
16131627
**kwargs))
1628+
if "access_token" in response:
1629+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
16141630
telemetry_context.update_telemetry(response)
16151631
return response
16161632

@@ -1658,6 +1674,7 @@ def acquire_token_by_username_password(
16581674
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)
16591675
headers = telemetry_context.generate_headers()
16601676
data = dict(kwargs.pop("data", {}), claims=claims)
1677+
response = None
16611678
if not self.authority.is_adfs:
16621679
user_realm_result = self.authority.user_realm_discovery(
16631680
username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID])
@@ -1666,13 +1683,14 @@ def acquire_token_by_username_password(
16661683
user_realm_result, username, password, scopes=scopes,
16671684
data=data,
16681685
headers=headers, **kwargs))
1669-
telemetry_context.update_telemetry(response)
1670-
return response
1671-
response = _clean_up(self.client.obtain_token_by_username_password(
1686+
if response is None: # Either ADFS or not federated
1687+
response = _clean_up(self.client.obtain_token_by_username_password(
16721688
username, password, scope=scopes,
16731689
headers=headers,
16741690
data=data,
16751691
**kwargs))
1692+
if "access_token" in response:
1693+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
16761694
telemetry_context.update_telemetry(response)
16771695
return response
16781696

@@ -1859,7 +1877,7 @@ def acquire_token_interactive(
18591877
logger.warning(
18601878
"Ignoring parameter extra_scopes_to_consent, "
18611879
"which is not supported by broker")
1862-
return self._acquire_token_interactive_via_broker(
1880+
response = self._acquire_token_interactive_via_broker(
18631881
scopes,
18641882
parent_window_handle,
18651883
enable_msa_passthrough,
@@ -1870,6 +1888,7 @@ def acquire_token_interactive(
18701888
login_hint=login_hint,
18711889
max_age=max_age,
18721890
)
1891+
return self._process_broker_response(response, scopes, data)
18731892

18741893
on_before_launching_ui(ui="browser")
18751894
telemetry_context = self._build_telemetry_context(
@@ -1892,6 +1911,8 @@ def acquire_token_interactive(
18921911
headers=telemetry_context.generate_headers(),
18931912
browser_name=_preferred_browser(),
18941913
**kwargs))
1914+
if "access_token" in response:
1915+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
18951916
telemetry_context.update_telemetry(response)
18961917
return response
18971918

@@ -1928,7 +1949,7 @@ def _acquire_token_interactive_via_broker(
19281949
claims=claims,
19291950
**data)
19301951
if response and "error" not in response:
1931-
return self._process_broker_response(response, scopes, data)
1952+
return response
19321953
# login_hint undecisive or not exists
19331954
if prompt == "none" or not prompt: # Must/Can attempt _signin_silently()
19341955
logger.debug("Calling broker._signin_silently()")
@@ -1949,9 +1970,7 @@ def _acquire_token_interactive_via_broker(
19491970
if is_wrong_account:
19501971
logger.debug(wrong_account_error_message)
19511972
if prompt == "none":
1952-
return self._process_broker_response( # It is either token or error
1953-
response, scopes, data
1954-
) if not is_wrong_account else {
1973+
return response if not is_wrong_account else {
19551974
"error": "broker_error",
19561975
"error_description": wrong_account_error_message,
19571976
}
@@ -1966,11 +1985,11 @@ def _acquire_token_interactive_via_broker(
19661985
"_broker_status") in recoverable_errors:
19671986
pass # It will fall back to the _signin_interactively()
19681987
else:
1969-
return self._process_broker_response(response, scopes, data)
1988+
return response
19701989

19711990
logger.debug("Falls back to broker._signin_interactively()")
19721991
on_before_launching_ui(ui="broker")
1973-
response = _signin_interactively(
1992+
return _signin_interactively(
19741993
authority, self.client_id, scopes,
19751994
None if parent_window_handle is self.CONSOLE_WINDOW_HANDLE
19761995
else parent_window_handle,
@@ -1981,7 +2000,6 @@ def _acquire_token_interactive_via_broker(
19812000
max_age=max_age,
19822001
enable_msa_pt=enable_msa_passthrough,
19832002
**data)
1984-
return self._process_broker_response(response, scopes, data)
19852003

19862004
def initiate_device_flow(self, scopes=None, **kwargs):
19872005
"""Initiate a Device Flow instance,
@@ -2036,6 +2054,8 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
20362054
),
20372055
headers=telemetry_context.generate_headers(),
20382056
**kwargs))
2057+
if "access_token" in response:
2058+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
20392059
telemetry_context.update_telemetry(response)
20402060
return response
20412061

@@ -2145,5 +2165,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
21452165
headers=telemetry_context.generate_headers(),
21462166
# TBD: Expose a login_hint (or ccs_routing_hint) param for web app
21472167
**kwargs))
2168+
if "access_token" in response:
2169+
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
21482170
telemetry_context.update_telemetry(response)
21492171
return response

sample/confidential_client_certificate_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def acquire_and_use_token():
6363
result = global_app.acquire_token_for_client(scopes=config["scope"])
6464

6565
if "access_token" in result:
66+
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
6667
# Calling graph using the access token
6768
graph_data = requests.get( # Use token to call downstream service
6869
config["endpoint"],

sample/confidential_client_secret_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def acquire_and_use_token():
6262
result = global_app.acquire_token_for_client(scopes=config["scope"])
6363

6464
if "access_token" in result:
65+
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
6566
# Calling graph using the access token
6667
graph_data = requests.get( # Use token to call downstream service
6768
config["endpoint"],

sample/device_flow_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def acquire_and_use_token():
8484
# and then keep calling acquire_token_by_device_flow(flow) in your own customized loop.
8585

8686
if "access_token" in result:
87+
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
8788
# Calling graph using the access token
8889
graph_data = requests.get( # Use token to call downstream service
8990
config["endpoint"],

sample/interactive_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def acquire_and_use_token():
7979
)
8080

8181
if "access_token" in result:
82+
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
8283
# Calling graph using the access token
8384
graph_response = requests.get( # Use token to call downstream service
8485
config["endpoint"],

sample/username_password_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def acquire_and_use_token():
6666
config["username"], config["password"], scopes=config["scope"])
6767

6868
if "access_token" in result:
69+
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
6970
# Calling graph using the access token
7071
graph_data = requests.get( # Use token to call downstream service
7172
config["endpoint"],

sample/vault_jwt_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def acquire_and_use_token():
125125
result = global_app.acquire_token_for_client(scopes=config["scope"])
126126

127127
if "access_token" in result:
128+
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
128129
# Calling graph using the access token
129130
graph_data = requests.get( # Use token to call downstream service
130131
config["endpoint"],

tests/test_application.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def tester(url, **kwargs):
109109
self.scopes, self.account, post=tester)
110110
self.assertEqual("", result.get("classification"))
111111

112+
112113
class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase):
113114

114115
def setUp(self):
@@ -263,6 +264,7 @@ def test_get_accounts_should_find_accounts_under_different_alias(self):
263264
def test_acquire_token_silent_should_find_at_under_different_alias(self):
264265
result = self.app.acquire_token_silent(self.scopes, self.account)
265266
self.assertNotEqual(None, result)
267+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
266268
self.assertEqual(self.access_token, result.get('access_token'))
267269

268270
def test_acquire_token_silent_should_find_rt_under_different_alias(self):
@@ -360,6 +362,7 @@ def test_fresh_token_should_be_returned_from_cache(self):
360362
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
361363
self.fail("I/O shouldn't happen in cache hit AT scenario")
362364
)
365+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
363366
self.assertEqual(access_token, result.get("access_token"))
364367
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
365368

@@ -374,6 +377,7 @@ def mock_post(url, headers=None, *args, **kwargs):
374377
"refresh_in": 123,
375378
}))
376379
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
380+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
377381
self.assertEqual(new_access_token, result.get("access_token"))
378382
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
379383

@@ -385,6 +389,7 @@ def mock_post(url, headers=None, *args, **kwargs):
385389
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
386390
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
387391
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
392+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
388393
self.assertEqual(old_at, result.get("access_token"))
389394

390395
def test_expired_token_and_unavailable_aad_should_return_error(self):
@@ -409,6 +414,7 @@ def mock_post(url, headers=None, *args, **kwargs):
409414
"refresh_in": 123,
410415
}))
411416
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
417+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
412418
self.assertEqual(new_access_token, result.get("access_token"))
413419
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
414420

@@ -444,6 +450,7 @@ def test_maintaining_offline_state_and_sending_them(self):
444450
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
445451
self.fail("I/O shouldn't happen in cache hit AT scenario")
446452
)
453+
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE)
447454
self.assertEqual(cached_access_token, result.get("access_token"))
448455

449456
error1 = "error_1"
@@ -477,6 +484,7 @@ def mock_post(url, headers=None, *args, **kwargs):
477484
"The previous error should result in same success counter plus latest error info")
478485
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
479486
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post)
487+
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)
480488
self.assertEqual(at, result.get("access_token"))
481489

482490
def mock_post(url, headers=None, *args, **kwargs):
@@ -485,6 +493,7 @@ def mock_post(url, headers=None, *args, **kwargs):
485493
"The previous success should reset all offline telemetry counters")
486494
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
487495
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post)
496+
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)
488497
self.assertEqual(at, result.get("access_token"))
489498

490499

@@ -503,6 +512,7 @@ def mock_post(url, headers=None, *args, **kwargs):
503512
result = self.app.acquire_token_by_auth_code_flow(
504513
{"state": state, "code_verifier": "bar"}, {"state": state, "code": "012"},
505514
post=mock_post)
515+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
506516
self.assertEqual(at, result.get("access_token"))
507517

508518
def test_acquire_token_by_refresh_token(self):
@@ -511,6 +521,7 @@ def mock_post(url, headers=None, *args, **kwargs):
511521
self.assertEqual("4|85,1|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
512522
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
513523
result = self.app.acquire_token_by_refresh_token("rt", ["s"], post=mock_post)
524+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
514525
self.assertEqual(at, result.get("access_token"))
515526

516527

@@ -529,6 +540,7 @@ def mock_post(url, headers=None, *args, **kwargs):
529540
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
530541
result = self.app.acquire_token_by_device_flow(
531542
{"device_code": "123"}, post=mock_post)
543+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
532544
self.assertEqual(at, result.get("access_token"))
533545

534546
def test_acquire_token_by_username_password(self):
@@ -538,6 +550,7 @@ def mock_post(url, headers=None, *args, **kwargs):
538550
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
539551
result = self.app.acquire_token_by_username_password(
540552
"username", "password", ["scope"], post=mock_post)
553+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
541554
self.assertEqual(at, result.get("access_token"))
542555

543556

@@ -556,6 +569,7 @@ def mock_post(url, headers=None, *args, **kwargs):
556569
"expires_in": 0,
557570
}))
558571
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
572+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
559573
self.assertEqual("AT 1", result.get("access_token"), "Shall get a new token")
560574

561575
def mock_post(url, headers=None, *args, **kwargs):
@@ -566,13 +580,15 @@ def mock_post(url, headers=None, *args, **kwargs):
566580
"refresh_in": -100, # A hack to make sure it will attempt refresh
567581
}))
568582
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
583+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
569584
self.assertEqual("AT 2", result.get("access_token"), "Shall get a new token")
570585

571586
def mock_post(url, headers=None, *args, **kwargs):
572587
# 1/0 # TODO: Make sure this was called
573588
self.assertEqual("4|730,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
574589
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
575590
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
591+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
576592
self.assertEqual("AT 2", result.get("access_token"), "Shall get aging token")
577593

578594
def test_acquire_token_on_behalf_of(self):
@@ -581,6 +597,7 @@ def mock_post(url, headers=None, *args, **kwargs):
581597
self.assertEqual("4|523,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
582598
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
583599
result = self.app.acquire_token_on_behalf_of("assertion", ["s"], post=mock_post)
600+
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
584601
self.assertEqual(at, result.get("access_token"))
585602

586603

0 commit comments

Comments
 (0)