44import os .path
55import sys
66import numpy as np
7+ import lbann .contrib .args
78
89# Bamboo utilities
910current_file = os .path .realpath (__file__ )
1011current_dir = os .path .dirname (current_file )
11- sys .path .insert (0 , os .path .join (os .path .dirname (current_dir ), ' common_python' ))
12+ sys .path .insert (0 , os .path .join (os .path .dirname (current_dir ), " common_python" ))
1213import tools
1314
1415# ==============================================
2021# Data
2122np .random .seed (20200115 )
2223_num_samples = 15
23- _sample_dims = (15 ,5 , 1 )
24+ _sample_dims = (15 , 5 , 1 )
2425_sample_size = functools .reduce (operator .mul , _sample_dims )
25- _samples = np .random .normal (loc = 0.5 , size = (_num_samples ,_sample_size )).astype (np .float32 )
26+ _samples = np .random .normal (loc = 0.5 , size = (_num_samples , _sample_size )).astype (
27+ np .float32
28+ )
29+
2630
2731# Sample access functions
2832def get_sample (index ):
29- return _samples [index ,:]
33+ return _samples [index , :]
34+
35+
3036def num_samples ():
3137 return _num_samples
38+
39+
3240def sample_dims ():
3341 return (_sample_size ,)
3442
43+
3544# ==============================================
3645# NumPy implementation
3746# ==============================================
3847
48+
3949def numpy_channelwise_softmax (x ):
4050 if x .dtype is not np .float64 :
4151 x = x .astype (np .float64 )
42- axis = tuple (range (1 ,x .ndim ))
52+ axis = tuple (range (1 , x .ndim ))
4353 shift = np .max (x , axis = axis , keepdims = True )
44- y = np .exp (x - shift )
54+ y = np .exp (x - shift )
4555 return y / np .sum (y , axis = axis , keepdims = True )
4656
57+
4758# ==============================================
4859# Setup LBANN experiment
4960# ==============================================
5061
62+
5163def setup_experiment (lbann , weekly ):
5264 """Construct LBANN experiment.
5365
@@ -60,11 +72,18 @@ def setup_experiment(lbann, weekly):
6072 model = construct_model (lbann )
6173 data_reader = construct_data_reader (lbann )
6274 optimizer = lbann .NoOptimizer ()
63- return trainer , model , data_reader , optimizer , None # Don't request any specific number of nodes
75+ return (
76+ trainer ,
77+ model ,
78+ data_reader ,
79+ optimizer ,
80+ None ,
81+ ) # Don't request any specific number of nodes
82+
6483
6584def create_parallel_strategy (num_channel_groups ):
66- return {"channel_groups" : num_channel_groups ,
67- "filter_groups" : num_channel_groups }
85+ return {"channel_groups" : num_channel_groups , "filter_groups" : num_channel_groups }
86+
6887
6988def construct_model (lbann ):
7089 """Construct LBANN model.
@@ -77,22 +96,24 @@ def construct_model(lbann):
7796 # Input data
7897 # Note: Sum with a weights layer so that gradient checking will
7998 # verify that error signals are correct.
80- x_weights = lbann .Weights (optimizer = lbann .SGD (),
81- initializer = lbann .ConstantInitializer (value = 0.0 ),
82- name = 'input_weights' )
83- x = lbann .Sum (lbann .Reshape (lbann .Input (data_field = 'samples' ),
84- dims = _sample_dims ),
85- lbann .WeightsLayer (weights = x_weights ,
86- dims = _sample_dims ))
99+ x_weights = lbann .Weights (
100+ optimizer = lbann .SGD (),
101+ initializer = lbann .ConstantInitializer (value = 0.0 ),
102+ name = "input_weights" ,
103+ )
104+ x = lbann .Sum (
105+ lbann .Reshape (lbann .Input (data_field = "samples" ), dims = _sample_dims ),
106+ lbann .WeightsLayer (weights = x_weights , dims = _sample_dims ),
107+ )
87108 x_lbann = x
88109 obj = []
89110 metrics = []
90111 callbacks = []
91112
92113 num_channel_groups = tools .gpus_per_node (lbann )
93114 if num_channel_groups == 0 :
94- e = ' this test requires GPUs.'
95- print (' Skip - ' + e )
115+ e = " this test requires GPUs."
116+ print (" Skip - " + e )
96117 pytest .skip (e )
97118
98119 # ------------------------------------------
@@ -102,13 +123,15 @@ def construct_model(lbann):
102123 # LBANN implementation
103124 x = x_lbann
104125
105- y = lbann .ChannelwiseSoftmax (x ,
106- data_layout = 'data_parallel' ,
107- parallel_strategy = create_parallel_strategy (num_channel_groups ),
108- name = "Channelwise_softmax_distconv" )
126+ y = lbann .ChannelwiseSoftmax (
127+ x ,
128+ data_layout = "data_parallel" ,
129+ parallel_strategy = create_parallel_strategy (num_channel_groups ),
130+ name = "Channelwise_softmax_distconv" ,
131+ )
109132 z = lbann .L2Norm2 (y )
110133 obj .append (z )
111- metrics .append (lbann .Metric (z , name = ' channelwise split distconv' ))
134+ metrics .append (lbann .Metric (z , name = " channelwise split distconv" ))
112135
113136 # NumPy implementation
114137 vals = []
@@ -119,12 +142,15 @@ def construct_model(lbann):
119142 vals .append (z )
120143 val = np .mean (vals )
121144 tol = 8 * val * np .finfo (np .float32 ).eps
122- callbacks .append (lbann .CallbackCheckMetric (
123- metric = metrics [- 1 ].name ,
124- lower_bound = val - tol ,
125- upper_bound = val + tol ,
126- error_on_failure = True ,
127- execution_modes = 'test' ))
145+ callbacks .append (
146+ lbann .CallbackCheckMetric (
147+ metric = metrics [- 1 ].name ,
148+ lower_bound = val - tol ,
149+ upper_bound = val + tol ,
150+ error_on_failure = True ,
151+ execution_modes = "test" ,
152+ )
153+ )
128154
129155 # ------------------------------------------
130156 # Gradient checking
@@ -137,11 +163,14 @@ def construct_model(lbann):
137163 # ------------------------------------------
138164
139165 num_epochs = 0
140- return lbann .Model (num_epochs ,
141- layers = lbann .traverse_layer_graph (x_lbann ),
142- objective_function = obj ,
143- metrics = metrics ,
144- callbacks = callbacks )
166+ return lbann .Model (
167+ num_epochs ,
168+ layers = lbann .traverse_layer_graph (x_lbann ),
169+ objective_function = obj ,
170+ metrics = metrics ,
171+ callbacks = callbacks ,
172+ )
173+
145174
146175def construct_data_reader (lbann ):
147176 """Construct Protobuf message for Python data reader.
@@ -157,32 +186,31 @@ def construct_data_reader(lbann):
157186 # Note: The training data reader should be removed when
158187 # https://github.com/LLNL/lbann/issues/1098 is resolved.
159188 message = lbann .reader_pb2 .DataReader ()
160- message .reader .extend ([
161- tools .create_python_data_reader (
162- lbann ,
163- current_file ,
164- 'get_sample' ,
165- 'num_samples' ,
166- 'sample_dims' ,
167- 'train'
168- )
169- ])
170- message .reader .extend ([
171- tools .create_python_data_reader (
172- lbann ,
173- current_file ,
174- 'get_sample' ,
175- 'num_samples' ,
176- 'sample_dims' ,
177- 'test'
178- )
179- ])
189+ message .reader .extend (
190+ [
191+ tools .create_python_data_reader (
192+ lbann , current_file , "get_sample" , "num_samples" , "sample_dims" , "train"
193+ )
194+ ]
195+ )
196+ message .reader .extend (
197+ [
198+ tools .create_python_data_reader (
199+ lbann , current_file , "get_sample" , "num_samples" , "sample_dims" , "test"
200+ )
201+ ]
202+ )
180203 return message
181204
205+
182206# ==============================================
183207# Setup PyTest
184208# ==============================================
185209
186210# Create test functions that can interact with PyTest
187- for _test_func in tools .create_tests (setup_experiment , __file__ , environment = tools .get_distconv_environment ()):
211+ for _test_func in tools .create_tests (
212+ setup_experiment ,
213+ __file__ ,
214+ environment = lbann .contrib .args .get_distconv_environment (),
215+ ):
188216 globals ()[_test_func .__name__ ] = _test_func
0 commit comments