99from torch .distributions import Categorical , Independent , Normal , TransformedDistribution , TanhTransform
1010
1111from lzero .model .common import SimNorm
12- from lzero .model .utils import cal_dormant_ratio , compute_average_weight_magnitude , cal_effective_rank
12+ from lzero .model .utils import calculate_dormant_ratio , compute_average_weight_magnitude , compute_effective_rank
1313from .kv_caching import KeysValues
1414from .slicer import Head , PolicyHeadCont
1515from .tokenizer import Tokenizer
@@ -45,6 +45,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
4545
4646 self .transformer = Transformer (self .config )
4747 self .task_num = 1
48+ self .env_num = self .config .env_num
4849 if self .config .device == 'cpu' :
4950 self .device = torch .device ('cpu' )
5051 else :
@@ -70,7 +71,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
7071 print (f"self.pos_emb.weight.device: { self .pos_emb .weight .device } " )
7172
7273 self .register_token_num = config .register_token_num if hasattr (config , "register_token_num" ) else 4
73-
74+ if self .task_embed_option == "concat_task_embed" :
75+ self .obs_per_embdding_dim = self .config .embed_dim - self .task_embed_dim
76+ else :
77+ self .obs_per_embdding_dim = self .config .embed_dim
7478 self .continuous_action_space = self .config .continuous_action_space
7579
7680 # Initialize action embedding table
@@ -1352,7 +1356,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13521356 # E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64)
13531357 inputs = batch ['observations' ].contiguous ().view (- 1 , * shape [- 3 :])
13541358
1355- dormant_ratio_encoder_dict = cal_dormant_ratio (
1359+ dormant_ratio_encoder_dict = calculate_dormant_ratio (
13561360 self .tokenizer .encoder , inputs .detach (), dormant_threshold = self .dormant_threshold
13571361 )
13581362 dormant_ratio_encoder = dormant_ratio_encoder_dict ['global' ]
@@ -1370,11 +1374,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13701374 # The 'representation_layer_name' argument specifies the target layer within the model's named modules.
13711375
13721376 # Effective rank for the final linear layer of the encoder.
1373- e_rank_last_linear = cal_effective_rank (
1377+ e_rank_last_linear = compute_effective_rank (
13741378 self .tokenizer .encoder , inputs , representation_layer_name = "last_linear"
13751379 )
13761380 # Effective rank for the SimNorm layer of the encoder.
1377- e_rank_sim_norm = cal_effective_rank (
1381+ e_rank_sim_norm = compute_effective_rank (
13781382 self .tokenizer .encoder , inputs , representation_layer_name = "sim_norm"
13791383 )
13801384
@@ -1485,7 +1489,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
14851489 # ========= logging for analysis =========
14861490 if self .analysis_dormant_ratio_weight_rank :
14871491 # Calculate dormant ratio of the world model
1488- dormant_ratio_world_model = cal_dormant_ratio (self , {
1492+ dormant_ratio_world_model = calculate_dormant_ratio (self , {
14891493 'obs_embeddings_and_act_tokens' : (obs_embeddings .detach (), act_tokens .detach ())},
14901494 dormant_threshold = self .dormant_threshold )
14911495 dormant_ratio_transformer = dormant_ratio_world_model ['transformer' ]
0 commit comments