Skip to content

Commit a8d0085

Browse files
committed
Store state with object
1 parent 0fb1996 commit a8d0085

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

zstash/create.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
import os.path
77
import sqlite3
88
import sys
9-
from typing import Any, List, Tuple
9+
from typing import Any, List
1010

1111
from six.moves.urllib.parse import urlparse
1212

1313
from .globus import globus_activate, globus_finalize
1414
from .hpss import hpss_put
1515
from .hpss_utils import add_files
16-
from .settings import DEFAULT_CACHE, config, get_db_filename, logger
16+
from .settings import config, get_db_filename, logger
1717
from .utils import (
18+
CommandInfo,
1819
create_tars_table,
1920
get_files_to_archive,
2021
run_command,
@@ -24,8 +25,8 @@
2425

2526

2627
def create():
27-
cache: str
28-
cache, args = setup_create()
28+
ci = CommandInfo("create")
29+
args = setup_create(ci)
2930

3031
# Check config fields
3132
if config.path is not None:
@@ -77,7 +78,7 @@ def create():
7778
logger.debug(f"{ts_utc()}: Creating local cache directory")
7879
os.chdir(path)
7980
try:
80-
os.makedirs(cache)
81+
os.makedirs(ci.cache_dir)
8182
except OSError as exc:
8283
if exc.errno != errno.EEXIST:
8384
cache_error_str: str = "Cannot create local cache directory"
@@ -88,11 +89,11 @@ def create():
8889

8990
# Create and set up the database
9091
logger.debug(f"{ts_utc()}: Calling create_database()")
91-
failures: List[str] = create_database(cache, args)
92+
failures: List[str] = create_database(ci.cache_dir, args)
9293

9394
# Transfer to HPSS. Always keep a local copy.
94-
logger.debug(f"{ts_utc()}: calling hpss_put() for {get_db_filename(cache)}")
95-
hpss_put(hpss, get_db_filename(cache), cache, keep=True)
95+
logger.debug(f"{ts_utc()}: calling hpss_put() for {ci.get_db_name()}")
96+
hpss_put(hpss, ci.get_db_name(), ci.cache_dir, keep=True)
9697

9798
logger.debug(f"{ts_utc()}: calling globus_finalize()")
9899
globus_finalize(non_blocking=args.non_blocking)
@@ -104,7 +105,7 @@ def create():
104105
logger.error("Failed to archive {}".format(file_path))
105106

106107

107-
def setup_create() -> Tuple[str, argparse.Namespace]:
108+
def setup_create(ci: CommandInfo) -> argparse.Namespace:
108109
# Parser
109110
parser: argparse.ArgumentParser = argparse.ArgumentParser(
110111
usage="zstash create [<args>] path", description="Create a new zstash archive"
@@ -180,13 +181,10 @@ def setup_create() -> Tuple[str, argparse.Namespace]:
180181
config.path = os.path.abspath(args.path)
181182
config.hpss = args.hpss
182183
config.maxsize = int(1024 * 1024 * 1024 * args.maxsize)
183-
cache: str
184184
if args.cache:
185-
cache = args.cache
186-
else:
187-
cache = DEFAULT_CACHE
185+
ci.cache_dir = args.cache
188186

189-
return cache, args
187+
return args
190188

191189

192190
def create_database(cache: str, args: argparse.Namespace) -> List[str]:

zstash/utils.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,69 @@
55
import sqlite3
66
import subprocess
77
from datetime import datetime, timezone
8+
from enum import Enum
89
from fnmatch import fnmatch
910
from typing import Any, List, Tuple
10-
11-
from .settings import TupleTarsRow, config, logger
11+
from urllib.parse import urlparse
12+
13+
from .settings import DEFAULT_CACHE, TupleTarsRow, config, logger
14+
15+
16+
class HPSSType(Enum):
17+
NO_HPSS = 1
18+
SAME_MACHINE_HPSS = 2
19+
GLOBUS = 3
20+
21+
22+
class CommandInfo(object):
23+
def __init__(self, command_name: str):
24+
self.command_name = command_name
25+
26+
# Directories
27+
self.dir_called_from = os.getcwd()
28+
self.dir_to_archive = None
29+
self.cache_dir = DEFAULT_CACHE
30+
31+
# HPSS
32+
# self.hpss_type # Use set_hpss_parameters
33+
# self.hpss_path # Use set_hpss_parameters
34+
35+
# Globus-specific
36+
# self.globus_path # Use set_hpss_parameters
37+
# remote_endpoint = None
38+
# local_endpoint = None
39+
# transfer_client = None
40+
# transfer_data = None
41+
# task_id = None
42+
# archive_directory_listing = None
43+
44+
def set_hpss_parameters(self, hpss_path: str):
45+
if hpss_path == "none":
46+
self.hpss_type = HPSSType.NO_HPSS
47+
else:
48+
url = urlparse(hpss_path)
49+
if url.scheme == "globus":
50+
self.hpss_type = HPSSType.GLOBUS
51+
self.globus_path = hpss_path
52+
else:
53+
self.hpss_type = HPSSType.SAME_MACHINE_HPSS
54+
self.hpss_path = hpss_path
55+
56+
def get_db_name(self):
57+
return os.path.join(self.cache_dir, "index.db")
58+
59+
def list_cache_dir(self):
60+
logger.info(
61+
f"Contents of cache {self.cache_dir} = {os.listdir(self.cache_dir)}"
62+
)
63+
64+
def list_hpss_path(self):
65+
if self.hpss_type == HPSSType.SAME_MACHINE_HPSS:
66+
command = "hsi ls -l {}".format(self.hpss_path)
67+
error_str = "Attempted to list contents at hpss_path={hpss_path}"
68+
run_command(command, error_str)
69+
else:
70+
logger.info("No HPSS path to list")
1271

1372

1473
def ts_utc():

0 commit comments

Comments
 (0)