Skip to content

Commit 7af7aa8

Browse files
NemantorMayankm96
andcommitted
Fixes RSL-RL ONNX exporter for empirical normalization (#78)
The current onnx exporter does not export the empirical normalization layer. This MR adds the empirical normalization exporting to the JIT and ONNX exporters for RSL-RL. - Bug fix (non-breaking change which fixes an issue) - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./orbit.sh --format` - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [x] I have run all the tests with `./orbit.sh --test` and they pass (some did timeout) - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [ ] I have added my name to the `CONTRIBUTORS.md` or my name already exists there --------- Co-authored-by: Mayank Mittal <[email protected]>
1 parent 90f6fb1 commit 7af7aa8

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

source/extensions/omni.isaac.orbit_tasks/config/extension.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
# Description
77
title = "ORBIT Environments"

source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
Changelog
22
---------
33

4+
0.6.2 (2024-05-31)
5+
~~~~~~~~~~~~~~~~~~
6+
7+
Added
8+
^^^^^
9+
10+
* Added exporting of empirical normalization layer to ONNX and JIT when exporting the model using
11+
:meth:`omni.isaac.orbit.actuators.ActuatorNetMLP.export` method. Previously, the normalization layer
12+
was not exported to the ONNX and JIT models. This caused the exported model to not work properly
13+
when used for inference.
14+
15+
416
0.6.1 (2024-04-16)
517
~~~~~~~~~~~~~~~~~~
618

source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/exporter.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,37 @@
88
import torch
99

1010

11-
def export_policy_as_jit(actor_critic: object, path: str, filename="policy.pt"):
11+
def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
1212
"""Export policy into a Torch JIT file.
1313
1414
Args:
1515
actor_critic: The actor-critic torch module.
16+
normalizer: The empirical normalizer module. If None, Identity is used.
1617
path: The path to the saving directory.
1718
filename: The name of exported JIT file. Defaults to "policy.pt".
1819
1920
Reference:
2021
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180
2122
"""
22-
policy_exporter = _TorchPolicyExporter(actor_critic)
23+
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
2324
policy_exporter.export(path, filename)
2425

2526

26-
def export_policy_as_onnx(actor_critic: object, path: str, filename="policy.onnx", verbose=False):
27+
def export_policy_as_onnx(
28+
actor_critic: object, normalizer: object | None, path: str, filename="policy.onnx", verbose=False
29+
):
2730
"""Export policy into a Torch ONNX file.
2831
2932
Args:
3033
actor_critic: The actor-critic torch module.
34+
normalizer: The empirical normalizer module. If None, Identity is used.
3135
path: The path to the saving directory.
32-
filename: The name of exported JIT file. Defaults to "policy.pt".
36+
filename: The name of exported ONNX file. Defaults to "policy.onnx".
3337
verbose: Whether to print the model summary. Defaults to False.
3438
"""
3539
if not os.path.exists(path):
3640
os.makedirs(path, exist_ok=True)
37-
policy_exporter = _OnnxPolicyExporter(actor_critic, verbose)
41+
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
3842
policy_exporter.export(path, filename)
3943

4044

@@ -50,7 +54,7 @@ class _TorchPolicyExporter(torch.nn.Module):
5054
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193
5155
"""
5256

53-
def __init__(self, actor_critic):
57+
def __init__(self, actor_critic, normalizer=None):
5458
super().__init__()
5559
self.actor = copy.deepcopy(actor_critic.actor)
5660
self.is_recurrent = actor_critic.is_recurrent
@@ -61,16 +65,22 @@ def __init__(self, actor_critic):
6165
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
6266
self.forward = self.forward_lstm
6367
self.reset = self.reset_memory
68+
# copy normalizer if exists
69+
if normalizer:
70+
self.normalizer = copy.deepcopy(normalizer)
71+
else:
72+
self.normalizer = torch.nn.Identity()
6473

6574
def forward_lstm(self, x):
75+
x = self.normalizer(x)
6676
x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state))
6777
self.hidden_state[:] = h
6878
self.cell_state[:] = c
6979
x = x.squeeze(0)
7080
return self.actor(x)
7181

7282
def forward(self, x):
73-
return self.actor(x)
83+
return self.actor(self.normalizer(x))
7484

7585
@torch.jit.export
7686
def reset(self):
@@ -91,7 +101,7 @@ def export(self, path, filename):
91101
class _OnnxPolicyExporter(torch.nn.Module):
92102
"""Exporter of actor-critic into ONNX file."""
93103

94-
def __init__(self, actor_critic, verbose=False):
104+
def __init__(self, actor_critic, normalizer=None, verbose=False):
95105
super().__init__()
96106
self.verbose = verbose
97107
self.actor = copy.deepcopy(actor_critic.actor)
@@ -100,14 +110,20 @@ def __init__(self, actor_critic, verbose=False):
100110
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
101111
self.rnn.cpu()
102112
self.forward = self.forward_lstm
113+
# copy normalizer if exists
114+
if normalizer:
115+
self.normalizer = copy.deepcopy(normalizer)
116+
else:
117+
self.normalizer = torch.nn.Identity()
103118

104119
def forward_lstm(self, x_in, h_in, c_in):
120+
x_in = self.normalizer(x_in)
105121
x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
106122
x = x.squeeze(0)
107123
return self.actor(x), h, c
108124

109125
def forward(self, x):
110-
return self.actor(x)
126+
return self.actor(self.normalizer(x))
111127

112128
def export(self, path, filename):
113129
self.to("cpu")

source/standalone/workflows/rsl_rl/play.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import (
4747
RslRlOnPolicyRunnerCfg,
4848
RslRlVecEnvWrapper,
49+
export_policy_as_jit,
4950
export_policy_as_onnx,
5051
)
5152

@@ -78,9 +79,14 @@ def main():
7879
# obtain the trained policy for inference
7980
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
8081

81-
# export policy to onnx
82+
# export policy to onnx/jit
8283
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
83-
export_policy_as_onnx(ppo_runner.alg.actor_critic, export_model_dir, filename="policy.onnx")
84+
export_policy_as_jit(
85+
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt"
86+
)
87+
export_policy_as_onnx(
88+
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
89+
)
8490

8591
# reset environment
8692
obs, _ = env.get_observations()

0 commit comments

Comments
 (0)