diff --git a/tensorflow_datasets/core/logging/__init__.py b/tensorflow_datasets/core/logging/__init__.py index ece4b67dc76..984166c8776 100644 --- a/tensorflow_datasets/core/logging/__init__.py +++ b/tensorflow_datasets/core/logging/__init__.py @@ -20,7 +20,7 @@ import collections import functools import threading -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, TypeVar from absl import flags from tensorflow_datasets.core.logging import base_logger @@ -33,15 +33,15 @@ _LoggerMethod = Callable[..., None] -_registered_loggers: Optional[List[base_logger.Logger]] = None +_registered_loggers: list[base_logger.Logger] | None = None -_import_operations: List[Tuple[call_metadata.CallMetadata, int, int]] = [] +_import_operations: list[tuple[call_metadata.CallMetadata, int, int, int]] = [] _import_operations_lock = threading.Lock() _thread_id_to_builder_init_count = collections.Counter() -def _init_registered_loggers() -> List[base_logger.Logger]: +def _init_registered_loggers() -> list[base_logger.Logger]: """Initializes the registered loggers if they are not set yet.""" global _registered_loggers if _registered_loggers is None: @@ -54,17 +54,23 @@ def _init_registered_loggers() -> List[base_logger.Logger]: def _log_import_operation(): """Log import operations (most of time maximum one), if any.""" with _import_operations_lock: - for metadata, import_time_tf, import_time_builders in _import_operations: + for ( + metadata, + import_time_tf, + import_time_builders, + import_time_ms, + ) in _import_operations: for logger in _init_registered_loggers(): logger.tfds_import( metadata=metadata, import_time_ms_tensorflow=import_time_tf, import_time_ms_dataset_builders=import_time_builders, + import_time_ms=import_time_ms, ) _import_operations.clear() -def _get_registered_loggers() -> List[base_logger.Logger]: +def _get_registered_loggers() -> list[base_logger.Logger]: _log_import_operation() return _init_registered_loggers() @@ -181,7 +187,7 @@ class _DsbuilderMethodDecorator(_FunctionDecorator): IS_PROPERTY: bool = False @staticmethod - def _get_info(dsbuilder: Any) -> Tuple[str, str, str, str]: + def _get_info(dsbuilder: Any) -> tuple[str, str, str, str]: """Gets information about the builder. Args: @@ -256,8 +262,9 @@ def register(logger: base_logger.Logger) -> None: def tfds_import( *, metadata: call_metadata.CallMetadata, - import_time_ms_tensorflow: int, - import_time_ms_dataset_builders: int, + import_time_ms_tensorflow: int = 0, + import_time_ms_dataset_builders: int = 0, + import_time_ms: int = 0, ): """Call `tfds_import` on registered loggers. @@ -271,11 +278,15 @@ def tfds_import( import_time_ms_tensorflow: time (ms) it took to import TF. import_time_ms_dataset_builders: time (ms) it took to import DatasetBuilder modules. + import_time_ms: time (ms) it took to import the module. """ with _import_operations_lock: - _import_operations.append( - (metadata, import_time_ms_tensorflow, import_time_ms_dataset_builders) - ) + _import_operations.append(( + metadata, + import_time_ms_tensorflow, + import_time_ms_dataset_builders, + import_time_ms, + )) def builder_init(is_read_only_builder: bool = False) -> _Decorator: diff --git a/tensorflow_datasets/core/logging/base_logger.py b/tensorflow_datasets/core/logging/base_logger.py index 6b87aeb520b..d1469f73c34 100644 --- a/tensorflow_datasets/core/logging/base_logger.py +++ b/tensorflow_datasets/core/logging/base_logger.py @@ -43,6 +43,7 @@ def tfds_import( metadata: call_metadata.CallMetadata, import_time_ms_tensorflow: int, import_time_ms_dataset_builders: int, + import_time_ms: int, ): """Callback called when user calls `import tensorflow_datasets`.""" pass