Skip to content

Commit 08d984c

Browse files
authored
Move pandas and matplotlib to optional extras (#2244)
* Move pandas and matplotlib to optional extras Update setup.py to remove them from core dependencies. Add lazy imports with clear error messages in logger, monitor, and plotter. Move read_csv and read_json helpers to test files. Update documentation and changelog. * Update README (grammar check) * Fix review comments * Add new expression to exclude from code coverage report * Test error cases for monitor wrapper * Add test to check that core SB3 doesn't depends on pandas/matplotlib
1 parent 499b424 commit 08d984c

11 files changed

Lines changed: 176 additions & 56 deletions

File tree

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ These algorithms will make it easier for the research community and industry to
1919

2020
## Main Features
2121

22-
**The performance of each algorithm was tested** (see *Results* section in their respective page),
23-
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
22+
**The performance of each algorithm was tested** (see *Results* section in their respective page).
23+
You can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
2424

2525
We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.
2626

@@ -43,7 +43,7 @@ We also provide detailed logs and reports on the [OpenRL Benchmark](https://wand
4343

4444
### Planned features
4545

46-
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
46+
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3; it is now *stable*.
4747
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).
4848

4949
While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
@@ -116,7 +116,7 @@ Install the Stable Baselines3 package:
116116
pip install 'stable-baselines3[extra]'
117117
```
118118

119-
This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
119+
This includes optional dependencies like Tensorboard, OpenCV, `ale-py` to train on atari games, as well as `pandas` and `matplotlib` for plotting and analyzing results. If you do not need those, you can use:
120120
```sh
121121
pip install stable-baselines3
122122
```
@@ -163,7 +163,7 @@ model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
163163
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.
164164

165165

166-
## Try it online with Colab Notebooks !
166+
## Try it online with Colab Notebooks!
167167

168168
All the following examples can be executed online using Google Colab notebooks:
169169

@@ -201,7 +201,7 @@ All the following examples can be executed online using Google Colab notebooks:
201201

202202
Actions `gymnasium.spaces`:
203203
* `Box`: A N-dimensional box that contains every point in the action space.
204-
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
204+
* `Discrete`: A list of possible actions, where only one action can be used per timestep.
205205
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
206206
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
207207

@@ -272,12 +272,12 @@ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv
272272

273273
## How To Contribute
274274

275-
To any interested in making the baselines better, there is still some documentation that needs to be done.
275+
For anyone interested in making the baselines better, there is still some documentation that needs to be done.
276276
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.
277277

278278
## Acknowledgments
279279

280-
The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU's Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)).
280+
The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)).
281281

282282
The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).
283283

docs/guide/plotting.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
Stable Baselines3 provides utilities for plotting training results, allowing you to monitor and visualize your agent's learning progress.
66
The plotting functionality is provided by the `results_plotter` module, which can load monitor files created during training and generate various plots.
77

8+
:::{note}
9+
Plotting requires `pandas` and `matplotlib`. Install them with:
10+
```bash
11+
pip install pandas matplotlib
12+
```
13+
Or install the extra dependencies:
14+
```bash
15+
pip install 'stable-baselines3[extra]'
16+
```
17+
:::
18+
819
:::{note}
920
We recommend using the
1021
[RL Baselines3 Zoo plotting scripts](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html)

docs/misc/changelog.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
# Changelog
44

5-
## Release 2.9.0a1 (WIP)
5+
## Release 2.9.0a2 (WIP)
66

77
### Breaking Changes:
8-
- Relax Gymnasium version range (from `"gymnasium>=0.29.1,<1.3.0"` to `"gymnasium>=0.29.1,<2.0"`)
8+
- Relaxed Gymnasium version range (from `"gymnasium>=0.29.1,<1.3.0"` to `"gymnasium>=0.29.1,<2.0"`)
9+
- `pandas` and `matplotlib` are no longer core dependencies; they are now optional and only required for loading results and plotting (moved to `stable-baselines3[extra]`).
10+
- Moved `read_json` and `read_csv` helper functions to test files
911

1012
### New Features:
1113

@@ -21,7 +23,7 @@
2123

2224
### Others:
2325

24-
- Optimize tests (faster to run)
26+
- Optimized tests (faster to run)
2527

2628
### Documentation:
2729

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,5 @@ exclude_lines = [
7272
"pragma: no cover",
7373
"raise NotImplementedError()",
7474
"if typing.TYPE_CHECKING:",
75+
"if TYPE_CHECKING:",
7576
]

setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@
8484
"torch>=2.3,<3.0",
8585
# For saving models
8686
"cloudpickle",
87-
# For reading logs
88-
"pandas",
89-
# Plotting learning curves
90-
"matplotlib",
9187
],
9288
extras_require={
9389
"tests": [
@@ -128,6 +124,9 @@
128124
# For atari games,
129125
"ale-py>=0.9.0",
130126
"pillow",
127+
# For plotting and loading results
128+
"pandas",
129+
"matplotlib",
131130
],
132131
},
133132
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",

stable_baselines3/common/logger.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from collections import defaultdict
88
from collections.abc import Mapping, Sequence
99
from io import TextIOBase
10-
from typing import Any, TextIO
10+
from typing import TYPE_CHECKING, Any, TextIO
1111

12-
import matplotlib.figure
1312
import numpy as np
14-
import pandas
1513
import torch as th
1614

15+
if TYPE_CHECKING:
16+
import matplotlib.figure
17+
1718
try:
1819
from torch.utils.tensorboard import SummaryWriter
1920
from torch.utils.tensorboard.summary import hparams
@@ -53,7 +54,7 @@ class Figure:
5354
:param close: if true, close the figure after logging it
5455
"""
5556

56-
def __init__(self, figure: matplotlib.figure.Figure, close: bool):
57+
def __init__(self, figure: "matplotlib.figure.Figure", close: bool):
5758
self.figure = figure
5859
self.close = close
5960

@@ -665,32 +666,3 @@ def configure(folder: str | None = None, format_strings: list[str] | None = None
665666
if len(format_strings) > 0 and format_strings != ["stdout"]:
666667
logger.log(f"Logging to {folder}")
667668
return logger
668-
669-
670-
# ================================================================
671-
# Readers
672-
# ================================================================
673-
674-
675-
def read_json(filename: str) -> pandas.DataFrame:
676-
"""
677-
read a json file using pandas
678-
679-
:param filename: the file path to read
680-
:return: the data in the json
681-
"""
682-
data = []
683-
with open(filename) as file_handler:
684-
for line in file_handler:
685-
data.append(json.loads(line))
686-
return pandas.DataFrame(data)
687-
688-
689-
def read_csv(filename: str) -> pandas.DataFrame:
690-
"""
691-
read a csv file using pandas
692-
693-
:param filename: the file path to read
694-
:return: the data in the csv
695-
"""
696-
return pandas.read_csv(filename, index_col=None, comment="#")

stable_baselines3/common/monitor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import os
66
import time
77
from glob import glob
8-
from typing import Any, SupportsFloat
8+
from typing import TYPE_CHECKING, Any, SupportsFloat
99

1010
import gymnasium as gym
11-
import pandas
1211
from gymnasium.core import ActType, ObsType
1312

13+
if TYPE_CHECKING:
14+
import pandas
15+
1416

1517
class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
1618
"""
@@ -227,13 +229,21 @@ def get_monitor_files(path: str) -> list[str]:
227229
return glob(os.path.join(path, "*" + Monitor.EXT))
228230

229231

230-
def load_results(path: str) -> pandas.DataFrame:
232+
def load_results(path: str) -> "pandas.DataFrame":
231233
"""
232234
Load all Monitor logs from a given directory path matching ``*monitor.csv``
233235
234236
:param path: the directory path containing the log file(s)
235237
:return: the logged data
236238
"""
239+
try:
240+
import pandas
241+
except ImportError as e:
242+
raise ImportError(
243+
"pandas is required for loading results. "
244+
"Install it with `pip install pandas` or install the extra dependencies with "
245+
"`pip install 'stable-baselines3[extra]'`."
246+
) from e
237247
monitor_files = get_monitor_files(path)
238248
if len(monitor_files) == 0:
239249
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")

stable_baselines3/common/results_plotter.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
11
from collections.abc import Callable
22

33
import numpy as np
4-
import pandas as pd
4+
5+
try:
6+
import pandas as pd
7+
except ImportError as e:
8+
raise ImportError(
9+
"pandas is required for plotting functionality. "
10+
"Install it with `pip install pandas` or install the extra dependencies with "
11+
"`pip install 'stable-baselines3[extra]'`."
12+
) from e
513

614
# import matplotlib
715
# matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
8-
from matplotlib import pyplot as plt
16+
try:
17+
from matplotlib import pyplot as plt
18+
except ImportError as e:
19+
raise ImportError(
20+
"matplotlib is required for plotting functionality. "
21+
"Install it with `pip install matplotlib` or install the extra dependencies with "
22+
"`pip install 'stable-baselines3[extra]'`."
23+
) from e
924

1025
from stable_baselines3.common.monitor import load_results
1126

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.9.0a1
1+
2.9.0a2

tests/test_logger.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib.util
2+
import json
23
import os
34
import sys
45
import time
@@ -8,6 +9,7 @@
89

910
import gymnasium as gym
1011
import numpy as np
12+
import pandas
1113
import pytest
1214
import torch as th
1315
from gymnasium import spaces
@@ -30,11 +32,34 @@
3032
Video,
3133
configure,
3234
make_output_format,
33-
read_csv,
34-
read_json,
3535
)
3636
from stable_baselines3.common.monitor import Monitor
3737

38+
39+
def read_csv(filename: str):
40+
"""
41+
read a csv file using pandas
42+
43+
:param filename: the file path to read
44+
:return: the data in the csv
45+
"""
46+
return pandas.read_csv(filename, index_col=None, comment="#")
47+
48+
49+
def read_json(filename: str):
50+
"""
51+
read a json file using pandas
52+
53+
:param filename: the file path to read
54+
:return: the data in the json
55+
"""
56+
data = []
57+
with open(filename) as file_handler:
58+
for line in file_handler:
59+
data.append(json.loads(line))
60+
return pandas.DataFrame(data)
61+
62+
3863
KEY_VALUES = {
3964
"test": 1,
4065
"b": -3.14,
@@ -634,3 +659,58 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path):
634659
assert logger.name_to_value["rollout/success_rate"] == 0.5
635660
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
636661
assert logger.name_to_value["rollout/success_rate"] == 0.8
662+
663+
664+
def test_pandas_import_error(tmp_path):
665+
"""Test that a clear ImportError is raised when pandas is not available."""
666+
# Mock the import to simulate pandas not being installed
667+
with mock.patch.dict("sys.modules", {"pandas": None}):
668+
# First, remove the modules from cache if they exist
669+
if "stable_baselines3.common.results_plotter" in sys.modules:
670+
del sys.modules["stable_baselines3.common.results_plotter"]
671+
if "stable_baselines3.common.monitor" in sys.modules:
672+
del sys.modules["stable_baselines3.common.monitor"]
673+
674+
# Test results_plotter raises ImportError at import time
675+
with pytest.raises(ImportError, match="pandas is required for plotting"):
676+
import stable_baselines3.common.results_plotter # noqa: F401
677+
678+
# Test load_results raises ImportError at call time
679+
# monitor module can still be imported (pandas import is lazy)
680+
from stable_baselines3.common.monitor import load_results
681+
682+
with pytest.raises(ImportError, match="pandas is required for loading results"):
683+
load_results(str(tmp_path))
684+
685+
686+
def test_matplotlib_import_error():
687+
"""Test that a clear ImportError is raised when matplotlib is not available."""
688+
# Mock the import to simulate matplotlib not being installed
689+
with mock.patch.dict("sys.modules", {"matplotlib": None, "matplotlib.pyplot": None}):
690+
# First, remove the module from cache if it exists
691+
if "stable_baselines3.common.results_plotter" in sys.modules:
692+
del sys.modules["stable_baselines3.common.results_plotter"]
693+
694+
# Test results_plotter raises ImportError at import time
695+
with pytest.raises(ImportError, match="matplotlib is required for plotting"):
696+
import stable_baselines3.common.results_plotter # noqa: F401
697+
698+
699+
def test_sb3_import_without_optional_deps():
700+
"""Test that SB3 core can be imported without matplotlib and pandas."""
701+
# Mock the imports to simulate optional dependencies not being installed
702+
with mock.patch.dict("sys.modules", {"pandas": None, "matplotlib": None, "matplotlib.pyplot": None}):
703+
# First, remove the modules from cache if they exist
704+
modules_to_remove = [key for key in sys.modules.keys() if key.startswith("stable_baselines3")]
705+
for module in modules_to_remove:
706+
del sys.modules[module]
707+
708+
# Core SB3 should still be importable
709+
from stable_baselines3 import A2C, DQN, PPO # noqa: F401
710+
711+
# Monitor should be importable (pandas import is lazy in load_results)
712+
from stable_baselines3.common.monitor import Monitor # noqa: F401
713+
714+
# But plotting module should fail
715+
with pytest.raises(ImportError):
716+
import stable_baselines3.common.results_plotter # noqa: F401

0 commit comments

Comments
 (0)