44
44
"f" : np .array (1 ),
45
45
"g" : np .array ([[[1 ]]]),
46
46
"h" : 'this ", ;is a \n tes:,t' ,
47
+ "i" : th .ones (3 ),
47
48
}
48
49
49
50
KEY_EXCLUDED = {}
@@ -176,6 +177,9 @@ def test_main(tmp_path):
176
177
logger .record_mean ("b" , - 22.5 )
177
178
logger .record_mean ("b" , - 44.4 )
178
179
logger .record ("a" , 5.5 )
180
+ # Converted to string:
181
+ logger .record ("hist1" , th .ones (2 ))
182
+ logger .record ("hist2" , np .ones (2 ))
179
183
logger .dump ()
180
184
181
185
logger .record ("a" , "longasslongasslongasslongasslongasslongassvalue" )
@@ -241,7 +245,7 @@ def is_moviepy_installed():
241
245
242
246
243
247
@pytest .mark .parametrize ("unsupported_format" , ["stdout" , "log" , "json" , "csv" ])
244
- def test_report_video_to_unsupported_format_raises_error (tmp_path , unsupported_format ):
248
+ def test_unsupported_video_format (tmp_path , unsupported_format ):
245
249
writer = make_output_format (unsupported_format , tmp_path )
246
250
247
251
with pytest .raises (FormatUnsupportedError ) as exec_info :
@@ -251,6 +255,54 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f
251
255
writer .close ()
252
256
253
257
258
+ @pytest .mark .parametrize (
259
+ "histogram" ,
260
+ [
261
+ th .rand (100 ),
262
+ np .random .rand (100 ),
263
+ np .ones (1 ),
264
+ np .ones (1 , dtype = "int" ),
265
+ ],
266
+ )
267
+ def test_log_histogram (tmp_path , read_log , histogram ):
268
+ pytest .importorskip ("tensorboard" )
269
+
270
+ writer = make_output_format ("tensorboard" , tmp_path )
271
+ writer .write ({"data" : histogram }, key_excluded = {"data" : ()})
272
+
273
+ log = read_log ("tensorboard" )
274
+
275
+ assert not log .empty
276
+ assert any ("data" in line for line in log .lines )
277
+ assert any ("Histogram" in line for line in log .lines )
278
+
279
+ writer .close ()
280
+
281
+
282
+ @pytest .mark .parametrize (
283
+ "histogram" ,
284
+ [
285
+ list (np .random .rand (100 )),
286
+ tuple (np .random .rand (100 )),
287
+ "1 2 3 4" ,
288
+ np .ones (1 ).item (),
289
+ th .ones (1 ).item (),
290
+ ],
291
+ )
292
+ def test_unsupported_type_histogram (tmp_path , read_log , histogram ):
293
+ """
294
+ Check that other types aren't accidentally logged as a Histogram
295
+ """
296
+ pytest .importorskip ("tensorboard" )
297
+
298
+ writer = make_output_format ("tensorboard" , tmp_path )
299
+ writer .write ({"data" : histogram }, key_excluded = {"data" : ()})
300
+
301
+ assert all ("Histogram" not in line for line in read_log ("tensorboard" ).lines )
302
+
303
+ writer .close ()
304
+
305
+
254
306
def test_report_image_to_tensorboard (tmp_path , read_log ):
255
307
pytest .importorskip ("tensorboard" )
256
308
@@ -263,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log):
263
315
264
316
265
317
@pytest .mark .parametrize ("unsupported_format" , ["stdout" , "log" , "json" , "csv" ])
266
- def test_report_image_to_unsupported_format_raises_error (tmp_path , unsupported_format ):
318
+ def test_unsupported_image_format (tmp_path , unsupported_format ):
267
319
writer = make_output_format (unsupported_format , tmp_path )
268
320
269
321
with pytest .raises (FormatUnsupportedError ) as exec_info :
@@ -287,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log):
287
339
288
340
289
341
@pytest .mark .parametrize ("unsupported_format" , ["stdout" , "log" , "json" , "csv" ])
290
- def test_report_figure_to_unsupported_format_raises_error (tmp_path , unsupported_format ):
342
+ def test_unsupported_figure_format (tmp_path , unsupported_format ):
291
343
writer = make_output_format (unsupported_format , tmp_path )
292
344
293
345
with pytest .raises (FormatUnsupportedError ) as exec_info :
@@ -300,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_
300
352
301
353
302
354
@pytest .mark .parametrize ("unsupported_format" , ["stdout" , "log" , "json" , "csv" ])
303
- def test_report_hparam_to_unsupported_format_raises_error (tmp_path , unsupported_format ):
355
+ def test_unsupported_hparam (tmp_path , unsupported_format ):
304
356
writer = make_output_format (unsupported_format , tmp_path )
305
357
306
358
with pytest .raises (FormatUnsupportedError ) as exec_info :
@@ -419,9 +471,9 @@ def test_fps_no_div_zero(algo):
419
471
model .learn (total_timesteps = 100 )
420
472
421
473
422
- def test_human_output_format_no_crash_on_same_keys_different_tags ():
423
- o = HumanOutputFormat (sys .stdout , max_length = 60 )
424
- o .write (
474
+ def test_human_output_same_keys_different_tags ():
475
+ human_out = HumanOutputFormat (sys .stdout , max_length = 60 )
476
+ human_out .write (
425
477
{"key1/foo" : "value1" , "key1/bar" : "value2" , "key2/bizz" : "value3" , "key2/foo" : "value4" },
426
478
{"key1/foo" : None , "key2/bizz" : None , "key1/bar" : None , "key2/foo" : None },
427
479
)
@@ -439,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size):
439
491
440
492
441
493
@pytest .mark .parametrize ("base_class" , [object , TextIOBase ])
442
- def test_human_output_format_custom_test_io (base_class ):
494
+ def test_human_out_custom_text_io (base_class ):
443
495
class DummyTextIO (base_class ):
444
496
def __init__ (self ) -> None :
445
497
super ().__init__ ()
@@ -531,7 +583,7 @@ def step(self, action):
531
583
return self .observation_space .sample (), 0.0 , False , truncated , info
532
584
533
585
534
- def test_rollout_success_rate_on_policy_algorithm (tmp_path ):
586
+ def test_rollout_success_rate_onpolicy_algo (tmp_path ):
535
587
"""
536
588
Test if the rollout/success_rate information is correctly logged with on policy algorithms
537
589
0 commit comments