how to pmap sync batch_stats? maybe a bug? #3201
Unanswered
zhenlan0426
asked this question in
Q&A
Replies: 1 comment
-
This simple example shows this empirically. import numpy as np
x = np.random.randn(128)
# var over the entire data, var over device first and then average across device
print(np.var(x),x.reshape(-1,8).var(0).mean()) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
@cgarciae kindly point me to this example for pmap with batchNorm. And the batch_stats are sync across devices with pmean.
And this is how batch_stats is updated inside BatchNorm module,
The pmean of var across device wont be the same as the true var calculated over the entire data (device * batch) if we dont divided the data into device. This way the variance stats in batchNorm seems off.
Beta Was this translation helpful? Give feedback.
All reactions