Skip to content

Commit c9d643b

Browse files
authored
Add an LBANN Python unit test wrapper and utilities (#2264)
* Add an LBANN Python unit test wrapper and utilities * Add capability for extra metrics and callbacks * Add a simple test that uses the new interface * Relax bounds of NASNet further * Improve support for multidimensional tensors * Add single-tensor test data reader * Improve readability of pytest assertions * Fix weighted sum operation and add test * Make weighted sum in-place-capable, ensure backprop runs all the way through in testing
1 parent cffea66 commit c9d643b

File tree

7 files changed

+490
-7
lines changed

7 files changed

+490
-7
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
################################################################################
2+
# Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3+
# Produced at the Lawrence Livermore National Laboratory.
4+
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5+
# the CONTRIBUTORS file. <[email protected]>
6+
#
7+
# LLNL-CODE-697807.
8+
# All rights reserved.
9+
#
10+
# This file is part of LBANN: Livermore Big Artificial Neural Network
11+
# Toolkit. For details, see http://software.llnl.gov/LBANN or
12+
# https://github.com/LLNL/LBANN.
13+
#
14+
# Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15+
# may not use this file except in compliance with the License. You may
16+
# obtain a copy of the License at:
17+
#
18+
# http://www.apache.org/licenses/LICENSE-2.0
19+
#
20+
# Unless required by applicable law or agreed to in writing, software
21+
# distributed under the License is distributed on an "AS IS" BASIS,
22+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23+
# implied. See the License for the specific language governing
24+
# permissions and limitations under the license.
25+
#
26+
################################################################################
27+
"""
28+
Simple data reader that opens one file with one tensor. Used for unit testing.
29+
"""
30+
import numpy as np
31+
32+
# Lazy-load tensor
33+
tensor = None
34+
35+
36+
def lazy_load():
37+
# This file operates under the assumption that the working directory is set
38+
# to a specific experiment.
39+
global tensor
40+
if tensor is None:
41+
tensor = np.load('data.npy')
42+
assert len(tensor.shape) == 2
43+
44+
45+
def get_sample(idx):
46+
lazy_load()
47+
return tensor[idx]
48+
49+
50+
def num_samples():
51+
lazy_load()
52+
return tensor.shape[0]
53+
54+
55+
def sample_dims():
56+
lazy_load()
57+
return (tensor.shape[1], )

ci_test/common_python/test_util.py

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
################################################################################
2+
# Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3+
# Produced at the Lawrence Livermore National Laboratory.
4+
# Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5+
# the CONTRIBUTORS file. <[email protected]>
6+
#
7+
# LLNL-CODE-697807.
8+
# All rights reserved.
9+
#
10+
# This file is part of LBANN: Livermore Big Artificial Neural Network
11+
# Toolkit. For details, see http://software.llnl.gov/LBANN or
12+
# https://github.com/LLNL/LBANN.
13+
#
14+
# Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15+
# may not use this file except in compliance with the License. You may
16+
# obtain a copy of the License at:
17+
#
18+
# http://www.apache.org/licenses/LICENSE-2.0
19+
#
20+
# Unless required by applicable law or agreed to in writing, software
21+
# distributed under the License is distributed on an "AS IS" BASIS,
22+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23+
# implied. See the License for the specific language governing
24+
# permissions and limitations under the license.
25+
#
26+
################################################################################
27+
import lbann
28+
from dataclasses import dataclass, field
29+
import functools
30+
import inspect
31+
from typing import Any, Callable, List, Optional, Tuple, Union
32+
import numpy as np
33+
import os
34+
import re
35+
import tools
36+
import single_tensor_data_reader
37+
38+
39+
def lbann_test(check_gradients=False, **decorator_kwargs):
40+
"""
41+
A decorator that wraps an LBANN-enabled model unit test.
42+
Use it before a function named ``test_*`` to run it automatically in pytest.
43+
The unit test in the wrapped function must return a ``test_util.ModelTester``
44+
object, which contains all the necessary information to test the model (e.g.,
45+
model, input/reference tensors).
46+
47+
The decorator wraps the test with the appropriate setup phase, data reading,
48+
callbacks, and metrics so that the test functions properly.
49+
"""
50+
51+
def internal_tester(f):
52+
53+
@functools.wraps(f)
54+
def wrapped(*args, **kwargs):
55+
# Call model constructor
56+
tester = f(*args, **kwargs)
57+
58+
# Check return value
59+
if not isinstance(tester, ModelTester):
60+
raise ValueError('LBANN test must return a ModelTester object')
61+
if tester.loss is None:
62+
raise ValueError(
63+
'LBANN test did not define a loss function, '
64+
'use ``ModelTester.set_loss`` or ``set_loss_function``.')
65+
if tester.input_tensor is None:
66+
raise ValueError('LBANN test did not define an input, call '
67+
'``ModelTester.inputs`` or ``inputs_like``.')
68+
if (tester.reference_tensor is not None
69+
and tester.reference_tensor.shape[0] !=
70+
tester.input_tensor.shape[0]):
71+
raise ValueError(
72+
'Input and reference tensors in LBANN test '
73+
'must match in the first (minibatch) dimension')
74+
full_graph = lbann.traverse_layer_graph(tester.loss)
75+
callbacks = []
76+
callbacks.append(
77+
lbann.CallbackCheckMetric(metric='test',
78+
lower_bound=0,
79+
upper_bound=tester.tolerance,
80+
error_on_failure=True,
81+
execution_modes='test'))
82+
if check_gradients:
83+
callbacks.append(
84+
lbann.CallbackCheckGradients(error_on_failure=True))
85+
callbacks.extend(tester.extra_callbacks)
86+
87+
metrics = [lbann.Metric(tester.loss, name='test')]
88+
metrics.extend(tester.extra_metrics)
89+
model = lbann.Model(epochs=0,
90+
layers=full_graph,
91+
metrics=metrics,
92+
callbacks=callbacks)
93+
94+
# Get file
95+
file = inspect.getfile(f)
96+
97+
def setup_func(lbann, weekly):
98+
# Get minibatch size from tensor
99+
mini_batch_size = tester.input_tensor.shape[0]
100+
101+
# Save combined input/reference data to file
102+
work_dir = _get_work_dir(file)
103+
os.makedirs(work_dir, exist_ok=True)
104+
if tester.reference_tensor is not None:
105+
flat_inp = tester.input_tensor.reshape(mini_batch_size, -1)
106+
flat_ref = tester.reference_tensor.reshape(
107+
mini_batch_size, -1)
108+
np.save(os.path.join(work_dir, 'data.npy'),
109+
np.concatenate((flat_inp, flat_ref), axis=1))
110+
else:
111+
np.save(os.path.join(work_dir, 'data.npy'),
112+
tester.input_tensor.reshape(mini_batch_size, -1))
113+
114+
# Setup data reader
115+
data_reader = lbann.reader_pb2.DataReader()
116+
data_reader.reader.extend([
117+
tools.create_python_data_reader(
118+
lbann, single_tensor_data_reader.__file__,
119+
'get_sample', 'num_samples', 'sample_dims', 'train'),
120+
tools.create_python_data_reader(
121+
lbann, single_tensor_data_reader.__file__,
122+
'get_sample', 'num_samples', 'sample_dims', 'test')
123+
])
124+
125+
trainer = lbann.Trainer(mini_batch_size)
126+
optimizer = lbann.NoOptimizer()
127+
return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes
128+
129+
test = tools.create_tests(setup_func, file, **decorator_kwargs)[0]
130+
cluster = kwargs.get('cluster', 'unset')
131+
weekly = kwargs.get('weekly', False)
132+
test(cluster, weekly, False, **decorator_kwargs)
133+
134+
return wrapped
135+
136+
return internal_tester
137+
138+
139+
@dataclass
140+
class ModelTester:
141+
"""
142+
An object that is constructed within an ``lbann_test``-wrapped unit test.
143+
"""
144+
145+
# Input tensor (required for test to construct)
146+
input_tensor: Optional[Any] = None
147+
148+
reference: Optional[lbann.Layer] = None #: Reference LBANN node (optional)
149+
reference_tensor: Optional[
150+
Any] = None #: Optional reference tensor to compare with
151+
152+
loss: Optional[lbann.Layer] = None # Optional loss test
153+
tolerance: float = 0.0 #: Tolerance value for loss test
154+
155+
# Optional additional metrics to use in test
156+
extra_metrics: List[lbann.Metric] = field(default_factory=list)
157+
158+
# Optional additional callbacks to use in test
159+
extra_callbacks: List[lbann.Callback] = field(default_factory=list)
160+
161+
def inputs(self, tensor: Any) -> lbann.Layer:
162+
"""
163+
Marks the given tensor as an input of the tested LBANN model, and
164+
returns a matching LBANN Input node (or a Slice/Reshape thereof).
165+
166+
:param tensor: The input NumPy array to use.
167+
:return: An LBANN layer object that will serve as the input.
168+
"""
169+
self.input_tensor = tensor
170+
inp = lbann.Input(data_field='samples')
171+
return slice_to_tensors(inp, tensor)
172+
173+
def inputs_like(self, *tensors) -> List[lbann.Layer]:
174+
"""
175+
Marks the given tensors as input of the tested LBANN model, and
176+
returns a list of matching LBANN Slice nodes, potentially reshaped to
177+
be like the input tensors.
178+
179+
:param tensors: The input NumPy arrays to use.
180+
:return: A list of LBANN layer objects that will serve as the inputs.
181+
"""
182+
minibatch_size = tensors[0].shape[0] # Assume the first dimension
183+
184+
# All tensors concatenated on the non-batch dimension
185+
all_tensors_combined = np.concatenate(
186+
[t.reshape(minibatch_size, -1) for t in tensors], axis=1)
187+
188+
self.input_tensor = all_tensors_combined
189+
x = lbann.Input(data_field='samples')
190+
return slice_to_tensors(x, *tensors)
191+
192+
def make_reference(self, ref: Any) -> lbann.Input:
193+
"""
194+
Marks the given tensor as a reference output of the tested LBANN model,
195+
and returns a matching LBANN node.
196+
197+
:param ref: The reference NumPy array to use.
198+
:return: An LBANN layer object that will serve as the reference.
199+
"""
200+
# The reference is the second part of the input "samples"
201+
refnode = lbann.Input(data_field='samples')
202+
if self.input_tensor is None:
203+
raise ValueError('Please call ``inputs`` or ``inputs_like`` prior '
204+
'to calling ``make_reference`` for correctness.')
205+
mbsize = self.input_tensor.shape[0]
206+
207+
# Obtain reference
208+
refnode = lbann.Reshape(lbann.Identity(
209+
lbann.Slice(
210+
refnode,
211+
slice_points=[
212+
numel(self.input_tensor) // mbsize,
213+
(numel(self.input_tensor) + numel(ref)) // mbsize
214+
],
215+
)),
216+
dims=ref.shape[1:])
217+
218+
# Store reference
219+
self.reference = refnode
220+
self.reference_tensor = ref
221+
return self.reference
222+
223+
def set_loss_function(self,
224+
func: Callable[[lbann.Layer, lbann.Layer],
225+
lbann.Layer],
226+
output: lbann.Layer,
227+
tolerance=None):
228+
"""
229+
Sets a loss function and the LBANN test output to be measured for the
230+
test.
231+
This assumes that the first argument has two parameters (e.g.,
232+
``MeanSquaredError``), where the first argument will be used for the
233+
LBANN output and the second will be used for the reference.
234+
235+
:param func: The loss function.
236+
:param output: The LBANN model output to use.
237+
:param tolerance: Optional tolerance to set for the test. If ``None``,
238+
the default tolerance of ``8*eps*mean(reference)``
239+
will be used.
240+
"""
241+
return self.set_loss(func(output, self.reference), tolerance)
242+
243+
def set_loss(self,
244+
loss: lbann.Layer,
245+
tolerance: Optional[float] = None) -> None:
246+
"""
247+
Sets an LBANN node to be measured for the test.
248+
249+
:param loss: The LBANN graph node to use for the test.
250+
:param tolerance: Optional tolerance to set for the test. If ``None``,
251+
the default tolerance of ``8*eps*mean(reference)``
252+
will be used.
253+
"""
254+
# Set loss node
255+
self.loss = loss
256+
257+
# Set tolerance
258+
if tolerance is not None:
259+
self.tolerance = tolerance
260+
else:
261+
if self.reference_tensor is None:
262+
raise ValueError(
263+
'Cannot set tolerance on loss function automatically '
264+
'without a reference tensor. Either set tolerance '
265+
'explicitly or call ``ModelTester.make_reference``.')
266+
# Default tolerance
267+
self.tolerance = abs(8 * np.mean(self.reference_tensor) *
268+
np.finfo(self.reference_tensor.dtype).eps)
269+
270+
271+
def slice_to_tensors(x: lbann.Layer, *tensors) -> List[lbann.Layer]:
272+
"""
273+
Slices an LBANN layer into multiple tensors that match the dimensions of
274+
the given numpy arrays.
275+
"""
276+
slice_points = [0]
277+
offset = 0
278+
for tensor in tensors:
279+
offset += numel(tensor) // tensor.shape[0]
280+
281+
slice_points.append(offset)
282+
lslice = lbann.Slice(x, slice_points=slice_points)
283+
return [
284+
lbann.Reshape(_ensure_bp(t, lbann.Identity(lslice)), dims=t.shape[1:])
285+
for t in tensors
286+
]
287+
288+
289+
def numel(tensor) -> int:
290+
"""
291+
Returns the number of elements in a NumPy array, PyTorch array, or integer.
292+
"""
293+
if isinstance(tensor, int): # Integer
294+
return tensor
295+
elif hasattr(tensor, 'numel'): # PyTorch array
296+
return tensor.numel()
297+
else: # NumPy array
298+
return tensor.size
299+
300+
301+
# Mimics the other tester's determination of working directory
302+
def _get_work_dir(test_file: str) -> str:
303+
test_fname = os.path.realpath(test_file)
304+
# Create test name by removing '.py' from file name
305+
test_fname = os.path.splitext(os.path.basename(test_fname))[0]
306+
if not re.match('^test_.', test_fname):
307+
# Make sure test name is prefixed with 'test_'
308+
test_fname = 'test_' + test_fname
309+
return os.path.join(os.path.dirname(test_file), 'experiments', test_fname)
310+
311+
312+
# Ensures that backpropagation would be run through the entire model
313+
def _ensure_bp(tensor: Any, node: lbann.Layer) -> lbann.Sum:
314+
# Note: Sum with a weights layer so that gradient checking will
315+
# verify that error signals are correct.
316+
x_weights = lbann.Weights(initializer=lbann.ConstantInitializer(value=0.0))
317+
return lbann.Sum(
318+
node,
319+
lbann.WeightsLayer(
320+
weights=x_weights,
321+
dims=[numel(tensor) // tensor.shape[0]],
322+
))

0 commit comments

Comments
 (0)