Skip to content

Commit 0fc15e9

Browse files
committed
Update with code from sensAI v1.3.0
1 parent 2824267 commit 0fc15e9

9 files changed

Lines changed: 427 additions & 59 deletions

File tree

src/sensai/util/cache.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import threading
1010
import time
1111
from abc import abstractmethod, ABC
12+
from collections import OrderedDict
1213
from functools import wraps
1314
from pathlib import Path
14-
from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union
15+
from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union, Hashable
1516

1617
from .hash import pickle_hash
1718
from .pickle import load_pickle, dump_pickle, setstate
@@ -20,6 +21,7 @@
2021

2122
T = TypeVar("T")
2223
TKey = TypeVar("TKey")
24+
THashableKey = TypeVar("THashableKey", bound=Hashable)
2325
TValue = TypeVar("TValue")
2426
TData = TypeVar("TData")
2527

@@ -788,3 +790,28 @@ def load(cls, path: Union[str, Path], backend="pickle"):
788790
if not isinstance(result, cls):
789791
raise Exception(f"Excepted instance of {cls}, instead got: {result.__class__.__name__}")
790792
return result
793+
794+
795+
class LRUCache(KeyValueCache[THashableKey, TValue], Generic[THashableKey, TValue]):
796+
def __init__(self, capacity: int) -> None:
797+
self._cache = OrderedDict()
798+
self._capacity = capacity
799+
800+
def get(self, key: THashableKey) -> TValue:
801+
if key not in self._cache:
802+
return None
803+
self._cache.move_to_end(key)
804+
return self._cache[key]
805+
806+
def set(self, key: THashableKey, value: TValue):
807+
if key in self._cache:
808+
self._cache.move_to_end(key)
809+
self._cache[key] = value
810+
if len(self._cache) > self._capacity:
811+
self._cache.popitem(last=False)
812+
813+
def __len__(self) -> int:
814+
return len(self._cache)
815+
816+
def clear(self) -> None:
817+
self._cache.clear()

src/sensai/util/helper.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This module contains various helper functions.
33
"""
44
import math
5-
from typing import Any, Sequence, Union, TypeVar, List
5+
from typing import Any, Sequence, Union, TypeVar, List, Optional, Dict, Container, Iterable
66

77
T = TypeVar("T")
88

@@ -58,6 +58,13 @@ def check_not_nan_dict(d: dict):
5858
raise ValueError(f"Got one or more NaN values: {invalid_keys}")
5959

6060

61+
def contains_any(container: Union[Container, Iterable], items: Sequence) -> bool:
62+
for item in items:
63+
if item in container:
64+
return True
65+
return False
66+
67+
6168
# noinspection PyUnusedLocal
6269
def mark_used(*args):
6370
"""
@@ -83,4 +90,23 @@ def flatten_arguments(args: Sequence[Union[T, Sequence[T]]]) -> List[T]:
8390
result.extend(arg)
8491
else:
8592
result.append(arg)
86-
return result
93+
return result
94+
95+
96+
def kwarg_if_not_none(arg_name: str, arg_value: Any) -> Dict[str, Any]:
97+
"""
98+
Supports the optional passing of a kwarg, returning a non-empty dictionary with the kwarg only
99+
if the value is not None.
100+
101+
This can be helpful to improve backward compatibility for cases where a kwarg was added later
102+
but old implementations weren't updated to have it, such that an exception will be raised only
103+
if the kwarg is actually used.
104+
105+
:param arg_name: the argument name
106+
:param arg_value: the value
107+
:return: a dictionary containing the kwarg or, if the value is None, an empty dictionary
108+
"""
109+
if arg_value is None:
110+
return {}
111+
else:
112+
return {arg_name: arg_value}

