Skip to content

Commit d23d98b

Browse files
author
Jan Michelfeit
committed
#625 use batching for entropy computation to avoid memory issues
1 parent 7c3470e commit d23d98b

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def has_pretraining(self) -> bool:
8484
The value can be used, e.g., when allocating time-steps for pre-training.
8585
8686
By default, True is returned if the unsupervised_pretrain() method is not
87-
overriden, bud subclasses may choose to override this behavior.
87+
overridden, bud subclasses may choose to override this behavior.
8888
8989
Returns:
9090
True if this generator has a pre-training phase, False otherwise

src/imitation/util/util.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -377,16 +377,19 @@ def compute_state_entropy(
377377
A tensor containing the state entropy for `obs`.
378378
"""
379379
assert obs.shape[1:] == all_obs.shape[1:]
380+
batch_size = 500
380381
with th.no_grad():
381382
non_batch_dimensions = tuple(range(2, len(obs.shape) + 1))
382-
distances_tensor = th.linalg.vector_norm(
383-
obs[:, None] - all_obs[None, :],
384-
dim=non_batch_dimensions,
385-
ord=2,
386-
)
387-
388-
# Note that we take the k+1'th value because the closest neighbor to
389-
# a point is itself, which we want to skip.
390-
assert distances_tensor.shape[-1] > k
391-
knn_dists = th.kthvalue(distances_tensor, k=k + 1, dim=1).values
383+
dists: List[th.Tensor] = []
384+
for idx in range(len(all_obs) // batch_size + 1):
385+
start = idx * batch_size
386+
end = (idx + 1) * batch_size
387+
distances_tensor = th.linalg.vector_norm(
388+
obs[:, None] - all_obs[None, start:end],
389+
dim=non_batch_dimensions,
390+
ord=2,
391+
)
392+
dists.append(distances_tensor)
393+
all_dists = th.cat(dists, dim=1)
394+
knn_dists = th.kthvalue(all_dists, k=k + 1, dim=1).values
392395
return knn_dists

0 commit comments

Comments
 (0)