Skip to content

Commit 8fa5ed5

Browse files
committed
Store state with object
1 parent d9d50b7 commit 8fa5ed5

File tree

5 files changed

+191
-140
lines changed

5 files changed

+191
-140
lines changed

zstash/create.py

Lines changed: 50 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
import os.path
77
import sqlite3
88
import sys
9-
from typing import Any, List, Tuple
9+
from typing import Any, List
1010

1111
from six.moves.urllib.parse import urlparse
1212

1313
from .globus import globus_activate, globus_finalize
1414
from .hpss import hpss_put
1515
from .hpss_utils import add_files
16-
from .settings import DEFAULT_CACHE, config, get_db_filename, logger
16+
from .settings import logger
1717
from .utils import (
18+
CommandInfo,
19+
HPSSType,
1820
create_tars_table,
1921
get_files_to_archive,
2022
run_command,
@@ -24,60 +26,46 @@
2426

2527

2628
def create():
27-
cache: str
28-
cache, args = setup_create()
29-
30-
# Check config fields
31-
if config.path is not None:
32-
path: str = config.path
33-
else:
34-
raise TypeError("Invalid config.path={}".format(config.path))
35-
if config.hpss is not None:
36-
hpss: str = config.hpss
37-
else:
38-
raise TypeError("Invalid config.hpss={}".format(config.hpss))
29+
command_info = CommandInfo("create")
30+
args = setup_create(command_info)
3931

4032
# Start doing actual work
4133
logger.debug(f"{ts_utc()}: Running zstash create")
42-
logger.debug("Local path : {}".format(path))
43-
logger.debug("HPSS path : {}".format(hpss))
44-
logger.debug("Max size : {}".format(config.maxsize))
45-
logger.debug("Keep local tar files : {}".format(args.keep))
34+
logger.debug(f"Local path: {command_info.dir_to_archive_absolute}")
35+
logger.debug(f"HPSS path: {command_info.hpss_path}")
36+
logger.debug(f"Max size: {command_info.maxsize}")
37+
logger.debug(f"Keep local tar files: {command_info.keep}")
4638

4739
# Make sure input path exists and is a directory
4840
logger.debug("Making sure input path exists and is a directory")
49-
if not os.path.isdir(path):
41+
if not os.path.isdir(command_info.dir_to_archive_absolute):
5042
# Input path is not a directory
51-
input_path_error_str: str = "Input path should be a directory: {}".format(path)
43+
input_path_error_str: str = f"Input path should be a directory: {command_info.dir_to_archive_absolute}"
5244
logger.error(input_path_error_str)
5345
raise NotADirectoryError(input_path_error_str)
5446

55-
if hpss != "none":
56-
url = urlparse(hpss)
57-
if url.scheme == "globus":
58-
# identify globus endpoints
59-
logger.debug(f"{ts_utc()}:Calling globus_activate(hpss)")
60-
globus_activate(hpss)
61-
else:
62-
# config.hpss is not "none", so we need to
63-
# create target HPSS directory
64-
logger.debug(f"{ts_utc()}: Creating target HPSS directory {hpss}")
65-
mkdir_command: str = "hsi -q mkdir -p {}".format(hpss)
66-
mkdir_error_str: str = "Could not create HPSS directory: {}".format(hpss)
67-
run_command(mkdir_command, mkdir_error_str)
47+
if command_info.hpss_type == HPSSType.GLOBUS:
48+
# identify globus endpoints
49+
logger.debug(f"{ts_utc()}: Calling globus_activate")
50+
globus_activate(command_info.globus_info)
51+
elif command_info.hpss_type == HPSSType.SAME_MACHINE_HPSS:
52+
logger.debug(f"{ts_utc()}: Creating target HPSS directory {command_info.hpss_path}")
53+
mkdir_command: str = f"hsi -q mkdir -p {command_info.hpss_path}"
54+
mkdir_error_str: str = f"Could not create HPSS directory: {command_info.hpss_path}"
55+
run_command(mkdir_command, mkdir_error_str)
6856

69-
# Make sure it is exists and is empty
70-
logger.debug("Making sure target HPSS directory exists and is empty")
57+
# Make sure it is exists and is empty
58+
logger.debug("Making sure target HPSS directory exists and is empty")
7159

72-
ls_command: str = 'hsi -q "cd {}; ls -l"'.format(hpss)
73-
ls_error_str: str = "Target HPSS directory is not empty"
74-
run_command(ls_command, ls_error_str)
60+
ls_command: str = f'hsi -q "cd {command_info.hpss_path}; ls -l"'
61+
ls_error_str: str = "Target HPSS directory is not empty"
62+
run_command(ls_command, ls_error_str)
7563

7664
# Create cache directory
7765
logger.debug(f"{ts_utc()}: Creating local cache directory")
78-
os.chdir(path)
66+
os.chdir(command_info.dir_to_archive_absolute)
7967
try:
80-
os.makedirs(cache)
68+
os.makedirs(command_info.cache_dir)
8169
except OSError as exc:
8270
if exc.errno != errno.EEXIST:
8371
cache_error_str: str = "Cannot create local cache directory"
@@ -88,11 +76,12 @@ def create():
8876

8977
# Create and set up the database
9078
logger.debug(f"{ts_utc()}: Calling create_database()")
91-
failures: List[str] = create_database(cache, args)
79+
failures: List[str] = create_database(command_info, args)
9280

93-
# Transfer to HPSS. Always keep a local copy.
94-
logger.debug(f"{ts_utc()}: calling hpss_put() for {get_db_filename(cache)}")
95-
hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True)
81+
# Transfer to HPSS. Always keep a local copy of the database.
82+
logger.debug(f"{ts_utc()}: calling hpss_put() for {command_info.get_db_name()}")
83+
# TODO: (A) Continue refactoring from here
84+
hpss_put(command_info, command_info.get_db_name(), is_index=True)
9685