src/sensai/util/io.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def path(self, filename_suffix: str, extension_to_add=None, valid_other_extensio
5252
relevant if extensionToAdd is specified
5353
:return: the full path
5454
"""
55+
# replace forbidden characters
56+
filename_suffix = filename_suffix.replace(">=", "gte").replace(">", "gt")
57+
5558
if extension_to_add is not None:
5659
add_ext = True
5760
valid_extensions = set(valid_other_extensions) if valid_other_extensions is not None else set()
@@ -201,3 +204,19 @@ def _get_s3_object(self):
201204
session = boto3.session.Session(profile_name=os.getenv("AWS_PROFILE"))
202205
s3 = session.resource("s3")
203206
return s3.Bucket(self.bucket).Object(self.object)
207+
208+
209+
def create_path(root: str, *path_elems: str, is_dir: bool, make_dirs: bool = False) -> str:
210+
path = os.path.join(root, *path_elems)
211+
if make_dirs:
212+
dir_path = path if is_dir else os.path.dirname(path)
213+
os.makedirs(dir_path, exist_ok=True)
214+
return path
215+
216+
217+
def create_file_path(root, *path_elems, make_dirs: bool = False) -> str:
218+
return create_path(root, *path_elems, is_dir=False, make_dirs=make_dirs)
219+
220+
221+
def create_dir_path(root, *path_elems, make_dirs: bool = False) -> str:
222+
return create_path(root, *path_elems, is_dir=True, make_dirs=make_dirs)

src/sensai/util/logging.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import logging as lg
33
import sys
44
import time
5+
from abc import ABC, abstractmethod
56
from datetime import datetime
67
from io import StringIO
78
from logging import *
8-
from typing import List, Callable, Optional, TypeVar, TYPE_CHECKING
9+
from typing import List, Callable, Optional, TypeVar, TYPE_CHECKING, Generic
910

1011
from .time import format_duration
1112

@@ -17,6 +18,7 @@
1718

1819
LOG_DEFAULT_FORMAT = '%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s:%(lineno)d - %(message)s'
1920
T = TypeVar("T")
21+
THandler = TypeVar("THandler", bound=Handler)
2022

2123
# Holds the log format that is configured by the user (using function `configure`), such
2224
# that it can be reused in other places
@@ -142,18 +144,18 @@ def datetime_tag() -> str:
142144

143145
_fileLoggerPaths: List[str] = []
144146
_isAtExitReportFileLoggerRegistered = False
145-
_memoryLogStream: Optional[StringIO] = None
146147

147148

148149
def _at_exit_report_file_logger():
149150
for path in _fileLoggerPaths:
150151
print(f"A log file was saved to {path}")
151152

152153

153-
def add_file_logger(path, register_atexit=True):
154+
def add_file_logger(path, append=True, register_atexit=True) -> FileHandler:
154155
global _isAtExitReportFileLoggerRegistered
155156
log.info(f"Logging to {path} ...")
156-
handler = FileHandler(path)
157+
mode = "a" if append else "w"
158+
handler = FileHandler(path, mode=mode)
157159
handler.setFormatter(Formatter(_logFormat))
158160
Logger.root.addHandler(handler)
159161
_fileLoggerPaths.append(path)
@@ -163,22 +165,24 @@ def add_file_logger(path, register_atexit=True):
163165
return handler
164166

165167

166-
def add_memory_logger() -> None:
168+
class MemoryStreamHandler(StreamHandler):
169+
def __init__(self, stream: StringIO):
170+
super().__init__(stream)
171+
172+
def get_log(self) -> str:
173+
stream: StringIO = self.stream
174+
return stream.getvalue()
175+
176+
177+
def add_memory_logger() -> MemoryStreamHandler:
167178
"""
168-
Enables in-memory logging (if it is not already enabled), i.e. all log statements are written to a memory buffer and can later be
169-
read via function `get_memory_log()`
179+
Adds an in-memory logger, i.e. all log statements are written to a memory buffer which can be retrieved
180+
using the handler's `get_log` method.
170181
"""
171-
global _memoryLogStream
172-
if _memoryLogStream is not None:
173-
return
174-
_memoryLogStream = StringIO()
175-
handler = StreamHandler(_memoryLogStream)
182+
handler = MemoryStreamHandler(StringIO())
176183
handler.setFormatter(Formatter(_logFormat))
177184
Logger.root.addHandler(handler)
178-
179-
180-
def get_memory_log():
181-
return _memoryLogStream.getvalue()
185+
return handler
182186

183187

184188
class StopWatch:
@@ -334,29 +338,62 @@ def __enter__(self):
334338
return self
335339

336340

337-
class FileLoggerContext:
341+
class LoggerContext(Generic[THandler], ABC):
338342
"""
339-
A context handler to be used in conjunction with Python's `with` statement which enables file-based logging.
343+
Base class for context handlers to be used in conjunction with Python's `with` statement.
340344
"""
341-
def __init__(self, path: str, enabled=True):
345+
346+
def __init__(self, enabled=True):
342347
"""
343-
:param path: the path to the log file
344348
:param enabled: whether to actually perform any logging.
345349
This switch allows the with statement to be applied regardless of whether logging shall be enabled.
346350
"""
347351
self.enabled = enabled
348-
self.path = path
349352
self._log_handler = None
350353

351-
def __enter__(self):
354+
@abstractmethod
355+
def _create_log_handler(self) -> THandler:
356+
pass
357+
358+
def __enter__(self) -> Optional[THandler]:
352359
if self.enabled:
353-
self._log_handler = add_file_logger(self.path, register_atexit=False)
360+
self._log_handler = self._create_log_handler()
361+
return self._log_handler
354362

355363
def __exit__(self, exc_type, exc_value, traceback):
356364
if self._log_handler is not None:
357365
remove_log_handler(self._log_handler)
358366

359367

368+
class FileLoggerContext(LoggerContext[FileHandler]):
369+
"""
370+
A context handler to be used in conjunction with Python's `with` statement which enables file-based logging.
371+
"""
372+
373+
def __init__(self, path: str, append=True, enabled=True):
374+
"""
375+
:param path: the path to the log file
376+
:param append: whether to append in case the file already exists; if False, always create a new file.
377+
:param enabled: whether to actually perform any logging.
378+
This switch allows the with statement to be applied regardless of whether logging shall be enabled.
379+
"""
380+
self.path = path
381+
self.append = append
382+
super().__init__(enabled=enabled)
383+
384+
def _create_log_handler(self) -> FileHandler:
385+
return add_file_logger(self.path, append=self.append, register_atexit=False)
386+
387+
388+
class MemoryLoggerContext(LoggerContext[MemoryStreamHandler]):
389+
"""
390+
A context handler to be used in conjunction with Python's `with` statement which enables in-memory logging.
391+
"""
392+
393+
def _create_log_handler(self) -> MemoryStreamHandler:
394+
return add_memory_logger()
395+
396+
360397
class LoggingDisabledContext:
361398
"""
362399
A context manager that will temporarily disable logging

0 commit comments

Comments
 (0)