Skip to content

GPT-2 OOM when using more than 4 attention blocks #7791

Open
@miladm

Description

🐛 Bug

The following GPT-2 code OOMs on TPU v4-8 with >4 attention layers (i.e. n_layer > 4). This issue occurs when I use the DDP() wrapper. On a single TPU chip, with the default n_layer = 12, the code runs successfully.

Code to repro is available at think link. The preceding commit is single-chip and runs successfully.

Error:

"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 78, in _run_thread_per_device
    replica_results = list(
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 71, in _thread_fn
    return fn()
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 190, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/root/build-nanogpt-ptxla/train_gpt2.py", line 412, in _mp_fn
    train_gpt()
  File "/root/build-nanogpt-ptxla/train_gpt2.py", line 382, in train_gpt
    loss.backward()
  File "/usr/local/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 288, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1879, in backward
    out = call_compiled_backward()
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1818, in call_compiled_backward
    out = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 94, in g
    return f(*args)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py", line 36, in fwd
    compiled_graph = bridge.extract_compiled_graph(model, args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 610, in extract_compiled_graph
    return extract_compiled_graph_helper(xla_model, xla_args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 704, in extract_compiled_graph_helper
    extract_internal(fused_module), node.args, None)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py", line 432, in extract_internal
    xm.mark_step(reset_scope=False)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1055, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 514.06M. That was not possible. There are 8.81M free.; (0x0x0_HBM0)
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/build-nanogpt-ptxla/train_gpt2.py", line 418, in <module>
    xmp.spawn(_mp_fn)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 214, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 174, in run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 175, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 514.06M. That was not possible. There are 8.81M free.; (0x0x0_HBM0)

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions