Skip to content

Commit 8ca546f

Browse files
m-abrMiguelaraffin
authored
Docs: update ONNX export for SBX (#2214)
* Docs: update ONNX export for SBX * Update doc requirements * Update example to add debug info * Update changelog * Update example Updated example code for loading and saving a PPO model. --------- Co-authored-by: Miguel <miguel.abreu@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
1 parent 9c7b0f2 commit 8ca546f

File tree

4 files changed

+103
-5
lines changed

4 files changed

+103
-5
lines changed

docs/conda_env.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ dependencies:
1414
- pandas
1515
- numpy>=1.20,<3.0
1616
- matplotlib
17-
- sphinx>=5,<9
18-
- sphinx_rtd_theme>=1.3.0
17+
- sphinx>=5,<10
18+
- sphinx_rtd_theme>=3.0
1919
- sphinx_copybutton

docs/guide/export.rst

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,100 @@ in their respective folders.
395395

396396
In most cases, we recommend using PyTorch methods ``state_dict()`` and ``load_state_dict()`` from the policy,
397397
unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters()``.
398+
399+
400+
SBX (SB3 + Jax) Export
401+
----------------------
402+
403+
As an example of manual export, :ref:`Stable Baselines Jax (SBX) <sbx>` policies can be exported to ONNX
404+
by using an intermediate PyTorch representation, as shown in the following example:
405+
406+
.. code-block:: python
407+
408+
import numpy as np
409+
import sbx
410+
import torch as th
411+
412+
413+
class TorchPolicy(th.nn.Module):
414+
def __init__(self, obs_dim: int, hidden_dim: int, act_dim: int):
415+
super().__init__()
416+
self.net = th.nn.Sequential(
417+
th.nn.Linear(obs_dim, hidden_dim),
418+
th.nn.Tanh(),
419+
th.nn.Linear(hidden_dim, hidden_dim),
420+
th.nn.Tanh(),
421+
th.nn.Linear(hidden_dim, act_dim),
422+
)
423+
424+
def forward(self, x: th.Tensor) -> th.Tensor:
425+
return self.net(x)
426+
427+
428+
model = sbx.PPO("MlpPolicy", "Pendulum-v1")
429+
# Also possible: load a trained model
430+
# model = sbx.PPO.load("PathToTrainedModel.zip")
431+
432+
params = model.policy.actor_state.params["params"]
433+
# For debug:
434+
print("=== SBX params ===")
435+
for key, value in params.items():
436+
if isinstance(value, dict):
437+
for name, val in value.items():
438+
print(f"{key}.{name}: {val.shape}", end=" ")
439+
else:
440+
print(f"{key}: {value.shape}", end=" ")
441+
print("\n" + "=" * 20 + "\n")
442+
443+
obs_dim = model.observation_space.shape
444+
act_dim = model.action_space.shape
445+
446+
# Number of units in the hidden layers (assume a network architecture like [64, 64])
447+
hidden_dim = params["Dense_0"]["kernel"].shape[1]
448+
449+
# map params to torch state_dict keys
450+
num_layers = len([k for k in params.keys() if k.startswith("Dense_")])
451+
state_dict = {}
452+
for i in range(num_layers):
453+
layer_name = f"Dense_{i}"
454+
state_dict[f"net.{i * 2}.bias"] = th.from_numpy(np.array(params[layer_name]["bias"]))
455+
state_dict[f"net.{i * 2}.weight"] = th.from_numpy(np.array(params[layer_name]["kernel"].T))
456+
457+
torch_policy = TorchPolicy(obs_dim[0], hidden_dim, act_dim[0])
458+
print("=== Torch params ===")
459+
print(" ".join(f"{key}:{tuple(value.shape)}" for key, value in torch_policy.named_parameters()))
460+
print("=" * 20 + "\n")
461+
462+
torch_policy.load_state_dict(state_dict)
463+
torch_policy.eval()
464+
465+
dummy_input = th.zeros((1, *obs_dim))
466+
# Use normal Torch export
467+
th.onnx.export(
468+
torch_policy,
469+
(dummy_input,),
470+
"my_ppo_actor.onnx",
471+
opset_version=18,
472+
input_names=["input"],
473+
output_names=["action"],
474+
)
475+
476+
477+
##### Load and test with onnx
478+
479+
import onnxruntime as ort
480+
481+
onnx_path = "my_ppo_actor.onnx"
482+
ort_sess = ort.InferenceSession(onnx_path)
483+
484+
observation = np.random.random((1, *obs_dim)).astype(np.float32)
485+
action = ort_sess.run(None, {"input": observation})[0]
486+
487+
print(action)
488+
sbx_action, _ = model.predict(observation, deterministic=True)
489+
with th.no_grad():
490+
torch_action = torch_policy(th.as_tensor(observation))
491+
492+
# Check that the predictions are the same
493+
assert np.allclose(sbx_action, action)
494+
assert np.allclose(sbx_action, torch_action.numpy())

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Others:
4747
Documentation:
4848
^^^^^^^^^^^^^^
4949
- Added a note on MultiDiscrete spaces with multi-dimensional arrays and a wrapper to fix the issue (@unexploredtest)
50-
50+
- Added an example of manual export of SBX (SB3 + Jax) model to ONNX (@m-abr)
5151

5252
Release 2.7.1 (2025-12-05)
5353
--------------------------
@@ -1949,3 +1949,4 @@ And all the contributors:
19491949
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
19501950
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
19511951
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti @unexploredtest
1952+
@m-abr

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@
101101
"black>=26.1.0,<27",
102102
],
103103
"docs": [
104-
"sphinx>=5,<9",
104+
"sphinx>=5,<10",
105105
"sphinx-autobuild",
106-
"sphinx-rtd-theme>=1.3.0",
106+
"sphinx-rtd-theme>=3.0.0",
107107
# For spelling
108108
"sphinxcontrib.spelling",
109109
# Copy button for code snippets

0 commit comments

Comments
 (0)