@@ -107,24 +107,6 @@ def download_task_files(task_name: str, save_dir: str, train_episodes: int, eval
107107
108108 # Download train episodes
109109 if train_episodes > 0 :
110- print (f"Downloading train episodes for { task_name } ..." )
111- for episode in range (train_episodes ):
112- local_file = os .path .join (train_dir , f"episode{ episode } " , "low_dim_obs.pkl" )
113- if os .path .exists (local_file ):
114- print (f" Skipping existing train episode { episode } " )
115- continue
116- try :
117- hf_hub_download (
118- repo_id = repo_id ,
119- repo_type = "dataset" ,
120- filename = f"train/{ task_name } /variation0/episodes/episode{ episode } /low_dim_obs.pkl" ,
121- local_dir = save_dir ,
122- local_dir_use_symlinks = False ,
123- resume_download = True
124- )
125- except Exception as e :
126- print (f" Warning: Failed to download train episode { episode } : { e } " )
127- success = False
128110 # Download variation_descriptions.pkl for train
129111 variation_file = os .path .join (save_dir , "train" , task_name , "variation0" , "variation_descriptions.pkl" )
130112 if os .path .exists (variation_file ):
@@ -143,27 +125,29 @@ def download_task_files(task_name: str, save_dir: str, train_episodes: int, eval
143125 except Exception as e :
144126 print (f" Warning: Failed to download train variation_descriptions.pkl: { e } " )
145127 success = False
146-
147- # Download eval episodes
148- if eval_episodes > 0 :
149- print (f"Downloading eval episodes for { task_name } ..." )
150- for episode in range (eval_episodes ):
151- local_file = os .path .join (eval_dir , f"episode{ episode } " , "low_dim_obs.pkl" )
128+ # Download train episodes
129+ print (f"Downloading train episodes for { task_name } ..." )
130+ for episode in range (train_episodes ):
131+ local_file = os .path .join (train_dir , f"episode{ episode } " , "low_dim_obs.pkl" )
152132 if os .path .exists (local_file ):
153- print (f" Skipping existing eval episode { episode } " )
133+ print (f" Skipping existing train episode { episode } " )
154134 continue
155135 try :
156136 hf_hub_download (
157137 repo_id = repo_id ,
158138 repo_type = "dataset" ,
159- filename = f"eval /{ task_name } /variation0/episodes/episode{ episode } /low_dim_obs.pkl" ,
139+ filename = f"train /{ task_name } /variation0/episodes/episode{ episode } /low_dim_obs.pkl" ,
160140 local_dir = save_dir ,
161141 local_dir_use_symlinks = False ,
162142 resume_download = True
163143 )
164144 except Exception as e :
165- print (f" Warning: Failed to download eval episode { episode } : { e } " )
145+ print (f" Warning: Failed to download train episode { episode } : { e } " )
166146 success = False
147+
148+
149+ # Download eval episodes
150+ if eval_episodes > 0 :
167151 # Download variation_descriptions.pkl for eval
168152 variation_file = os .path .join (save_dir , "eval" , task_name , "variation0" , "variation_descriptions.pkl" )
169153 if os .path .exists (variation_file ):
@@ -182,6 +166,27 @@ def download_task_files(task_name: str, save_dir: str, train_episodes: int, eval
182166 except Exception as e :
183167 print (f" Warning: Failed to download eval variation_descriptions.pkl: { e } " )
184168 success = False
169+
170+ # Download eval episodes
171+ print (f"Downloading eval episodes for { task_name } ..." )
172+ for episode in range (eval_episodes ):
173+ local_file = os .path .join (eval_dir , f"episode{ episode } " , "low_dim_obs.pkl" )
174+ if os .path .exists (local_file ):
175+ print (f" Skipping existing eval episode { episode } " )
176+ continue
177+ try :
178+ hf_hub_download (
179+ repo_id = repo_id ,
180+ repo_type = "dataset" ,
181+ filename = f"eval/{ task_name } /variation0/episodes/episode{ episode } /low_dim_obs.pkl" ,
182+ local_dir = save_dir ,
183+ local_dir_use_symlinks = False ,
184+ resume_download = True
185+ )
186+ except Exception as e :
187+ print (f" Warning: Failed to download eval episode { episode } : { e } " )
188+ success = False
189+
185190
186191 print (f"Completed downloading task: { task_name } " )
187192 return success
0 commit comments