1111import amago
1212from amago .envs import AMAGOEnv
1313from amago import cli_utils
14- from amago .loading import RLData , RLDataset
14+ from amago .loading import RLData , RLDataset , DiskTrajDataset , MixtureOfDatasets
1515from amago .nets .policy_dists import TanhGaussian , GMM , Beta
1616from amago .nets .actor_critic import ResidualActor , Actor
17+ from amago .agent import binary_filter , exp_filter
1718
1819
1920def add_cli (parser ):
@@ -26,7 +27,7 @@ def add_cli(parser):
2627 parser .add_argument (
2728 "--policy_dist" ,
2829 type = str ,
29- default = "TanhGaussian " ,
30+ default = "Beta " ,
3031 help = "Policy distribution type" ,
3132 choices = ["TanhGaussian" , "GMM" , "Beta" ],
3233 )
@@ -37,6 +38,12 @@ def add_cli(parser):
3738 help = "Actor head type" ,
3839 choices = ["ResidualActor" , "Actor" ],
3940 )
41+ parser .add_argument (
42+ "--online_after_epoch" ,
43+ type = int ,
44+ default = float ("inf" ),
45+ help = "Number of epochs after which to start collecting online data" ,
46+ )
4047 parser .add_argument (
4148 "--eval_timesteps" ,
4249 type = int ,
@@ -83,7 +90,6 @@ def _sample_trajectory(self, episode_idx: int):
8390 rewards = torch .from_numpy (rewards_np ).float ().unsqueeze (- 1 )
8491 time_idxs = torch .arange (traj_len ).unsqueeze (- 1 ).long ()
8592 dones = torch .from_numpy (terminals_np ).bool ().unsqueeze (- 1 )
86-
8793 return RLData (
8894 obs = obs ,
8995 actions = actions ,
@@ -128,12 +134,14 @@ def reset(self, *args, **kwargs):
128134
129135 def step (self , action ):
130136 s , r , d , i = self .env .step (action )
137+ truncated = i .get ("TimeLimit.truncated" , False )
138+ terminated = d and not truncated
131139 self .episode_return += r
132- if d :
140+ if terminated or truncated :
133141 i [f"{ AMAGO_ENV_LOG_PREFIX } D4RL Normalized Return" ] = (
134142 d4rl .get_normalized_score (self .env_name , self .episode_return )
135143 )
136- return s , r , d , d , i
144+ return s , r , terminated , truncated , i
137145
138146
139147if __name__ == "__main__" :
@@ -148,13 +156,9 @@ def step(self, action):
148156 assert isinstance (
149157 example_env .action_space , gym .spaces .Box
150158 ), "Only supports continuous action spaces"
151- if args .timesteps_per_epoch > 0 :
152- print ("WARNING: timesteps_per_epoch is not supported for D4RL, setting to 0" )
153- args .timesteps_per_epoch = 0
154159
155- # create dataset
156- dataset = D4RLDataset (d4rl_dset = example_env .dset )
157160 args .eval_timesteps = example_env .time_limit + 1
161+ args .timesteps_per_epoch = example_env .time_limit
158162
159163 # setup environment
160164 make_train_env = lambda : AMAGOEnv (
@@ -166,8 +170,8 @@ def step(self, action):
166170 # agent architecture: drop everything down to standard small sizes
167171 config = {
168172 "amago.nets.actor_critic.NCritics.d_hidden" : 128 ,
169- "amago.nets.actor_critic.NCriticsTwoHot.d_hidden" : 256 ,
170- "amago.nets.actor_critic.NCriticsTwoHot.output_bins" : 128 ,
173+ "amago.nets.actor_critic.NCriticsTwoHot.d_hidden" : 128 ,
174+ "amago.nets.actor_critic.NCriticsTwoHot.output_bins" : 64 ,
171175 "amago.nets.actor_critic.Actor.d_hidden" : 128 ,
172176 "amago.nets.actor_critic.Actor.continuous_dist_type" : eval (args .policy_dist ),
173177 "amago.nets.actor_critic.ResidualActor.feature_dim" : 128 ,
@@ -184,6 +188,13 @@ def step(self, action):
184188 d_output = 128 ,
185189 n_layers = 1 ,
186190 )
191+ exploration_wrapper_type = cli_utils .switch_exploration (
192+ config ,
193+ strategy = "egreedy" ,
194+ eps_start = 0.05 ,
195+ eps_end = 0.01 ,
196+ steps_anneal = 15_000 ,
197+ )
187198 traj_encoder_type = cli_utils .switch_traj_encoder (
188199 config ,
189200 arch = args .traj_encoder ,
@@ -195,18 +206,37 @@ def step(self, action):
195206 args .agent_type ,
196207 online_coeff = 0.0 ,
197208 offline_coeff = 1.0 ,
198- gamma = 0.995 ,
209+ gamma = 0.997 ,
199210 reward_multiplier = 100.0 if example_env .max_return <= 10.0 else 1 ,
200- num_actions_for_value_in_critic_loss = 2 ,
201- num_actions_for_value_in_actor_loss = 4 ,
202- num_critics = 4 ,
211+ num_actions_for_value_in_critic_loss = 3 ,
212+ num_actions_for_value_in_actor_loss = 5 ,
213+ num_critics = 5 ,
203214 actor_type = eval (args .actor_type ),
215+ fbc_filter_func = exp_filter ,
204216 )
205217 cli_utils .use_config (config , args .configs )
206218
207219 group_name = f"{ args .run_name } _{ env_name } "
208220 for trial in range (args .trials ):
209221 run_name = group_name + f"_trial_{ trial } "
222+
223+ # create dataset
224+ d4rl_dataset = D4RLDataset (d4rl_dset = example_env .dset )
225+ online_dset = DiskTrajDataset (
226+ dset_root = args .buffer_dir ,
227+ dset_name = run_name ,
228+ dset_min_size = 250 ,
229+ dset_max_size = args .dset_max_size ,
230+ )
231+ combined_dset = MixtureOfDatasets (
232+ datasets = [d4rl_dataset , online_dset ],
233+ # skew sampling towards the demos 80/20
234+ sampling_weights = [0.8 , 0.2 ],
235+ # gradually increase the weight of the online dset
236+ # over the first 100 epochs *after online collection starts*
237+ smooth_sudden_starts = 50 ,
238+ )
239+
210240 experiment = cli_utils .create_experiment_from_cli (
211241 args ,
212242 make_train_env = make_train_env ,
@@ -219,10 +249,15 @@ def step(self, action):
219249 group_name = group_name ,
220250 val_timesteps_per_epoch = args .eval_timesteps ,
221251 learning_rate = 1e-4 ,
222- dataset = dataset ,
252+ dataset = combined_dset ,
223253 padded_sampling = "right" ,
254+ start_collecting_at_epoch = args .online_after_epoch ,
255+ stagger_traj_file_lengths = False ,
256+ traj_save_len = args .eval_timesteps + 1 ,
224257 sample_actions = False ,
258+ exploration_wrapper_type = exploration_wrapper_type ,
225259 )
260+ # save a copy of this script at the time of the run
226261 experiment = cli_utils .switch_async_mode (experiment , args .mode )
227262 experiment .start ()
228263 if args .ckpt is not None :
0 commit comments