152152
153153
154154"""
155+ import re
155156from pathlib import Path
156- from typing import Dict , Iterable , Set
157+ from typing import Dict , Iterable , Set , Union
157158
158159from matplotlib import pyplot as plt
160+ from matplotlib .figure import Figure
159161
160162import tfs
161163from generic_parser import DotDict , EntryPointParameters , entrypoint
162- from omc3 .correction .constants import (CORRECTED_LABEL , UNCORRECTED_LABEL , CORRECTION_LABEL , EXPECTED_LABEL ,
163- COUPLING_NAME_TO_MODEL_COLUMN_SUFFIX )
164- from omc3 .definitions .optics import FILE_COLUMN_MAPPING , ColumnsAndLabels , RDT_COLUMN_MAPPING
164+ from omc3 .correction .constants import (
165+ CORRECTED_LABEL , UNCORRECTED_LABEL , CORRECTION_LABEL , EXPECTED_LABEL ,
166+ COUPLING_NAME_TO_MODEL_COLUMN_SUFFIX
167+ )
168+ from omc3 .definitions .optics import (
169+ FILE_COLUMN_MAPPING , ColumnsAndLabels , RDT_COLUMN_MAPPING , RDT_PHASE_COLUMN ,
170+ RDT_IMAG_COLUMN , RDT_REAL_COLUMN , RDT_AMPLITUDE_COLUMN
171+ )
165172from omc3 .optics_measurements .constants import EXT
166- from omc3 .plotting .plot_optics_measurements import (_get_x_axis_column_and_label , _get_ip_positions ,
167- get_optics_style_params , get_plottfs_style_params )
173+ from omc3 .plotting .plot_optics_measurements import (
174+ _get_x_axis_column_and_label , _get_ip_positions ,
175+ get_optics_style_params , get_plottfs_style_params
176+ )
168177from omc3 .plotting .plot_tfs import plot as plot_tfs , get_full_output_path
169- from omc3 .plotting .utils import (annotations as pannot )
178+ from omc3 .plotting .utils .windows import (
179+ PlotWidget , TabWidget , VerticalTabWindow ,
180+ log_no_qtpy_many_windows , create_pyplot_window_from_fig
181+ )
182+ from omc3 .plotting .utils import annotations as pannot
170183from omc3 .utils import logging_tools
171- from omc3 .utils .iotools import PathOrStr
184+ from omc3 .utils .iotools import PathOrStr , save_config
172185
173186LOG = logging_tools .get_logger (__name__ )
174187
175188SPLIT_ID = "#_#" # will appear in the figure ID, but should be fine to read
176189PREFIX = "plot_corrections"
177190
191+
178192def get_plotting_params () -> EntryPointParameters :
179193 params = EntryPointParameters ()
180194 params .add_parameter (name = "input_dir" ,
@@ -220,6 +234,12 @@ def get_plotting_style_parameters():
220234def plot_checked_corrections (opt : DotDict ):
221235 """ Entrypoint for the plotting function. """
222236 LOG .info ("Plotting checked corrections." )
237+
238+ opt .input_dir = Path (opt .input_dir )
239+ if opt .output_dir :
240+ opt .output_dir = Path (opt .output_dir )
241+ save_config (opt .output_dir , opt , __file__ )
242+
223243 # Preparations -------------------------------------------------------------
224244 correction_dirs : Dict [str , Path ] = {}
225245 if len (opt .corrections ) == 1 and not opt .corrections [0 ]:
@@ -248,7 +268,7 @@ def plot_checked_corrections(opt: DotDict):
248268 # get F1### column map without the I, R, A, P part based on the rdt-filename:
249269 LOG .debug (f"Plotting coupling correction for { filename } " )
250270 new_figs = {}
251- for idx , y_colmap in enumerate ( RDT_COLUMN_MAPPING .values () ): # AMP, PHASE, REAL or IMAG as column-map
271+ for y_colmap in RDT_COLUMN_MAPPING .values (): # AMP, PHASE, REAL or IMAG as column-map
252272 y_colmap = ColumnsAndLabels (
253273 _column = y_colmap .column ,
254274 _label = y_colmap .label .format (filename ), # this one needs additional info
@@ -281,11 +301,27 @@ def plot_checked_corrections(opt: DotDict):
281301 fig_dict .update (new_figs )
282302
283303 # Output -------------------------------------------------------------------
284- save_plots (opt .output_dir , figure_dict = fig_dict , input_dir = opt .input_dir if opt .individual_to_input else None )
285- show_plots (opt .show )
304+ if opt .output_dir :
305+ save_plots (
306+ opt .output_dir ,
307+ figure_dict = fig_dict ,
308+ input_dir = opt .input_dir if opt .individual_to_input else None
309+ )
310+
311+ if opt .show :
312+ show_plots (fig_dict )
286313 return fig_dict
287314
288- def _create_correction_plots_per_filename (filename , measurements , correction_dirs , x_colmap , y_colmap , ip_positions , opt ):
315+
316+ def _create_correction_plots_per_filename (
317+ filename : str ,
318+ measurements : Path ,
319+ correction_dirs : Dict [str , Path ],
320+ x_colmap : ColumnsAndLabels ,
321+ y_colmap : ColumnsAndLabels ,
322+ ip_positions : Union [str , Dict [str , float ], Path ],
323+ opt : DotDict
324+ ):
289325 """ Plot measurements and all different correction scenarios into a single plot. """
290326 full_filename = f"{ filename } { EXT } "
291327 file_label = filename
@@ -305,7 +341,7 @@ def _create_correction_plots_per_filename(filename, measurements, correction_dir
305341 x_labels = [x_colmap .label ],
306342 vertical_lines = ip_positions + opt .lines_manual ,
307343 same_axes = ["files" ],
308- output_prefix = f"{ PREFIX } _ { file_label } _" , # used in the id, which is the fig_dict key
344+ output_prefix = f"{ file_label } _" , # used in the id, which is the fig_dict key
309345 ** opt .get_subdict ([
310346 'plot_styles' , 'manual_style' ,
311347 'change_marker' , 'errorbar_alpha' ,
@@ -366,41 +402,110 @@ def _create_correction_plots_per_filename(filename, measurements, correction_dir
366402 return figs
367403
368404
369- def show_plots (show : bool ):
370- """ Show plots if so desired. """
371- # plt.show()
372- if show :
373- plt .show ()
374-
375-
376- def save_plots (output_dir , figure_dict , input_dir = None ):
405+ def save_plots (output_dir : Path , figure_dict : Dict [str , Figure ], input_dir : Path = None ):
377406 """ Save the plots. """
378- if not output_dir and not input_dir :
379- return
380-
381407 for figname , fig in figure_dict .items ():
382408 outdir = output_dir
383409 figname_parts = figname .split (SPLIT_ID )
384410 if len (figname_parts ) == 1 : # no SPLIT_ID
385411 # these are the combined plots. They have the column name at the end,
386412 # which we do not care for here at the moment.
387413 # In case of multiple columns per file, this could be brought back
388- # (then we would also not need the RDT check).
389- figname = "_" .join (figname .split ("_" )[:- 1 ])
414+ figname = "_" .join ([PREFIX ] + figname .split ("_" )[:- 1 ])
390415 else :
391416 # this is then the individual plots
392417 if input_dir :
393418 # files go directly into the correction-scenario folders
394419 outdir = input_dir / figname_parts [0 ]
395420 figname = f"{ PREFIX } _{ figname_parts [1 ]} "
396421 else :
397- # everything goes into the output-dir (if given), but needs prefix
422+ # everything goes into the output-dir (if given),
423+ # but with correction-name as additional prefix
398424 figname = "_" .join ([PREFIX ] + figname_parts )
399425
400426 output_path = get_full_output_path (outdir , figname )
401- if output_path is not None :
402- LOG .info (f"Saving Corrections Plot to '{ output_path } '" )
403- fig .savefig (output_path )
427+ LOG .debug (f"Saving corrections plot to '{ output_path } '" )
428+ fig .savefig (output_path )
429+
430+ if input_dir :
431+ LOG .info (f"Saved all correction plots in '{ output_dir } '\n "
432+ f"and into the correction-scenario in '{ input_dir } '." )
433+ else :
434+ LOG .info (f"Saved all correction plots in '{ output_dir } '." )
435+
436+
437+ def show_plots (figure_dict : Dict [str , Figure ]):
438+ """Displays the provided figures.
439+ If `qtpy` is installed, they are shown in a single window.
440+ The individual corrections are sorted into vertical tabs,
441+ the optics parameter into horizontal tabs.
442+ If `qtpy` is not installed, they are simply shown as individual figures.
443+ This is not recommended
444+ """
445+ try :
446+ window = VerticalTabWindow ("Correction Check" )
447+ except TypeError :
448+ log_no_qtpy_many_windows ()
449+ for name , fig in figure_dict .items ():
450+ create_pyplot_window_from_fig (fig , title = name .replace (SPLIT_ID , " " ))
451+ plt .show ()
452+ return
453+
454+ rdt_pattern = re .compile (r"f\d{4}" )
455+ rdt_complement = {
456+ RDT_REAL_COLUMN .text_label : RDT_IMAG_COLUMN ,
457+ RDT_AMPLITUDE_COLUMN .text_label : RDT_PHASE_COLUMN ,
458+ }
459+
460+ correction_names = sorted (set ([k .split (SPLIT_ID )[0 ] for k in figure_dict .keys () if SPLIT_ID in k ]))
461+ for correction_name in [None ] + list (correction_names ):
462+ if not correction_name :
463+ parameter_names = iter (sorted (k for k in figure_dict .keys () if SPLIT_ID not in k ))
464+ correction_tab_name = "All Corrections"
465+ else :
466+ parameter_names = iter (sorted (k for k in figure_dict .keys () if correction_name in k ))
467+ correction_tab_name = correction_name
468+
469+ current_tab = TabWidget (title = correction_tab_name )
470+ window .add_tab (current_tab )
471+
472+ for name_x in parameter_names :
473+ # extract the filename (and column-name in case of multi-correction-file)
474+ tab_prename = name_x .split (SPLIT_ID )[- 1 ]
475+
476+ if rdt_pattern .match (tab_prename ):
477+ # Handle RDTs: Get the rdt column and if it's amplitude or real,
478+ # we look for the respective complement column (phase, imag).
479+ # Both, column and complement column are then added to the tab,
480+ # which is named after the rdt followed by either AP (amp/phase) or RI (real/imag)).
481+ rdt , column = tab_prename .split ("_" )[:2 ]
482+ try :
483+ complement_column : ColumnsAndLabels = rdt_complement [column ]
484+ except KeyError :
485+ # skip phase and imag as they will become name_y for amp and real.
486+ continue
487+
488+ if not correction_name :
489+ name_y = "_" .join ([rdt , complement_column .text_label , complement_column .expected_column ])
490+ else :
491+ name_y = "_" .join (name_x .split ("_" )[:- 1 ] + [complement_column .text_label ,])
492+
493+ tab_name = f"{ rdt } { column [0 ].upper ()} /{ complement_column .text_label [0 ].upper ()} "
494+
495+ else :
496+ # Handle non-RDT columns: As they are sorted alphabetically, the current column
497+ # is x and the following column is y. They are added to the tab, which
498+ # is named by the optics parameter without plane.
499+ tab_name = " " .join (tab_prename .split ("_" )[:- 1 if correction_name else - 2 ]) # remove plane (and column-name)
500+ name_y = next (parameter_names )
501+
502+ new_tab = PlotWidget (
503+ figure_dict [name_x ],
504+ figure_dict [name_y ],
505+ title = tab_name ,
506+ )
507+ current_tab .add_tab (new_tab )
508+ window .show ()
404509
405510
406511def _get_corrected_measurement_names (correction_dirs : Iterable [Path ]) -> Set [str ]:
@@ -415,3 +520,7 @@ def _get_corrected_measurement_names(correction_dirs: Iterable[Path]) -> Set[str
415520 tfs_files &= new_files
416521 # tfs_files -= {Path(MODEL_MATCHED_FILENAME).stem, Path(MODEL_NOMINAL_FILENAME).stem} # no need, filtered later anyway
417522 return tfs_files
523+
524+
525+ if __name__ == "__main__" :
526+ plot_checked_corrections ()
0 commit comments