Skip to content

Commit 9426e0b

Browse files
author
Jan Michelfeit
committed
#625 use batching for entropy computation to avoid memory issues
1 parent 7c3470e commit 9426e0b

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/imitation/util/util.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,16 +377,19 @@ def compute_state_entropy(
377377
A tensor containing the state entropy for `obs`.
378378
"""
379379
assert obs.shape[1:] == all_obs.shape[1:]
380+
batch_size = 500
380381
with th.no_grad():
381382
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
382-
distances_tensor = th.linalg.vector_norm(
383-
obs[:, None] - all_obs[None, :],
384-
dim=non_batch_dimensions,
385-
ord=2,
386-
)
387-
388-
# Note that we take the k+1'th value because the closest neighbor to
389-
# a point is itself, which we want to skip.
390-
assert distances_tensor.shape[-1] > k
391-
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
392-
return knn_dists
383+
dists = []
384+
for idx in range(len(all_obs) // batch_size + 1):
385+
start = idx * batch_size
386+
end = (idx + 1) * batch_size
387+
distances_tensor = th.linalg.vector_norm(
388+
obs[:, None] - all_obs[None, start:end],
389+
dim=non_batch_dimensions,
390+
ord=2,
391+
)
392+
dists.append(distances_tensor)
393+
dists = th.cat(dists, dim=1)
394+
knn_dists = th.kthvalue(dists, k=k + 1, dim=1).values
395+
return knn_dists

0 commit comments

Comments
 (0)