Skip to content

Commit 67f3291

Browse files
committed
Merge remote-tracking branch 'upstream/main' into sensor_touch
and updates to utilize ray geom functionality with touch sensor
2 parents 7bc7c8a + 406d46b commit 67f3291

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+12751
-6847
lines changed

.github/workflows/ci.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,22 @@ jobs:
7070
- name: Test with pytest
7171
run: |
7272
pytest
73+
74+
kernel_analyzer:
75+
name: Kernel analyzer
76+
runs-on: ubuntu-latest
77+
steps:
78+
- uses: actions/checkout@v3
79+
- name: Set up Python ${{ matrix.python-version }}
80+
uses: actions/setup-python@v4
81+
with:
82+
python-version: ${{ matrix.python-version }}
83+
- name: Install dependencies
84+
run: |
85+
python -m pip install --upgrade pip
86+
pip install mujoco --pre -f https://py.mujoco.org/
87+
pip install warp-lang --pre --upgrade -f https://pypi.nvidia.com/warp-lang/
88+
pip install -e .[dev,cpu]
89+
- name: Run kernel analyzer
90+
run: |
91+
python contrib/kernel_analyzer/kernel_analyzer/cli.py mujoco_warp/_src/*.py --output=github

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,7 @@ Temporary Items
208208
.apdisk
209209

210210
# Node stuff
211-
node_modules
211+
node_modules
212+
213+
# VS Code
214+
.vscode/

.pre-commit-config.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.4.2 # Use the latest Ruff version
4+
hooks:
5+
- id: ruff
6+
args: [--fix, --exit-non-zero-on-fix]
7+
- id: ruff-format
8+
- repo: https://github.com/pre-commit/pre-commit-hooks
9+
rev: v4.6.0 # Use the latest version
10+
hooks:
11+
- id: check-yaml
12+
- repo: local
13+
hooks:
14+
- id: kernel-analyzer
15+
name: kernel-analyzer
16+
entry: python contrib/kernel_analyzer/kernel_analyzer/cli.py
17+
language: system
18+
types: [python]

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ pytest
3939

4040
Should print out something like `XX passed in XX.XXs` at the end!
4141

42+
If you plan to write Warp kernels for MJWarp, please use the `kernel_analyzer` vscode plugin located in `contrib/kernel_analyzer`.
43+
Please see the `README.md` there for details on how to install it and use it. The same kernel analyzer will be run on any PR
44+
you open, so it's important to fix any issues it reports.
45+
4246
# Compatibility
4347

4448
The following features are implemented:

conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818

1919
def pytest_addoption(parser):
20-
parser.addoption(
21-
"--cpu", action="store_true", default=False, help="run tests with cpu"
22-
)
20+
parser.addoption("--cpu", action="store_true", default=False, help="run tests with cpu")
2321
parser.addoption(
2422
"--verify_cuda",
2523
action="store_true",

contrib/apptronik_apollo_locomotion.ipynb

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@
176176
" params[\"site_xpos\"] = True\n",
177177
" out_batched = mjx.Data(**params)\n",
178178
"\n",
179-
" qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = (\n",
180-
" jax_mjwarp_step(d.ctrl, d.qpos, d.qvel, d.qacc_warmstart)\n",
179+
" qpos, qvel, xpos, xmat, qacc_warmstart, subtree_com, cvel, site_xpos = jax_mjwarp_step(\n",
180+
" d.ctrl, d.qpos, d.qvel, d.qacc_warmstart\n",
181181
" )\n",
182182
" d = d.replace(\n",
183183
" qpos=qpos,\n",
@@ -355,9 +355,7 @@
355355
" obs = self._get_obs(data, state.info, contact)\n",
356356
" done = self._get_termination(data)\n",
357357
"\n",
358-
" rewards = self._get_reward(\n",
359-
" data, action, state.info, state.metrics, done, first_contact, contact\n",
360-
" )\n",
358+
" rewards = self._get_reward(data, action, state.info, state.metrics, done, first_contact, contact)\n",
361359
" rewards = {k: v * self._config.reward_config.scales[k] for k, v in rewards.items()}\n",
362360
" reward = sum(rewards.values()) * self.dt\n",
363361
"\n",
@@ -390,9 +388,7 @@
390388
" fall_termination = data.xpos[self._head_body_id, 2] < 1.0\n",
391389
" return fall_termination\n",
392390
"\n",
393-
" def _get_obs(\n",
394-
" self, data: mjx.Data, info: dict[str, Any], contact: jax.Array\n",
395-
" ) -> mjx_env.Observation:\n",
391+
" def _get_obs(self, data: mjx.Data, info: dict[str, Any], contact: jax.Array) -> mjx_env.Observation:\n",
396392
" cos = jp.cos(info[\"phase\"])\n",
397393
" sin = jp.sin(info[\"phase\"])\n",
398394
" phase = jp.concatenate([cos, sin])\n",
@@ -424,36 +420,24 @@
424420
" del metrics # Unused.\n",
425421
" return {\n",
426422
" # Tracking rewards.\n",
427-
" \"tracking_lin_vel\": self._reward_tracking_lin_vel(\n",
428-
" info[\"command\"], self._get_global_linvel(data, self._torso_id)\n",
429-
" ),\n",
430-
" \"tracking_ang_vel\": self._reward_tracking_ang_vel(\n",
431-
" info[\"command\"], self._get_global_angvel(data, self._torso_id)\n",
432-
" ),\n",
423+
" \"tracking_lin_vel\": self._reward_tracking_lin_vel(info[\"command\"], self._get_global_linvel(data, self._torso_id)),\n",
424+
" \"tracking_ang_vel\": self._reward_tracking_ang_vel(info[\"command\"], self._get_global_angvel(data, self._torso_id)),\n",
433425
" # Base-related rewards.\n",
434-
" \"ang_vel_xy\": self._cost_ang_vel_xy(\n",
435-
" self._get_global_angvel(data, self._torso_id)\n",
436-
" ),\n",
426+
" \"ang_vel_xy\": self._cost_ang_vel_xy(self._get_global_angvel(data, self._torso_id)),\n",
437427
" \"orientation\": self._cost_orientation(self._get_z_frame(data, self._torso_id)),\n",
438428
" # Energy related rewards.\n",
439-
" \"action_rate\": self._cost_action_rate(\n",
440-
" action, info[\"last_act\"], info[\"last_last_act\"]\n",
441-
" ),\n",
429+
" \"action_rate\": self._cost_action_rate(action, info[\"last_act\"], info[\"last_last_act\"]),\n",
442430
" # Feet related rewards.\n",
443431
" \"feet_slip\": self._cost_feet_slip(data, contact, info),\n",
444-
" \"feet_air_time\": self._reward_feet_air_time(\n",
445-
" info[\"feet_air_time\"], first_contact, info[\"command\"]\n",
446-
" ),\n",
432+
" \"feet_air_time\": self._reward_feet_air_time(info[\"feet_air_time\"], first_contact, info[\"command\"]),\n",
447433
" \"feet_phase\": self._reward_feet_phase(\n",
448434
" data,\n",
449435
" info[\"phase\"],\n",
450436
" self._config.reward_config.max_foot_height,\n",
451437
" info[\"command\"],\n",
452438
" ),\n",
453439
" # Pose related rewards.\n",
454-
" \"joint_deviation_hip\": self._cost_joint_deviation_hip(\n",
455-
" data.qpos[7:], info[\"command\"]\n",
456-
" ),\n",
440+
" \"joint_deviation_hip\": self._cost_joint_deviation_hip(data.qpos[7:], info[\"command\"]),\n",
457441
" \"joint_deviation_knee\": self._cost_joint_deviation_knee(data.qpos[7:]),\n",
458442
" \"dof_pos_limits\": self._cost_joint_pos_limits(data.qpos[7:]),\n",
459443
" \"pose\": self._cost_pose(data.qpos[7:]),\n",
@@ -504,17 +488,13 @@
504488
"\n",
505489
" # Energy related rewards.\n",
506490
"\n",
507-
" def _cost_action_rate(\n",
508-
" self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array\n",
509-
" ) -> jax.Array:\n",
491+
" def _cost_action_rate(self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array) -> jax.Array:\n",
510492
" del last_last_act # Unused.\n",
511493
" return jp.sum(jp.square(act - last_act))\n",
512494
"\n",
513495
" # Feet related rewards.\n",
514496
"\n",
515-
" def _cost_feet_slip(\n",
516-
" self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]\n",
517-
" ) -> jax.Array:\n",
497+
" def _cost_feet_slip(self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]) -> jax.Array:\n",
518498
" del info # Unused.\n",
519499
" body_vel = self._get_global_linvel(data, self._torso_id)[:2]\n",
520500
" reward = jp.sum(jp.linalg.norm(body_vel, axis=-1) * contact)\n",
@@ -534,9 +514,7 @@
534514
" reward = jp.sum(air_time)\n",
535515
" return reward\n",
536516
"\n",
537-
" def get_rz(\n",
538-
" phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08\n",
539-
" ) -> jax.Array:\n",
517+
" def get_rz(phi: Union[jax.Array, float], swing_height: Union[jax.Array, float] = 0.08) -> jax.Array:\n",
540518
" def cubic_bezier_interpolation(y_start, y_end, x):\n",
541519
" y_diff = y_end - y_start\n",
542520
" bezier = x**3 + 3 * (x**2 * (1 - x))\n",
@@ -556,9 +534,7 @@
556534
" ) -> jax.Array:\n",
557535
" # Reward for tracking the desired foot height.\n",
558536
" foot_pos = data.site_xpos[self._feet_site_id]\n",
559-
" foot_pos = jp.array(\n",
560-
" [jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)]\n",
561-
" )\n",
537+
" foot_pos = jp.array([jp.mean(foot_pos[0:4], axis=0), jp.mean(foot_pos[4:8], axis=0)])\n",
562538
" foot_z = foot_pos[..., -1]\n",
563539
" rz = Joystick.get_rz(phase, swing_height=foot_height)\n",
564540
" error = jp.sum(jp.square(foot_z - rz))\n",
@@ -606,12 +582,8 @@
606582
" def sample_command(self, rng: jax.Array) -> jax.Array:\n",
607583
" rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)\n",
608584
"\n",
609-
" lin_vel_x = jax.random.uniform(\n",
610-
" rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1]\n",
611-
" )\n",
612-
" lin_vel_y = jax.random.uniform(\n",
613-
" rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1]\n",
614-
" )\n",
585+
" lin_vel_x = jax.random.uniform(rng1, minval=self._config.lin_vel_x[0], maxval=self._config.lin_vel_x[1])\n",
586+
" lin_vel_y = jax.random.uniform(rng2, minval=self._config.lin_vel_y[0], maxval=self._config.lin_vel_y[1])\n",
615587
" ang_vel_yaw = jax.random.uniform(\n",
616588
" rng3,\n",
617589
" minval=self._config.ang_vel_yaw[0],\n",
@@ -704,9 +676,7 @@
704676
"network_factory = ppo_networks.make_ppo_networks\n",
705677
"if \"network_factory\" in ppo_params:\n",
706678
" del ppo_training_params[\"network_factory\"]\n",
707-
" network_factory = functools.partial(\n",
708-
" ppo_networks.make_ppo_networks, **ppo_params.network_factory\n",
709-
" )\n",
679+
" network_factory = functools.partial(ppo_networks.make_ppo_networks, **ppo_params.network_factory)\n",
710680
"\n",
711681
"train_fn = functools.partial(\n",
712682
" ppo.train,\n",

contrib/kernel_analyzer/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ Kernel Analyzer checks warp kernels to ensure correctness and conformity. It co
55
# CLI usage
66

77
```bash
8-
python contrib/kernel_analyzer/kernel_analyzer/cli.py --files somefile.py --types mujoco_warp/_src/types.py
8+
python contrib/kernel_analyzer/kernel_analyzer/cli.py mujoco_warp/_src/*.py --types mujoco_warp/_src/types.py
99
```
1010

1111
# CLI for github CI
1212

1313
```bash
14-
python contrib/kernel_analyzer/kernel_analyzer/cli.py --files somefile.py --types mujoco_warp/_src/types.py
14+
python contrib/kernel_analyzer/kernel_analyzer/cli.py mujoco_warp/_src/*.py --types mujoco_warp/_src/types.py --output=github
1515
```
1616

1717
# VSCode plugin
@@ -21,16 +21,17 @@ Enjoy kernel analysis directly within vscode.
2121
## Installing kernel analyzer
2222

2323
1. Create a new python env (`python3 -m venv env`) or use your existing mjwarp env (`source env/bin/activate`).
24-
2. Within the python env, install the kernel analyzer python dependencies:
24+
2. Within the python env, install the kernel analyzer's python dependencies by pip installing MJWarp dev:
2525
```bash
26-
pip install -r contrib/kernel_analyzer/kernel_analyzer/requirements.txt
26+
cd mujoco_warp
27+
pip install -e .[dev]
2728
```
2829
3. Inside vscode, navigate to `contrib/kernel_analyzer/`
2930
4. Right click on `kernel-analyzer-{version}.vsix` file
3031
5. Select "Install Extension VSIX"
3132
6. Open vscode settings and navigate to `Extensions > Kernel Analyzer`
32-
7. Set **Python Path** to the `bin/python` of the env you set up in step 1, e.g. `/home/$USER/work/mujoco_warp/env/bin/python`
33-
8. Set **Types Path** to the location of `types.py` in your checked out code, e.g. `/home/$USER/work/mujoco_warp/mujoco_warp/_src/types.py`
33+
7. Set Python Path to the `bin/python` of the env you set up in step 1, e.g. `/home/$USER/work/mujoco_warp/env/bin/python`
34+
8. Set Types Path to the location of `types.py` in your checked out code, e.g. `/home/$USER/work/mujoco_warp/mujoco_warp/_src/types.py`
3435

3536
## Plugin Development
3637

@@ -60,4 +61,4 @@ Create a debug configuration in `.vscode/launch.json`:
6061

6162
```bash
6263
npm run package
63-
```
64+
```

contrib/kernel_analyzer/client/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "kernel-analyzer-client",
33
"description": "Client-side code (TypeScript) for the Kernel Analyzer VS Code extension",
4-
"version": "0.1.0",
4+
"version": "0.2.0",
55
"private": true,
66
"license": "Apache-2.0",
77
"engines": {

0 commit comments

Comments
 (0)