|
19 | 19 |
|
20 | 20 | import time |
21 | 21 | from typing import Sequence |
| 22 | +import itertools |
22 | 23 |
|
23 | 24 | from absl import app |
24 | 25 | from absl import logging |
@@ -75,22 +76,29 @@ def main(_): |
75 | 76 | next_timestep = environment.step(action) |
76 | 77 | adder.add(action, next_timestep, extras=()) |
77 | 78 |
|
78 | | - for batch_size in [256, 256 * 8, 256 * 64]: |
79 | | - for prefetch_size in [0, 1, 4]: |
80 | | - print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') |
81 | | - ds = datasets.make_reverb_dataset( |
82 | | - table='default', |
83 | | - server_address=replay_client.server_address, |
84 | | - batch_size=batch_size, |
85 | | - prefetch_size=prefetch_size, |
86 | | - ) |
87 | | - it = ds.as_numpy_iterator() |
88 | | - |
89 | | - for iteration in range(3): |
90 | | - t = time.time() |
91 | | - for _ in range(1000): |
92 | | - _ = next(it) |
93 | | - print(f'Iteration {iteration} finished in {time.time() - t}s') |
| 79 | + batch_sizes = [256, 256 * 8, 256 * 64] |
| 80 | + prefetch_sizes = [0, 1, 4] |
| 81 | + num_batches_per_iteration = 1000 |
| 82 | + |
| 83 | + for batch_size, prefetch_size in itertools.product(batch_sizes, prefetch_sizes): |
| 84 | + print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}') |
| 85 | + ds = datasets.make_reverb_dataset( |
| 86 | + table='default', |
| 87 | + server_address=replay_client.server_address, |
| 88 | + batch_size=batch_size, |
| 89 | + prefetch_size=prefetch_size, |
| 90 | + ) |
| 91 | + it = ds.as_numpy_iterator() |
| 92 | + |
| 93 | + for iteration in range(3): |
| 94 | + start = time.time() |
| 95 | + for _ in range(num_batches_per_iteration): |
| 96 | + _ = next(it) |
| 97 | + end = time.time() |
| 98 | + duration_s = end - start |
| 99 | + samples_per_second = batch_size * num_batches_per_iteration / duration_s |
| 100 | + print(f'Iteration {iteration} finished in {duration_s:_.02}s with ' |
| 101 | + f'{samples_per_second:_.2f} samples/s.') |
94 | 102 |
|
95 | 103 |
|
96 | 104 | if __name__ == '__main__': |
|
0 commit comments