Skip to content

Commit 69949ed

Browse files
committed
🚀 [RofuncRL] Fix the mistake in ASE latent space construction
1 parent ca3882d commit 69949ed

10 files changed

+59
-15
lines changed

examples/learning_rl/example_HumanoidASE_RofuncRL.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def inference(custom_args):
101101
# HumanoidASEReachSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
102102
# HumanoidASELocationSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
103103
# HumanoidASEStrikeSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
104-
parser.add_argument("--task", type=str, default="HumanoidASEReachSwordShield")
104+
parser.add_argument("--task", type=str, default="HumanoidASEGetupSwordShield")
105105
parser.add_argument("--motion_file", type=str,
106-
default="reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy")
106+
default="reallusion_sword_shield/dataset_reallusion_sword_shield.yaml")
107107
parser.add_argument("--agent", type=str, default="ase") # Available agent: ase
108108
parser.add_argument("--num_envs", type=int, default=4096)
109109
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))

rofunc/config/learning/rl/train/HumanoidASEGetupSwordShieldASERofuncRL.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ Agent:
5252
discriminator_weight_decay_scale: 0.0001
5353

5454
ase_latent_dim: 64
55+
ase_latent_steps_min: 1
56+
ase_latent_steps_max: 150
57+
5558
enc_reward_scale: 1
5659
enc_weight_decay_scale: 0
5760
enc_gradient_penalty_scale: 0

rofunc/config/learning/rl/train/HumanoidASEHeadingSwordShieldASERofuncRL.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Agent:
4545
use_gae: True # If true, use generalized advantage estimation.
4646

4747
ase_latent_dim: 64
48+
ase_latent_steps_min: 1
49+
ase_latent_steps_max: 150
4850

4951
task_reward_weight: 0.9
5052
style_reward_weight: 0.1

rofunc/config/learning/rl/train/HumanoidASELocationSwordShieldASERofuncRL.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Agent:
4545
use_gae: True # If true, use generalized advantage estimation.
4646

4747
ase_latent_dim: 64
48+
ase_latent_steps_min: 1
49+
ase_latent_steps_max: 150
4850

4951
task_reward_weight: 0.9
5052
style_reward_weight: 0.1

rofunc/config/learning/rl/train/HumanoidASEPerturbSwordShieldASERofuncRL.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ Agent:
5252
discriminator_weight_decay_scale: 0.0001
5353

5454
ase_latent_dim: 64
55+
ase_latent_steps_min: 1
56+
ase_latent_steps_max: 150
57+
5558
enc_reward_scale: 1
5659
enc_weight_decay_scale: 0
5760
enc_gradient_penalty_scale: 0

rofunc/config/learning/rl/train/HumanoidASEReachSwordShieldASERofuncRL.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Agent:
4545
use_gae: True # If true, use generalized advantage estimation.
4646

4747
ase_latent_dim: 64
48+
ase_latent_steps_min: 1
49+
ase_latent_steps_max: 150
4850

4951
task_reward_weight: 0.9
5052
style_reward_weight: 0.1

rofunc/config/learning/rl/train/HumanoidASEStrikeSwordShieldASERofuncRL.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Agent:
4545
use_gae: True # If true, use generalized advantage estimation.
4646

4747
ase_latent_dim: 64
48+
ase_latent_steps_min: 1
49+
ase_latent_steps_max: 150
4850

4951
task_reward_weight: 0.9
5052
style_reward_weight: 0.1

rofunc/learning/RofuncRL/agents/mixline/ase_agent.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ def __init__(self,
6363
"""ASE specific parameters"""
6464
self._lr_e = cfg.Agent.lr_e
6565
self._ase_latent_dim = cfg.Agent.ase_latent_dim
66-
# self._ase_latent_steps_min = self.cfg.Agent.ase_latent_steps_min
67-
# self._ase_latent_steps_max = self.cfg.Agent.ase_latent_steps_max
6866
# self._amp_diversity_bonus = self.cfg.Agent.amp_diversity_bonus
6967
# self._amp_diversity_tar = self.cfg.Agent.amp_diversity_tar
7068
# self._enc_coef = self.cfg.Agent.enc_coef
@@ -92,24 +90,21 @@ def __init__(self,
9290
self.memory.create_tensor(name="ase_latents", size=self._ase_latent_dim, dtype=torch.float32)
9391
self._tensors_names.append("ase_latents")
9492

93+
self._ase_latents = torch.zeros((self.memory.num_envs, self._ase_latent_dim), dtype=torch.float32,
94+
device=self.device)
95+
9596
def _set_up(self):
9697
super()._set_up()
9798
self.optimizer_enc = torch.optim.Adam(self.encoder.parameters(), lr=self._lr_e, eps=self._adam_eps)
9899
if self._lr_scheduler is not None:
99100
self.scheduler_enc = self._lr_scheduler(self.optimizer_enc, **self._lr_scheduler_kwargs)
100101
self.checkpoint_modules["optimizer_enc"] = self.optimizer_enc
101102

102-
def _update_latents(self, num_envs: int):
103-
# Equ. 11, provide the model with a latent space
104-
z_bar = torch.normal(torch.zeros([num_envs, self._ase_latent_dim]))
105-
self._ase_latents = z = torch.nn.functional.normalize(z_bar, dim=-1).to(self.device)
106-
107103
def act(self, states: torch.Tensor, deterministic: bool = False, ase_latents: torch.Tensor = None):
108104
if self._current_states is not None:
109105
states = self._current_states
110106

111107
if ase_latents is None:
112-
self._update_latents(states.shape[0])
113108
ase_latents = self._ase_latents
114109

115110
if not deterministic:
@@ -171,10 +166,10 @@ def update_net(self):
171166
amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
172167
if self._least_square_discriminator:
173168
style_rewards = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
174-
torch.tensor(0.0001, device=self.device))
169+
torch.tensor(0.0001, device=self.device))
175170
else:
176171
style_rewards = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
177-
torch.tensor(0.0001, device=self.device)))
172+
torch.tensor(0.0001, device=self.device)))
178173
style_rewards *= self._discriminator_reward_scale
179174

