Skip to content

Commit 037ad79

Browse files
committed
Update with code from sensAI 1.5.0
1 parent 5c1eec0 commit 037ad79

13 files changed

Lines changed: 583 additions & 58 deletions

File tree

src/sensai/util/cache.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import sqlite3
99
import threading
1010
import time
11-
from abc import abstractmethod, ABC
11+
from abc import ABC, abstractmethod
1212
from collections import OrderedDict
1313
from functools import wraps
1414
from pathlib import Path
15-
from typing import Any, Callable, Iterator, List, Optional, TypeVar, Generic, Union, Hashable
15+
from typing import Any, Callable, Generic, Hashable, Iterator, List, Optional, TypeVar, Union
1616

1717
from .hash import pickle_hash
18-
from .pickle import load_pickle, dump_pickle, setstate
18+
from .pickle import dump_pickle, load_pickle, setstate
1919

2020
log = logging.getLogger(__name__)
2121

@@ -443,22 +443,25 @@ class KeyType(enum.Enum):
443443
INTEGER = ("LONG", )
444444

445445
def __init__(self, path, table_name="cache", deferred_commit_delay_secs=1.0, key_type: KeyType = KeyType.STRING,
446-
max_key_length=255):
446+
max_key_length=255, max_entries_for_commit: int = 500):
447447
"""
448448
:param path: the path to the file that is to hold the SQLite database
449449
:param table_name: the name of the table to create in the database
450450
:param deferred_commit_delay_secs: the time frame during which no new data must be added for a pending transaction to be committed
451451
:param key_type: the type to use for keys; for complex keys (i.e. tuples), use STRING (conversions to string are automatic)
452452
:param max_key_length: the maximum key length for the case where the key_type can be parametrised (e.g. STRING)
453453
"""
454+
os.makedirs(os.path.dirname(path), exist_ok=True)
454455
self.path = path
455456
self.conn = SqliteConnectionManager.open_connection(path)
456457
self.table_name = table_name
457458
self.max_key_length = 255
458459
self.key_type = key_type
459460
self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs)
460461
self._num_entries_to_be_committed = 0
462+
self._num_entries_total_changed = 0
461463
self._conn_mutex = threading.Lock()
464+
self._max_entries_for_commit = max_entries_for_commit
462465

463466
cursor = self.conn.cursor()
464467
cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table';")
@@ -484,9 +487,10 @@ def _key_db_value(self, key):
484487
def _commit(self):
485488
self._conn_mutex.acquire()
486489
try:
487-
log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the SQLite database {self.path}")
488-
self.conn.commit()
489-
self._num_entries_to_be_committed = 0
490+
if self._num_entries_to_be_committed > 0:
491+
log.info(f"Committing {self._num_entries_to_be_committed} cache entries ({self._num_entries_total_changed} total in session) to the SQLite database {self.path}")
492+
self.conn.commit()
493+
self._num_entries_to_be_committed = 0
490494
finally:
491495
self._conn_mutex.release()
492496

@@ -502,11 +506,15 @@ def set(self, key: TKey, value: TValue):
502506
else:
503507
cursor.execute(f"UPDATE {self.table_name} SET cache_value=? WHERE cache_key=?", (pickle.dumps(value), key))
504508
self._num_entries_to_be_committed += 1
509+
self._num_entries_total_changed += 1
505510
cursor.close()
506511
finally:
507512
self._conn_mutex.release()
508513

509-
self._update_hook.handle_update()
514+
if self._num_entries_to_be_committed >= self._max_entries_for_commit:
515+
self._commit()
516+
else:
517+
self._update_hook.handle_update()
510518

511519
def _execute(self, cursor, *query):
512520
try:
@@ -553,6 +561,12 @@ def iter_items(self):
553561
finally:
554562
self._conn_mutex.release()
555563

564+
def finalise(self):
565+
"""
566+
Commits any not yet commited entries to the database
567+
"""
568+
self._commit()
569+
556570

