Skip to content

Colab error with haiku/jax dependency #5

@hstojic

Description

@hstojic

A following error appears on executing the colab notebook:

AttributeError                            Traceback (most recent call last)
<ipython-input-7-4863fdbb2553> in <module>()
     13     num_ensemble=FLAGS.index_dim,
     14     prior_scale=FLAGS.prior_scale,
---> 15     seed=FLAGS.seed,
     16 )
     17 

4 frames
/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init)
    137     """Ensemble of MLPs with matched prior functions."""
    138     mlp_priors = make_mlp_ensemble_prior_fns(
--> 139         output_sizes, dummy_input, num_ensemble, seed)
    140     enn = priors.EnnWithAdditivePrior(
    141         enn=MLPEnsembleEnn(

/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init)
     90     return hk.Sequential(layers)(x)
     91 
---> 92   transformed = hk.without_apply_rng(hk.transform(net_fn))
     93 
     94   prior_fns = []

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
    301         "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
    302 
--> 303   return without_state(transform_with_state(f))
    304 
    305 

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
    359   """
    360   analytics.log_once("transform_with_state")
--> 361   check_not_jax_transformed(f)
    362 
    363   unexpected_tracer_hint = (

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
    306 def check_not_jax_transformed(f):
    307   # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
--> 308   if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)):  # pytype: disable=name-error
    309     raise ValueError("A common error with Haiku is to pass an already jit "
    310                      "(or pmap) decorated function into hk.transform (e.g. "

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions