Skip to content

Commit e6af314

Browse files
committed
Add samples per second logging for reverb_dataset.py
1 parent 98c4204 commit e6af314

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

acme/datasets/reverb_benchmark.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import time
2121
from typing import Sequence
22+
import itertools
2223

2324
from absl import app
2425
from absl import logging
@@ -75,22 +76,29 @@ def main(_):
7576
next_timestep = environment.step(action)
7677
adder.add(action, next_timestep, extras=())
7778

78-
for batch_size in [256, 256 * 8, 256 * 64]:
79-
for prefetch_size in [0, 1, 4]:
80-
print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
81-
ds = datasets.make_reverb_dataset(
82-
table='default',
83-
server_address=replay_client.server_address,
84-
batch_size=batch_size,
85-
prefetch_size=prefetch_size,
86-
)
87-
it = ds.as_numpy_iterator()
88-
89-
for iteration in range(3):
90-
t = time.time()
91-
for _ in range(1000):
92-
_ = next(it)
93-
print(f'Iteration {iteration} finished in {time.time() - t}s')
79+
batch_sizes = [256, 256 * 8, 256 * 64]
80+
prefetch_sizes = [0, 1, 4]
81+
num_batches_per_iteration = 1000
82+
83+
for batch_size, prefetch_size in itertools.product(batch_sizes, prefetch_sizes):
84+
print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
85+
ds = datasets.make_reverb_dataset(
86+
table='default',
87+
server_address=replay_client.server_address,
88+
batch_size=batch_size,
89+
prefetch_size=prefetch_size,
90+
)
91+
it = ds.as_numpy_iterator()
92+
93+
for iteration in range(3):
94+
start = time.time()
95+
for _ in range(num_batches_per_iteration):
96+
_ = next(it)
97+
end = time.time()
98+
duration_s = end - start
99+
samples_per_second = batch_size * num_batches_per_iteration / duration_s
100+
print(f'Iteration {iteration} finished in {duration_s:_.02f}s with '
101+
f'{samples_per_second:_.02f} samples/s.')
94102

95103

96104
if __name__ == '__main__':

0 commit comments

Comments
 (0)