Skip to content

eval memory issue #5

@ramdhan1989

Description

@ramdhan1989

Hi,
I can train the model using my dataset but when using eval mode I always got memory error. is there any solution?

bytes.
Traceback (most recent call last):
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 31, in
app.run(main)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 26, in main
eval.evaluate(FLAGS.config)
File "/scratch1/rwibawa/cvit/ns/eval.py", line 81, in evaluate
pred = model.apply(state.params, x, coords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 364, in call
x = CrossAttnBlock(
^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 133, in call
x = nn.MultiHeadDotProductAttention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 674, in call
x = self.attention_fn(*attn_args, **attn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 266, in dot_product_attention
attn_weights = dot_product_attention_weights(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 132, in dot_product_attention_weights
attn_weights = einsum('...qhd,...khd->...hqk', query, key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 9747, in einsum
return einsum(operands, contractions, precision,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359738368 bytes.
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions