Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions examples/baselines/diffusion_policy/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,17 @@ def __getitem__(self, index):

obs_seq = self.trajectories['observations'][traj_idx][max(0, start):start+self.obs_horizon]
# start+self.obs_horizon is at least 1
act_seq = self.trajectories['actions'][traj_idx][max(0, start):end]
start_act = start + self.obs_horizon - 1 # start first action with the last obs
end_act = start_act + self.pred_horizon
act_seq = self.trajectories['actions'][traj_idx][max(0, start_act):min(end_act, L)]
if start < 0: # pad before the trajectory
obs_seq = torch.cat([obs_seq[0].repeat(-start, 1), obs_seq], dim=0)
act_seq = torch.cat([act_seq[0].repeat(-start, 1), act_seq], dim=0)
if end > L: # pad after the trajectory
if start_act < 0:
act_seq = torch.cat([act_seq[0].repeat(-start_act, 1), act_seq], dim=0)
if end_act > L: # pad after the trajectory
gripper_action = act_seq[-1, -1]
pad_action = torch.cat((self.pad_action_arm, gripper_action[None]), dim=0)
act_seq = torch.cat([act_seq, pad_action.repeat(end-L, 1)], dim=0)
act_seq = torch.cat([act_seq, pad_action.repeat(end_act-L, 1)], dim=0)
# making the robot (arm and gripper) stay still
assert obs_seq.shape[0] == self.obs_horizon and act_seq.shape[0] == self.pred_horizon
return {
Expand Down
15 changes: 8 additions & 7 deletions examples/baselines/diffusion_policy/train_rgbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,15 @@ def __getitem__(self, index):
pad_obs_seq = torch.stack([obs_seq[k][0]] * abs(start), dim=0)
obs_seq[k] = torch.cat((pad_obs_seq, obs_seq[k]), dim=0)
# don't need to pad obs after the trajectory, see the above char drawing

act_seq = self.trajectories["actions"][traj_idx][max(0, start) : end]
if start < 0: # pad before the trajectory
act_seq = torch.cat([act_seq[0].repeat(-start, 1), act_seq], dim=0)
if end > L: # pad after the trajectory
gripper_action = act_seq[-1, -1] # assume gripper is with pos controller
start_act = start + self.obs_horizon - 1 # start first action with the last obs
end_act = start_act + self.pred_horizon
act_seq = self.trajectories['actions'][traj_idx][max(0, start_act):min(end_act, L)]
if start_act < 0:
act_seq = torch.cat([act_seq[0].repeat(-start_act, 1), act_seq], dim=0)
if end_act > L: # pad after the trajectory
gripper_action = act_seq[-1, -1]
pad_action = torch.cat((self.pad_action_arm, gripper_action[None]), dim=0)
act_seq = torch.cat([act_seq, pad_action.repeat(end - L, 1)], dim=0)
act_seq = torch.cat([act_seq, pad_action.repeat(end_act-L, 1)], dim=0)
# making the robot (arm and gripper) stay still
assert (
obs_seq["state"].shape[0] == self.obs_horizon
Expand Down