9786
logger.debug(f"{ts_utc()}: calling globus_finalize()")
9887
globus_finalize(non_blocking=args.non_blocking)
@@ -104,7 +93,7 @@ def create():
10493
logger.error("Failed to archive {}".format(file_path))
10594

10695

107-
def setup_create() -> Tuple[str, argparse.Namespace]:
96+
def setup_create(ci: CommandInfo) -> argparse.Namespace:
10897
# Parser
10998
parser: argparse.ArgumentParser = argparse.ArgumentParser(
11099
usage="zstash create [<args>] path", description="Create a new zstash archive"
@@ -175,27 +164,25 @@ def setup_create() -> Tuple[str, argparse.Namespace]:
175164
if args.verbose:
176165
logger.setLevel(logging.DEBUG)
177166

178-
# Copy configuration
179-
config.path = os.path.abspath(args.path)
180-
config.hpss = args.hpss
181-
config.maxsize = int(1024 * 1024 * 1024 * args.maxsize)
182-
cache: str
183167
if args.cache:
184-
cache = args.cache
185-
else:
186-
cache = DEFAULT_CACHE
168+
ci.cache_dir = args.cache
169+
ci.keep = args.keep
170+
ci.set_dir_to_archive(args.path)
171+
ci.set_maxsize(args.maxsize)
172+
ci.set_hpss_parameters(args.hpss)
187173

188-
return cache, args
174+
return args
189175

190176

191-
def create_database(cache: str, args: argparse.Namespace) -> List[str]:
177+
def create_database(command_info: CommandInfo, args: argparse.Namespace) -> List[str]:
192178
# Create new database
193179
logger.debug(f"{ts_utc()}:Creating index database")
194-
if os.path.exists(get_db_filename(cache)):
180+
db_name: str = command_info.get_db_name()
181+
if os.path.exists(db_name):
195182
# Remove old database
196-
os.remove(get_db_filename(cache))
183+
os.remove(db_name)
197184
con: sqlite3.Connection = sqlite3.connect(
198-
get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES
185+
db_name, detect_types=sqlite3.PARSE_DECLTYPES
199186
)
200187
cur: sqlite3.Cursor = con.cursor()
201188

@@ -233,8 +220,8 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]:
233220

