Skip to content

Commit 50a6db9

Browse files
committed
- Added model compile-time checks on the shape of the input when distconv is enabled
- Updated ReleaseNotes
1 parent 99db87e commit 50a6db9

File tree

3 files changed

+206
-0
lines changed

3 files changed

+206
-0
lines changed

ReleaseNotes.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Support for new network structures:
2323
- RoBERTa with pretrained weights
2424

2525
Support for new layers:
26+
- - Added distributed tensor parallelism with channelwise decomposition for channelwise softmax layer
2627
- Added support for 2D Matrices for Scatter and Gather layers
2728
- Added image rotation layer and composite image transformation layer
2829
(rotate, shear, translate)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import functools
2+
import operator
3+
import os
4+
import os.path
5+
import sys
6+
import numpy as np
7+
8+
# Bamboo utilities
9+
current_file = os.path.realpath(__file__)
10+
current_dir = os.path.dirname(current_file)
11+
sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python'))
12+
import tools
13+
14+
# ==============================================
15+
# Objects for Python data reader
16+
# ==============================================
17+
# Note: The Python data reader imports this file as a module and calls
18+
# the functions below to ingest data.
19+
20+
# Data
21+
np.random.seed(20200115)
22+
_num_samples = 15
23+
_sample_dims = (15,36,1)
24+
_sample_size = functools.reduce(operator.mul, _sample_dims)
25+
_samples = np.random.normal(loc=0.5, size=(_num_samples,_sample_size)).astype(np.float32)
26+
27+
# Sample access functions
28+
def get_sample(index):
29+
return _samples[index,:]
30+
def num_samples():
31+
return _num_samples
32+
def sample_dims():
33+
return (_sample_size,)
34+
35+
# ==============================================
36+
# NumPy implementation
37+
# ==============================================
38+
39+
def numpy_channelwise_softmax(x):
40+
if x.dtype is not np.float64:
41+
x = x.astype(np.float64)
42+
axis = tuple(range(1,x.ndim))
43+
shift = np.max(x, axis=axis, keepdims=True)
44+
y = np.exp(x-shift)
45+
return y / np.sum(y, axis=axis, keepdims=True)
46+
47+
# ==============================================
48+
# Setup LBANN experiment
49+
# ==============================================
50+
51+
def setup_experiment(lbann, weekly):
52+
"""Construct LBANN experiment.
53+
54+
Args:
55+
lbann (module): Module for LBANN Python frontend
56+
57+
"""
58+
mini_batch_size = num_samples() // 2
59+
trainer = lbann.Trainer(mini_batch_size)
60+
model = construct_model(lbann)
61+
data_reader = construct_data_reader(lbann)
62+
optimizer = lbann.NoOptimizer()
63+
return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes
64+
65+
def create_parallel_strategy(num_channel_groups):
66+
return {"channel_groups": num_channel_groups,
67+
"filter_groups": num_channel_groups}
68+
69+
def construct_model(lbann):
70+
"""Construct LBANN model.
71+
72+
Args:
73+
lbann (module): Module for LBANN Python frontend
74+
75+
"""
76+
77+
# Input data
78+
# Note: Sum with a weights layer so that gradient checking will
79+
# verify that error signals are correct.
80+
x_weights = lbann.Weights(optimizer=lbann.SGD(),
81+
initializer=lbann.ConstantInitializer(value=0.0),
82+
name='input_weights')
83+
x = lbann.Sum(lbann.Reshape(lbann.Input(data_field='samples'),
84+
dims=_sample_dims),
85+
lbann.WeightsLayer(weights=x_weights,
86+
dims=_sample_dims))
87+
x_lbann = x
88+
obj = []
89+
metrics = []
90+
callbacks = []
91+
92+
num_channel_groups = tools.gpus_per_node(lbann)
93+
if num_channel_groups == 0:
94+
e = 'this test requires GPUs.'
95+
print('Skip - ' + e)
96+
pytest.skip(e)
97+
98+
# ------------------------------------------
99+
# Data-parallel layout
100+
# ------------------------------------------
101+
102+
# LBANN implementation
103+
x = x_lbann
104+
105+
y = lbann.ChannelwiseSoftmax(x,
106+
parallel_strategy=create_parallel_strategy(num_channel_groups),
107+
name="Channelwise_softmax_distconv")
108+
z = lbann.L2Norm2(y)
109+
obj.append(z)
110+
metrics.append(lbann.Metric(z, name='data-parallel layout'))
111+
112+
# NumPy implementation
113+
vals = []
114+
for i in range(num_samples()):
115+
x = get_sample(i).reshape(_sample_dims).astype(np.float64)
116+
y = numpy_channelwise_softmax(x)
117+
z = tools.numpy_l2norm2(y)
118+
vals.append(z)
119+
val = np.mean(vals)
120+
tol = 8 * val * np.finfo(np.float32).eps
121+
callbacks.append(lbann.CallbackCheckMetric(
122+
metric=metrics[-1].name,
123+
lower_bound=val-tol,
124+
upper_bound=val+tol,
125+
error_on_failure=True,
126+
execution_modes='test'))
127+
128+
# ------------------------------------------
129+
# Gradient checking
130+
# ------------------------------------------
131+
132+
callbacks.append(lbann.CallbackCheckGradients(error_on_failure=True))
133+
134+
# ------------------------------------------
135+
# Construct model
136+
# ------------------------------------------
137+
138+
num_epochs = 0
139+
return lbann.Model(num_epochs,
140+
layers=lbann.traverse_layer_graph(x_lbann),
141+
objective_function=obj,
142+
metrics=metrics,
143+
callbacks=callbacks)
144+
145+
def construct_data_reader(lbann):
146+
"""Construct Protobuf message for Python data reader.
147+
148+
The Python data reader will import the current Python file to
149+
access the sample access functions.
150+
151+
Args:
152+
lbann (module): Module for LBANN Python frontend
153+
154+
"""
155+
156+
# Note: The training data reader should be removed when
157+
# https://github.com/LLNL/lbann/issues/1098 is resolved.
158+
message = lbann.reader_pb2.DataReader()
159+
message.reader.extend([
160+
tools.create_python_data_reader(
161+
lbann,
162+
current_file,
163+
'get_sample',
164+
'num_samples',
165+
'sample_dims',
166+
'train'
167+
)
168+
])
169+
message.reader.extend([
170+
tools.create_python_data_reader(
171+
lbann,
172+
current_file,
173+
'get_sample',
174+
'num_samples',
175+
'sample_dims',
176+
'test'
177+
)
178+
])
179+
return message
180+
181+
# ==============================================
182+
# Setup PyTest
183+
# ==============================================
184+
185+
# Create test functions that can interact with PyTest
186+
for _test_func in tools.create_tests(setup_experiment, __file__):
187+
globals()[_test_func.__name__] = _test_func

include/lbann/layers/misc/channelwise_softmax.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,24 @@ template <typename TensorDataType, data_layout Layout, El::Device Device>
168168
void channelwise_softmax_layer<TensorDataType,Layout,Device>::setup_dims(DataReaderMetaData& dr_metadata) {
169169
data_type_layer<TensorDataType>::setup_dims(dr_metadata);
170170
this->set_output_dims(this->get_input_dims());
171+
172+
#ifdef LBANN_HAS_DISTCONV
173+
174+
if (this->distconv_enabled()){
175+
// Additional checks when distconv mode is enabled
176+
const auto& input_dims = this->get_input_dims();
177+
const auto& output_dims = this->get_output_dims();
178+
179+
if (input_dims.size() != 3 || output_dims.size() != 3){
180+
LBANN_ERROR(this->get_type()," layer \"",this->get_name(),"\" ",
181+
"expects an input and output tensor with 3 dimensions (channel, *, *), "
182+
"but it has been configured as a ",
183+
input_dims.size(), "-D input tensor and ",
184+
output_dims.size(),"-D output tensor");
185+
}
186+
}
187+
188+
#endif
171189
}
172190

173191
#ifdef LBANN_HAS_DISTCONV

0 commit comments

Comments
 (0)