diff --git a/README.md b/README.md index d00e444..b3f04fd 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ If tensorflow or onnx is to be used as inference engine, the yaml file should be model: name: optional_default_name_of_the_network_for_the_storage_path label: optional default label of the network for the plots + color: optional color of the network for the plots version: optional_version_number # (e.g. "1.0.0") inference_engine: name_of_inference_engine # (either "tf" or "onnx") file: path_to_your_pb_or_onnx_model_file @@ -153,6 +154,7 @@ Hence, the yaml file should be in the following format: model: name: optional_default_name_of_the_network_for_the_storage_path label: optional default label of the network for the plots + color: optional color of the network for the plots version: version_number # (e.g. "1.0.0") inference_engine: tfaot saved_model: path_to_your_saved_model_directory @@ -264,7 +266,7 @@ This task merges the .csv output files with the required multiple batch sizes fr - The .csv files from the several occurences of `MeasureRuntime` (one for each batch size). ## Parameters: -- batch-sizes: int. The comma-separated list of batch sizes to be tested; default: `1,2,4`. +- batch-sizes: int. The comma-separated list of batch sizes to be tested. default: `1,2,4`. - model-file: str. The absolute path of the yaml file containing the informations of the model to be tested. default: `$MLP_BASE/examples/dnn/model_tf_l10u128.yaml`. @@ -308,21 +310,23 @@ The number of inferences behind one plotted data point is given by `n-events * n - The .csv file from the `MergeRuntimes` task. ## Parameters: -- y-log: bool. Plot the y-axis values logarithmically; default: `False`. +- y-log: bool. Plot the y-axis values logarithmically. default: `False`. -- x-log: bool. Plot the x-axis values logarithmically; default: `False`. +- x-log: bool. Plot the x-axis values logarithmically. default: `False`. - y-min = float. Minimum y-axis value. default: empty - y-max: float. Maximum y-axis value. default: empty -- bs-normalized: bool. Normalize the measured values with the batch size before plotting; default: `True`. +- bs-normalized: bool. Normalize the measured values with the batch size before plotting. default: `True`. -- error-style: str. Style of errors / uncertainties due to averaging; choices: `bars`,`band`; default: `band`. +- error-style: str. Style of errors / uncertainties due to averaging. choices: `bars`,`band`. default: `band`. - top-right-label: str. When set, stick this string as label over the top right corner of the plot. default: empty. -- batch-sizes: int. The comma-separated list of batch sizes to be tested; default: `1,2,4`. +- default_colors: str. Default color cycle to use for plots. choices: `mpl`, `cms_6`, `atlas_10`. default: `cms_6`. + +- batch-sizes: int. The comma-separated list of batch sizes to be tested. default: `1,2,4`. - model-file: str. The absolute path of the yaml file containing the informations of the model to be tested. default: `$MLP_BASE/examples/dnn/model_tf_l10u128.yaml`. @@ -406,21 +410,23 @@ The number of inferences behind one plotted data point is given by `n-events * n - model-labels: str. The comma-separated list of model labels. When set, use these strings for the model labels in the plots from the plotting tasks. When empty, the `label` fields in the models yaml data are used when existing, else the `name` fields in the models yaml data are used when existing, and model-names otherwise. default: empty. -- y-log: bool. Plot the y-axis values logarithmically; default: `False`. +- y-log: bool. Plot the y-axis values logarithmically. default: `False`. -- x-log: bool. Plot the x-axis values logarithmically; default: `False`. +- x-log: bool. Plot the x-axis values logarithmically. default: `False`. - y-min = float. Minimum y-axis value. default: empty - y-max: float. Maximum y-axis value. default: empty -- bs-normalized: bool. Normalize the measured values with the batch size before plotting; default: `True`. +- bs-normalized: bool. Normalize the measured values with the batch size before plotting. default: `True`. -- error-style: str. Style of errors / uncertainties due to averaging; choices: `bars`,`band`; default: `band`. +- error-style: str. Style of errors / uncertainties due to averaging. choices: `bars`,`band`. default: `band`. - top-right-label: str. When set, stick this string as label over the top right corner of the plot. default: empty. -- batch-sizes: int. The comma-separated list of batch sizes to be tested; default: `1,2,4`. +- default_colors: str. Default color cycle to use for plots. choices: `mpl`, `cms_6`, `atlas_10`. default: `cms_6`. + +- batch-sizes: int. The comma-separated list of batch sizes to be tested. default: `1,2,4`. - n-events: int. The number of events to read from each input file for averaging measurements. default: `1` diff --git a/mlprof/plotting/plotter.py b/mlprof/plotting/plotter.py index 0e59858..571a0f3 100644 --- a/mlprof/plotting/plotter.py +++ b/mlprof/plotting/plotter.py @@ -1,11 +1,30 @@ # coding: utf-8 colors = { - "mpl_standard": [ + "mpl": [ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", ], - "custom_edgecolor": ["#CC4F1B", "#1B2ACC", "#3F7F4C"], - "custom_facecolor": ["#FF9848", "#089FFF", "#7EFF99"], + # Atlas and cms standards correspond to results in : + # Color wheel from https://arxiv.org/pdf/2107.02270 Table 1, 10 color palette + # hexacodes in https://github.com/mpetroff/accessible-color-cycles/blob/0a17e754d9f83161baffd803dcea8bee7d95a549/readme.md#final-results # noqa + # as implemented in mplhep + "cms_6": [ + "#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd", + ], + "atlas_10": [ + "#3f90da", + "#ffa90e", + "#bd1f01", + "#94a4a2", + "#832db6", + "#a96b59", + "#e76300", + "#b9ac70", + "#717581", + "#92dadd", + ], + # "custom_edgecolor": ["#CC4F1B", "#1B2ACC", "#3F7F4C"], + # "custom_facecolor": ["#FF9848", "#089FFF", "#7EFF99"], } @@ -86,7 +105,7 @@ def fill_plot(x, y, y_down, y_up, error_style, color): if error_style == "band": p1 = plt.plot(x, y, "-", color=color) plt.fill_between(x, y - y_down, y + y_up, alpha=0.5, facecolor=color) - p2 = plt.fill(np.NaN, np.NaN, alpha=0.5, color=color) + p2 = plt.fill(np.nan, np.nan, alpha=0.5, color=color) legend = (p1[0], p2[0]) else: # bars p = plt.errorbar(x, y, yerr=(y_down, y_up), capsize=12, marker=".", linestyle="") @@ -100,6 +119,7 @@ def plot_batch_size_several_measurements( input_paths, output_path, measurements, + color_list, plot_params, ): """ @@ -114,6 +134,7 @@ def plot_batch_size_several_measurements( """ import matplotlib.pyplot as plt import mplhep # type: ignore[import-untyped] + from cycler import cycler if isinstance(measurements[0], str): measurements_labels_strs = list(measurements) @@ -136,14 +157,19 @@ def plot_batch_size_several_measurements( # create plot with curves using a single color for each value-error pair legend_entries = [] - for data in plot_data: + if plot_params.get("default_colors"): + # set the color cycle to the custom color cycle + ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("default_colors")])) + + for i, data in enumerate(plot_data): + color_used = color_list[i] if color_list[i] else ax._get_lines.get_next_color() entry = fill_plot( x=batch_sizes, y=data["y"], y_down=data["y_down"], y_up=data["y_up"], error_style=plot_params["error_style"], - color=ax._get_lines.get_next_color(), + color=color_used, ) legend_entries.append(entry) diff --git a/mlprof/tasks/parameters.py b/mlprof/tasks/parameters.py index 01a0a3c..7d41aee 100644 --- a/mlprof/tasks/parameters.py +++ b/mlprof/tasks/parameters.py @@ -211,6 +211,7 @@ class MultiModelParameters(BaseTask): description="when set, use these labels in plots; when empty, the `label` fields in the models " "yaml data are used when existing, else the `name` fields in the models yaml data are used when " "existing and model-names otherwise; default: empty", + brace_expand=True, ) def __init__(self, *args, **kwargs): @@ -312,6 +313,13 @@ class CustomPlotParameters(BaseTask): significant=False, description="stick a label over the top right corner of the plot", ) + default_colors = luigi.ChoiceParameter( + choices=["mpl", "cms_6", "atlas_10"], + default="cms_6", + significant=False, + description="default color cycle to use; choices: 'mpl', 'cms_6', 'atlas_10'" + "; default: 'cms_6'", + ) @property def custom_plot_params(self): @@ -323,4 +331,5 @@ def custom_plot_params(self): "bs_normalized": self.bs_normalized, "error_style": self.error_style, "top_right_label": None if self.top_right_label == law.NO_STR else self.top_right_label, + "default_colors": self.default_colors, } diff --git a/mlprof/tasks/runtime.py b/mlprof/tasks/runtime.py index 6bf0395..a1463df 100644 --- a/mlprof/tasks/runtime.py +++ b/mlprof/tasks/runtime.py @@ -150,6 +150,7 @@ def run(self): [self.input().path], output.path, [self.model.full_model_label], + [self.model.color], self.custom_plot_params, ) print("plot saved") @@ -280,5 +281,6 @@ def run(self): input_paths=input_paths, output_path=output.path, measurements=self.params_product_params_to_write, + color_list=[model.color for model in self.models], plot_params=self.custom_plot_params, ) diff --git a/mlprof/util.py b/mlprof/util.py index 5d10f42..7689b52 100644 --- a/mlprof/util.py +++ b/mlprof/util.py @@ -26,6 +26,7 @@ def __init__(self, model_file: str, name: str, label: str, **kwargs) -> None: self.model_file = expand_path(model_file, abs=True) self.name = name self.label = label + self._color = None # cached data self._all_data = None @@ -67,3 +68,9 @@ def full_model_label(self): # fallback to the full model name return self.full_name + + @property + def color(self): + if self._color is None: + self._color = self.data.get("color") + return self._color diff --git a/sandboxes/plotting.txt b/sandboxes/plotting.txt index 7a2a915..62976eb 100644 --- a/sandboxes/plotting.txt +++ b/sandboxes/plotting.txt @@ -1,6 +1,6 @@ -# version 1 +# version 2 -numpy -pandas -matplotlib -mplhep +numpy~=2.0.1 +pandas~=2.2.2 +matplotlib~=3.9.0 +mplhep~=0.3.50