Skip to content

Commit 54ff9da

Browse files
authored
InputOptions and optional errors (#8)
* errors optional * InputOptions Class * additional tests
1 parent 327c928 commit 54ff9da

File tree

9 files changed

+423
-175
lines changed

9 files changed

+423
-175
lines changed

irnl_rdt_correction/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
:license: MIT, see the LICENSE.md file for details.
99
"""
1010
from irnl_rdt_correction.main import irnl_rdt_correction
11+
from irnl_rdt_correction.input_options import InputOptions
1112

1213
__title__ = "irnl-rdt-correction"
1314
__description__ = "Correction script to power the nonlinear correctors in the (HL-)LHC insertion regions based on RDTs."
1415
__url__ = "https://github.com/pylhc/irnl_rdt_correction"
15-
__version__ = "1.0.0"
16+
__version__ = "1.1.0"
1617
__author__ = "pylhc"
1718
__author_email__ = "[email protected]"
1819
__license__ = "MIT"
1920

20-
__all__ = [irnl_rdt_correction, __version__]
21+
__all__ = [irnl_rdt_correction, InputOptions, __version__]

irnl_rdt_correction/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@
3434

3535
# Types ---
3636
StrOrPathOrDataFrame = Union[str, Path, DataFrame, TfsDataFrame]
37+
StrOrPathOrDataFrameOrNone = Union[str, Path, DataFrame, TfsDataFrame, None]
3738
RDTInputTypes = Union[Sequence[str], Dict[str, Sequence[str]]]

irnl_rdt_correction/input_options.py

Lines changed: 160 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,33 @@
66
and checks for their validity.
77
"""
88
import argparse
9+
from dataclasses import dataclass, fields
910
import logging
1011
from pathlib import Path
11-
from typing import Iterable, Sized, Union
12+
from typing import Iterable, Optional, Sequence, Sized, Union, Tuple
1213

1314
import pandas as pd
1415
import tfs
1516

16-
from irnl_rdt_correction.constants import EXT_TFS, EXT_MADX
17+
from irnl_rdt_correction.constants import EXT_TFS, EXT_MADX, StrOrPathOrDataFrame, StrOrPathOrDataFrameOrNone
1718
from irnl_rdt_correction.equation_system import SOLVER_MAP
18-
from irnl_rdt_correction.utilities import list2str, DotDict
19+
from irnl_rdt_correction.utilities import list2str
1920

2021
LOG = logging.getLogger(__name__)
2122

2223

23-
# Default Values ---------------------------------------------------------------
24-
25-
DEFAULTS = {'feeddown': 0,
26-
'ips': [1, 2, 5, 8],
27-
'accel': 'lhc',
28-
'solver': 'lstsq',
29-
'update_optics': True,
30-
'iterations': 1,
31-
'ignore_corrector_settings': False,
32-
'rdts2': None,
33-
'ignore_missing_columns': False,
34-
'output': None,
35-
}
36-
37-
DEFAULT_RDTS = {
38-
'lhc': ('F0003', 'F0003*', # correct a3 errors with F0003
39-
'F1002', 'F1002*', # correct b3 errors with F1002
40-
'F1003', 'F3001', # correct a4 errors with F1003 and F3001
41-
'F4000', 'F0004', # correct b4 errors with F4000 and F0004
42-
'F6000', 'F0006', # correct b6 errors with F6000 and F0006
43-
),
44-
'hllhc': ('F0003', 'F0003*', # correct a3 errors with F0003
45-
'F1002', 'F1002*', # correct b3 errors with F1002
46-
'F1003', 'F3001', # correct a4 errors with F1003 and F3001
47-
'F0004', 'F4000', # correct b4 errors with F0004 and F4000
48-
'F0005', 'F0005*', # correct a5 errors with F0005
49-
'F5000', 'F5000*', # correct b5 errors with F5000
50-
'F5001', 'F1005', # correct a6 errors with F5001 and F1005
51-
'F6000', 'F0006', # correct b6 errors with F6000 and F0006
52-
),
53-
}
54-
55-
5624
# Parser -----------------------------------------------------------------------
5725

5826
def get_parser() -> argparse.ArgumentParser:
5927
parser = argparse.ArgumentParser()
28+
parser.add_argument(
29+
"--beams",
30+
dest="beams",
31+
type=int,
32+
nargs="+",
33+
help="Which beam the files come from (1, 2 or 4)",
34+
required=True,
35+
)
6036
parser.add_argument(
6137
"--twiss",
6238
dest="twiss",
@@ -71,15 +47,6 @@ def get_parser() -> argparse.ArgumentParser:
7147
dest="errors",
7248
nargs="+",
7349
help="Path(s) to error file(s), in the format of MAD-X `esave` output.",
74-
required=True,
75-
)
76-
parser.add_argument(
77-
"--beams",
78-
dest="beams",
79-
type=int,
80-
nargs="+",
81-
help="Which beam the files come from (1, 2 or 4)",
82-
required=True,
8350
)
8451
parser.add_argument(
8552
"--output",
@@ -109,32 +76,32 @@ def get_parser() -> argparse.ArgumentParser:
10976
"--accel",
11077
dest="accel",
11178
type=str.lower,
112-
choices=list(DEFAULT_RDTS.keys()),
113-
default=DEFAULTS['accel'],
79+
choices=list(InputOptions.DEFAULT_RDTS.keys()),
80+
default=InputOptions.accel,
11481
help="Which accelerator we have.",
11582
)
11683
parser.add_argument(
11784
"--feeddown",
11885
dest="feeddown",
11986
type=int,
12087
help="Order of Feeddown to include.",
121-
default=DEFAULTS['feeddown'],
88+
default=InputOptions.feeddown,
12289
)
12390
parser.add_argument(
12491
"--ips",
12592
dest="ips",
12693
nargs="+",
12794
help="In which IPs to correct.",
12895
type=int,
129-
default=DEFAULTS['ips'],
96+
default=list(InputOptions.ips),
13097
)
13198
parser.add_argument(
13299
"--solver",
133100
dest="solver",
134101
help="Solving method to use.",
135102
type=str.lower,
136103
choices=list(SOLVER_MAP.keys()),
137-
default=DEFAULTS['solver'],
104+
default=InputOptions.solver,
138105
)
139106
parser.add_argument(
140107
"--update_optics",
@@ -144,15 +111,15 @@ def get_parser() -> argparse.ArgumentParser:
144111
"corrector strengths in the optics after calculation, so the "
145112
"feeddown to lower order correctors is included."
146113
),
147-
default=DEFAULTS["update_optics"]
114+
default=InputOptions.update_optics
148115
)
149116
parser.add_argument(
150117
"--iterations",
151118
dest="iterations",
152119
type=int,
153120
help=("Reiterate correction, "
154121
"starting with the previously calculated values."),
155-
default=DEFAULTS["iterations"]
122+
default=InputOptions.iterations
156123
)
157124
parser.add_argument(
158125
"--ignore_corrector_settings",
@@ -170,78 +137,159 @@ def get_parser() -> argparse.ArgumentParser:
170137
)
171138
return parser
172139

