Skip to content

Commit 1fdfc74

Browse files
author
Jan Michelfeit
committed
#625 use RunningNorm instead of RunningMeanAndVar
1 parent 50ec092 commit 1fdfc74

File tree

4 files changed

+9
-84
lines changed

4 files changed

+9
-84
lines changed

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from imitation.rewards.reward_function import RewardFn
1111
from imitation.util import util
12+
from imitation.util.networks import RunningNorm
1213

1314

1415
def _samples_to_reward_fn_input(
@@ -148,7 +149,7 @@ def __init__(
148149
self.sample_count = 0
149150
self.k = k
150151
# TODO support n_envs > 1
151-
self.entropy_stats = util.RunningMeanAndVar(shape=(1,))
152+
self.entropy_stats = RunningNorm(1)
152153
self.entropy_as_reward_samples = entropy_as_reward_samples
153154

154155
def sample(self, *args, **kwargs):
@@ -173,10 +174,8 @@ def sample(self, *args, **kwargs):
173174
self.k,
174175
)
175176

176-
# Normalize to have mean of 0 and standard deviation of 1
177-
self.entropy_stats.update(entropies)
178-
entropies -= self.entropy_stats.running_mean
179-
entropies /= self.entropy_stats.std
177+
# Normalize to have mean of 0 and standard deviation of 1 according to running stats
178+
entropies = self.entropy_stats.forward(entropies)
180179

181180
entropies_th = (
182181
util.safe_to_tensor(entropies)

src/imitation/util/networks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ def update_stats(self, batch: th.Tensor) -> None:
121121
tot_count = self.count + batch_count
122122
self.running_mean += delta * batch_count / tot_count
123123

124-
self.running_var *= self.count
125-
self.running_var += batch_var * batch_count
126-
self.running_var += th.square(delta) * self.count * batch_count / tot_count
127-
self.running_var /= tot_count
124+
m_a = self.running_var * self.count
125+
m_b = batch_var * batch_count
126+
M2 = m_a + m_b + th.square(delta) * self.count * batch_count / tot_count
127+
self.running_var = M2 / tot_count
128128

129-
self.count += batch_count
129+
self.count = tot_count
130130

131131

132132
class EMANorm(BaseNorm):

src/imitation/util/util.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -361,46 +361,6 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
361361
return first_element, return_iterable
362362

363363

364-
class RunningMeanAndVar:
365-
"""Stores a running mean and variance using Wellford's algorithm."""
366-
367-
def __init__(
368-
self,
369-
shape: Tuple[int, ...] = (),
370-
device: Optional[str] = None,
371-
) -> None:
372-
"""Initialize blank mean, variance, count."""
373-
self.running_mean = th.zeros(shape, device=device)
374-
self.M2 = th.zeros(shape, device=device)
375-
self.count = 0
376-
377-
def update(self, batch: th.Tensor) -> None:
378-
"""Update the mean and variance with a batch `x`."""
379-
with th.no_grad():
380-
batch_mean = th.mean(batch, dim=0)
381-
batch_var = th.var(batch, dim=0, unbiased=False)
382-
batch_count = batch.shape[0]
383-
384-
delta = batch_mean - self.running_mean
385-
tot_count = self.count + batch_count
386-
self.running_mean += delta * batch_count / tot_count
387-
388-
self.M2 += batch_var * batch_count
389-
self.M2 += th.square(delta) * self.count * batch_count / tot_count
390-
391-
self.count += batch_count
392-
393-
@property
394-
def var(self) -> th.Tensor:
395-
"""Returns the unbiased estimate of the variances."""
396-
return self.M2 / (self.count - 1)
397-
398-
@property
399-
def std(self) -> th.Tensor:
400-
"""Returns the unbiased estimate of the standard deviations."""
401-
return np.sqrt(self.var)
402-
403-
404364
def compute_state_entropy(
405365
obs: th.Tensor,
406366
all_obs: th.Tensor,

tests/util/test_util.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -120,40 +120,6 @@ def test_tensor_iter_norm():
120120
util.tensor_iter_norm(tensor_list, ord=0.0)
121121

122122

123-
def test_RunningMeanAndVar():
124-
running_stats = util.RunningMeanAndVar(shape=(3, 4))
125-
data = th.normal(mean=10 * th.ones(size=(20, 3, 4), dtype=th.double))
126-
127-
first_half = data[:10]
128-
running_stats.update(first_half)
129-
np.testing.assert_allclose(
130-
running_stats.running_mean,
131-
first_half.mean(dim=0),
132-
atol=1e-5,
133-
rtol=1e-4,
134-
)
135-
np.testing.assert_allclose(
136-
running_stats.var,
137-
first_half.var(dim=0),
138-
atol=1e-5,
139-
rtol=1e-4,
140-
)
141-
142-
running_stats.update(data[10:])
143-
np.testing.assert_allclose(
144-
running_stats.running_mean,
145-
data.mean(dim=0),
146-
atol=1e-5,
147-
rtol=1e-4,
148-
)
149-
np.testing.assert_allclose(
150-
running_stats.var,
151-
data.var(dim=0),
152-
atol=1e-5,
153-
rtol=1e-4,
154-
)
155-
156-
157123
def test_compute_state_entropy_1d():
158124
all_obs = th.arange(10, dtype=th.float).unsqueeze(1)
159125
obs = all_obs[4:6]

0 commit comments

Comments
 (0)