-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
- STEP 1
Working on google colab (runtime type: T4 GPU) the following code is used to install Ferminet and runs without errors:
!git clone https://github.com/google-deepmind/ferminet.git
%cd ferminet
!pip install -e .
!pip install --upgrade pip
# NVIDIA CUDA 12 installation
!pip install --upgrade "jax[cuda12]"- STEP 2
Run the following code which is the script shown on the ferminet github page, used to check if the network is installed and working correctly:
import sys
from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train
# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)
# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1) # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]
# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100
train.train(cfg)Running this script, the pretraining of the network is executed correctly but then you get the error in output
AttributeError: module 'jax.core' has no attribute 'DebugInfo'
The error is generated by the module tag_graph_matcher.py:
[/usr/local/lib/python3.11/dist-packages/kfac_jax/_src/tag_graph_matcher.py](https://localhost:8080/#) in make_jax_graph(func, func_args, params_index, name, compute_only_loss_tags, clean_broadcasts, tag_ctor)
386 debug_info = closed_jaxpr.jaxpr.debug_info
387 if debug_info is not None:
--> 388 debug_info = jax.core.DebugInfo(
389 debug_info.traced_for,
390 debug_info.func_src_info,
Specifications
import jax; jax.print_environment_info()
jax: 0.5.0
jaxlib: 0.5.0
numpy: 1.26.4
python: 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0]
device info: Tesla T4-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='f2740d6493af', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')
Metadata
Metadata
Assignees
Labels
No labels