how to pmap sync batch_stats? maybe a bug? #3201
Unanswered
zhenlan0426
asked this question in
Q&A
Replies: 1 comment
-
This simple example shows this empirically. import numpy as np
x = np.random.randn(128)
# var over the entire data, var over device first and then average across device
print(np.var(x),x.reshape(-1,8).var(0).mean()) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
@cgarciae kindly point me to this example for pmap with batchNorm. And the batch_stats are sync across devices with pmean.
And this is how batch_stats is updated inside BatchNorm module,
The pmean of var across device wont be the same as the true var calculated over the entire data (device * batch) if we dont divided the data into device. This way the variance stats in batchNorm seems off.
Beta Was this translation helpful? Give feedback.
All reactions