diff --git a/datajoint/autopopulate.py b/datajoint/autopopulate.py index 0e16ee29b..6d72b7aa7 100644 --- a/datajoint/autopopulate.py +++ b/datajoint/autopopulate.py @@ -12,6 +12,7 @@ import signal import multiprocessing as mp import contextlib +import deepdiff # noinspection PyExceptionInherit,PyCallingNonCallable @@ -309,17 +310,46 @@ def _populate1( ): return False - self.connection.start_transaction() + # if make is a generator, it transaction can be delayed until the final stage + is_generator = inspect.isgeneratorfunction(make) + if not is_generator: + self.connection.start_transaction() + if key in self.target: # already populated - self.connection.cancel_transaction() + if not is_generator: + self.connection.cancel_transaction() if jobs is not None: jobs.complete(self.target.table_name, self._job_key(key)) return False logger.debug(f"Making {key} -> {self.target.full_table_name}") self.__class__._allow_insert = True + try: - make(dict(key), **(make_kwargs or {})) + if not is_generator: + make(dict(key), **(make_kwargs or {})) + else: + # tripartite make - transaction is delayed until the final stage + gen = make(dict(key), **(make_kwargs or {})) + fetched_data = next(gen) + fetch_hash = deepdiff.DeepHash( + fetched_data, ignore_iterable_order=False + )[fetched_data] + computed_result = next(gen) # perform the computation + # fetch and insert inside a transaction + self.connection.start_transaction() + gen = make(dict(key), **(make_kwargs or {})) # restart make + fetched_data = next(gen) + if ( + fetch_hash + != deepdiff.DeepHash(fetched_data, ignore_iterable_order=False)[ + fetched_data + ] + ): # rollback due to referential integrity fail + self.connection.cancel_transaction() + return False + gen.send(computed_result) # insert + except (KeyboardInterrupt, SystemExit, Exception) as error: try: self.connection.cancel_transaction() diff --git a/datajoint/blob.py b/datajoint/blob.py index 891522fd2..f38525477 100644 --- a/datajoint/blob.py +++ b/datajoint/blob.py @@ -204,7 +204,7 @@ def pack_blob(self, obj): return self.pack_dict(obj) if isinstance(obj, str): return self.pack_string(obj) - if isinstance(obj, collections.abc.ByteString): + if isinstance(obj, (bytes, bytearray)): return self.pack_bytes(obj) if isinstance(obj, collections.abc.MutableSequence): return self.pack_list(obj) diff --git a/datajoint/connection.py b/datajoint/connection.py index 7536e7af2..5d2fbc27e 100644 --- a/datajoint/connection.py +++ b/datajoint/connection.py @@ -12,7 +12,7 @@ import pathlib from .settings import config -from . import errors +from . import errors, __version__ from .dependencies import Dependencies from .blob import pack, unpack from .hash import uuid_from_buffer @@ -190,15 +190,20 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None) self.conn_info["ssl_input"] = use_tls self.conn_info["host_input"] = host_input self.init_fun = init_fun - logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info)) self._conn = None self._query_cache = None connect_host_hook(self) if self.is_connected: - logger.info("Connected {user}@{host}:{port}".format(**self.conn_info)) + logger.info( + "DataJoint {version} connected to {user}@{host}:{port}".format( + version=__version__, **self.conn_info + ) + ) self.connection_id = self.query("SELECT connection_id()").fetchone()[0] else: - raise errors.LostConnectionError("Connection failed.") + raise errors.LostConnectionError( + "Connection failed {user}@{host}:{port}".format(**self.conn_info) + ) self._in_transaction = False self.schemas = dict() self.dependencies = Dependencies(self) @@ -344,7 +349,7 @@ def query( except errors.LostConnectionError: if not reconnect: raise - logger.warning("MySQL server has gone away. Reconnecting to the server.") + logger.warning("Reconnecting to MySQL server.") connect_host_hook(self) if self._in_transaction: self.cancel_transaction() diff --git a/datajoint/external.py b/datajoint/external.py index a3a546e22..08787ca7f 100644 --- a/datajoint/external.py +++ b/datajoint/external.py @@ -22,7 +22,7 @@ def subfold(name, folds): """ - subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde'] + subfolding for external storage: e.g. subfold('aBCdefg', (2, 3)) --> ['ab','cde'] """ return ( (name[: folds[0]].lower(),) + subfold(name[folds[0] :], folds[1:]) @@ -278,7 +278,7 @@ def upload_filepath(self, local_filepath): # check if the remote file already exists and verify that it matches check_hash = (self & {"hash": uuid}).fetch("contents_hash") - if check_hash: + if check_hash.size: # the tracking entry exists, check that it's the same file as before if contents_hash != check_hash[0]: raise DataJointError( diff --git a/datajoint/schemas.py b/datajoint/schemas.py index c3894ba29..7ea40724f 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -482,8 +482,8 @@ def list_tables(self): return [ t for d, t in ( - full_t.replace("`", "").split(".") - for full_t in self.connection.dependencies.topo_sort() + table_name.replace("`", "").split(".") + for table_name in self.connection.dependencies.topo_sort() ) if d == self.database ] diff --git a/datajoint/settings.py b/datajoint/settings.py index cdf27891d..f1c300029 100644 --- a/datajoint/settings.py +++ b/datajoint/settings.py @@ -1,5 +1,5 @@ """ -Settings for DataJoint. +Settings for DataJoint """ from contextlib import contextmanager @@ -48,7 +48,8 @@ "database.use_tls": None, "enable_python_native_blobs": True, # python-native/dj0 encoding support "add_hidden_timestamp": False, - "filepath_checksum_size_limit": None, # file size limit for when to disable checksums + # file size limit for when to disable checksums + "filepath_checksum_size_limit": None, } ) @@ -117,6 +118,7 @@ def load(self, filename): if filename is None: filename = LOCALCONFIG with open(filename, "r") as fid: + logger.info(f"DataJoint is configured from {os.path.abspath(filename)}") self._conf.update(json.load(fid)) def save_local(self, verbose=False): @@ -236,7 +238,8 @@ class __Config: def __init__(self, *args, **kwargs): self._conf = dict(default) - self._conf.update(dict(*args, **kwargs)) # use the free update to set keys + # use the free update to set keys + self._conf.update(dict(*args, **kwargs)) def __getitem__(self, key): return self._conf[key] @@ -250,7 +253,9 @@ def __setitem__(self, key, value): valid_logging_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} if key == "loglevel": if value not in valid_logging_levels: - raise ValueError(f"{'value'} is not a valid logging value") + raise ValueError( + f"'{value}' is not a valid logging value {tuple(valid_logging_levels)}" + ) logger.setLevel(value) @@ -260,11 +265,9 @@ def __setitem__(self, key, value): os.path.expanduser(n) for n in (LOCALCONFIG, os.path.join("~", GLOBALCONFIG)) ) try: - config_file = next(n for n in config_files if os.path.exists(n)) + config.load(next(n for n in config_files if os.path.exists(n))) except StopIteration: - pass -else: - config.load(config_file) + logger.info("No config file was found.") # override login credentials with environment variables mapping = { @@ -292,6 +295,8 @@ def __setitem__(self, key, value): ) if v is not None } -config.update(mapping) +if mapping: + logger.info(f"Overloaded settings {tuple(mapping)} from environment variables.") + config.update(mapping) logger.setLevel(log_levels[config["loglevel"]]) diff --git a/datajoint/version.py b/datajoint/version.py index 6bcf0e20a..cc1d88710 100644 --- a/datajoint/version.py +++ b/datajoint/version.py @@ -1,3 +1,3 @@ -__version__ = "0.14.3" +__version__ = "0.14.4" assert len(__version__) <= 10 # The log table limits version to the 10 characters diff --git a/pyproject.toml b/pyproject.toml index 097d168e1..1eb8c723d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.14.3" dependencies = [ "numpy", "pymysql>=0.7.2", + "deepdiff", "pyparsing", "ipython", "pandas", diff --git a/tests/test_declare.py b/tests/test_declare.py index b3c928294..6e66f4c81 100644 --- a/tests/test_declare.py +++ b/tests/test_declare.py @@ -359,7 +359,6 @@ class WithSuchALongPartNameThatItCrashesMySQL(dj.Part): def test_regex_mismatch(schema_any): - class IndexAttribute(dj.Manual): definition = """ index: int diff --git a/tests/test_relational_operand.py b/tests/test_relational_operand.py index 8ff8286e1..f2c16e9cd 100644 --- a/tests/test_relational_operand.py +++ b/tests/test_relational_operand.py @@ -574,7 +574,6 @@ def test_union_multiple(schema_simp_pop): class TestDjTop: - def test_restrictions_by_top(self, schema_simp_pop): a = L() & dj.Top() b = L() & dj.Top(order_by=["cond_in_l", "KEY"])