-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathgs_storage_client_factory.py
More file actions
79 lines (59 loc) · 2.43 KB
/
gs_storage_client_factory.py
File metadata and controls
79 lines (59 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import threading
_client_cache = dict()
def _get_cache_key():
return os.getpid(), threading.get_ident()
def _get_gs_storage_client_default():
cache_key = _get_cache_key()
if cache_key not in _client_cache:
from google.cloud import storage
if os.environ.get("STORAGE_EMULATOR_HOST"):
# When a storage emulator is configured, create a plain Client()
# which auto-detects the emulator and uses anonymous credentials.
# Calling google.auth.default() would fail without real GCP creds.
_client_cache[cache_key] = storage.Client()
else:
import google.auth
credentials, project_id = google.auth.default(scopes=storage.Client.SCOPE)
_client_cache[cache_key] = storage.Client(
credentials=credentials, project=project_id
)
return _client_cache[cache_key]
class GcpDefaultClientProvider(object):
name = "gcp-default"
@staticmethod
def get_gs_storage_client(*args, **kwargs):
return _get_gs_storage_client_default()
@staticmethod
def get_credentials(scopes, *args, **kwargs):
import google.auth
return google.auth.default(scopes=scopes)
cached_provider_class = None
def get_gs_storage_client():
global cached_provider_class
if cached_provider_class is None:
from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER
from metaflow.plugins import GCP_CLIENT_PROVIDERS
for p in GCP_CLIENT_PROVIDERS:
if p.name == DEFAULT_GCP_CLIENT_PROVIDER:
cached_provider_class = p
break
else:
raise ValueError(
"Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER
)
return cached_provider_class.get_gs_storage_client()
def get_credentials(scopes, *args, **kwargs):
global cached_provider_class
if cached_provider_class is None:
from metaflow.metaflow_config import DEFAULT_GCP_CLIENT_PROVIDER
from metaflow.plugins import GCP_CLIENT_PROVIDERS
for p in GCP_CLIENT_PROVIDERS:
if p.name == DEFAULT_GCP_CLIENT_PROVIDER:
cached_provider_class = p
break
else:
raise ValueError(
"Cannot find GCP Client provider %s" % DEFAULT_GCP_CLIENT_PROVIDER
)
return cached_provider_class.get_credentials(scopes, *args, **kwargs)