Automatic Differentiation for rigid-body-dynamics Algorithms
adam computes rigid-body dynamics for floating-base robots. Built on Featherstone's algorithms and available across multiple backends:
- π₯ JAX β compile, vectorize, and differentiate with XLA
- π― CasADi β symbolic computation for optimization and control
- π¦ PyTorch β GPU acceleration and batched operations
- π NumPy β simple numerical evaluation
All backends share the same interface and produce numerically consistent results, letting you pick the tool that fits your use case.
# JAX backend
pip install adam-robotics[jax]
# CasADi backend
pip install adam-robotics[casadi]
# PyTorch backend
pip install adam-robotics[pytorch]
# MuJoCo support
pip install adam-robotics[mujoco]
# OpenUSD support
pip install adam-robotics[usd]
# Visualization support
pip install adam-robotics[visualization]
# All backends
pip install adam-robotics[all]# CasADi backend
conda create -n adamenv -c conda-forge adam-robotics-casadi
# JAX backend (Linux/macOS only)
conda create -n adamenv -c conda-forge adam-robotics-jax
# PyTorch backend (Linux/macOS only)
conda create -n adamenv -c conda-forge adam-robotics-pytorch
# OpenUSD support
conda install -n adamenv -c conda-forge openusd
# All backends (Linux/macOS only)
conda create -n adamenv -c conda-forge adam-robotics-allgit clone https://github.com/ami-iit/adam.git
cd adam
pip install .[jax] # or [casadi], [pytorch], [mujoco], [usd], [visualization], [all]Load a robot model and compute dynamics quantities:
Note
Check the JAX installation guide
import adam
from adam.jax import KinDynComputations
import icub_models
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap
# if you want to icub-models https://github.com/robotology/icub-models to retrieve the urdf
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = [
'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll',
'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch',
'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'
]
kinDyn = KinDynComputations(model_path, joints_name_list)
# Set velocity representation (3 options available):
# 1. MIXED_REPRESENTATION (default) - time derivative of base origin position (expressed in world frame) + world-frame angular velocity
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)
# 2. BODY_FIXED_REPRESENTATION - both linear & angular velocity in body frame
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
# 3. INERTIAL_FIXED_REPRESENTATION - world-frame linear & angular velocity
# kinDyn.set_frame_velocity_representation(adam.Representations.INERTIAL_FIXED_REPRESENTATION)
w_H_b = np.eye(4)
joints = np.ones(len(joints_name_list))
M = kinDyn.mass_matrix(w_H_b, joints)
print(M)
w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b, joints)
# JAX functions can also be jitted!
# For example:
def frame_forward_kinematics(w_H_b, joints):
# This is needed since str is not a valid JAX type
return kinDyn.forward_kinematics('frame_name', w_H_b, joints)
jitted_frame_fk = jit(frame_forward_kinematics)
w_H_f = jitted_frame_fk(w_H_b, joints)
# JAX natively supports batching
joints_batch = jnp.tile(joints, (1024, 1))
w_H_b_batch = jnp.tile(w_H_b, (1024, 1, 1))
w_H_f_batch = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch)Note
The first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster!
import casadi as cs
import adam
from adam.casadi import KinDynComputations
import icub_models
import numpy as np
# if you want to icub-models https://github.com/robotology/icub-models to retrieve the urdf
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = [
'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll',
'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch',
'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'
]
kinDyn = KinDynComputations(model_path, joints_name_list)
# Set velocity representation (3 options available):
# 1. MIXED_REPRESENTATION (default) - time derivative of position + world-frame angular velocity
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)
# 2. BODY_FIXED_REPRESENTATION - both linear & angular velocity in body frame
# kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
# 3. INERTIAL_FIXED_REPRESENTATION - world-frame linear & angular velocity
# kinDyn.set_frame_velocity_representation(adam.Representations.INERTIAL_FIXED_REPRESENTATION)
w_H_b = np.eye(4)
joints = np.ones(len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))
# If you want to use the symbolic version
w_H_b = cs.SX.eye(4)
joints = cs.SX.sym('joints', len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))
# This is usable also with casadi.MX
w_H_b = cs.MX.eye(4)
joints = cs.MX.sym('joints', len(joints_name_list))
M = kinDyn.mass_matrix_fun()
print(M(w_H_b, joints))import adam
from adam.pytorch import KinDynComputations
import icub_models
import numpy as np
# if you want to icub-models https://github.com/robotology/icub-models to retrieve the urdf
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = [
'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll',
'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch',
'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'
]
kinDyn = KinDynComputations(model_path, joints_name_list)
# choose the representation you want to use the body fixed representation
kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
# or, if you want to use the mixed representation (that is the default)
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)
w_H_b = np.eye(4)
joints = np.ones(len(joints_name_list))
M = kinDyn.mass_matrix(w_H_b, joints)
print(M)Use pytorch.KinDynComputations to process also multiple configurations.
Note
There is a class pytorch.KinDynComputationsBatch that has the functionality of pytorch.KinDynComputations. It exists to avoid API changes in existing code. New users should prefer pytorch.KinDynComputations for both single and batched computations.
import adam
from adam.pytorch import KinDynComputations
import icub_models
# if you want to icub-models
model_path = icub_models.get_model_file("iCubGazeboV2_5")
# The joint list
joints_name_list = [
'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll',
'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch',
'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'
]
kinDyn = KinDynComputations(model_path, joints_name_list)
# choose the representation you want to use the body fixed representation
kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION)
# or, if you want to use the mixed representation (that is the default)
kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION)
w_H_b = np.eye(4)
joints = np.ones(len(joints_name_list))
num_samples = 1024
w_H_b_batch = torch.tensor(np.tile(w_H_b, (num_samples, 1, 1)), dtype=torch.float32)
joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float32)
M = kinDyn.mass_matrix(w_H_b_batch, joints_batch)
w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch)adam supports loading models directly from MuJoCo MjModel objects. This is useful when working with MuJoCo simulations or models from robot_descriptions.
import mujoco
import numpy as np
from adam import Representations
from adam.numpy import KinDynComputations
# Load a MuJoCo model (e.g., from robot_descriptions)
from robot_descriptions.loaders.mujoco import load_robot_description
mj_model = load_robot_description("g1_mj_description")
# Create KinDynComputations directly from MuJoCo model
kinDyn = KinDynComputations.from_mujoco_model(mj_model)
# Set velocity representation (default is mixed)
kinDyn.set_frame_velocity_representation(Representations.MIXED_REPRESENTATION)
# Set gravity to match MuJoCo settings
kinDyn.g = np.concatenate([mj_model.opt.gravity, np.zeros(3)])
# Create MuJoCo data and set state
mj_data = mujoco.MjData(mj_model)
mj_data.qpos[:] = your_qpos # Your configuration
mj_data.qvel[:] = your_qvel # Your velocities
mujoco.mj_forward(mj_model, mj_data)
# Extract base transform from MuJoCo state (for floating-base robots)
from scipy.spatial.transform import Rotation as R
base_rot = R.from_quat(mj_data.qpos[3:7], scalar_first=True).as_matrix()
base_pos = mj_data.qpos[0:3]
w_H_b = np.eye(4)
w_H_b[:3, :3] = base_rot
w_H_b[:3, 3] = base_pos
# Joint positions (excluding free joint).
# Be sure the serialization between mujoco and adam is the same
joints = mj_data.qpos[7:]
# Compute dynamics quantities
M = kinDyn.mass_matrix(w_H_b, joints)
com_pos = kinDyn.CoM_position(w_H_b, joints)
J = kinDyn.jacobian('frame_name', w_H_b, joints)Note
When using from_mujoco_model, adam automatically extracts the joint names from the MuJoCo model. You can also specify use_mujoco_actuators=True to use actuator names instead of joint names.
Warning
MuJoCo uses a different velocity representation for the floating base. The free joint velocity in MuJoCo is [I \dot{p}_B, B \omega_B], while mixed representation uses [I \dot{p}_B, I \omega_B]. Make sure to handle this transformation when comparing with MuJoCo computations.
adam supports loading robot models directly from OpenUSD files and exporting models to OpenUSD.
Loading directly from a USD file:
import numpy as np
from adam import Representations
from adam.numpy import KinDynComputations
# Load directly from any existing USD file
kinDyn = KinDynComputations.from_usd(
"robot.usd",
robot_prim_path="/Robot",
joints_name_list=["joint_1", "joint_2"],
)
kinDyn.set_frame_velocity_representation(Representations.MIXED_REPRESENTATION)
# Compute quantities as usual
w_H_b = np.eye(4)
q = np.zeros(kinDyn.NDoF)
M = kinDyn.mass_matrix(w_H_b, q)
com = kinDyn.CoM_position(w_H_b, q)Exporting a model to USD (e.g. to convert a URDF to USD):
from adam.model import Model, build_model_factory
from adam.numpy.numpy_like import SpatialMath
model_path = "robot.urdf"
joints_name_list = ["joint_1", "joint_2"]
factory = build_model_factory(description=model_path, math=SpatialMath())
model = Model.build(factory=factory, joints_name_list=joints_name_list)
# Export to USD
model.to_usd("robot.usd", robot_prim_path="/Robot")adam also provides a lightweight visualization layer based on viser. It works with the same normalized model API, so URDF, MuJoCo, and USD models can all be rendered through the same interface.
For quick inspection from the terminal, use the bundled viewer command:
adam-model-view --urdf path/to/robot.urdf
adam-model-view --mujoco path/to/model.xml
adam-model-view --usd path/to/robot.usd --robot-prim-path /Robotimport numpy as np
import icub_models
from adam.numpy import KinDynComputations
from adam.visualization import Visualizer
kindyn = KinDynComputations.from_urdf(
icub_models.get_model_file("iCubGazeboV2_5")
)
visualizer = Visualizer(
world_axes=True,
ground=True,
camera_position=(2.5, -2.0, 1.5),
camera_look_at=(0.0, 0.0, 0.6),
)
robot = visualizer.add_model(kindyn, root_name="/icub")
w_H_b = np.eye(4)
w_H_b[2, 3] = 0.6
q = np.zeros(kindyn.NDoF)
robot.update(w_H_b, q)
robot.add_joint_sliders(folder_name="iCub")With other model sources, only the loader changes:
# MuJoCo
kindyn = KinDynComputations.from_mujoco_model(mj_model)
# USD
kindyn = KinDynComputations.from_usd("robot.usd", robot_prim_path="/Robot")Batched visualization is available through the same ModelHandle API by passing
num_instances to add_model(). The model is rendered with viser batched meshes,
and each update() call accepts base transforms with shape (B, 4, 4) and joint
positions with shape (B, N):
num_instances = 16
robot = visualizer.add_model(
kindyn,
root_name="/g1_batch",
num_instances=num_instances,
)
w_H_b = np.repeat(np.eye(4)[None, :, :], num_instances, axis=0)
q = np.zeros((num_instances, kindyn.NDoF))
robot.update(w_H_b, q)For a ready-to-run MuJoCo example with a batch of Unitree G1 robots:
python examples/visualization/visualize_g1_batch.pyThe batched example animates left_hip_pitch_joint by default with a phase offset
per instance. Frames and joint sliders are scalar-model conveniences and are not
enabled for batched models.
Examples are available in:
examples/visualization/visualize_mujoco.pyexamples/visualization/visualize_g1_batch.pyexamples/visualization/visualize_usd.pyexamples/visualization/visualize_multi_robot.pyexamples/visualization/visualize_urdf.py
By default adam uses the root link of the URDF as the floating base. You can choose any other link as the floating base at construction time or at runtime:
from adam.numpy import KinDynComputations
# Construction-time: use "l_ankle_2" as the floating base
kinDyn = KinDynComputations(model_path, joints_name_list, root_link="l_ankle_2")
# Runtime setter (rebuilds the kinematic tree)
kinDyn.set_root_link("chest")adam re-roots the kinematic tree internally by reversing the joints along the path from the new root to the original URDF root. All dynamics quantities (mass matrix, Jacobians, bias forces, β¦) are then consistent with the new floating base. Results match iDynTree's setFloatingBase API.
Important
Joint serialization is independent of the floating base. The order of joints in the joints vector is fixed by joints_name_list at construction time and never changes β only the base state inputs (w_H_b and base velocity) reflect the new floating base.
Note
root_link must be a link name, not a frame name. Passing a frame name raises a ValueError listing valid link names. Frames remain valid as targets for forward kinematics and Jacobians.
import casadi as cs
import numpy as np
import adam
from adam.casadi import KinDynComputations
from adam.casadi.inverse_kinematics import InverseKinematics, TargetType
# Load model
model_path = ...
joints_name_list = [...]
# Create IK solver
ik = InverseKinematics(model_path, joints_name_list)
ik.add_target("l_sole", target_type=TargetType.POSE, as_soft_constraint=True, weight=1.0)
# Update target and solve
desired_position = np.array([0.3, 0.2, 1.0])
desired_orientation = np.eye(3)
ik.update_target("l_sole", (desired_position, desired_orientation))
ik.solve()
# Get solution
w_H_b_sol, q_sol = ik.get_solution()
print("Base pose:\n", w_H_b_sol)
print("Joint values:\n", q_sol)- Kinematics: Forward kinematics, Jacobians (frame and base)
- Dynamics: Mass matrix, Coriolis/centrifugal forces and gravity, Articulated body algorithm
- Centroidal: Centroidal momentum matrix and derivatives
- Configurable floating base: Any link can be set as the floating base
- Differentiation: Get gradients, Jacobians, and Hessians automatically
- Symbolic: Build computation graphs with CasADi for optimization
- Batched: Process multiple configurations in parallel with PyTorch
- Visualization: Render URDF, MuJoCo, and USD robot models with viser
See the full documentation for detailed API reference, more examples, and theory.
Run tests to verify installation:
pip install .[test] # Install test dependencies
pytest tests/See tests/ folder for comprehensive examples across all backends.
Found a bug or have a feature idea? Open an issue or submit a pull request! π
Warning
This is a project under active development. API may change.
BSD 3-Clause License β see LICENSE file.
Built on Roy Featherstone's Rigid Body Dynamics Algorithms and references like Traversaro's A Unified View of the Equations of Motion.