-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
112 lines (92 loc) · 3.39 KB
/
cli.py
File metadata and controls
112 lines (92 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# File: cli.py
# Created Date: Saturday July 27th 2024
# Author: Steven Atkinson (steven@atkinson.mn)
"""
Command line interface entry points (GUI trainer, full trainer)
"""
# This must happen first
def _ensure_graceful_shutdowns():
"""
Hack to recover graceful shutdowns in Windows.
This has to happen ASAP
See:
https://github.com/sdatkinson/neural-amp-modeler/issues/105
https://stackoverflow.com/a/44822794
"""
import os
if os.name == "nt": # OS is Windows
os.environ["FOR_DISABLE_CONSOLE_CTRL_HANDLER"] = "1"
_ensure_graceful_shutdowns()
# This must happen ASAP but not before the graceful shutdown hack
def _apply_extensions():
"""
Find and apply extensions to NAM
"""
def removesuffix(s: str, suffix: str) -> str:
# Remove once 3.8 is dropped
if len(suffix) == 0:
return s
return s[: -len(suffix)] if s.endswith(suffix) else s
import importlib
import os
import sys
# DRY: Make sure this matches the test!
home_path = os.environ["HOMEPATH"] if os.name == "nt" else os.environ["HOME"]
extensions_path = os.path.join(home_path, ".neural-amp-modeler", "extensions")
if not os.path.exists(extensions_path):
return
if not os.path.isdir(extensions_path):
print(
f"WARNING: non-directory object found at expected extensions path {extensions_path}; skip"
)
print("Applying extensions...")
if extensions_path not in sys.path:
sys.path.append(extensions_path)
extensions_path_not_in_sys_path = True
else:
extensions_path_not_in_sys_path = False
for name in os.listdir(extensions_path):
if name in {"__pycache__", ".DS_Store"}:
continue
try:
importlib.import_module(removesuffix(name, ".py")) # Runs it
print(f" {name} [SUCCESS]")
except Exception as e:
print(f" {name} [FAILED]")
print(e)
if extensions_path_not_in_sys_path:
for i, p in enumerate(sys.path):
if p == extensions_path:
sys.path = sys.path[:i] + sys.path[i + 1 :]
break
else:
raise RuntimeError("Failed to remove extensions path from sys.path?")
print("Done!")
_apply_extensions()
import json
from argparse import ArgumentParser
from pathlib import Path
from nam.train.full import main as _nam_full
from nam.train.gui import run as nam_gui # noqa F401 Used as an entry point
from nam.util import timestamp
def nam_full():
parser = ArgumentParser()
parser.add_argument("data_config_path", type=str)
parser.add_argument("model_config_path", type=str)
parser.add_argument("learning_config_path", type=str)
parser.add_argument("outdir")
parser.add_argument("--no-show", action="store_true", help="Don't show plots")
args = parser.parse_args()
def ensure_outdir(outdir: str) -> Path:
outdir = Path(outdir, timestamp())
outdir.mkdir(parents=True, exist_ok=False)
return outdir
outdir = ensure_outdir(args.outdir)
# Read
with open(args.data_config_path, "r") as fp:
data_config = json.load(fp)
with open(args.model_config_path, "r") as fp:
model_config = json.load(fp)
with open(args.learning_config_path, "r") as fp:
learning_config = json.load(fp)
_nam_full(data_config, model_config, learning_config, outdir, args.no_show)