diff --git a/tests/integration/bash_tests/run_from_any/test_globus_tar_deletion.bash b/tests/integration/bash_tests/run_from_any/test_globus_tar_deletion.bash index 4b76f4c5..cd5ac041 100755 --- a/tests/integration/bash_tests/run_from_any/test_globus_tar_deletion.bash +++ b/tests/integration/bash_tests/run_from_any/test_globus_tar_deletion.bash @@ -137,16 +137,17 @@ test_globus_tar_deletion() keep_flag="" fi - zstash create ${blocking_flag} ${keep_flag} --hpss=${globus_path}/${case_name} --maxsize 128 zstash_demo 2>&1 | tee ${case_name}.log + # Use -v so debug logs show up. + zstash create ${blocking_flag} ${keep_flag} --hpss=${globus_path}/${case_name} --maxsize 128 -v zstash_demo 2>&1 | tee ${case_name}.log if [ $? != 0 ]; then - echo "${case_name} failed. Check ${case_name}_create.log for details. Cannot continue." + echo "${case_name} failed. Check ${case_name}.log for details. Cannot continue." return 1 fi echo "${case_name} completed successfully. Checking ${case_name}.log now." check_log_has "Creating new tar archive 000000.tar" ${case_name}.log || return 2 echo "" - echo "Checking directory status after 'zstash create' has completed. src should only have index.db. dst should have tar and index.db." + echo "Checking directory status after 'zstash create' has completed." echo "Checking logs in current directory: ${PWD}" echo "" @@ -181,6 +182,98 @@ test_globus_tar_deletion() return 0 # Success } +test_globus_progressive_deletion() +{ + local path_to_repo=$1 + local dst_endpoint=$2 + local dst_dir=$3 + local blocking_str=$4 + + src_dir=${path_to_repo}/tests/utils/globus_tar_deletion + rm -rf ${src_dir} + mkdir -p ${src_dir} + dst_endpoint_uuid=$(get_endpoint ${dst_endpoint}) + globus_path=globus://${dst_endpoint_uuid}/${dst_dir} + + case_name=${blocking_str}_progressive_deletion + echo "Running test_globus_progressive_deletion on case=${case_name}" + echo "Exit codes: 0 -- success, 1 -- zstash failed, 2 -- grep check failed" + + setup ${case_name} "${src_dir}" + + # Create files totaling >2 GB to trigger multiple tars with maxsize=1 GB + # Each file is ~700 MB, so we'll get 3 tars + echo "Creating large test files (this may take a minute)..." + dd if=/dev/zero of=zstash_demo/file1.dat bs=1M count=700 2>/dev/null # 700 MB + dd if=/dev/zero of=zstash_demo/file2.dat bs=1M count=700 2>/dev/null # 700 MB + dd if=/dev/zero of=zstash_demo/file3.dat bs=1M count=700 2>/dev/null # 700 MB + echo "✓ Test files created" + + if [ "$blocking_str" == "non-blocking" ]; then + blocking_flag="--non-blocking" + else + blocking_flag="" + fi + + # Run with maxsize=1 GB to create multiple tars + echo "Running zstash create (this may take several minutes due to file size and transfers)..." + zstash create ${blocking_flag} --hpss=${globus_path}/${case_name} --maxsize 1 -v zstash_demo 2>&1 | tee ${case_name}.log + if [ $? != 0 ]; then + echo "${case_name} failed." + return 1 + fi + + # Check that multiple tar files were created + tar_count=$(grep -c "Creating new tar archive" ${case_name}.log) + if [ ${tar_count} -lt 2 ]; then + echo "Expected at least 2 tar archives to be created, found ${tar_count}" + return 2 + fi + echo "✓ Created ${tar_count} tar archives" + + # Check that files were deleted progressively + deletion_count=$(grep -c "Deleting .* files from successful transfer" ${case_name}.log) + + if [ "$blocking_str" == "blocking" ]; then + # In blocking mode, we should see deletion after each tar transfer + if [ ${deletion_count} -lt $((tar_count - 1)) ]; then + echo "Expected at least $((tar_count - 1)) deletion events in blocking mode, found ${deletion_count}" + return 2 + fi + echo "✓ Files deleted progressively (${deletion_count} deletion events)" + else + # In non-blocking mode, deletions happen when we check status + if [ ${deletion_count} -lt 1 ]; then + echo "Expected at least 1 deletion event in non-blocking mode, found ${deletion_count}" + return 2 + fi + echo "✓ Files deleted (${deletion_count} deletion events in non-blocking mode)" + fi + + # Verify that NO tar files remain in source after completion + echo "Checking that no tar files remain in source" + ls ${src_dir}/${case_name}/zstash_demo/zstash/*.tar 2>&1 | tee ls_tar_check.log + if grep -q "\.tar" ls_tar_check.log && ! grep -q "No such file" ls_tar_check.log; then + echo "Found tar files that should have been deleted!" + return 2 + fi + echo "✓ All tar files successfully deleted from source" + + # Verify tar files exist in destination + if [ "$blocking_str" == "non-blocking" ]; then + wait_for_directory "${dst_dir}/${case_name}" || return 1 + fi + + dst_tar_count=$(ls ${dst_dir}/${case_name}/*.tar 2>/dev/null | wc -l) + if [ ${dst_tar_count} -ne ${tar_count} ]; then + echo "Expected ${tar_count} tar files in destination, found ${dst_tar_count}" + return 2 + fi + echo "✓ All ${tar_count} tar files present in destination" + + return 0 +} + # Follow these directions ##################################################### # Example usage: @@ -232,7 +325,14 @@ run_test_with_tracking() { echo "Running: ${test_name}" echo "==========================================" - if test_globus_tar_deletion "${args[@]}"; then + # Determine which test function to call based on test name + if [[ "${test_name}" == *"progressive"* ]]; then + test_func=test_globus_progressive_deletion + else + test_func=test_globus_tar_deletion + fi + + if ${test_func} "${args[@]}"; then # Print test result in the output block AND at the end echo "✓ ${test_name} PASSED" test_results+=("✓ ${test_name} PASSED") # Uses Global variable @@ -252,15 +352,29 @@ tests_passed=0 tests_failed=0 test_results=() # Global variable to hold test results -echo "Primary tests: single authentication code tests for each endpoint" +echo "Primary tests: basic functionality tests" echo "If a test hangs, check if https://app.globus.org/activity reports any errors on your transfers." -# Run all tests independently +# Run basic tests +# These check that AT THE END of the run, +# we either still have the files (keep) or the files are deleted (non-keep). run_test_with_tracking "blocking_non-keep" ${path_to_repo} ${endpoint_str} ${machine_dst_dir} "blocking" "non-keep" || true run_test_with_tracking "non-blocking_non-keep" ${path_to_repo} ${endpoint_str} ${machine_dst_dir} "non-blocking" "non-keep" || true run_test_with_tracking "blocking_keep" ${path_to_repo} ${endpoint_str} ${machine_dst_dir} "blocking" "keep" || true run_test_with_tracking "non-blocking_keep" ${path_to_repo} ${endpoint_str} ${machine_dst_dir} "non-blocking" "keep" || true +echo "" +echo "Progressive deletion tests: verify files are deleted as transfers complete" +echo "WARNING: These tests create ~2GB of data and will take several minutes" + +# Run progressive deletion tests +# These check that DURING the run, +# files are deleted after successful transfers (non-keep only). +# Blocking -- get files, transfer files, delete at src, start next transfer. +# Non-blocking -- get files, transfer files, get next set of files, transfer those files, check if previous transfer is done (and if so, delete at src). +run_test_with_tracking "blocking_progressive_deletion" ${path_to_repo} ${endpoint_str} ${machine_dst_dir} "blocking" || true +run_test_with_tracking "non-blocking_progressive_deletion" ${path_to_repo} ${endpoint_str} ${machine_dst_dir} "non-blocking" || true + # Print summary echo "" echo "==========================================" diff --git a/tests/unit/test_optimized_update.py b/tests/unit/test_optimized_update.py index 59b5d78d..5934209b 100644 --- a/tests/unit/test_optimized_update.py +++ b/tests/unit/test_optimized_update.py @@ -563,40 +563,6 @@ def test_time_tolerance_check(self): assert is_within_tolerance == should_match -class TestBackwardCompatibility: - """Tests to ensure backward compatibility with existing code.""" - - def test_get_files_to_archive_still_works(self, tmp_path): - """Test that legacy get_files_to_archive function still works.""" - from zstash.utils import get_files_to_archive - - (tmp_path / "file.txt").write_text("content") - - os.chdir(tmp_path) - result = get_files_to_archive("cache", None, None) - - # Should return list of strings - assert isinstance(result, list) - assert len(result) > 0 - assert all(isinstance(item, str) for item in result) - - def test_output_format_matches_original(self, tmp_path): - """Test that file paths are normalized the same way as original.""" - subdir = tmp_path / "subdir" - subdir.mkdir() - (subdir / "file.txt").write_text("content") - - os.chdir(tmp_path) - - from zstash.utils import get_files_to_archive - - legacy_result = get_files_to_archive("cache", None, None) - new_result = list(get_files_to_archive_with_stats("cache", None, None).keys()) - - # Should produce same file list - assert legacy_result == new_result - - @pytest.fixture def mock_database(): """Fixture providing a mock database cursor.""" diff --git a/zstash/create.py b/zstash/create.py index b502f1e6..872b91e4 100644 --- a/zstash/create.py +++ b/zstash/create.py @@ -6,17 +6,19 @@ import os.path import sqlite3 import sys -from typing import Any, List, Tuple +from datetime import datetime +from typing import Any, Dict, List, Tuple from six.moves.urllib.parse import urlparse from .globus import globus_activate, globus_finalize from .hpss import hpss_put -from .hpss_utils import add_files +from .hpss_utils import DevOptions, construct_tars from .settings import DEFAULT_CACHE, config, get_db_filename, logger +from .transfer_tracking import TransferManager from .utils import ( create_tars_table, - get_files_to_archive, + get_files_to_archive_with_stats, run_command, tars_table_exists, ts_utc, @@ -52,12 +54,13 @@ def create(): logger.error(input_path_error_str) raise NotADirectoryError(input_path_error_str) + transfer_manager: TransferManager = TransferManager() if hpss != "none": url = urlparse(hpss) if url.scheme == "globus": # identify globus endpoints logger.debug(f"{ts_utc()}:Calling globus_activate(hpss)") - globus_activate(hpss) + transfer_manager.globus_config = globus_activate(hpss) else: # config.hpss is not "none", so we need to # create target HPSS directory @@ -88,14 +91,21 @@ def create(): # Create and set up the database logger.debug(f"{ts_utc()}: Calling create_database()") - failures: List[str] = create_database(cache, args) + failures: List[str] = create_database(cache, args, transfer_manager) # Transfer to HPSS. Always keep a local copy. logger.debug(f"{ts_utc()}: calling hpss_put() for {get_db_filename(cache)}") - hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True) + hpss_put( + hpss, + get_db_filename(cache), + cache, + transfer_manager, + keep=args.keep, + is_index=True, + ) logger.debug(f"{ts_utc()}: calling globus_finalize()") - globus_finalize(non_blocking=args.non_blocking) + globus_finalize(transfer_manager, args.keep) if len(failures) > 0: # List the failures @@ -204,7 +214,9 @@ def setup_create() -> Tuple[str, argparse.Namespace]: return cache, args -def create_database(cache: str, args: argparse.Namespace) -> List[str]: +def create_database( + cache: str, args: argparse.Namespace, transfer_manager: TransferManager +) -> List[str]: # Create new database logger.debug(f"{ts_utc()}:Creating index database") if os.path.exists(get_db_filename(cache)): @@ -260,44 +272,30 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]: cur.execute("insert into config values (?,?)", (attr, value)) con.commit() - files: List[str] = get_files_to_archive(cache, args.include, args.exclude) + file_stats: Dict[str, Tuple[int, datetime]] = get_files_to_archive_with_stats( + cache, args.include, args.exclude + ) failures: List[str] - if args.follow_symlinks: - try: - # Add files to archive - failures = add_files( - cur, - con, - -1, - files, - cache, - args.keep, - args.follow_symlinks, - skip_tars_md5=args.no_tars_md5, - non_blocking=args.non_blocking, - error_on_duplicate_tar=args.error_on_duplicate_tar, - overwrite_duplicate_tars=args.overwrite_duplicate_tars, - force_database_corruption=args.for_developers_force_database_corruption, - ) - except FileNotFoundError: - raise Exception("Archive creation failed due to broken symlink.") - else: - # Add files to archive - failures = add_files( - cur, - con, - -1, - files, - cache, - args.keep, - args.follow_symlinks, - skip_tars_md5=args.no_tars_md5, - non_blocking=args.non_blocking, - error_on_duplicate_tar=args.error_on_duplicate_tar, - overwrite_duplicate_tars=args.overwrite_duplicate_tars, - force_database_corruption=args.for_developers_force_database_corruption, - ) + dev_options: DevOptions = DevOptions( + error_on_duplicate_tar=args.error_on_duplicate_tar, + overwrite_duplicate_tars=args.overwrite_duplicate_tars, + force_database_corruption=args.for_developers_force_database_corruption, + ) + # Add files to archive + failures = construct_tars( + cur, + con, + -1, + file_stats, + cache, + args.keep, + args.follow_symlinks, + dev_options, + transfer_manager, + skip_tars_table=args.no_tars_md5, + non_blocking=args.non_blocking, + ) # Close database con.commit() diff --git a/zstash/globus.py b/zstash/globus.py index 2cacad5f..b99a6392 100644 --- a/zstash/globus.py +++ b/zstash/globus.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function import sys -from typing import List, Optional +from typing import Dict, List, Optional, Set, Tuple from globus_sdk import TransferAPIError, TransferClient, TransferData from globus_sdk.response import GlobusHTTPResponse @@ -10,85 +10,108 @@ from .globus_utils import ( HPSS_ENDPOINT_MAP, + add_file_to_TransferData, check_state_files, + create_TransferData, + get_label, get_local_endpoint_id, get_transfer_client_with_auth, - set_up_TransferData, submit_transfer_with_checks, ) from .settings import logger +from .transfer_tracking import GlobusConfig, TaskStatus, TransferBatch, TransferManager from .utils import ts_utc -remote_endpoint = None -local_endpoint = None -transfer_client: TransferClient = None -transfer_data: TransferData = None -task_id = None -archive_directory_listing: IterableTransferResponse = None - -def globus_activate(hpss: str): +def globus_activate( + hpss: str, globus_config: Optional[GlobusConfig] = None +) -> Optional[GlobusConfig]: """ Read the local globus endpoint UUID from ~/.zstash.ini. If the ini file does not exist, create an ini file with empty values, and try to find the local endpoint UUID based on the FQDN """ - global transfer_client - global local_endpoint - global remote_endpoint url = urlparse(hpss) if url.scheme != "globus": - return + return None + if globus_config is None: + globus_config = GlobusConfig() check_state_files() - remote_endpoint = url.netloc - local_endpoint = get_local_endpoint_id(local_endpoint) - if remote_endpoint.upper() in HPSS_ENDPOINT_MAP.keys(): - remote_endpoint = HPSS_ENDPOINT_MAP.get(remote_endpoint.upper()) - both_endpoints: List[Optional[str]] = [local_endpoint, remote_endpoint] - transfer_client = get_transfer_client_with_auth(both_endpoints) + globus_config.remote_endpoint = url.netloc + globus_config.local_endpoint = get_local_endpoint_id(globus_config.local_endpoint) + upper_remote_ep = globus_config.remote_endpoint.upper() + if upper_remote_ep in HPSS_ENDPOINT_MAP.keys(): + globus_config.remote_endpoint = HPSS_ENDPOINT_MAP[upper_remote_ep] + both_endpoints: List[Optional[str]] = [ + globus_config.local_endpoint, + globus_config.remote_endpoint, + ] + globus_config.transfer_client = get_transfer_client_with_auth(both_endpoints) for ep_id in both_endpoints: - r = transfer_client.endpoint_autoactivate(ep_id, if_expires_in=600) + r = globus_config.transfer_client.endpoint_autoactivate( + ep_id, if_expires_in=600 + ) if r.get("code") == "AutoActivationFailed": logger.error( f"The {ep_id} endpoint is not activated or the current activation expires soon. Please go to https://app.globus.org/file-manager/collections/{ep_id} and (re)activate the endpoint." ) sys.exit(1) + return globus_config -def file_exists(name: str) -> bool: - +def file_exists(archive_directory_listing: IterableTransferResponse, name: str) -> bool: for entry in archive_directory_listing: if entry.get("name") == name: return True return False -global_variable_tarfiles_pushed = 0 +def update_cumulative_tarfiles_pushed( + transfer_manager: TransferManager, transfer_data: TransferData +) -> None: + logger.info(f"{ts_utc()}: TransferData: accumulated items:") + attribs = transfer_data.__dict__ + for item in attribs["data"]["DATA"]: + if item["DATA_TYPE"] == "transfer_item": + transfer_manager.cumulative_tarfiles_pushed += 1 + print( + f"PUSHED (#{transfer_manager.cumulative_tarfiles_pushed}) tars, STORED source item: {item['source_path']}", + flush=True, + ) # C901 'globus_transfer' is too complex (20) def globus_transfer( # noqa: C901 - remote_ep: str, remote_path: str, name: str, transfer_type: str, non_blocking: bool -): - global transfer_data - global task_id - global archive_directory_listing - global global_variable_tarfiles_pushed + transfer_manager: TransferManager, + remote_ep: str, + remote_path: str, + name: str, + transfer_type: str, + non_blocking: bool, +) -> TaskStatus: logger.info(f"{ts_utc()}: Entered globus_transfer() for name = {name}") logger.debug(f"{ts_utc()}: non_blocking = {non_blocking}") - if not transfer_client: - globus_activate("globus://" + remote_ep) - if not transfer_client: + if (not transfer_manager.globus_config) or ( + not transfer_manager.globus_config.transfer_client + ): + transfer_manager.globus_config = globus_activate("globus://" + remote_ep) + if (not transfer_manager.globus_config) or ( + not transfer_manager.globus_config.transfer_client + ): sys.exit(1) if transfer_type == "get": - if not archive_directory_listing: - archive_directory_listing = transfer_client.operation_ls( - remote_endpoint, remote_path + if not transfer_manager.globus_config.archive_directory_listing: + transfer_manager.globus_config.archive_directory_listing = ( + transfer_manager.globus_config.transfer_client.operation_ls( + transfer_manager.globus_config.remote_endpoint, remote_path + ) ) - if not file_exists(name): + if not file_exists( + transfer_manager.globus_config.archive_directory_listing, name + ): logger.error( "Remote file globus://{}{}/{} does not exist".format( remote_ep, remote_path, name @@ -96,70 +119,127 @@ def globus_transfer( # noqa: C901 ) sys.exit(1) - transfer_data = set_up_TransferData( + mrb: Optional[TransferBatch] = transfer_manager.get_most_recent_batch() + if not mrb: + raise RuntimeError( + "The transfer manager should always have at least one batch by the time globus_transfer is called, however, the batch list is empty." + ) + + if transfer_manager.globus_config.local_endpoint: + local_endpoint: str = transfer_manager.globus_config.local_endpoint + else: + raise ValueError("Local endpoint ID is not set.") + if transfer_manager.globus_config.remote_endpoint: + remote_endpoint: str = transfer_manager.globus_config.remote_endpoint + else: + raise ValueError("Remote endpoint ID is not set.") + label: str = get_label(remote_path, name) + transfer_data: TransferData + if mrb.transfer_data: + # We already have a TransferData for this batch. + transfer_data = mrb.transfer_data + else: + # We need to create a new TransferData for this batch. + transfer_data = create_TransferData( + transfer_type, + local_endpoint, + remote_endpoint, + transfer_manager.globus_config.transfer_client, + label, + ) + add_file_to_TransferData( transfer_type, - local_endpoint, # Global - remote_endpoint, # Global + local_endpoint, + remote_endpoint, remote_path, name, - transfer_client, # Global - transfer_data, # Global + transfer_data, + label, ) task: GlobusHTTPResponse try: - if task_id: - task = transfer_client.get_task(task_id) - prev_task_status = task["status"] - # one of {ACTIVE, SUCCEEDED, FAILED, CANCELED, PENDING, INACTIVE} - # NOTE: How we behave here depends upon whether we want to support mutliple active transfers. - # Presently, we do not, except inadvertantly (if status == PENDING) - if prev_task_status == "ACTIVE": + if mrb.task_id: + # This the current transfer task associated with the most recent batch. + task = transfer_manager.globus_config.transfer_client.get_task(mrb.task_id) + # Update the most recent batch's task_status based on the current status from Globus API. + mrb.task_status = TaskStatus.convert_from_status_from_globus_sdk(task) + if mrb.task_status == TaskStatus.ACTIVE: + # The most recent transfer (mrb) is still active. logger.info( - f"{ts_utc()}: Previous task_id {task_id} Still Active. Returning ACTIVE." + f"{ts_utc()}: Previous task_id {mrb.task_id} Still Active. Returning ACTIVE." ) - return "ACTIVE" - elif prev_task_status == "SUCCEEDED": + if non_blocking: + # Globus allows up to 3 simulataneous transfers, + # but zstash is currently configured to only ever allow 1. + # If we're in this block, then we're already at 1 active transfer. + # We will therefore wait to submit a new transfer until it's done. + # So, we'll simply return and the next run of globus_transfer + # (i.e., on the next tar) will evaluate if the active transfer has finished. + return TaskStatus.ACTIVE + else: + # If we're in this block, then the blocking wait + # for the previous transfer to finish was unsuccessful. + # This is an unexpected state and so we raise an error. + error_str: str = ( + "task_status='ACTIVE', but in blocking mode, the previous transfer should have waited through globus_block_wait" + ) + logger.error(error_str) + raise RuntimeError(error_str) + elif mrb.task_status == TaskStatus.SUCCEEDED: logger.info( - f"{ts_utc()}: Previous task_id {task_id} status = SUCCEEDED." + f"{ts_utc()}: Previous task_id {mrb.task_id} status = SUCCEEDED." ) src_ep = task["source_endpoint_id"] dst_ep = task["destination_endpoint_id"] label = task["label"] ts = ts_utc() logger.info( - "{}:Globus transfer {}, from {} to {}: {} succeeded".format( - ts, task_id, src_ep, dst_ep, label - ) + f"{ts}:Globus transfer {mrb.task_id}, from {src_ep} to {dst_ep}: {label} succeeded" ) + # The previous transfer succeeded. + # That means we can transfer the current batch now. else: - logger.error( - f"{ts_utc()}: Previous task_id {task_id} status = {prev_task_status}." + # The previous transfer is in an unexpected state (i.e., "INACTIVE", "FAILED"). + # Either way, the previous transfer is effectively terminated, + # so we will proceed with the current transfer attempt. + # (I.e., we will not return yet). + # Note: any status we manually set + # (I.e., "UNKNOWN", "SUBMITTED", "EXHAUSTED_TIMEOUT_RETRIES") is NOT possible here, + # because we're using `task["status"]` from the globus_sdk TransferClient. + logger.warning( + f"{ts_utc()}: Previous task_id {mrb.task_id} status = {mrb.task_status}." ) - # DEBUG: review accumulated items in TransferData - logger.info(f"{ts_utc()}: TransferData: accumulated items:") - attribs = transfer_data.__dict__ - for item in attribs["data"]["DATA"]: - if item["DATA_TYPE"] == "transfer_item": - global_variable_tarfiles_pushed += 1 - print( - f" (routine) PUSHING (#{global_variable_tarfiles_pushed}) STORED source item: {item['source_path']}", - flush=True, - ) + update_cumulative_tarfiles_pushed(transfer_manager, transfer_data) - # SUBMIT new transfer here logger.info(f"{ts_utc()}: DIVING: Submit Transfer for {transfer_data['label']}") - task = submit_transfer_with_checks(transfer_client, transfer_data) + # Submit the current transfer_data + # ALWAYS submit. If we've gotten to this point, we're ready to submit. + task = submit_transfer_with_checks( + transfer_manager.globus_config.transfer_client, transfer_data + ) task_id = task.get("task_id") - # NOTE: This log message is misleading. If we have accumulated multiple tar files for transfer, - # the "lable" given here refers only to the LAST tarfile in the TransferData list. logger.info( - f"{ts_utc()}: SURFACE Submit Transfer returned new task_id = {task_id} for label {transfer_data['label']}" + f"{ts_utc()}: SURFACE Submit Transfer returned new task_id = {task_id}, with last tarfile having label: {transfer_data['label']}" ) - # Nullify the submitted transfer data structure so that a new one will be created on next call. - transfer_data = None + # Update the current batch with the task info + # The batch was already created in hpss_transfer with files added to it + # We just need to mark it as submitted + if transfer_manager.batches: + # Update these two fields of the most recent batch + # (which is still available in this function as `mrb`). + transfer_manager.batches[-1].task_id = task_id + transfer_manager.batches[-1].task_status = TaskStatus.SUBMITTED + else: + # This block should be impossible to reach. + # By now, we've ensured that `get_most_recent_batch()` returns a batch, + # and we haven't removed any batches since then, + # so there should always be at least one batch in `batches`. + error_str = "transfer_manager has no batches" + logger.error(error_str) + raise RuntimeError(error_str) except TransferAPIError as e: if e.code == "NoCredException": logger.error( @@ -174,51 +254,77 @@ def globus_transfer( # noqa: C901 logger.error("Exception: {}".format(e)) sys.exit(1) - # test for blocking on new task_id - task_status = "UNKNOWN" - if not non_blocking: - task_status = globus_block_wait( - task_id=task_id, wait_timeout=7200, polling_interval=10, max_retries=5 - ) + task_status: TaskStatus = TaskStatus.UNKNOWN + if mrb.task_id: + if not non_blocking: + # If blocking, wait for the task to complete and get the final status, + # before we proceed with any more transfers. + mrb.task_status = globus_block_wait( + transfer_manager.globus_config.transfer_client, + task_id=mrb.task_id, + ) + task_status = mrb.task_status + else: + logger.info( + f"{ts_utc()}: NO BLOCKING (task_wait) for task_id {mrb.task_id}" + ) else: - logger.info(f"{ts_utc()}: NO BLOCKING (task_wait) for task_id {task_id}") + # This block should be impossible to reach. + # By now, we've set `transfer_manager.batches[-1].task_id = task_id` or else raised an error, so `mrb.task_id` should always be set. + error_str = "No task_id found for most recent batch after submission" + logger.error(f"{ts_utc()}: {error_str}") + raise RuntimeError(error_str) if transfer_type == "put": return task_status if transfer_type == "get" and task_id: - globus_wait(task_id) + globus_wait(transfer_manager.globus_config.transfer_client, task_id) return task_status def globus_block_wait( - task_id: str, wait_timeout: int, polling_interval: int, max_retries: int -): - - # poll every "polling_interval" seconds to speed up small transfers. Report every 2 hours, stop waiting aftert 5*2 = 10 hours + transfer_client: TransferClient, + task_id: str, + wait_timeout: int = 7200, # 7200/3600 = 2 hours + max_retries: int = 5, +) -> TaskStatus: + # Poll every "polling_interval" seconds to speed up small transfers. + # Report every "wait_timeout" seconds, and stop waiting after "max_retries" reports. + # By default: report every 2 hours, stop waiting after 5*2 = 10 hours logger.info( f"{ts_utc()}: BLOCKING START: invoking task_wait for task_id = {task_id}" ) - task_status = "UNKNOWN" - retry_count = 0 + task_status: TaskStatus = TaskStatus.UNKNOWN + retry_count: int = 0 while retry_count < max_retries: try: - # Wait for the task to complete logger.info( f"{ts_utc()}: on task_wait try {retry_count + 1} out of {max_retries}" ) - transfer_client.task_wait( + # Wait for the task to complete. This is what makes this function BLOCKING. + # From https://globus-sdk-python.readthedocs.io/en/stable/services/transfer.html#globus_sdk.TransferClient.task_wait: Wait until a Task is complete or fails, with a time limit. If the task is “ACTIVE” after time runs out, returns False. Otherwise returns True. + task_is_not_active: bool = transfer_client.task_wait( task_id, timeout=wait_timeout, polling_interval=10 ) + if task_is_not_active: + curr_task: GlobusHTTPResponse = transfer_client.get_task(task_id) + task_status = TaskStatus.convert_from_status_from_globus_sdk(curr_task) + if task_status == TaskStatus.SUCCEEDED: + break # Break out of the while-loop. The transfer already succeeded, so no need to retry. + elif task_status == TaskStatus.FAILED: + error_str = f"{ts_utc()}: task_wait returned True, but task_status={task_status} for task_id {task_id}. No reason to keep retrying now." + logger.warning(error_str) + # We still need to break, because no matter how long we wait now, nothing will change with the transfer status. + break + else: + error_str = f"{ts_utc()}: task_wait returned True, but task_status={task_status} for task_id {task_id}. Will retry waiting until max_retries is reached." + logger.warning(error_str) + # Don't break -- continue retries logger.info(f"{ts_utc()}: done with wait") except Exception as e: logger.error(f"Unexpected Exception: {e}") - else: - curr_task = transfer_client.get_task(task_id) - task_status = curr_task["status"] - if task_status == "SUCCEEDED": - break finally: retry_count += 1 logger.info( @@ -229,7 +335,7 @@ def globus_block_wait( logger.info( f"{ts_utc()}: BLOCKING EXHAUSTED {max_retries} of timeout {wait_timeout} seconds" ) - task_status = "EXHAUSTED_TIMEOUT_RETRIES" + task_status = TaskStatus.EXHAUSTED_TIMEOUT_RETRIES logger.info( f"{ts_utc()}: BLOCKING ENDS: task_id {task_id} returned from task_wait with status {task_status}" @@ -238,12 +344,13 @@ def globus_block_wait( return task_status -def globus_wait(task_id: str): - +def globus_wait(transfer_client: TransferClient, task_id: str): try: """ - A Globus transfer job (task) can be in one of the three states: - ACTIVE, SUCCEEDED, FAILED. The script every 20 seconds polls a + A Globus transfer job (task) can be in one of the four states: + {ACTIVE, SUCCEEDED, FAILED, INACTIVE} + according to https://docs.globus.org/api/transfer/task/#task_fields. + The script every 20 seconds polls a status of the transfer job (task) from the Globus Transfer service, with 20 second timeout limit. If the task is ACTIVE after time runs out 'task_wait' returns False, and True otherwise. @@ -254,8 +361,8 @@ def globus_wait(task_id: str): The Globus transfer job (task) has been finished (SUCCEEDED or FAILED). Check if the transfer SUCCEEDED or FAILED. """ - task = transfer_client.get_task(task_id) - if task["status"] == "SUCCEEDED": + task: GlobusHTTPResponse = transfer_client.get_task(task_id) + if TaskStatus.convert_from_status_from_globus_sdk(task) == TaskStatus.SUCCEEDED: src_ep = task["source_endpoint_id"] dst_ep = task["destination_endpoint_id"] label = task["label"] @@ -281,44 +388,171 @@ def globus_wait(task_id: str): sys.exit(1) -def globus_finalize(non_blocking: bool = False): - global global_variable_tarfiles_pushed +def _submit_pending_transfer_data( + transfer_client: TransferClient, + transfer_manager: TransferManager, +) -> Optional[str]: + """ + If the most recent batch has unsubmitted TransferData, submit it and return task_id. + Otherwise return None. + """ + transfer: Optional[TransferBatch] = transfer_manager.get_most_recent_batch() + if not transfer or not transfer.transfer_data: + return None - last_task_id = None + update_cumulative_tarfiles_pushed(transfer_manager, transfer.transfer_data) - if transfer_data: - # DEBUG: review accumulated items in TransferData - logger.info(f"{ts_utc()}: FINAL TransferData: accumulated items:") - attribs = transfer_data.__dict__ - for item in attribs["data"]["DATA"]: - if item["DATA_TYPE"] == "transfer_item": - global_variable_tarfiles_pushed += 1 - print( - f" (finalize) PUSHING ({global_variable_tarfiles_pushed}) source item: {item['source_path']}", - flush=True, - ) + logger.info( + f"{ts_utc()}: DIVING: Submit Transfer for {transfer.transfer_data['label']}" + ) + try: + last_task = submit_transfer_with_checks(transfer_client, transfer.transfer_data) + task_id = last_task.get("task_id") - # SUBMIT new transfer here - logger.info(f"{ts_utc()}: DIVING: Submit Transfer for {transfer_data['label']}") - try: - last_task = submit_transfer_with_checks(transfer_client, transfer_data) - last_task_id = last_task.get("task_id") - except TransferAPIError as e: - if e.code == "NoCredException": - logger.error( - "{}. Please go to https://app.globus.org/endpoints and activate the endpoint.".format( - e.message - ) + # Best-effort: if this batch represents the submission, store the task_id. + if task_id and transfer.is_globus and not transfer.task_id: + transfer.task_id = task_id + + return task_id + + except TransferAPIError as e: + if e.code == "NoCredException": + logger.error( + "{}. Please go to https://app.globus.org/endpoints and activate the endpoint.".format( + e.message ) - else: - logger.error(e) - sys.exit(1) - except Exception as e: - logger.error("Exception: {}".format(e)) - sys.exit(1) + ) + else: + logger.error(e) + sys.exit(1) + except Exception as e: + logger.error("Exception: {}".format(e)) + sys.exit(1) + + +def _collect_globus_task_ids( + transfer_manager: TransferManager, extra_task_id: Optional[str], keep: bool +) -> Tuple[List[str], Dict[str, TransferBatch]]: + """ + Return (ordered unique task_ids, task_id->batch mapping for first occurrence). + """ + task_ids: List[str] = [] + seen: Set[str] = set() + task_to_batch: Dict[str, TransferBatch] = {} + + for batch in transfer_manager.batches: + if not keep: + # NOTE: This is always true if `keep` is set, + # since we never track files for deletion if `keep` is set. + already_deleted: bool = not batch.file_paths + if already_deleted: + # This batch has already been processed and files deleted, so we can skip it. + continue + + if (not batch.is_globus) or (not batch.task_id): + continue + + # By this point, we know batch.task_id is not None + tid: str = batch.task_id + if tid in seen: + continue + + seen.add(tid) + task_ids.append(tid) + task_to_batch[tid] = batch + + # Always include extra_task_id (e.g., just-submitted transfer), + # even if not yet reflected in batches. + if extra_task_id and (extra_task_id not in seen): + task_ids.append(extra_task_id) + + return task_ids, task_to_batch + + +def _refresh_batch_status( + transfer_client: TransferClient, + task_id: str, + task_to_batch: Dict[str, TransferBatch], +) -> Optional[TaskStatus]: + """ + Fetch Globus task status and update corresponding batch.task_status if present. + Returns status, or None if fetch fails. + """ + try: + task: GlobusHTTPResponse = transfer_client.get_task(task_id) + status: TaskStatus = TaskStatus.convert_from_status_from_globus_sdk(task) + batch: Optional[TransferBatch] = task_to_batch.get(task_id) + if batch: + batch.task_status = status + return status + except Exception as e: + logger.warning( + f"{ts_utc()}: Could not fetch status for task_id={task_id}; will wait anyway. ({e})" + ) + return None + + +def _wait_for_all_tasks( + transfer_client: TransferClient, + task_ids: List[str], + task_to_batch: Dict[str, TransferBatch], +) -> None: + """ + For each task_id, refresh status; if not SUCCEEDED, block via globus_wait; + then refresh status again for deletion logic. + """ + for tid in task_ids: + status = _refresh_batch_status(transfer_client, tid, task_to_batch) + if status == TaskStatus.SUCCEEDED: + logger.info(f"{ts_utc()}: task_id={tid} already SUCCEEDED; skipping wait") + continue + + logger.info( + f"{ts_utc()}: Waiting for transfer task_id={tid} to complete (status={status})" + ) + globus_wait(transfer_client, tid) + + # After wait returns, task is terminal; refresh once more. + _refresh_batch_status(transfer_client, tid, task_to_batch) + + +def _prune_empty_batches(transfer_manager: TransferManager) -> None: + """ + Remove batches which have no remaining files to manage. + + Note: we only prune batches whose file_paths is empty, regardless of Globus/HPSS. + That matches current semantics where file_paths=[] means "processed". + """ + before = len(transfer_manager.batches) + transfer_manager.batches = [b for b in transfer_manager.batches if b.file_paths] + after = len(transfer_manager.batches) + if after != before: + logger.debug(f"{ts_utc()}: Pruned {before - after} empty transfer batches") + + +def globus_finalize(transfer_manager: TransferManager, keep: bool) -> None: + if transfer_manager.globus_config is None: + logger.debug("No GlobusConfig object provided for finalization") + return + if transfer_manager.globus_config.transfer_client is None: + logger.debug("GlobusConfig provided but transfer_client is None") + return + + # By this point, we know transfer_client is not None + transfer_client: TransferClient = transfer_manager.globus_config.transfer_client + + last_task_id: Optional[str] = _submit_pending_transfer_data( + transfer_client, transfer_manager + ) + + task_ids: List[str] + task_to_batch: Dict[str, TransferBatch] + task_ids, task_to_batch = _collect_globus_task_ids( + transfer_manager, last_task_id, keep + ) + + _wait_for_all_tasks(transfer_client, task_ids, task_to_batch) + + transfer_manager.delete_successfully_transferred_files() - if not non_blocking: - if task_id: - globus_wait(task_id) - if last_task_id: - globus_wait(last_task_id) + _prune_empty_batches(transfer_manager) diff --git a/zstash/globus_utils.py b/zstash/globus_utils.py index e5346f69..9c6f4e2a 100644 --- a/zstash/globus_utils.py +++ b/zstash/globus_utils.py @@ -217,48 +217,61 @@ def save_tokens(token_response): # Primarily used by globus_transfer ########################################### -def set_up_TransferData( +def get_label(remote_path: str, name: str) -> str: + subdir = os.path.basename(os.path.normpath(remote_path)) + subdir_label = re.sub("[^A-Za-z0-9_ -]", "", subdir) + filename = name.split(".")[0] + label = subdir_label + " " + filename + return label + + +def create_TransferData( + transfer_type: str, + local_endpoint: str, + remote_endpoint: str, + transfer_client: TransferClient, + label: str, +) -> TransferData: + if transfer_type == "get": + src_ep = remote_endpoint + dst_ep = local_endpoint + else: + src_ep = local_endpoint + dst_ep = remote_endpoint + transfer_data = TransferData( + transfer_client, + src_ep, + dst_ep, + label=label, + verify_checksum=True, + preserve_timestamp=True, + fail_on_quota_errors=True, + ) + return transfer_data + + +def add_file_to_TransferData( transfer_type: str, local_endpoint: Optional[str], remote_endpoint: Optional[str], remote_path: str, name: str, - transfer_client: TransferClient, - transfer_data: Optional[TransferData] = None, -) -> TransferData: + transfer_data: TransferData, + label: str, +): if not local_endpoint: raise ValueError("Local endpoint ID is not set.") if not remote_endpoint: raise ValueError("Remote endpoint ID is not set.") if transfer_type == "get": - src_ep = remote_endpoint src_path = os.path.join(remote_path, name) - dst_ep = local_endpoint dst_path = os.path.join(os.getcwd(), name) else: - src_ep = local_endpoint src_path = os.path.join(os.getcwd(), name) - dst_ep = remote_endpoint dst_path = os.path.join(remote_path, name) - subdir = os.path.basename(os.path.normpath(remote_path)) - subdir_label = re.sub("[^A-Za-z0-9_ -]", "", subdir) - filename = name.split(".")[0] - label = subdir_label + " " + filename - - if not transfer_data: - transfer_data = TransferData( - transfer_client, - src_ep, - dst_ep, - label=label, - verify_checksum=True, - preserve_timestamp=True, - fail_on_quota_errors=True, - ) transfer_data.add_item(src_path, dst_path) transfer_data["label"] = label - return transfer_data def submit_transfer_with_checks(transfer_client, transfer_data) -> GlobusHTTPResponse: diff --git a/zstash/hpss.py b/zstash/hpss.py index 24603388..19855488 100644 --- a/zstash/hpss.py +++ b/zstash/hpss.py @@ -2,16 +2,93 @@ import os.path import subprocess -from typing import List +from typing import List, Optional from six.moves.urllib.parse import urlparse from .globus import globus_transfer from .settings import get_db_filename, logger +from .transfer_tracking import GlobusConfig, TaskStatus, TransferBatch, TransferManager from .utils import run_command, ts_utc -prev_transfers: List[str] = list() -curr_transfers: List[str] = list() + +def ensure_transfer_batch(transfer_manager: TransferManager, scheme: str): + # Create a new batch if needed (before we start adding files) + if not transfer_manager.batches or transfer_manager.batches[-1].task_id: + # Either no batches exist, or the last batch was already submitted + new_batch = TransferBatch() + new_batch.is_globus = scheme == "globus" + transfer_manager.batches.append(new_batch) + logger.debug( + f"{ts_utc()}: Created new TransferBatch, total batches: {len(transfer_manager.batches)}" + ) + + +def local_put(file_path: str, cache: str): + if file_path != get_db_filename(cache): + # We are adding a file (that is not the cache) to the local non-HPSS archive + logger.info("put: Keeping tar files locally and removing write permissions") + # https://unix.stackexchange.com/questions/46915/get-the-chmod-numerical-value-for-a-file + display_mode_command: List[str] = "stat --format '%a' {}".format( + file_path + ).split() + display_mode_output: bytes = subprocess.check_output( + display_mode_command + ).strip() + logger.info("{!r} original mode={!r}".format(file_path, display_mode_output)) + # https://www.washington.edu/doit/technology-tips-chmod-overview + # Remove write-permission from user, group, and others, + # without changing read or execute permissions for any. + change_mode_command: List[str] = "chmod ugo-w {}".format(file_path).split() + # An error will be raised if this line fails. + subprocess.check_output(change_mode_command) + new_display_mode_output: bytes = subprocess.check_output( + display_mode_command + ).strip() + logger.info("{!r} new mode={!r}".format(file_path, new_display_mode_output)) + # else: no action needed + + +def run_hpss_put(hpss: str, name: str): + command: str = f'hsi -q "cd {hpss}; put {name}"' + error_str: str = f"Transferring file to HPSS: {name}" + run_command(command, error_str) + + +def run_hpss_get(hpss: str, name: str): + command: str = f'hsi -q "cd {hpss}; get {name}"' + error_str: str = f"Transferring file from HPSS: {name}" + run_command(command, error_str) + + +def globus_transfer_wrapper( + transfer_manager: TransferManager, + endpoint: str, + url_path: str, + name: str, + transfer_type: str, + non_blocking: bool, +): + if not transfer_manager.globus_config: + transfer_manager.globus_config = GlobusConfig() + # Transfer file using the Globus Transfer Service + logger.info(f"{ts_utc()}: DIVING: hpss calls globus_transfer(name={name})") + task_status: TaskStatus = globus_transfer( + transfer_manager, endpoint, url_path, name, transfer_type, non_blocking + ) + logger.info( + f"{ts_utc()}: SURFACE: hpss globus_transfer(name={name}) returned task_status={task_status}" + ) + mrb: Optional[TransferBatch] = transfer_manager.get_most_recent_batch() + if mrb and mrb.task_status: + globus_status: TaskStatus = mrb.task_status + logger.info( + f"{ts_utc()}: Most recent globus_transfer returned task_status={globus_status}" + ) + # NOTE: Here, the status could be "EXHAUSTED_TIMEOUT_RETRIES", meaning a very long transfer + # or perhaps transfer is hanging. We should decide whether to ignore it, or cancel it, but + # we'd need the task_id to issue a cancellation. Perhaps we should have globus_transfer + # return a tuple (task_id, status). def hpss_transfer( @@ -22,75 +99,45 @@ def hpss_transfer( keep: bool = False, non_blocking: bool = False, is_index: bool = False, + transfer_manager: Optional[TransferManager] = None, ): - global prev_transfers - global curr_transfers + if not transfer_manager: + transfer_manager = TransferManager() - logger.info( - f"{ts_utc()}: in hpss_transfer, prev_transfers is starting as {prev_transfers}" - ) - # logger.debug( - # f"{ts_utc()}: in hpss_transfer, curr_transfers is starting as {curr_transfers}" - # ) + url = urlparse(hpss) + scheme = url.scheme + + ensure_transfer_batch(transfer_manager, scheme) if hpss == "none": logger.info("{}: HPSS is unavailable".format(transfer_type)) - if transfer_type == "put" and file_path != get_db_filename(cache): - # We are adding a file (that is not the cache) to the local non-HPSS archive - logger.info( - "{}: Keeping tar files locally and removing write permissions".format( - transfer_type - ) - ) - # https://unix.stackexchange.com/questions/46915/get-the-chmod-numerical-value-for-a-file - display_mode_command: List[str] = "stat --format '%a' {}".format( - file_path - ).split() - display_mode_output: bytes = subprocess.check_output( - display_mode_command - ).strip() - logger.info( - "{!r} original mode={!r}".format(file_path, display_mode_output) - ) - # https://www.washington.edu/doit/technology-tips-chmod-overview - # Remove write-permission from user, group, and others, - # without changing read or execute permissions for any. - change_mode_command: List[str] = "chmod ugo-w {}".format(file_path).split() - # An error will be raised if this line fails. - subprocess.check_output(change_mode_command) - new_display_mode_output: bytes = subprocess.check_output( - display_mode_command - ).strip() - logger.info("{!r} new mode={!r}".format(file_path, new_display_mode_output)) + if transfer_type == "put": + local_put(file_path, cache) # else: no action needed else: transfer_word: str - transfer_command: str if transfer_type == "put": transfer_word = "to" - transfer_command = "put" elif transfer_type == "get": transfer_word = "from" - transfer_command = "get" else: raise ValueError("Invalid transfer_type={}".format(transfer_type)) logger.info("Transferring file {} HPSS: {}".format(transfer_word, file_path)) - scheme: str - endpoint: str - path: str - name: str - url = urlparse(hpss) - scheme = url.scheme - endpoint = url.netloc + endpoint: str = url.netloc url_path = url.path - - curr_transfers.append(file_path) - # logger.debug( - # f"{ts_utc()}: curr_transfers has been appended to, is now {curr_transfers}" - # ) + path: str + name: str path, name = os.path.split(file_path) + # Never track index.db for deletion, only the tar files + if (not keep) and (not is_index): + # Add this tar file to the current batch + transfer_manager.batches[-1].file_paths.append(file_path) + logger.debug( + f"{ts_utc()}: Added {file_path} to current batch, batch now has {len(transfer_manager.batches[-1].file_paths)} files" + ) + # Need to be in local directory for `hsi` to work cwd = os.getcwd() if path != "": @@ -105,24 +152,15 @@ def hpss_transfer( os.chdir(path) if scheme == "globus": - globus_status = "UNKNOWN" - # Transfer file using the Globus Transfer Service - logger.info(f"{ts_utc()}: DIVING: hpss calls globus_transfer(name={name})") - globus_status = globus_transfer( - endpoint, url_path, name, transfer_type, non_blocking - ) - logger.info( - f"{ts_utc()}: SURFACE hpss globus_transfer(name={name}) returns {globus_status}" + globus_transfer_wrapper( + transfer_manager, endpoint, url_path, name, transfer_type, non_blocking ) - # NOTE: Here, the status could be "EXHAUSTED_TIMEOUT_RETRIES", meaning a very long transfer - # or perhaps transfer is hanging. We should decide whether to ignore it, or cancel it, but - # we'd need the task_id to issue a cancellation. Perhaps we should have globus_transfer - # return a tuple (task_id, status). else: # Transfer file using `hsi` - command: str = 'hsi -q "cd {}; {} {}"'.format(hpss, transfer_command, name) - error_str: str = "Transferring file {} HPSS: {}".format(transfer_word, name) - run_command(command, error_str) + if transfer_type == "put": + run_hpss_put(hpss, name) + else: + run_hpss_get(hpss, name) # Return to original working directory if path != "": @@ -130,25 +168,15 @@ def hpss_transfer( if transfer_type == "put": if not keep: - if (scheme != "globus") or (globus_status == "SUCCEEDED"): - # Note: This is intended to fulfill the default removal of successfully-transfered - # tar files when keep=False, irrespective of non-blocking status - logger.debug( - f"{ts_utc()}: deleting transfered files {prev_transfers}" - ) - for src_path in prev_transfers: - os.remove(src_path) - prev_transfers = curr_transfers - curr_transfers = list() - logger.info( - f"{ts_utc()}: prev_transfers has been set to {prev_transfers}" - ) + # We never delete if `--keep` is set. + transfer_manager.delete_successfully_transferred_files() def hpss_put( hpss: str, file_path: str, cache: str, + transfer_manager: TransferManager, keep: bool = True, non_blocking: bool = False, is_index=False, @@ -156,14 +184,35 @@ def hpss_put( """ Put a file to the HPSS archive. """ - hpss_transfer(hpss, file_path, "put", cache, keep, non_blocking, is_index) + hpss_transfer( + hpss, + file_path, + "put", + cache, + keep, + non_blocking, + is_index, + transfer_manager, + ) -def hpss_get(hpss: str, file_path: str, cache: str): +def hpss_get( + hpss: str, + file_path: str, + cache: str, + transfer_manager: Optional[TransferManager] = None, +): """ Get a file from the HPSS archive. """ - hpss_transfer(hpss, file_path, "get", cache, False) + url = urlparse(hpss) + if not transfer_manager: + transfer_manager = TransferManager() + if (url.scheme == "globus") and not (transfer_manager.globus_config): + transfer_manager.globus_config = GlobusConfig() + hpss_transfer( + hpss, file_path, "get", cache, False, transfer_manager=transfer_manager + ) def hpss_chgrp(hpss: str, group: str, recurse: bool = False): diff --git a/zstash/hpss_utils.py b/zstash/hpss_utils.py index 2f1158bc..0956cd89 100644 --- a/zstash/hpss_utils.py +++ b/zstash/hpss_utils.py @@ -7,15 +7,240 @@ import tarfile import traceback from datetime import datetime -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import _hashlib from .hpss import hpss_put from .settings import TupleFilesRowNoId, TupleTarsRowNoId, config, logger +from .transfer_tracking import TransferManager from .utils import create_tars_table, tars_table_exists, ts_utc +# This class holds parameters for developer options. +# I.e., these parameters should only ever be activated by developers during debugging and/or testing. +class DevOptions(object): + def __init__( + self, + error_on_duplicate_tar: bool, + overwrite_duplicate_tars: bool, + force_database_corruption: str, + ): + self.error_on_duplicate_tar: bool = error_on_duplicate_tar + self.overwrite_duplicate_tars: bool = overwrite_duplicate_tars + self.force_database_corruption: str = force_database_corruption + + def simulate_row_existing( + self, + tfname: str, + cur: sqlite3.Cursor, + tar_tuple: TupleTarsRowNoId, + tar_size: int, + tar_md5: Optional[str], + ): + if self.force_database_corruption == "simulate_row_existing": + # Tested by database_corruption.bash Cases 3, 5 + logger.info( + f"TESTING/DEBUGGING ONLY: Simulating row existing for {tfname}." + ) + cur.execute("INSERT INTO tars VALUES (NULL,?,?,?)", tar_tuple) + elif self.force_database_corruption == "simulate_row_existing_bad_size": + # Tested by database_corruption.bash Cases 4, 7 + logger.info( + f"TESTING/DEBUGGING ONLY: Simulating row existing with bad size for {tfname}." + ) + cur.execute( + "INSERT INTO tars VALUES (NULL,?,?,?)", + (tfname, tar_size + 1000, tar_md5), + ) + + +class TarWrapper(object): + def __init__(self, tar_num: int, cache: str, do_hash: bool, follow_symlinks: bool): + # Create a hex value at least 6 digits long + tname: str = "{0:0{1}x}".format(tar_num, 6) + # Create the tar file name by adding ".tar" + self.tfname: str = f"{tname}.tar" + logger.info(f"{ts_utc()}: Creating new tar archive {self.tfname}") + # Open that tar file in the cache + self.tarFileObject = HashIO(os.path.join(cache, self.tfname), "wb", do_hash) + # FIXME: error: Argument "fileobj" to "open" has incompatible type "HashIO"; expected "Optional[IO[bytes]]" + self.tar = tarfile.open(mode="w", fileobj=self.tarFileObject, dereference=follow_symlinks) # type: ignore + + def process_file( + self, + current_file: str, + tar_info: tarfile.TarInfo, + archived: List[TupleFilesRowNoId], + failures: List[str], + ) -> int: + logger.info(f"Archiving {current_file}") + tar_size: int = 0 + try: + offset: int + size: int + mtime: datetime + md5: Optional[str] + offset, size, mtime, md5 = add_file_to_tar_archive( + self.tar, current_file, tar_info + ) + t: TupleFilesRowNoId = ( + current_file, + size, + mtime, + md5, + self.tfname, + offset, + ) + archived.append(t) + # Increase tar_size by the size of the current file. + # Use `tell()` to also include the tar's metadata in the size. + tar_size = self.tarFileObject.tell() + except Exception: + # Catch all exceptions here. + traceback.print_exc() + logger.error(f"Archiving {current_file}") + failures.append(current_file) + return tar_size + + def process_tar( + self, + cache: str, + keep: bool, + non_blocking: bool, + transfer_manager: TransferManager, + skip_tars_table: bool, + cur: sqlite3.Cursor, + con: sqlite3.Connection, + dev_options: DevOptions, + archived: List[TupleFilesRowNoId], + ): + # 1. Close the tar #################################################### + logger.debug(f"{ts_utc()}: Closing tar archive {self.tfname}") + self.tar.close() + + tar_size = self.tarFileObject.tell() + tar_md5: Optional[str] = self.tarFileObject.md5() + self.tarFileObject.close() + logger.info(f"{ts_utc()}: (process_tar): Completed archive file {self.tfname}") + + # 2. Submit the tar to the transfer manager's batch transfer system ### + if config.hpss is not None: + hpss: str = config.hpss + else: + raise TypeError("Invalid config.hpss={}".format(config.hpss)) + + logger.debug(f"Contents of the cache prior to `hpss_put`: {os.listdir(cache)}") + + logger.info( + f"{ts_utc()}: DIVING: (process_tar): Calling hpss_put to dispatch archive file {self.tfname} [keep, non_blocking] = [{keep}, {non_blocking}]" + ) + # Actually submit the tar file + hpss_put( + hpss, + os.path.join(cache, self.tfname), + cache, + transfer_manager, + keep, + non_blocking, + is_index=False, + ) + logger.info( + f"{ts_utc()}: SURFACE (process_tar): Called hpss_put to dispatch archive file {self.tfname}" + ) + + # 3. Add the tar itself to the tars table ############################# + if not skip_tars_table: + tar_tuple: TupleTarsRowNoId = (self.tfname, tar_size, tar_md5) + logger.info("tar name={}, tar size={}, tar md5={}".format(*tar_tuple)) + if not tars_table_exists(cur): + # Need to create tars table + create_tars_table(cur, con) + + # For developers only! For debugging/testing purposes only! + dev_options.simulate_row_existing( + self.tfname, cur, tar_tuple, tar_size, tar_md5 + ) + + # We're done adding files to the tar. + # And we've transferred it to HPSS. + # Now we can insert the tar into the database. + cur.execute("SELECT COUNT(*) FROM tars WHERE name = ?", (self.tfname,)) + tar_count: int = cur.fetchone()[0] + if tar_count != 0: + error_str: str = ( + f"Database corruption detected! {self.tfname} is already in the database." + ) + if dev_options.error_on_duplicate_tar: + # Tested by database_corruption.bash Case 3 + # Exists - error out + logger.error(error_str) + raise RuntimeError(error_str) + elif dev_options.overwrite_duplicate_tars: + # Tested by database_corruption.bash Case 4 + # Exists - update with new size and md5 + logger.warning(error_str) + logger.warning(f"Updating existing tar {self.tfname} to proceed.") + cur.execute( + "UPDATE tars SET size = ?, md5 = ? WHERE name = ?", + (tar_size, tar_md5, self.tfname), + ) + else: + # Tested by database_corruption.bash Cases 5,7 + # Proceed as if we're in the typical case -- insert new + logger.warning(error_str) + logger.warning(f"Adding a new entry for {self.tfname}.") + cur.execute("INSERT INTO tars VALUES (NULL,?,?,?)", tar_tuple) + elif dev_options.force_database_corruption == "simulate_no_correct_size": + # Tested by database_corruption.bash Case 6 + # For developers only! For debugging purposes only! + # Add this tar twice, with different sizes. + logger.info( + f"TESTING/DEBUGGING ONLY: Simulating no correct size for {self.tfname}." + ) + cur.execute( + "INSERT INTO tars VALUES (NULL,?,?,?)", + (self.tfname, tar_size + 1000, tar_md5), + ) + cur.execute( + "INSERT INTO tars VALUES (NULL,?,?,?)", + (self.tfname, tar_size + 2000, tar_md5), + ) + elif ( + dev_options.force_database_corruption + == "simulate_bad_size_for_most_recent" + ): + # Tested by database_corruption.bash Case 8 + # For developers only! For debugging purposes only! + # Add this tar twice, second time with bad size. + logger.info( + f"TESTING/DEBUGGING ONLY: Simulating bad size for most recent entry for {self.tfname}." + ) + cur.execute( + "INSERT INTO tars VALUES (NULL,?,?,?)", + (self.tfname, tar_size, tar_md5), + ) + cur.execute( + "INSERT INTO tars VALUES (NULL,?,?,?)", + (self.tfname, tar_size + 2000, tar_md5), + ) + else: + # Tested by database_corruption.bash Cases 1,2 + # Typical case + # Doesn't exist - insert new + logger.info(f"Adding {self.tfname} to the database.") + cur.execute("INSERT INTO tars VALUES (NULL,?,?,?)", tar_tuple) + + con.commit() + + # 4. Add the files included in this tar to the files table ############ + # Update database with the individual files that have been archived + # Add a row to the "files" table, + # the last 6 columns matching the values of `archived` + cur.executemany("insert into files values (NULL,?,?,?,?,?,?)", archived) + con.commit() + + # Minimum output file object class HashIO(object): def __init__(self, name: str, mode: str, do_hash: bool): @@ -32,6 +257,18 @@ def tell(self) -> int: return self.position def write(self, s): + """ + This is called implicitly. + In TarWrapper.__init__: + + ``` + self.tarFileObject = HashIO(os.path.join(cache, self.tfname), "wb", do_hash) + self.tar = tarfile.open(mode="w", fileobj=self.tarFileObject, dereference=follow_symlinks) + ``` + + tarfile.open requires that the fileobj argument has a write() method. + It calls that method to write data to the tar file. + """ self.f.write(s) if self.hash: self.hash.update(s) @@ -53,218 +290,154 @@ def close(self): self.closed = True -def add_files( +def estimate_tar_entry_size(file_size: int) -> int: + """ + Estimate how much space a file of a given size would take in the tar archive, + including metadata and padding. + """ + TAR_BLOCK_SIZE = 512 + TAR_HEADER_SIZE = 512 # per file header + # This formula computes: ceil(file_size / TAR_BLOCK_SIZE) + # But faster and avoiding floats. + data_blocks = (file_size + TAR_BLOCK_SIZE - 1) // TAR_BLOCK_SIZE + return TAR_HEADER_SIZE + (data_blocks * TAR_BLOCK_SIZE) + + +# Add file to tar archive while computing its hash +# Return file offset (in tar archive), size and md5 hash +def add_file_to_tar_archive( + tar: tarfile.TarFile, file_name: str, tar_info: tarfile.TarInfo +) -> Tuple[int, int, datetime, Optional[str]]: + offset = tar.offset + + md5: Optional[str] = None + + # For files/hardlinks + if tar_info.isfile() or tar_info.islnk(): + if tar_info.size > 0: + # Non-empty files: stream with hash computation + hash_md5 = hashlib.md5() + with open(file_name, "rb") as f: + wrapper = HashingFileWrapper(f, hash_md5) + tar.addfile(tar_info, wrapper) + md5 = hash_md5.hexdigest() + else: + # Empty files: just add to tar, compute hash of empty data + tar.addfile(tar_info) + md5 = hashlib.md5(b"").hexdigest() # MD5 of empty bytes + else: + # Directories, symlinks, etc. + # md5 will be None in these cases. + tar.addfile(tar_info) + + size = tar_info.size + mtime = datetime.utcfromtimestamp(tar_info.mtime) + return offset, size, mtime, md5 + + +def construct_tars( cur: sqlite3.Cursor, con: sqlite3.Connection, itar: int, - files: List[str], + file_stats: Dict[str, Tuple[int, datetime]], cache: str, keep: bool, follow_symlinks: bool, - skip_tars_md5: bool = False, + dev_options: DevOptions, + transfer_manager: TransferManager, + skip_tars_table: bool = False, non_blocking: bool = False, - error_on_duplicate_tar: bool = False, - overwrite_duplicate_tars: bool = False, - force_database_corruption: str = "", ) -> List[str]: - # Now, perform the actual archiving failures: List[str] = [] - create_new_tar: bool = True + files: List[str] = list(file_stats.keys()) nfiles: int = len(files) - archived: List[TupleFilesRowNoId] - tarsize: int - tname: str - tfname: str - tarFileObject: HashIO - tar: tarfile.TarFile - for i in range(nfiles): - - # New tar in the local cache - if create_new_tar: - create_new_tar = False - archived = [] - tarsize = 0 - itar += 1 - # Create a hex value at least 6 digits long - tname = "{0:0{1}x}".format(itar, 6) - # Create the tar file name by adding ".tar" - tfname = "{}.tar".format(tname) - logger.info(f"{ts_utc()}: Creating new tar archive {tfname}") - # Open that tar file in the cache - do_hash: bool - if not skip_tars_md5: - # If we're not skipping tars, we want to calculate the hash of the tars. - do_hash = True - else: - do_hash = False - tarFileObject = HashIO(os.path.join(cache, tfname), "wb", do_hash) - # FIXME: error: Argument "fileobj" to "open" has incompatible type "HashIO"; expected "Optional[IO[bytes]]" - tar = tarfile.open(mode="w", fileobj=tarFileObject, dereference=follow_symlinks) # type: ignore - - # Add current file to tar archive - current_file: str = files[i] - logger.info("Archiving {}".format(current_file)) - try: - offset: int - size: int - mtime: datetime - md5: Optional[str] - offset, size, mtime, md5 = add_file(tar, current_file, follow_symlinks) - t: TupleFilesRowNoId = ( - current_file, - size, - mtime, - md5, - tfname, - offset, - ) - archived.append(t) - # Increase tarsize by the size of the current file. - # Use `tell()` to also include the tar's metadata in the size. - tarsize = tarFileObject.tell() - except Exception: - # Catch all exceptions here. - traceback.print_exc() - logger.error("Archiving {}".format(current_file)) - failures.append(current_file) - - # Close tar archive if current file is the last one or - # if adding one more would push us over the limit. - next_file_size: int = tar.gettarinfo(current_file).size - if config.maxsize is not None: - maxsize: int = config.maxsize - else: - raise TypeError("Invalid config.maxsize={}".format(config.maxsize)) - if i == nfiles - 1 or tarsize + next_file_size > maxsize: - - # Close current temporary file - logger.debug(f"{ts_utc()}: Closing tar archive {tfname}") - tar.close() - - tarsize = tarFileObject.tell() - tar_md5: Optional[str] = tarFileObject.md5() - tarFileObject.close() - logger.info(f"{ts_utc()}: (add_files): Completed archive file {tfname}") - - # Transfer tar to HPSS - if config.hpss is not None: - hpss: str = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) - - logger.info( - f"Contents of the cache prior to `hpss_put`: {os.listdir(cache)}" - ) - logger.info( - f"{ts_utc()}: DIVING: (add_files): Calling hpss_put to dispatch archive file {tfname} [keep, non_blocking] = [{keep}, {non_blocking}]" - ) - hpss_put(hpss, os.path.join(cache, tfname), cache, keep, non_blocking) - logger.info( - f"{ts_utc()}: SURFACE (add_files): Called hpss_put to dispatch archive file {tfname}" - ) - - if not skip_tars_md5: - tar_tuple: TupleTarsRowNoId = (tfname, tarsize, tar_md5) - logger.info("tar name={}, tar size={}, tar md5={}".format(*tar_tuple)) - if not tars_table_exists(cur): - # Need to create tars table - create_tars_table(cur, con) - - # For developers only! For debugging purposes only! - if force_database_corruption == "simulate_row_existing": - # Tested by database_corruption.bash Cases 3, 5 - logger.info( - f"TESTING/DEBUGGING ONLY: Simulating row existing for {tfname}." - ) - cur.execute("INSERT INTO tars VALUES (NULL,?,?,?)", tar_tuple) - elif force_database_corruption == "simulate_row_existing_bad_size": - # Tested by database_corruption.bash CaseS 4, 7 - logger.info( - f"TESTING/DEBUGGING ONLY: Simulating row existing with bad size for {tfname}." - ) - cur.execute( - "INSERT INTO tars VALUES (NULL,?,?,?)", - (tfname, tarsize + 1000, tar_md5), - ) + if config.maxsize is not None: + max_size: int = config.maxsize + else: + raise TypeError(f"Invalid config.maxsize={config.maxsize}") - # We're done adding files to the tar. - # And we've transferred it to HPSS. - # Now we can insert the tar into the database. - cur.execute("SELECT COUNT(*) FROM tars WHERE name = ?", (tfname,)) - tar_count: int = cur.fetchone()[0] - if tar_count != 0: - error_str: str = ( - f"Database corruption detected! {tfname} is already in the database." - ) - if error_on_duplicate_tar: - # Tested by database_corruption.bash Case 3 - # Exists - error out - logger.error(error_str) - raise RuntimeError(error_str) - elif overwrite_duplicate_tars: - # Tested by database_corruption.bash Case 4 - # Exists - update with new size and md5 - logger.warning(error_str) - logger.warning(f"Updating existing tar {tfname} to proceed.") - cur.execute( - "UPDATE tars SET size = ?, md5 = ? WHERE name = ?", - (tarsize, tar_md5, tfname), - ) - else: - # Tested by database_corruption.bash Cases 5,7 - # Proceed as if we're in the typical case -- insert new - logger.warning(error_str) - logger.warning(f"Adding a new entry for {tfname}.") - cur.execute("INSERT INTO tars VALUES (NULL,?,?,?)", tar_tuple) - elif force_database_corruption == "simulate_no_correct_size": - # Tested by database_corruption.bash Case 6 - # For developers only! For debugging purposes only! - # Add this tar twice, with different sizes. - logger.info( - f"TESTING/DEBUGGING ONLY: Simulating no correct size for {tfname}." - ) - cur.execute( - "INSERT INTO tars VALUES (NULL,?,?,?)", - (tfname, tarsize + 1000, tar_md5), - ) - cur.execute( - "INSERT INTO tars VALUES (NULL,?,?,?)", - (tfname, tarsize + 2000, tar_md5), - ) - elif force_database_corruption == "simulate_bad_size_for_most_recent": - # Tested by database_corruption.bash Case 8 - # For developers only! For debugging purposes only! - # Add this tar twice, second time with bad size. - logger.info( - f"TESTING/DEBUGGING ONLY: Simulating bad size for most recent entry for {tfname}." - ) - cur.execute( - "INSERT INTO tars VALUES (NULL,?,?,?)", - (tfname, tarsize, tar_md5), - ) - cur.execute( - "INSERT INTO tars VALUES (NULL,?,?,?)", - (tfname, tarsize + 2000, tar_md5), + operation: str + if itar == -1: + operation = "creation" + else: + operation = "update" + + i_file: int = 0 + while i_file < nfiles: + # Each iteration of this loop constructs one tar + + # `create` passes in itar=-1, so the first tar will be 000000.tar + # `update` passes in itar=max existing tar number, so the first tar will be max+1 + itar += 1 + cumulative_tar_size: int = 0 + archived: List[TupleFilesRowNoId] = [] + + # Open a new tar + # Note: if we're not skipping the tars table, then we DO want to calculate the hash of the tars. + # That is, we DO want to add the tar to the tars table in the database. + # That means we need to calculate the hash of the tar file as well. + # + # We ALWAYS want to calculate the hashes of the individual files, regardless of skip_tars_table, + # because we need to add those to the files table. + tar_wrapper = TarWrapper( + tar_num=itar, + cache=cache, + do_hash=not skip_tars_table, + follow_symlinks=follow_symlinks, + ) + + # Add files to the tar until we reach the max size + while i_file < nfiles: + current_file: str = files[i_file] + current_file_size: int + current_file_size, _ = file_stats[current_file] + estimated_entry_size: int = estimate_tar_entry_size(current_file_size) + if (cumulative_tar_size != 0) and ( + cumulative_tar_size + estimated_entry_size > max_size + ): + # Over the size limit: time to close and transfer this tar archive. + # Done adding files to this particular tar. + # Break out of the inner while-loop + break + # If we make it this far, + # we know we can add the current file without going over the max size. + # (Either that, or the tar is currently empty, + # in which case we add the file even if it's over the max size.) + try: + tar_info = tar_wrapper.tar.gettarinfo(current_file) + if tar_info.islnk(): + tar_info.size = os.path.getsize(current_file) + except FileNotFoundError: + logger.error(f"Archiving {current_file}") + if follow_symlinks: + raise Exception( + f"Archive {operation} failed due to broken symlink." ) else: - # Tested by database_corruption.bash Cases 1,2 - # Typical case - # Doesn't exist - insert new - logger.info(f"Adding {tfname} to the database.") - cur.execute("INSERT INTO tars VALUES (NULL,?,?,?)", tar_tuple) - - con.commit() - - # Update database with the individual files that have been archived - # Add a row to the "files" table, - # the last 6 columns matching the values of `archived` - cur.executemany("insert into files values (NULL,?,?,?,?,?,?)", archived) - con.commit() - - # Open new tar next time - create_new_tar = True + raise + new_cumulative_tar_size = tar_wrapper.process_file( + current_file, tar_info, archived, failures + ) + if new_cumulative_tar_size != 0: + # Update the cumulative tar size with the new tar size returned by process_file. + cumulative_tar_size = new_cumulative_tar_size + # Else: process_file failed, so we should keep the original cumulative_tar_size + i_file += 1 + + # Close the tar, submit it to the batch transfer system, and update the database with the archived files (and optionally the tar as well, depending on skip_tars_table) + tar_wrapper.process_tar( + cache, + keep, + non_blocking, + transfer_manager, + skip_tars_table, + cur, + con, + dev_options, + archived, + ) return failures @@ -280,38 +453,3 @@ def read(self, size=-1): if data: self.hasher.update(data) return data - - -# Add file to tar archive while computing its hash -# Return file offset (in tar archive), size and md5 hash -def add_file( - tar: tarfile.TarFile, file_name: str, follow_symlinks: bool -) -> Tuple[int, int, datetime, Optional[str]]: - offset = tar.offset - tarinfo = tar.gettarinfo(file_name) - - if tarinfo.islnk(): - tarinfo.size = os.path.getsize(file_name) - - md5 = None - - # For files/hardlinks - if tarinfo.isfile() or tarinfo.islnk(): - if tarinfo.size > 0: - # Non-empty files: stream with hash computation - hash_md5 = hashlib.md5() - with open(file_name, "rb") as f: - wrapper = HashingFileWrapper(f, hash_md5) - tar.addfile(tarinfo, wrapper) - md5 = hash_md5.hexdigest() - else: - # Empty files: just add to tar, compute hash of empty data - tar.addfile(tarinfo) - md5 = hashlib.md5(b"").hexdigest() # MD5 of empty bytes - else: - # Directories, symlinks, etc. - tar.addfile(tarinfo) - - size = tarinfo.size - mtime = datetime.utcfromtimestamp(tarinfo.mtime) - return offset, size, mtime, md5 diff --git a/zstash/transfer_tracking.py b/zstash/transfer_tracking.py new file mode 100644 index 00000000..1cac18d5 --- /dev/null +++ b/zstash/transfer_tracking.py @@ -0,0 +1,107 @@ +import os +from enum import Enum, auto +from typing import List, Optional + +from globus_sdk import TransferClient, TransferData +from globus_sdk.response import GlobusHTTPResponse +from globus_sdk.services.transfer.response.iterable import IterableTransferResponse + +from .settings import logger +from .utils import ts_utc + + +class GlobusConfig: + """Globus connection configuration""" + + def __init__(self): + self.remote_endpoint: Optional[str] = None + self.local_endpoint: Optional[str] = None + self.transfer_client: Optional[TransferClient] = None + self.archive_directory_listing: Optional[IterableTransferResponse] = None + + +class TaskStatus(Enum): + """Enum for Globus transfer task status""" + + # The first 4 values are defined by the Globus API: + # https://docs.globus.org/api/transfer/task/#task_fields + SUCCEEDED = auto() + ACTIVE = auto() + INACTIVE = auto() + FAILED = auto() + # The last 3 values are custom statuses we add on. + UNKNOWN = auto() + SUBMITTED = auto() + EXHAUSTED_TIMEOUT_RETRIES = auto() + + @classmethod + def convert_from_status_from_globus_sdk(cls, globus_task: GlobusHTTPResponse): + """Convert a Globus API status string to a TaskStatus enum value""" + status_from_globus_sdk: str = globus_task["status"] + status_from_globus_sdk = status_from_globus_sdk.upper() + if status_from_globus_sdk == "SUCCEEDED": + return TaskStatus.SUCCEEDED + elif status_from_globus_sdk == "ACTIVE": + return TaskStatus.ACTIVE + elif status_from_globus_sdk == "INACTIVE": + return TaskStatus.INACTIVE + elif status_from_globus_sdk == "FAILED": + return TaskStatus.FAILED + else: + logger.warning( + f"Received unrecognized Globus status: {status_from_globus_sdk}" + ) + return TaskStatus.UNKNOWN + + def __str__(self) -> str: + return self.name + + +class TransferBatch: + """Represents one batch of files being transferred""" + + def __init__(self): + self.file_paths: List[str] = [] + self.task_id: Optional[str] = None + self.task_status: Optional[TaskStatus] = None + self.is_globus: bool = False + self.transfer_data: Optional[TransferData] = None # Only for Globus + + def delete_files(self): + for src_path in self.file_paths: + try: + os.remove(src_path) + except FileNotFoundError: + logger.warning(f"File already deleted: {src_path}") + + +class TransferManager: + def __init__(self): + # All transfer batches (Globus or HPSS) + self.batches: List[TransferBatch] = [] + self.cumulative_tarfiles_pushed: int = 0 + + # Connection state (Globus-specific, None if not using Globus) + self.globus_config: Optional[GlobusConfig] = None + + def get_most_recent_batch(self) -> Optional[TransferBatch]: + """Get the last batch added to the manager, or None if no batches exist""" + return self.batches[-1] if self.batches else None + + def delete_successfully_transferred_files(self): + """Check transfer status and delete files from successful transfers""" + logger.info( + f"{ts_utc()}: Checking for successfully transferred files to delete" + ) + # Clean up empty batches first + self.batches = [batch for batch in self.batches if batch.file_paths] + # Now delete files for successful transfers + for batch in self.batches: + if (not batch.is_globus) or (batch.task_status == TaskStatus.SUCCEEDED): + # The files were transferred successfully, so delete them + logger.info( + f"{ts_utc()}: Deleting {len(batch.file_paths)} files from successful transfer" + ) + batch.delete_files() + logger.debug("Deletion completed") + batch.file_paths = [] # Mark as processed diff --git a/zstash/update.py b/zstash/update.py index ebc4f84c..69f72e0f 100644 --- a/zstash/update.py +++ b/zstash/update.py @@ -10,8 +10,9 @@ from .globus import globus_activate, globus_finalize from .hpss import hpss_get, hpss_put -from .hpss_utils import add_files +from .hpss_utils import DevOptions, construct_tars from .settings import DEFAULT_CACHE, TIME_TOL, config, get_db_filename, logger +from .transfer_tracking import TransferManager from .utils import get_files_to_archive_with_stats, update_config @@ -21,7 +22,8 @@ def update(): cache: str args, cache = setup_update() - result: Optional[List[str]] = update_database(args, cache) + transfer_manager = TransferManager() + result: Optional[List[str]] = update_database(args, cache, transfer_manager) if result is None: # There was either nothing to update or `--dry-run` was set. @@ -34,9 +36,16 @@ def update(): hpss = config.hpss else: raise TypeError("Invalid config.hpss={}".format(config.hpss)) - hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True) + hpss_put( + hpss, + get_db_filename(cache), + cache, + transfer_manager, + keep=args.keep, + is_index=True, + ) - globus_finalize(non_blocking=args.non_blocking) + globus_finalize(transfer_manager, args.keep) # List failures if len(failures) > 0: @@ -138,7 +147,7 @@ def setup_update() -> Tuple[argparse.Namespace, str]: # C901 'update_database' is too complex (20) def update_database( # noqa: C901 - args: argparse.Namespace, cache: str + args: argparse.Namespace, cache: str, transfer_manager: TransferManager ) -> Optional[List[str]]: # Open database logger.debug("Opening index database") @@ -151,8 +160,8 @@ def update_database( # noqa: C901 hpss: str = config.hpss else: raise TypeError("Invalid config.hpss={}".format(config.hpss)) - globus_activate(hpss) - hpss_get(hpss, get_db_filename(cache), cache) + transfer_manager.globus_config = globus_activate(hpss) + hpss_get(hpss, get_db_filename(cache), cache, transfer_manager) else: error_str: str = ( "--hpss argument is required when local copy of database is unavailable" @@ -217,7 +226,7 @@ def update_database( # noqa: C901 else: archived_files[file_path] = (size, mtime) - newfiles: List[str] = [] + newfiles: Dict[str, Tuple[int, datetime]] = {} files_checked = 0 for file_path in files: @@ -227,7 +236,7 @@ def update_database( # noqa: C901 # Check if file exists in database if file_path not in archived_files: # File not in database - it's new - newfiles.append(file_path) + newfiles[file_path] = (size_new, mdtime_new) else: # File exists in database - check if it changed archived_size, archived_mtime = archived_files[file_path] @@ -237,7 +246,7 @@ def update_database( # noqa: C901 and (abs((mdtime_new - archived_mtime).total_seconds()) <= TIME_TOL) ): # File has changed - newfiles.append(file_path) + newfiles[file_path] = (size_new, mdtime_new) files_checked += 1 @@ -252,7 +261,7 @@ def update_database( # noqa: C901 # --dry-run option if args.dry_run: print("List of files to be updated") - for file_path in newfiles: + for file_path in newfiles.keys(): print(file_path) # Close database con.commit() @@ -266,40 +275,24 @@ def update_database( # noqa: C901 for tfile in tfiles: tfile_string: str = tfile[0] itar = max(itar, int(tfile_string[0:6], 16)) - + dev_options: DevOptions = DevOptions( + error_on_duplicate_tar=args.error_on_duplicate_tar, + overwrite_duplicate_tars=args.overwrite_duplicate_tars, + force_database_corruption="", + ) # Add files - failures: List[str] - if args.follow_symlinks: - try: - # Add files - failures = add_files( - cur, - con, - itar, - newfiles, - cache, - keep, - args.follow_symlinks, - non_blocking=args.non_blocking, - error_on_duplicate_tar=args.error_on_duplicate_tar, - overwrite_duplicate_tars=args.overwrite_duplicate_tars, - ) - except FileNotFoundError: - raise Exception("Archive update failed due to broken symlink.") - else: - # Add files - failures = add_files( - cur, - con, - itar, - newfiles, - cache, - keep, - args.follow_symlinks, - non_blocking=args.non_blocking, - error_on_duplicate_tar=args.error_on_duplicate_tar, - overwrite_duplicate_tars=args.overwrite_duplicate_tars, - ) + failures = construct_tars( + cur, + con, + itar, + newfiles, + cache, + keep, + args.follow_symlinks, + dev_options, + transfer_manager, + non_blocking=args.non_blocking, + ) # Close database con.commit() diff --git a/zstash/utils.py b/zstash/utils.py index 1bec4f3e..04154136 100644 --- a/zstash/utils.py +++ b/zstash/utils.py @@ -220,15 +220,6 @@ def get_files_to_archive_with_stats( return file_stats -def get_files_to_archive(cache: str, include: str, exclude: str) -> List[str]: - """ - LEGACY VERSION: Still used for `zstash create`. - Uses the optimized version but returns only the file list. - """ - file_stats = get_files_to_archive_with_stats(cache, include, exclude) - return list(file_stats.keys()) - - def update_config(cur: sqlite3.Cursor): # Retrieve some configuration settings from database # Loop through all attributes of config.