Skip to content

Commit a8ec9d2

Browse files
authored
Plotting Windows (#416)
* New Classes for GUI-Windows to organize plots * Plotting from plot_correction_test now in window * Plotting from plot_tfs now in window * bumpversion, changelog * save-opt only if output-dir is set
1 parent ad7861e commit a8ec9d2

File tree

9 files changed

+539
-49
lines changed

9 files changed

+539
-49
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# OMC3 Changelog
22

3+
#### 2023-06-05 - v0.11.0 - _jdilly_
4+
5+
- Added:
6+
- `omc3.plotting.utils.windows`: Qt-based windows and widgets for matplotlib-figure organization.
7+
- Using the new windows in `omc3.plotting.plot_checked_corrections` and `omc3.plotting.plot_tfs`
8+
39
#### 2023-05-15 - v0.10.0 - _jdilly_
10+
411
- Added:
512
- `omc3.check_corrections`: A new feature to check the validity of corrections.
613
- `omc3.plotting.plot_checked_corrections`: Function to plot the checked corrections.

doc/modules/plotting.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,6 @@ Common Plot Functionality
5454
:members:
5555
:noindex:
5656

57+
.. automodule:: omc3.plotting.utils.windows
58+
:members:
59+
:noindex:

omc3/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
__title__ = "omc3"
1212
__description__ = "An accelerator physics tools package for the OMC team at CERN."
1313
__url__ = "https://github.com/pylhc/omc3"
14-
__version__ = "0.10.0"
14+
__version__ = "0.11.0"
1515
__author__ = "pylhc"
1616
__author_email__ = "[email protected]"
1717
__license__ = "MIT"

omc3/plotting/optics_measurements/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Iterable, Optional
1111

1212
from matplotlib import pyplot as plt
13+
from matplotlib.figure import Figure
1314
from numpy.typing import ArrayLike
1415

1516

@@ -33,8 +34,9 @@ class DataSet:
3334
class FigureContainer:
3435
"""Container for attaching additional information to one figure."""
3536
def __init__(self, id_: str, path: Path, axes_ids: Iterable[str]) -> None:
36-
self.fig, axs = plt.subplots(nrows=len(axes_ids))
37-
self.fig.canvas.manager.set_window_title(id_)
37+
self.fig = Figure()
38+
axs = self.fig.subplots(nrows=len(axes_ids))
39+
self.id = id_
3840

3941
if len(axes_ids) == 1:
4042
axs = [axs]
@@ -55,6 +57,9 @@ class FigureCollector:
5557
def __init__(self) -> None:
5658
self.fig_dict = OrderedDict() # dictionary of matplotlib figures, for output
5759
self.figs = OrderedDict() # dictionary of FigureContainers, used internally
60+
61+
def __len__(self) -> int:
62+
return len(self.figs)
5863

5964
def add_data_for_id(self, figure_id: str, label: str, data: DataSet,
6065
x_label: str, y_label: str,
@@ -93,4 +98,3 @@ def safe_format(label: str, insert: str) -> Optional[str]:
9398
return label
9499
except AttributeError: # label is None
95100
return None
96-

omc3/plotting/plot_checked_corrections.py

Lines changed: 139 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -152,29 +152,43 @@
152152
153153
154154
"""
155+
import re
155156
from pathlib import Path
156-
from typing import Dict, Iterable, Set
157+
from typing import Dict, Iterable, Set, Union
157158

158159
from matplotlib import pyplot as plt
160+
from matplotlib.figure import Figure
159161

160162
import tfs
161163
from generic_parser import DotDict, EntryPointParameters, entrypoint
162-
from omc3.correction.constants import (CORRECTED_LABEL, UNCORRECTED_LABEL, CORRECTION_LABEL, EXPECTED_LABEL,
163-
COUPLING_NAME_TO_MODEL_COLUMN_SUFFIX)
164-
from omc3.definitions.optics import FILE_COLUMN_MAPPING, ColumnsAndLabels, RDT_COLUMN_MAPPING
164+
from omc3.correction.constants import (
165+
CORRECTED_LABEL, UNCORRECTED_LABEL, CORRECTION_LABEL, EXPECTED_LABEL,
166+
COUPLING_NAME_TO_MODEL_COLUMN_SUFFIX
167+
)
168+
from omc3.definitions.optics import (
169+
FILE_COLUMN_MAPPING, ColumnsAndLabels, RDT_COLUMN_MAPPING, RDT_PHASE_COLUMN,
170+
RDT_IMAG_COLUMN, RDT_REAL_COLUMN, RDT_AMPLITUDE_COLUMN
171+
)
165172
from omc3.optics_measurements.constants import EXT
166-
from omc3.plotting.plot_optics_measurements import (_get_x_axis_column_and_label, _get_ip_positions,
167-
get_optics_style_params, get_plottfs_style_params)
173+
from omc3.plotting.plot_optics_measurements import (
174+
_get_x_axis_column_and_label, _get_ip_positions,
175+
get_optics_style_params, get_plottfs_style_params
176+
)
168177
from omc3.plotting.plot_tfs import plot as plot_tfs, get_full_output_path
169-
from omc3.plotting.utils import (annotations as pannot)
178+
from omc3.plotting.utils.windows import (
179+
PlotWidget, TabWidget, VerticalTabWindow,
180+
log_no_qtpy_many_windows, create_pyplot_window_from_fig
181+
)
182+
from omc3.plotting.utils import annotations as pannot
170183
from omc3.utils import logging_tools
171-
from omc3.utils.iotools import PathOrStr
184+
from omc3.utils.iotools import PathOrStr, save_config
172185

173186
LOG = logging_tools.get_logger(__name__)
174187

175188
SPLIT_ID = "#_#" # will appear in the figure ID, but should be fine to read
176189
PREFIX = "plot_corrections"
177190

191+
178192
def get_plotting_params() -> EntryPointParameters:
179193
params = EntryPointParameters()
180194
params.add_parameter(name="input_dir",
@@ -220,6 +234,12 @@ def get_plotting_style_parameters():
220234
def plot_checked_corrections(opt: DotDict):
221235
""" Entrypoint for the plotting function. """
222236
LOG.info("Plotting checked corrections.")
237+
238+
opt.input_dir = Path(opt.input_dir)
239+
if opt.output_dir:
240+
opt.output_dir = Path(opt.output_dir)
241+
save_config(opt.output_dir, opt, __file__)
242+
223243
# Preparations -------------------------------------------------------------
224244
correction_dirs: Dict[str, Path] = {}
225245
if len(opt.corrections) == 1 and not opt.corrections[0]:
@@ -248,7 +268,7 @@ def plot_checked_corrections(opt: DotDict):
248268
# get F1### column map without the I, R, A, P part based on the rdt-filename:
249269
LOG.debug(f"Plotting coupling correction for {filename}")
250270
new_figs = {}
251-
for idx, y_colmap in enumerate(RDT_COLUMN_MAPPING.values()): # AMP, PHASE, REAL or IMAG as column-map
271+
for y_colmap in RDT_COLUMN_MAPPING.values(): # AMP, PHASE, REAL or IMAG as column-map
252272
y_colmap = ColumnsAndLabels(
253273
_column=y_colmap.column,
254274
_label=y_colmap.label.format(filename), # this one needs additional info
@@ -281,11 +301,27 @@ def plot_checked_corrections(opt: DotDict):
281301
fig_dict.update(new_figs)
282302

283303
# Output -------------------------------------------------------------------
284-
save_plots(opt.output_dir, figure_dict=fig_dict, input_dir=opt.input_dir if opt.individual_to_input else None)
285-
show_plots(opt.show)
304+
if opt.output_dir:
305+
save_plots(
306+
opt.output_dir,
307+
figure_dict=fig_dict,
308+
input_dir=opt.input_dir if opt.individual_to_input else None
309+
)
310+
311+
if opt.show:
312+
show_plots(fig_dict)
286313
return fig_dict
287314

288-
def _create_correction_plots_per_filename(filename, measurements, correction_dirs, x_colmap, y_colmap, ip_positions, opt):
315+
316+
def _create_correction_plots_per_filename(
317+
filename: str,
318+
measurements: Path,
319+
correction_dirs: Dict[str, Path],
320+
x_colmap: ColumnsAndLabels,
321+
y_colmap: ColumnsAndLabels,
322+
ip_positions: Union[str, Dict[str, float], Path],
323+
opt: DotDict
324+
):
289325
""" Plot measurements and all different correction scenarios into a single plot. """
290326
full_filename = f"{filename}{EXT}"
291327
file_label = filename
@@ -305,7 +341,7 @@ def _create_correction_plots_per_filename(filename, measurements, correction_dir
305341
x_labels=[x_colmap.label],
306342
vertical_lines=ip_positions + opt.lines_manual,
307343
same_axes=["files"],
308-
output_prefix=f"{PREFIX}_{file_label}_", # used in the id, which is the fig_dict key
344+
output_prefix=f"{file_label}_", # used in the id, which is the fig_dict key
309345
**opt.get_subdict([
310346
'plot_styles', 'manual_style',
311347
'change_marker', 'errorbar_alpha',
@@ -366,41 +402,110 @@ def _create_correction_plots_per_filename(filename, measurements, correction_dir
366402
return figs
367403

368404

369-
def show_plots(show: bool):
370-
""" Show plots if so desired. """
371-
# plt.show()
372-
if show:
373-
plt.show()
374-
375-
376-
def save_plots(output_dir, figure_dict, input_dir=None):
405+
def save_plots(output_dir: Path, figure_dict: Dict[str, Figure], input_dir: Path = None):
377406
""" Save the plots. """
378-
if not output_dir and not input_dir:
379-
return
380-
381407
for figname, fig in figure_dict.items():
382408
outdir = output_dir
383409
figname_parts = figname.split(SPLIT_ID)
384410
if len(figname_parts) == 1: # no SPLIT_ID
385411
# these are the combined plots. They have the column name at the end,
386412
# which we do not care for here at the moment.
387413
# In case of multiple columns per file, this could be brought back
388-
# (then we would also not need the RDT check).
389-
figname = "_".join(figname.split("_")[:-1])
414+
figname = "_".join([PREFIX] + figname.split("_")[:-1])
390415
else:
391416
# this is then the individual plots
392417
if input_dir:
393418
# files go directly into the correction-scenario folders
394419
outdir = input_dir / figname_parts[0]
395420
figname = f"{PREFIX}_{figname_parts[1]}"
396421
else:
397-
# everything goes into the output-dir (if given), but needs prefix
422+
# everything goes into the output-dir (if given),
423+
# but with correction-name as additional prefix
398424
figname = "_".join([PREFIX] + figname_parts)
399425

400426
output_path = get_full_output_path(outdir, figname)
401-
if output_path is not None:
402-
LOG.info(f"Saving Corrections Plot to '{output_path}'")
403-
fig.savefig(output_path)
427+
LOG.debug(f"Saving corrections plot to '{output_path}'")
428+
fig.savefig(output_path)
429+
430+
if input_dir:
431+
LOG.info(f"Saved all correction plots in '{output_dir}'\n"
432+
f"and into the correction-scenario in '{input_dir}'.")
433+
else:
434+
LOG.info(f"Saved all correction plots in '{output_dir}'.")
435+
436+
437+
def show_plots(figure_dict: Dict[str, Figure]):
438+
"""Displays the provided figures.
439+
If `qtpy` is installed, they are shown in a single window.
440+
The individual corrections are sorted into vertical tabs,
441+
the optics parameter into horizontal tabs.
442+
If `qtpy` is not installed, they are simply shown as individual figures.
443+
This is not recommended
444+
"""
445+
try:
446+
window = VerticalTabWindow("Correction Check")
447+
except TypeError:
448+
log_no_qtpy_many_windows()
449+
for name, fig in figure_dict.items():
450+
create_pyplot_window_from_fig(fig, title=name.replace(SPLIT_ID, " "))
451+
plt.show()
452+
return
453+
454+
rdt_pattern = re.compile(r"f\d{4}")
455+
rdt_complement = {
456+
RDT_REAL_COLUMN.text_label: RDT_IMAG_COLUMN,
457+
RDT_AMPLITUDE_COLUMN.text_label: RDT_PHASE_COLUMN,
458+
}
459+
460+
correction_names = sorted(set([k.split(SPLIT_ID)[0] for k in figure_dict.keys() if SPLIT_ID in k]))
461+
for correction_name in [None] + list(correction_names):
462+
if not correction_name:
463+
parameter_names = iter(sorted(k for k in figure_dict.keys() if SPLIT_ID not in k))
464+
correction_tab_name = "All Corrections"
465+
else:
466+
parameter_names = iter(sorted(k for k in figure_dict.keys() if correction_name in k))
467+
correction_tab_name = correction_name
468+
469+
current_tab = TabWidget(title=correction_tab_name)
470+
window.add_tab(current_tab)
471+
472+
for name_x in parameter_names:
473+
# extract the filename (and column-name in case of multi-correction-file)
474+
tab_prename = name_x.split(SPLIT_ID)[-1]
475+
476+
if rdt_pattern.match(tab_prename):
477+
# Handle RDTs: Get the rdt column and if it's amplitude or real,
478+
# we look for the respective complement column (phase, imag).
479+
# Both, column and complement column are then added to the tab,
480+
# which is named after the rdt followed by either AP (amp/phase) or RI (real/imag)).
481+
rdt, column = tab_prename.split("_")[:2]
482+
try:
483+
complement_column: ColumnsAndLabels = rdt_complement[column]
484+
except KeyError:
485+
# skip phase and imag as they will become name_y for amp and real.
486+
continue
487+
488+
if not correction_name:
489+
name_y = "_".join([rdt, complement_column.text_label, complement_column.expected_column])
490+
else:
491+
name_y = "_".join(name_x.split("_")[:-1] + [complement_column.text_label,])
492+
493+
tab_name = f"{rdt} {column[0].upper()}/{complement_column.text_label[0].upper()}"
494+
495+
else:
496+
# Handle non-RDT columns: As they are sorted alphabetically, the current column
497+
# is x and the following column is y. They are added to the tab, which
498+
# is named by the optics parameter without plane.
499+
tab_name = " ".join(tab_prename.split("_")[:-1 if correction_name else -2]) # remove plane (and column-name)
500+
name_y = next(parameter_names)
501+
502+
new_tab = PlotWidget(
503+
figure_dict[name_x],
504+
figure_dict[name_y],
505+
title=tab_name,
506+
)
507+
current_tab.add_tab(new_tab)
508+
window.show()
404509

405510

406511
def _get_corrected_measurement_names(correction_dirs: Iterable[Path]) -> Set[str]:
@@ -415,3 +520,7 @@ def _get_corrected_measurement_names(correction_dirs: Iterable[Path]) -> Set[str
415520
tfs_files &= new_files
416521
# tfs_files -= {Path(MODEL_MATCHED_FILENAME).stem, Path(MODEL_NOMINAL_FILENAME).stem} # no need, filtered later anyway
417522
return tfs_files
523+
524+
525+
if __name__ == "__main__":
526+
plot_checked_corrections()

omc3/plotting/plot_tfs.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,26 +105,31 @@
105105
- **y_lim** *(float, int None)*: Limits on the y axis (Tupel)
106106
"""
107107
from collections import OrderedDict
108-
from dataclasses import dataclass
109108
from pathlib import Path
110-
from typing import Any, Optional, Dict
109+
from typing import Dict
111110

112111
import matplotlib
112+
from matplotlib import pyplot as plt, rcParams
113113
from matplotlib.axes import Axes
114-
from numpy.typing import ArrayLike
115114

116115
import tfs
117-
from generic_parser import EntryPointParameters, entrypoint, DotDict
118-
from generic_parser.entry_datatypes import DictAsString, get_multi_class
119-
from matplotlib import pyplot as plt, rcParams
116+
from generic_parser import EntryPointParameters, entrypoint
117+
from generic_parser.entry_datatypes import DictAsString
120118

121119
from omc3.definitions.constants import PLANES
122120
from omc3.optics_measurements.constants import EXT
123121
from omc3.plotting.optics_measurements.constants import DEFAULTS
124122
from omc3.plotting.optics_measurements.utils import FigureCollector, DataSet, IDMap, safe_format
123+
from omc3.plotting.utils.windows import (
124+
PlotWidget, SimpleTabWindow, is_qtpy_installed, log_no_qtpy_many_windows, create_pyplot_window_from_fig
125+
)
125126
from omc3.plotting.spectrum.utils import get_unique_filenames, output_plot
126-
from omc3.plotting.utils import (annotations as pannot, lines as plines,
127-
style as pstyle, colors as pcolors)
127+
from omc3.plotting.utils import (
128+
annotations as pannot,
129+
lines as plines,
130+
style as pstyle,
131+
colors as pcolors,
132+
)
128133
from omc3.plotting.utils.lines import VERTICAL_LINES_TEXT_LOCATIONS
129134
from omc3.utils.iotools import PathOrStr, save_config, OptionalStr, OptionalFloat
130135
from omc3.utils.logging_tools import get_logger, list2str
@@ -302,7 +307,7 @@ def get_params():
302307
@entrypoint(get_params(), strict=True)
303308
def plot(opt):
304309
"""Main plotting function."""
305-
LOG.info(f"Starting plotting of tfs files: {list2str(opt.files):s}")
310+
LOG.debug(f"Starting plotting of tfs files: {list2str(opt.files):s}")
306311
if opt.output is not None:
307312
save_config(Path(opt.output), opt, __file__)
308313

@@ -531,8 +536,22 @@ def _create_plots(fig_collection, opt):
531536
output_plot(fig_container)
532537

533538
if opt.show:
539+
if len(fig_collection) > 1 and is_qtpy_installed():
540+
window = SimpleTabWindow("Tfs Plots")
541+
for fig_container in fig_collection.figs.values():
542+
tab = PlotWidget(fig_container.fig, title=fig_container.id)
543+
window.add_tab(tab)
544+
window.show()
545+
return
546+
547+
if len(fig_collection) > rcParams['figure.max_open_warning']:
548+
log_no_qtpy_many_windows()
549+
550+
for fig_container in fig_collection.figs.values():
551+
create_pyplot_window_from_fig(fig_container.fig, title=fig_container.id)
534552
plt.show()
535553

554+
536555

537556
def _plot_data(ax: Axes, data: Dict[str, DataSet], change_marker: bool, ebar_alpha: float):
538557
for idx, (label, values) in enumerate(data.items()):

0 commit comments

Comments
 (0)