Skip to content

Commit e6f569c

Browse files
committed
fix: reset before collect on watch
1 parent b006cd5 commit e6f569c

File tree

10 files changed

+10
-10
lines changed

10 files changed

+10
-10
lines changed

examples/atari/atari_c51.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def watch() -> None:
174174
stack_num=args.frames_stack,
175175
)
176176
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
177-
result = collector.collect(n_step=args.buffer_size)
177+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
178178
print(f"Save buffer into {args.save_buffer_name}")
179179
# Unfortunately, pickle will cause oom with 1M buffer size
180180
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def watch() -> None:
216216
stack_num=args.frames_stack,
217217
)
218218
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
219-
result = collector.collect(n_step=args.buffer_size)
219+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
220220
print(f"Save buffer into {args.save_buffer_name}")
221221
# Unfortunately, pickle will cause oom with 1M buffer size
222222
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_fqf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def watch() -> None:
187187
stack_num=args.frames_stack,
188188
)
189189
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
190-
result = collector.collect(n_step=args.buffer_size)
190+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
191191
print(f"Save buffer into {args.save_buffer_name}")
192192
# Unfortunately, pickle will cause oom with 1M buffer size
193193
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_iqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def watch() -> None:
184184
stack_num=args.frames_stack,
185185
)
186186
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
187-
result = collector.collect(n_step=args.buffer_size)
187+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
188188
print(f"Save buffer into {args.save_buffer_name}")
189189
# Unfortunately, pickle will cause oom with 1M buffer size
190190
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def watch() -> None:
244244
stack_num=args.frames_stack,
245245
)
246246
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
247-
result = collector.collect(n_step=args.buffer_size)
247+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
248248
print(f"Save buffer into {args.save_buffer_name}")
249249
# Unfortunately, pickle will cause oom with 1M buffer size
250250
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_qrdqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def watch() -> None:
178178
stack_num=args.frames_stack,
179179
)
180180
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
181-
result = collector.collect(n_step=args.buffer_size)
181+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
182182
print(f"Save buffer into {args.save_buffer_name}")
183183
# Unfortunately, pickle will cause oom with 1M buffer size
184184
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_rainbow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def watch() -> None:
219219
beta=args.beta,
220220
)
221221
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
222-
result = collector.collect(n_step=args.buffer_size)
222+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
223223
print(f"Save buffer into {args.save_buffer_name}")
224224
# Unfortunately, pickle will cause oom with 1M buffer size
225225
buffer.save_hdf5(args.save_buffer_name)

examples/atari/atari_sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def watch() -> None:
227227
stack_num=args.frames_stack,
228228
)
229229
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
230-
result = collector.collect(n_step=args.buffer_size)
230+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
231231
print(f"Save buffer into {args.save_buffer_name}")
232232
# Unfortunately, pickle will cause oom with 1M buffer size
233233
buffer.save_hdf5(args.save_buffer_name)

examples/vizdoom/vizdoom_c51.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def watch() -> None:
180180
stack_num=args.frames_stack,
181181
)
182182
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
183-
result = collector.collect(n_step=args.buffer_size)
183+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
184184
print(f"Save buffer into {args.save_buffer_name}")
185185
# Unfortunately, pickle will cause oom with 1M buffer size
186186
buffer.save_hdf5(args.save_buffer_name)

examples/vizdoom/vizdoom_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def watch() -> None:
246246
stack_num=args.frames_stack,
247247
)
248248
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
249-
result = collector.collect(n_step=args.buffer_size)
249+
result = collector.collect(n_step=args.buffer_size, reset_before_collect=True)
250250
print(f"Save buffer into {args.save_buffer_name}")
251251
# Unfortunately, pickle will cause oom with 1M buffer size
252252
buffer.save_hdf5(args.save_buffer_name)

0 commit comments

Comments
 (0)