Skip to content

Commit 4a58b17

Browse files
de-vri-esfchollet
authored andcommitted
Rethrow original exception in GeneratorEnqueuer. (#8485)
* Fix randint usage in test_multiprocessing * Rethrow original exception in GeneratorEnqueuer. * Use multiprocessing.Manager to obtain a race-free Queue.
1 parent a27b4a5 commit 4a58b17

File tree

3 files changed

+44
-23
lines changed

3 files changed

+44
-23
lines changed

keras/utils/data_utils.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tarfile
1212
import threading
1313
import time
14+
import traceback
1415
import zipfile
1516
from abc import abstractmethod
1617
from multiprocessing.pool import ThreadPool
@@ -553,7 +554,7 @@ def get(self):
553554
yield inputs
554555
except Exception as e:
555556
self.stop()
556-
raise StopIteration(e)
557+
six.raise_from(StopIteration(e), e)
557558

558559
def _send_sequence(self):
559560
"""Send current Sequence to all workers."""
@@ -614,6 +615,7 @@ def __init__(self, generator,
614615
self._use_multiprocessing = use_multiprocessing
615616
self._threads = []
616617
self._stop_event = None
618+
self._manager = None
617619
self.queue = None
618620
self.seed = seed
619621

@@ -631,18 +633,27 @@ def data_generator_task():
631633
try:
632634
if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
633635
generator_output = next(self._generator)
634-
self.queue.put(generator_output)
636+
self.queue.put((True, generator_output))
635637
else:
636638
time.sleep(self.wait_time)
637639
except StopIteration:
638640
break
639-
except Exception:
641+
except Exception as e:
642+
# Can't pick tracebacks.
643+
# As a compromise, print the traceback and pickle None instead.
644+
if self._use_multiprocessing:
645+
traceback.print_exc()
646+
setattr(e, '__traceback__', None)
647+
elif not hasattr(e, '__traceback__'):
648+
setattr(e, '__traceback__', sys.exc_info()[2])
649+
self.queue.put((False, e))
640650
self._stop_event.set()
641-
raise
651+
break
642652

643653
try:
644654
if self._use_multiprocessing:
645-
self.queue = multiprocessing.Queue(maxsize=max_queue_size)
655+
self._manager = multiprocessing.Manager()
656+
self.queue = self._manager.Queue(maxsize=max_queue_size)
646657
self._stop_event = multiprocessing.Event()
647658
else:
648659
self.queue = queue.Queue()
@@ -686,9 +697,8 @@ def stop(self, timeout=None):
686697
else:
687698
thread.join(timeout)
688699

689-
if self._use_multiprocessing:
690-
if self.queue is not None:
691-
self.queue.close()
700+
if self._manager:
701+
self._manager.shutdown()
692702

693703
self._threads = []
694704
self._stop_event = None
@@ -704,12 +714,22 @@ def get(self):
704714
"""
705715
while self.is_running():
706716
if not self.queue.empty():
707-
inputs = self.queue.get()
708-
if inputs is not None:
709-
yield inputs
717+
success, value = self.queue.get()
718+
# Rethrow any exceptions found in the queue
719+
if not success:
720+
six.reraise(value.__class__, value, value.__traceback__)
721+
# Yield regular values
722+
if value is not None:
723+
yield value
710724
else:
711725
all_finished = all([not thread.is_alive() for thread in self._threads])
712726
if all_finished and self.queue.empty():
713727
raise StopIteration()
714728
else:
715729
time.sleep(self.wait_time)
730+
731+
# Make sure to rethrow the first exception in the queue, if any
732+
while not self.queue.empty():
733+
success, value = self.queue.get()
734+
if not success:
735+
six.reraise(value.__class__, value, value.__traceback__)

tests/keras/utils/data_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_generator_enqueuer_fail_threads():
182182
FaultSequence()), use_multiprocessing=False)
183183
enqueuer.start(3, 10)
184184
gen_output = enqueuer.get()
185-
with pytest.raises(StopIteration):
185+
with pytest.raises(IndexError):
186186
next(gen_output)
187187

188188

@@ -191,7 +191,7 @@ def test_generator_enqueuer_fail_processes():
191191
FaultSequence()), use_multiprocessing=True)
192192
enqueuer.start(3, 10)
193193
gen_output = enqueuer.get()
194-
with pytest.raises(StopIteration):
194+
with pytest.raises(IndexError):
195195
next(gen_output)
196196

197197

tests/test_multiprocessing.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def custom_generator():
232232
"""Raises an exception after a few good batches"""
233233
for i in range(good_batches):
234234
yield (np.random.randint(batch_size, 256, (50, 2)),
235-
np.random.randint(batch_size, 2, 50))
235+
np.random.randint(batch_size, 12, 50))
236236
raise RuntimeError
237237

238238
model = Sequential()
@@ -241,13 +241,13 @@ def custom_generator():
241241

242242
samples = batch_size * (good_batches + 1)
243243

244-
with pytest.raises(StopIteration):
244+
with pytest.raises(RuntimeError):
245245
model.fit_generator(
246246
custom_generator(), samples, 1,
247247
workers=4, use_multiprocessing=True,
248248
)
249249

250-
with pytest.raises(StopIteration):
250+
with pytest.raises(RuntimeError):
251251
model.fit_generator(
252252
custom_generator(), samples, 1,
253253
use_multiprocessing=False,
@@ -258,25 +258,26 @@ def custom_generator():
258258
def test_multiprocessing_evaluate_error():
259259
batch_size = 10
260260
good_batches = 3
261+
workers = 4
261262

262263
def custom_generator():
263264
"""Raises an exception after a few good batches"""
264265
for i in range(good_batches):
265266
yield (np.random.randint(batch_size, 256, (50, 2)),
266-
np.random.randint(batch_size, 2, 50))
267+
np.random.randint(batch_size, 12, 50))
267268
raise RuntimeError
268269

269270
model = Sequential()
270271
model.add(Dense(1, input_shape=(2, )))
271272
model.compile(loss='mse', optimizer='adadelta')
272273

273-
with pytest.raises(StopIteration):
274+
with pytest.raises(RuntimeError):
274275
model.evaluate_generator(
275-
custom_generator(), good_batches + 1, 1,
276-
workers=4, use_multiprocessing=True,
276+
custom_generator(), good_batches * workers + 1, 1,
277+
workers=workers, use_multiprocessing=True,
277278
)
278279

279-
with pytest.raises(StopIteration):
280+
with pytest.raises(RuntimeError):
280281
model.evaluate_generator(
281282
custom_generator(), good_batches + 1, 1,
282283
use_multiprocessing=False,
@@ -299,13 +300,13 @@ def custom_generator():
299300
model.add(Dense(1, input_shape=(5,)))
300301
model.compile(loss='mse', optimizer='adadelta')
301302

302-
with pytest.raises(StopIteration):
303+
with pytest.raises(RuntimeError):
303304
model.predict_generator(
304305
custom_generator(), good_batches * workers + 1, 1,
305306
workers=workers, use_multiprocessing=True,
306307
)
307308

308-
with pytest.raises(StopIteration):
309+
with pytest.raises(RuntimeError):
309310
model.predict_generator(
310311
custom_generator(), good_batches + 1, 1,
311312
use_multiprocessing=False,

0 commit comments

Comments
 (0)