@@ -377,16 +377,19 @@ def compute_state_entropy(
377
377
A tensor containing the state entropy for `obs`.
378
378
"""
379
379
assert obs .shape [1 :] == all_obs .shape [1 :]
380
+ batch_size = 500
380
381
with th .no_grad ():
381
382
non_batch_dimensions = tuple (range (2 , len (obs .shape ) + 1 ))
382
- distances_tensor = th .linalg .vector_norm (
383
- obs [:, None ] - all_obs [None , :],
384
- dim = non_batch_dimensions ,
385
- ord = 2 ,
386
- )
387
-
388
- # Note that we take the k+1'th value because the closest neighbor to
389
- # a point is itself, which we want to skip.
390
- assert distances_tensor .shape [- 1 ] > k
391
- knn_dists = th .kthvalue (distances_tensor , k = k + 1 , dim = 1 ).values
392
- return knn_dists
383
+ dists = []
384
+ for idx in range (len (all_obs ) // batch_size + 1 ):
385
+ start = idx * batch_size
386
+ end = (idx + 1 ) * batch_size
387
+ distances_tensor = th .linalg .vector_norm (
388
+ obs [:, None ] - all_obs [None , start :end ],
389
+ dim = non_batch_dimensions ,
390
+ ord = 2 ,
391
+ )
392
+ dists .append (distances_tensor )
393
+ dists = th .cat (dists , dim = 1 )
394
+ knn_dists = th .kthvalue (dists , k = k + 1 , dim = 1 ).values
395
+ return knn_dists
0 commit comments