Skip to content

Support functionalities to enhance task traceability with metadata for dependency search. #450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
79a2881
WIP: End to implement the logic to gather the required task output path.
Mar 4, 2025
0cfe7ee
WIP: success to add output path in nest mode, but some other case sho…
Mar 4, 2025
3eee422
WIP: no ci apply.
Mar 5, 2025
ec3bf4f
feat: fix to pass labels and has_seen_keys.
Mar 5, 2025
22a69d0
feat: fix conflicts
Mar 5, 2025
08e3f59
CI: apply ruff and mypy
Mar 5, 2025
9b19a1c
feat: add implementation of nest mode.
Mar 5, 2025
accbf1d
feat: deal with kitagry comments.
Mar 6, 2025
6719f4d
feat: Remove CLI dependencies.
Mar 6, 2025
0bcc16c
feat: remove redundant statements.
Mar 6, 2025
5c41035
feat: change serialization expression for single FlattenableItems[Req…
Mar 6, 2025
0b951ab
CI: fix test and apply CI.
Mar 6, 2025
10795a2
feat: fix mypy error.
Mar 6, 2025
32b4343
feat: refactoring make _list_flatten inner function.
Mar 6, 2025
6f70a41
feat: fix nits miss and add __ prefix to avoid conflicts.
Mar 6, 2025
637f5da
feat: rename _list_flatten
Mar 7, 2025
b607926
Merge: fix conflicts.
TlexCypher Mar 15, 2025
a8059a1
Merge: fix conflicts.
TlexCypher Mar 15, 2025
27b1abd
feat: convert map object to list, any iterable objects that would be …
TlexCypher Apr 17, 2025
5ac1c4d
Merge remote-tracking branch 'origin/master' into feat/nestmode
TlexCypher Apr 17, 2025
f4479da
Merge remote-tracking branch 'origin/feat/nestmode' into feat/nestmode
TlexCypher Apr 17, 2025
e71833b
feat: add new line to end of param.ini
TlexCypher Apr 22, 2025
46aabcf
feat: remove redundant expressions
TlexCypher Apr 22, 2025
7bde3b0
Merge branch 'master' into feat/nestmode
hirosassa Apr 24, 2025
4c44cea
feat: use yiled to make memory efficient and use functools.reduce to …
TlexCypher Apr 28, 2025
dd6a629
Merge remote-tracking branch 'origin/feat/nestmode' into feat/nestmode
TlexCypher Apr 28, 2025
d884c79
Merge branch 'master' into feat/nestmode
hirosassa Apr 28, 2025
f1418f8
feat: fix type of normalized_labeles_list
TlexCypher Apr 28, 2025
6a1c4c2
Merge remote-tracking branch 'origin/feat/nestmode' into feat/nestmode
TlexCypher Apr 28, 2025
0b06455
chore: change custom_labels type
kitagry Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/logging.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[loggers]
keys=root,luigi,luigi-interface,gokart

[handlers]
keys=stderrHandler

[formatters]
keys=simpleFormatter

[logger_root]
level=INFO
handlers=stderrHandler

[logger_gokart]
level=INFO
handlers=stderrHandler
qualname=gokart
propagate=0

[logger_luigi]
level=INFO
handlers=stderrHandler
qualname=luigi
propagate=0

[logger_luigi-interface]
level=INFO
handlers=stderrHandler
qualname=luigi-interface
propagate=0

[handler_stderrHandler]
class=StreamHandler
formatter=simpleFormatter
args=(sys.stdout,)

[formatter_simpleFormatter]
format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s
datefmt=%Y/%m/%d %H:%M:%S
class=logging.Formatter
7 changes: 7 additions & 0 deletions examples/param.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[TaskOnKart]
workspace_directory=./resource
local_temporary_directory=./resource/tmp

[core]
logging_conf_file=logging.ini

72 changes: 49 additions & 23 deletions gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from __future__ import annotations

import copy
import functools
import json
import re
from collections.abc import Iterable
from logging import getLogger
from typing import Any
from urllib.parse import urlsplit

from googleapiclient.model import makepatch

from gokart.gcs_config import GCSConfig
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems

logger = getLogger(__name__)

Expand All @@ -21,7 +26,7 @@ class GCSObjectMetadataClient:

@staticmethod
def _is_log_related_path(path: str) -> bool:
return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None

# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
@staticmethod
Expand All @@ -32,7 +37,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
return netloc, path_without_initial_slash

@staticmethod
def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def add_task_state_labels(
path: str,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
if GCSObjectMetadataClient._is_log_related_path(path):
return
# In gokart/object_storage.get_time_stamp, could find same call.
Expand All @@ -42,20 +52,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
if _response is None:
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
return

response: dict[str, Any] = dict(_response)
original_metadata: dict[Any, Any] = {}
if 'metadata' in response.keys():
_metadata = response.get('metadata')
if _metadata is not None:
original_metadata = dict(_metadata)

patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
copy.deepcopy(original_metadata),
task_params,
custom_labels,
required_task_outputs,
)

if original_metadata != patched_metadata:
# If we use update api, existing object metadata are removed, so should use patch api.
# See the official document descriptions.
Expand All @@ -71,7 +79,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
)
.execute()
)

