Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions examples/tutorial/rl/hilserl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from queue import Empty, Full

import torch
import torch.optim as optim

from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
Expand Down Expand Up @@ -40,8 +40,9 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)

# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers()

print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
Expand Down Expand Up @@ -83,24 +84,26 @@ def run_learner(
else:
batch[key] = online_batch[key]

loss, _ = policy_learner.forward(batch)
def batch_iter(b=batch):
while True:
yield b

optimizer.zero_grad()
loss.backward()
optimizer.step()
stats = algorithm.update(batch_iter())
training_step += 1

if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)

# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
Expand Down Expand Up @@ -144,15 +147,15 @@ def run_actor(

while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass

# Get action from policy
# Get action from policy (returns full action: continuous + discrete)
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action_tensor = policy_actor.select_action(policy_obs)
action = action_tensor.squeeze(0).cpu().numpy()

# Step environment
Expand Down
12 changes: 12 additions & 0 deletions src/lerobot/configs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,15 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional

# Algorithm name registered in RLAlgorithmConfig registry
algorithm: str = "sac"

# Data mixer strategy name. Currently supports "online_offline"
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer
online_ratio: float = 0.5

# RL trainer iterator
async_prefetch: bool = True
queue_size: int = 2
Loading
Loading