12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- """A couple of authentication types in ODPS.
16
- """
15
+ """A couple of authentication types in ODPS."""
17
16
18
17
import base64
19
- import hmac
18
+ import calendar
20
19
import hashlib
20
+ import hmac
21
+ import json
21
22
import logging
22
23
import os
23
24
import threading
27
28
28
29
import requests
29
30
30
- from .compat import six , cgi , urlparse , unquote , parse_qsl
31
31
from . import options , utils
32
-
32
+ from . compat import cgi , parse_qsl , six , unquote , urlparse
33
33
34
34
logger = logging .getLogger (__name__ )
35
35
36
- DEFAULT_BEARER_TOKEN_HOURS = 5
36
+ DEFAULT_TEMP_ACCOUNT_HOURS = 5
37
37
38
38
39
39
class BaseAccount (object ):
@@ -138,20 +138,6 @@ def sign_request(self, req, endpoint, region_name=None):
138
138
logger .debug ('headers after signing: %r' , req .headers )
139
139
140
140
141
- class StsAccount (AliyunAccount ):
142
- """
143
- Account of sts
144
- """
145
- def __init__ (self , access_id , secret_access_key , sts_token ):
146
- super (StsAccount , self ).__init__ (access_id , secret_access_key )
147
- self .sts_token = sts_token
148
-
149
- def sign_request (self , req , endpoint , region_name = None ):
150
- super (StsAccount , self ).sign_request (req , endpoint , region_name = region_name )
151
- if self .sts_token :
152
- req .headers ['authorization-sts-token' ] = self .sts_token
153
-
154
-
155
141
class AppAccount (BaseAccount ):
156
142
"""
157
143
Account for applications.
@@ -352,77 +338,152 @@ def sign_request(self, req, endpoint, region_name=None):
352
338
)
353
339
354
340
355
- class BearerTokenAccount (BaseAccount ):
341
+ class TempAccountMixin (object ):
342
+ def __init__ (self , expired_hours = DEFAULT_TEMP_ACCOUNT_HOURS ):
343
+ self ._last_modified_time = datetime .now ()
344
+ if expired_hours is not None :
345
+ self ._expired_time = timedelta (hours = expired_hours )
346
+ else :
347
+ self ._expired_time = None
348
+ self .reload ()
349
+
350
+ def _is_account_valid (self ):
351
+ raise NotImplementedError
352
+
353
+ def _reload_account (self ):
354
+ raise NotImplementedError
355
+
356
+ def reload (self , force = False ):
357
+ t = datetime .now ()
358
+ if (
359
+ force
360
+ or not self ._is_account_valid ()
361
+ or (
362
+ self ._last_modified_time is not None
363
+ and self ._expired_time is not None
364
+ and (t - self ._last_modified_time ) > self ._expired_time
365
+ )
366
+ ):
367
+ self ._last_modified_time = self ._reload_account () or datetime .now ()
368
+
369
+
370
+ class StsAccount (TempAccountMixin , AliyunAccount ):
371
+ """
372
+ Account of sts
373
+ """
374
+
356
375
def __init__ (
357
- self , token = None , expired_hours = DEFAULT_BEARER_TOKEN_HOURS , get_bearer_token_fun = None
376
+ self ,
377
+ access_id ,
378
+ secret_access_key ,
379
+ sts_token ,
380
+ expired_hours = DEFAULT_TEMP_ACCOUNT_HOURS ,
358
381
):
359
- self ._get_bearer_token = get_bearer_token_fun or self . get_default_bearer_token
360
- self . _token = token or self . _get_bearer_token ( )
361
- self . _reload_bearer_token_time ( )
382
+ self .sts_token = sts_token
383
+ AliyunAccount . __init__ ( self , access_id , secret_access_key )
384
+ TempAccountMixin . __init__ ( self , expired_hours = expired_hours )
362
385
363
- self ._expired_time = timedelta (hours = expired_hours )
386
+ @classmethod
387
+ def from_environments (cls ):
388
+ expired_hours = int (
389
+ os .getenv ("ODPS_STS_TOKEN_HOURS" , str (DEFAULT_TEMP_ACCOUNT_HOURS ))
390
+ )
391
+ if "ODPS_STS_ACCOUNT_FILE" in os .environ or "ODPS_STS_TOKEN" in os .environ :
392
+ if "ODPS_STS_ACCOUNT_FILE" not in os .environ :
393
+ expired_hours = None
394
+ return cls (None , None , None , expired_hours = expired_hours )
395
+ return None
396
+
397
+ def sign_request (self , req , endpoint , region_name = None ):
398
+ self .reload ()
399
+ super (StsAccount , self ).sign_request (req , endpoint , region_name = region_name )
400
+ if self .sts_token :
401
+ req .headers ["authorization-sts-token" ] = self .sts_token
402
+
403
+ def _is_account_valid (self ):
404
+ return self .sts_token is not None
405
+
406
+ def _resolve_expiration (self , exp_data ):
407
+ if exp_data is None or self ._expired_time is None :
408
+ return None
409
+ try :
410
+ ts = calendar .timegm (time .strptime (exp_data , "%Y-%m-%dT%H:%M:%SZ" ))
411
+ return ts - self ._expired_time .total_seconds ()
412
+ except :
413
+ return None
414
+
415
+ def _reload_account (self ):
416
+ ts = None
417
+ if "ODPS_STS_ACCOUNT_FILE" in os .environ :
418
+ token_file_name = os .getenv ("ODPS_STS_ACCOUNT_FILE" )
419
+ if token_file_name and os .path .exists (token_file_name ):
420
+ with open (token_file_name , "r" ) as token_file :
421
+ token_json = json .load (token_file )
422
+ self .access_id = token_json ["accessKeyId" ]
423
+ self .secret_access_key = token_json ["accessKeySecret" ]
424
+ self .sts_token = token_json ["securityToken" ]
425
+ ts = self ._resolve_expiration (token_json .get ("expiration" ))
426
+ elif "ODPS_STS_ACCESS_KEY_ID" in os .environ :
427
+ self .access_id = os .getenv ("ODPS_STS_ACCESS_KEY_ID" )
428
+ self .secret_access_key = os .getenv ("ODPS_STS_ACCESS_KEY_SECRET" )
429
+ self .sts_token = os .getenv ("ODPS_STS_TOKEN" )
430
+
431
+ return datetime .fromtimestamp (ts ) if ts is not None else None
432
+
433
+
434
+ class BearerTokenAccount (TempAccountMixin , BaseAccount ):
435
+ def __init__ (
436
+ self , token = None , expired_hours = DEFAULT_TEMP_ACCOUNT_HOURS , get_bearer_token_fun = None
437
+ ):
438
+ self .token = token
439
+ self ._custom_bearer_token_func = get_bearer_token_fun
440
+ TempAccountMixin .__init__ (self , expired_hours = expired_hours )
364
441
365
442
@classmethod
366
443
def from_environments (cls ):
367
- expired_hours = int (os .getenv ('ODPS_BEARER_TOKEN_HOURS' , str (DEFAULT_BEARER_TOKEN_HOURS )))
444
+ expired_hours = int (os .getenv ('ODPS_BEARER_TOKEN_HOURS' , str (DEFAULT_TEMP_ACCOUNT_HOURS )))
368
445
kwargs = {"expired_hours" : expired_hours }
369
- if 'ODPS_BEARER_TOKEN' in os .environ :
370
- return cls (os .environ ['ODPS_BEARER_TOKEN' ], ** kwargs )
371
- elif 'ODPS_BEARER_TOKEN_FILE' in os .environ :
446
+ if "ODPS_BEARER_TOKEN_FILE" in os .environ :
372
447
return cls (** kwargs )
448
+ elif "ODPS_BEARER_TOKEN" in os .environ :
449
+ kwargs ["expired_hours" ] = None
450
+ return cls (os .environ ["ODPS_BEARER_TOKEN" ], ** kwargs )
373
451
return None
374
452
375
- @staticmethod
376
- def get_default_bearer_token ():
453
+ def _get_bearer_token (self ):
454
+ if self ._custom_bearer_token_func is not None :
455
+ return self ._custom_bearer_token_func ()
456
+
377
457
token_file_name = os .getenv ("ODPS_BEARER_TOKEN_FILE" )
378
458
if token_file_name and os .path .exists (token_file_name ):
379
459
with open (token_file_name , "r" ) as token_file :
380
460
return token_file .read ().strip ()
461
+ else : # pragma: no cover
462
+ from cupid .runtime import context , RuntimeContext
381
463
382
- from cupid .runtime import context , RuntimeContext
383
-
384
- if not RuntimeContext .is_context_ready ():
385
- return
386
- cupid_context = context ()
387
- return cupid_context .get_bearer_token ()
388
-
389
- def get_bearer_token_and_timestamp (self ):
390
- self ._check_bearer_token ()
391
- return self ._token , self ._last_modified_time .timestamp ()
392
-
393
- def _reload_bearer_token_time (self ):
394
- if "ODPS_BEARER_TOKEN_TIMESTAMP_FILE" in os .environ :
395
- with open (os .getenv ("ODPS_BEARER_TOKEN_TIMESTAMP_FILE" ), "r" ) as ts_file :
396
- self ._last_modified_time = datetime .fromtimestamp (float (ts_file .read ()))
397
- else :
398
- self ._last_modified_time = datetime .now ()
399
-
400
- def _check_bearer_token (self ):
401
- t = datetime .now ()
402
- if self ._last_modified_time is None :
403
- token = self ._get_bearer_token ()
404
- if token is None :
405
- return
406
- if token != self ._token :
407
- self ._token = token
408
- self ._reload_bearer_token_time ()
409
- elif (t - self ._last_modified_time ) > self ._expired_time :
410
- token = self ._get_bearer_token ()
411
- if token is None :
464
+ if not RuntimeContext .is_context_ready ():
412
465
return
413
- self . _token = token
414
- self . _reload_bearer_token_time ()
466
+ cupid_context = context ()
467
+ return cupid_context . get_bearer_token ()
415
468
416
- @property
417
- def token (self ):
418
- return self ._token
469
+ def _is_account_valid (self ):
470
+ return self .token is not None
471
+
472
+ def _reload_account (self ):
473
+ token = self ._get_bearer_token ()
474
+ self .token = token
475
+ try :
476
+ resolved_token_parts = base64 .b64decode (token ).decode ().split ("," )
477
+ return datetime .fromtimestamp (int (resolved_token_parts [2 ]))
478
+ except :
479
+ return None
419
480
420
481
def sign_request (self , req , endpoint , region_name = None ):
421
- self ._check_bearer_token ()
482
+ self .reload ()
422
483
url = req .url [len (endpoint ):]
423
484
url_components = urlparse (unquote (url ), allow_fragments = False )
424
485
self ._build_canonical_str (url_components , req )
425
- req .headers ['x-odps-bearer-token' ] = self ._token
486
+ req .headers ['x-odps-bearer-token' ] = self .token
426
487
logger .debug ('headers after signing: %r' , req .headers )
427
488
428
489
@@ -432,11 +493,22 @@ def __init__(self, credential_provider):
432
493
super (CredentialProviderAccount , self ).__init__ (None , None , None )
433
494
434
495
def sign_request (self , req , endpoint , region_name = None ):
435
- credential = self .provider .get_credentials ()
496
+ get_cred_method = getattr (self .provider , "get_credential" , None ) or getattr (
497
+ self .provider , "get_credentials"
498
+ )
499
+ credential = get_cred_method ()
436
500
437
501
self .access_id = credential .get_access_key_id ()
438
502
self .secret_access_key = credential .get_access_key_secret ()
439
503
self .sts_token = credential .get_security_token ()
440
504
return super (CredentialProviderAccount , self ).sign_request (
441
505
req , endpoint , region_name = region_name
442
506
)
507
+
508
+
509
+ def from_environments ():
510
+ for account_cls in (StsAccount , BearerTokenAccount ):
511
+ account = account_cls .from_environments ()
512
+ if account is not None :
513
+ break
514
+ return account
0 commit comments