-
Notifications
You must be signed in to change notification settings - Fork 39
Description
I encountered an error while running either of the control examples:
examples/cart_pole.py
or
examples/double_cart_pole.py
The error is given below
Traceback (most recent call last):
File "/home/alinjar/hydrax/examples/cart_pole.py", line 71, in
run_interactive(
File "/home/alinjar/hydrax/hydrax/simulation/deterministic.py", line 94, in run_interactive
policy_params, rollouts = jit_optimize(mjx_data, policy_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/hydrax/hydrax/alg_base.py", line 143, in optimize
new_mean = self.interp_func(new_tk, tk, params.mean[None, ...])[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/hydrax/hydrax/utils/spline.py", line 36, in interp_cubic
return interp1d(tq, tk, knots, method="cubic2", extrap=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_spline.py", line 570, in interp1d
fx = approx_df(x, f, method, axis, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_fd_derivs.py", line 49, in approx_df
out = _cubic2(x, f, axis, bc=bc, dtype=dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_fd_derivs.py", line 286, in _cubic2
fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/numpy/vectorize.py", line 347, in wrapped
result = vectorized_func(*squeezed_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/numpy/vectorize.py", line 144, in wrapped
out = func(*args)
^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/interpax/_fd_derivs.py", line 285, in
solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/lineax/_solve.py", line 820, in linear_solve
solution, result, stats = eqxi.filter_primitive_bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alinjar/miniconda3/envs/hydrax/lib/python3.12/site-packages/equinox/internal/_primitive.py", line 273, in filter_primitive_bind
flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot create weak reference to 'Flatten' object
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The installation is standard as per Readme.md