@@ -395,3 +395,100 @@ in their respective folders.
395395
396396 In most cases, we recommend using PyTorch methods ``state_dict() `` and ``load_state_dict() `` from the policy,
397397 unless you need to access the optimizers' state dict too. In that case, you need to call ``get_parameters() ``.
398+
399+
400+ SBX (SB3 + Jax) Export
401+ ----------------------
402+
403+ As an example of manual export, :ref: `Stable Baselines Jax (SBX) <sbx >` policies can be exported to ONNX
404+ by using an intermediate PyTorch representation, as shown in the following example:
405+
406+ .. code-block :: python
407+
408+ import numpy as np
409+ import sbx
410+ import torch as th
411+
412+
413+ class TorchPolicy (th .nn .Module ):
414+ def __init__ (self , obs_dim : int , hidden_dim : int , act_dim : int ):
415+ super ().__init__ ()
416+ self .net = th.nn.Sequential(
417+ th.nn.Linear(obs_dim, hidden_dim),
418+ th.nn.Tanh(),
419+ th.nn.Linear(hidden_dim, hidden_dim),
420+ th.nn.Tanh(),
421+ th.nn.Linear(hidden_dim, act_dim),
422+ )
423+
424+ def forward (self , x : th.Tensor) -> th.Tensor:
425+ return self .net(x)
426+
427+
428+ model = sbx.PPO(" MlpPolicy" , " Pendulum-v1" )
429+ # Also possible: load a trained model
430+ # model = sbx.PPO.load("PathToTrainedModel.zip")
431+
432+ params = model.policy.actor_state.params[" params" ]
433+ # For debug:
434+ print (" === SBX params ===" )
435+ for key, value in params.items():
436+ if isinstance (value, dict ):
437+ for name, val in value.items():
438+ print (f " { key} . { name} : { val.shape} " , end = " " )
439+ else :
440+ print (f " { key} : { value.shape} " , end = " " )
441+ print (" \n " + " =" * 20 + " \n " )
442+
443+ obs_dim = model.observation_space.shape
444+ act_dim = model.action_space.shape
445+
446+ # Number of units in the hidden layers (assume a network architecture like [64, 64])
447+ hidden_dim = params[" Dense_0" ][" kernel" ].shape[1 ]
448+
449+ # map params to torch state_dict keys
450+ num_layers = len ([k for k in params.keys() if k.startswith(" Dense_" )])
451+ state_dict = {}
452+ for i in range (num_layers):
453+ layer_name = f " Dense_ { i} "
454+ state_dict[f " net. { i * 2 } .bias " ] = th.from_numpy(np.array(params[layer_name][" bias" ]))
455+ state_dict[f " net. { i * 2 } .weight " ] = th.from_numpy(np.array(params[layer_name][" kernel" ].T))
456+
457+ torch_policy = TorchPolicy(obs_dim[0 ], hidden_dim, act_dim[0 ])
458+ print (" === Torch params ===" )
459+ print (" " .join(f " { key} : { tuple (value.shape)} " for key, value in torch_policy.named_parameters()))
460+ print (" =" * 20 + " \n " )
461+
462+ torch_policy.load_state_dict(state_dict)
463+ torch_policy.eval()
464+
465+ dummy_input = th.zeros((1 , * obs_dim))
466+ # Use normal Torch export
467+ th.onnx.export(
468+ torch_policy,
469+ (dummy_input,),
470+ " my_ppo_actor.onnx" ,
471+ opset_version = 18 ,
472+ input_names = [" input" ],
473+ output_names = [" action" ],
474+ )
475+
476+
477+ # #### Load and test with onnx
478+
479+ import onnxruntime as ort
480+
481+ onnx_path = " my_ppo_actor.onnx"
482+ ort_sess = ort.InferenceSession(onnx_path)
483+
484+ observation = np.random.random((1 , * obs_dim)).astype(np.float32)
485+ action = ort_sess.run(None , {" input" : observation})[0 ]
486+
487+ print (action)
488+ sbx_action, _ = model.predict(observation, deterministic = True )
489+ with th.no_grad():
490+ torch_action = torch_policy(th.as_tensor(observation))
491+
492+ # Check that the predictions are the same
493+ assert np.allclose(sbx_action, action)
494+ assert np.allclose(sbx_action, torch_action.numpy())
0 commit comments