Open
Description
🐛 Bug
The following GPT-2 code OOMs on TPU v4-8 with >4 attention layers (i.e. n_layer > 4
). This issue occurs when I use the DDP()
wrapper. On a single TPU chip, with the default n_layer = 12
, the code runs successfully.
Code to repro is available at think link. The preceding commit is single-chip and runs successfully.
Error:
"""
Traceback (most recent call last):
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 78, in _run_thread_per_device
replica_results = list(
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 71, in _thread_fn
return fn()
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 190, in __call__
self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
File "/root/build-nanogpt-ptxla/train_gpt2.py", line 412, in _mp_fn
train_gpt()
File "/root/build-nanogpt-ptxla/train_gpt2.py", line 382, in train_gpt
loss.backward()
File "/usr/local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 288, in backward
_engine_run_backward(
File "/usr/local/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
return user_fn(self, *args)
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1879, in backward
out = call_compiled_backward()
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1818, in call_compiled_backward
out = call_func_at_runtime_with_args(
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 94, in g
return f(*args)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 36, in fwd
compiled_graph = bridge.extract_compiled_graph(model, args)
File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 610, in extract_compiled_graph
return extract_compiled_graph_helper(xla_model, xla_args)
File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 704, in extract_compiled_graph_helper
extract_internal(fused_module), node.args, None)
File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 432, in extract_internal
xm.mark_step(reset_scope=False)
File "/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1055, in mark_step
torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 514.06M. That was not possible. There are 8.81M free.; (0x0x0_HBM0)
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/build-nanogpt-ptxla/train_gpt2.py", line 418, in <module>
xmp.spawn(_mp_fn)
File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 214, in spawn
run_multiprocess(spawn_fn, start_method=start_method)
File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 174, in run_multiprocess
replica_results = list(
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 175, in <genexpr>
itertools.chain.from_iterable(
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 514.06M. That was not possible. There are 8.81M free.; (0x0x0_HBM0)