Skip to content

Commit bdb6688

Browse files
committed
update vis
1 parent 400df7a commit bdb6688

File tree

3 files changed

+42
-41
lines changed

3 files changed

+42
-41
lines changed

scripts/download_dataset.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,6 @@ def download_task_files(task_name: str, save_dir: str, train_episodes: int, eval
107107

108108
# Download train episodes
109109
if train_episodes > 0:
110-
print(f"Downloading train episodes for {task_name}...")
111-
for episode in range(train_episodes):
112-
local_file = os.path.join(train_dir, f"episode{episode}", "low_dim_obs.pkl")
113-
if os.path.exists(local_file):
114-
print(f" Skipping existing train episode {episode}")
115-
continue
116-
try:
117-
hf_hub_download(
118-
repo_id=repo_id,
119-
repo_type="dataset",
120-
filename=f"train/{task_name}/variation0/episodes/episode{episode}/low_dim_obs.pkl",
121-
local_dir=save_dir,
122-
local_dir_use_symlinks=False,
123-
resume_download=True
124-
)
125-
except Exception as e:
126-
print(f" Warning: Failed to download train episode {episode}: {e}")
127-
success = False
128110
# Download variation_descriptions.pkl for train
129111
variation_file = os.path.join(save_dir, "train", task_name, "variation0", "variation_descriptions.pkl")
130112
if os.path.exists(variation_file):
@@ -143,27 +125,29 @@ def download_task_files(task_name: str, save_dir: str, train_episodes: int, eval
143125
except Exception as e:
144126
print(f" Warning: Failed to download train variation_descriptions.pkl: {e}")
145127
success = False
146-
147-
# Download eval episodes
148-
if eval_episodes > 0:
149-
print(f"Downloading eval episodes for {task_name}...")
150-
for episode in range(eval_episodes):
151-
local_file = os.path.join(eval_dir, f"episode{episode}", "low_dim_obs.pkl")
128+
# Download train episodes
129+
print(f"Downloading train episodes for {task_name}...")
130+
for episode in range(train_episodes):
131+
local_file = os.path.join(train_dir, f"episode{episode}", "low_dim_obs.pkl")
152132
if os.path.exists(local_file):
153-
print(f" Skipping existing eval episode {episode}")
133+
print(f" Skipping existing train episode {episode}")
154134
continue
155135
try:
156136
hf_hub_download(
157137
repo_id=repo_id,
158138
repo_type="dataset",
159-
filename=f"eval/{task_name}/variation0/episodes/episode{episode}/low_dim_obs.pkl",
139+
filename=f"train/{task_name}/variation0/episodes/episode{episode}/low_dim_obs.pkl",
160140
local_dir=save_dir,
161141
local_dir_use_symlinks=False,
162142
resume_download=True
163143
)
164144
except Exception as e:
165-
print(f" Warning: Failed to download eval episode {episode}: {e}")
145+
print(f" Warning: Failed to download train episode {episode}: {e}")
166146
success = False
147+
148+
149+
# Download eval episodes
150+
if eval_episodes > 0:
167151
# Download variation_descriptions.pkl for eval
168152
variation_file = os.path.join(save_dir, "eval", task_name, "variation0", "variation_descriptions.pkl")
169153
if os.path.exists(variation_file):
@@ -182,6 +166,27 @@ def download_task_files(task_name: str, save_dir: str, train_episodes: int, eval
182166
except Exception as e:
183167
print(f" Warning: Failed to download eval variation_descriptions.pkl: {e}")
184168
success = False
169+
170+
# Download eval episodes
171+
print(f"Downloading eval episodes for {task_name}...")
172+
for episode in range(eval_episodes):
173+
local_file = os.path.join(eval_dir, f"episode{episode}", "low_dim_obs.pkl")
174+
if os.path.exists(local_file):
175+
print(f" Skipping existing eval episode {episode}")
176+
continue
177+
try:
178+
hf_hub_download(
179+
repo_id=repo_id,
180+
repo_type="dataset",
181+
filename=f"eval/{task_name}/variation0/episodes/episode{episode}/low_dim_obs.pkl",
182+
local_dir=save_dir,
183+
local_dir_use_symlinks=False,
184+
resume_download=True
185+
)
186+
except Exception as e:
187+
print(f" Warning: Failed to download eval episode {episode}: {e}")
188+
success = False
189+
185190

186191
print(f"Completed downloading task: {task_name}")
187192
return success

src/envs/rlbench/rlbench_env.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -578,18 +578,7 @@ def get_action_space(self, cfg):
578578

579579
def _load_demos(self, cfg,training=True):
580580
self.training = training
581-
582-
'''
583-
# The "evaluation dataset" in Huggingface omits point cloud observations to reduce storage requirements.
584-
# and the point cloud observation is neccessary for the visualization.
585-
# As a result, we use the "training dataset" here for quick project startup.
586-
# TODO: upload full eval dataset to huggingface
587-
'''
588581
dataset_root_dir = cfg.dataset_root_train
589-
# if training or cfg.debug:
590-
# dataset_root_dir = cfg.dataset_root_train
591-
# else:
592-
# dataset_root_dir = cfg.dataset_root_eval
593582

594583
obs_config = _make_obs_config(cfg)
595584
obs_config_demo = copy.deepcopy(obs_config)

src/workspace.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,16 @@ def loop(self):
155155
self.logger.log_metrics(updated_metrics, step=self._current_step)
156156
print(f"Step {self._current_step}/{num_train_steps}, Loss: {updated_metrics.get('total_loss', 0):.4f}")
157157

158-
if self._current_step % self.cfg.vis_every_steps == 0:
159-
updated_metrics = self.vis()
160-
self.logger.log_metrics(updated_metrics, step=self._current_step)
158+
'''
159+
# The dataset in Huggingface omits point cloud observations to reduce storage requirements.
160+
# and the point cloud observation is neccessary for the visualization.
161+
# As a result, the visualization is not available now or you can generate the full eval dataset manually.
162+
# TODO: upload full eval dataset to huggingface
163+
'''
164+
165+
# if self._current_step % self.cfg.vis_every_steps == 0:
166+
# updated_metrics = self.vis()
167+
# self.logger.log_metrics(updated_metrics, step=self._current_step)
161168

162169
if self._current_step % self.cfg.eval_every_steps == 0:
163170
updated_metrics = self.eval()

0 commit comments

Comments
 (0)