Skip to content

Commit 6e82505

Browse files
committed
🚀 [RofuncRL] Update trainers
1 parent cbc4309 commit 6e82505

16 files changed

+43
-31
lines changed

doc/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
author = 'Junjia Liu'
2828

2929
# The full version, including alpha/beta/rc tags
30-
release = '0.0.2.1'
30+
release = '0.0.2.3'
3131

3232
# -- General configuration ---------------------------------------------------
3333

examples/learning_rl/example_Ant_RofuncRL.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def train(custom_args):
4242
# Instantiate the RL trainer
4343
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
4444
env=env,
45-
device=cfg.rl_device)
45+
device=cfg.rl_device,
46+
env_name=custom_args.task)
4647

4748
# Start training
4849
trainer.train()
@@ -74,7 +75,8 @@ def inference(custom_args):
7475
# Instantiate the RL trainer
7576
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
7677
env=infer_env,
77-
device=cfg.rl_device)
78+
device=cfg.rl_device,
79+
env_name=custom_args.task)
7880
# load checkpoint
7981
if custom_args.ckpt_path is None:
8082
custom_args.ckpt_path = model_zoo(name="AntRofuncRLPPO.pt") # TODO: check

examples/learning_rl/example_CURICabinet_RofuncRL.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def train(custom_args):
4242
# Instantiate the RL trainer
4343
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
4444
env=env,
45-
device=cfg.rl_device)
45+
device=cfg.rl_device,
46+
env_name=custom_args.task)
4647

4748
# Start training
4849
trainer.train()
@@ -74,7 +75,8 @@ def inference(custom_args):
7475
# Instantiate the RL trainer
7576
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
7677
env=infer_env,
77-
device=cfg.rl_device)
78+
device=cfg.rl_device,
79+
env_name=custom_args.task)
7880
# load checkpoint
7981
if custom_args.ckpt_path is None:
8082
custom_args.ckpt_path = model_zoo(name="CURICabinetRofuncRLPPO_left_arm.pt") # TODO: Check

examples/learning_rl/example_FrankaCabinet_RofuncRL.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def train(custom_args):
4242
# Instantiate the RL trainer
4343
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
4444
env=env,
45-
device=cfg.rl_device)
45+
device=cfg.rl_device,
46+
env_name=custom_args.task)
4647

4748
# Start training
4849
trainer.train()
@@ -74,7 +75,8 @@ def inference(custom_args):
7475
# Instantiate the RL trainer
7576
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
7677
env=infer_env,
77-
device=cfg.rl_device)
78+
device=cfg.rl_device,
79+
env_name=custom_args.task)
7880
# load checkpoint
7981
if custom_args.ckpt_path is None:
8082
custom_args.ckpt_path = model_zoo(name=f"{custom_args.task}.pth") # TODO: Check

examples/learning_rl/example_GymTasks_RofuncRL.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def train(custom_args):
3434
# Instantiate the RL trainer
3535
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
3636
env=env,
37-
device=cfg.rl_device)
37+
device=cfg.rl_device,
38+
env_name=custom_args.task)
3839

3940
# Start training
4041
trainer.train()

examples/learning_rl/example_HumanoidAMP_RofuncRL.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def train(custom_args):
4242
# Instantiate the RL trainer
4343
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
4444
env=env,
45-
device=cfg.rl_device)
46-
45+
device=cfg.rl_device,
46+
env_name=custom_args.task)
4747
# Start training
4848
trainer.train()
4949

@@ -75,7 +75,9 @@ def inference(custom_args):
7575
# Instantiate the RL trainer
7676
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
7777
env=infer_env,
78-
device=cfg.rl_device)
78+
device=cfg.rl_device,
79+
env_name=custom_args.task)
80+
7981
# load checkpoint
8082
if custom_args.ckpt_path is None:
8183
custom_args.ckpt_path = model_zoo(name=f"{custom_args.task}.pth")

examples/learning_rl/example_HumanoidASE_RofuncRL.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def train(custom_args):
4343
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
4444
env=env,
4545
device=cfg.rl_device,
46+
env_name=custom_args.task,
4647
hrl=hrl)
4748

4849
# Start training
@@ -78,6 +79,7 @@ def inference(custom_args):
7879
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
7980
env=infer_env,
8081
device=cfg.rl_device,
82+
env_name=custom_args.task,
8183
hrl=hrl)
8284
# load checkpoint
8385
if custom_args.ckpt_path is None:
@@ -89,7 +91,7 @@ def inference(custom_args):
8991

9092

9193
if __name__ == '__main__':
92-
gpu_id = 0
94+
gpu_id = 1
9395

9496
parser = argparse.ArgumentParser()
9597
# Available tasks and motion files:

examples/learning_rl/example_Humanoid_RofuncRL.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def train(custom_args):
4242
# Instantiate the RL trainer
4343
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
4444
env=env,
45-
device=cfg.rl_device)
46-
45+
device=cfg.rl_device,
46+
env_name=custom_args.task)
4747
# Start training
4848
trainer.train()
4949

@@ -74,7 +74,9 @@ def inference(custom_args):
7474
# Instantiate the RL trainer
7575
trainer = trainer_map[custom_args.agent](cfg=cfg.train,
7676
env=infer_env,
77-
device=cfg.rl_device)
77+
device=cfg.rl_device,
78+
env_name=custom_args.task)
79+
7880
# load checkpoint
7981
if custom_args.ckpt_path is None:
8082
custom_args.ckpt_path = model_zoo(name=f"{custom_args.task}.pth")

rofunc/learning/RofuncRL/trainers/a2c_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121

2222
class A2CTrainer(BaseTrainer):
23-
def __init__(self, cfg, env, device):
24-
super().__init__(cfg, env, device)
23+
def __init__(self, cfg, env, device, env_name):
24+
super().__init__(cfg, env, device, env_name)
2525
self.memory = RandomMemory(memory_size=cfg.Trainer.rollouts, num_envs=self.env.num_envs, device=device)
2626
self.agent = A2CAgent(cfg, self.env.observation_space, self.env.action_space, self.memory,
2727
device, self.exp_dir, self.rofunc_logger)

rofunc/learning/RofuncRL/trainers/amp_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121

2222
class AMPTrainer(BaseTrainer):
23-
def __init__(self, cfg, env, device):
24-
super().__init__(cfg, env, device)
23+
def __init__(self, cfg, env, device, env_name):
24+
super().__init__(cfg, env, device, env_name)
2525
self.memory = RandomMemory(memory_size=self.rollouts, num_envs=self.env.num_envs, device=device)
2626
self.motion_dataset = RandomMemory(memory_size=200000, device=device)
2727
self.replay_buffer = RandomMemory(memory_size=1000000, device=device)

0 commit comments

Comments
 (0)