|
1 | 1 | from numbers import Number |
2 | | -from typing import List, Union |
| 2 | +from typing import List, Optional, Union |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import torch |
@@ -70,15 +70,31 @@ class RunningMeanStd(object): |
70 | 70 | """Calculates the running mean and std of a data stream. |
71 | 71 |
|
72 | 72 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm |
| 73 | +
|
| 74 | + :param mean: the initial mean estimation for data array. Default to 0. |
| 75 | + :param std: the initial standard error estimation for data array. Default to 1. |
| 76 | + :param float clip_max: the maximum absolute value for data array. Default to |
| 77 | + 10.0. |
| 78 | + :param float epsilon: To avoid division by zero. |
73 | 79 | """ |
74 | 80 |
|
75 | 81 | def __init__( |
76 | 82 | self, |
77 | 83 | mean: Union[float, np.ndarray] = 0.0, |
78 | | - std: Union[float, np.ndarray] = 1.0 |
| 84 | + std: Union[float, np.ndarray] = 1.0, |
| 85 | + clip_max: Optional[float] = 10.0, |
| 86 | + epsilon: float = np.finfo(np.float32).eps.item(), |
79 | 87 | ) -> None: |
80 | 88 | self.mean, self.var = mean, std |
| 89 | + self.clip_max = clip_max |
81 | 90 | self.count = 0 |
| 91 | + self.eps = epsilon |
| 92 | + |
| 93 | + def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]: |
| 94 | + data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps) |
| 95 | + if self.clip_max: |
| 96 | + data_array = np.clip(data_array, -self.clip_max, self.clip_max) |
| 97 | + return data_array |
82 | 98 |
|
83 | 99 | def update(self, data_array: np.ndarray) -> None: |
84 | 100 | """Add a batch of item into RMS with the same shape, modify mean/var/count.""" |
|
0 commit comments