66and checks for their validity.
77"""
88import argparse
9+ from dataclasses import dataclass , fields
910import logging
1011from pathlib import Path
11- from typing import Iterable , Sized , Union
12+ from typing import Iterable , Optional , Sequence , Sized , Union , Tuple
1213
1314import pandas as pd
1415import tfs
1516
16- from irnl_rdt_correction .constants import EXT_TFS , EXT_MADX
17+ from irnl_rdt_correction .constants import EXT_TFS , EXT_MADX , StrOrPathOrDataFrame , StrOrPathOrDataFrameOrNone
1718from irnl_rdt_correction .equation_system import SOLVER_MAP
18- from irnl_rdt_correction .utilities import list2str , DotDict
19+ from irnl_rdt_correction .utilities import list2str
1920
2021LOG = logging .getLogger (__name__ )
2122
2223
23- # Default Values ---------------------------------------------------------------
24-
25- DEFAULTS = {'feeddown' : 0 ,
26- 'ips' : [1 , 2 , 5 , 8 ],
27- 'accel' : 'lhc' ,
28- 'solver' : 'lstsq' ,
29- 'update_optics' : True ,
30- 'iterations' : 1 ,
31- 'ignore_corrector_settings' : False ,
32- 'rdts2' : None ,
33- 'ignore_missing_columns' : False ,
34- 'output' : None ,
35- }
36-
37- DEFAULT_RDTS = {
38- 'lhc' : ('F0003' , 'F0003*' , # correct a3 errors with F0003
39- 'F1002' , 'F1002*' , # correct b3 errors with F1002
40- 'F1003' , 'F3001' , # correct a4 errors with F1003 and F3001
41- 'F4000' , 'F0004' , # correct b4 errors with F4000 and F0004
42- 'F6000' , 'F0006' , # correct b6 errors with F6000 and F0006
43- ),
44- 'hllhc' : ('F0003' , 'F0003*' , # correct a3 errors with F0003
45- 'F1002' , 'F1002*' , # correct b3 errors with F1002
46- 'F1003' , 'F3001' , # correct a4 errors with F1003 and F3001
47- 'F0004' , 'F4000' , # correct b4 errors with F0004 and F4000
48- 'F0005' , 'F0005*' , # correct a5 errors with F0005
49- 'F5000' , 'F5000*' , # correct b5 errors with F5000
50- 'F5001' , 'F1005' , # correct a6 errors with F5001 and F1005
51- 'F6000' , 'F0006' , # correct b6 errors with F6000 and F0006
52- ),
53- }
54-
55-
5624# Parser -----------------------------------------------------------------------
5725
5826def get_parser () -> argparse .ArgumentParser :
5927 parser = argparse .ArgumentParser ()
28+ parser .add_argument (
29+ "--beams" ,
30+ dest = "beams" ,
31+ type = int ,
32+ nargs = "+" ,
33+ help = "Which beam the files come from (1, 2 or 4)" ,
34+ required = True ,
35+ )
6036 parser .add_argument (
6137 "--twiss" ,
6238 dest = "twiss" ,
@@ -71,15 +47,6 @@ def get_parser() -> argparse.ArgumentParser:
7147 dest = "errors" ,
7248 nargs = "+" ,
7349 help = "Path(s) to error file(s), in the format of MAD-X `esave` output." ,
74- required = True ,
75- )
76- parser .add_argument (
77- "--beams" ,
78- dest = "beams" ,
79- type = int ,
80- nargs = "+" ,
81- help = "Which beam the files come from (1, 2 or 4)" ,
82- required = True ,
8350 )
8451 parser .add_argument (
8552 "--output" ,
@@ -109,32 +76,32 @@ def get_parser() -> argparse.ArgumentParser:
10976 "--accel" ,
11077 dest = "accel" ,
11178 type = str .lower ,
112- choices = list (DEFAULT_RDTS .keys ()),
113- default = DEFAULTS [ ' accel' ] ,
79+ choices = list (InputOptions . DEFAULT_RDTS .keys ()),
80+ default = InputOptions . accel ,
11481 help = "Which accelerator we have." ,
11582 )
11683 parser .add_argument (
11784 "--feeddown" ,
11885 dest = "feeddown" ,
11986 type = int ,
12087 help = "Order of Feeddown to include." ,
121- default = DEFAULTS [ ' feeddown' ] ,
88+ default = InputOptions . feeddown ,
12289 )
12390 parser .add_argument (
12491 "--ips" ,
12592 dest = "ips" ,
12693 nargs = "+" ,
12794 help = "In which IPs to correct." ,
12895 type = int ,
129- default = DEFAULTS [ ' ips' ] ,
96+ default = list ( InputOptions . ips ) ,
13097 )
13198 parser .add_argument (
13299 "--solver" ,
133100 dest = "solver" ,
134101 help = "Solving method to use." ,
135102 type = str .lower ,
136103 choices = list (SOLVER_MAP .keys ()),
137- default = DEFAULTS [ ' solver' ] ,
104+ default = InputOptions . solver ,
138105 )
139106 parser .add_argument (
140107 "--update_optics" ,
@@ -144,15 +111,15 @@ def get_parser() -> argparse.ArgumentParser:
144111 "corrector strengths in the optics after calculation, so the "
145112 "feeddown to lower order correctors is included."
146113 ),
147- default = DEFAULTS [ " update_optics" ]
114+ default = InputOptions . update_optics
148115 )
149116 parser .add_argument (
150117 "--iterations" ,
151118 dest = "iterations" ,
152119 type = int ,
153120 help = ("Reiterate correction, "
154121 "starting with the previously calculated values." ),
155- default = DEFAULTS [ " iterations" ]
122+ default = InputOptions . iterations
156123 )
157124 parser .add_argument (
158125 "--ignore_corrector_settings" ,
@@ -170,78 +137,159 @@ def get_parser() -> argparse.ArgumentParser:
170137 )
171138 return parser
172139
140+ # InputOptions and Defaults ---------------------------------------------------------------
173141
174- # Checks -----------------------------------------------------------------------
175-
176- def check_opt (opt : Union [dict , DotDict ]) -> DotDict :
177- """ Asserts that the input parameters make sense and adds what's missing.
178- If the input is empty, arguments will be parsed from commandline.
142+ @dataclass
143+ class InputOptions :
144+ """ DataClass to store the input options.
145+ On creation it asserts that the input parameters make sense and adds what's missing.
179146 Checks include:
180- - Set defaults (see ``DEFAULTS``) if option not given.
181147 - Check accelerator name is valid
182148 - Set default RDTs if not given (see ``DEFAULT_RDTS``)
183149 - Check required parameters are present (twiss, errors, beams, rdts)
184150 - Check feeddown and iterations
185151
186- TODO: Replace DotDict with dataclass and have class check most of this...
152+ """
153+ DEFAULT_RDTS = {
154+ 'lhc' : ('F0003' , 'F0003*' , # correct a3 errors with F0003
155+ 'F1002' , 'F1002*' , # correct b3 errors with F1002
156+ 'F1003' , 'F3001' , # correct a4 errors with F1003 and F3001
157+ 'F4000' , 'F0004' , # correct b4 errors with F4000 and F0004
158+ 'F6000' , 'F0006' , # correct b6 errors with F6000 and F0006
159+ ),
160+ 'hllhc' : ('F0003' , 'F0003*' , # correct a3 errors with F0003
161+ 'F1002' , 'F1002*' , # correct b3 errors with F1002
162+ 'F1003' , 'F3001' , # correct a4 errors with F1003 and F3001
163+ 'F0004' , 'F4000' , # correct b4 errors with F0004 and F4000
164+ 'F0005' , 'F0005*' , # correct a5 errors with F0005
165+ 'F5000' , 'F5000*' , # correct b5 errors with F5000
166+ 'F5001' , 'F1005' , # correct a6 errors with F5001 and F1005
167+ 'F6000' , 'F0006' , # correct b6 errors with F6000 and F0006
168+ ),
169+ }
187170
188- Args:
189- opt (Union[dict, DotDict]): Function options in dictionary format.
190- Description of the arguments are given in
191- :func:`irnl_rdt_correction.main.irnl_rdt_correction`.
171+ beams : Sequence [int ]
172+ twiss : Sequence [StrOrPathOrDataFrame ]
173+ errors : Sequence [StrOrPathOrDataFrameOrNone ] = None
174+ rdts : Sequence [str ] = None
175+ rdts2 : Sequence [str ] = None
176+ accel : str = 'lhc'
177+ feeddown : int = 0
178+ ips : Sequence [int ] = (1 , 2 , 5 , 8 )
179+ solver : str = 'lstsq'
180+ update_optics : bool = True
181+ iterations : int = 1
182+ ignore_corrector_settings : bool = False
183+ ignore_missing_columns : bool = False
184+ output : str = None
192185
193- Returns :
194- DotDict: (Parsed and) checked options.
186+ def __post_init__ ( self ) :
187+ self . check_all ()
195188
196- """
197- # check for unkown input
198- parser = get_parser ()
199- if not len (opt ):
200- opt = vars (parser .parse_args ())
201- opt = DotDict (opt )
202- known_opts = [a .dest for a in parser ._actions if not isinstance (a , argparse ._HelpAction )] # best way I could figure out
203- unknown_opts = [k for k in opt .keys () if k not in known_opts ]
204- if len (unknown_opts ):
205- raise AttributeError (f"Unknown arguments found: '{ list2str (unknown_opts )} '.\n "
206- f"Allowed input parameters are: '{ list2str (known_opts )} '" )
207-
208- # Set defaults
209- for name , default in DEFAULTS .items ():
210- if opt .get (name ) is None :
211- LOG .debug (f"Setting input '{ name } ' to default value '{ default } '." )
212- opt [name ] = default
213-
214- # check accel
215- opt .accel = opt .accel .lower () # let's not care about case
216- if opt .accel not in DEFAULT_RDTS .keys ():
217- raise ValueError (f"Parameter 'accel' needs to be one of '{ list2str (list (DEFAULT_RDTS .keys ()))} ' "
218- f"but was '{ opt .accel } ' instead." )
219-
220- # Set rdts:
221- if opt .get ('rdts' ) is None :
222- opt .rdts = DEFAULT_RDTS [opt .accel ]
223-
224- # Check required and rdts:
225- for name in ('twiss' , 'errors' , 'beams' , 'rdts' ):
226- inputs = opt .get (name )
227- if inputs is None or isinstance (inputs , str ) or not isinstance (inputs , (Iterable , Sized )):
228- raise ValueError (f"Parameter '{ name } ' is required and needs to be "
229- "iterable, even if only of length 1. "
230- f"Instead was '{ inputs } '." )
231-
232- # Check twiss and errors input type
233- for name in ('twiss' , 'errors' ):
234- inputs = opt .get (name )
235- for element in inputs :
189+ def __getitem__ (self , item ):
190+ return getattr (self , item )
191+
192+ @classmethod
193+ def keys (cls ):
194+ return (f .name for f in fields (cls ))
195+
196+ def values (self ):
197+ return (getattr (self , f .name ) for f in fields (self ))
198+
199+ def items (self ):
200+ return ((f .name , getattr (self , f .name )) for f in fields (self ))
201+
202+ def check_all (self ):
203+ self .check_accel ()
204+ self .check_twiss ()
205+ self .check_errors ()
206+ self .check_beams ()
207+ self .check_rdts ()
208+ self .check_feeddown ()
209+ self .check_iterations ()
210+
211+ def check_accel (self ):
212+ if self .accel not in self .DEFAULT_RDTS :
213+ raise ValueError (f"Parameter 'accel' needs to be one of '{ list2str (list (self .DEFAULT_RDTS .keys ()))} ' "
214+ f"but was '{ self .accel } ' instead." )
215+
216+ def check_twiss (self ):
217+ if self .twiss is None :
218+ raise ValueError ("Parameter 'twiss' needs to be set." )
219+
220+ self ._check_iterable ('twiss' )
221+ for element in self .twiss :
236222 if not isinstance (element , (str , Path , pd .DataFrame , tfs .TfsDataFrame )):
237- raise TypeError (f"Not all elements of '{ name } ' are DataFrames or paths to DataFrames!" )
223+ raise TypeError (f"Not all elements of 'twiss' are DataFrames or paths to DataFrames!" )
224+
225+ def check_errors (self ):
226+ if self .errors is None :
227+ self .errors = tuple ([None ] * len (self .twiss ))
228+ return
229+
230+ self ._check_iterable ('errors' )
231+ for element in self .errors :
232+ if not isinstance (element , (str , Path , pd .DataFrame , tfs .TfsDataFrame , type (None ))):
233+ raise TypeError (f"Not all elements of 'errors' are DataFrames or paths to DataFrames or None!" )
234+
235+ def check_beams (self ):
236+ if self .beams is None :
237+ raise ValueError ("Parameter 'beams' needs to be set." )
238+ self ._check_iterable ('beams' )
239+
240+ def check_rdts (self ):
241+ if self .rdts is None :
242+ self .rdts = self .DEFAULT_RDTS [self .accel ]
243+ else :
244+ self ._check_iterable ('rdts' )
245+
246+ def check_feeddown (self ):
247+ if self .feeddown < 0 or not (self .feeddown == int (self .feeddown )):
248+ raise ValueError ("'feeddown' needs to be a positive integer." )
249+
250+ def check_iterations (self ):
251+ if self .iterations < 1 :
252+ raise ValueError ("At least one iteration (see: 'iterations') needs to "
253+ "be done for correction." )
238254
239- if opt .feeddown < 0 or not (opt .feeddown == int (opt .feeddown )):
240- raise ValueError ("'feeddown' needs to be a positive integer." )
255+ def _check_iterable (self , name ):
256+ inputs = getattr (self , name )
257+ if isinstance (inputs , str ) or not isinstance (inputs , (Iterable , Sized )):
258+ raise ValueError (f"Parameter '{ name } ' needs to be iterable, "
259+ f"even if only of length 1. Instead was '{ inputs } '." )
241260
242- if opt . iterations < 1 :
243- raise ValueError ( "At least one iteration (see: 'iterations') needs to "
244- "be done for correction." )
245- return opt
261+ @ classmethod
262+ def from_args_or_dict ( cls , opt : Optional [ Union [ dict , 'InputOptions' ]] = None ) -> 'InputOptions' :
263+ """Create an InputOptions instance from the given dictionary.
264+ If the input is empty, arguments will be parsed from commandline.
246265
266+ Args:
267+ opt (Union[dict, DotDict]): Function options in dictionary format.
268+ Description of the arguments are given in
269+ :func:`irnl_rdt_correction.main.irnl_rdt_correction`.
270+ Optional, if not given parses commandline args
247271
272+ Returns:
273+ InputOptions: (Parsed and) checked options.
274+ """
275+ if isinstance (opt , InputOptions ):
276+ return opt
277+
278+ if opt is None or not len (opt ):
279+ parser = get_parser ()
280+ opt = vars (parser .parse_args ())
281+
282+ return cls (** opt )
283+
284+
285+ def allow_commandline_and_kwargs (func ):
286+ """ Decorator to allow a function to take options from the commandline
287+ or via kwargs, or given an InputOptions instance.
288+ """
289+ def wrapper (opt : Optional [Union [InputOptions , dict ]] = None , ** kwargs ) -> Tuple [str , tfs .TfsDataFrame ]:
290+ if not isinstance (opt , InputOptions ):
291+ if opt is None :
292+ opt = kwargs
293+ opt = InputOptions .from_args_or_dict (opt )
294+ return func (opt )
295+ return wrapper
0 commit comments