Skip to content

Commit f2aa2de

Browse files
committed
zstash update refactored
1 parent a5b1697 commit f2aa2de

File tree

5 files changed

+84
-101
lines changed

5 files changed

+84
-101
lines changed

zstash/create.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def create():
8080

8181
# Transfer to HPSS. Always keep a local copy of the database.
8282
logger.debug(f"{ts_utc()}: calling hpss_put() for {command_info.get_db_name()}")
83-
hpss_put(command_info, command_info.get_db_name(), is_index=True)
83+
hpss_put(command_info, command_info.get_db_name())
8484

8585
if command_info.globus_info:
8686
logger.debug(f"{ts_utc()}: calling globus_finalize()")
@@ -93,7 +93,7 @@ def create():
9393
logger.error(f"Failed to archive {file_path}")
9494

9595

96-
def setup_create(ci: CommandInfo) -> argparse.Namespace:
96+
def setup_create(command_info: CommandInfo) -> argparse.Namespace:
9797
# Parser
9898
parser: argparse.ArgumentParser = argparse.ArgumentParser(
9999
usage="zstash create [<args>] path", description="Create a new zstash archive"
@@ -165,11 +165,11 @@ def setup_create(ci: CommandInfo) -> argparse.Namespace:
165165
logger.setLevel(logging.DEBUG)
166166

167167
if args.cache:
168-
ci.cache_dir = args.cache
169-
ci.keep = args.keep
170-
ci.set_dir_to_archive(args.path)
171-
ci.set_maxsize(args.maxsize)
172-
ci.set_hpss_parameters(args.hpss)
168+
command_info.cache_dir = args.cache
169+
command_info.keep = args.keep
170+
command_info.set_dir_to_archive(args.path)
171+
command_info.set_maxsize(args.maxsize)
172+
command_info.set_hpss_parameters(args.hpss)
173173

174174
return args
175175

zstash/globus.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def globus_activate(globus_info: GlobusInfo):
7272
If the ini file does not exist, create an ini file with empty values,
7373
and try to find the local endpoint UUID based on the FQDN
7474
"""
75+
if not globus_info:
76+
raise ValueError("globus_info is undefined")
7577
globus_info.remote_endpoint = globus_info.url.netloc
7678

7779
ini_path = os.path.expanduser("~/.zstash.ini")
@@ -362,6 +364,8 @@ def globus_wait(globus_info: GlobusInfo, alternative_task_id=None):
362364

363365

364366
def globus_finalize(globus_info: GlobusInfo, non_blocking: bool = False):
367+
if not globus_info:
368+
raise ValueError("globus_info is undefined")
365369
last_task_id = None
366370

367371
if globus_info.transfer_data:

zstash/hpss.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def hpss_transfer(
1616
file_path: str,
1717
transfer_type: str,
1818
non_blocking: bool = False,
19-
is_index: bool = False,
2019
):
2120

2221
logger.info(
@@ -67,21 +66,18 @@ def hpss_transfer(
6766
else:
6867
raise ValueError("Invalid transfer_type={}".format(transfer_type))
6968
logger.info(f"Transferring file {transfer_word} HPSS: {file_path}")
70-
scheme: str
71-
endpoint: str
72-
path: str
73-
name: str
7469

7570
url = urlparse(command_info.hpss_path)
76-
scheme = url.scheme
77-
endpoint = url.netloc
71+
endpoint: str = url.netloc
7872
url_path = url.path
7973

8074
command_info.curr_transfers.append(file_path)
8175
# TODO: Expected output for tests needs to be changed if we uncomment this:
8276
# logger.debug(
8377
# f"{ts_utc()}: curr_transfers has been appended to, is now {command_info.curr_transfers}"
8478
# )
79+
path: str
80+
name: str
8581
path, name = os.path.split(file_path)
8682

8783
# Need to be in local directory for `hsi` to work
@@ -143,19 +139,18 @@ def hpss_put(
143139
command_info: CommandInfo,
144140
file_path: str,
145141
non_blocking: bool = False,
146-
is_index=False,
147142
):
148143
"""
149144
Put a file to the HPSS archive.
150145
"""
151-
hpss_transfer(command_info, file_path, "put", non_blocking, is_index)
146+
hpss_transfer(command_info, file_path, "put", non_blocking)
152147

153148

154-
def hpss_get(hpss: str, file_path: str, cache: str):
149+
def hpss_get(command_info: CommandInfo, file_path: str):
155150
"""
156151
Get a file from the HPSS archive.
157152
"""
158-
hpss_transfer(hpss, file_path, "get", cache, False)
153+
hpss_transfer(command_info, file_path, "get")
159154

160155

161156
def hpss_chgrp(hpss: str, group: str, recurse: bool = False):

zstash/update.py

Lines changed: 52 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,48 +13,44 @@
1313
from .hpss import hpss_get, hpss_put
1414
from .hpss_utils import add_files
1515
from .settings import (
16-
DEFAULT_CACHE,
1716
TIME_TOL,
1817
FilesRow,
1918
TupleFilesRow,
2019
config,
2120
get_db_filename,
2221
logger,
2322
)
24-
from .utils import get_files_to_archive, update_config
23+
from .utils import CommandInfo, HPSSType, get_files_to_archive
2524

2625

2726
def update():
27+
command_info = CommandInfo("update")
28+
args: argparse.Namespace = setup_update(command_info)
2829

29-
args: argparse.Namespace
30-
cache: str
31-
args, cache = setup_update()
32-
33-
result: Optional[List[str]] = update_database(args, cache)
34-
35-
if result is None:
30+
failures: Optional[List[str]] = update_database(command_info, args)
31+
if failures is None:
3632
# There was either nothing to update or `--dry-run` was set.
3733
return
38-
else:
39-
failures = result
4034

4135
# Transfer to HPSS. Always keep a local copy of the database.
42-
if config.hpss is not None:
43-
hpss = config.hpss
36+
if command_info.config.hpss is not None:
37+
hpss = command_info.config.hpss
4438
else:
45-
raise TypeError("Invalid config.hpss={}".format(config.hpss))
46-
hpss_put(hpss, get_db_filename(cache), cache, keep=args.keep, is_index=True)
39+
raise TypeError(f"Invalid config.hpss={command_info.config.hpss}")
40+
hpss_put(hpss, command_info.get_db_name(), command_info.cache_dir, keep=command_info.keep)
4741

48-
globus_finalize(non_blocking=args.non_blocking)
42+
if command_info.hpss_type == HPSSType.GLOBUS:
43+
globus_finalize(command_info.globus_info, non_blocking=args.non_blocking)
4944

5045
# List failures
5146
if len(failures) > 0:
5247
logger.warning("Some files could not be archived")
5348
for file_path in failures:
54-
logger.error("Archiving {}".format(file_path))
49+
logger.error(f"Archiving {file_path}")
50+
# TODO: (A) Continue refactor from here
5551

5652

57-
def setup_update() -> Tuple[argparse.Namespace, str]:
53+
def setup_update(command_info: CommandInfo) -> argparse.Namespace:
5854
# Parser
5955
parser: argparse.ArgumentParser = argparse.ArgumentParser(
6056
usage="zstash update [<args>]", description="Update an existing zstash archive"
@@ -118,80 +114,69 @@ def setup_update() -> Tuple[argparse.Namespace, str]:
118114
if (not args.hpss) or (args.hpss.lower() == "none"):
119115
args.hpss = "none"
120116
args.keep = True
121-
122-
# Copy configuration
123-
# config.path = os.path.abspath(args.path)
124-
config.hpss = args.hpss
125-
config.maxsize = int(1024 * 1024 * 1024 * args.maxsize)
126-
127-
cache: str
128-
if args.cache:
129-
cache = args.cache
130-
else:
131-
cache = DEFAULT_CACHE
132117
if args.verbose:
133118
logger.setLevel(logging.DEBUG)
134119

135-
return args, cache
120+
if args.cache:
121+
command_info.cache_dir = args.cache
122+
command_info.keep = args.keep
123+
command_info.set_dir_to_archive(os.getcwd())
124+
command_info.set_maxsize(args.maxsize)
125+
command_info.set_hpss_parameters(args.hpss)
126+
127+
return args
136128

137129

138130
# C901 'update_database' is too complex (20)
139131
def update_database( # noqa: C901
140-
args: argparse.Namespace, cache: str
132+
command_info: CommandInfo, args: argparse.Namespace
141133
) -> Optional[List[str]]:
142134
# Open database
143135
logger.debug("Opening index database")
144-
if not os.path.exists(get_db_filename(cache)):
136+
if not os.path.exists(command_info.get_db_name()):
145137
# The database file doesn't exist in the cache.
146138
# We need to retrieve it from HPSS
147-
if args.hpss is not None:
148-
config.hpss = args.hpss
149-
if config.hpss is not None:
150-
hpss: str = config.hpss
151-
else:
152-
raise TypeError("Invalid config.hpss={}".format(config.hpss))
153-
globus_activate(hpss)
154-
hpss_get(hpss, get_db_filename(cache), cache)
139+
if command_info.hpss_type != HPSSType.NO_HPSS:
140+
command_info.update_config()
141+
if command_info.hpss_type == HPSSType.GLOBUS:
142+
globus_activate(command_info.globus_info)
143+
hpss_get(command_info, command_info.get_db_name())
155144
else:
145+
# NOTE: while --hpss is required in `create`, it is optional in `update`!
146+
# If --hpss is not provided, we assume it is 'none' => HPSSType.NO_HPSS
156147
error_str: str = (
157-
"--hpss argument is required when local copy of database is unavailable"
148+
"--hpss argument (!= none) is required when local copy of database is unavailable"
158149
)
159150
logger.error(error_str)
160151
raise ValueError(error_str)
161152

162153
con: sqlite3.Connection = sqlite3.connect(
163-
get_db_filename(cache), detect_types=sqlite3.PARSE_DECLTYPES
154+
command_info.get_db_name(), detect_types=sqlite3.PARSE_DECLTYPES
164155
)
165156
cur: sqlite3.Cursor = con.cursor()
166157

167-
update_config(cur)
158+
command_info.update_config_using_db(cur)
168159

169-
if config.maxsize is not None:
170-
maxsize = config.maxsize
171-
else:
172-
raise TypeError("Invalid config.maxsize={}".format(config.maxsize))
173-
config.maxsize = int(maxsize)
174-
175-
keep: bool
176-
# The command line arg should always have precedence
177-
if args.hpss == "none":
178-
# If no HPSS is available, always keep the files.
179-
keep = True
160+
if command_info.config.maxsize is not None:
161+
command_info.maxsize = command_info.config.maxsize
180162
else:
181-
# If HPSS is used, let the user specify whether or not to keep the files.
182-
keep = args.keep
163+
raise TypeError(f"Invalid config.maxsize={command_info.config.maxsize}")
164+
command_info.config.maxsize = int(command_info.maxsize)
165+
166+
if command_info.hpss_type == HPSSType.NO_HPSS:
167+
# If not using HPSS, always keep the files.
168+
command_info.keep = True
169+
# else: keep command_info.keep set to args.keep
183170

184-
if args.hpss is not None:
185-
config.hpss = args.hpss
186171

187172
# Start doing actual work
188173
logger.debug("Running zstash update")
189-
logger.debug("Local path : {}".format(config.path))
190-
logger.debug("HPSS path : {}".format(config.hpss))
191-
logger.debug("Max size : {}".format(maxsize))
192-
logger.debug("Keep local tar files : {}".format(keep))
174+
logger.debug(f"Local path : {command_info.config.path}")
175+
logger.debug(f"HPSS path : {command_info.config.hpss}")
176+
logger.debug(f"Max size : {command_info.maxsize}")
177+
logger.debug(f"Keep local tar files : {command_info.keep}")
193178

194-
files: List[str] = get_files_to_archive(cache, args.include, args.exclude)
179+
files: List[str] = get_files_to_archive(command_info.get_db_name(), args.include, args.exclude)
195180

196181
# Eliminate files that are already archived and up to date
197182
newfiles: List[str] = []
@@ -261,8 +246,8 @@ def update_database( # noqa: C901
261246
con,
262247
itar,
263248
newfiles,
264-
cache,
265-
keep,
249+
command_info.get_db_name(),
250+
command_info.keep,
266251
args.follow_symlinks,
267252
non_blocking=args.non_blocking,
268253
)
@@ -275,8 +260,8 @@ def update_database( # noqa: C901
275260
con,
276261
itar,
277262
newfiles,
278-
cache,
279-
keep,
263+
command_info.get_db_name(),
264+
command_info.keep,
280265
args.follow_symlinks,
281266
non_blocking=args.non_blocking,
282267
)

zstash/utils.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ def update_config(self):
9797
self.config.hpss = self.hpss_path
9898
self.config.maxsize = self.maxsize
9999

100+
def update_config_using_db(self, cur: sqlite3.Cursor):
101+
# Retrieve some configuration settings from database
102+
# Loop through all attributes of config.
103+
for attr in dir(self.config):
104+
value: Any = getattr(self.config, attr)
105+
if not callable(value) and not attr.startswith("__"):
106+
# config.{attr} is not a function.
107+
# The attribute name does not start with "__"
108+
# Get the value (column 2) for attribute `attr` (column 1)
109+
# i.e., for the row where column 1 is the attribute, get the value from column 2
110+
cur.execute("select value from config where arg=?", (attr,))
111+
value = cur.fetchone()[0]
112+
# Update config with the new attribute-value pair
113+
setattr(self.config, attr, value)
114+
100115
def get_db_name(self) -> str:
101116
return os.path.join(self.cache_dir, "index.db")
102117

@@ -212,22 +227,6 @@ def get_files_to_archive(cache: str, include: str, exclude: str) -> List[str]:
212227
return files
213228

214229

215-
def update_config(cur: sqlite3.Cursor):
216-
# Retrieve some configuration settings from database
217-
# Loop through all attributes of config.
218-
for attr in dir(config):
219-
value: Any = getattr(config, attr)
220-
if not callable(value) and not attr.startswith("__"):
221-
# config.{attr} is not a function.
222-
# The attribute name does not start with "__"
223-
# Get the value (column 2) for attribute `attr` (column 1)
224-
# i.e., for the row where column 1 is the attribute, get the value from column 2
225-
cur.execute("select value from config where arg=?", (attr,))
226-
value = cur.fetchone()[0]
227-
# Update config with the new attribute-value pair
228-
setattr(config, attr, value)
229-
230-
231230
def create_tars_table(cur: sqlite3.Cursor, con: sqlite3.Connection):
232231
# Create 'tars' table
233232
cur.execute(

0 commit comments

Comments
 (0)