diff --git a/.github/workflows/build_workflow.yml b/.github/workflows/build_workflow.yml index b2a5d2b6..72c0f9f1 100644 --- a/.github/workflows/build_workflow.yml +++ b/.github/workflows/build_workflow.yml @@ -51,7 +51,7 @@ jobs: hashFiles('conda/dev.yml') }} - name: Build Conda Environment - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: activate-environment: zstash_dev miniforge-variant: Miniforge3 @@ -94,7 +94,7 @@ jobs: hashFiles('conda/dev.yml') }} - name: Build Conda Environment - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: activate-environment: zstash_dev miniforge-variant: Miniforge3 diff --git a/.github/workflows/release_workflow.yml b/.github/workflows/release_workflow.yml index 12b068a9..3a259552 100644 --- a/.github/workflows/release_workflow.yml +++ b/.github/workflows/release_workflow.yml @@ -32,7 +32,7 @@ jobs: hashFiles('conda/dev.yml') }} - name: Build Conda Environment - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: activate-environment: zstash_dev miniforge-variant: Miniforge3 diff --git a/.gitignore b/.gitignore index f0c2eeda..1679901a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ build/ dist/ +examples/*_run/ tests/test_follow_symlinks/ tests/test_follow_symlinks_non_archived/ zstash.egg-info/ diff --git a/conda/dev.yml b/conda/dev.yml index a70d6164..19ce5fbf 100644 --- a/conda/dev.yml +++ b/conda/dev.yml @@ -10,6 +10,10 @@ dependencies: - six=1.16.0 - globus-sdk=3.2.1 - fair-research-login=0.2.6 + # Testing + # ======================= + - pytest + - pytest-cov # Developer Tools # ================= # If versions are updated, also update 'rev' in `.pre-commit.config.yaml` diff --git a/examples/simple_globus.py b/examples/simple_globus.py new file mode 100644 index 00000000..a3e08e1e --- /dev/null +++ b/examples/simple_globus.py @@ -0,0 +1,218 @@ +import configparser +import os +import re +import shutil +from typing import List, Optional +from urllib.parse import ParseResult, urlparse + +from fair_research_login.client import NativeClient +from globus_sdk import TransferAPIError, TransferClient, TransferData +from globus_sdk.response import GlobusHTTPResponse + +# Minimal example of how Globus is used in zstash +# 1. Log into endpoints at globus.org +# File Manager > Add the endpoints in the "Collection" fields +# 2. To start fresh, with no consents: +# https://auth.globus.org/v2/web/consents > Manage Your Consents > Globus Endpoint Performance Monitoring > rescind all" + +HSI_DIR: str = "zstash_test_370_20250723" +ENDPOINT_NAME: str = ( + "LCRC Improv DTN" # Change this to the name of the endpoint you want to use +) +REQUEST_SCOPES_EARLY: bool = True # False will emulate zstash behavior + +# Globus-specific settings #################################################### +GLOBUS_CFG: str = os.path.expanduser("~/.globus-native-apps.cfg") +INI_PATH: str = os.path.expanduser("~/.zstash.ini") +ZSTASH_CLIENT_ID: str = "6c1629cf-446c-49e7-af95-323c6412397f" +NAME_TO_ENDPOINT_MAP = { + # "Globus Tutorial Collection 1": "6c54cade-bde5-45c1-bdea-f4bd71dba2cc", # The Unit test endpoint + "NERSC HPSS": "9cd89cfd-6d04-11e5-ba46-22000b92c6ec", + "NERSC Perlmutter": "6bdc7956-fc0f-4ad2-989c-7aa5ee643a79", + "LCRC Improv DTN": "15288284-7006-4041-ba1a-6b52501e49f1", +} + + +# Functions ################################################################### +def main(): + base_dir = os.getcwd() + print(f"Starting in {base_dir}") + if os.path.exists(INI_PATH): + os.remove(INI_PATH) + if os.path.exists(GLOBUS_CFG): + os.remove(GLOBUS_CFG) + skipped_second_auth: bool = False + try: + skipped_second_auth = simple_transfer("toy_run") + except RuntimeError: + print("Now that we have the authentications, let's re-run.") + print(f"For toy_run, skipped_second_auth={skipped_second_auth}") + if skipped_second_auth: + # We want to enter this block! + print( + "We didn't need to authenticate a second time! That means we don't have to re-run the previous command to start the transfer!" + ) + else: + # Without `get_all_endpoint_scopes`, we ended up in this block! + # + # /global/homes/f/forsyth/.globus-native-apps.cfg does not exist. zstash will need to prompt for authentications twice, and then you will need to re-run. + # + # Might ask for 1st authentication prompt: + # Please paste the following URL in a browser: + # Authenticated for the 1st time! + # + # Might ask for 2nd authentication prompt: + # Please paste the following URL in a browser: + # Authenticated for the 2nd time! + # Consents added, please re-run the previous command to start transfer + # Now that we have the authentications, let's re-run. + os.chdir(base_dir) + print(f"Now in {os.getcwd()}") + assert os.path.exists(INI_PATH) + assert os.path.exists(GLOBUS_CFG) + skipped_second_auth = simple_transfer("real_run") + print(f"For real_run, skipped_second_auth={skipped_second_auth}") + # /global/homes/f/forsyth/.globus-native-apps.cfg exists. If this file does not have the proper settings, it may cause a TransferAPIError (e.g., 'Token is not active', 'No credentials supplied') + # + # Might ask for 1st authentication prompt: + # Authenticated for the 1st time! + # + # Bypassed 2nd authentication. + # + # Wait for task to complete, wait_timeout=300 + print(f"To see transferred files, run: hsi ls {HSI_DIR}") + # To see transferred files, run: hsi ls zstash_debugging_20250415_v2 + # Shows file0.txt + assert skipped_second_auth + + +def simple_transfer(run_dir: str) -> bool: + hpss_path = f"globus://{NAME_TO_ENDPOINT_MAP['NERSC HPSS']}/~/{HSI_DIR}" + if os.path.exists(run_dir): + shutil.rmtree(run_dir) + os.mkdir(run_dir) + os.chdir(run_dir) + print(f"Now in {os.getcwd()}") + dir_to_archive: str = "dir_to_archive" + txt_file: str = "file0.txt" + os.mkdir(dir_to_archive) + with open(f"{dir_to_archive}/{txt_file}", "w") as f: + f.write("file contents") + url: ParseResult = urlparse(hpss_path) + assert url.scheme == "globus" + if os.path.exists(GLOBUS_CFG): + print( + f"{GLOBUS_CFG} exists. If this file does not have the proper settings, it may cause a TransferAPIError (e.g., 'Token is not active', 'No credentials supplied')" + ) + else: + print( + f"{GLOBUS_CFG} does not exist. zstash will need to prompt for authentications twice, and then you will need to re-run." + ) + config_path: str = os.path.abspath(dir_to_archive) + assert os.path.isdir(config_path) + remote_endpoint: str = url.netloc + # Simulate globus_activate > set_local_endpoint + ini = configparser.ConfigParser() + local_endpoint: Optional[str] = None + if ini.read(INI_PATH): + if "local" in ini.sections(): + local_endpoint = ini["local"].get("globus_endpoint_uuid") + else: + ini["local"] = {"globus_endpoint_uuid": ""} + with open(INI_PATH, "w") as f: + ini.write(f) + if not local_endpoint: + # nersc_hostname = os.environ.get("NERSC_HOST") + # assert nersc_hostname == "perlmutter" + local_endpoint = NAME_TO_ENDPOINT_MAP[ENDPOINT_NAME] + native_client = NativeClient( + client_id=ZSTASH_CLIENT_ID, + app_name="Zstash", + default_scopes="openid urn:globus:auth:scope:transfer.api.globus.org:all", + ) + # May print 'Please Paste your Auth Code Below:' + # This is the 1st authentication prompt! + print("Might ask for 1st authentication prompt:") + if REQUEST_SCOPES_EARLY: + all_scopes: str = get_all_endpoint_scopes(list(NAME_TO_ENDPOINT_MAP.values())) + native_client.login( + requested_scopes=all_scopes, no_local_server=True, refresh_tokens=True + ) + else: + native_client.login(no_local_server=True, refresh_tokens=True) + print("Authenticated for the 1st time!") + transfer_authorizer = native_client.get_authorizers().get("transfer.api.globus.org") + transfer_client: TransferClient = TransferClient(authorizer=transfer_authorizer) + for ep_id in [ + local_endpoint, + remote_endpoint, + ]: + r = transfer_client.endpoint_autoactivate(ep_id, if_expires_in=600) + assert r.get("code") != "AutoActivationFailed" + os.chdir(config_path) + print(f"Now in {os.getcwd()}") + url_path: str = str(url.path) + assert local_endpoint is not None + src_path: str = os.path.join(os.getcwd(), txt_file) + dst_path: str = os.path.join(url_path, txt_file) + subdir = os.path.basename(os.path.normpath(url_path)) + subdir_label = re.sub("[^A-Za-z0-9_ -]", "", subdir) + filename = txt_file.split(".")[0] + label = subdir_label + " " + filename + transfer_data: TransferData = TransferData( + transfer_client, + local_endpoint, # src_ep + remote_endpoint, # dst_ep + label=label, + verify_checksum=True, + preserve_timestamp=True, + fail_on_quota_errors=True, + ) + transfer_data.add_item(src_path, dst_path) + transfer_data["label"] = label + task: GlobusHTTPResponse + skipped_second_auth: bool = False + try: + task = transfer_client.submit_transfer(transfer_data) + print("Bypassed 2nd authentication.") + skipped_second_auth = True + except TransferAPIError as err: + if err.info.consent_required: + scopes = "urn:globus:auth:scope:transfer.api.globus.org:all[" + for ep_id in [remote_endpoint, local_endpoint]: + scopes += f" *https://auth.globus.org/scopes/{ep_id}/data_access" + scopes += " ]" + native_client = NativeClient(client_id=ZSTASH_CLIENT_ID, app_name="Zstash") + # May print 'Please Paste your Auth Code Below:' + # This is the 2nd authentication prompt! + print("Might ask for 2nd authentication prompt:") + native_client.login(requested_scopes=scopes) + print("Authenticated for the 2nd time!") + print( + "Consents added, please re-run the previous command to start transfer" + ) + raise RuntimeError("Re-run now that authentications are set up!") + else: + if err.info.authorization_parameters: + print("Error is in authorization parameters") + raise err + task_id = task.get("task_id") + wait_timeout = 300 # 300 sec = 5 min + print(f"Wait for task to complete, wait_timeout={wait_timeout}") + transfer_client.task_wait(task_id, timeout=wait_timeout, polling_interval=10) + curr_task: GlobusHTTPResponse = transfer_client.get_task(task_id) + task_status = curr_task["status"] + assert task_status == "SUCCEEDED" + return skipped_second_auth + + +def get_all_endpoint_scopes(endpoints: List[str]) -> str: + inner = " ".join( + [f"*https://auth.globus.org/scopes/{ep}/data_access" for ep in endpoints] + ) + return f"urn:globus:auth:scope:transfer.api.globus.org:all[{inner}]" + + +# Run ######################################################################### +if __name__ == "__main__": + main() diff --git a/examples/zstash_create_globus.py b/examples/zstash_create_globus.py new file mode 100644 index 00000000..5341eb81 --- /dev/null +++ b/examples/zstash_create_globus.py @@ -0,0 +1,318 @@ +import configparser +import os +import re +import shutil +from typing import Optional +from urllib.parse import ParseResult, urlparse + +from fair_research_login.client import NativeClient +from globus_sdk import TransferAPIError, TransferClient, TransferData +from globus_sdk.response import GlobusHTTPResponse + +# Minimal example of how Globus is used in zstash +# 1. Log into endpoints at globus.org +# 2. To start fresh, with no consents: +# https://app.globus.org/settings/consents > Manage Your Consents > Globus Endpoint Performance Monitoring > rescind all" + +HSI_DIR = "zstash_debugging_20250414_v4" + +# Globus-specific settings #################################################### +GLOBUS_CFG: str = os.path.expanduser("~/.globus-native-apps.cfg") +INI_PATH: str = os.path.expanduser("~/.zstash.ini") +ZSTASH_CLIENT_ID: str = "6c1629cf-446c-49e7-af95-323c6412397f" +NAME_TO_ENDPOINT_MAP = { + "Globus Tutorial Collection 1": "6c54cade-bde5-45c1-bdea-f4bd71dba2cc", # The Unit test endpoint + "NERSC HPSS": "9cd89cfd-6d04-11e5-ba46-22000b92c6ec", + "NERSC Perlmutter": "6bdc7956-fc0f-4ad2-989c-7aa5ee643a79", +} + + +class GlobusInfo(object): + def __init__(self, hpss_path: str): + url: ParseResult = urlparse(hpss_path) + assert url.scheme == "globus" + self.hpss_path: str = hpss_path + self.url: ParseResult = url + self.remote_endpoint: Optional[str] = None + self.local_endpoint: Optional[str] = None + self.transfer_client: Optional[TransferClient] = None + self.transfer_data: Optional[TransferData] = None + self.task_id = None + + +# zstash general settings ##################################################### +class Config(object): + def __init__(self): + self.path: Optional[str] = None + self.hpss: Optional[str] = None + self.maxsize: int = int(1024 * 1024 * 1024 * 256) + + +class CommandInfo(object): + def __init__(self, dir_to_archive: str, hpss_path: str): + self.config: Config = Config() + # Simulate CommandInfo.set_dir_to_archive + self.config.path = os.path.abspath(dir_to_archive) + # Simulate CommandInfo.set_hpss_parameters + self.config.hpss = hpss_path + url: ParseResult = urlparse(hpss_path) + assert url.scheme == "globus" + self.globus_info: GlobusInfo = GlobusInfo(hpss_path) + if os.path.exists(GLOBUS_CFG): + print( + f"{GLOBUS_CFG} exists. If this file does not have the proper settings, it may cause a TransferAPIError (e.g., 'Token is not active', 'No credentials supplied')" + ) + else: + print( + f"{GLOBUS_CFG} does not exist. zstash will need to prompt for authentications twice, and then you will need to re-run." + ) + + +# Functions ################################################################### +def main(): + hpss_path = f"globus://{NAME_TO_ENDPOINT_MAP['NERSC HPSS']}/~/{HSI_DIR}" + dir_to_archive: str = "dir_to_archive" + base_dir = os.getcwd() + toy_run(hpss_path, dir_to_archive) + # /global/homes/f/forsyth/.globus-native-apps.cfg does not exist. zstash will need to prompt for authentications twice, and then you will need to re-run. + # + # Might ask for 1st authentication prompt: + # Please paste the following URL in a browser: + # Authenticated for the 1st time! + # + # Might ask for 2nd authentication prompt: + # Please paste the following URL in a browser: + # Authenticated for the 2nd time! + # Consents added, please re-run the previous command to start transfer + # Now that we have the authentications, let's re-run. + os.chdir(base_dir) + real_run(hpss_path, dir_to_archive) + # /global/homes/f/forsyth/.globus-native-apps.cfg exists. If this file does not have the proper settings, it may cause a TransferAPIError (e.g., 'Token is not active', 'No credentials supplied') + # + # Might ask for 1st authentication prompt: + # Authenticated for the 1st time! + # + # distilled_globus_block_wait. task_wait, retry_count=0 + # To see transferred files, run: hsi ls zstash_debugging_20250414_v2 + # Shows file0.txt + print(f"To see transferred files, run: hsi ls {HSI_DIR}") + + +# Toy run to get everything set up correctly. +def toy_run( + hpss_path: str, + dir_to_archive: str, +): + # Start fresh + if os.path.exists(INI_PATH): + os.remove(INI_PATH) + if os.path.exists(GLOBUS_CFG): + os.remove(GLOBUS_CFG) + set_up_dirs("toy_run", dir_to_archive) + try: + distilled_create( + hpss_path, + dir_to_archive, + ) + except RuntimeError: + print("Now that we have the authentications, let's re-run.") + + +def real_run( + hpss_path: str, + dir_to_archive: str, +): + # Start fresh + assert os.path.exists(INI_PATH) + assert os.path.exists(GLOBUS_CFG) + set_up_dirs("real_run", dir_to_archive) + distilled_create( + hpss_path, + dir_to_archive, + ) + + +def set_up_dirs(run_dir: str, dir_to_archive: str): + if os.path.exists(run_dir): + shutil.rmtree(run_dir) + os.mkdir(run_dir) + os.chdir(run_dir) + os.mkdir(dir_to_archive) + with open(f"{dir_to_archive}/file0.txt", "w") as f: + f.write("file0 stuff") + + +# Distilled versions of zstash functions ###################################### + + +def distilled_create(hpss_path: str, dir_to_archive: str): + command_info = CommandInfo(dir_to_archive, hpss_path) + print(command_info.config.path) + assert command_info.config.path is not None + assert os.path.isdir(command_info.config.path) + + # Begin simulating globus_activate ######################################## + command_info.globus_info.remote_endpoint = command_info.globus_info.url.netloc + # Simulate globus_activate > set_local_endpoint + ini = configparser.ConfigParser() + if ini.read(INI_PATH): + if "local" in ini.sections(): + command_info.globus_info.local_endpoint = ini["local"].get( + "globus_endpoint_uuid" + ) + else: + ini["local"] = {"globus_endpoint_uuid": ""} + with open(INI_PATH, "w") as f: + ini.write(f) + if not command_info.globus_info.local_endpoint: + nersc_hostname = os.environ.get("NERSC_HOST") + assert nersc_hostname == "perlmutter" + command_info.globus_info.local_endpoint = NAME_TO_ENDPOINT_MAP[ + "NERSC Perlmutter" + ] + # Simulate globus_activate > set_clients + native_client = NativeClient( + client_id=ZSTASH_CLIENT_ID, + app_name="Zstash", + default_scopes="openid urn:globus:auth:scope:transfer.api.globus.org:all", + ) + # May print 'Please Paste your Auth Code Below:' + # This is the 1st authentication prompt! + print("Might ask for 1st authentication prompt:") + native_client.login(no_local_server=True, refresh_tokens=True) + print("Authenticated for the 1st time!") + transfer_authorizer = native_client.get_authorizers().get("transfer.api.globus.org") + command_info.globus_info.transfer_client = TransferClient( + authorizer=transfer_authorizer + ) + # Continue globus_activate + for ep_id in [ + command_info.globus_info.local_endpoint, + command_info.globus_info.remote_endpoint, + ]: + r = command_info.globus_info.transfer_client.endpoint_autoactivate( + ep_id, if_expires_in=600 + ) + assert r.get("code") != "AutoActivationFailed" + # End simulating globus_activate ########################################## + + os.chdir(command_info.config.path) + file_path = os.path.join(command_info.config.path, "file0.txt") + + # Begin simulating hpss_put ############################################### + url = urlparse(command_info.config.hpss) + url_path: str = str(url.path) + path: str + name: str + path, name = os.path.split(file_path) + cwd: str = os.getcwd() + if path != "": + # This directory contains the file we want to transfer to HPSS. + os.chdir(path) + _ = distilled_globus_transfer(command_info.globus_info, url_path, name) + if path != "": + os.chdir(cwd) + # End simulating hpss_put ################################################# + + assert command_info.globus_info.transfer_data is None + + +def distilled_globus_transfer( + globus_info: GlobusInfo, remote_path: str, name: str +) -> str: + assert globus_info.local_endpoint is not None + src_ep: str = globus_info.local_endpoint + src_path: str = os.path.join(os.getcwd(), name) + assert globus_info.remote_endpoint is not None + dst_ep: str = globus_info.remote_endpoint + dst_path: str = os.path.join(remote_path, name) + subdir = os.path.basename(os.path.normpath(remote_path)) + subdir_label = re.sub("[^A-Za-z0-9_ -]", "", subdir) + filename = name.split(".")[0] + label = subdir_label + " " + filename + assert globus_info.transfer_data is None + globus_info.transfer_data = TransferData( + globus_info.transfer_client, + src_ep, + dst_ep, + label=label, + verify_checksum=True, + preserve_timestamp=True, + fail_on_quota_errors=True, + ) + globus_info.transfer_data.add_item(src_path, dst_path) + globus_info.transfer_data["label"] = label + task: GlobusHTTPResponse + if globus_info.task_id: + task = globus_info.transfer_client.get_task(globus_info.task_id) + prev_task_status = task["status"] + if prev_task_status == "ACTIVE": + return "ACTIVE" + + # Begin simulating submit_transfer_with_checks ############################ + try: + assert globus_info.transfer_client is not None + task = globus_info.transfer_client.submit_transfer(globus_info.transfer_data) + except TransferAPIError as err: + if err.info.consent_required: + scopes = "urn:globus:auth:scope:transfer.api.globus.org:all[" + for ep_id in [globus_info.remote_endpoint, globus_info.local_endpoint]: + scopes += f" *https://auth.globus.org/scopes/{ep_id}/data_access" + scopes += " ]" + native_client = NativeClient(client_id=ZSTASH_CLIENT_ID, app_name="Zstash") + # May print 'Please Paste your Auth Code Below:' + # This is the 2nd authentication prompt! + print("Might ask for 2nd authentication prompt:") + native_client.login(requested_scopes=scopes) + print("Authenticated for the 2nd time!") + print( + "Consents added, please re-run the previous command to start transfer" + ) + raise RuntimeError("Re-run now that authentications are set up!") + else: + if err.info.authorization_parameters: + print("Error is in authorization parameters") + raise err + # End simulating submit_transfer_with_checks ############################## + + globus_info.task_id = task.get("task_id") + # Nullify the submitted transfer data structure so that a new one will be created on next call. + globus_info.transfer_data = None + # wait_timeout = 7200 sec = 120 min = 2h + # wait_timeout = 300 sec = 5 min + wait_timeout = 300 + max_retries = 2 + + # Begin simulating globus_block_wait ###################################### + task_status = "UNKNOWN" + retry_count = 0 + while retry_count < max_retries: + try: + # Wait for the task to complete + assert globus_info.transfer_client is not None + print(f"task_wait, retry_count={retry_count}") + globus_info.transfer_client.task_wait( + globus_info.task_id, timeout=wait_timeout, polling_interval=10 + ) + except Exception as e: + print(f"Unexpected Exception: {e}") + else: + assert globus_info.transfer_client is not None + curr_task: GlobusHTTPResponse = globus_info.transfer_client.get_task( + globus_info.task_id + ) + task_status = curr_task["status"] + if task_status == "SUCCEEDED": + break + finally: + retry_count += 1 + if retry_count == max_retries: + task_status = "EXHAUSTED_TIMEOUT_RETRIES" + # End simulating globus_block_wait ###################################### + + return task_status + + +# Run ######################################################################### +if __name__ == "__main__": + main() diff --git a/tests/test_functions.py b/tests/test_functions.py deleted file mode 100644 index cc494e9c..00000000 --- a/tests/test_functions.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest - -from zstash.extract import parse_tars_option - - -class TestFunctions(unittest.TestCase): - def testParseTarsOption(self): - # Starting at 00005a until the end - tar_list = parse_tars_option("00005a-", "000000", "00005d") - self.assertEqual(tar_list, ["00005a", "00005b", "00005c", "00005d"]) - - # Starting from the beginning to 00005a (included) - tar_list = parse_tars_option("-00005a", "000058", "00005d") - self.assertEqual(tar_list, ["000058", "000059", "00005a"]) - - # Specific range - tar_list = parse_tars_option("00005a-00005c", "00000", "000005d") - self.assertEqual(tar_list, ["00005a", "00005b", "00005c"]) - - # Selected tar files - tar_list = parse_tars_option("00003e,00004e,000059", "000000", "00005d") - self.assertEqual(tar_list, ["00003e", "00004e", "000059"]) - - # Mix and match - tar_list = parse_tars_option("000030-00003e,00004e,00005a-", "000000", "00005d") - self.assertEqual( - tar_list, - [ - "000030", - "000031", - "000032", - "000033", - "000034", - "000035", - "000036", - "000037", - "000038", - "000039", - "00003a", - "00003b", - "00003c", - "00003d", - "00003e", - "00004e", - "00005a", - "00005b", - "00005c", - "00005d", - ], - ) - - # Check removal of duplicates and sorting - tar_list = parse_tars_option("000009,000003,-000005", "000000", "00005d") - self.assertEqual( - tar_list, - ["000000", "000001", "000002", "000003", "000004", "000005", "000009"], - ) - - # Remove .tar suffix - tar_list = parse_tars_option( - "000009.tar-00000a,000003.tar,-000002.tar", "000000", "00005d" - ) - self.assertEqual( - tar_list, ["000000", "000001", "000002", "000003", "000009", "00000a"] - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests_unit/test_chgrp.py b/tests_unit/test_chgrp.py new file mode 100644 index 00000000..ec3735fe --- /dev/null +++ b/tests_unit/test_chgrp.py @@ -0,0 +1,31 @@ +from zstash.chgrp import setup_chgrp + + +def test_setup_chgrp(): + args_str = "zstash chgrp 775 my_path".split(" ") + args = setup_chgrp(args_str) + assert args.group == "775" + assert args.hpss == "my_path" + assert args.R is None + assert args.verbose is False + + args_str = "zstash chgrp 775 my_path -R".split(" ") + args = setup_chgrp(args_str) + assert args.group == "775" + assert args.hpss == "my_path" + assert args.R is True + assert args.verbose is False + + args_str = "zstash chgrp 775 my_path -v".split(" ") + args = setup_chgrp(args_str) + assert args.group == "775" + assert args.hpss == "my_path" + assert args.R is None + assert args.verbose is True + + args_str = "zstash chgrp 775 my_path -R -v".split(" ") + args = setup_chgrp(args_str) + assert args.group == "775" + assert args.hpss == "my_path" + assert args.R is True + assert args.verbose is True diff --git a/tests_unit/test_create.py b/tests_unit/test_create.py new file mode 100644 index 00000000..127c9d14 --- /dev/null +++ b/tests_unit/test_create.py @@ -0,0 +1,122 @@ +from zstash.create import setup_create +from zstash.utils import CommandInfo, HPSSType + + +def test_setup_create(): + # Test required parameters + args_str = "zstash create dir_to_archive --hpss=none".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.path == "dir_to_archive" + assert args.hpss == "none" + assert args.include is None + assert args.exclude is None + assert args.maxsize == 256 + assert args.keep is True + assert args.cache is None + assert args.non_blocking is False + assert args.verbose is False + assert args.no_tars_md5 is False + assert args.follow_symlinks is False + assert command_info.cache_dir == "zstash" + assert command_info.keep is True + assert command_info.config.path.endswith("dir_to_archive") + assert command_info.dir_to_archive_relative == "dir_to_archive" + assert command_info.config.maxsize == 274877906944 + assert command_info.config.hpss == "none" + assert command_info.hpss_type == HPSSType.NO_HPSS + + args_str = "zstash create dir_to_archive --hpss=my_path".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.hpss == "my_path" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.hpss == "my_path" + assert command_info.hpss_type == HPSSType.SAME_MACHINE_HPSS + + args_str = "zstash create dir_to_archive --hpss=globus://my_path".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.hpss == "globus://my_path" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.hpss == "globus://my_path" + assert command_info.hpss_type == HPSSType.GLOBUS + assert command_info.globus_info.hpss_path == "globus://my_path" + + # Test required parameters, with --keep + args_str = "zstash create dir_to_archive --hpss=none --keep".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.keep is True + assert command_info.keep is True + + args_str = "zstash create dir_to_archive --hpss=my_path --keep".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.keep is True + assert command_info.keep is True + + args_str = "zstash create dir_to_archive --hpss=globus://my_path --keep".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.keep is True + assert command_info.keep is True + + # Test optional parameters + args_str = "zstash create dir_to_archive --hpss=none --include=file1.txt,file2.txt,file3.txt".split( + " " + ) + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.include == "file1.txt,file2.txt,file3.txt" + assert args.exclude is None + + args_str = "zstash create dir_to_archive --hpss=none --exclude=file1.txt,file2.txt,file3.txt".split( + " " + ) + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.include is None + assert args.exclude == "file1.txt,file2.txt,file3.txt" + + args_str = "zstash create dir_to_archive --hpss=none --include=file1.txt,file2.txt --exclude=file3.txt,file4.txt".split( + " " + ) + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.include == "file1.txt,file2.txt" + assert args.exclude == "file3.txt,file4.txt" + + args_str = "zstash create dir_to_archive --hpss=none --maxsize=1024".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.maxsize == 1024 + assert command_info.config.maxsize == 1024**4 + + args_str = "zstash create dir_to_archive --hpss=none --cache=my_cache".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.cache == "my_cache" + assert command_info.cache_dir == "my_cache" + + args_str = "zstash create dir_to_archive --hpss=none --non-blocking".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.non_blocking is True + + args_str = "zstash create dir_to_archive --hpss=none --verbose".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.verbose is True + + args_str = "zstash create dir_to_archive --hpss=none --no_tars_md5".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.no_tars_md5 is True + + args_str = "zstash create dir_to_archive --hpss=none --follow-symlinks".split(" ") + command_info = CommandInfo("create") + args = setup_create(command_info, args_str) + assert args.follow_symlinks is True diff --git a/tests_unit/test_extract.py b/tests_unit/test_extract.py new file mode 100644 index 00000000..cce3a984 --- /dev/null +++ b/tests_unit/test_extract.py @@ -0,0 +1,233 @@ +from datetime import datetime + +from zstash.extract import parse_tars_option, process_matches, setup_extract +from zstash.utils import CommandInfo, HPSSType + + +def test_setup_extract(): + # Test required parameters + args_str = "zstash extract".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.files == ["*"] + assert args.hpss is None + assert args.workers == 1 + assert args.keep is False + assert args.cache is None + assert args.retries == 1 + assert args.verbose is False + assert command_info.cache_dir == "zstash" + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.maxsize is None + assert command_info.config.hpss is None + assert command_info.hpss_type == HPSSType.UNDEFINED + + args_str = "zstash extract file*.txt".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.files == ["file*.txt"] + + args_str = "zstash extract file1.txt fileA.txt".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.files == ["file1.txt", "fileA.txt"] + + # Test optional parameters + args_str = "zstash extract --hpss=none".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.hpss == "none" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "none" + assert command_info.hpss_type == HPSSType.NO_HPSS + + args_str = "zstash extract --hpss=my_path".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.hpss == "my_path" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "my_path" + assert command_info.hpss_type == HPSSType.SAME_MACHINE_HPSS + + args_str = "zstash extract --hpss=globus://my_path".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.hpss == "globus://my_path" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "globus://my_path" + assert command_info.hpss_type == HPSSType.GLOBUS + + args_str = "zstash extract --workers=5".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.workers == 5 + + args_str = "zstash extract --keep".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.keep is True + + args_str = "zstash extract --cache=my_cache".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.cache == "my_cache" + assert command_info.cache_dir == "my_cache" + + args_str = "zstash extract --retries=3".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.retries == 3 + + args_str = "zstash extract --verbose".split(" ") + command_info = CommandInfo("extract") + args = setup_extract(command_info, args_str) + assert args.verbose is True + + +def test_process_matches(): + # TupleFilesRow = Tuple[int, str, int, datetime.datetime, Optional[str], str, int] + matches_ = [ + (2, "file2.txt", 10, datetime.fromordinal(3), "md5_2", "tar0", 20), + (0, "file0.txt", 10, datetime.fromordinal(1), "md5_0", "tar0", 0), + (0, "file0.txt", 10, datetime.fromordinal(5), "md5_4", "tar1", 10), + (1, "file1.txt", 10, datetime.fromordinal(2), "md5_1", "tar0", 10), + (3, "file3.txt", 10, datetime.fromordinal(4), "md5_3", "tar1", 0), + ] + # Sorts on name [1], then on tar [5], then on offset [6] + # Removes duplicates (removes the earlier entry) + # Sorts on tar [5], then on offset [6] + # + # So, that would order this as: + # (0, "file0.txt", 10, datetime.fromordinal(1), "md5_0", "tar0", 0), + # (0, "file0.txt", 10, datetime.fromordinal(5), "md5_4", "tar1", 10), + # (1, "file1.txt", 10, datetime.fromordinal(2), "md5_1", "tar0", 10), + # (2, "file2.txt", 10, datetime.fromordinal(3), "md5_2", "tar0", 20), + # (3, "file3.txt", 10, datetime.fromordinal(4), "md5_3", "tar1", 0), + # Then, as: + # (1, "file1.txt", 10, datetime.fromordinal(2), "md5_1", "tar0", 10), + # (2, "file2.txt", 10, datetime.fromordinal(3), "md5_2", "tar0", 20), + # (3, "file3.txt", 10, datetime.fromordinal(4), "md5_3", "tar1", 0), + # (0, "file0.txt", 10, datetime.fromordinal(5), "md5_4", "tar1", 10), + matches = process_matches(matches_) + assert len(matches) == 4 + + n = 0 + assert matches[n].identifier == 1 + assert matches[n].name == "file1.txt" + assert matches[n].size == 10 + assert matches[n].mtime == datetime.fromordinal(2) + assert matches[n].md5 == "md5_1" + assert matches[n].tar == "tar0" + assert matches[n].offset == 10 + + n = 1 + assert matches[n].identifier == 2 + assert matches[n].name == "file2.txt" + assert matches[n].size == 10 + assert matches[n].mtime == datetime.fromordinal(3) + assert matches[n].md5 == "md5_2" + assert matches[n].tar == "tar0" + assert matches[n].offset == 20 + + n = 2 + assert matches[n].identifier == 3 + assert matches[n].name == "file3.txt" + assert matches[n].size == 10 + assert matches[n].mtime == datetime.fromordinal(4) + assert matches[n].md5 == "md5_3" + assert matches[n].tar == "tar1" + assert matches[n].offset == 0 + + n = 3 + assert matches[n].identifier == 0 + assert matches[n].name == "file0.txt" + assert matches[n].size == 10 + assert matches[n].mtime == datetime.fromordinal(5) + assert matches[n].md5 == "md5_4" + assert matches[n].tar == "tar1" + assert matches[n].offset == 10 + + +def test_prepare_multiprocess(): + # TODO eventually -- this is a complicated function to test. + pass + + +def test_parse_tars_option(): + # Starting at 00005a until the end + tar_list = parse_tars_option("00005a-", "000000", "00005d") + assert tar_list == ["00005a", "00005b", "00005c", "00005d"] + + # Starting from the beginning to 00005a (included) + tar_list = parse_tars_option("-00005a", "000058", "00005d") + assert tar_list == ["000058", "000059", "00005a"] + + # Specific range + tar_list = parse_tars_option("00005a-00005c", "00000", "000005d") + assert tar_list == ["00005a", "00005b", "00005c"] + + # Selected tar files + tar_list = parse_tars_option("00003e,00004e,000059", "000000", "00005d") + assert tar_list == ["00003e", "00004e", "000059"] + + # Mix and match + tar_list = parse_tars_option("000030-00003e,00004e,00005a-", "000000", "00005d") + assert tar_list == [ + "000030", + "000031", + "000032", + "000033", + "000034", + "000035", + "000036", + "000037", + "000038", + "000039", + "00003a", + "00003b", + "00003c", + "00003d", + "00003e", + "00004e", + "00005a", + "00005b", + "00005c", + "00005d", + ] + + # Check removal of duplicates and sorting + tar_list = parse_tars_option("000009,000003,-000005", "000000", "00005d") + assert tar_list == [ + "000000", + "000001", + "000002", + "000003", + "000004", + "000005", + "000009", + ] + + # Remove .tar suffix + tar_list = parse_tars_option( + "000009.tar-00000a,000003.tar,-000002.tar", "000000", "00005d" + ) + assert tar_list == ["000000", "000001", "000002", "000003", "000009", "00000a"] diff --git a/tests_unit/test_ls.py b/tests_unit/test_ls.py new file mode 100644 index 00000000..111f2446 --- /dev/null +++ b/tests_unit/test_ls.py @@ -0,0 +1,102 @@ +from zstash.ls import setup_ls +from zstash.utils import CommandInfo, HPSSType + + +def test_setup_ls(): + # Test required parameters + args_str = "zstash ls".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.files == ["*"] + assert args.hpss is None + assert args.long is None + assert args.cache is None + assert args.tars is False + assert args.verbose is False + assert command_info.cache_dir == "zstash" + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.maxsize is None + assert command_info.config.hpss is None + assert command_info.hpss_type == HPSSType.UNDEFINED + + args_str = "zstash extract file*.txt".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.files == ["file*.txt"] + + args_str = "zstash extract file1.txt fileA.txt".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.files == ["file1.txt", "fileA.txt"] + + # Test optional parameters + args_str = "zstash ls --hpss=none".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.hpss == "none" + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "none" + assert command_info.hpss_type == HPSSType.NO_HPSS + + args_str = "zstash ls --hpss=my_path".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.hpss == "my_path" + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "my_path" + assert command_info.hpss_type == HPSSType.SAME_MACHINE_HPSS + + args_str = "zstash ls --hpss=globus://my_path".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.hpss == "globus://my_path" + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "globus://my_path" + assert command_info.hpss_type == HPSSType.GLOBUS + + args_str = "zstash ls -l".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.long is True + + args_str = "zstash ls --cache=my_cache".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.cache == "my_cache" + assert command_info.cache_dir == "my_cache" + + args_str = "zstash ls --tars".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.tars is True + + args_str = "zstash ls --verbose".split(" ") + command_info = CommandInfo("ls") + args = setup_ls(command_info, args_str) + assert args.verbose is True + + +def test_process_matches_files(): + # TODO eventually + pass + + +def test_process_matches_tars(): + # TODO eventually + pass diff --git a/tests_unit/test_update.py b/tests_unit/test_update.py new file mode 100644 index 00000000..22cc8e1b --- /dev/null +++ b/tests_unit/test_update.py @@ -0,0 +1,146 @@ +from zstash.update import setup_update +from zstash.utils import CommandInfo, HPSSType + + +def test_setup_update(): + # Test required parameters + args_str = "zstash update".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.hpss == "none" + assert args.include is None + assert args.exclude is None + assert args.dry_run is False + assert args.maxsize == 256 + assert args.keep is True + assert args.cache is None + assert args.non_blocking is False + assert args.verbose is False + assert args.follow_symlinks is False + assert command_info.cache_dir == "zstash" + assert command_info.keep is True + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.maxsize == 274877906944 + assert command_info.config.hpss == "none" + assert command_info.hpss_type == HPSSType.NO_HPSS + + # Test --hpss, without --keep + args_str = "zstash update --hpss=none".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.hpss == "none" + assert args.keep is True + assert command_info.keep is True + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "none" + assert command_info.hpss_type == HPSSType.NO_HPSS + + args_str = "zstash update --hpss=my_path".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.hpss == "my_path" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "my_path" + assert command_info.hpss_type == HPSSType.SAME_MACHINE_HPSS + + args_str = "zstash update --hpss=globus://my_path".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.hpss == "globus://my_path" + assert args.keep is False + assert command_info.keep is False + assert command_info.config.path.endswith( + "zstash" + ) # If running from top level of git repo + assert command_info.dir_to_archive_relative == command_info.config.path + assert command_info.config.hpss == "globus://my_path" + assert command_info.hpss_type == HPSSType.GLOBUS + + # Test --hpss, with --keep + args_str = "zstash update --hpss=none --keep".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.keep is True + assert command_info.keep is True + + args_str = "zstash update --hpss=my_path --keep".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.keep is True + assert command_info.keep is True + + args_str = "zstash update --hpss=globus://my_path".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.keep is False + assert command_info.keep is False + + # Test other optional parameters + args_str = "zstash update --include=file1.txt,file2.txt,file3.txt".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.include == "file1.txt,file2.txt,file3.txt" + assert args.exclude is None + + args_str = "zstash update --exclude=file1.txt,file2.txt,file3.txt".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.include is None + assert args.exclude == "file1.txt,file2.txt,file3.txt" + + args_str = "zstash update --include=file1.txt,file2.txt --exclude=file3.txt,file4.txt".split( + " " + ) + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.include == "file1.txt,file2.txt" + assert args.exclude == "file3.txt,file4.txt" + + args_str = "zstash update --dry-run".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.dry_run is True + + args_str = "zstash update --maxsize=1024".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.maxsize == 1024 + assert command_info.config.maxsize == 1024**4 + + args_str = "zstash update --keep".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.keep is True + assert command_info.keep is True + + args_str = "zstash update --cache=my_cache".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.cache == "my_cache" + assert command_info.cache_dir == "my_cache" + + args_str = "zstash update --non-blocking".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.non_blocking is True + + args_str = "zstash update --verbose".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.verbose is True + + args_str = "zstash update --follow-symlinks".split(" ") + command_info = CommandInfo("update") + args = setup_update(command_info, args_str) + assert args.follow_symlinks is True diff --git a/tests_unit/test_utils.py b/tests_unit/test_utils.py new file mode 100644 index 00000000..ac43b583 --- /dev/null +++ b/tests_unit/test_utils.py @@ -0,0 +1,16 @@ +from zstash.utils import exclude_files, include_files + +# NOTE: CommadInfo methods are tested implicitly via other tests in this directory. + + +# Test filter_files via exclude_files & include_files +def test_filter_files(): + file_str = "file1.txt,file2.txt,file3.txt" + file_list = [f"file{i}.txt" for i in range(5)] + assert exclude_files(file_str, file_list) == ["file0.txt", "file4.txt"] + assert include_files(file_str, file_list) == ["file1.txt", "file2.txt", "file3.txt"] + + file_str = "path1/*" + file_list = ["path1/f1.txt", "path1/f2.txt", "path2/f1.txt", "path2/f2.txt"] + assert exclude_files(file_str, file_list) == ["path2/f1.txt", "path2/f2.txt"] + assert include_files(file_str, file_list) == ["path1/f1.txt", "path1/f2.txt"] diff --git a/zstash/check.py b/zstash/check.py index 675a378a..b604482e 100644 --- a/zstash/check.py +++ b/zstash/check.py @@ -9,4 +9,4 @@ def check(): """ # This basically just goes through the process of extracting the files, # but doesn't actually save the output. - extract.extract(keep_files=False) + extract.extract(do_extract_files=False) diff --git a/zstash/chgrp.py b/zstash/chgrp.py index 84393b77..ad85105e 100644 --- a/zstash/chgrp.py +++ b/zstash/chgrp.py @@ -3,13 +3,19 @@ import argparse import logging import sys +from typing import List from .hpss import hpss_chgrp from .settings import logger def chgrp(): + args: argparse.Namespace = setup_chgrp(sys.argv) + recurse: bool = True if args.R else False + hpss_chgrp(args.hpss, args.group, recurse) + +def setup_chgrp(arg_list: List[str]) -> argparse.Namespace: # Parser parser: argparse.ArgumentParser = argparse.ArgumentParser( usage="zstash chgrp [] group hpss_archive", @@ -24,12 +30,9 @@ def chgrp(): "-v", "--verbose", action="store_true", help="increase output verbosity" ) - args: argparse.Namespace = parser.parse_args(sys.argv[2:]) + args: argparse.Namespace = parser.parse_args(arg_list[2:]) if args.hpss and args.hpss.lower() == "none": args.hpss = "none" - - # Start doing actual work if args.verbose: logger.setLevel(logging.DEBUG) - recurse: bool = True if args.R else False - hpss_chgrp(args.hpss, args.group, recurse) + return args diff --git a/zstash/create.py b/zstash/create.py index c544896c..6a6b5c8a 100644 --- a/zstash/create.py +++ b/zstash/create.py @@ -6,15 +6,15 @@ import os.path import sqlite3 import sys -from typing import Any, List, Tuple - -from six.moves.urllib.parse import urlparse +from typing import Any, List from .globus import globus_activate, globus_finalize from .hpss import hpss_put from .hpss_utils import add_files -from .settings import DEFAULT_CACHE, config, get_db_filename, logger +from .settings import logger from .utils import ( + CommandInfo, + HPSSType, create_tars_table, get_files_to_archive, run_command, @@ -24,60 +24,54 @@ def create(): - cache: str - cache, args = setup_create() - - # Check config fields - if config.path is not None: - path: str = config.path - else: - raise TypeError("Invalid config.path={}".format(config.path)) - if config.hpss is not None: - hpss: str = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) + command_info = CommandInfo("create") + args = setup_create(command_info, sys.argv) # Start doing actual work logger.debug(f"{ts_utc()}: Running zstash create") - logger.debug("Local path : {}".format(path)) - logger.debug("HPSS path : {}".format(hpss)) - logger.debug("Max size : {}".format(config.maxsize)) - logger.debug("Keep local tar files : {}".format(args.keep)) + logger.debug(f"Local path: {command_info.config.path}") + logger.debug(f"HPSS path: {command_info.config.hpss}") + logger.debug(f"Max size: {command_info.config.maxsize}") + logger.debug(f"Keep local tar files: {command_info.keep}") # Make sure input path exists and is a directory logger.debug("Making sure input path exists and is a directory") - if not os.path.isdir(path): + if not command_info.config.path: + raise ValueError("config.path is undefined") + if not os.path.isdir(command_info.config.path): # Input path is not a directory - input_path_error_str: str = "Input path should be a directory: {}".format(path) + input_path_error_str: str = ( + f"Input path should be a directory: {command_info.config.path}" + ) logger.error(input_path_error_str) raise NotADirectoryError(input_path_error_str) - if hpss != "none": - url = urlparse(hpss) - if url.scheme == "globus": - # identify globus endpoints - logger.debug(f"{ts_utc()}:Calling globus_activate(hpss)") - globus_activate(hpss) - else: - # config.hpss is not "none", so we need to - # create target HPSS directory - logger.debug(f"{ts_utc()}: Creating target HPSS directory {hpss}") - mkdir_command: str = "hsi -q mkdir -p {}".format(hpss) - mkdir_error_str: str = "Could not create HPSS directory: {}".format(hpss) - run_command(mkdir_command, mkdir_error_str) + if command_info.globus_info: + # identify globus endpoints + logger.debug(f"{ts_utc()}: Calling globus_activate") + globus_activate(command_info.globus_info) + elif command_info.hpss_type == HPSSType.SAME_MACHINE_HPSS: + logger.debug( + f"{ts_utc()}: Creating target HPSS directory {command_info.config.hpss}" + ) + mkdir_command: str = f"hsi -q mkdir -p {command_info.config.hpss}" + mkdir_error_str: str = ( + f"Could not create HPSS directory: {command_info.config.hpss}" + ) + run_command(mkdir_command, mkdir_error_str) - # Make sure it is exists and is empty - logger.debug("Making sure target HPSS directory exists and is empty") + # Make sure it is exists and is empty + logger.debug("Making sure target HPSS directory exists and is empty") - ls_command: str = 'hsi -q "cd {}; ls -l"'.format(hpss) - ls_error_str: str = "Target HPSS directory is not empty" - run_command(ls_command, ls_error_str) + ls_command: str = f'hsi -q "cd {command_info.config.hpss}; ls -l"' + ls_error_str: str = "Target HPSS directory is not empty" + run_command(ls_command, ls_error_str) # Create cache directory logger.debug(f"{ts_utc()}: Creating local cache directory") - os.chdir(path) + os.chdir(command_info.config.path) try: - os.makedirs(cache) + os.makedirs(command_info.cache_dir) except OSError as exc: if exc.errno != errno.EEXIST: cache_error_str: str = "Cannot create local cache directory" @@ -88,23 +82,24 @@ def create(): # Create and set up the database logger.debug(f"{ts_utc()}: Calling create_database()") - failures: List[str] = create_database(cache, args) + failures: List[str] = create_database(command_info, args) - # Transfer to HPSS. Always keep a local copy. - logger.debug(f"{ts_utc()}: calling hpss_put() for {get_db_filename(cache)}") - hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True) + # Transfer to HPSS. Always keep a local copy of the database. + logger.debug(f"{ts_utc()}: calling hpss_put() for {command_info.get_db_name()}") + hpss_put(command_info, command_info.get_db_name()) - logger.debug(f"{ts_utc()}: calling globus_finalize()") - globus_finalize(non_blocking=args.non_blocking) + if command_info.globus_info: + logger.debug(f"{ts_utc()}: calling globus_finalize()") + globus_finalize(command_info.globus_info, non_blocking=args.non_blocking) if len(failures) > 0: # List the failures logger.warning("Some files could not be archived") for file_path in failures: - logger.error("Failed to archive {}".format(file_path)) + logger.error(f"Failed to archive {file_path}") -def setup_create() -> Tuple[str, argparse.Namespace]: +def setup_create(command_info: CommandInfo, arg_list: List[str]) -> argparse.Namespace: # Parser parser: argparse.ArgumentParser = argparse.ArgumentParser( usage="zstash create [] path", description="Create a new zstash archive" @@ -168,34 +163,32 @@ def setup_create() -> Tuple[str, argparse.Namespace]: ) # Now that we're inside a subcommand, ignore the first two argvs # (zstash create) - args: argparse.Namespace = parser.parse_args(sys.argv[2:]) + args: argparse.Namespace = parser.parse_args(arg_list[2:]) if (not args.hpss) or (args.hpss.lower() == "none"): args.hpss = "none" args.keep = True if args.verbose: logger.setLevel(logging.DEBUG) - # Copy configuration - config.path = os.path.abspath(args.path) - config.hpss = args.hpss - config.maxsize = int(1024 * 1024 * 1024 * args.maxsize) - cache: str if args.cache: - cache = args.cache - else: - cache = DEFAULT_CACHE + command_info.cache_dir = args.cache + command_info.keep = args.keep + command_info.set_dir_to_archive(args.path) + command_info.set_and_scale_maxsize(args.maxsize) + command_info.set_hpss_parameters(args.hpss) - return cache, args + return args -def create_database(cache: str, args: argparse.Namespace) -> List[str]: +def create_database(command_info: CommandInfo, args: argparse.Namespace) -> List[str]: # Create new database logger.debug(f"{ts_utc()}:Creating index database") - if os.path.exists(get_db_filename(cache)): + db_name: str = command_info.get_db_name() + if os.path.exists(db_name): # Remove old database - os.remove(get_db_filename(cache)) + os.remove(db_name) con: sqlite3.Connection = sqlite3.connect( - get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + db_name, detect_types=sqlite3.PARSE_DECLTYPES ) cur: sqlite3.Cursor = con.cursor() @@ -233,8 +226,8 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]: # Store configuration in database # Loop through all attributes of config. - for attr in dir(config): - value: Any = getattr(config, attr) + for attr in dir(command_info.config): + value: Any = getattr(command_info.config, attr) if not callable(value) and not attr.startswith("__"): # config.{attr} is not a function. # The attribute name does not start with "__" @@ -244,19 +237,20 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]: cur.execute("insert into config values (?,?)", (attr, value)) con.commit() - files: List[str] = get_files_to_archive(cache, args.include, args.exclude) + files: List[str] = get_files_to_archive( + command_info.cache_dir, args.include, args.exclude + ) failures: List[str] if args.follow_symlinks: try: # Add files to archive failures = add_files( + command_info, cur, con, -1, files, - cache, - args.keep, args.follow_symlinks, skip_tars_md5=args.no_tars_md5, non_blocking=args.non_blocking, @@ -266,12 +260,11 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]: else: # Add files to archive failures = add_files( + command_info, cur, con, -1, files, - cache, - args.keep, args.follow_symlinks, skip_tars_md5=args.no_tars_md5, non_blocking=args.non_blocking, diff --git a/zstash/extract.py b/zstash/extract.py index a0446cde..e9164eb0 100644 --- a/zstash/extract.py +++ b/zstash/extract.py @@ -20,29 +20,18 @@ from . import parallel from .hpss import hpss_get -from .settings import ( - BLOCK_SIZE, - DEFAULT_CACHE, - TIME_TOL, - FilesRow, - TupleFilesRow, - config, - get_db_filename, - logger, -) -from .utils import tars_table_exists, update_config - - -def extract(keep_files: bool = True): +from .settings import BLOCK_SIZE, TIME_TOL, FilesRow, TupleFilesRow, logger +from .utils import CommandInfo, HPSSType, tars_table_exists + + +def extract(do_extract_files: bool = True): """ Given an HPSS path in the zstash database or passed via the command line, extract the archived data based on the file pattern (if given). """ - args: argparse.Namespace - cache: str - args, cache = setup_extract() - - failures: List[FilesRow] = extract_database(args, cache, keep_files) + command_info = CommandInfo("extract") + args: argparse.Namespace = setup_extract(command_info, sys.argv) + failures: List[FilesRow] = extract_database(command_info, args, do_extract_files) if failures: logger.error("Encountered an error for files:") @@ -56,7 +45,7 @@ def extract(keep_files: bool = True): for tar in broken_tars: logger.error(tar) else: - verb: str = "extracting" if keep_files else "checking" + verb: str = "extracting" if do_extract_files else "checking" logger.info( 'No failures detected when {} the files. If you have a log file, run "grep -i Exception " to double check.'.format( verb @@ -64,7 +53,7 @@ def extract(keep_files: bool = True): ) -def setup_extract() -> Tuple[argparse.Namespace, str]: +def setup_extract(command_info: CommandInfo, arg_list: List[str]) -> argparse.Namespace: parser: argparse.ArgumentParser = argparse.ArgumentParser( usage="zstash extract [] [files]", description="Extract files from existing archive", @@ -102,112 +91,64 @@ def setup_extract() -> Tuple[argparse.Namespace, str]: "-v", "--verbose", action="store_true", help="increase output verbosity" ) parser.add_argument("files", nargs="*", default=["*"]) - args: argparse.Namespace = parser.parse_args(sys.argv[2:]) - if args.hpss and args.hpss.lower() == "none": + args: argparse.Namespace = parser.parse_args(arg_list[2:]) + + if args.hpss and (args.hpss.lower() == "none"): args.hpss = "none" - if args.cache: - cache = args.cache - else: - cache = DEFAULT_CACHE # Note: setting logging level to anything other than DEBUG doesn't work with # multiple workers. This must have someting to do with the custom logger # implemented for multiple workers. if args.verbose or args.workers > 1: logger.setLevel(logging.DEBUG) - return args, cache + if args.cache: + command_info.cache_dir = args.cache + command_info.keep = args.keep + command_info.set_dir_to_archive(os.getcwd()) + command_info.set_hpss_parameters(args.hpss, null_hpss_allowed=True) - -def parse_tars_option(tars: str, first_tar: str, last_tar: str) -> List[str]: - tar_str_list: List[str] = tars.split(",") - tar_list: List[str] = [] - tar_str: str - for tar_str in tar_str_list: - if tar_str.startswith('"'): - tar_str = tar_str[1:] - if tar_str.endswith('"'): - tar_str = tar_str[:-1] - if tar_str.startswith("-"): - tar_str = "{}{}".format(first_tar, tar_str) - elif tar_str.endswith("-"): - tar_str = "{}{}".format(tar_str, last_tar) - m: Optional[re.Match] - m = re.match("(.*)-(.*)", tar_str) - if m: - m1: str = m.group(1) - m2: str = m.group(2) - # Remove .tar suffix - if m1.endswith(".tar"): - m1 = m1[:-4] - if m2.endswith(".tar"): - m2 = m2[:-4] - beginning_tar: int = int(m1, 16) - ending_tar: int = int(m2, 16) - t: int - for t in range(beginning_tar, ending_tar + 1): - tar_list.append("{:06x}".format(t)) - else: - # Remove .tar suffix - if tar_str.endswith(".tar"): - tar_str = tar_str[:-4] - tar_list.append(tar_str) - # Remove duplicates and sort tar_list - tar_list = sorted(list(set(tar_list))) - return tar_list + return args def extract_database( - args: argparse.Namespace, cache: str, keep_files: bool + command_info: CommandInfo, args: argparse.Namespace, do_extract_files: bool ) -> List[FilesRow]: # Open database logger.debug("Opening index database") - if not os.path.exists(get_db_filename(cache)): + if not os.path.exists(command_info.get_db_name()): # Will need to retrieve from HPSS - if args.hpss is not None: - config.hpss = args.hpss - if config.hpss is not None: - hpss: str = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) - hpss_get(hpss, get_db_filename(cache), cache) + if command_info.hpss_type != HPSSType.UNDEFINED: + hpss_get(command_info, command_info.get_db_name()) else: error_str: str = ( "--hpss argument is required when local copy of database is unavailable" ) logger.error(error_str) - raise ValueError(error_str) con: sqlite3.Connection = sqlite3.connect( - get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + command_info.get_db_name(), detect_types=sqlite3.PARSE_DECLTYPES ) cur: sqlite3.Cursor = con.cursor() - update_config(cur) - if config.maxsize is not None: - maxsize = config.maxsize - else: - raise TypeError("Invalid config.maxsize={}".format(config.maxsize)) - config.maxsize = int(maxsize) - - # The command line arg should always have precedence - if args.hpss is not None: - config.hpss = args.hpss - keep: bool - if config.hpss == "none": - # If no HPSS is available, always keep the files. - keep = True - else: - keep = args.keep + command_info.update_config_using_db(cur) + command_info.validate_maxsize() + + if command_info.hpss_type in [HPSSType.NO_HPSS, HPSSType.UNDEFINED]: + if command_info.config.hpss != "none": + raise ValueError(f"Invalid config.hpss={command_info.config.hpss}") + # If not using HPSS, always keep the files. + command_info.keep = True + # else: keep command_info.keep set to args.keep # Start doing actual work - cmd: str = "extract" if keep_files else "check" + cmd: str = "extract" if do_extract_files else "check" logger.debug("Running zstash " + cmd) - logger.debug("Local path : {}".format(config.path)) - logger.debug("HPSS path : {}".format(config.hpss)) - logger.debug("Max size : {}".format(config.maxsize)) - logger.debug("Keep local tar files : {}".format(keep)) + logger.debug(f"Local path : {command_info.config.path}") + logger.debug(f"HPSS path : {command_info.config.hpss}") + logger.debug(f"Max size : {command_info.config.maxsize}") + logger.debug(f"Keep local tar files : {command_info.keep}") matches_: List[TupleFilesRow] = [] if args.tars is not None: @@ -244,6 +185,69 @@ def extract_database( if matches_ == []: raise FileNotFoundError("There was nothing to extract.") + matches: List[FilesRow] = process_matches(matches_) + + # Retrieve from tapes + failures: List[FilesRow] + if args.workers > 1: + logger.debug("Running zstash {} with multiprocessing".format(cmd)) + failures = multiprocess_extract( + args.workers, + command_info, + matches, + do_extract_files, + cur, + args, + ) + else: + failures = extractFiles(command_info, matches, do_extract_files, cur, args) + + # Close database + logger.debug("Closing index database") + con.close() + + return failures + + +def parse_tars_option(tars: str, first_tar: str, last_tar: str) -> List[str]: + tar_str_list: List[str] = tars.split(",") + tar_list: List[str] = [] + tar_str: str + for tar_str in tar_str_list: + if tar_str.startswith('"'): + tar_str = tar_str[1:] + if tar_str.endswith('"'): + tar_str = tar_str[:-1] + if tar_str.startswith("-"): + tar_str = "{}{}".format(first_tar, tar_str) + elif tar_str.endswith("-"): + tar_str = "{}{}".format(tar_str, last_tar) + m: Optional[re.Match] + m = re.match("(.*)-(.*)", tar_str) + if m: + m1: str = m.group(1) + m2: str = m.group(2) + # Remove .tar suffix + if m1.endswith(".tar"): + m1 = m1[:-4] + if m2.endswith(".tar"): + m2 = m2[:-4] + beginning_tar: int = int(m1, 16) + ending_tar: int = int(m2, 16) + t: int + for t in range(beginning_tar, ending_tar + 1): + tar_list.append("{:06x}".format(t)) + else: + # Remove .tar suffix + if tar_str.endswith(".tar"): + tar_str = tar_str[:-4] + tar_list.append(tar_str) + # Remove duplicates and sort tar_list + tar_list = sorted(list(set(tar_list))) + return tar_list + + +def process_matches(matches_: List[TupleFilesRow]) -> List[FilesRow]: matches: List[FilesRow] = list(map(lambda match: FilesRow(match), matches_)) # Sort by the filename, tape (so the tar archive), @@ -271,30 +275,14 @@ def extract_database( # Sort by tape and offset, so that we make sure # that extract the files by tape order. matches.sort(key=lambda t: (t.tar, t.offset)) - - # Retrieve from tapes - failures: List[FilesRow] - if args.workers > 1: - logger.debug("Running zstash {} with multiprocessing".format(cmd)) - failures = multiprocess_extract( - args.workers, matches, keep_files, keep, cache, cur, args - ) - else: - failures = extractFiles(matches, keep_files, keep, cache, cur, args) - - # Close database - logger.debug("Closing index database") - con.close() - - return failures + return matches def multiprocess_extract( num_workers: int, + command_info: CommandInfo, matches: List[FilesRow], - keep_files: bool, - keep_tars: Optional[bool], - cache: str, + do_extract_files: bool, cur: sqlite3.Cursor, args: argparse.Namespace, ) -> List[FilesRow]: @@ -304,6 +292,42 @@ def multiprocess_extract( A single unit of work is a tar and all of the files in it to extract. """ + tar_ordering, workers_to_matches = prepare_multiprocess(num_workers, matches) + monitor: parallel.PrintMonitor = parallel.PrintMonitor(tar_ordering) + + # The return value for extractFiles will be added here. + failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue() + processes: List[multiprocessing.Process] = [] + for matches in workers_to_matches: + tars_for_this_worker: List[str] = list(set(match.tar for match in matches)) + worker: parallel.ExtractWorker = parallel.ExtractWorker( + monitor, tars_for_this_worker, failure_queue + ) + process: multiprocessing.Process = multiprocessing.Process( + target=extractFiles, + args=(command_info, matches, do_extract_files, cur, args, worker), + daemon=True, + ) + process.start() + processes.append(process) + + # While the processes are running, we need to empty the queue. + # Otherwise, it causes hanging. + # No need to join() each of the processes when doing this, + # because we'll be in this loop until completion. + failures: List[FilesRow] = [] + while any(p.is_alive() for p in processes): + while not failure_queue.empty(): + failures.append(failure_queue.get()) + + # Sort the failures, since they can come in at any order. + failures.sort(key=lambda t: (t.name, t.tar, t.offset)) + return failures + + +def prepare_multiprocess( + num_workers: int, matches: List[FilesRow] +) -> Tuple[List[str], List[List[FilesRow]]]: # A dict of tar -> size of files in it. # This is because we're trying to balance the load between # the processes. @@ -357,66 +381,13 @@ def multiprocess_extract( workers_to_matches[workers_idx].append(db_row) tar_ordering: List[str] = sorted([tar for tar in tar_to_size]) - monitor: parallel.PrintMonitor = parallel.PrintMonitor(tar_ordering) - - # The return value for extractFiles will be added here. - failure_queue: multiprocessing.Queue[FilesRow] = multiprocessing.Queue() - processes: List[multiprocessing.Process] = [] - for matches in workers_to_matches: - tars_for_this_worker: List[str] = list(set(match.tar for match in matches)) - worker: parallel.ExtractWorker = parallel.ExtractWorker( - monitor, tars_for_this_worker, failure_queue - ) - process: multiprocessing.Process = multiprocessing.Process( - target=extractFiles, - args=(matches, keep_files, keep_tars, cache, cur, args, worker), - daemon=True, - ) - process.start() - processes.append(process) - - # While the processes are running, we need to empty the queue. - # Otherwise, it causes hanging. - # No need to join() each of the processes when doing this, - # because we'll be in this loop until completion. - failures: List[FilesRow] = [] - while any(p.is_alive() for p in processes): - while not failure_queue.empty(): - failures.append(failure_queue.get()) - - # Sort the failures, since they can come in at any order. - failures.sort(key=lambda t: (t.name, t.tar, t.offset)) - return failures + return tar_ordering, workers_to_matches -def check_sizes_match(cur, tfname): - match: bool - if cur and tars_table_exists(cur): - logger.info(f"{tfname} exists. Checking expected size matches actual size.") - actual_size = os.path.getsize(tfname) - name_only = os.path.split(tfname)[1] - cur.execute(f"select size from tars where name is '{name_only}';") - expected_size: int = cur.fetchall()[0][0] - if expected_size != actual_size: - logger.info( - f"{name_only}: expected size={expected_size} != {actual_size}=actual_size" - ) - match = False - else: - # Sizes match - match = True - else: - # Cannot access size information; assume the sizes match. - match = True - return match - - -# FIXME: C901 'extractFiles' is too complex (33) -def extractFiles( # noqa: C901 +def extractFiles( + command_info: CommandInfo, files: List[FilesRow], - keep_files: bool, - keep_tars: Optional[bool], - cache: str, + do_extract_files: bool, cur: sqlite3.Cursor, args: argparse.Namespace, multiprocess_worker: Optional[parallel.ExtractWorker] = None, @@ -425,11 +396,11 @@ def extractFiles( # noqa: C901 Given a list of database rows, extract the files from the tar archives to the current location on disk. - If keep_files is False, the files are not extracted. + If do_extract_files is False, the files are not extracted. This is used for when checking if the files in an HPSS repository are valid. - If keep_tars is True, the tar archives that are downloaded are kept, + If command_info.keep is True, the tar archives that are downloaded are kept, even after the program has terminated. Otherwise, they are deleted. If running in parallel, then multiprocess_worker is the Worker @@ -437,20 +408,14 @@ def extractFiles( # noqa: C901 We need a reference to it so we can signal it to print the contents of what's in its print queue. """ + failures: List[FilesRow] = [] tfname: str newtar: bool = True + tar: Optional[tarfile.TarFile] = None nfiles: int = len(files) if multiprocess_worker: - # All messages to the logger will now be sent to - # this queue, instead of sys.stdout. - sh = logging.StreamHandler(multiprocess_worker.print_queue) - sh.setLevel(logging.DEBUG) - formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s") - sh.setFormatter(formatter) - logger.addHandler(sh) - # Don't have the logger print to the console as the message come in. - logger.propagate = False + setup_multiprocess_logging(multiprocess_worker) for i in range(nfiles): files_row: FilesRow = files[i] @@ -458,152 +423,25 @@ def extractFiles( # noqa: C901 # Open new tar archive if newtar: newtar = False - tfname = os.path.join(cache, files_row.tar) - # Everytime we're extracting a new tar, if running in parallel, - # let the process know. - # This is to synchronize the print statements. - if multiprocess_worker: - multiprocess_worker.set_curr_tar(files_row.tar) - - if config.hpss is not None: - hpss: str = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) - tries: int = args.retries + 1 - # Set to True to test the `--retries` option with a forced failure. - # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries` - test_retry: bool = False - while tries > 0: - tries -= 1 - do_retrieve: bool - - if not os.path.exists(tfname): - do_retrieve = True - else: - do_retrieve = not check_sizes_match(cur, tfname) - - try: - if test_retry: - test_retry = False - raise RuntimeError - if do_retrieve: - hpss_get(hpss, tfname, cache) - if not check_sizes_match(cur, tfname): - raise RuntimeError( - f"{tfname} size does not match expected size." - ) - # `hpss_get` successful or not needed: no more tries needed - break - except RuntimeError as e: - if tries > 0: - logger.info(f"Retrying HPSS get: {tries} tries remaining.") - # Run the try-except block again - continue - else: - raise e - - logger.info("Opening tar archive %s" % (tfname)) - tar: tarfile.TarFile = tarfile.open(tfname, "r") + tfname, tar = open_tar_with_retries( + command_info, files_row, args, cur, multiprocess_worker + ) + elif not tar: + # At first, newtar is True, so tar gets set. + # From then on, if newtar is False, we reuse the tar from the last iteration. + # So, tar is guaranteed to be set. + raise RuntimeError("tar was never set!") # Extract file - cmd: str = "Extracting" if keep_files else "Checking" + cmd: str = "Extracting" if do_extract_files else "Checking" logger.info(cmd + " %s" % (files_row.name)) # if multiprocess_worker: # print('{} is {} {} from {}'.format(multiprocess_worker, cmd, file[1], file[5])) - if keep_files and not should_extract_file(files_row): - # If we were going to extract, but aren't - # because a matching file is on disk - msg: str = "Not extracting {}, because it" - msg += " already exists on disk with the same" - msg += " size and modification date." - logger.info(msg.format(files_row.name)) - - # True if we should actually extract the file from the tar - extract_this_file: bool = keep_files and should_extract_file(files_row) + extract_this_file: bool = should_extract_this_file(do_extract_files, files_row) try: - # Seek file position - if tar.fileobj is not None: - fileobj: _io.BufferedReader = tar.fileobj - else: - raise TypeError("Invalid tar.fileobj={}".format(tar.fileobj)) - fileobj.seek(files_row.offset) - - # Get next member - tarinfo: tarfile.TarInfo = tar.tarinfo.fromtarfile(tar) - - if tarinfo.isfile(): - # fileobj to extract - # error: Name 'tarfile.ExFileObject' is not defined - extracted_file: Optional[tarfile.ExFileObject] = tar.extractfile(tarinfo) # type: ignore - if extracted_file: - fin: tarfile.ExFileObject = extracted_file - else: - raise TypeError("Invalid extracted_file={}".format(extracted_file)) - try: - fname: str = tarinfo.name - path: str - name: str - path, name = os.path.split(fname) - if path != "" and extract_this_file: - if not os.path.isdir(path): - # The path doesn't exist, so create it. - os.makedirs(path) - if extract_this_file: - # If we're keeping the files, - # then have an output file - fout: _io.BufferedWriter = open(fname, "wb") - - hash_md5: _hashlib.HASH = hashlib.md5() - while True: - s: bytes = fin.read(BLOCK_SIZE) - if len(s) > 0: - hash_md5.update(s) - if extract_this_file: - fout.write(s) - if len(s) < BLOCK_SIZE: - break - finally: - fin.close() - if extract_this_file: - fout.close() - - md5: str = hash_md5.hexdigest() - if extract_this_file: - # numeric_owner is a required arg in Python 3. - # If True, "only the numbers for user/group names - # are used and not the names". - tar.chown(tarinfo, fname, numeric_owner=False) - tar.chmod(tarinfo, fname) - tar.utime(tarinfo, fname) - # Verify size - if os.path.getsize(fname) != files_row.size: - logger.error("size mismatch for: {}".format(fname)) - - # Verify md5 checksum - files_row_md5: Optional[str] = files_row.md5 - if md5 != files_row_md5: - logger.error("md5 mismatch for: {}".format(fname)) - logger.error("md5 of extracted file: {}".format(md5)) - logger.error("md5 of original file: {}".format(files_row_md5)) - - failures.append(files_row) - else: - logger.debug("Valid md5: {} {}".format(md5, fname)) - - elif extract_this_file: - tar.extract(tarinfo) - # Note: tar.extract() will not restore time stamps of symbolic - # links. Could not find a Python-way to restore it either, so - # relying here on 'touch'. This is not the prettiest solution. - # Maybe a better one can be implemented later. - if tarinfo.issym(): - tmp1: int = tarinfo.mtime - tmp2: datetime = datetime.fromtimestamp(tmp1) - tmp3: str = tmp2.strftime("%Y%m%d%H%M.%S") - os.system("touch -h -t %s %s" % (tmp3, tarinfo.name)) - + extract_file_from_tar(tar, files_row, extract_this_file, failures) except Exception: # Catch all exceptions here. traceback.print_exc() @@ -616,24 +454,11 @@ def extractFiles( # noqa: C901 # Close current archive? if i == nfiles - 1 or files[i].tar != files[i + 1].tar: # We're either on the last file or the tar is distinct from the tar of the next file. - - # Close current archive file - logger.debug("Closing tar archive {}".format(tfname)) - tar.close() - - if multiprocess_worker: - multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar) - - # Open new archive next time + close_and_cleanup_tar( + tar, tfname, command_info, files_row, multiprocess_worker + ) newtar = True - # Delete this tar if the corresponding command-line arg was used. - if not keep_tars: - if tfname is not None: - os.remove(tfname) - else: - raise TypeError("Invalid tfname={}".format(tfname)) - if multiprocess_worker: # If there are things left to print, print them. multiprocess_worker.print_all_contents() @@ -646,6 +471,110 @@ def extractFiles( # noqa: C901 return failures +def setup_multiprocess_logging(multiprocess_worker: parallel.ExtractWorker): + # All messages to the logger will now be sent to + # this queue, instead of sys.stdout. + sh = logging.StreamHandler(multiprocess_worker.print_queue) + sh.setLevel(logging.DEBUG) + formatter: logging.Formatter = logging.Formatter("%(levelname)s: %(message)s") + sh.setFormatter(formatter) + logger.addHandler(sh) + # Don't have the logger print to the console as the message come in. + logger.propagate = False + + +def open_tar_with_retries( + command_info: CommandInfo, + files_row: FilesRow, + args: argparse.Namespace, + cur: sqlite3.Cursor, + multiprocess_worker: Optional[parallel.ExtractWorker] = None, +) -> Tuple[str, tarfile.TarFile]: + tfname: str = os.path.join(command_info.cache_dir, files_row.tar) + # Everytime we're extracting a new tar, if running in parallel, + # let the process know. + # This is to synchronize the print statements. + if multiprocess_worker: + multiprocess_worker.set_curr_tar(files_row.tar) + + tries: int = args.retries + 1 + # Set to True to test the `--retries` option with a forced failure. + # Then run `python -m unittest tests.test_extract.TestExtract.testExtractRetries` + test_retry: bool = False + while tries > 0: + tries -= 1 + do_retrieve: bool + + if not os.path.exists(tfname): + do_retrieve = True + else: + do_retrieve = not check_sizes_match(cur, tfname) + + try: + if test_retry: + test_retry = False + raise RuntimeError + if do_retrieve: + hpss_get(command_info, tfname) + if not check_sizes_match(cur, tfname): + raise RuntimeError(f"{tfname} size does not match expected size.") + # `hpss_get` successful or not needed: no more tries needed + break + except RuntimeError as e: + if tries > 0: + logger.info(f"Retrying HPSS get: {tries} tries remaining.") + # Run the try-except block again + continue + else: + raise e + + logger.info("Opening tar archive %s" % (tfname)) + tar: tarfile.TarFile = tarfile.open(tfname, "r") + return tfname, tar + + +def check_sizes_match(cur, tfname): + match: bool + if cur and tars_table_exists(cur): + logger.info(f"{tfname} exists. Checking expected size matches actual size.") + actual_size = os.path.getsize(tfname) + name_only = os.path.split(tfname)[1] + cur.execute(f"select size from tars where name is '{name_only}';") + expected_size: int = cur.fetchall()[0][0] + if expected_size != actual_size: + logger.info( + f"{name_only}: expected size={expected_size} != {actual_size}=actual_size" + ) + match = False + else: + # Sizes match + match = True + else: + # Cannot access size information; assume the sizes match. + match = True + return match + + +def should_extract_this_file(do_extract_files: bool, files_row: FilesRow) -> bool: + extract_this_file: bool + if do_extract_files: + if should_extract_file(files_row): + extract_this_file = True + else: + # If we were going to extract, but aren't + # because a matching file is on disk + msg: str = "Not extracting {}, because it" + msg += " already exists on disk with the same" + msg += " size and modification date." + logger.info(msg.format(files_row.name)) + extract_this_file = False + else: + extract_this_file = False + + # True if we should actually extract the file from the tar + return extract_this_file + + def should_extract_file(db_row: FilesRow) -> bool: """ If a file is on disk already with the correct @@ -670,3 +599,117 @@ def should_extract_file(db_row: FilesRow) -> bool: (size_disk == size_db) and (abs(mod_time_disk - mod_time_db).total_seconds() < TIME_TOL) ) + + +def extract_file_from_tar( + tar: tarfile.TarFile, + files_row: FilesRow, + extract_this_file: bool, + failures: List[FilesRow], +): + # Seek file position + if tar.fileobj is not None: + fileobj: _io.BufferedReader = tar.fileobj + else: + raise TypeError("Invalid tar.fileobj={}".format(tar.fileobj)) + fileobj.seek(files_row.offset) + + # Get next member + tarinfo: tarfile.TarInfo = tar.tarinfo.fromtarfile(tar) + + if tarinfo.isfile(): + # fileobj to extract + # error: Name 'tarfile.ExFileObject' is not defined + extracted_file: Optional[tarfile.ExFileObject] = tar.extractfile(tarinfo) # type: ignore + if extracted_file: + fin: tarfile.ExFileObject = extracted_file + else: + raise TypeError("Invalid extracted_file={}".format(extracted_file)) + try: + fname: str = tarinfo.name + path: str + name: str + path, name = os.path.split(fname) + if path != "" and extract_this_file: + if not os.path.isdir(path): + # The path doesn't exist, so create it. + os.makedirs(path) + if extract_this_file: + # If we're keeping the files, + # then have an output file + fout: _io.BufferedWriter = open(fname, "wb") + + hash_md5: _hashlib.HASH = hashlib.md5() + while True: + s: bytes = fin.read(BLOCK_SIZE) + if len(s) > 0: + hash_md5.update(s) + if extract_this_file: + fout.write(s) + if len(s) < BLOCK_SIZE: + break + finally: + fin.close() + if extract_this_file: + fout.close() + + md5: str = hash_md5.hexdigest() + if extract_this_file: + # numeric_owner is a required arg in Python 3. + # If True, "only the numbers for user/group names + # are used and not the names". + tar.chown(tarinfo, fname, numeric_owner=False) + tar.chmod(tarinfo, fname) + tar.utime(tarinfo, fname) + # Verify size + if os.path.getsize(fname) != files_row.size: + logger.error("size mismatch for: {}".format(fname)) + + # Verify md5 checksum + files_row_md5: Optional[str] = files_row.md5 + if md5 != files_row_md5: + logger.error("md5 mismatch for: {}".format(fname)) + logger.error("md5 of extracted file: {}".format(md5)) + logger.error("md5 of original file: {}".format(files_row_md5)) + + failures.append(files_row) + else: + logger.debug("Valid md5: {} {}".format(md5, fname)) + + elif extract_this_file: + tar.extract(tarinfo) + # Note: tar.extract() will not restore time stamps of symbolic + # links. Could not find a Python-way to restore it either, so + # relying here on 'touch'. This is not the prettiest solution. + # Maybe a better one can be implemented later. + if tarinfo.issym(): + tmp1: int = tarinfo.mtime + tmp2: datetime = datetime.fromtimestamp(tmp1) + tmp3: str = tmp2.strftime("%Y%m%d%H%M.%S") + os.system("touch -h -t %s %s" % (tmp3, tarinfo.name)) + + +def close_and_cleanup_tar( + tar: tarfile.TarFile, + tfname: str, + command_info: CommandInfo, + files_row: FilesRow, + multiprocess_worker: Optional[parallel.ExtractWorker] = None, +): + # Close current archive file + logger.debug("Closing tar archive {}".format(tfname)) + tar.close() + + if multiprocess_worker: + multiprocess_worker.done_enqueuing_output_for_tar(files_row.tar) + + # Delete this tar if the corresponding command-line arg was used. + logger.debug(f"hpss type={command_info.hpss_type}") + if not command_info.keep: + if tfname is not None: + logger.debug(f"Removing tar archive {tfname}") + os.remove(tfname) + else: + raise TypeError(f"Invalid tfname={tfname}") + else: + logger.debug(f"Keeping tar archive {tfname}") diff --git a/zstash/globus.py b/zstash/globus.py index 84c12cdf..d43dca85 100644 --- a/zstash/globus.py +++ b/zstash/globus.py @@ -6,22 +6,26 @@ import re import socket import sys +from typing import Dict, List, Optional, Tuple +from urllib.parse import urlparse from fair_research_login.client import NativeClient from globus_sdk import TransferAPIError, TransferClient, TransferData -from globus_sdk.services.transfer.response.iterable import IterableTransferResponse -from six.moves.urllib.parse import urlparse from .settings import logger -from .utils import ts_utc +from .utils import GlobusInfo, ts_utc -hpss_endpoint_map = { +# Constants ################################################################### + +ZSTASH_CLIENT_ID: str = "6c1629cf-446c-49e7-af95-323c6412397f" + +HPSS_ENDPOINT_MAP: Dict[str, str] = { "ALCF": "de463ec4-6d04-11e5-ba46-22000b92c6ec", "NERSC": "9cd89cfd-6d04-11e5-ba46-22000b92c6ec", } # This is used if the `globus_endpoint_uuid` is not set in `~/.zstash.ini` -regex_endpoint_map = { +REGEX_ENDPOINT_MAP: Dict[str, str] = { r"theta.*\.alcf\.anl\.gov": "08925f04-569f-11e7-bef8-22000b9a448b", r"blueslogin.*\.lcrc\.anl\.gov": "15288284-7006-4041-ba1a-6b52501e49f1", r"chrlogin.*\.lcrc\.anl\.gov": "15288284-7006-4041-ba1a-6b52501e49f1", @@ -31,16 +35,79 @@ r"perlmutter.*\.nersc\.gov": "6bdc7956-fc0f-4ad2-989c-7aa5ee643a79", } -remote_endpoint = None -local_endpoint = None -transfer_client: TransferClient = None -transfer_data: TransferData = None -task_id = None -archive_directory_listing: IterableTransferResponse = None +# Last updated 2025-04-08 +ENDPOINT_TO_NAME_MAP: Dict[str, str] = { + "08925f04-569f-11e7-bef8-22000b9a448b": "Invalid, presumably Theta", + "15288284-7006-4041-ba1a-6b52501e49f1": "LCRC Improv DTN", + "68fbd2fa-83d7-11e9-8e63-029d279f7e24": "pic#compty-dtn", + "6bdc7956-fc0f-4ad2-989c-7aa5ee643a79": "NERSC Perlmutter", + "6c54cade-bde5-45c1-bdea-f4bd71dba2cc": "Globus Tutorial Collection 1", # The Unit test endpoint + "9cd89cfd-6d04-11e5-ba46-22000b92c6ec": "NERSC HPSS", + "de463ec4-6d04-11e5-ba46-22000b92c6ec": "Invalid, presumably ALCF HPSS", +} + +# Helper functions ############################################################ + + +def ep_to_name(endpoint_id: str) -> str: + if endpoint_id in ENDPOINT_TO_NAME_MAP: + return ENDPOINT_TO_NAME_MAP[endpoint_id] + else: + return endpoint_id # Just use the endpoint_id itself + + +def log_current_endpoints(globus_info: GlobusInfo): + local: str + remote: str + if globus_info.local_endpoint: + local = ep_to_name(globus_info.local_endpoint) + else: + local = "undefined" + logger.debug(f"local endpoint={local}") + if globus_info.remote_endpoint: + remote = ep_to_name(globus_info.remote_endpoint) + else: + remote = "undefined" + logger.debug(f"remote endpoint={remote}") + + +def get_all_endpoint_scopes(endpoints: List[str]) -> str: + inner = " ".join( + [f"*https://auth.globus.org/scopes/{ep}/data_access" for ep in endpoints] + ) + return f"urn:globus:auth:scope:transfer.api.globus.org:all[{inner}]" + + +def set_clients(globus_info: GlobusInfo): + native_client = NativeClient( + client_id=ZSTASH_CLIENT_ID, + app_name="Zstash", + default_scopes="openid urn:globus:auth:scope:transfer.api.globus.org:all", + ) + log_current_endpoints(globus_info) + logger.debug( + "set_clients. Calling login, which may print 'Please Paste your Auth Code Below:'" + ) + if globus_info.local_endpoint and globus_info.remote_endpoint: + all_scopes: str = get_all_endpoint_scopes( + [globus_info.local_endpoint, globus_info.remote_endpoint] + ) + native_client.login( + requested_scopes=all_scopes, no_local_server=True, refresh_tokens=True + ) + else: + native_client.login(no_local_server=True, refresh_tokens=True) + transfer_authorizer = native_client.get_authorizers().get("transfer.api.globus.org") + globus_info.transfer_client = TransferClient(authorizer=transfer_authorizer) -def check_endpoint_version_5(ep_id): - output = transfer_client.get_endpoint(ep_id) +# Used exclusively by check_consents +def check_endpoint_version_5(globus_info: GlobusInfo, ep_id: str) -> bool: + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + log_current_endpoints(globus_info) + logger.debug(f"check_endpoint_version_5. endpoint={ep_to_name(ep_id)}") + output = globus_info.transfer_client.get_endpoint(ep_id) version = output.get("gcs_version", "0.0") if output["gcs_version"] is None: return False @@ -49,50 +116,33 @@ def check_endpoint_version_5(ep_id): return False -def submit_transfer_with_checks(transfer_data): - try: - task = transfer_client.submit_transfer(transfer_data) - except TransferAPIError as err: - if err.info.consent_required: - scopes = "urn:globus:auth:scope:transfer.api.globus.org:all[" - for ep_id in [remote_endpoint, local_endpoint]: - if check_endpoint_version_5(ep_id): - scopes += f" *https://auth.globus.org/scopes/{ep_id}/data_access" - scopes += " ]" - native_client = NativeClient( - client_id="6c1629cf-446c-49e7-af95-323c6412397f", app_name="Zstash" - ) - native_client.login(requested_scopes=scopes) - # Quit here and tell user to re-try - print( - "Consents added, please re-run the previous command to start transfer" - ) - sys.exit(0) - else: - raise err - return task - - -def globus_activate(hpss: str): - """ - Read the local globus endpoint UUID from ~/.zstash.ini. - If the ini file does not exist, create an ini file with empty values, - and try to find the local endpoint UUID based on the FQDN - """ - global transfer_client - global local_endpoint - global remote_endpoint +# Used exclusively by submit_transfer_with_checks, exclusively when there is a TransferAPIError +# This function is really to diagnose an error: are the consents ok? +# That is, we don't *need* to check consents or endpoint versions if everything worked out fine. +def check_consents(globus_info: GlobusInfo): + scopes = "urn:globus:auth:scope:transfer.api.globus.org:all[" + for ep_id in [globus_info.remote_endpoint, globus_info.local_endpoint]: + if ep_id and check_endpoint_version_5(globus_info, ep_id): + scopes += f" *https://auth.globus.org/scopes/{ep_id}/data_access" + scopes += " ]" + native_client = NativeClient(client_id=ZSTASH_CLIENT_ID, app_name="Zstash") + log_current_endpoints(globus_info) + logger.debug( + "check_consents. Calling login, which may print 'Please Paste your Auth Code Below:'" + ) + native_client.login(requested_scopes=scopes) - url = urlparse(hpss) - if url.scheme != "globus": - return - remote_endpoint = url.netloc +# Used exclusively in globus_activate +def set_local_endpoint(globus_info: GlobusInfo): ini_path = os.path.expanduser("~/.zstash.ini") ini = configparser.ConfigParser() if ini.read(ini_path): if "local" in ini.sections(): - local_endpoint = ini["local"].get("globus_endpoint_uuid") + globus_info.local_endpoint = ini["local"].get("globus_endpoint_uuid") + logger.debug( + f"globus endpoint in ~/.zstash.ini: {ep_to_name(globus_info.local_endpoint)}" + ) else: ini["local"] = {"globus_endpoint_uuid": ""} try: @@ -101,219 +151,128 @@ def globus_activate(hpss: str): except Exception as e: logger.error(e) sys.exit(1) - if not local_endpoint: + if not globus_info.local_endpoint: fqdn = socket.getfqdn() if re.fullmatch(r"n.*\.local", fqdn) and os.getenv("HOSTNAME", "NA").startswith( "compy" ): fqdn = "compy.pnl.gov" - for pattern in regex_endpoint_map.keys(): + for pattern in REGEX_ENDPOINT_MAP.keys(): if re.fullmatch(pattern, fqdn): - local_endpoint = regex_endpoint_map.get(pattern) + globus_info.local_endpoint = REGEX_ENDPOINT_MAP.get(pattern) break # FQDN is not set on Perlmutter at NERSC - if not local_endpoint: + if not globus_info.local_endpoint: nersc_hostname = os.environ.get("NERSC_HOST") if nersc_hostname and ( nersc_hostname == "perlmutter" or nersc_hostname == "unknown" ): - local_endpoint = regex_endpoint_map.get(r"perlmutter.*\.nersc\.gov") - if not local_endpoint: - logger.error( - "{} does not have the local Globus endpoint set nor could one be found in regex_endpoint_map.".format( - ini_path + globus_info.local_endpoint = REGEX_ENDPOINT_MAP.get( + r"perlmutter.*\.nersc\.gov" ) + if not globus_info.local_endpoint: + logger.error( + f"{ini_path} does not have the local Globus endpoint set nor could one be found in REGEX_ENDPOINT_MAP." ) sys.exit(1) - if remote_endpoint.upper() in hpss_endpoint_map.keys(): - remote_endpoint = hpss_endpoint_map.get(remote_endpoint.upper()) - native_client = NativeClient( - client_id="6c1629cf-446c-49e7-af95-323c6412397f", - app_name="Zstash", - default_scopes="openid urn:globus:auth:scope:transfer.api.globus.org:all", - ) - native_client.login(no_local_server=True, refresh_tokens=True) - transfer_authorizer = native_client.get_authorizers().get("transfer.api.globus.org") - transfer_client = TransferClient(authorizer=transfer_authorizer) - - for ep_id in [local_endpoint, remote_endpoint]: - r = transfer_client.endpoint_autoactivate(ep_id, if_expires_in=600) - if r.get("code") == "AutoActivationFailed": - logger.error( - "The {} endpoint is not activated or the current activation expires soon. Please go to https://app.globus.org/file-manager/collections/{} and (re)activate the endpoint.".format( - ep_id, ep_id - ) - ) - sys.exit(1) - - -def file_exists(name: str) -> bool: - global archive_directory_listing - - for entry in archive_directory_listing: +# Used exclusively in globus_transfer +def file_exists(globus_info: GlobusInfo, name: str) -> bool: + if not globus_info.archive_directory_listing: + raise ValueError("archive_directory_listing is undefined") + for entry in globus_info.archive_directory_listing: if entry.get("name") == name: return True return False -global_variable_tarfiles_pushed = 0 - - -# C901 'globus_transfer' is too complex (20) -def globus_transfer( # noqa: C901 - remote_ep: str, remote_path: str, name: str, transfer_type: str, non_blocking: bool -): - global transfer_client - global local_endpoint - global remote_endpoint - global transfer_data - global task_id - global archive_directory_listing - global global_variable_tarfiles_pushed - - logger.info(f"{ts_utc()}: Entered globus_transfer() for name = {name}") - logger.debug(f"{ts_utc()}: non_blocking = {non_blocking}") - if not transfer_client: - globus_activate("globus://" + remote_ep) - if not transfer_client: - sys.exit(1) - - if transfer_type == "get": - if not archive_directory_listing: - archive_directory_listing = transfer_client.operation_ls( - remote_endpoint, remote_path - ) - if not file_exists(name): - logger.error( - "Remote file globus://{}{}/{} does not exist".format( - remote_ep, remote_path, name - ) - ) - sys.exit(1) - +# Used exclusively in globus_transfer +def compute_transfer_paths( + globus_info: GlobusInfo, remote_path: str, name: str, transfer_type: str +) -> Tuple[str, str, str, str]: + if not (globus_info.local_endpoint and globus_info.remote_endpoint): + raise ValueError( + f"Undefined value: local_endpoint={globus_info.local_endpoint} & remote_endpoint={globus_info.remote_endpoint}" + ) if transfer_type == "get": - src_ep = remote_endpoint + src_ep = globus_info.remote_endpoint src_path = os.path.join(remote_path, name) - dst_ep = local_endpoint + dst_ep = globus_info.local_endpoint dst_path = os.path.join(os.getcwd(), name) - else: - src_ep = local_endpoint + elif transfer_type == "put": + src_ep = globus_info.local_endpoint src_path = os.path.join(os.getcwd(), name) - dst_ep = remote_endpoint + dst_ep = globus_info.remote_endpoint dst_path = os.path.join(remote_path, name) - - subdir = os.path.basename(os.path.normpath(remote_path)) - subdir_label = re.sub("[^A-Za-z0-9_ -]", "", subdir) - filename = name.split(".")[0] - label = subdir_label + " " + filename - - if not transfer_data: - transfer_data = TransferData( - transfer_client, - src_ep, - dst_ep, - label=label, - verify_checksum=True, - preserve_timestamp=True, - fail_on_quota_errors=True, + else: + raise ValueError(f"Invalid transfer_type: {transfer_type}") + return src_ep, src_path, dst_ep, dst_path + + +# Used exclusively in globus_transfer +def build_transfer_label(remote_path: str, name: str) -> str: + subdir: str = os.path.basename(os.path.normpath(remote_path)) + subdir_label: str = re.sub("[^A-Za-z0-9_ -]", "", subdir) + filename: str = name.split(".")[0] + label: str = subdir_label + " " + filename + return label + + +# Used exclusively in globus_transfer +def handle_previous_task(globus_info: GlobusInfo) -> Optional[str]: + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + task = globus_info.transfer_client.get_task(globus_info.task_id) + prev_task_status = task["status"] + # one of {ACTIVE, SUCCEEDED, FAILED, CANCELED, PENDING, INACTIVE} + # NOTE: How we behave here depends upon whether we want to support mutliple active transfers. + # Presently, we do not, except inadvertantly (if status == PENDING) + if prev_task_status == "ACTIVE": + logger.info( + f"{ts_utc()}: Previous task_id {globus_info.task_id} Still Active. Returning ACTIVE." ) - transfer_data.add_item(src_path, dst_path) - transfer_data["label"] = label - try: - if task_id: - task = transfer_client.get_task(task_id) - prev_task_status = task["status"] - # one of {ACTIVE, SUCCEEDED, FAILED, CANCELED, PENDING, INACTIVE} - # NOTE: How we behave here depends upon whether we want to support mutliple active transfers. - # Presently, we do not, except inadvertantly (if status == PENDING) - if prev_task_status == "ACTIVE": - logger.info( - f"{ts_utc()}: Previous task_id {task_id} Still Active. Returning ACTIVE." - ) - return "ACTIVE" - elif prev_task_status == "SUCCEEDED": - logger.info( - f"{ts_utc()}: Previous task_id {task_id} status = SUCCEEDED." - ) - src_ep = task["source_endpoint_id"] - dst_ep = task["destination_endpoint_id"] - label = task["label"] - ts = ts_utc() - logger.info( - "{}:Globus transfer {}, from {} to {}: {} succeeded".format( - ts, task_id, src_ep, dst_ep, label - ) - ) - else: - logger.error( - f"{ts_utc()}: Previous task_id {task_id} status = {prev_task_status}." - ) - - # DEBUG: review accumulated items in TransferData - logger.info(f"{ts_utc()}: TransferData: accumulated items:") - attribs = transfer_data.__dict__ - for item in attribs["data"]["DATA"]: - if item["DATA_TYPE"] == "transfer_item": - global_variable_tarfiles_pushed += 1 - print( - f" (routine) PUSHING (#{global_variable_tarfiles_pushed}) STORED source item: {item['source_path']}", - flush=True, - ) - - # SUBMIT new transfer here - logger.info(f"{ts_utc()}: DIVING: Submit Transfer for {transfer_data['label']}") - task = submit_transfer_with_checks(transfer_data) - task_id = task.get("task_id") - # NOTE: This log message is misleading. If we have accumulated multiple tar files for transfer, - # the "lable" given here refers only to the LAST tarfile in the TransferData list. + return "ACTIVE" + elif prev_task_status == "SUCCEEDED": logger.info( - f"{ts_utc()}: SURFACE Submit Transfer returned new task_id = {task_id} for label {transfer_data['label']}" + f"{ts_utc()}: Previous task_id {globus_info.task_id} status = SUCCEEDED." ) - - # Nullify the submitted transfer data structure so that a new one will be created on next call. - transfer_data = None - except TransferAPIError as e: - if e.code == "NoCredException": - logger.error( - "{}. Please go to https://app.globus.org/endpoints and activate the endpoint.".format( - e.message - ) - ) - else: - logger.error(e) - sys.exit(1) - except Exception as e: - logger.error("Exception: {}".format(e)) - sys.exit(1) - - # test for blocking on new task_id - task_status = "UNKNOWN" - if not non_blocking: - task_status = globus_block_wait( - task_id=task_id, wait_timeout=7200, polling_interval=10, max_retries=5 + src_ep = task["source_endpoint_id"] + dst_ep = task["destination_endpoint_id"] + label = task["label"] + ts = ts_utc() + logger.info( + f"{ts}:Globus transfer {globus_info.task_id}, from {src_ep} to {dst_ep}: {label} succeeded" ) else: - logger.info(f"{ts_utc()}: NO BLOCKING (task_wait) for task_id {task_id}") - - if transfer_type == "put": - return task_status + logger.error( + f"{ts_utc()}: Previous task_id {globus_info.task_id} status = {prev_task_status}." + ) + return None - if transfer_type == "get" and task_id: - globus_wait(task_id) - return task_status +# Used exclusively in globus_transfer +def log_accumulated_transfer_items(globus_info: GlobusInfo): + # DEBUG: review accumulated items in TransferData + logger.info(f"{ts_utc()}: TransferData: accumulated items:") + attribs = globus_info.transfer_data.__dict__ + for item in attribs["data"]["DATA"]: + if item["DATA_TYPE"] == "transfer_item": + globus_info.tarfiles_pushed += 1 + print( + f" (routine) PUSHING (#{globus_info.tarfiles_pushed}) STORED source item: {item['source_path']}", + flush=True, + ) +# Used exclusively in globus_transfer def globus_block_wait( - task_id: str, wait_timeout: int, polling_interval: int, max_retries: int + globus_info: GlobusInfo, wait_timeout: int, polling_interval: int, max_retries: int ): - global transfer_client # poll every "polling_interval" seconds to speed up small transfers. Report every 2 hours, stop waiting aftert 5*2 = 10 hours logger.info( - f"{ts_utc()}: BLOCKING START: invoking task_wait for task_id = {task_id}" + f"{ts_utc()}: BLOCKING START: invoking task_wait for task_id = {globus_info.task_id}" ) task_status = "UNKNOWN" retry_count = 0 @@ -323,14 +282,18 @@ def globus_block_wait( logger.info( f"{ts_utc()}: on task_wait try {retry_count+1} out of {max_retries}" ) - transfer_client.task_wait( - task_id, timeout=wait_timeout, polling_interval=10 + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + globus_info.transfer_client.task_wait( + globus_info.task_id, timeout=wait_timeout, polling_interval=10 ) logger.info(f"{ts_utc()}: done with wait") except Exception as e: logger.error(f"Unexpected Exception: {e}") else: - curr_task = transfer_client.get_task(task_id) + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + curr_task = globus_info.transfer_client.get_task(globus_info.task_id) task_status = curr_task["status"] if task_status == "SUCCEEDED": break @@ -347,15 +310,39 @@ def globus_block_wait( task_status = "EXHAUSTED_TIMEOUT_RETRIES" logger.info( - f"{ts_utc()}: BLOCKING ENDS: task_id {task_id} returned from task_wait with status {task_status}" + f"{ts_utc()}: BLOCKING ENDS: task_id {globus_info.task_id} returned from task_wait with status {task_status}" ) return task_status -def globus_wait(task_id: str): - global transfer_client +# Used exclusively in globus_transfer, globus_finalize +def submit_transfer_with_checks(globus_info: GlobusInfo): + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + try: + task = globus_info.transfer_client.submit_transfer(globus_info.transfer_data) + except TransferAPIError as err: + if err.info.consent_required: + check_consents(globus_info) + # Quit here and tell user to re-try + print( + "Consents added, please re-run the previous command to start transfer" + ) + sys.exit(0) + else: + if err.info.authorization_parameters: + print("Error is in authorization parameters") + raise err + return task + +# Used exclusively in globus_transfer, globus_finalize +def globus_wait(globus_info: GlobusInfo, alternative_task_id=None): + if alternative_task_id: + task_id = alternative_task_id + else: + task_id = globus_info.task_id try: """ A Globus transfer job (task) can be in one of the three states: @@ -364,80 +351,219 @@ def globus_wait(task_id: str): with 20 second timeout limit. If the task is ACTIVE after time runs out 'task_wait' returns False, and True otherwise. """ - while not transfer_client.task_wait(task_id, timeout=300, polling_interval=20): + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + while not globus_info.transfer_client.task_wait( + task_id, timeout=300, polling_interval=20 + ): pass """ The Globus transfer job (task) has been finished (SUCCEEDED or FAILED). Check if the transfer SUCCEEDED or FAILED. """ - task = transfer_client.get_task(task_id) + if not globus_info.transfer_client: + raise ValueError("transfer_client is undefined") + task = globus_info.transfer_client.get_task(task_id) if task["status"] == "SUCCEEDED": src_ep = task["source_endpoint_id"] dst_ep = task["destination_endpoint_id"] label = task["label"] logger.info( - "Globus transfer {}, from {} to {}: {} succeeded".format( - task_id, src_ep, dst_ep, label - ) + f"Globus transfer {task_id}, from {src_ep} to {dst_ep}: {label} succeeded" ) else: logger.error("Transfer FAILED") except TransferAPIError as e: if e.code == "NoCredException": logger.error( - "{}. Please go to https://app.globus.org/endpoints and activate the endpoint.".format( - e.message + f"{e.message}. Please go to https://app.globus.org/endpoints and activate the endpoint." + ) + else: + logger.error(e) + sys.exit(1) + except Exception as e: + logger.error(f"Exception: {e}") + sys.exit(1) + + +# Primary functions ########################################################### + + +def globus_activate(globus_info: GlobusInfo, alt_hpss: str = ""): + """ + Read the local globus endpoint UUID from ~/.zstash.ini. + If the ini file does not exist, create an ini file with empty values, + and try to find the local endpoint UUID based on the FQDN + """ + if alt_hpss != "": + globus_info.hpss_path = alt_hpss + globus_info.url = urlparse(alt_hpss) + if globus_info.url.scheme != "globus": + raise ValueError(f"Invalid url.scheme={globus_info.url.scheme}") + globus_info.remote_endpoint = globus_info.url.netloc + else: + globus_info.remote_endpoint = globus_info.url.netloc + set_local_endpoint(globus_info) + if globus_info.remote_endpoint.upper() in HPSS_ENDPOINT_MAP.keys(): + globus_info.remote_endpoint = HPSS_ENDPOINT_MAP.get( + globus_info.remote_endpoint.upper() + ) + log_current_endpoints(globus_info) + set_clients(globus_info) + log_current_endpoints(globus_info) + for ep_id in [globus_info.local_endpoint, globus_info.remote_endpoint]: + if ep_id: + ep_name = ep_to_name(ep_id) + else: + ep_name = "undefined" + logger.debug(f"globus_activate. endpoint={ep_name}") + if not globus_info.transfer_client: + raise ValueError("Was unable to instantiate transfer_client") + r = globus_info.transfer_client.endpoint_autoactivate(ep_id, if_expires_in=600) + if r.get("code") == "AutoActivationFailed": + logger.error( + f"The {ep_id} endpoint is not activated or the current activation expires soon. Please go to https://app.globus.org/file-manager/collections/{ep_id} and (re)activate the endpoint." + ) + sys.exit(1) + + +def globus_transfer( + globus_info: GlobusInfo, + remote_ep: str, + remote_path: str, + name: str, + transfer_type: str, + non_blocking: bool, +): + + logger.info(f"{ts_utc()}: Entered globus_transfer() for name = {name}") + logger.debug(f"{ts_utc()}: non_blocking = {non_blocking}") + if not globus_info.transfer_client: + globus_activate(globus_info, "globus://" + remote_ep) + # Try again: + if not globus_info.transfer_client: + logger.info(f"{ts_utc()}: Could not instantiate transfer client.") + sys.exit(1) + + if transfer_type == "get": + if not globus_info.archive_directory_listing: + globus_info.archive_directory_listing = ( + globus_info.transfer_client.operation_ls( + globus_info.remote_endpoint, remote_path ) ) + if not file_exists(globus_info, name): + logger.error( + f"Remote file globus://{remote_ep}{remote_path}/{name} does not exist" + ) + sys.exit(1) + src_ep: str + src_path: str + dst_ep: str + dst_path: str + src_ep, src_path, dst_ep, dst_path = compute_transfer_paths( + globus_info, remote_path, name, transfer_type + ) + label: str = build_transfer_label(remote_path, name) + if not globus_info.transfer_data: + globus_info.transfer_data = TransferData( + globus_info.transfer_client, + src_ep, + dst_ep, + label=label, + verify_checksum=True, + preserve_timestamp=True, + fail_on_quota_errors=True, + ) + globus_info.transfer_data.add_item(src_path, dst_path) + globus_info.transfer_data["label"] = label + try: + if globus_info.task_id: + relevant_prev_task_status: Optional[str] = handle_previous_task(globus_info) + if relevant_prev_task_status: + return relevant_prev_task_status + log_accumulated_transfer_items(globus_info) + # SUBMIT new transfer here + logger.info( + f"{ts_utc()}: DIVING: Submit Transfer for {globus_info.transfer_data['label']}" + ) + task = submit_transfer_with_checks(globus_info) + globus_info.task_id = task.get("task_id") + # NOTE: This log message is misleading. If we have accumulated multiple tar files for transfer, + # the "label" given here refers only to the LAST tarfile in the TransferData list. + logger.info( + f"{ts_utc()}: SURFACE Submit Transfer returned new task_id = {globus_info.task_id} for label {globus_info.transfer_data['label']}" + ) + # Nullify the submitted transfer data structure so that a new one will be created on next call. + globus_info.transfer_data = None + except TransferAPIError as e: + if e.code == "NoCredException": + logger.error( + f"{e.message}. Please go to https://app.globus.org/endpoints and activate the endpoint." + ) else: logger.error(e) sys.exit(1) except Exception as e: - logger.error("Exception: {}".format(e)) + logger.error(f"Exception: {e}") sys.exit(1) + # test for blocking on new task_id + task_status = "UNKNOWN" + if not non_blocking: + task_status = globus_block_wait( + globus_info, wait_timeout=7200, polling_interval=10, max_retries=5 + ) + else: + logger.info( + f"{ts_utc()}: NO BLOCKING (task_wait) for task_id {globus_info.task_id}" + ) -def globus_finalize(non_blocking: bool = False): - global transfer_client - global transfer_data - global task_id - global global_variable_tarfiles_pushed + if transfer_type == "put": + return task_status + if transfer_type == "get" and globus_info.task_id: + globus_wait(globus_info) + + return task_status + + +def globus_finalize(globus_info: GlobusInfo, non_blocking: bool = False): last_task_id = None - if transfer_data: + if globus_info.transfer_data: # DEBUG: review accumulated items in TransferData logger.info(f"{ts_utc()}: FINAL TransferData: accumulated items:") - attribs = transfer_data.__dict__ + attribs = globus_info.transfer_data.__dict__ for item in attribs["data"]["DATA"]: if item["DATA_TYPE"] == "transfer_item": - global_variable_tarfiles_pushed += 1 + globus_info.tarfiles_pushed += 1 print( - f" (finalize) PUSHING ({global_variable_tarfiles_pushed}) source item: {item['source_path']}", + f" (finalize) PUSHING ({globus_info.tarfiles_pushed}) source item: {item['source_path']}", flush=True, ) # SUBMIT new transfer here - logger.info(f"{ts_utc()}: DIVING: Submit Transfer for {transfer_data['label']}") + logger.info( + f"{ts_utc()}: DIVING: Submit Transfer for {globus_info.transfer_data['label']}" + ) try: - last_task = submit_transfer_with_checks(transfer_data) + last_task = submit_transfer_with_checks(globus_info) last_task_id = last_task.get("task_id") except TransferAPIError as e: if e.code == "NoCredException": logger.error( - "{}. Please go to https://app.globus.org/endpoints and activate the endpoint.".format( - e.message - ) + f"{e.message}. Please go to https://app.globus.org/endpoints and activate the endpoint." ) else: logger.error(e) sys.exit(1) except Exception as e: - logger.error("Exception: {}".format(e)) + logger.error(f"Exception: {e}") sys.exit(1) if not non_blocking: - if task_id: - globus_wait(task_id) + if globus_info.task_id: + globus_wait(globus_info) if last_task_id: - globus_wait(last_task_id) + globus_wait(globus_info, last_task_id) diff --git a/zstash/hpss.py b/zstash/hpss.py index 24603388..c5424a10 100644 --- a/zstash/hpss.py +++ b/zstash/hpss.py @@ -7,40 +7,31 @@ from six.moves.urllib.parse import urlparse from .globus import globus_transfer -from .settings import get_db_filename, logger -from .utils import run_command, ts_utc - -prev_transfers: List[str] = list() -curr_transfers: List[str] = list() +from .settings import logger +from .utils import CommandInfo, HPSSType, run_command, ts_utc def hpss_transfer( - hpss: str, + command_info: CommandInfo, file_path: str, transfer_type: str, - cache: str, - keep: bool = False, non_blocking: bool = False, - is_index: bool = False, ): - global prev_transfers - global curr_transfers logger.info( - f"{ts_utc()}: in hpss_transfer, prev_transfers is starting as {prev_transfers}" + f"{ts_utc()}: in hpss_transfer, prev_transfers is starting as {command_info.prev_transfers}" ) + # TODO: Expected output for tests needs to be changed if we uncomment this: # logger.debug( - # f"{ts_utc()}: in hpss_transfer, curr_transfers is starting as {curr_transfers}" + # f"{ts_utc()}: in hpss_transfer, curr_transfers is starting as {command_info.curr_transfers}" # ) - if hpss == "none": - logger.info("{}: HPSS is unavailable".format(transfer_type)) - if transfer_type == "put" and file_path != get_db_filename(cache): - # We are adding a file (that is not the cache) to the local non-HPSS archive + if command_info.hpss_type == HPSSType.NO_HPSS: + logger.info(f"{transfer_type}: HPSS is unavailable") + if transfer_type == "put" and file_path != command_info.get_db_name(): + # We are adding a file (that is NOT the database) to the local non-HPSS archive logger.info( - "{}: Keeping tar files locally and removing write permissions".format( - transfer_type - ) + f"{transfer_type}: Keeping tar files locally and removing write permissions" ) # https://unix.stackexchange.com/questions/46915/get-the-chmod-numerical-value-for-a-file display_mode_command: List[str] = "stat --format '%a' {}".format( @@ -74,21 +65,19 @@ def hpss_transfer( transfer_command = "get" else: raise ValueError("Invalid transfer_type={}".format(transfer_type)) - logger.info("Transferring file {} HPSS: {}".format(transfer_word, file_path)) - scheme: str - endpoint: str - path: str - name: str + logger.info(f"Transferring file {transfer_word} HPSS: {file_path}") - url = urlparse(hpss) - scheme = url.scheme - endpoint = url.netloc - url_path = url.path + url = urlparse(command_info.config.hpss) + endpoint: str = str(url.netloc) + url_path: str = str(url.path) - curr_transfers.append(file_path) + command_info.curr_transfers.append(file_path) + # TODO: Expected output for tests needs to be changed if we uncomment this: # logger.debug( - # f"{ts_utc()}: curr_transfers has been appended to, is now {curr_transfers}" + # f"{ts_utc()}: curr_transfers has been appended to, is now {command_info.curr_transfers}" # ) + path: str + name: str path, name = os.path.split(file_path) # Need to be in local directory for `hsi` to work @@ -104,12 +93,20 @@ def hpss_transfer( # For `get`, this directory is where the file we get from HPSS will go. os.chdir(path) - if scheme == "globus": + globus_status = None + if command_info.hpss_type == HPSSType.GLOBUS: globus_status = "UNKNOWN" # Transfer file using the Globus Transfer Service logger.info(f"{ts_utc()}: DIVING: hpss calls globus_transfer(name={name})") + if not command_info.globus_info: + raise ValueError("globus_info is undefined") globus_status = globus_transfer( - endpoint, url_path, name, transfer_type, non_blocking + command_info.globus_info, + endpoint, + url_path, + name, + transfer_type, + non_blocking, ) logger.info( f"{ts_utc()}: SURFACE hpss globus_transfer(name={name}) returns {globus_status}" @@ -120,8 +117,10 @@ def hpss_transfer( # return a tuple (task_id, status). else: # Transfer file using `hsi` - command: str = 'hsi -q "cd {}; {} {}"'.format(hpss, transfer_command, name) - error_str: str = "Transferring file {} HPSS: {}".format(transfer_word, name) + command: str = ( + f'hsi -q "cd {command_info.config.hpss}; {transfer_command} {name}"' + ) + error_str: str = f"Transferring file {transfer_word} HPSS: {name}" run_command(command, error_str) # Return to original working directory @@ -129,41 +128,40 @@ def hpss_transfer( os.chdir(cwd) if transfer_type == "put": - if not keep: - if (scheme != "globus") or (globus_status == "SUCCEEDED"): + if not command_info.keep: + if (command_info.hpss_type != HPSSType.GLOBUS) or ( + globus_status == "SUCCEEDED" + ): # Note: This is intended to fulfill the default removal of successfully-transfered # tar files when keep=False, irrespective of non-blocking status logger.debug( - f"{ts_utc()}: deleting transfered files {prev_transfers}" + f"{ts_utc()}: deleting transfered files {command_info.prev_transfers}" ) - for src_path in prev_transfers: + for src_path in command_info.prev_transfers: os.remove(src_path) - prev_transfers = curr_transfers - curr_transfers = list() + command_info.prev_transfers = command_info.curr_transfers + command_info.curr_transfers = list() logger.info( - f"{ts_utc()}: prev_transfers has been set to {prev_transfers}" + f"{ts_utc()}: prev_transfers has been set to {command_info.prev_transfers}" ) def hpss_put( - hpss: str, + command_info: CommandInfo, file_path: str, - cache: str, - keep: bool = True, non_blocking: bool = False, - is_index=False, ): """ Put a file to the HPSS archive. """ - hpss_transfer(hpss, file_path, "put", cache, keep, non_blocking, is_index) + hpss_transfer(command_info, file_path, "put", non_blocking) -def hpss_get(hpss: str, file_path: str, cache: str): +def hpss_get(command_info: CommandInfo, file_path: str): """ Get a file from the HPSS archive. """ - hpss_transfer(hpss, file_path, "get", cache, False) + hpss_transfer(command_info, file_path, "get") def hpss_chgrp(hpss: str, group: str, recurse: bool = False): diff --git a/zstash/hpss_utils.py b/zstash/hpss_utils.py index 87325f4f..20038d7c 100644 --- a/zstash/hpss_utils.py +++ b/zstash/hpss_utils.py @@ -13,8 +13,8 @@ import _io from .hpss import hpss_put -from .settings import BLOCK_SIZE, TupleFilesRowNoId, TupleTarsRowNoId, config, logger -from .utils import create_tars_table, tars_table_exists, ts_utc +from .settings import BLOCK_SIZE, TupleFilesRowNoId, TupleTarsRowNoId, logger +from .utils import CommandInfo, create_tars_table, tars_table_exists, ts_utc # Minimum output file object @@ -55,12 +55,11 @@ def close(self): def add_files( + command_info: CommandInfo, cur: sqlite3.Cursor, con: sqlite3.Connection, itar: int, files: List[str], - cache: str, - keep: bool, follow_symlinks: bool, skip_tars_md5: bool = False, non_blocking: bool = False, @@ -96,7 +95,9 @@ def add_files( do_hash = True else: do_hash = False - tarFileObject = HashIO(os.path.join(cache, tfname), "wb", do_hash) + tarFileObject = HashIO( + os.path.join(command_info.cache_dir, tfname), "wb", do_hash + ) # FIXME: error: Argument "fileobj" to "open" has incompatible type "HashIO"; expected "Optional[IO[bytes]]" tar = tarfile.open(mode="w", fileobj=tarFileObject, dereference=follow_symlinks) # type: ignore @@ -130,11 +131,12 @@ def add_files( # Close tar archive if current file is the last one or # if adding one more would push us over the limit. next_file_size: int = tar.gettarinfo(current_file).size - if config.maxsize is not None: - maxsize: int = config.maxsize - else: - raise TypeError("Invalid config.maxsize={}".format(config.maxsize)) - if i == nfiles - 1 or tarsize + next_file_size > maxsize: + command_info.validate_maxsize() + if not command_info.config.maxsize: + raise ValueError("config.maxsize is undefined") + if (i == nfiles - 1) or ( + tarsize + next_file_size > command_info.config.maxsize + ): # Close current temporary file logger.debug(f"{ts_utc()}: Closing tar archive {tfname}") @@ -153,21 +155,19 @@ def add_files( cur.execute("insert into tars values (NULL,?,?,?)", tar_tuple) con.commit() - # Transfer tar to HPSS - if config.hpss is not None: - hpss: str = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) - # NOTE: These lines could be added under an "if debug" condition # logger.info(f"{ts_utc()}: CONTENTS of CACHE upon call to hpss_put:") # process = subprocess.run(["ls", "-l", "zstash"], capture_output=True, text=True) # print(process.stdout) logger.info( - f"{ts_utc()}: DIVING: (add_files): Calling hpss_put to dispatch archive file {tfname} [keep, non_blocking] = [{keep}, {non_blocking}]" + f"{ts_utc()}: DIVING: (add_files): Calling hpss_put to dispatch archive file {tfname} [keep, non_blocking] = [{command_info.keep}, {non_blocking}]" + ) + hpss_put( + command_info, + os.path.join(command_info.cache_dir, tfname), + non_blocking, ) - hpss_put(hpss, os.path.join(cache, tfname), cache, keep, non_blocking) logger.info( f"{ts_utc()}: SURFACE (add_files): Called hpss_put to dispatch archive file {tfname}" ) diff --git a/zstash/ls.py b/zstash/ls.py index 8b6ad6e4..f4f8218f 100644 --- a/zstash/ls.py +++ b/zstash/ls.py @@ -5,20 +5,11 @@ import os import sqlite3 import sys -from typing import List, Tuple, Union +from typing import List, Union from .hpss import hpss_get -from .settings import ( - DEFAULT_CACHE, - FilesRow, - TarsRow, - TupleFilesRow, - TupleTarsRow, - config, - get_db_filename, - logger, -) -from .utils import tars_table_exists, update_config +from .settings import FilesRow, TarsRow, TupleFilesRow, TupleTarsRow, logger +from .utils import CommandInfo, HPSSType, tars_table_exists def ls(): @@ -26,21 +17,18 @@ def ls(): List all of the files in the HPSS path. Supports the '-l' argument for more information. """ - - args: argparse.Namespace - cache: str - args, cache = setup_ls() - - matches: List[FilesRow] = ls_database(args, cache) + command_info = CommandInfo("ls") + args: argparse.Namespace = setup_ls(command_info, sys.argv) + matches: List[FilesRow] = ls_database(command_info, args) print_matches(args, matches) if args.tars: - tar_matches: List[TarsRow] = ls_tars_database(args, cache) + tar_matches: List[TarsRow] = ls_tars_database(command_info, args) print_matches(args, tar_matches) -def setup_ls() -> Tuple[argparse.Namespace, str]: +def setup_ls(command_info: CommandInfo, arg_list: List[str]) -> argparse.Namespace: parser: argparse.ArgumentParser = argparse.ArgumentParser( usage="zstash ls [] [files]", description="List the files from an existing archive. If `files` is specified, then only the files specified will be listed. If `hpss=none`, then this will list the directories and files in the current directory excluding the cache.", @@ -75,34 +63,30 @@ def setup_ls() -> Tuple[argparse.Namespace, str]: ) parser.add_argument("files", nargs="*", default=["*"]) - args: argparse.Namespace = parser.parse_args(sys.argv[2:]) - if args.hpss and args.hpss.lower() == "none": + args: argparse.Namespace = parser.parse_args(arg_list[2:]) + + if args.hpss and (args.hpss.lower() == "none"): args.hpss = "none" - cache: str - if args.cache: - cache = args.cache - else: - cache = DEFAULT_CACHE if args.verbose: logger.setLevel(logging.DEBUG) - return args, cache + if args.cache: + command_info.cache_dir = args.cache + command_info.set_dir_to_archive(os.getcwd()) + command_info.set_hpss_parameters(args.hpss, null_hpss_allowed=True) + return args -def ls_database(args: argparse.Namespace, cache: str) -> List[FilesRow]: + +def ls_database(command_info: CommandInfo, args: argparse.Namespace) -> List[FilesRow]: # Open database logger.debug("Opening index database") - if not os.path.exists(get_db_filename(cache)): + if not os.path.exists(command_info.get_db_name()): # Will need to retrieve from HPSS - if args.hpss is not None: - config.hpss = args.hpss - if config.hpss is not None: - hpss = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) + if command_info.hpss_type != HPSSType.UNDEFINED: try: # Retrieve from HPSS - hpss_get(hpss, get_db_filename(cache), cache) + hpss_get(command_info, command_info.get_db_name()) except RuntimeError: raise FileNotFoundError("There was nothing to ls.") else: @@ -113,25 +97,16 @@ def ls_database(args: argparse.Namespace, cache: str) -> List[FilesRow]: raise ValueError(error_str) con: sqlite3.Connection = sqlite3.connect( - get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + command_info.get_db_name(), detect_types=sqlite3.PARSE_DECLTYPES ) cur: sqlite3.Cursor = con.cursor() - update_config(cur) - - if config.maxsize is not None: - maxsize: int = config.maxsize - else: - raise TypeError("Invalid config.maxsize={}".format(config.maxsize)) - config.maxsize = maxsize - - # The command line arg should always have precedence - if args.hpss is not None: - config.hpss = args.hpss + command_info.update_config_using_db(cur) + command_info.validate_maxsize() # Start doing actual work logger.debug("Running zstash ls") - logger.debug("HPSS path : %s" % (config.hpss)) + logger.debug(f"HPSS path : {command_info.config.hpss}") # Find matching files matches_: List[TupleFilesRow] = [] @@ -145,12 +120,7 @@ def ls_database(args: argparse.Namespace, cache: str) -> List[FilesRow]: if matches_ == []: raise FileNotFoundError("There was nothing to ls.") - # Remove duplicates - matches_ = list(set(matches_)) - matches: List[FilesRow] = list(map(FilesRow, matches_)) - - # Sort by tape and order within tapes (offset) - matches = sorted(matches, key=lambda t: (t.tar, t.offset)) + matches: List[FilesRow] = process_matches_files(matches_) if args.long: # Get the names of the columns @@ -165,9 +135,37 @@ def ls_database(args: argparse.Namespace, cache: str) -> List[FilesRow]: return matches -def ls_tars_database(args: argparse.Namespace, cache: str) -> List[TarsRow]: +def process_matches_files(matches_: List[TupleFilesRow]) -> List[FilesRow]: + # Remove duplicates + matches_ = list(set(matches_)) + matches: List[FilesRow] = list(map(FilesRow, matches_)) + + # Sort by tape and order within tapes (offset) + matches = sorted(matches, key=lambda t: (t.tar, t.offset)) + return matches + + +def print_matches( + args: argparse.Namespace, matches: Union[List[FilesRow], List[TarsRow]] +): + # Print the results + match: Union[FilesRow, TarsRow] + for match in matches: + if args.long: + # Print all contents of each match + for col in match.to_tuple(): + print(col, end="\t") + print("") + else: + # Just print the name + print(match.name) + + +def ls_tars_database( + command_info: CommandInfo, args: argparse.Namespace +) -> List[TarsRow]: con: sqlite3.Connection = sqlite3.connect( - get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + command_info.get_db_name(), detect_types=sqlite3.PARSE_DECLTYPES ) cur: sqlite3.Cursor = con.cursor() @@ -179,12 +177,7 @@ def ls_tars_database(args: argparse.Namespace, cache: str) -> List[TarsRow]: cur.execute("select * from tars") matches_: List[TupleTarsRow] = cur.fetchall() - # Remove duplicates - matches_ = list(set(matches_)) - matches: List[TarsRow] = list(map(TarsRow, matches_)) - - # Sort by name - matches = sorted(matches, key=lambda t: (t.name)) + matches: List[TarsRow] = process_matches_tars(matches_) if matches != []: print("\nTars:") @@ -201,17 +194,11 @@ def ls_tars_database(args: argparse.Namespace, cache: str) -> List[TarsRow]: return matches -def print_matches( - args: argparse.Namespace, matches: Union[List[FilesRow], List[TarsRow]] -): - # Print the results - match: Union[FilesRow, TarsRow] - for match in matches: - if args.long: - # Print all contents of each match - for col in match.to_tuple(): - print(col, end="\t") - print("") - else: - # Just print the name - print(match.name) +def process_matches_tars(matches_: List[TupleTarsRow]) -> List[TarsRow]: + # Remove duplicates + matches_ = list(set(matches_)) + matches: List[TarsRow] = list(map(TarsRow, matches_)) + + # Sort by name + matches = sorted(matches, key=lambda t: (t.name)) + return matches diff --git a/zstash/settings.py b/zstash/settings.py index fe39c853..f5369187 100644 --- a/zstash/settings.py +++ b/zstash/settings.py @@ -2,34 +2,14 @@ import datetime import logging -import os.path from typing import Optional, Tuple - -# Class to hold configuration -class Config(object): - path: Optional[str] = None - hpss: Optional[str] = None - maxsize: Optional[int] = None - - -def get_db_filename(cache: str) -> str: - # Database filename - return os.path.join(cache, "index.db") - - # Block size BLOCK_SIZE: int = 1024 * 1014 -# Default sub-directory to hold cache -DEFAULT_CACHE: str = "zstash" - # Time tolerance (in seconds) for file modification time TIME_TOL: float = 1.0 -# Initialize config -config: Config = Config() - # Initialize logger logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO) logger: logging.Logger = logging.getLogger(__name__) diff --git a/zstash/update.py b/zstash/update.py index 9ab565b8..8f17b3ce 100644 --- a/zstash/update.py +++ b/zstash/update.py @@ -12,49 +12,32 @@ from .globus import globus_activate, globus_finalize from .hpss import hpss_get, hpss_put from .hpss_utils import add_files -from .settings import ( - DEFAULT_CACHE, - TIME_TOL, - FilesRow, - TupleFilesRow, - config, - get_db_filename, - logger, -) -from .utils import get_files_to_archive, update_config +from .settings import TIME_TOL, FilesRow, TupleFilesRow, logger +from .utils import CommandInfo, HPSSType, get_files_to_archive def update(): + command_info = CommandInfo("update") + args: argparse.Namespace = setup_update(command_info, sys.argv) - args: argparse.Namespace - cache: str - args, cache = setup_update() - - result: Optional[List[str]] = update_database(args, cache) - - if result is None: + failures: Optional[List[str]] = update_database(command_info, args) + if failures is None: # There was either nothing to update or `--dry-run` was set. return - else: - failures = result - # Transfer to HPSS. Always keep a local copy of the database. - if config.hpss is not None: - hpss = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) - hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True) + hpss_put(command_info, command_info.get_db_name()) - globus_finalize(non_blocking=args.non_blocking) + if command_info.globus_info: + globus_finalize(command_info.globus_info, non_blocking=args.non_blocking) # List failures if len(failures) > 0: logger.warning("Some files could not be archived") for file_path in failures: - logger.error("Archiving {}".format(file_path)) + logger.error(f"Archiving {file_path}") -def setup_update() -> Tuple[argparse.Namespace, str]: +def setup_update(command_info: CommandInfo, arg_list: List[str]) -> argparse.Namespace: # Parser parser: argparse.ArgumentParser = argparse.ArgumentParser( usage="zstash update []", description="Update an existing zstash archive" @@ -113,89 +96,73 @@ def setup_update() -> Tuple[argparse.Namespace, str]: action="store_true", help="Hard copy symlinks. This is useful for preventing broken links. Note that a broken link will result in a failed update.", ) - args: argparse.Namespace = parser.parse_args(sys.argv[2:]) + args: argparse.Namespace = parser.parse_args(arg_list[2:]) if (not args.hpss) or (args.hpss.lower() == "none"): args.hpss = "none" args.keep = True - - # Copy configuration - # config.path = os.path.abspath(args.path) - config.hpss = args.hpss - config.maxsize = int(1024 * 1024 * 1024 * args.maxsize) - - cache: str - if args.cache: - cache = args.cache - else: - cache = DEFAULT_CACHE if args.verbose: logger.setLevel(logging.DEBUG) - return args, cache + if args.cache: + command_info.cache_dir = args.cache + command_info.keep = args.keep + command_info.set_dir_to_archive(os.getcwd()) + command_info.set_and_scale_maxsize(args.maxsize) + command_info.set_hpss_parameters(args.hpss) + + return args -# C901 'update_database' is too complex (20) -def update_database( # noqa: C901 - args: argparse.Namespace, cache: str +def update_database( + command_info: CommandInfo, args: argparse.Namespace ) -> Optional[List[str]]: # Open database logger.debug("Opening index database") - if not os.path.exists(get_db_filename(cache)): + if not os.path.exists(command_info.get_db_name()): # The database file doesn't exist in the cache. # We need to retrieve it from HPSS - if args.hpss is not None: - config.hpss = args.hpss - if config.hpss is not None: - hpss: str = config.hpss - else: - raise TypeError("Invalid config.hpss={}".format(config.hpss)) - globus_activate(hpss) - hpss_get(hpss, get_db_filename(cache), cache) + if command_info.hpss_type != HPSSType.NO_HPSS: + if command_info.globus_info: + globus_activate(command_info.globus_info) + hpss_get(command_info, command_info.get_db_name()) else: + # NOTE: while --hpss is required in `create`, it is optional in `update`! + # If --hpss is not provided, we assume it is 'none' => HPSSType.NO_HPSS error_str: str = ( - "--hpss argument is required when local copy of database is unavailable" + "--hpss argument (!= none) is required when local copy of database is unavailable" ) logger.error(error_str) raise ValueError(error_str) con: sqlite3.Connection = sqlite3.connect( - get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES + command_info.get_db_name(), detect_types=sqlite3.PARSE_DECLTYPES ) cur: sqlite3.Cursor = con.cursor() - update_config(cur) - - if config.maxsize is not None: - maxsize = config.maxsize - else: - raise TypeError("Invalid config.maxsize={}".format(config.maxsize)) - config.maxsize = int(maxsize) - - keep: bool - # The command line arg should always have precedence - if args.hpss == "none": - # If no HPSS is available, always keep the files. - keep = True - else: - # If HPSS is used, let the user specify whether or not to keep the files. - keep = args.keep + command_info.update_config_using_db(cur) + command_info.validate_maxsize() - if args.hpss is not None: - config.hpss = args.hpss + if command_info.hpss_type == HPSSType.NO_HPSS: + # If not using HPSS, always keep the files. + command_info.keep = True + # else: keep command_info.keep set to args.keep # Start doing actual work logger.debug("Running zstash update") - logger.debug("Local path : {}".format(config.path)) - logger.debug("HPSS path : {}".format(config.hpss)) - logger.debug("Max size : {}".format(maxsize)) - logger.debug("Keep local tar files : {}".format(keep)) + logger.debug(f"Local path : {command_info.config.path}") + logger.debug(f"HPSS path : {command_info.config.hpss}") + logger.debug(f"Max size : {command_info.config.maxsize}") + logger.debug(f"Keep local tar files : {command_info.keep}") - files: List[str] = get_files_to_archive(cache, args.include, args.exclude) + files: List[str] = get_files_to_archive( + command_info.cache_dir, args.include, args.exclude + ) # Eliminate files that are already archived and up to date newfiles: List[str] = [] for file_path in files: + # logger.debug(f"file_path={file_path}") statinfo: os.stat_result = os.lstat(file_path) mdtime_new: datetime = datetime.utcfromtimestamp(statinfo.st_mtime) mode: int = statinfo.st_mode @@ -212,6 +179,7 @@ def update_database( # noqa: C901 while True: # Get the corresponding row in the 'files' table match_: Optional[TupleFilesRow] = cur.fetchone() + # logger.debug(f"match_={match_}") if match_ is None: break else: @@ -233,6 +201,10 @@ def update_database( # noqa: C901 con.commit() con.close() return None + # else: + # logger.debug(f"Number of files to update: {len(newfiles)}") + # for f in newfiles: + # logger.debug(f) # --dry-run option if args.dry_run: @@ -257,12 +229,12 @@ def update_database( # noqa: C901 try: # Add files failures = add_files( + command_info, cur, con, itar, newfiles, - cache, - keep, + command_info.keep, args.follow_symlinks, non_blocking=args.non_blocking, ) @@ -271,12 +243,11 @@ def update_database( # noqa: C901 else: # Add files failures = add_files( + command_info, cur, con, itar, newfiles, - cache, - keep, args.follow_symlinks, non_blocking=args.non_blocking, ) diff --git a/zstash/utils.py b/zstash/utils.py index ea793603..07f70c03 100644 --- a/zstash/utils.py +++ b/zstash/utils.py @@ -5,10 +5,145 @@ import sqlite3 import subprocess from datetime import datetime, timezone +from enum import Enum from fnmatch import fnmatch -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple +from urllib.parse import ParseResult, urlparse -from .settings import TupleTarsRow, config, logger +from globus_sdk import TransferClient, TransferData +from globus_sdk.services.transfer.response.iterable import IterableTransferResponse + +from .settings import TupleTarsRow, logger + + +class HPSSType(Enum): + NO_HPSS = 1 + SAME_MACHINE_HPSS = 2 + GLOBUS = 3 + UNDEFINED = 4 + + +class GlobusInfo(object): + def __init__(self, hpss_path: str): + url: ParseResult = urlparse(hpss_path) + if url.scheme != "globus": + raise ValueError(f"Invalid Globus hpss_path={hpss_path}") + self.hpss_path: str = hpss_path + self.url: ParseResult = url + + # Set in globus.globus_activate + self.remote_endpoint: Optional[str] = None + self.local_endpoint: Optional[str] = None + self.transfer_client: Optional[TransferClient] = None + + # Set in globus.globus_transfer + self.archive_directory_listing: Optional[IterableTransferResponse] = None + self.transfer_data: Optional[TransferData] = None + self.task_id = None + self.tarfiles_pushed: int = 0 + + +# Class to hold configuration, as it appears in the database +class Config(object): + path: Optional[str] = None + hpss: Optional[str] = None + maxsize: Optional[int] = None + + +class CommandInfo(object): + + def __init__(self, command_name: str): + self.command_name: str = command_name + self.dir_called_from: str = os.getcwd() + self.cache_dir: str = "zstash" # # Default sub-directory to hold cache + self.config: Config = Config() + self.keep: bool = False # Defaults to False + self.prev_transfers: List[str] = [] + self.curr_transfers: List[str] = [] + # Use set_dir_to_archive: + self.dir_to_archive_relative: Optional[str] = None + # Use set_hpss_parameters: + self.hpss_type: Optional[HPSSType] = None + self.globus_info: Optional[GlobusInfo] = None + + def set_dir_to_archive(self, path: str): + abs_path = os.path.abspath(path) + if abs_path is not None: + self.config.path = abs_path + self.dir_to_archive_relative = path + else: + raise ValueError(f"Invalid path={path}") + + def set_and_scale_maxsize(self, maxsize): + self.config.maxsize = int(1024 * 1024 * 1024 * maxsize) + + def validate_maxsize(self): + if self.config.maxsize is not None: + self.config.maxsize = int(self.config.maxsize) + else: + raise ValueError("config.maxsize is undefined") + + def set_hpss_parameters(self, hpss_path: str, null_hpss_allowed=False): + self.config.hpss = hpss_path + if hpss_path == "none": + self.hpss_type = HPSSType.NO_HPSS + elif hpss_path is not None: + url = urlparse(hpss_path) + if url.scheme == "globus": + self.hpss_type = HPSSType.GLOBUS + self.globus_info = GlobusInfo(hpss_path) + globus_cfg: str = os.path.expanduser("~/.globus-native-apps.cfg") + logger.info(f"Checking if {globus_cfg} exists") + if os.path.exists(globus_cfg): + logger.info( + f"{globus_cfg} exists. If this file does not have the proper settings, it may cause a TransferAPIError (e.g., 'Token is not active', 'No credentials supplied')" + ) + else: + logger.info( + f"{globus_cfg} does not exist. zstash will need to prompt for authentications twice, and then you will need to re-run." + ) + else: + self.hpss_type = HPSSType.SAME_MACHINE_HPSS + elif null_hpss_allowed: + self.hpss_type = HPSSType.UNDEFINED + else: + raise ValueError("hpss_path is undefined") + logger.debug(f"Setting hpss_type={self.hpss_type}") + logger.debug(f"Setting hpss={self.config.hpss}") + + def update_config_using_db(self, cur: sqlite3.Cursor): + # Retrieve some configuration settings from database + # Loop through all attributes of config. + for attr in dir(self.config): + value: Any = getattr(self.config, attr) + if not callable(value) and not attr.startswith("__"): + # config.{attr} is not a function. + # The attribute name does not start with "__" + # Get the value (column 2) for attribute `attr` (column 1) + # i.e., for the row where column 1 is the attribute, get the value from column 2 + cur.execute("select value from config where arg=?", (attr,)) + value = cur.fetchone()[0] + # Update config with the new attribute-value pair + setattr(self.config, attr, value) + logger.debug( + f"Updated config using db. Now, maxsize={self.config.maxsize}, path={self.config.path}, hpss={self.config.hpss}, hpss_type={self.hpss_type}" + ) + + def get_db_name(self) -> str: + return os.path.join(self.cache_dir, "index.db") + + def list_cache_dir(self): + logger.info( + f"Contents of cache {self.cache_dir} = {os.listdir(self.cache_dir)}" + ) + + def list_hpss_path(self): + if self.hpss_type == HPSSType.SAME_MACHINE_HPSS: + command = "hsi ls -l {}".format(self.config.hpss) + error_str = f"Attempted to list contents at config.hpss={self.config.hpss}" + run_command(command, error_str) + else: + logger.info("No HPSS path to list") def ts_utc(): @@ -109,22 +244,6 @@ def get_files_to_archive(cache: str, include: str, exclude: str) -> List[str]: return files -def update_config(cur: sqlite3.Cursor): - # Retrieve some configuration settings from database - # Loop through all attributes of config. - for attr in dir(config): - value: Any = getattr(config, attr) - if not callable(value) and not attr.startswith("__"): - # config.{attr} is not a function. - # The attribute name does not start with "__" - # Get the value (column 2) for attribute `attr` (column 1) - # i.e., for the row where column 1 is the attribute, get the value from column 2 - cur.execute("select value from config where arg=?", (attr,)) - value = cur.fetchone()[0] - # Update config with the new attribute-value pair - setattr(config, attr, value) - - def create_tars_table(cur: sqlite3.Cursor, con: sqlite3.Connection): # Create 'tars' table cur.execute(