if update_response is None:
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')

Expand All @@ -83,14 +90,14 @@ def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]:
def _get_patched_obj_metadata(
metadata: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> dict | Any:
# If metadata from response when getting bucket and object information is not dictionary,
# something wrong might be happened, so return original metadata, no patched.
if not isinstance(metadata, dict):
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
return metadata

if not task_params and not custom_labels:
return metadata
# Maximum size of metadata for each object is 8 KiB.
Expand All @@ -101,24 +108,45 @@ def _get_patched_obj_metadata(
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
# Instead, users are expected to search using the labels they provided.
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels)
normalized_labels = [normalized_custom_labels, normalized_task_params_labels]
if required_task_outputs:
normalized_labels.append({'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))})

_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels)
return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels)

@staticmethod
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
def _iterable_flatten(nested_list: Iterable) -> Iterable[str]:
for item in nested_list:
if isinstance(item, Iterable):
yield from _iterable_flatten(item)
else:
yield item

if isinstance(required_task_outputs, dict):
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
if isinstance(required_task_outputs, Iterable):
return list(_iterable_flatten([GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]))
return [required_task_outputs.serialize()]

@staticmethod
def _merge_custom_labels_and_task_params_labels(
normalized_task_params: dict[str, str],
normalized_custom_labels: dict[str, Any],
normalized_labels_list: list[dict[str, str]],
) -> dict[str, str]:
merged_labels = copy.deepcopy(normalized_custom_labels)
for label_name, label_value in normalized_task_params.items():
if len(label_value) == 0:
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in merged_labels.keys():
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
continue
merged_labels[label_name] = label_value
return merged_labels
def __merge_two_dicts_helper(merged: dict[str, str], current_labels: dict[str, str]) -> dict[str, str]:
next_merged = copy.deepcopy(merged)
for label_name, label_value in current_labels.items():
if len(label_value) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MUST] This code may fail, since it seems to assume that label_value is str.

I prefer checking if it is str, and then check the length as,

isinstance(label_value, str) and len(label_value)==0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing my code!
In my opinion, type checking is not necessary, because GCSObjectMetadataClient._normalize_labels convert all values stored in dictionary into string.
So, label_value definitely is string.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TlexCypher
Then maybe the input normalized_labels_list: list[dict[str, Any]] should be normalized_labels_list: list[dict[str, str]] ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TlexCypher Colud you check this comment?

If you are confirmed that label_value is str, you should str instead of Any

Copy link
Member

@kitagry kitagry Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed here 0b06455

logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in next_merged:
logger.warning(f'label_name={label_name} is already seen. So skip to add as metadata.')
continue
next_merged[label_name] = label_value
return next_merged

return functools.reduce(__merge_two_dicts_helper, normalized_labels_list, {})

# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
# So, we need to adjust the size of metadata.
Expand All @@ -132,10 +160,8 @@ def _get_label_size(label_name: str, label_value: str) -> int:
8 * 1024,
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
)

if current_total_metadata_size <= max_gcs_metadata_size:
return labels

for label_name, label_value in reversed(labels.items()):
size = _get_label_size(label_name, label_value)
del labels[label_name]
Expand Down
10 changes: 9 additions & 1 deletion gokart/in_memory/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Any

from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart, TaskLockParams
from gokart.utils import FlattenableItems

_repository = InMemoryCacheRepository()

Expand All @@ -26,7 +28,13 @@ def _get_task_lock_params(self) -> TaskLockParams:
def _load(self) -> Any:
return _repository.get_value(self._data_key)

def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
return _repository.set_value(self._data_key, obj)

def _remove(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions gokart/required_task_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class RequiredTaskOutput:
task_name: str
output_path: str

def serialize(self) -> dict[str, str]:
return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path}
50 changes: 42 additions & 8 deletions gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
from gokart.object_storage import ObjectStorage
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems
from gokart.zip_client_util import make_zip_client

logger = getLogger(__name__)
Expand All @@ -30,13 +32,23 @@ def exists(self) -> bool:
def load(self) -> Any:
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()

def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def dump(
self,
obj,
lock_at_dump: bool = True,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
if lock_at_dump:
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
obj=obj, task_params=task_params, custom_labels=custom_labels
obj=obj,
task_params=task_params,
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)
else:
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels)
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs)

def remove(self) -> None:
if self.exists():
Expand All @@ -61,7 +73,13 @@ def _load(self) -> Any:
pass

@abstractmethod
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -98,11 +116,19 @@ def _load(self) -> Any:
with self._target.open('r') as f:
return self._processor.load(f)

def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
with self._target.open('w') as f:
self._processor.dump(obj, f)
if self.path().startswith('gs://'):
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels)
GCSObjectMetadataClient.add_task_state_labels(
path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)

def _remove(self) -> None:
self._target.remove()
Expand Down Expand Up @@ -142,10 +168,18 @@ def _load(self) -> Any:
self._remove_temporary_directory()
return model

def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
self._make_temporary_directory()
self._save_function(obj, self._model_path())
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)
make_target(self._load_function_path()).dump(
self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)
self._zip_client.make_archive()
self._remove_temporary_directory()

Expand Down
9 changes: 8 additions & 1 deletion gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from gokart.file_processor import FileProcessor
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper
from gokart.utils import FlattenableItems, flatten
from gokart.utils import FlattenableItems, flatten, map_flattenable_items

logger = getLogger(__name__)

Expand Down Expand Up @@ -337,11 +338,17 @@ def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels
if isinstance(obj, pd.DataFrame) and obj.empty:
raise EmptyDumpError()

required_task_outputs = map_flattenable_items(
lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()),
self.requires(),
)

self._get_output_target(target).dump(
obj,
lock_at_dump=self._lock_at_dump,
task_params=super().to_str_params(only_significant=True, only_public=True),
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)

@staticmethod
Expand Down
17 changes: 16 additions & 1 deletion gokart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from collections.abc import Iterable
from io import BytesIO
from typing import Any, Protocol, TypeVar, Union
from typing import Any, Callable, Protocol, TypeVar, Union

import dill
import luigi
Expand Down Expand Up @@ -72,6 +72,21 @@ def flatten(targets: FlattenableItems[T]) -> list[T]:
return flat


K = TypeVar('K')


def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]:
if isinstance(items, dict):
return {k: map_flattenable_items(func, v) for k, v in items.items()}
if isinstance(items, tuple):
return tuple(map_flattenable_items(func, i) for i in items)
if isinstance(items, str):
return func(items) # type: ignore
if isinstance(items, Iterable):
return list(map(lambda item: map_flattenable_items(func, item), items))
return func(items)


def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any:
"""Load binary dumped by dill with pandas backward compatibility.
pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle.
Expand Down
Loading