From f9efc20334734e956f6a3958977df035827bb629 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 21 Nov 2025 16:08:52 -0800 Subject: [PATCH] File compression tracking --- src/spyglass/common/common_file_tracking.py | 492 ++++++++++++++++++++ src/spyglass/common/common_nwbfile.py | 80 +++- tests/common/test_file_tracking.py | 147 ++++++ 3 files changed, 711 insertions(+), 8 deletions(-) create mode 100644 src/spyglass/common/common_file_tracking.py create mode 100644 tests/common/test_file_tracking.py diff --git a/src/spyglass/common/common_file_tracking.py b/src/spyglass/common/common_file_tracking.py new file mode 100644 index 000000000..572190bba --- /dev/null +++ b/src/spyglass/common/common_file_tracking.py @@ -0,0 +1,492 @@ +"""File compression tracking for Spyglass NWB files. + +Provides compression/decompression with transparent access, leveraging +existing Nwbfile infrastructure and externals table for metadata. +""" + +import gzip +import lzma +import os +import shutil +import tempfile +import time +from contextlib import contextmanager +from datetime import datetime, timedelta +from pathlib import Path + +import datajoint as dj +import pynwb + +from spyglass.common import LabMember, Nwbfile +from spyglass.common.common_nwbfile import schema as nwbfile_schema +from spyglass.utils import logger + +schema = dj.schema("common_file_compression") + + +# Compression algorithm implementations +def _compress_gzip(input_path, output_path, **kwargs): + """Compress file using gzip.""" + with open(input_path, "rb") as f_in: + with gzip.open(output_path, "wb", **kwargs) as f_out: + shutil.copyfileobj(f_in, f_out) + + +def _compress_lzma(input_path, output_path, **kwargs): + """Compress file using lzma.""" + with open(input_path, "rb") as f_in: + with lzma.open(output_path, "wb", **kwargs) as f_out: + shutil.copyfileobj(f_in, f_out) + + +def _decompress_gzip(input_path, output_path): + """Decompress gzip file.""" + with gzip.open(input_path, "rb") as f_in: + with open(output_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + +def _decompress_lzma(input_path, output_path): + """Decompress lzma file.""" + with lzma.open(input_path, "rb") as f_in: + with open(output_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + +# Mapping of algorithm names to compression/decompression functions +COMPRESSION_FUNCTIONS = { + "gzip": _compress_gzip, + "lzma": _compress_lzma, +} + +DECOMPRESSION_FUNCTIONS = { + "gzip": _decompress_gzip, + "lzma": _decompress_lzma, +} + +COMPRESSION_SUFFIXES = { + "gzip": ".gz", + "lzma": ".xz", +} + + +@schema +class CompressionParams(dj.Lookup): + """Supported compression algorithms and parameters. + + Maps algorithm names to compression functions with configurable parameters. + Adding new algorithms requires implementing compression/decompression functions + and adding entries to COMPRESSION_FUNCTIONS/DECOMPRESSION_FUNCTIONS dicts. + """ + + definition = """ + param_id: int unsigned auto_increment # Unique parameter set ID + --- + algorithm: varchar(32) # Algorithm name (gzip, lzma, zstd, etc.) + kwargs: blob # Algorithm parameters as dict + description: varchar(127) # Human-readable description + """ + + contents = [ + (1, "gzip", {"compresslevel": 1}, "Gzip level 1 (fastest)"), + (2, "gzip", {"compresslevel": 4}, "Gzip level 4 (balanced)"), + (3, "gzip", {"compresslevel": 6}, "Gzip level 6 (default)"), + (4, "gzip", {"compresslevel": 9}, "Gzip level 9 (best)"), + (5, "lzma", {"preset": 1}, "LZMA preset 1 (fastest)"), + (6, "lzma", {"preset": 6}, "LZMA preset 6 (default)"), + (7, "lzma", {"preset": 9}, "LZMA preset 9 (best)"), + ] + + default_param_id = 3 # gzip level 6 + + def insert1(self, *args, **kwargs): + """Require admin privileges to add compression algorithms.""" + LabMember().check_admin_privilege( + "Admin permissions required to add compression algorithms" + ) + super().insert1(*args, **kwargs) + + +@schema +class CompressedNwbfile(dj.Manual): + """Tracks compressed NWB files with transparent decompression. + + Uses externals table for file size and checksum metadata. + """ + + definition = """ + -> Nwbfile # Foreign key to existing Nwbfile table + --- + is_compressed: bool # Whether file is currently compressed + is_deleted=0: bool # Whether original file has been deleted + compressed_path: varchar(255) # Path to compressed file (.gz) + -> CompressionParams # Compression algorithm used + compressed_size_bytes: bigint unsigned # Size of compressed file + compression_ratio: float # Ratio (original/compressed) + compressed_time=CURRENT_TIMESTAMP: timestamp # When compressed + """ + + _cache = {} # {nwb_file_name: temp_path} + + def make(self, key, param_id=None): + """Compress an NWB file (standard DataJoint make function). + + Parameters + ---------- + key : dict + Must contain 'nwb_file_name' + param_id : int, optional + Compression parameter set ID from CompressionParams + (default: uses default_param_id) + """ + nwb_file_name = key["nwb_file_name"] + + # Get original file info from Nwbfile and externals + if not (Nwbfile & key): + raise ValueError(f"NWB file not found: {nwb_file_name}") + + original_path = (Nwbfile & key).get_abs_path() + + if not os.path.exists(original_path): + raise FileNotFoundError(f"File not found: {original_path}") + + # Check if already compressed + if self & key: + logger.warning( + f"{nwb_file_name} already tracked in compression table" + ) + return + + # Get original size from externals table + try: + # Access the external store 'raw' which tracks file metadata + ext_store = nwbfile_schema.external["raw"] + + # Query the external tracking table for file size + ext_table = ext_store.tracker + file_meta = (ext_table & {"filepath": original_path}).fetch1() + original_size = file_meta["size"] + except (KeyError, AttributeError, Exception) as e: + # Fallback to filesystem if externals table query fails + logger.warning( + f"Could not fetch size from externals table ({e}), " + "using filesystem" + ) + original_size = os.path.getsize(original_path) + + # Get compression parameters + if param_id is None: + param_id = CompressionParams().default_param_id + + params_query = CompressionParams & {"param_id": param_id} + if not params_query: + raise ValueError( + f"Parameter ID not found: {param_id}. " + f"Available IDs: {CompressionParams.fetch('param_id')}" + ) + + params = params_query.fetch1() + algorithm = params["algorithm"] + kwargs = params["kwargs"] + description = params["description"] + + # Get compression function and file suffix + compress_func = COMPRESSION_FUNCTIONS.get(algorithm) + if not compress_func: + raise ValueError( + f"Unsupported algorithm: {algorithm}. " + f"Available: {list(COMPRESSION_FUNCTIONS.keys())}" + ) + + suffix = COMPRESSION_SUFFIXES.get(algorithm, ".compressed") + compressed_path = str(original_path) + suffix + + logger.info( + f"Compressing {nwb_file_name} with {description} " + f"(param_id={param_id})..." + ) + + # Use _safe_compress to create compressed file atomically + with _safe_compress(original_path, compressed_path) as temp_path: + # Compress using the algorithm-specific function with kwargs + compress_func(original_path, temp_path, **kwargs) + + compressed_size = os.path.getsize(compressed_path) + ratio = original_size / compressed_size if compressed_size > 0 else 0 + + # Insert into table + self.insert1( + { + "nwb_file_name": nwb_file_name, + "is_compressed": True, + "compressed_path": compressed_path, + "param_id": param_id, + "compressed_size_bytes": compressed_size, + "compression_ratio": ratio, + } + ) + + def decompress(self, nwb_file_name): + """Decompress file and update is_compressed flag. + + Uses the algorithm-specific decompression function based on the + compression parameters stored in the table. + + Parameters + ---------- + nwb_file_name : str + NWB file name + + Returns + ------- + str + Path to original (decompressed) file + """ + key = {"nwb_file_name": nwb_file_name} + + # Check if in compression table + if not (self & key): + logger.warning(f"{nwb_file_name} not in compression table") + return (Nwbfile & key).get_abs_path() + + metadata = (self & key).fetch1() + + # If already decompressed, return original path + if not metadata["is_compressed"]: + logger.info(f"{nwb_file_name} already decompressed") + return (Nwbfile & key).get_abs_path() + + compressed_path = metadata["compressed_path"] + original_path = (Nwbfile & key).get_abs_path() + param_id = metadata["param_id"] + + if not os.path.exists(compressed_path): + raise FileNotFoundError( + f"Compressed file not found: {compressed_path}" + ) + + # Get algorithm from params + params = (CompressionParams & {"param_id": param_id}).fetch1() + algorithm = params["algorithm"] + + # Get decompression function + decompress_func = DECOMPRESSION_FUNCTIONS.get(algorithm) + if not decompress_func: + raise ValueError( + f"Unsupported algorithm: {algorithm}. " + f"Available: {list(DECOMPRESSION_FUNCTIONS.keys())}" + ) + + logger.info(f"Decompressing {nwb_file_name} ({algorithm})...") + + # Decompress and time it + start_time = time.time() + decompress_func(compressed_path, original_path) + decompress_time_ms = int((time.time() - start_time) * 1000) + + # Update is_compressed flag + (self & key)._update("is_compressed", False) + + logger.info(f"Decompressed {nwb_file_name} in {decompress_time_ms}ms") + + return original_path + + def delete_files(self, age_days=7, dry_run=True, check_recent_access=False): + """Delete original files for compressed entries older than specified age. + + Optionally checks AccessLog to avoid deleting recently accessed files. + + Parameters + ---------- + age_days : int + Minimum age in days for files to delete (default: 7) + dry_run : bool + If True, report what would be deleted without deleting (default: True) + check_recent_access : bool + If True, skip files accessed within age_days (default: False) + + Returns + ------- + list + List of files deleted (or would be deleted if dry_run=True) + """ + cutoff_date = datetime.now() - timedelta(days=age_days) + + # Find compressed files older than cutoff and not deleted + compressed_entries = ( + self + & "is_compressed = 1" + & "is_deleted = 0" + & f'compressed_time < "{cutoff_date}"' + ).fetch(as_dict=True) + + deleted = [] + skipped = [] + + for entry in compressed_entries: + nwb_file_name = entry["nwb_file_name"] + + # Optionally check if file was accessed recently + if check_recent_access: + recent_access = ( + Nwbfile.AccessLog + & {"nwb_file_name": nwb_file_name} + & f'access_time >= "{cutoff_date}"' + ) + if recent_access: + skipped.append(nwb_file_name) + logger.info( + f"Skipping {nwb_file_name} - " + f"accessed within last {age_days} days" + ) + continue + + original_path = Nwbfile().get_abs_path(nwb_file_name) + + if os.path.exists(original_path): + size_gb = os.path.getsize(original_path) / 1e9 + + if dry_run: + logger.info( + f"[DRY RUN] Would delete {nwb_file_name} " + f"({size_gb:.2f} GB, compressed {age_days}+ days ago)" + ) + deleted.append(nwb_file_name) + else: + os.unlink(original_path) + # Mark as deleted in table + (self & {"nwb_file_name": nwb_file_name})._update( + "is_deleted", True + ) + logger.info( + f"Deleted {nwb_file_name} " + f"({size_gb:.2f} GB, compressed {age_days}+ days ago)" + ) + deleted.append(nwb_file_name) + + if dry_run: + logger.info( + f"[DRY RUN] Would delete {len(deleted)} files" + + ( + f", skipped {len(skipped)} recently accessed" + if skipped + else "" + ) + + ", run with dry_run=False to actually delete" + ) + else: + logger.info( + f"Deleted {len(deleted)} original files" + + ( + f", skipped {len(skipped)} recently accessed" + if skipped + else "" + ) + ) + + return deleted + + def get_stats(self): + """Get compression statistics. + + Retrieves original file sizes from externals table and compressed + sizes from this table. + + Returns + ------- + dict + Summary statistics + """ + if not self: + return { + "total_files": 0, + "total_original_gb": 0, + "total_compressed_gb": 0, + "total_saved_gb": 0, + "avg_ratio": 0, + "currently_compressed": 0, + } + + entries = self.fetch(as_dict=True) + + total_orig = 0 + total_comp = 0 + ratios = [] + currently_compressed = 0 + + for entry in entries: + # Get original size from file system + nwb_file_name = entry["nwb_file_name"] + original_path = Nwbfile().get_abs_path(nwb_file_name) + + if os.path.exists(original_path): + orig_size = os.path.getsize(original_path) + else: + # If original deleted, calculate from compressed size and ratio + orig_size = ( + entry["compressed_size_bytes"] * entry["compression_ratio"] + ) + + total_orig += orig_size + total_comp += entry["compressed_size_bytes"] + ratios.append(entry["compression_ratio"]) + + if entry["is_compressed"]: + currently_compressed += 1 + + avg_ratio = sum(ratios) / len(ratios) if ratios else 0 + + return { + "total_files": len(self), + "total_original_gb": total_orig / 1e9, + "total_compressed_gb": total_comp / 1e9, + "total_saved_gb": (total_orig - total_comp) / 1e9, + "avg_ratio": avg_ratio, + "currently_compressed": currently_compressed, + } + + +# ============================================================================ +# Internal Utilities +# ============================================================================ + + +@contextmanager +def _safe_compress(input_path, output_path): + """Context manager for safe compression with locks and temp files. + + Checksums are managed by DataJoint's externals table. + + Yields the temp file path to write compressed data to. + """ + output_path = Path(output_path) + lock_path = Path(str(output_path) + ".lock") + + if lock_path.exists(): + raise RuntimeError( + f"Lock file exists: {lock_path}. " + "Another compression may be in progress." + ) + + lock_path.touch() + temp_fd, temp_output = tempfile.mkstemp( + suffix=".tmp", dir=output_path.parent + ) + os.close(temp_fd) + temp_path = Path(temp_output) + + try: + # Yield temp path for writing + yield temp_path + + # Atomically move temp file to final location + temp_path.rename(output_path) + + except Exception: + if temp_path.exists(): + temp_path.unlink() + raise + + finally: + if lock_path.exists(): + lock_path.unlink() diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index c44a49019..cdbb3b030 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -1,11 +1,7 @@ import os -import random import re -import string -import subprocess from pathlib import Path from typing import Dict, List, Optional, Set, Union -from uuid import uuid4 import datajoint as dj import h5py @@ -59,6 +55,23 @@ class Nwbfile(SpyglassMixin, dj.Manual): # NOTE: See #630, #664. Excessive key length. + class AccessLog(dj.Part): + """Track file access for usage analysis and compression decisions. + + Logs all file access events to help identify compression candidates + and monitor usage patterns. + """ + + definition = """ + -> master + access_id: int auto_increment + --- + dj_user: varchar(64) # DataJoint user who accessed the file + access_time=CURRENT_TIMESTAMP: timestamp + access_method: varchar(32) # Method used (fetch_nwb, direct, etc) + decompression_time_ms=null: int unsigned # Decompression duration + """ + @classmethod def insert_from_relative_file_name(cls, nwb_file_name: str) -> None: """Insert a new session from an existing NWB file. @@ -81,10 +94,61 @@ def insert_from_relative_file_name(cls, nwb_file_name: str) -> None: ) def fetch_nwb(self): - return [ - get_nwb_file(self.get_abs_path(file)) - for file in self.fetch("nwb_file_name") - ] + """Fetch NWB files, decompressing if needed. + + Checks compression status and automatically decompresses compressed files + before opening. Logs all file access events with user tracking for usage + analysis and compression decisions. + + Returns + ------- + list + List of opened NWB file objects + """ + import time + + from spyglass.common.common_file_tracking import CompressedNwbfile + + # Get current DataJoint user + dj_user = dj.config.get("database.user", "unknown") + + file_names = self.fetch("nwb_file_name") + nwb_files = [] + + for file_name in file_names: + decompression_time_ms = None + + # Check if file is compressed and decompress if needed + comp_entry = CompressedNwbfile() & {"nwb_file_name": file_name} + + if comp_entry: + is_compressed = comp_entry.fetch1("is_compressed") + + if is_compressed: + # Decompress and measure time + logger.info(f"Decompressing {file_name} for access") + start_time = time.time() + CompressedNwbfile().decompress(file_name) + decompression_time_ms = int( + (time.time() - start_time) * 1000 + ) + + # Log access event for all files + Nwbfile.AccessLog.insert1( + { + "nwb_file_name": file_name, + "dj_user": dj_user, + "access_method": "fetch_nwb", + "decompression_time_ms": decompression_time_ms, + }, + skip_duplicates=False, # Allow multiple access logs + ) + + # Get path and open file + file_path = self.get_abs_path(file_name) + nwb_files.append(get_nwb_file(file_path)) + + return nwb_files @classmethod def get_abs_path( diff --git a/tests/common/test_file_tracking.py b/tests/common/test_file_tracking.py new file mode 100644 index 000000000..49c940c5c --- /dev/null +++ b/tests/common/test_file_tracking.py @@ -0,0 +1,147 @@ +"""Unit tests for file compression utility functions. + +These tests cover the compression utilities that don't require database access. +Full integration tests with database are in test_file_tracking_integration.py +""" + +import gzip +import tempfile +from pathlib import Path + +import pytest +from datajoint.hash import uuid_from_file + + +@pytest.fixture +def temp_dir(): + """Create temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def sample_file(temp_dir): + """Create a sample file for testing.""" + file_path = temp_dir / "test_file.nwb" + with open(file_path, "wb") as f: + f.write(b"Test data content " * 1000) # ~18 KB + return file_path + + +class TestSafeCompress: + """Tests for _safe_compress context manager.""" + + def test_checksum_computation(self, sample_file, temp_dir): + """Test that checksums are computed correctly.""" + # Import here to avoid module-level schema creation + from spyglass.common.common_file_tracking import _safe_compress + + output_path = temp_dir / "output.gz" + + with _safe_compress(str(sample_file), str(output_path)) as checksums: + # Simulate compression + with open(sample_file, "rb") as f_in: + with gzip.open(output_path, "wb") as f_out: + f_out.write(f_in.read()) + + assert "input" in checksums + assert "output" in checksums + # UUID string format is 36 chars (includes hyphens) + assert len(str(checksums["input"])) == 36 # UUID string length + assert len(str(checksums["output"])) == 36 + + # Verify input checksum matches + expected = uuid_from_file(sample_file) + assert checksums["input"] == expected + + def test_lock_file_creation(self, sample_file, temp_dir): + """Test that lock file is created and removed.""" + from spyglass.common.common_file_tracking import _safe_compress + + output_path = temp_dir / "output.gz" + lock_path = Path(str(output_path) + ".lock") + + with _safe_compress(str(sample_file), str(output_path)) as _: + # Lock is removed in finally block + import gzip + + with open(sample_file, "rb") as f_in: + with gzip.open(output_path, "wb") as f_out: + f_out.write(f_in.read()) + + # Lock should be removed after + assert not lock_path.exists() + + def test_lock_file_prevents_concurrent(self, sample_file, temp_dir): + """Test that existing lock file prevents compression.""" + from spyglass.common.common_file_tracking import _safe_compress + + output_path = temp_dir / "output.gz" + lock_path = Path(str(output_path) + ".lock") + + # Create lock file + lock_path.touch() + + with pytest.raises(RuntimeError, match="Lock file exists"): + with _safe_compress(str(sample_file), str(output_path)): + pass + + # Cleanup + lock_path.unlink() + + def test_cleanup_on_error(self, sample_file, temp_dir): + """Test that temp file is cleaned up on error.""" + from spyglass.common.common_file_tracking import _safe_compress + + output_path = temp_dir / "output.gz" + + with pytest.raises(ValueError): + with _safe_compress(str(sample_file), str(output_path)): + raise ValueError("Test error") + + # Temp file should be cleaned up + temp_files = list(temp_dir.glob("*.tmp")) + assert len(temp_files) == 0 + + +class TestCompressionIntegrity: + """Tests for compression/decompression round trip.""" + + def test_checksum_verification(self, sample_file, temp_dir): + """Test that checksum verification catches corruption.""" + # Compress file + compressed_path = temp_dir / "compressed.gz" + with gzip.open(compressed_path, "wb") as f_out: + with open(sample_file, "rb") as f_in: + f_out.write(f_in.read()) + + # Get original checksum + original_checksum = uuid_from_file(sample_file) + + # Decompress + decompressed_path = temp_dir / "decompressed.nwb" + with gzip.open(compressed_path, "rb") as f_in: + with open(decompressed_path, "wb") as f_out: + f_out.write(f_in.read()) + + # Verify checksum matches + decompressed_checksum = uuid_from_file(decompressed_path) + assert original_checksum == decompressed_checksum + + def test_concurrent_compression_prevention(self, sample_file, temp_dir): + """Test that lock files prevent concurrent compression.""" + from spyglass.common.common_file_tracking import _safe_compress + + output_path = temp_dir / "output.gz" + lock_path = Path(str(output_path) + ".lock") + + # Start first compression + lock_path.touch() + + # Try second compression + with pytest.raises(RuntimeError, match="Lock file exists"): + with _safe_compress(str(sample_file), str(output_path)): + pass + + # Cleanup + lock_path.unlink()