Skip to content

Commit 72eb150

Browse files
committed
several minor improvements
1 parent bae7981 commit 72eb150

File tree

3 files changed

+57
-62
lines changed

3 files changed

+57
-62
lines changed

examples/manipulation/behavior_cloning.py

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
from collections import deque
9+
from collections.abc import Iterator
910

1011
from torch.utils.tensorboard import SummaryWriter
1112

@@ -37,7 +38,7 @@ def __init__(self, env, cfg: dict, teacher: nn.Module, device: str = "cpu"):
3738
img_shape=rgb_shape,
3839
state_dim=self._cfg["policy"]["action_head"]["state_obs_dim"],
3940
action_dim=action_dim,
40-
device=self._device,
41+
device=device,
4142
)
4243

4344
# Training state
@@ -63,18 +64,18 @@ def learn(self, num_learning_iterations: int, log_dir: str) -> None:
6364
num_batches = 0
6465

6566
start_time = time.time()
66-
generator = self._buffer.get_batches(self._cfg.get("mini_batches_size", 32), self._cfg["num_epochs"])
67+
generator = self._buffer.get_batches(self._cfg.get("num_mini_batches", 4), self._cfg["num_epochs"])
6768
for batch in generator:
6869
# Forward pass for both action and pose prediction
69-
pred_action = self._policy(batch["rgb_obs"].float(), batch["robot_pose"].float())
70-
pred_left_pose, pred_right_pose = self._policy.predict_pose(batch["rgb_obs"].float())
70+
pred_action = self._policy(batch["rgb_obs"], batch["robot_pose"])
71+
pred_left_pose, pred_right_pose = self._policy.predict_pose(batch["rgb_obs"])
7172

7273
# Compute action prediction loss
73-
action_loss = F.mse_loss(pred_action, batch["actions"].float())
74+
action_loss = F.mse_loss(pred_action, batch["actions"])
7475

7576
# Compute pose estimation loss (position + orientation)
76-
pose_left_loss = self._compute_pose_loss(pred_left_pose, batch["object_poses"].float())
77-
pose_right_loss = self._compute_pose_loss(pred_right_pose, batch["object_poses"].float())
77+
pose_left_loss = self._compute_pose_loss(pred_left_pose, batch["object_poses"])
78+
pose_right_loss = self._compute_pose_loss(pred_right_pose, batch["object_poses"])
7879
pose_loss = pose_left_loss + pose_right_loss
7980

8081
# Combined loss with weights
@@ -227,7 +228,7 @@ def load_finetuned_model(self, path: str) -> None:
227228

228229

229230
class ExperienceBuffer:
230-
"""Experience buffer."""
231+
"""A first-in-first-out buffer for experience replay."""
231232

232233
def __init__(
233234
self,
@@ -238,20 +239,20 @@ def __init__(
238239
action_dim: int,
239240
device: str = "cpu",
240241
):
242+
self._num_envs = num_envs
243+
self._max_size = max_size
241244
self._img_shape = img_shape
242245
self._state_dim = state_dim
243246
self._action_dim = action_dim
244-
self._num_envs = num_envs
245-
self._max_size = max_size
246247
self._device = device
248+
self._ptr = 0
247249
self._size = 0
248-
self._ptr = 0 # pointer to the next free slot in the buffer
249250

250-
# Initialize buffers
251-
self._rgb_obs = torch.zeros(max_size, num_envs, *img_shape, device=device)
252-
self._robot_pose = torch.zeros(max_size, num_envs, state_dim, device=device)
253-
self._object_poses = torch.zeros(max_size, num_envs, 7, device=device)
254-
self._actions = torch.zeros(max_size, num_envs, action_dim, device=device)
251+
# Buffers for data
252+
self._rgb_obs = torch.empty(max_size, num_envs, *img_shape, dtype=torch.float32, device=device)
253+
self._robot_pose = torch.empty(max_size, num_envs, state_dim, dtype=torch.float32, device=device)
254+
self._object_poses = torch.empty(max_size, num_envs, 7, dtype=torch.float32, device=device)
255+
self._actions = torch.empty(max_size, num_envs, action_dim, dtype=torch.float32, device=device)
255256

256257
def add(
257258
self,
@@ -261,42 +262,38 @@ def add(
261262
actions: torch.Tensor,
262263
) -> None:
263264
"""Add experience to buffer."""
264-
ptr = self._ptr % self._max_size
265-
self._rgb_obs[ptr].copy_(rgb_obs)
266-
self._robot_pose[ptr].copy_(robot_pose)
267-
self._object_poses[ptr].copy_(object_poses)
268-
self._actions[ptr].copy_(actions)
269-
self._ptr = self._ptr + 1
265+
self._ptr = (self._ptr + 1) % self._max_size
266+
self._rgb_obs[self._ptr] = rgb_obs
267+
self._robot_pose[self._ptr] = robot_pose
268+
self._object_poses[self._ptr] = object_poses
269+
self._actions[self._ptr] = actions
270270
self._size = min(self._size + 1, self._max_size)
271271

272-
def get_batches(self, mini_batches_size: int, num_epochs: int):
272+
def get_batches(self, num_mini_batches: int, num_epochs: int) -> Iterator[dict[str, torch.Tensor]]:
273273
"""Generate batches for training."""
274-
buffer_size = self._size * self._num_envs
275-
indices = torch.randperm(buffer_size, device=self._device)
276274
# calculate the size of each mini-batch
277-
num_batches = min(buffer_size // mini_batches_size, 10)
275+
batch_size = self._size // num_mini_batches
278276
for _ in range(num_epochs):
279-
for batch_idx in range(num_batches):
280-
start = batch_idx * mini_batches_size
281-
end = start + mini_batches_size
282-
mb_indices = indices[start:end]
277+
indices = torch.randperm(self._size)
278+
for batch_idx in range(0, self._size, batch_size):
279+
batch_indices = indices[batch_idx : batch_idx + batch_size]
283280

284281
# Yield a mini-batch of data
285-
batch = {
286-
"rgb_obs": self._rgb_obs.view(-1, *self._img_shape)[mb_indices],
287-
"robot_pose": self._robot_pose.view(-1, self._state_dim)[mb_indices],
288-
"object_poses": self._object_poses.view(-1, 7)[mb_indices],
289-
"actions": self._actions.view(-1, self._action_dim)[mb_indices],
282+
yield {
283+
"rgb_obs": self._rgb_obs[batch_indices].view(-1, *self._img_shape),
284+
"robot_pose": self._robot_pose[batch_indices].view(-1, self._state_dim),
285+
"object_poses": self._object_poses[batch_indices].view(-1, 7),
286+
"actions": self._actions[batch_indices].view(-1, self._action_dim),
290287
}
291-
yield batch
292288

293289
def clear(self) -> None:
294290
"""Clear the buffer."""
295291
self._rgb_obs.zero_()
296292
self._robot_pose.zero_()
293+
self._object_poses.zero_()
297294
self._actions.zero_()
298-
self._size = 0
299295
self._ptr = 0
296+
self._size = 0
300297

301298
def is_full(self) -> bool:
302299
"""Check if buffer is full."""
@@ -343,9 +340,7 @@ def __init__(self, config: dict, action_dim: int):
343340
pose_mlp_cfg["output_dim"] = 7
344341
self.pose_mlp = self._build_mlp(pose_mlp_cfg)
345342

346-
# Force float32 for better performance
347-
self.float()
348-
343+
@staticmethod
349344
def _build_cnn(self, config: dict) -> nn.Sequential:
350345
"""Build CNN encoder for grayscale images."""
351346
layers = []
@@ -372,7 +367,8 @@ def _build_cnn(self, config: dict) -> nn.Sequential:
372367

373368
return nn.Sequential(*layers)
374369

375-
def _build_mlp(self, config: dict) -> nn.Sequential:
370+
@staticmethod
371+
def _build_mlp(config: dict) -> nn.Sequential:
376372
mlp_input_dim = config["input_dim"]
377373
layers = []
378374
for hidden_dim in config["hidden_dims"]:
@@ -393,11 +389,6 @@ def get_features(self, rgb_obs: torch.Tensor) -> torch.Tensor:
393389

394390
def forward(self, rgb_obs: torch.Tensor, state_obs: torch.Tensor | None = None) -> dict:
395391
"""Forward pass with shared stereo encoder for rgb images."""
396-
# Ensure float32 for better performance
397-
rgb_obs = rgb_obs.float()
398-
if state_obs is not None:
399-
state_obs = state_obs.float()
400-
401392
# Get features
402393
left_features, right_features = self.get_features(rgb_obs)
403394

@@ -417,8 +408,6 @@ def forward(self, rgb_obs: torch.Tensor, state_obs: torch.Tensor | None = None)
417408

418409
def predict_pose(self, rgb_obs: torch.Tensor) -> torch.Tensor:
419410
"""Predict pose from rgb images and state observations."""
420-
# Ensure float32 for better performance
421-
rgb_obs = rgb_obs.float()
422411
left_features, right_features = self.get_features(rgb_obs)
423412
left_pose = self.pose_mlp(left_features)
424413
right_pose = self.pose_mlp(right_features)

examples/manipulation/grasp_env.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import torch
22
import math
33
from typing import Literal
4+
45
import genesis as gs
56
from genesis.utils.geom import (
67
xyz_to_quat,
78
transform_quat_by_quat,
89
transform_by_quat,
910
)
1011

12+
MAX_DEPTH = 10.0
13+
1114

1215
class GraspEnv:
1316
def __init__(
@@ -223,9 +226,9 @@ def get_observations(self) -> tuple[torch.Tensor, dict]:
223226
#
224227
obs_components = [
225228
finger_pos - obj_pos, # 3D position difference
226-
finger_quat, # current orientation (4D quaternion)
229+
finger_quat, # current orientation (w, x, y, z)
227230
obj_pos, # goal position
228-
obj_quat, # goal orientation (4D quaternion)
231+
obj_quat, # goal orientation (w, x, y, z)
229232
]
230233
obs_tensor = torch.cat(obs_components, dim=-1)
231234
self.extras["observations"]["critic"] = obs_tensor
@@ -237,27 +240,25 @@ def rescale_action(self, action: torch.Tensor) -> torch.Tensor:
237240

238241
def get_depth_image(self, normalize: bool = True) -> torch.Tensor:
239242
# Render depth image from the camera
240-
_, depth, _, _ = self.batch_cam.render(rgb=False, depth=True)
243+
_, depth, _, _ = self.batch_cam.render(rgb=False, depth=True, segmentation=False, normal=False)
241244
depth = depth.permute(0, 3, 1, 2) # shape (B, 1, H, W)
242245
if normalize:
243-
depth = torch.clamp(depth, min=0.0, max=10)
244-
depth = (depth - 0.0) / (10.0 - 0.0) # normalize to [0, 1]
246+
depth = torch.clamp(depth, min=0.0, max=MAX_DEPTH)
247+
depth = (depth - 0.0) / (MAX_DEPTH - 0.0) # normalize to [0, 1]
245248
return depth
246249

247250
def get_stereo_rgb_images(self, normalize: bool = True) -> torch.Tensor:
248-
rgb_left, _, _, _ = self.left_cam.render(rgb=True, depth=False)
249-
rgb_right, _, _, _ = self.right_cam.render(rgb=True, depth=False)
251+
rgb_left, _, _, _ = self.left_cam.render(rgb=True, depth=False, segmentation=False, normal=False)
252+
rgb_right, _, _, _ = self.right_cam.render(rgb=True, depth=False, segmentation=False, normal=False)
250253

251254
# Convert to proper format
252255
rgb_left = rgb_left.permute(0, 3, 1, 2)[:, :3] # shape (B, 3, H, W)
253256
rgb_right = rgb_right.permute(0, 3, 1, 2)[:, :3] # shape (B, 3, H, W)
254257

255258
# Normalize if requested
256259
if normalize:
257-
rgb_left = torch.clamp(rgb_left, min=0.0, max=255.0)
258-
rgb_left = (rgb_left - 0.0) / (255.0 - 0.0)
259-
rgb_right = torch.clamp(rgb_right, min=0.0, max=255.0)
260-
rgb_right = (rgb_right - 0.0) / (255.0 - 0.0)
260+
rgb_left = torch.clamp(rgb_left, min=0.0, max=255.0) / 255.0
261+
rgb_right = torch.clamp(rgb_right, min=0.0, max=255.0) / 255.0
261262

262263
# Concatenate left and right rgb images along channel dimension
263264
# Result: [B, 6, H, W] where channel 0 is left rgb, channel 1 is right rgb

examples/manipulation/grasp_train.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_train_cfg(exp_name, max_iterations):
2929
"class_name": "PPO",
3030
"clip_param": 0.2,
3131
"desired_kl": 0.01,
32-
"entropy_coef": 0.00,
32+
"entropy_coef": 0.0,
3333
"gamma": 0.99,
3434
"lam": 0.95,
3535
"learning_rate": 0.0003,
@@ -72,7 +72,7 @@ def get_train_cfg(exp_name, max_iterations):
7272
"num_steps_per_env": 24,
7373
"learning_rate": 0.001,
7474
"num_epochs": 5,
75-
"mini_batches_size": 512,
75+
"num_mini_batches": 10,
7676
"max_grad_norm": 1.0,
7777
# Network architecture
7878
"policy": {
@@ -210,5 +210,10 @@ def main():
210210

211211
"""
212212
# training
213-
python examples/manipulation/grasp_train.py
213+
214+
# to train the RL policy
215+
python examples/manipulation/grasp_train.py --stage=rl
216+
217+
# to train the BC policy (requires RL policy to be trained first)
218+
python examples/manipulation/grasp_train.py --stage=bc
214219
"""

0 commit comments

Comments
 (0)