|
1 | 1 | import importlib.util |
| 2 | +import json |
2 | 3 | import os |
3 | 4 | import sys |
4 | 5 | import time |
|
8 | 9 |
|
9 | 10 | import gymnasium as gym |
10 | 11 | import numpy as np |
| 12 | +import pandas |
11 | 13 | import pytest |
12 | 14 | import torch as th |
13 | 15 | from gymnasium import spaces |
|
30 | 32 | Video, |
31 | 33 | configure, |
32 | 34 | make_output_format, |
33 | | - read_csv, |
34 | | - read_json, |
35 | 35 | ) |
36 | 36 | from stable_baselines3.common.monitor import Monitor |
37 | 37 |
|
| 38 | + |
| 39 | +def read_csv(filename: str): |
| 40 | + """ |
| 41 | + read a csv file using pandas |
| 42 | +
|
| 43 | + :param filename: the file path to read |
| 44 | + :return: the data in the csv |
| 45 | + """ |
| 46 | + return pandas.read_csv(filename, index_col=None, comment="#") |
| 47 | + |
| 48 | + |
| 49 | +def read_json(filename: str): |
| 50 | + """ |
| 51 | + read a json file using pandas |
| 52 | +
|
| 53 | + :param filename: the file path to read |
| 54 | + :return: the data in the json |
| 55 | + """ |
| 56 | + data = [] |
| 57 | + with open(filename) as file_handler: |
| 58 | + for line in file_handler: |
| 59 | + data.append(json.loads(line)) |
| 60 | + return pandas.DataFrame(data) |
| 61 | + |
| 62 | + |
38 | 63 | KEY_VALUES = { |
39 | 64 | "test": 1, |
40 | 65 | "b": -3.14, |
@@ -634,3 +659,58 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): |
634 | 659 | assert logger.name_to_value["rollout/success_rate"] == 0.5 |
635 | 660 | model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) |
636 | 661 | assert logger.name_to_value["rollout/success_rate"] == 0.8 |
| 662 | + |
| 663 | + |
| 664 | +def test_pandas_import_error(tmp_path): |
| 665 | + """Test that a clear ImportError is raised when pandas is not available.""" |
| 666 | + # Mock the import to simulate pandas not being installed |
| 667 | + with mock.patch.dict("sys.modules", {"pandas": None}): |
| 668 | + # First, remove the modules from cache if they exist |
| 669 | + if "stable_baselines3.common.results_plotter" in sys.modules: |
| 670 | + del sys.modules["stable_baselines3.common.results_plotter"] |
| 671 | + if "stable_baselines3.common.monitor" in sys.modules: |
| 672 | + del sys.modules["stable_baselines3.common.monitor"] |
| 673 | + |
| 674 | + # Test results_plotter raises ImportError at import time |
| 675 | + with pytest.raises(ImportError, match="pandas is required for plotting"): |
| 676 | + import stable_baselines3.common.results_plotter # noqa: F401 |
| 677 | + |
| 678 | + # Test load_results raises ImportError at call time |
| 679 | + # monitor module can still be imported (pandas import is lazy) |
| 680 | + from stable_baselines3.common.monitor import load_results |
| 681 | + |
| 682 | + with pytest.raises(ImportError, match="pandas is required for loading results"): |
| 683 | + load_results(str(tmp_path)) |
| 684 | + |
| 685 | + |
| 686 | +def test_matplotlib_import_error(): |
| 687 | + """Test that a clear ImportError is raised when matplotlib is not available.""" |
| 688 | + # Mock the import to simulate matplotlib not being installed |
| 689 | + with mock.patch.dict("sys.modules", {"matplotlib": None, "matplotlib.pyplot": None}): |
| 690 | + # First, remove the module from cache if it exists |
| 691 | + if "stable_baselines3.common.results_plotter" in sys.modules: |
| 692 | + del sys.modules["stable_baselines3.common.results_plotter"] |
| 693 | + |
| 694 | + # Test results_plotter raises ImportError at import time |
| 695 | + with pytest.raises(ImportError, match="matplotlib is required for plotting"): |
| 696 | + import stable_baselines3.common.results_plotter # noqa: F401 |
| 697 | + |
| 698 | + |
| 699 | +def test_sb3_import_without_optional_deps(): |
| 700 | + """Test that SB3 core can be imported without matplotlib and pandas.""" |
| 701 | + # Mock the imports to simulate optional dependencies not being installed |
| 702 | + with mock.patch.dict("sys.modules", {"pandas": None, "matplotlib": None, "matplotlib.pyplot": None}): |
| 703 | + # First, remove the modules from cache if they exist |
| 704 | + modules_to_remove = [key for key in sys.modules.keys() if key.startswith("stable_baselines3")] |
| 705 | + for module in modules_to_remove: |
| 706 | + del sys.modules[module] |
| 707 | + |
| 708 | + # Core SB3 should still be importable |
| 709 | + from stable_baselines3 import A2C, DQN, PPO # noqa: F401 |
| 710 | + |
| 711 | + # Monitor should be importable (pandas import is lazy in load_results) |
| 712 | + from stable_baselines3.common.monitor import Monitor # noqa: F401 |
| 713 | + |
| 714 | + # But plotting module should fail |
| 715 | + with pytest.raises(ImportError): |
| 716 | + import stable_baselines3.common.results_plotter # noqa: F401 |
0 commit comments