Skip to content

Commit 42a7284

Browse files
TlexCypher荒木 太一hirosassakitagry
authored
Support functionalities to enhance task traceability with metadata for dependency search. (#450)
* WIP: End to implement the logic to gather the required task output path. * WIP: success to add output path in nest mode, but some other case should be handled. * WIP: no ci apply. * feat: fix to pass labels and has_seen_keys. * CI: apply ruff and mypy * feat: add implementation of nest mode. * feat: deal with kitagry comments. * feat: Remove CLI dependencies. * feat: remove redundant statements. * feat: change serialization expression for single FlattenableItems[RequiredTaskOutput]] * CI: fix test and apply CI. * feat: fix mypy error. * feat: refactoring make _list_flatten inner function. * feat: fix nits miss and add __ prefix to avoid conflicts. * feat: rename _list_flatten * feat: convert map object to list, any iterable objects that would be hashed should be list. * feat: add new line to end of param.ini * feat: remove redundant expressions * feat: use yiled to make memory efficient and use functools.reduce to get great readability. * feat: fix type of normalized_labeles_list * chore: change custom_labels type --------- Co-authored-by: 荒木 太一 <[email protected]> Co-authored-by: hirosassa <[email protected]> Co-authored-by: Ryo Kitagawa <[email protected]>
1 parent d94754c commit 42a7284

10 files changed

+202
-36
lines changed

examples/logging.ini

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
[loggers]
2+
keys=root,luigi,luigi-interface,gokart
3+
4+
[handlers]
5+
keys=stderrHandler
6+
7+
[formatters]
8+
keys=simpleFormatter
9+
10+
[logger_root]
11+
level=INFO
12+
handlers=stderrHandler
13+
14+
[logger_gokart]
15+
level=INFO
16+
handlers=stderrHandler
17+
qualname=gokart
18+
propagate=0
19+
20+
[logger_luigi]
21+
level=INFO
22+
handlers=stderrHandler
23+
qualname=luigi
24+
propagate=0
25+
26+
[logger_luigi-interface]
27+
level=INFO
28+
handlers=stderrHandler
29+
qualname=luigi-interface
30+
propagate=0
31+
32+
[handler_stderrHandler]
33+
class=StreamHandler
34+
formatter=simpleFormatter
35+
args=(sys.stdout,)
36+
37+
[formatter_simpleFormatter]
38+
format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s
39+
datefmt=%Y/%m/%d %H:%M:%S
40+
class=logging.Formatter

examples/param.ini

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[TaskOnKart]
2+
workspace_directory=./resource
3+
local_temporary_directory=./resource/tmp
4+
5+
[core]
6+
logging_conf_file=logging.ini
7+

gokart/gcs_obj_metadata_client.py

+49-23
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from __future__ import annotations
22

33
import copy
4+
import functools
5+
import json
46
import re
7+
from collections.abc import Iterable
58
from logging import getLogger
69
from typing import Any
710
from urllib.parse import urlsplit
811

912
from googleapiclient.model import makepatch
1013

1114
from gokart.gcs_config import GCSConfig
15+
from gokart.required_task_output import RequiredTaskOutput
16+
from gokart.utils import FlattenableItems
1217

1318
logger = getLogger(__name__)
1419

@@ -21,7 +26,7 @@ class GCSObjectMetadataClient:
2126

2227
@staticmethod
2328
def _is_log_related_path(path: str) -> bool:
24-
return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
29+
return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
2530

2631
# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
2732
@staticmethod
@@ -32,7 +37,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
3237
return netloc, path_without_initial_slash
3338

3439
@staticmethod
35-
def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
40+
def add_task_state_labels(
41+
path: str,
42+
task_params: dict[str, str] | None = None,
43+
custom_labels: dict[str, str] | None = None,
44+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
45+
) -> None:
3646
if GCSObjectMetadataClient._is_log_related_path(path):
3747
return
3848
# In gokart/object_storage.get_time_stamp, could find same call.
@@ -42,20 +52,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
4252
if _response is None:
4353
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
4454
return
45-
4655
response: dict[str, Any] = dict(_response)
4756
original_metadata: dict[Any, Any] = {}
4857
if 'metadata' in response.keys():
4958
_metadata = response.get('metadata')
5059
if _metadata is not None:
5160
original_metadata = dict(_metadata)
52-
5361
patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
5462
copy.deepcopy(original_metadata),
5563
task_params,
5664
custom_labels,
65+
required_task_outputs,
5766
)
58-
5967
if original_metadata != patched_metadata:
6068
# If we use update api, existing object metadata are removed, so should use patch api.
6169
# See the official document descriptions.
@@ -71,7 +79,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
7179
)
7280
.execute()
7381
)
74-
7582
if update_response is None:
7683
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')
7784