234221
# Store configuration in database
235222
# Loop through all attributes of config.
236-
for attr in dir(config):
237-
value: Any = getattr(config, attr)
223+
for attr in dir(command_info.config):
224+
value: Any = getattr(command_info.config, attr)
238225
if not callable(value) and not attr.startswith("__"):
239226
# config.{attr} is not a function.
240227
# The attribute name does not start with "__"
@@ -244,7 +231,7 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]:
244231
cur.execute("insert into config values (?,?)", (attr, value))
245232
con.commit()
246233

247-
files: List[str] = get_files_to_archive(cache, args.include, args.exclude)
234+
files: List[str] = get_files_to_archive(command_info.cache_dir, args.include, args.exclude)
248235

249236
failures: List[str]
250237
if args.follow_symlinks:
@@ -255,7 +242,7 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]:
255242
con,
256243
-1,
257244
files,
258-
cache,
245+
command_info.cache_dir,
259246
args.keep,
260247
args.follow_symlinks,
261248
skip_tars_md5=args.no_tars_md5,
@@ -270,7 +257,7 @@ def create_database(cache: str, args: argparse.Namespace) -> List[str]:
270257
con,
271258
-1,
272259
files,
273-
cache,
260+
command_info.cache_dir,
274261
args.keep,
275262
args.follow_symlinks,
276263
skip_tars_md5=args.no_tars_md5,

zstash/globus.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
from six.moves.urllib.parse import urlparse
1414

1515
from .settings import logger
16-
from .utils import ts_utc
16+
from .utils import GlobusInfo, ts_utc
1717