140+
# InputOptions and Defaults ---------------------------------------------------------------
173141

174-
# Checks -----------------------------------------------------------------------
175-
176-
def check_opt(opt: Union[dict, DotDict]) -> DotDict:
177-
""" Asserts that the input parameters make sense and adds what's missing.
178-
If the input is empty, arguments will be parsed from commandline.
142+
@dataclass
143+
class InputOptions:
144+
""" DataClass to store the input options.
145+
On creation it asserts that the input parameters make sense and adds what's missing.
179146
Checks include:
180-
- Set defaults (see ``DEFAULTS``) if option not given.
181147
- Check accelerator name is valid
182148
- Set default RDTs if not given (see ``DEFAULT_RDTS``)
183149
- Check required parameters are present (twiss, errors, beams, rdts)
184150
- Check feeddown and iterations
185151
186-
TODO: Replace DotDict with dataclass and have class check most of this...
152+
"""
153+
DEFAULT_RDTS = {
154+
'lhc': ('F0003', 'F0003*', # correct a3 errors with F0003
155+
'F1002', 'F1002*', # correct b3 errors with F1002
156+
'F1003', 'F3001', # correct a4 errors with F1003 and F3001
157+
'F4000', 'F0004', # correct b4 errors with F4000 and F0004
158+
'F6000', 'F0006', # correct b6 errors with F6000 and F0006
159+
),
160+
'hllhc': ('F0003', 'F0003*', # correct a3 errors with F0003
161+
'F1002', 'F1002*', # correct b3 errors with F1002
162+
'F1003', 'F3001', # correct a4 errors with F1003 and F3001
163+
'F0004', 'F4000', # correct b4 errors with F0004 and F4000
164+
'F0005', 'F0005*', # correct a5 errors with F0005
165+
'F5000', 'F5000*', # correct b5 errors with F5000
166+
'F5001', 'F1005', # correct a6 errors with F5001 and F1005
167+
'F6000', 'F0006', # correct b6 errors with F6000 and F0006
168+
),
169+
}
187170

188-
Args:
189-
opt (Union[dict, DotDict]): Function options in dictionary format.
190-
Description of the arguments are given in
191-
:func:`irnl_rdt_correction.main.irnl_rdt_correction`.
171+
beams: Sequence[int]
172+
twiss: Sequence[StrOrPathOrDataFrame]
173+
errors: Sequence[StrOrPathOrDataFrameOrNone] = None
174+
rdts: Sequence[str] = None
175+
rdts2: Sequence[str] = None
176+
accel: str = 'lhc'
177+
feeddown: int = 0
178+
ips: Sequence[int] = (1, 2, 5, 8)
179+
solver: str ='lstsq'
180+
update_optics: bool = True
181+
iterations: int = 1
182+
ignore_corrector_settings: bool = False
183+
ignore_missing_columns: bool = False
184+
output: str = None
192185

193-
Returns:
194-
DotDict: (Parsed and) checked options.
186+
def __post_init__(self):
187+
self.check_all()
195188

196-
"""
197-
# check for unkown input
198-
parser = get_parser()
199-
if not len(opt):
200-
opt = vars(parser.parse_args())
201-
opt = DotDict(opt)
202-
known_opts = [a.dest for a in parser._actions if not isinstance(a, argparse._HelpAction)] # best way I could figure out
203-
unknown_opts = [k for k in opt.keys() if k not in known_opts]
204-
if len(unknown_opts):
205-
raise AttributeError(f"Unknown arguments found: '{list2str(unknown_opts)}'.\n"
206-
f"Allowed input parameters are: '{list2str(known_opts)}'")
207-
208-
# Set defaults
209-
for name, default in DEFAULTS.items():
210-
if opt.get(name) is None:
211-
LOG.debug(f"Setting input '{name}' to default value '{default}'.")
212-
opt[name] = default
213-
214-
# check accel
215-
opt.accel = opt.accel.lower() # let's not care about case
216-
if opt.accel not in DEFAULT_RDTS.keys():
217-
raise ValueError(f"Parameter 'accel' needs to be one of '{list2str(list(DEFAULT_RDTS.keys()))}' "
218-
f"but was '{opt.accel}' instead.")
219-
220-
# Set rdts:
221-
if opt.get('rdts') is None:
222-
opt.rdts = DEFAULT_RDTS[opt.accel]
223-
224-
# Check required and rdts:
225-
for name in ('twiss', 'errors', 'beams', 'rdts'):
226-
inputs = opt.get(name)
227-
if inputs is None or isinstance(inputs, str) or not isinstance(inputs, (Iterable, Sized)):
228-
raise ValueError(f"Parameter '{name}' is required and needs to be "
229-
"iterable, even if only of length 1. "
230-
f"Instead was '{inputs}'.")
231-
232-
# Check twiss and errors input type
233-
for name in ('twiss', 'errors'):
234-
inputs = opt.get(name)
235-
for element in inputs:
189+
def __getitem__(self, item):
190+
return getattr(self, item)
191+
192+
@classmethod
193+
def keys(cls):
194+
return (f.name for f in fields(cls))
195+
196+
def values(self):
197+
return (getattr(self, f.name) for f in fields(self))
198+
199+
def items(self):
200+
return ((f.name, getattr(self, f.name)) for f in fields(self))
201+
202+
def check_all(self):
203+
self.check_accel()
204+
self.check_twiss()
205+
self.check_errors()
206+
self.check_beams()
207+
self.check_rdts()
208+
self.check_feeddown()
209+
self.check_iterations()
210+
211+
def check_accel(self):
212+
if self.accel not in self.DEFAULT_RDTS:
213+
raise ValueError(f"Parameter 'accel' needs to be one of '{list2str(list(self.DEFAULT_RDTS.keys()))}' "
214+
f"but was '{self.accel}' instead.")
215+
216+
def check_twiss(self):
217+
if self.twiss is None:
218+
raise ValueError("Parameter 'twiss' needs to be set.")
219+
220+
self._check_iterable('twiss')
221+
for element in self.twiss:
236222
if not isinstance(element, (str, Path, pd.DataFrame, tfs.TfsDataFrame)):
237-
raise TypeError(f"Not all elements of '{name}' are DataFrames or paths to DataFrames!")
223+
raise TypeError(f"Not all elements of 'twiss' are DataFrames or paths to DataFrames!")
224+
225+
def check_errors(self):
226+
if self.errors is None:
227+
self.errors = tuple([None] * len(self.twiss))
228+
return
229+
230+
self._check_iterable('errors')
231+
for element in self.errors:
232+
if not isinstance(element, (str, Path, pd.DataFrame, tfs.TfsDataFrame, type(None))):
233+
raise TypeError(f"Not all elements of 'errors' are DataFrames or paths to DataFrames or None!")
234+
235+
def check_beams(self):
236+
if self.beams is None:
237+
raise ValueError("Parameter 'beams' needs to be set.")
238+
self._check_iterable('beams')
239+
240+
def check_rdts(self):
241+
if self.rdts is None:
242+
self.rdts = self.DEFAULT_RDTS[self.accel]
243+
else:
244+
self._check_iterable('rdts')
245+
246+
def check_feeddown(self):
247+
if self.feeddown < 0 or not (self.feeddown == int(self.feeddown)):
248+
raise ValueError("'feeddown' needs to be a positive integer.")
249+
250+
def check_iterations(self):
251+
if self.iterations < 1:
252+
raise ValueError("At least one iteration (see: 'iterations') needs to "
253+
"be done for correction.")
238254