@@ -83,14 +90,14 @@ def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]:
8390
def _get_patched_obj_metadata(
8491
metadata: Any,
8592
task_params: dict[str, str] | None = None,
86-
custom_labels: dict[str, Any] | None = None,
93+
custom_labels: dict[str, str] | None = None,
94+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
8795
) -> dict | Any:
8896
# If metadata from response when getting bucket and object information is not dictionary,
8997
# something wrong might be happened, so return original metadata, no patched.
9098
if not isinstance(metadata, dict):
9199
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
92100
return metadata
93-
94101
if not task_params and not custom_labels:
95102
return metadata
96103
# Maximum size of metadata for each object is 8 KiB.
@@ -101,24 +108,45 @@ def _get_patched_obj_metadata(
101108
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
102109
# Instead, users are expected to search using the labels they provided.
103110
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
104-
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels)
111+
normalized_labels = [normalized_custom_labels, normalized_task_params_labels]
112+
if required_task_outputs:
113+
normalized_labels.append({'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))})
114+
115+
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels)
105116
return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels)
106117

118+
@staticmethod
119+
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
120+
def _iterable_flatten(nested_list: Iterable) -> Iterable[str]:
121+
for item in nested_list:
122+
if isinstance(item, Iterable):
123+
yield from _iterable_flatten(item)
124+
else:
125+
yield item
126+
127+
if isinstance(required_task_outputs, dict):
128+
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
129+
if isinstance(required_task_outputs, Iterable):
130+
return list(_iterable_flatten([GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]))
131+
return [required_task_outputs.serialize()]
132+
107133
@staticmethod
108134
def _merge_custom_labels_and_task_params_labels(
109-
normalized_task_params: dict[str, str],
110-
normalized_custom_labels: dict[str, Any],
135+
normalized_labels_list: list[dict[str, str]],
111136
) -> dict[str, str]:
112-
merged_labels = copy.deepcopy(normalized_custom_labels)
113-
for label_name, label_value in normalized_task_params.items():
114-
if len(label_value) == 0:
115-
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
116-
continue
117-
if label_name in merged_labels.keys():
118-
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
119-
continue
120-
merged_labels[label_name] = label_value
121-
return merged_labels
137+
def __merge_two_dicts_helper(merged: dict[str, str], current_labels: dict[str, str]) -> dict[str, str]:
138+
next_merged = copy.deepcopy(merged)
139+
for label_name, label_value in current_labels.items():
140+
if len(label_value) == 0:
141+
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
142+
continue
143+
if label_name in next_merged:
144+
logger.warning(f'label_name={label_name} is already seen. So skip to add as metadata.')
145+
continue
146+
next_merged[label_name] = label_value
147+
return next_merged
148+
149+
return functools.reduce(__merge_two_dicts_helper, normalized_labels_list, {})
122150

123151
# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
124152
# So, we need to adjust the size of metadata.
@@ -132,10 +160,8 @@ def _get_label_size(label_name: str, label_value: str) -> int:
132160
8 * 1024,
133161
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
134162
)
135-
136163
if current_total_metadata_size <= max_gcs_metadata_size:
137164
return labels
138-
139165
for label_name, label_value in reversed(labels.items()):
140166
size = _get_label_size(label_name, label_value)
141167
del labels[label_name]

gokart/in_memory/target.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from typing import Any
55

66
from gokart.in_memory.repository import InMemoryCacheRepository
7+
from gokart.required_task_output import RequiredTaskOutput
78
from gokart.target import TargetOnKart, TaskLockParams
9+
from gokart.utils import FlattenableItems
810

911
_repository = InMemoryCacheRepository()
1012

@@ -26,7 +28,13 @@ def _get_task_lock_params(self) -> TaskLockParams:
2628
def _load(self) -> Any:
2729
return _repository.get_value(self._data_key)
2830

29-
def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
31+
def _dump(
32+
self,
33+
obj: Any,
34+
task_params: dict[str, str] | None = None,
35+
custom_labels: dict[str, str] | None = None,
36+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
37+
) -> None:
3038
return _repository.set_value(self._data_key, obj)
3139

3240
def _remove(self) -> None:

gokart/required_task_output.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class RequiredTaskOutput:
6+
task_name: str
7+
output_path: str
8+
9+
def serialize(self) -> dict[str, str]:
10+
return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path}

gokart/target.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from gokart.file_processor import FileProcessor, make_file_processor
1919
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
2020
from gokart.object_storage import ObjectStorage
21+
from gokart.required_task_output import RequiredTaskOutput
22+
from gokart.utils import FlattenableItems
2123
from gokart.zip_client_util import make_zip_client
2224