18-
hpss_endpoint_map = {
18+
HPSS_ENDPOINT_MAP = {
1919
"ALCF": "de463ec4-6d04-11e5-ba46-22000b92c6ec",
2020
"NERSC": "9cd89cfd-6d04-11e5-ba46-22000b92c6ec",
2121
}
2222

2323
# This is used if the `globus_endpoint_uuid` is not set in `~/.zstash.ini`
24-
regex_endpoint_map = {
24+
REGEX_ENDPOINT_MAP = {
2525
r"theta.*\.alcf\.anl\.gov": "08925f04-569f-11e7-bef8-22000b9a448b",
2626
r"blueslogin.*\.lcrc\.anl\.gov": "15288284-7006-4041-ba1a-6b52501e49f1",
2727
r"chrlogin.*\.lcrc\.anl\.gov": "15288284-7006-4041-ba1a-6b52501e49f1",
@@ -73,26 +73,19 @@ def submit_transfer_with_checks(transfer_data):
7373
return task
7474

7575

76-
def globus_activate(hpss: str):
76+
def globus_activate(globus_info: GlobusInfo):
7777
"""
7878
Read the local globus endpoint UUID from ~/.zstash.ini.
7979
If the ini file does not exist, create an ini file with empty values,
8080
and try to find the local endpoint UUID based on the FQDN
8181
"""
82-
global transfer_client
83-
global local_endpoint
84-
global remote_endpoint
85-
86-
url = urlparse(hpss)
87-
if url.scheme != "globus":
88-
return
89-
remote_endpoint = url.netloc
82+
globus_info.remote_endpoint = globus_info.url.netloc
9083

9184
ini_path = os.path.expanduser("~/.zstash.ini")
9285
ini = configparser.ConfigParser()
9386
if ini.read(ini_path):
9487
if "local" in ini.sections():
95-
local_endpoint = ini["local"].get("globus_endpoint_uuid")
88+
globus_info.local_endpoint = ini["local"].get("globus_endpoint_uuid")
9689
else:
9790
ini["local"] = {"globus_endpoint_uuid": ""}
9891
try:
@@ -101,33 +94,31 @@ def globus_activate(hpss: str):
10194
except Exception as e:
10295
logger.error(e)
10396
sys.exit(1)
104-
if not local_endpoint:
97+
if not globus_info.local_endpoint:
10598
fqdn = socket.getfqdn()
10699
if re.fullmatch(r"n.*\.local", fqdn) and os.getenv("HOSTNAME", "NA").startswith(
107100
"compy"
108101
):
109102
fqdn = "compy.pnl.gov"
110-
for pattern in regex_endpoint_map.keys():
103+
for pattern in REGEX_ENDPOINT_MAP.keys():
111104
if re.fullmatch(pattern, fqdn):
112-
local_endpoint = regex_endpoint_map.get(pattern)
105+
globus_info.local_endpoint = REGEX_ENDPOINT_MAP.get(pattern)
113106
break
114107
# FQDN is not set on Perlmutter at NERSC
115-
if not local_endpoint:
108+
if not globus_info.local_endpoint:
116109
nersc_hostname = os.environ.get("NERSC_HOST")
117110
if nersc_hostname and (
118111
nersc_hostname == "perlmutter" or nersc_hostname == "unknown"
119112
):
120-
local_endpoint = regex_endpoint_map.get(r"perlmutter.*\.nersc\.gov")
121-
if not local_endpoint:
113+
globus_info.local_endpoint = REGEX_ENDPOINT_MAP.get(r"perlmutter.*\.nersc\.gov")
114+
if not globus_info.local_endpoint:
122115
logger.error(
123-
"{} does not have the local Globus endpoint set nor could one be found in regex_endpoint_map.".format(
124-
ini_path
125-
)
116+
f"{ini_path} does not have the local Globus endpoint set nor could one be found in REGEX_ENDPOINT_MAP."
126117
)
127118
sys.exit(1)
128119

129-
if remote_endpoint.upper() in hpss_endpoint_map.keys():
130-
remote_endpoint = hpss_endpoint_map.get(remote_endpoint.upper())
120+
if globus_info.remote_endpoint.upper() in HPSS_ENDPOINT_MAP.keys():
121+
globus_info.remote_endpoint = HPSS_ENDPOINT_MAP.get(remote_endpoint.upper())
131122

132123
native_client = NativeClient(
133124
client_id="6c1629cf-446c-49e7-af95-323c6412397f",
@@ -136,15 +127,13 @@ def globus_activate(hpss: str):
136127
)
137128
native_client.login(no_local_server=True, refresh_tokens=True)
138129
transfer_authorizer = native_client.get_authorizers().get("transfer.api.globus.org")
139-
transfer_client = TransferClient(authorizer=transfer_authorizer)
130+
globus_info.transfer_client = TransferClient(authorizer=transfer_authorizer)
140131

141-
for ep_id in [local_endpoint, remote_endpoint]:
142-
r = transfer_client.endpoint_autoactivate(ep_id, if_expires_in=600)
132+
for ep_id in [globus_info.local_endpoint, globus_info.remote_endpoint]:
133+
r = globus_info.transfer_client.endpoint_autoactivate(ep_id, if_expires_in=600)
143134
if r.get("code") == "AutoActivationFailed":
144135
logger.error(
145-
"The {} endpoint is not activated or the current activation expires soon. Please go to https://app.globus.org/file-manager/collections/{} and (re)activate the endpoint.".format(
146-
ep_id, ep_id
147-
)
136+
f"The {ep_id} endpoint is not activated or the current activation expires soon. Please go to https://app.globus.org/file-manager/collections/{ep_id} and (re)activate the endpoint."
148137
)
149138
sys.exit(1)
150139

0 commit comments

Comments
 (0)