557571
class SqlitePersistentList(PersistentList):
558572
def __init__(self, path):
@@ -774,6 +788,15 @@ def save(self, path: Union[str, Path], backend="pickle"):
774788
:param path:
775789
:param backend: pickle, cloudpickle, or joblib
776790
"""
791+
if not hasattr(self, "__getstate__"):
792+
log.warning(
793+
f"You are persisting an object {self} without __getstate__. "
794+
f"This may lead to irrecoverable problems with backwards compatibility! "
795+
f"It is highly recommended that you implement __getstate__ for everything you persist (especially for stateless classes, since"
796+
f"adding state later on means they won't be able to load and even implementing __setstate__ won't help then). "
797+
f"The easiest and recommended way to do this is to also inherit from `sensai.util.pickle.PersistableObject` whenever"
798+
f"you inherit from `PickleLoadSaveMixin`"
799+
)
777800
dump_pickle(self, path, backend=backend)
778801

779802
@classmethod

src/sensai/util/cache_mysql.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,74 @@
99

1010

1111
class MySQLPersistentKeyValueCache(PersistentKeyValueCache):
12+
"""
13+
Can cache arbitrary values in a MySQL database.
14+
The keys are always strings at the database level, i.e. if a key is not a string, it is converted to a string using str().
15+
"""
1216

1317
class ValueType(enum.Enum):
14-
DOUBLE = ("DOUBLE", False) # (SQL data type, isCachedValuePickled)
18+
"""
19+
The value type to use within the MySQL database.
20+
Note that the binary BLOB types can be used for all Python types that can be pickled, so the lack
21+
of specific types (e.g. for strings) is not a problem.
22+
"""
23+
# enum values are (SQL data type, isCachedValuePickled)
24+
DOUBLE = ("DOUBLE", False)
1525
BLOB = ("BLOB", True)
26+
"""
27+
for Python data types whose pickled representation is up to 64 KB
28+
"""
29+
MEDIUMBLOB = ("MEDIUMBLOB", True)
30+
"""
31+
for Python data types whose pickled representation is up to 16 MB
32+
"""
1633

17-
def __init__(self, host, db, user, pw, value_type: ValueType, table_name="cache", deferred_commit_delay_secs=1.0, in_memory=False):
18-
import MySQLdb
19-
self.conn = MySQLdb.connect(host=host, database=db, user=user, password=pw)
34+
def __init__(self, host: str, db: str, user: str, pw: str, value_type: ValueType, table_name="cache",
35+
connect_params: dict | None = None, in_memory=False, max_key_length: int = 255, port=3306):
36+
"""
37+
:param host:
38+
:param db:
39+
:param user:
40+
:param pw:
41+
:param value_type: the type of value to store in the cache
42+
:param table_name:
43+
:param connect_params: additional parameters to pass to the pymysql.connect() function (e.g. ssl, etc.)
44+
:param in_memory:
45+
:param max_key_length: maximal length of the cache key string (keys are always strings) stored in the DB
46+
(i.e. the MySQL type is VARCHAR[max_key_length])
47+
:param port: the MySQL server port to connect to
48+
"""
49+
import pymysql
50+
if connect_params is None:
51+
connect_params = {}
52+
self._connect = lambda: pymysql.connect(host=host, database=db, user=user, password=pw, port=port, autocommit=True,
53+
**connect_params)
54+
self._conn = self._connect()
2055
self.table_name = table_name
21-
self.max_key_length = 255
22-
self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs)
23-
self._num_entries_to_be_committed = 0
56+
self.max_key_length = max_key_length
2457

2558
cache_value_sql_type, self.is_cache_value_pickled = value_type.value
2659

27-
cursor = self.conn.cursor()
60+
cursor = self._conn.cursor()
2861
cursor.execute(f"SHOW TABLES;")
2962
if table_name not in [r[0] for r in cursor.fetchall()]:
63+
log.debug(f"Creating table {table_name}")
3064
cursor.execute(f"CREATE TABLE {table_name} (cache_key VARCHAR({self.max_key_length}) PRIMARY KEY, "
3165
f"cache_value {cache_value_sql_type});")
3266
cursor.close()
3367

3468
self._in_memory_df = None if not in_memory else self._load_table_to_data_frame()
3569

70+
def _cursor(self):
71+
try:
72+
self._conn.ping(reconnect=True)
73+
except Exception as e:
74+
log.error(f"Error while pinging MySQL server: {e}; Reconnecting ...")
75+
self._conn = self._connect()
76+
return self._conn.cursor()
77+
3678
def _load_table_to_data_frame(self):
37-
df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self.conn, index_col="cache_key")
79+
df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self._conn, index_col="cache_key")
3880
if self.is_cache_value_pickled:
3981
df["cache_value"] = df["cache_value"].apply(pickle.loads)
4082
return df
@@ -43,16 +85,25 @@ def set(self, key, value):
4385
key = str(key)
4486
if len(key) > self.max_key_length:
4587
raise ValueError(f"Key too long, maximal key length is {self.max_key_length}")
46-
cursor = self.conn.cursor()
88+
cursor = self._cursor()
4789
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=%s", (key,))
4890
stored_value = pickle.dumps(value) if self.is_cache_value_pickled else value
4991
if cursor.fetchone()[0] == 0:
50-
cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)",
51-
(key, stored_value))
92+
from pymysql.err import IntegrityError
93+
try:
94+
cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)",
95+
(key, stored_value))
96+
except IntegrityError as e:
97+
if e.args[0] == 1062: # Duplicate entry
98+
# This can only happen when the user is inserting the same value almost simultaneously (race condition)
99+
args = list(e.args)
100+
args[1] = f"{args[1]}; The duplicate entry is due to quasi-simultaneous insertions for the same key; " \
101+
"Check your application logic!"
102+
raise IntegrityError(*args)
103+
else:
104+
raise
52105
else:
53106
cursor.execute(f"UPDATE {self.table_name} SET cache_value=%s WHERE cache_key=%s", (stored_value, key))
54-
self._num_entries_to_be_committed += 1
55-
self._update_hook.handle_update()
56107
cursor.close()
57108
if self._in_memory_df is not None:
58109
self._in_memory_df["cache_value"][str(key)] = value
@@ -64,9 +115,10 @@ def get(self, key):
64115
return value
65116

66117
def _get_from_table(self, key):
67-
cursor = self.conn.cursor()
118+
cursor = self._cursor()
68119
cursor.execute(f"SELECT cache_value FROM {self.table_name} WHERE cache_key=%s", (str(key),))
69120
row = cursor.fetchone()
121+
cursor.close()
70122
if row is None:
71123
return None
72124
stored_value = row[0]
@@ -81,8 +133,3 @@ def _get_from_in_memory_df(self, key):
81133
except Exception as e:
82134
log.debug(f"Unable to load value for key {str(key)} from in-memory dataframe: {e}")
83135
return None
84-
85-
def _commit(self):
86-
log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the database")
87-
self.conn.commit()
88-
self._num_entries_to_be_committed = 0

src/sensai/util/git.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ def is_clean(self) -> bool:
2222
self.has_untracked_files)
2323

2424

25-
def git_status() -> Optional[GitStatus]:
25+
def git_status(log_error: bool = True) -> Optional[GitStatus]:
26+
"""
27+
Gets the git status of the current repository.
28+
29+
:param log_error: whether to log an error if the git status cannot be determined
30+
:return: the git status, or None if it cannot be determined
31+
"""
2632
try:
2733
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
2834
unstaged = bool(subprocess.check_output(['git', 'diff', '--name-only']).decode('ascii').strip())
@@ -35,5 +41,6 @@ def git_status() -> Optional[GitStatus]:
3541
has_untracked_files=untracked
3642
)
3743
except Exception as e:
38-
log.error("Error determining Git status", exc_info=e)
44+
if log_error:
45+
log.error("Error determining Git status", exc_info=e)
3946
return None

src/sensai/util/io.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,26 @@ def write_data_frame_csv_file(self, filename_suffix: str, df: "pd.DataFrame", in
9999
df.to_csv(p, index=index, header=header)
100100
return p
101101

102+
def write_data_frame_excel_file(self, filename_suffix: str, df: "pd.DataFrame", index=True, header=True, sheet_name="Sheet1"):
103+
"""
104+
Saves the given data frame to an Excel spreadsheet.
105+
106+
Requires that an appropriate engine be installed, e.g. openpyxl.
107+
108+
:param filename_suffix: the filename suffix (or full basename for the case where no prefix was specified at construction),
109+
to which the extension .xlsx will be added by default if no extension is given
110+
:param df: the data frame to save
111+
:param index: whether to save the index
112+
:param header: whether to save the header row
113+
:param sheet_name: the name of the spreadsheet
114+
:return: the path to the saved file
115+
"""
116+
p = self.path(filename_suffix, extension_to_add="xlsx")
117+
if self.enabled:
118+
self.log.info(f"Saving data frame Excel file {p}")
119+
df.to_excel(p, index=index, header=header, sheet_name=sheet_name)
120+
return p
121+
102122
def write_figure(self, filename_suffix: str, fig: "plt.Figure", close_figure: Optional[bool] = None):
103123
"""
104124
:param filename_suffix: the filename suffix, which may or may not include a file extension, valid extensions being {"png", "jpg"}
@@ -220,3 +240,34 @@ def create_file_path(root, *path_elems, make_dirs: bool = False) -> str:
220240

221241
def create_dir_path(root, *path_elems, make_dirs: bool = False) -> str:
222242
return create_path(root, *path_elems, is_dir=True, make_dirs=make_dirs)
243+
244+
245+
def filename_compatible(fn: str, replacement: str = "-") -> str:
246+
"""
247+
Converts the given string into a string that can be used in a filename
248+
249+
:param fn: original filename, which may or may not be compatible with common filenames
250+
:param replacement: default character to use as a replacement
251+
:return: adapted filename
252+
"""
253+
# Replace prohibited characters
254+
# Windows prohibits: \ / : * ? " < > |
255+
# Unix/Linux prohibits: /
256+
# macOS prohibits: / and :
257+
replacements = {
258+
">": "gt",
259+
"<": "lt",
260+
"/": replacement,
261+
"\\": replacement,
262+
":": replacement,
263+
"*": replacement,
264+
"?": replacement,
265+
"\"": replacement,
266+
"|": replacement
267+
}
268+
269+
result = fn
270+
for char, repl in replacements.items():
271+
result = result.replace(char, repl)
272+
273+
return result

0 commit comments

Comments
 (0)