239-
if opt.feeddown < 0 or not (opt.feeddown == int(opt.feeddown)):
240-
raise ValueError("'feeddown' needs to be a positive integer.")
255+
def _check_iterable(self, name):
256+
inputs = getattr(self, name)
257+
if isinstance(inputs, str) or not isinstance(inputs, (Iterable, Sized)):
258+
raise ValueError(f"Parameter '{name}' needs to be iterable, "
259+
f"even if only of length 1. Instead was '{inputs}'.")
241260

242-
if opt.iterations < 1:
243-
raise ValueError("At least one iteration (see: 'iterations') needs to "
244-
"be done for correction.")
245-
return opt
261+
@classmethod
262+
def from_args_or_dict(cls, opt: Optional[Union[dict, 'InputOptions']] = None) -> 'InputOptions':
263+
"""Create an InputOptions instance from the given dictionary.
264+
If the input is empty, arguments will be parsed from commandline.
246265
266+
Args:
267+
opt (Union[dict, DotDict]): Function options in dictionary format.
268+
Description of the arguments are given in
269+
:func:`irnl_rdt_correction.main.irnl_rdt_correction`.
270+
Optional, if not given parses commandline args
247271
272+
Returns:
273+
InputOptions: (Parsed and) checked options.
274+
"""
275+
if isinstance(opt, InputOptions):
276+
return opt
277+
278+
if opt is None or not len(opt):
279+
parser = get_parser()
280+
opt = vars(parser.parse_args())
281+
282+
return cls(**opt)
283+
284+
285+
def allow_commandline_and_kwargs(func):
286+
""" Decorator to allow a function to take options from the commandline
287+
or via kwargs, or given an InputOptions instance.
288+
"""
289+
def wrapper(opt: Optional[Union[InputOptions, dict]] = None, **kwargs) -> Tuple[str, tfs.TfsDataFrame]:
290+
if not isinstance(opt, InputOptions):
291+
if opt is None:
292+
opt = kwargs
293+
opt = InputOptions.from_args_or_dict(opt)
294+
return func(opt)
295+
return wrapper

0 commit comments

Comments
 (0)