Skip to content

Commit 4de0868

Browse files
authierjdennisbadermadtoinou
authored
Color choice for multivariate timeseries (#2680)
* colors arg added, accept lists of length num_comp * color or c in kwargs * check added * changelog entry added * refactoring after discussion with Dennis * CHANGELOG changed Co-authored-by: Dennis Bader <[email protected]> * custom_colors instead of multi_color Co-authored-by: Dennis Bader <[email protected]> * logic check upgraded Co-authored-by: Dennis Bader <[email protected]> * doc corrected * spelling Co-authored-by: Dennis Bader <[email protected]> * change of logic Co-authored-by: Dennis Bader <[email protected]> * change of logic follow-up Co-authored-by: Dennis Bader <[email protected]> * alpha moved to arguments, error if color and c are used at the same time * update plotting --------- Co-authored-by: Dennis Bader <[email protected]> Co-authored-by: madtoinou <[email protected]>
1 parent daba9a9 commit 4de0868

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1313

1414
- Added ONNX support for torch-based models with method `TorchForecastingModel.to_onnx()`. Check out [this example](https://unit8co.github.io/darts/userguide/gpu_and_tpu_usage.html#exporting-model-to-onnx-format-for-inference) from the user guide on how to export and load a model for inference. [#2620](https://github.com/unit8co/darts/pull/2620) by [Antoine Madrona](https://github.com/madtoinou)
1515
- Made method `ForecastingModel.untrained_model()` public. Use this method to get a new (untrained) model instance created with the same parameters. [#2684](https://github.com/unit8co/darts/pull/2684) by [Timon Erhart](https://github.com/turbotimon)
16-
- Made it possbile to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj)
16+
- `TimeSeries.plot()` now supports setting the color for each component in the series. Simply pass a list / sequence of colors with length matching the number of components as parameters "c" or "colors". [#2680](https://github.com/unit8co/darts/pull/2680) by [Jules Authier](https://github.com/authierj)
17+
- Made it possible to run the quickstart notebook `00-quickstart.ipynb` locally. [#2691](https://github.com/unit8co/darts/pull/2691) by [Jules Authier](https://github.com/authierj)
1718

1819
**Fixed**
1920

darts/timeseries.py

+60-26
Original file line numberDiff line numberDiff line change
@@ -4105,6 +4105,9 @@ def plot(
41054105
label: Optional[Union[str, Sequence[str]]] = "",
41064106
max_nr_components: int = 10,
41074107
ax: Optional[matplotlib.axes.Axes] = None,
4108+
alpha: Optional[float] = None,
4109+
color: Optional[Union[str, tuple, Sequence[str, tuple]]] = None,
4110+
c: Optional[Union[str, tuple, Sequence[str, tuple]]] = None,
41084111
*args,
41094112
**kwargs,
41104113
) -> matplotlib.axes.Axes:
@@ -4144,8 +4147,16 @@ def plot(
41444147
Optionally, an axis to plot on. If `None`, and `new_plot=False`, will use the current axis. If
41454148
`new_plot=True`, will create a new axis.
41464149
alpha
4147-
Optionally, set the line alpha for deterministic series, or the confidence interval alpha for
4150+
Optionally, set the line alpha for deterministic series, or the confidence interval alpha for
41484151
probabilistic series.
4152+
color
4153+
Can either be a single color or list of colors. Any matplotlib color is accepted (string, hex string,
4154+
RGB/RGBA tuple). If a single color and the series only has a single component, it is used as the color
4155+
for that component. If a single color and the series has multiple components, it is used as the color
4156+
for each component. If a list of colors with length equal to the number of components in the series, the
4157+
colors will be mapped to the components in order.
4158+
c
4159+
An alias for `color`.
41494160
args
41504161
some positional arguments for the `plot()` method
41514162
kwargs
@@ -4172,40 +4183,63 @@ def plot(
41724183
logger,
41734184
)
41744185

4175-
if new_plot:
4176-
fig, ax = plt.subplots()
4186+
if max_nr_components == -1:
4187+
n_components_to_plot = self.n_components
41774188
else:
4178-
if ax is None:
4179-
ax = plt.gca()
4189+
n_components_to_plot = min(self.n_components, max_nr_components)
41804190

4181-
if not any(lw in kwargs for lw in ["lw", "linewidth"]):
4182-
kwargs["lw"] = 2
4183-
4184-
n_components_to_plot = max_nr_components
4185-
if n_components_to_plot == -1:
4186-
n_components_to_plot = self.n_components
4187-
elif self.n_components > max_nr_components:
4191+
if self.n_components > n_components_to_plot:
41884192
logger.warning(
4189-
f"Number of components is larger than {max_nr_components} ({self.n_components}). "
4190-
f"Plotting only the first {max_nr_components} components."
4191-
f"You can overwrite this in the using the `plot_all_components` argument in plot()"
4192-
f"Beware that plotting a large number of components may cause performance issues."
4193+
f"Number of series components ({self.n_components}) is larger than the maximum number of "
4194+
f"components to plot ({max_nr_components}). Plotting only the first `{max_nr_components}` "
4195+
f"components. You can adjust the number of components to plot using `max_nr_components`."
41934196
)
41944197

41954198
if not isinstance(label, str) and isinstance(label, Sequence):
4196-
raise_if_not(
4197-
len(label) == self.n_components
4198-
or (
4199-
self.n_components > n_components_to_plot
4200-
and len(label) >= n_components_to_plot
4199+
if len(label) != self.n_components and len(label) != n_components_to_plot:
4200+
raise_log(
4201+
ValueError(
4202+
f"The `label` sequence must have the same length as the number of series components "
4203+
f"({self.n_components}) or as the number of plotted components ({n_components_to_plot}). "
4204+
f"Received length `{len(label)}`."
4205+
),
4206+
logger,
4207+
)
4208+
custom_labels = True
4209+
else:
4210+
custom_labels = False
4211+
4212+
if color and c:
4213+
raise_log(
4214+
ValueError(
4215+
"`color` and `c` must not be used simultaneously, use one or the other."
42014216
),
4202-
"The label argument should have the same length as the number of plotted components "
4203-
f"({min(self.n_components, n_components_to_plot)}), only {len(label)} labels were provided",
42044217
logger,
42054218
)
4206-
custom_labels = True
4219+
color = color or c
4220+
if not isinstance(color, (str, tuple)) and isinstance(color, Sequence):
4221+
if len(color) != self.n_components and len(color) != n_components_to_plot:
4222+
raise_log(
4223+
ValueError(
4224+
f"The `color` sequence must have the same length as the number of series components "
4225+
f"({self.n_components}) or as the number of plotted components ({n_components_to_plot}). "
4226+
f"Received length `{len(label)}`."
4227+
),
4228+
logger,
4229+
)
4230+
custom_colors = True
42074231
else:
4208-
custom_labels = False
4232+
custom_colors = False
4233+
4234+
kwargs["alpha"] = alpha
4235+
if not any(lw in kwargs for lw in ["lw", "linewidth"]):
4236+
kwargs["lw"] = 2
4237+
4238+
if new_plot:
4239+
fig, ax = plt.subplots()
4240+
else:
4241+
if ax is None:
4242+
ax = plt.gca()
42094243

42104244
for i, c in enumerate(self._xa.component[:n_components_to_plot]):
42114245
comp_name = str(c.values)
@@ -4219,7 +4253,6 @@ def plot(
42194253
else:
42204254
central_series = comp.mean(dim=DIMS[2])
42214255

4222-
alpha = kwargs["alpha"] if "alpha" in kwargs else None
42234256
if custom_labels:
42244257
label_to_use = label[i]
42254258
else:
@@ -4230,6 +4263,7 @@ def plot(
42304263
else:
42314264
label_to_use = f"{label}_{comp_name}"
42324265
kwargs["label"] = label_to_use
4266+
kwargs["c"] = color[i] if custom_colors else color
42334267

42344268
kwargs_central = deepcopy(kwargs)
42354269
if not self.is_deterministic:

0 commit comments

Comments
 (0)