2325
logger = getLogger(__name__)
@@ -30,13 +32,23 @@ def exists(self) -> bool:
3032
def load(self) -> Any:
3133
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()
3234

33-
def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
35+
def dump(
36+
self,
37+
obj,
38+
lock_at_dump: bool = True,
39+
task_params: dict[str, str] | None = None,
40+
custom_labels: dict[str, str] | None = None,
41+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
42+
) -> None:
3443
if lock_at_dump:
3544
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
36-
obj=obj, task_params=task_params, custom_labels=custom_labels
45+
obj=obj,
46+
task_params=task_params,
47+
custom_labels=custom_labels,
48+
required_task_outputs=required_task_outputs,
3749
)
3850
else:
39-
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels)
51+
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs)
4052

4153
def remove(self) -> None:
4254
if self.exists():
@@ -61,7 +73,13 @@ def _load(self) -> Any:
6173
pass
6274

6375
@abstractmethod
64-
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
76+
def _dump(
77+
self,
78+
obj,
79+
task_params: dict[str, str] | None = None,
80+
custom_labels: dict[str, str] | None = None,
81+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
82+
) -> None:
6583
pass
6684

6785
@abstractmethod
@@ -98,11 +116,19 @@ def _load(self) -> Any:
98116
with self._target.open('r') as f:
99117
return self._processor.load(f)
100118

101-
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
119+
def _dump(
120+
self,
121+
obj,
122+
task_params: dict[str, str] | None = None,
123+
custom_labels: dict[str, str] | None = None,
124+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
125+
) -> None:
102126
with self._target.open('w') as f:
103127
self._processor.dump(obj, f)
104128
if self.path().startswith('gs://'):
105-
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels)
129+
GCSObjectMetadataClient.add_task_state_labels(
130+
path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
131+
)
106132

107133
def _remove(self) -> None:
108134
self._target.remove()
@@ -142,10 +168,18 @@ def _load(self) -> Any:
142168
self._remove_temporary_directory()
143169
return model
144170

145-
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
171+
def _dump(
172+
self,
173+
obj,
174+
task_params: dict[str, str] | None = None,
175+
custom_labels: dict[str, str] | None = None,
176+
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
177+
) -> None:
146178
self._make_temporary_directory()
147179
self._save_function(obj, self._model_path())
148-
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)
180+
make_target(self._load_function_path()).dump(
181+
self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
182+
)
149183
self._zip_client.make_archive()
150184
self._remove_temporary_directory()
151185

gokart/task.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from gokart.file_processor import FileProcessor
2323
from gokart.pandas_type_config import PandasTypeConfigMap
2424
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
25+
from gokart.required_task_output import RequiredTaskOutput
2526
from gokart.target import TargetOnKart
2627
from gokart.task_complete_check import task_complete_check_wrapper
27-
from gokart.utils import FlattenableItems, flatten
28+
from gokart.utils import FlattenableItems, flatten, map_flattenable_items
2829

2930
logger = getLogger(__name__)
3031

@@ -337,11 +338,17 @@ def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels
337338
if isinstance(obj, pd.DataFrame) and obj.empty:
338339
raise EmptyDumpError()
339340

341+
required_task_outputs = map_flattenable_items(
342+
lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()),
343+
self.requires(),
344+
)
345+
340346
self._get_output_target(target).dump(
341347
obj,
342348
lock_at_dump=self._lock_at_dump,
343349
task_params=super().to_str_params(only_significant=True, only_public=True),
344350
custom_labels=custom_labels,
351+
required_task_outputs=required_task_outputs,
345352
)
346353

347354
@staticmethod

gokart/utils.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
from collections.abc import Iterable
66
from io import BytesIO
7-
from typing import Any, Protocol, TypeVar, Union
7+
from typing import Any, Callable, Protocol, TypeVar, Union
88

99
import dill
1010
import luigi
@@ -72,6 +72,21 @@ def flatten(targets: FlattenableItems[T]) -> list[T]:
7272
return flat
7373

7474

75+
K = TypeVar('K')
76+
77+
78+
def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]:
79+
if isinstance(items, dict):
80+
return {k: map_flattenable_items(func, v) for k, v in items.items()}
81+
if isinstance(items, tuple):
82+
return tuple(map_flattenable_items(func, i) for i in items)
83+
if isinstance(items, str):
84+
return func(items) # type: ignore
85+
if isinstance(items, Iterable):
86+
return list(map(lambda item: map_flattenable_items(func, item), items))
87+
return func(items)
88+
89+
7590
def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any:
7691
"""Load binary dumped by dill with pandas backward compatibility.
7792
pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle.

0 commit comments

Comments
 (0)