Skip to content

Commit 4896b5d

Browse files
committed
Updated CI test with new environment imports
- CI test passing on Lassen
1 parent 4b57509 commit 4896b5d

File tree

1 file changed

+83
-55
lines changed

1 file changed

+83
-55
lines changed

ci_test/unit_tests/test_unit_layer_channelwise_softmax_distconv.py

Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import os.path
55
import sys
66
import numpy as np
7+
import lbann.contrib.args
78

89
# Bamboo utilities
910
current_file = os.path.realpath(__file__)
1011
current_dir = os.path.dirname(current_file)
11-
sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python'))
12+
sys.path.insert(0, os.path.join(os.path.dirname(current_dir), "common_python"))
1213
import tools
1314

1415
# ==============================================
@@ -20,34 +21,45 @@
2021
# Data
2122
np.random.seed(20200115)
2223
_num_samples = 15
23-
_sample_dims = (15,5,1)
24+
_sample_dims = (15, 5, 1)
2425
_sample_size = functools.reduce(operator.mul, _sample_dims)
25-
_samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32)
26+
_samples = np.random.normal(loc=0.5, size=(_num_samples, _sample_size)).astype(
27+
np.float32
28+
)
29+
2630

2731
# Sample access functions
2832
def get_sample(index):
29-
return _samples[index,:]
33+
return _samples[index, :]
34+
35+
3036
def num_samples():
3137
return _num_samples
38+
39+
3240
def sample_dims():
3341
return (_sample_size,)
3442

43+
3544
# ==============================================
3645
# NumPy implementation
3746
# ==============================================
3847

48+
3949
def numpy_channelwise_softmax(x):
4050
if x.dtype is not np.float64:
4151
x = x.astype(np.float64)
42-
axis = tuple(range(1,x.ndim))
52+
axis = tuple(range(1, x.ndim))
4353
shift = np.max(x, axis=axis, keepdims=True)
44-
y = np.exp(x-shift)
54+
y = np.exp(x - shift)
4555
return y / np.sum(y, axis=axis, keepdims=True)
4656

57+
4758
# ==============================================
4859
# Setup LBANN experiment
4960
# ==============================================
5061

62+
5163
def setup_experiment(lbann, weekly):
5264
"""Construct LBANN experiment.
5365
@@ -60,11 +72,18 @@ def setup_experiment(lbann, weekly):
6072
model = construct_model(lbann)
6173
data_reader = construct_data_reader(lbann)
6274
optimizer = lbann.NoOptimizer()
63-
return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes
75+
return (
76+
trainer,
77+
model,
78+
data_reader,
79+
optimizer,
80+
None,
81+
) # Don't request any specific number of nodes
82+
6483

6584
def create_parallel_strategy(num_channel_groups):
66-
return {"channel_groups": num_channel_groups,
67-
"filter_groups": num_channel_groups}
85+
return {"channel_groups": num_channel_groups, "filter_groups": num_channel_groups}
86+
6887

6988
def construct_model(lbann):
7089
"""Construct LBANN model.
@@ -77,22 +96,24 @@ def construct_model(lbann):
7796
# Input data
7897
# Note: Sum with a weights layer so that gradient checking will
7998
# verify that error signals are correct.
80-
x_weights = lbann.Weights(optimizer=lbann.SGD(),
81-
initializer=lbann.ConstantInitializer(value=0.0),
82-
name='input_weights')
83-
x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'),
84-
dims=_sample_dims),
85-
lbann.WeightsLayer(weights=x_weights,
86-
dims=_sample_dims))
99+
x_weights = lbann.Weights(
100+
optimizer=lbann.SGD(),
101+
initializer=lbann.ConstantInitializer(value=0.0),
102+
name="input_weights",
103+
)
104+
x = lbann.Sum(
105+
lbann.Reshape(lbann.Input(data_field="samples"), dims=_sample_dims),
106+
lbann.WeightsLayer(weights=x_weights, dims=_sample_dims),
107+
)
87108
x_lbann = x
88109
obj = []
89110
metrics = []
90111
callbacks = []
91112

92113
num_channel_groups = tools.gpus_per_node(lbann)
93114
if num_channel_groups == 0:
94-
e = 'this test requires GPUs.'
95-
print('Skip - ' + e)
115+
e = "this test requires GPUs."
116+
print("Skip - " + e)
96117
pytest.skip(e)
97118