180175
# Compute encoder reward

rofunc/learning/RofuncRL/agents/mixline/ase_hrl_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def __init__(self,
127127
self._task_reward_weight = self.cfg.Agent.task_reward_weight
128128
self._style_reward_weight = self.cfg.Agent.style_reward_weight
129129
self._kl_threshold = self.cfg.Agent.kl_threshold
130-
self._rewards_shaper = self.cfg.get("Agent", {}).get("rewards_shaper", lambda rewards: rewards * 0.01)
130+
self._rewards_shaper = None
131+
# self._rewards_shaper = self.cfg.get("Agent", {}).get("rewards_shaper", lambda rewards: rewards * 0.01)
131132
self._state_preprocessor = RunningStandardScaler
132133
self._state_preprocessor_kwargs = self.cfg.get("Agent", {}).get("state_preprocessor_kwargs",
133134
{"size": observation_space, "device": device})

rofunc/learning/RofuncRL/trainers/ase_trainer.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
import torch
1617

1718
from rofunc.learning.RofuncRL.agents.mixline.ase_agent import ASEAgent
1819
from rofunc.learning.RofuncRL.agents.mixline.ase_hrl_agent import ASEHRLAgent
@@ -46,9 +47,42 @@ def __init__(self, cfg, env, device, env_name, hrl=False):
4647
num_samples))
4748
self.setup_wandb()
4849

50+
'''Misc variables'''
51+
self._latent_reset_steps = torch.zeros(self.env.num_envs, dtype=torch.int32).to(self.device)
52+
self._latent_steps_min = self.cfg.Agent.ase_latent_steps_min
53+
self._latent_steps_max = self.cfg.Agent.ase_latent_steps_max
54+
55+
def _reset_latents(self, env_ids):
56+
# Equ. 11, provide the model with a latent space
57+
z_bar = torch.normal(torch.zeros([len(env_ids), self.agent._ase_latent_dim]))
58+
self.agent._ase_latents[env_ids] = torch.nn.functional.normalize(z_bar, dim=-1).to(self.device)
59+
60+
def _reset_latent_step_count(self, env_ids):
61+
self._latent_reset_steps[env_ids] = torch.randint_like(self._latent_reset_steps[env_ids],
62+
low=self._latent_steps_min,
63+
high=self._latent_steps_max)
64+
65+
def _update_latents(self):
66+
new_latent_envs = self._latent_reset_steps <= self.env.progress_buf
67+
68+
need_update = torch.any(new_latent_envs)
69+
if need_update:
70+
new_latent_env_ids = new_latent_envs.nonzero(as_tuple=False).flatten()
71+
self._reset_latents(new_latent_env_ids)
72+
self._latent_reset_steps[new_latent_env_ids] += torch.randint_like(
73+
self._latent_reset_steps[new_latent_env_ids],
74+
low=self._latent_steps_min,
75+
high=self._latent_steps_max)
76+
4977
def pre_interaction(self):
50-
if self.collect_observation is not None:
51-
self.agent._current_states = self.collect_observation()
78+
if self.collect_observation is not None: # Reset failed envs
79+
obs_dict, done_env_ids = self.env.reset_done()
80+
self.agent._current_states = obs_dict["obs"]
81+
if len(done_env_ids) > 0:
82+
self._reset_latents(done_env_ids)
83+
self._reset_latent_step_count(done_env_ids)
84+
85+
self._update_latents()
5286

5387
def post_interaction(self):
5488
self._rollout += 1

0 commit comments

Comments
 (0)