1
1
from __future__ import annotations
2
2
3
3
import copy
4
+ import functools
5
+ import json
4
6
import re
7
+ from collections .abc import Iterable
5
8
from logging import getLogger
6
9
from typing import Any
7
10
from urllib .parse import urlsplit
8
11
9
12
from googleapiclient .model import makepatch
10
13
11
14
from gokart .gcs_config import GCSConfig
15
+ from gokart .required_task_output import RequiredTaskOutput
16
+ from gokart .utils import FlattenableItems
12
17
13
18
logger = getLogger (__name__ )
14
19
@@ -21,7 +26,7 @@ class GCSObjectMetadataClient:
21
26
22
27
@staticmethod
23
28
def _is_log_related_path (path : str ) -> bool :
24
- return re .match (r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+' , path ) is not None
29
+ return re .match (r'^gs://.+?/ log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+' , path ) is not None
25
30
26
31
# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
27
32
@staticmethod
@@ -32,7 +37,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
32
37
return netloc , path_without_initial_slash
33
38
34
39
@staticmethod
35
- def add_task_state_labels (path : str , task_params : dict [str , str ] | None = None , custom_labels : dict [str , Any ] | None = None ) -> None :
40
+ def add_task_state_labels (
41
+ path : str ,
42
+ task_params : dict [str , str ] | None = None ,
43
+ custom_labels : dict [str , str ] | None = None ,
44
+ required_task_outputs : FlattenableItems [RequiredTaskOutput ] | None = None ,
45
+ ) -> None :
36
46
if GCSObjectMetadataClient ._is_log_related_path (path ):
37
47
return
38
48
# In gokart/object_storage.get_time_stamp, could find same call.
@@ -42,20 +52,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
42
52
if _response is None :
43
53
logger .error (f'failed to get object from GCS bucket { bucket } and object { obj } .' )
44
54
return
45
-
46
55
response : dict [str , Any ] = dict (_response )
47
56
original_metadata : dict [Any , Any ] = {}
48
57
if 'metadata' in response .keys ():
49
58
_metadata = response .get ('metadata' )
50
59
if _metadata is not None :
51
60
original_metadata = dict (_metadata )
52
-
53
61
patched_metadata = GCSObjectMetadataClient ._get_patched_obj_metadata (
54
62
copy .deepcopy (original_metadata ),
55
63
task_params ,
56
64
custom_labels ,
65
+ required_task_outputs ,
57
66
)
58
-
59
67
if original_metadata != patched_metadata :
60
68
# If we use update api, existing object metadata are removed, so should use patch api.
61
69
# See the official document descriptions.
@@ -71,7 +79,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
71
79
)
72
80
.execute ()
73
81
)
74
-
75
82
if update_response is None :
76
83
logger .error (f'failed to patch object { obj } in bucket { bucket } and object { obj } .' )
77
84
@@ -83,14 +90,14 @@ def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]:
83
90
def _get_patched_obj_metadata (
84
91
metadata : Any ,
85
92
task_params : dict [str , str ] | None = None ,
86
- custom_labels : dict [str , Any ] | None = None ,
93
+ custom_labels : dict [str , str ] | None = None ,
94
+ required_task_outputs : FlattenableItems [RequiredTaskOutput ] | None = None ,
87
95
) -> dict | Any :
88
96
# If metadata from response when getting bucket and object information is not dictionary,
89
97
# something wrong might be happened, so return original metadata, no patched.
90
98
if not isinstance (metadata , dict ):
91
99
logger .warning (f'metadata is not a dict: { metadata } , something wrong was happened when getting response when get bucket and object information.' )
92
100
return metadata
93
-
94
101
if not task_params and not custom_labels :
95
102
return metadata
96
103
# Maximum size of metadata for each object is 8 KiB.
@@ -101,24 +108,45 @@ def _get_patched_obj_metadata(
101
108
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
102
109
# Instead, users are expected to search using the labels they provided.
103
110
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
104
- _merged_labels = GCSObjectMetadataClient ._merge_custom_labels_and_task_params_labels (normalized_task_params_labels , normalized_custom_labels )
111
+ normalized_labels = [normalized_custom_labels , normalized_task_params_labels ]
112
+ if required_task_outputs :
113
+ normalized_labels .append ({'__required_task_outputs' : json .dumps (GCSObjectMetadataClient ._get_serialized_string (required_task_outputs ))})
114
+
115
+ _merged_labels = GCSObjectMetadataClient ._merge_custom_labels_and_task_params_labels (normalized_labels )
105
116
return GCSObjectMetadataClient ._adjust_gcs_metadata_limit_size (dict (metadata ) | _merged_labels )
106
117
118
+ @staticmethod
119
+ def _get_serialized_string (required_task_outputs : FlattenableItems [RequiredTaskOutput ]) -> FlattenableItems [str ]:
120
+ def _iterable_flatten (nested_list : Iterable ) -> Iterable [str ]:
121
+ for item in nested_list :
122
+ if isinstance (item , Iterable ):
123
+ yield from _iterable_flatten (item )
124
+ else :
125
+ yield item
126
+
127
+ if isinstance (required_task_outputs , dict ):
128
+ return {k : GCSObjectMetadataClient ._get_serialized_string (v ) for k , v in required_task_outputs .items ()}
129
+ if isinstance (required_task_outputs , Iterable ):
130
+ return list (_iterable_flatten ([GCSObjectMetadataClient ._get_serialized_string (ro ) for ro in required_task_outputs ]))
131
+ return [required_task_outputs .serialize ()]
132
+
107
133
@staticmethod
108
134
def _merge_custom_labels_and_task_params_labels (
109
- normalized_task_params : dict [str , str ],
110
- normalized_custom_labels : dict [str , Any ],
135
+ normalized_labels_list : list [dict [str , str ]],
111
136
) -> dict [str , str ]:
112
- merged_labels = copy .deepcopy (normalized_custom_labels )
113
- for label_name , label_value in normalized_task_params .items ():
114
- if len (label_value ) == 0 :
115
- logger .warning (f'value of label_name={ label_name } is empty. So skip to add as a metadata.' )
116
- continue
117
- if label_name in merged_labels .keys ():
118
- logger .warning (f'label_name={ label_name } is already seen. So skip to add as a metadata.' )
119
- continue
120
- merged_labels [label_name ] = label_value
121
- return merged_labels
137
+ def __merge_two_dicts_helper (merged : dict [str , str ], current_labels : dict [str , str ]) -> dict [str , str ]:
138
+ next_merged = copy .deepcopy (merged )
139
+ for label_name , label_value in current_labels .items ():
140
+ if len (label_value ) == 0 :
141
+ logger .warning (f'value of label_name={ label_name } is empty. So skip to add as a metadata.' )
142
+ continue
143
+ if label_name in next_merged :
144
+ logger .warning (f'label_name={ label_name } is already seen. So skip to add as metadata.' )
145
+ continue
146
+ next_merged [label_name ] = label_value
147
+ return next_merged
148
+
149
+ return functools .reduce (__merge_two_dicts_helper , normalized_labels_list , {})
122
150
123
151
# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
124
152
# So, we need to adjust the size of metadata.
@@ -132,10 +160,8 @@ def _get_label_size(label_name: str, label_value: str) -> int:
132
160
8 * 1024 ,
133
161
sum (_get_label_size (label_name , label_value ) for label_name , label_value in labels .items ()),
134
162
)
135
-
136
163
if current_total_metadata_size <= max_gcs_metadata_size :
137
164
return labels
138
-
139
165
for label_name , label_value in reversed (labels .items ()):
140
166
size = _get_label_size (label_name , label_value )
141
167
del labels [label_name ]
0 commit comments