Skip to content

Bug: "TypeError: cannot create weak reference to 'Flatten' object" during Cart-Pole and Double Cart-Pole Simulation #58

@alinjar1996

Description

@alinjar1996

I encountered an error while running either of the control examples:

examples/cart_pole.py
or
examples/double_cart_pole.py

The error is given below

Traceback (most recent call last):
File "/home/alinjar/hydrax/examples/cart_pole.py", line 71, in
run_interactive(
File "/home/alinjar/hydrax/hydrax/simulation/deterministic.py", line 94, in run_interactive
policy_params, rollouts = jit_optimize(mjx_data, policy_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/hydrax/hydrax/alg_base.py", line 143, in optimize
new_mean = self.interp_func(new_tk, tk, params.mean[None, ...])[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/hydrax/hydrax/utils/spline.py", line 36, in interp_cubic
return interp1d(tq, tk, knots, method="cubic2", extrap=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_spline.py", line 570, in interp1d
fx = approx_df(x, f, method, axis, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_fd_derivs.py", line 49, in approx_df
out = _cubic2(x, f, axis, bc=bc, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_fd_derivs.py", line 286, in _cubic2
fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/numpy/vectorize.py", line 347, in wrapped
result = vectorized_func(*squeezed_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/numpy/vectorize.py", line 144, in wrapped
out = func(*args)
^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_fd_derivs.py", line 285, in
solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/lineax/_solve.py", line 820, in linear_solve
solution, result, stats = eqxi.filter_primitive_bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/equinox/internal/_primitive.py", line 273, in filter_primitive_bind
flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot create weak reference to 'Flatten' object

For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The installation is standard as per Readme.md

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions