Skip to content

AttributeError raised by tag_graph_matcher.py when using Ferminet #314

@danielepremarini

Description

@danielepremarini
  • STEP 1
    Working on google colab (runtime type: T4 GPU) the following code is used to install Ferminet and runs without errors:
!git clone https://github.com/google-deepmind/ferminet.git
%cd ferminet
!pip install -e .
!pip install --upgrade pip

# NVIDIA CUDA 12 installation
!pip install --upgrade "jax[cuda12]"
  • STEP 2
    Run the following code which is the script shown on the ferminet github page, used to check if the network is installed and working correctly:
import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

Running this script, the pretraining of the network is executed correctly but then you get the error in output

AttributeError: module 'jax.core' has no attribute 'DebugInfo'

The error is generated by the module tag_graph_matcher.py:

[/usr/local/lib/python3.11/dist-packages/kfac_jax/_src/tag_graph_matcher.py](https://localhost:8080/#) in make_jax_graph(func, func_args, params_index, name, compute_only_loss_tags, clean_broadcasts, tag_ctor)
    386     debug_info = closed_jaxpr.jaxpr.debug_info
    387     if debug_info is not None:
--> 388       debug_info = jax.core.DebugInfo(
    389           debug_info.traced_for,
    390           debug_info.func_src_info,

Specifications

import jax; jax.print_environment_info()
jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
device info: Tesla T4-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='f2740d6493af', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions