66
77import numpy as np
88from collections import deque
9+ from collections .abc import Iterator
910
1011from torch .utils .tensorboard import SummaryWriter
1112
@@ -37,7 +38,7 @@ def __init__(self, env, cfg: dict, teacher: nn.Module, device: str = "cpu"):
3738 img_shape = rgb_shape ,
3839 state_dim = self ._cfg ["policy" ]["action_head" ]["state_obs_dim" ],
3940 action_dim = action_dim ,
40- device = self . _device ,
41+ device = device ,
4142 )
4243
4344 # Training state
@@ -63,18 +64,18 @@ def learn(self, num_learning_iterations: int, log_dir: str) -> None:
6364 num_batches = 0
6465
6566 start_time = time .time ()
66- generator = self ._buffer .get_batches (self ._cfg .get ("mini_batches_size " , 32 ), self ._cfg ["num_epochs" ])
67+ generator = self ._buffer .get_batches (self ._cfg .get ("num_mini_batches " , 4 ), self ._cfg ["num_epochs" ])
6768 for batch in generator :
6869 # Forward pass for both action and pose prediction
69- pred_action = self ._policy (batch ["rgb_obs" ]. float () , batch ["robot_pose" ]. float () )
70- pred_left_pose , pred_right_pose = self ._policy .predict_pose (batch ["rgb_obs" ]. float () )
70+ pred_action = self ._policy (batch ["rgb_obs" ], batch ["robot_pose" ])
71+ pred_left_pose , pred_right_pose = self ._policy .predict_pose (batch ["rgb_obs" ])
7172
7273 # Compute action prediction loss
73- action_loss = F .mse_loss (pred_action , batch ["actions" ]. float () )
74+ action_loss = F .mse_loss (pred_action , batch ["actions" ])
7475
7576 # Compute pose estimation loss (position + orientation)
76- pose_left_loss = self ._compute_pose_loss (pred_left_pose , batch ["object_poses" ]. float () )
77- pose_right_loss = self ._compute_pose_loss (pred_right_pose , batch ["object_poses" ]. float () )
77+ pose_left_loss = self ._compute_pose_loss (pred_left_pose , batch ["object_poses" ])
78+ pose_right_loss = self ._compute_pose_loss (pred_right_pose , batch ["object_poses" ])
7879 pose_loss = pose_left_loss + pose_right_loss
7980
8081 # Combined loss with weights
@@ -227,7 +228,7 @@ def load_finetuned_model(self, path: str) -> None:
227228
228229
229230class ExperienceBuffer :
230- """Experience buffer."""
231+ """A first-in-first-out buffer for experience replay ."""
231232
232233 def __init__ (
233234 self ,
@@ -238,20 +239,20 @@ def __init__(
238239 action_dim : int ,
239240 device : str = "cpu" ,
240241 ):
242+ self ._num_envs = num_envs
243+ self ._max_size = max_size
241244 self ._img_shape = img_shape
242245 self ._state_dim = state_dim
243246 self ._action_dim = action_dim
244- self ._num_envs = num_envs
245- self ._max_size = max_size
246247 self ._device = device
248+ self ._ptr = 0
247249 self ._size = 0
248- self ._ptr = 0 # pointer to the next free slot in the buffer
249250
250- # Initialize buffers
251- self ._rgb_obs = torch .zeros (max_size , num_envs , * img_shape , device = device )
252- self ._robot_pose = torch .zeros (max_size , num_envs , state_dim , device = device )
253- self ._object_poses = torch .zeros (max_size , num_envs , 7 , device = device )
254- self ._actions = torch .zeros (max_size , num_envs , action_dim , device = device )
251+ # Buffers for data
252+ self ._rgb_obs = torch .empty (max_size , num_envs , * img_shape , dtype = torch . float32 , device = device )
253+ self ._robot_pose = torch .empty (max_size , num_envs , state_dim , dtype = torch . float32 , device = device )
254+ self ._object_poses = torch .empty (max_size , num_envs , 7 , dtype = torch . float32 , device = device )
255+ self ._actions = torch .empty (max_size , num_envs , action_dim , dtype = torch . float32 , device = device )
255256
256257 def add (
257258 self ,
@@ -261,42 +262,38 @@ def add(
261262 actions : torch .Tensor ,
262263 ) -> None :
263264 """Add experience to buffer."""
264- ptr = self ._ptr % self ._max_size
265- self ._rgb_obs [ptr ].copy_ (rgb_obs )
266- self ._robot_pose [ptr ].copy_ (robot_pose )
267- self ._object_poses [ptr ].copy_ (object_poses )
268- self ._actions [ptr ].copy_ (actions )
269- self ._ptr = self ._ptr + 1
265+ self ._ptr = (self ._ptr + 1 ) % self ._max_size
266+ self ._rgb_obs [self ._ptr ] = rgb_obs
267+ self ._robot_pose [self ._ptr ] = robot_pose
268+ self ._object_poses [self ._ptr ] = object_poses
269+ self ._actions [self ._ptr ] = actions
270270 self ._size = min (self ._size + 1 , self ._max_size )
271271
272- def get_batches (self , mini_batches_size : int , num_epochs : int ):
272+ def get_batches (self , num_mini_batches : int , num_epochs : int ) -> Iterator [ dict [ str , torch . Tensor ]] :
273273 """Generate batches for training."""
274- buffer_size = self ._size * self ._num_envs
275- indices = torch .randperm (buffer_size , device = self ._device )
276274 # calculate the size of each mini-batch
277- num_batches = min ( buffer_size // mini_batches_size , 10 )
275+ batch_size = self . _size // num_mini_batches
278276 for _ in range (num_epochs ):
279- for batch_idx in range (num_batches ):
280- start = batch_idx * mini_batches_size
281- end = start + mini_batches_size
282- mb_indices = indices [start :end ]
277+ indices = torch .randperm (self ._size )
278+ for batch_idx in range (0 , self ._size , batch_size ):
279+ batch_indices = indices [batch_idx : batch_idx + batch_size ]
283280
284281 # Yield a mini-batch of data
285- batch = {
286- "rgb_obs" : self ._rgb_obs .view (- 1 , * self ._img_shape )[ mb_indices ] ,
287- "robot_pose" : self ._robot_pose .view (- 1 , self ._state_dim )[ mb_indices ] ,
288- "object_poses" : self ._object_poses .view (- 1 , 7 )[ mb_indices ] ,
289- "actions" : self ._actions .view (- 1 , self ._action_dim )[ mb_indices ] ,
282+ yield {
283+ "rgb_obs" : self ._rgb_obs [ batch_indices ] .view (- 1 , * self ._img_shape ),
284+ "robot_pose" : self ._robot_pose [ batch_indices ] .view (- 1 , self ._state_dim ),
285+ "object_poses" : self ._object_poses [ batch_indices ] .view (- 1 , 7 ),
286+ "actions" : self ._actions [ batch_indices ] .view (- 1 , self ._action_dim ),
290287 }
291- yield batch
292288
293289 def clear (self ) -> None :
294290 """Clear the buffer."""
295291 self ._rgb_obs .zero_ ()
296292 self ._robot_pose .zero_ ()
293+ self ._object_poses .zero_ ()
297294 self ._actions .zero_ ()
298- self ._size = 0
299295 self ._ptr = 0
296+ self ._size = 0
300297
301298 def is_full (self ) -> bool :
302299 """Check if buffer is full."""
@@ -343,9 +340,7 @@ def __init__(self, config: dict, action_dim: int):
343340 pose_mlp_cfg ["output_dim" ] = 7
344341 self .pose_mlp = self ._build_mlp (pose_mlp_cfg )
345342
346- # Force float32 for better performance
347- self .float ()
348-
343+ @staticmethod
349344 def _build_cnn (self , config : dict ) -> nn .Sequential :
350345 """Build CNN encoder for grayscale images."""
351346 layers = []
@@ -372,7 +367,8 @@ def _build_cnn(self, config: dict) -> nn.Sequential:
372367
373368 return nn .Sequential (* layers )
374369
375- def _build_mlp (self , config : dict ) -> nn .Sequential :
370+ @staticmethod
371+ def _build_mlp (config : dict ) -> nn .Sequential :
376372 mlp_input_dim = config ["input_dim" ]
377373 layers = []
378374 for hidden_dim in config ["hidden_dims" ]:
@@ -393,11 +389,6 @@ def get_features(self, rgb_obs: torch.Tensor) -> torch.Tensor:
393389
394390 def forward (self , rgb_obs : torch .Tensor , state_obs : torch .Tensor | None = None ) -> dict :
395391 """Forward pass with shared stereo encoder for rgb images."""
396- # Ensure float32 for better performance
397- rgb_obs = rgb_obs .float ()
398- if state_obs is not None :
399- state_obs = state_obs .float ()
400-
401392 # Get features
402393 left_features , right_features = self .get_features (rgb_obs )
403394
@@ -417,8 +408,6 @@ def forward(self, rgb_obs: torch.Tensor, state_obs: torch.Tensor | None = None)
417408
418409 def predict_pose (self , rgb_obs : torch .Tensor ) -> torch .Tensor :
419410 """Predict pose from rgb images and state observations."""
420- # Ensure float32 for better performance
421- rgb_obs = rgb_obs .float ()
422411 left_features , right_features = self .get_features (rgb_obs )
423412 left_pose = self .pose_mlp (left_features )
424413 right_pose = self .pose_mlp (right_features )
0 commit comments