98119
# ------------------------------------------
@@ -102,13 +123,15 @@ def construct_model(lbann):
102123
# LBANN implementation
103124
x = x_lbann
104125

105-
y = lbann.ChannelwiseSoftmax(x,
106-
data_layout='data_parallel',
107-
parallel_strategy=create_parallel_strategy(num_channel_groups),
108-
name="Channelwise_softmax_distconv")
126+
y = lbann.ChannelwiseSoftmax(
127+
x,
128+
data_layout="data_parallel",
129+
parallel_strategy=create_parallel_strategy(num_channel_groups),
130+
name="Channelwise_softmax_distconv",
131+
)
109132
z = lbann.L2Norm2(y)
110133
obj.append(z)
111-
metrics.append(lbann.Metric(z, name='channelwise split distconv'))
134+
metrics.append(lbann.Metric(z, name="channelwise split distconv"))
112135

113136
# NumPy implementation
114137
vals = []
@@ -119,12 +142,15 @@ def construct_model(lbann):
119142
vals.append(z)
120143
val = np.mean(vals)
121144
tol = 8 * val * np.finfo(np.float32).eps
122-
callbacks.append(lbann.CallbackCheckMetric(
123-
metric=metrics[-1].name,
124-
lower_bound=val-tol,
125-
upper_bound=val+tol,
126-
error_on_failure=True,
127-
execution_modes='test'))
145+
callbacks.append(
146+
lbann.CallbackCheckMetric(
147+
metric=metrics[-1].name,
148+
lower_bound=val - tol,
149+
upper_bound=val + tol,
150+
error_on_failure=True,
151+
execution_modes="test",
152+
)
153+
)
128154

129155
# ------------------------------------------
130156
# Gradient checking
@@ -137,11 +163,14 @@ def construct_model(lbann):
137163
# ------------------------------------------
138164

139165
num_epochs = 0
140-
return lbann.Model(num_epochs,
141-
layers=lbann.traverse_layer_graph(x_lbann),
142-
objective_function=obj,
143-
metrics=metrics,
144-
callbacks=callbacks)
166+
return lbann.Model(
167+
num_epochs,
168+
layers=lbann.traverse_layer_graph(x_lbann),
169+
objective_function=obj,
170+
metrics=metrics,
171+
callbacks=callbacks,
172+
)
173+
145174

146175
def construct_data_reader(lbann):
147176
"""Construct Protobuf message for Python data reader.
@@ -157,32 +186,31 @@ def construct_data_reader(lbann):
157186
# Note: The training data reader should be removed when
158187
# https://github.com/LLNL/lbann/issues/1098 is resolved.
159188
message = lbann.reader_pb2.DataReader()
160-
message.reader.extend([
161-
tools.create_python_data_reader(
162-
lbann,
163-
current_file,
164-
'get_sample',
165-
'num_samples',
166-
'sample_dims',
167-
'train'
168-
)
169-
])
170-
message.reader.extend([
171-
tools.create_python_data_reader(
172-
lbann,
173-
current_file,
174-
'get_sample',
175-
'num_samples',
176-
'sample_dims',
177-
'test'
178-
)
179-
])
189+
message.reader.extend(
190+
[
191+
tools.create_python_data_reader(
192+
lbann, current_file, "get_sample", "num_samples", "sample_dims", "train"
193+
)
194+
]
195+
)
196+
message.reader.extend(
197+
[
198+
tools.create_python_data_reader(
199+
lbann, current_file, "get_sample", "num_samples", "sample_dims", "test"
200+
)
201+
]
202+
)
180203
return message
181204

205+
182206
# ==============================================
183207
# Setup PyTest
184208
# ==============================================
185209

186210
# Create test functions that can interact with PyTest
187-
for _test_func in tools.create_tests(setup_experiment, __file__, environment=tools.get_distconv_environment()):
211+
for _test_func in tools.create_tests(
212+
setup_experiment,
213+
__file__,
214+
environment=lbann.contrib.args.get_distconv_environment(),
215+
):
188216
globals()[_test_func.__name__] = _test_func

0 commit comments

Comments
 (0)