-
Notifications
You must be signed in to change notification settings - Fork 5
Description
Hi,
I can train the model using my dataset but when using eval mode I always got memory error. is there any solution?
bytes.
Traceback (most recent call last):
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 31, in
app.run(main)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 26, in main
eval.evaluate(FLAGS.config)
File "/scratch1/rwibawa/cvit/ns/eval.py", line 81, in evaluate
pred = model.apply(state.params, x, coords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 364, in call
x = CrossAttnBlock(
^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 133, in call
x = nn.MultiHeadDotProductAttention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 674, in call
x = self.attention_fn(*attn_args, **attn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 266, in dot_product_attention
attn_weights = dot_product_attention_weights(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 132, in dot_product_attention_weights
attn_weights = einsum('...qhd,...khd->...hqk', query, key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 9747, in einsum
return einsum(operands, contractions, precision,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359738368 bytes.
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.