Skip to content

Commit 8eb5b55

Browse files
authored
Merge branch 'master' into add-precommit-config-yaml
2 parents 246008d + 4a1137b commit 8eb5b55

File tree

4 files changed

+67
-13
lines changed

4 files changed

+67
-13
lines changed

docs/misc/changelog.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a7 (WIP)
6+
Release 2.4.0a8 (WIP)
77
--------------------------
88

99
.. note::
@@ -19,6 +19,7 @@ Breaking Changes:
1919
New Features:
2020
^^^^^^^^^^^^^
2121
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
22+
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
2223

2324
Bug Fixes:
2425
^^^^^^^^^^

stable_baselines3/common/logger.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, .
412412
else:
413413
self.writer.add_scalar(key, value, step)
414414

415-
if isinstance(value, th.Tensor):
416-
self.writer.add_histogram(key, value, step)
415+
if isinstance(value, (th.Tensor, np.ndarray)):
416+
# Convert to Torch so it works with numpy<1.24 and torch<2.0
417+
self.writer.add_histogram(key, th.as_tensor(value), step)
417418

418419
if isinstance(value, Video):
419420
self.writer.add_video(key, value.frames, step, value.fps)

stable_baselines3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a7
1+
2.4.0a8

tests/test_logger.py

+61-9
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"f": np.array(1),
4545
"g": np.array([[[1]]]),
4646
"h": 'this ", ;is a \n tes:,t',
47+
"i": th.ones(3),
4748
}
4849

4950
KEY_EXCLUDED = {}
@@ -176,6 +177,9 @@ def test_main(tmp_path):
176177
logger.record_mean("b", -22.5)
177178
logger.record_mean("b", -44.4)
178179
logger.record("a", 5.5)
180+
# Converted to string:
181+
logger.record("hist1", th.ones(2))
182+
logger.record("hist2", np.ones(2))
179183
logger.dump()
180184

181185
logger.record("a", "longasslongasslongasslongasslongasslongassvalue")
@@ -241,7 +245,7 @@ def is_moviepy_installed():
241245

242246

243247
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
244-
def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_format):
248+
def test_unsupported_video_format(tmp_path, unsupported_format):
245249
writer = make_output_format(unsupported_format, tmp_path)
246250

247251
with pytest.raises(FormatUnsupportedError) as exec_info:
@@ -251,6 +255,54 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f
251255
writer.close()
252256

253257

258+
@pytest.mark.parametrize(
259+
"histogram",
260+
[
261+
th.rand(100),
262+
np.random.rand(100),
263+
np.ones(1),
264+
np.ones(1, dtype="int"),
265+
],
266+
)
267+
def test_log_histogram(tmp_path, read_log, histogram):
268+
pytest.importorskip("tensorboard")
269+
270+
writer = make_output_format("tensorboard", tmp_path)
271+
writer.write({"data": histogram}, key_excluded={"data": ()})
272+
273+
log = read_log("tensorboard")
274+
275+
assert not log.empty
276+
assert any("data" in line for line in log.lines)
277+
assert any("Histogram" in line for line in log.lines)
278+
279+
writer.close()
280+
281+
282+
@pytest.mark.parametrize(
283+
"histogram",
284+
[
285+
list(np.random.rand(100)),
286+
tuple(np.random.rand(100)),
287+
"1 2 3 4",
288+
np.ones(1).item(),
289+
th.ones(1).item(),
290+
],
291+
)
292+
def test_unsupported_type_histogram(tmp_path, read_log, histogram):
293+
"""
294+
Check that other types aren't accidentally logged as a Histogram
295+
"""
296+
pytest.importorskip("tensorboard")
297+
298+
writer = make_output_format("tensorboard", tmp_path)
299+
writer.write({"data": histogram}, key_excluded={"data": ()})
300+
301+
assert all("Histogram" not in line for line in read_log("tensorboard").lines)
302+
303+
writer.close()
304+
305+
254306
def test_report_image_to_tensorboard(tmp_path, read_log):
255307
pytest.importorskip("tensorboard")
256308

@@ -263,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log):
263315

264316

265317
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
266-
def test_report_image_to_unsupported_format_raises_error(tmp_path, unsupported_format):
318+
def test_unsupported_image_format(tmp_path, unsupported_format):
267319
writer = make_output_format(unsupported_format, tmp_path)
268320

269321
with pytest.raises(FormatUnsupportedError) as exec_info:
@@ -287,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log):
287339

288340

289341
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
290-
def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_format):
342+
def test_unsupported_figure_format(tmp_path, unsupported_format):
291343
writer = make_output_format(unsupported_format, tmp_path)
292344

293345
with pytest.raises(FormatUnsupportedError) as exec_info:
@@ -300,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_
300352

301353

302354
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
303-
def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format):
355+
def test_unsupported_hparam(tmp_path, unsupported_format):
304356
writer = make_output_format(unsupported_format, tmp_path)
305357

306358
with pytest.raises(FormatUnsupportedError) as exec_info:
@@ -419,9 +471,9 @@ def test_fps_no_div_zero(algo):
419471
model.learn(total_timesteps=100)
420472

421473

422-
def test_human_output_format_no_crash_on_same_keys_different_tags():
423-
o = HumanOutputFormat(sys.stdout, max_length=60)
424-
o.write(
474+
def test_human_output_same_keys_different_tags():
475+
human_out = HumanOutputFormat(sys.stdout, max_length=60)
476+
human_out.write(
425477
{"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"},
426478
{"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None},
427479
)
@@ -439,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):
439491

440492

441493
@pytest.mark.parametrize("base_class", [object, TextIOBase])
442-
def test_human_output_format_custom_test_io(base_class):
494+
def test_human_out_custom_text_io(base_class):
443495
class DummyTextIO(base_class):
444496
def __init__(self) -> None:
445497
super().__init__()
@@ -531,7 +583,7 @@ def step(self, action):
531583
return self.observation_space.sample(), 0.0, False, truncated, info
532584

533585

534-
def test_rollout_success_rate_on_policy_algorithm(tmp_path):
586+
def test_rollout_success_rate_onpolicy_algo(tmp_path):
535587
"""
536588
Test if the rollout/success_rate information is correctly logged with on policy algorithms
537589

0 commit comments

Comments
 (0)