From 47fd9688789eb2cf58578566d3b69c2395831880 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Wed, 11 Jan 2023 15:09:21 -0800 Subject: [PATCH 1/2] Adding CI check for batchnor --- .../test_unit_layer_batch_normalization.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/ci_test/unit_tests/test_unit_layer_batch_normalization.py b/ci_test/unit_tests/test_unit_layer_batch_normalization.py index 4bf41fc9c46..9d1225eae8d 100644 --- a/ci_test/unit_tests/test_unit_layer_batch_normalization.py +++ b/ci_test/unit_tests/test_unit_layer_batch_normalization.py @@ -116,6 +116,43 @@ def construct_model(lbann): obj.append(z) metrics.append(lbann.Metric(z, name='global statistics')) + # NumPy Implementation + + vals = [] + mb_size = num_samples() // 2 + i = 0 + running_mean = 0 + running_var = 1 + scale = 0.8 + bias = -0.25 + + while (i < num_samples()): + k = i + mb_size if (i + mb_size) < num_samples() else num_samples() + sample = _samples[i:k].reshape((k-i,7,5,3)) + + local_mean = sample.mean((0, 2, 3)) + local_var = sample.var((0, 2, 3)) + + running_mean = decay * running_mean + (1-decay)*local_mean[None,:,None,None] + running_var = decay * running_var + (1-decay)*local_var[None,:,None,None] + + inv_stdev = 1 / (np.sqrt(running_var + epsilon)) + + normalized = (sample - running_mean) * inv_stdev + y = scale * normalized + bias + z = tools.numpy_l2norm2(y) + vals.append(z) + i += mb_size + val = np.mean(z) + + tol = 8 * val * np.finfo(np.float32).eps + callbacks.append(lbann.CallbackCheckMetric( + metric=metrics[-1].name, + lower_bound=val-tol, + upper_bound=val+tol, + error_on_failure=True, + execution_modes='test')) + # ------------------------------------------ # Gradient checking # ------------------------------------------ From 3d1068df490ed1b045793e1e77806e017abef75a Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Wed, 25 Jan 2023 22:39:31 -0800 Subject: [PATCH 2/2] Updating batchnorm CI. Still failing --- .../unit_tests/test_unit_layer_batch_normalization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ci_test/unit_tests/test_unit_layer_batch_normalization.py b/ci_test/unit_tests/test_unit_layer_batch_normalization.py index 9d1225eae8d..51d48fedce7 100644 --- a/ci_test/unit_tests/test_unit_layer_batch_normalization.py +++ b/ci_test/unit_tests/test_unit_layer_batch_normalization.py @@ -104,14 +104,15 @@ def construct_model(lbann): # LBANN implementation decay = 0.9 epsilon = 1e-5 - x = x_lbann + x = lbann.Identity(x_lbann, name='input_layer') y = lbann.BatchNormalization(x, decay=decay, epsilon=epsilon, scale_init=0.8, bias_init=-0.25, statistics_group_size=-1, - data_layout='data_parallel') + data_layout='data_parallel', + name="global_bn_layer") z = lbann.L2Norm2(y) obj.append(z) metrics.append(lbann.Metric(z, name='global statistics')) @@ -146,6 +147,12 @@ def construct_model(lbann): val = np.mean(z) tol = 8 * val * np.finfo(np.float32).eps + callbacks.append(lbann.CallbackDumpOutputs( + layers="input_layer" + )) + callbacks.append(lbann.CallbackDumpOutputs( + layers="global_bn_layer" + )) callbacks.append(lbann.CallbackCheckMetric( metric=metrics[-1].name, lower_bound=val-tol,