-
Notifications
You must be signed in to change notification settings - Fork 188
Expand file tree
/
Copy pathworld_model_multitask.py
More file actions
2075 lines (1772 loc) · 112 KB
/
world_model_multitask.py
File metadata and controls
2075 lines (1772 loc) · 112 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import logging
import math
import os
from typing import Any, Dict, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from ding.utils import get_rank
from einops import rearrange
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from matplotlib.patches import Patch
from sklearn.manifold import TSNE
from lzero.model.common import SimNorm
from lzero.model.unizero_world_models.world_model import WorldModel
from lzero.model.utils import (
calculate_dormant_ratio,
calculate_effective_rank,
compute_average_weight_magnitude,
)
from .slicer import Head
from .tokenizer import Tokenizer
from .transformer import Transformer, TransformerConfig
from .utils import LossWithIntermediateLosses, WorldModelOutput, hash_state, init_weights
# Set the logging level for the root logger
logging.getLogger().setLevel(logging.DEBUG)
class WorldModelMT(WorldModel):
"""
Overview:
The WorldModel class for the multi-task UniZero model. It is responsible for
predicting the next latent state, reward, policy, and value based on the
current latent state and action. This model is a scalable latent world model
composed of three main parts: a tokenizer, a transformer, and prediction heads.
"""
def __init__(self, config: TransformerConfig, tokenizer: Tokenizer) -> None:
"""
Overview:
Initializes the multi-task WorldModel.
Arguments:
- config (:obj:`TransformerConfig`): The configuration object for the transformer and world model.
- tokenizer (:obj:`Tokenizer`): The tokenizer for encoding observations.
"""
super().__init__(config, tokenizer)
self.tokenizer = tokenizer
self.config = config
self.continuous_action_space = self.config.continuous_action_space
self.task_num = config.task_num
self.env_num = self.config.env_num
# Whether to share prediction heads across tasks.
self.share_head = config.share_head
self.device = torch.device('cuda' if torch.cuda.is_available() and self.config.device != 'cpu' else 'cpu')
print(f"self.device: {self.device}")
# Positional embedding layer.
self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device)
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
# Task embedding setup.
self.use_task_embed = config.use_task_embed
self.task_embed_option = self.config.task_embed_option
self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96
self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4
if self.task_embed_option == "register_task_embed":
# When using "register_task_embed", the positional encoding is not adjusted.
# Use a non-trainable, zero-initialized nn.Embedding for positional embeddings.
self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device)
nn.init.constant_(self.pos_emb.weight, 0.0) # Initialize with all zeros.
self.pos_emb.weight.requires_grad = False # Disable updates.
# Precompute positional embedding differences for efficient inference.
self.precompute_pos_emb_diff_kv()
self.sim_norm = SimNorm(simnorm_dim=self.config.group_size)
# Configure embedding dimensions based on the task embedding strategy.
if self.task_embed_option == "concat_task_embed":
# TODO: Currently, with "concat_task_embed", self.pos_emb needs to be fixed at 0.
self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TDMPC2 suggests max_norm=1.
self.obs_act_embed_dim = config.embed_dim - self.task_embed_dim
self.register_token_num = 0
elif self.task_embed_option == "register_task_embed":
self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1)
self.obs_act_embed_dim = config.embed_dim
elif self.task_embed_option == "add_task_embed":
self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1)
self.obs_act_embed_dim = config.embed_dim
else:
self.task_emb = None
self.obs_act_embed_dim = config.embed_dim
self.register_token_num = 0
self.transformer = Transformer(self.config, self.task_emb)
# --- Analysis and Logging Setup ---
self.analysis_dormant_ratio_interval = self.config.get('analysis_dormant_ratio_interval', 100)
self._analysis_step_counter = 0
self.do_analysis = self.config.analysis_dormant_ratio_weight_rank
self.analysis_tsne = self.config.get('analysis_tsne', False)
if self.analysis_tsne:
self.env_id_list = self.config.env_id_list
# Automatically generate short names for environments.
self.env_short_names = {
env_id: env_id.replace('NoFrameskip-v4', '')
for env_id in self.config.env_id_list
}
# Color mapping to ensure each task has a fixed color.
self.num_tasks = len(self.env_id_list)
self.colors = self._generate_colors(self.num_tasks)
# --- Prediction Head Initialization ---
self.head_policy_multi_task = nn.ModuleList()
self.head_value_multi_task = nn.ModuleList()
self.head_rewards_multi_task = nn.ModuleList()
self.head_observations_multi_task = nn.ModuleList()
self.num_experts_in_moe_head = config.num_experts_in_moe_head
self.use_normal_head = config.use_normal_head
self.use_moe_head = config.use_moe_head
self.use_softmoe_head = config.use_softmoe_head
self.to(self.device)
# Initialize configuration parameters from the config object.
self._initialize_config_parameters()
self._initialize_patterns()
self.hidden_size = config.embed_dim // config.num_heads
# Initialize action embedding table based on action space type.
if self.continuous_action_space:
self.act_embedding_table = nn.ModuleList([
nn.Sequential(
nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False),
SimNorm(simnorm_dim=self.group_size)
) for task_id in range(self.task_num)
])
else:
# For discrete action space.
self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device)
print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}")
print(f'=' * 20)
print(f"self.obs_act_embed_dim: {self.obs_act_embed_dim}")
print(f'=' * 20)
# ==================== [NEW] Policy Stability Fix Options ====================
# Load fix options from config (with defaults for backward compatibility)
self.use_policy_logits_clip = getattr(self.config, 'use_policy_logits_clip', False)
self.policy_logits_clip_method = getattr(self.config, 'policy_logits_clip_method', 'normalize_max')
self.policy_logits_clip_min = getattr(self.config, 'policy_logits_clip_min', -10.0)
self.policy_logits_clip_max = getattr(self.config, 'policy_logits_clip_max', 10.0)
self.policy_logits_soft_beta = getattr(self.config, 'policy_logits_soft_beta', 1.0)
self.policy_logits_adaptive_percentile = getattr(self.config, 'policy_logits_adaptive_percentile', 95)
assert self.num_experts_in_moe_head > 0
if self.use_normal_head:
self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm')
print('We use normal head')
for task_id in range(self.task_num):
if self.continuous_action_space:
self.sigma_type = self.config.sigma_type
self.bound_type = self.config.bound_type
head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id])
else:
head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size)
if not self.share_head or task_id == 0:
self.head_policy_multi_task.append(head_policy)
head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size)
if not self.share_head or task_id == 0:
self.head_value_multi_task.append(head_value)
head_rewards = self._create_head(self.act_tokens_pattern, self.support_size)
if not self.share_head or task_id == 0:
self.head_rewards_multi_task.append(head_rewards)
head_observations = self._create_head(
self.all_but_last_latent_state_pattern,
self.config.embed_dim,
self._get_final_norm(self.final_norm_option_in_obs_head) # Use the specified normalization method.
)
if not self.share_head or task_id == 0:
self.head_observations_multi_task.append(head_observations)
elif self.use_softmoe_head:
print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}')
self.soft_moe_instances = {}
self.create_head_modules_softmoe()
self.head_policy_multi_task.append(self.head_policy)
self.head_value_multi_task.append(self.head_value)
self.head_rewards_multi_task.append(self.head_rewards)
self.head_observations_multi_task.append(self.head_observations)
elif self.use_moe_head:
print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}')
self.moe_instances = {}
self.create_head_modules_moe()
self.head_policy_multi_task.append(self.head_policy)
self.head_value_multi_task.append(self.head_value)
self.head_rewards_multi_task.append(self.head_rewards)
self.head_observations_multi_task.append(self.head_observations)
# Group all head modules into a ModuleDict for easier management.
self.head_dict = nn.ModuleDict({
name: module for name, module in self.named_children()
if name.startswith("head_") and name.endswith("_multi_task")
})
print("=" * 20)
print(f"self.head_dict:{self.head_dict}")
# Apply weight initialization. The order of initialization is important.
self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type))
self._initialize_last_layer_mt()
# --- Cache and State Initialization ---
self._initialize_cache_structures()
self._initialize_projection_input_dim()
self._initialize_statistics()
self._initialize_transformer_keys_values()
self.latent_recon_loss = torch.tensor(0., device=self.device)
self.perceptual_loss = torch.tensor(0., device=self.device)
# Initially set to game_segment_length to ensure all KVs in self.shared_pool_init_infer are valid.
# TODO: Critical. This should be changed to match segment_length.
self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache?
self.shared_pool_size_recur = int(self.num_simulations*self.env_num)
self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur
self.shared_pool_index = 0
# For init_infer, it only needs to retain the results of the most recent step.
# NOTE: A large pool size might cause incorrect retrieval of the kv cache.
self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)]
self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)]
# For wm (world model) forward passes during training.
self.shared_pool_size_wm = int(self.env_num)
self.shared_pool_wm = [None] * self.shared_pool_size_wm
self.shared_pool_index_wm = 0
self.reanalyze_phase = False
self._rank = get_rank()
def _scale_grad(self, grad: torch.Tensor) -> torch.Tensor:
"""
Overview:
Scales the gradient. This hook is registered to encoder parameters
to stabilize multi-task training.
Arguments:
- grad (:obj:`torch.Tensor`): The original gradient.
Returns:
- (:obj:`torch.Tensor`): The scaled gradient.
"""
# Scale by 1/sqrt(k) for a conservative approach, where k is the number of tasks.
return grad / math.sqrt(self.task_num)
def _generate_colors(self, num_colors: int) -> list:
"""
Overview:
Generates a list of unique colors for visualization purposes,
suitable for a large number of categories.
Arguments:
- num_colors (:obj:`int`): The desired number of unique colors.
Returns:
- (:obj:`list`): A list of colors.
"""
# Concatenate multiple discrete colormaps from matplotlib to get more colors.
color_maps = ['tab20', 'tab20b', 'tab20c']
colors = []
for cmap_name in color_maps:
cmap = plt.get_cmap(cmap_name)
colors.extend([cmap(i) for i in range(cmap.N)])
if len(colors) >= num_colors:
break
# Generate additional colors if needed.
if len(colors) < num_colors:
additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors))
colors.extend([additional_colors(i) for i in range(num_colors - len(colors))])
return colors[:num_colors]
def _initialize_config_parameters(self) -> None:
"""Initializes model attributes from the configuration object."""
self.policy_entropy_weight = self.config.policy_entropy_weight
self.predict_latent_loss_type = self.config.predict_latent_loss_type
self.group_size = self.config.group_size
self.num_groups = self.config.embed_dim // self.group_size
self.obs_type = self.config.obs_type
self.embed_dim = self.config.embed_dim
self.num_heads = self.config.num_heads
self.gamma = self.config.gamma
self.context_length = self.config.context_length
self.dormant_threshold = self.config.dormant_threshold
self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank
self.num_observations_tokens = self.config.tokens_per_block - 1
self.latent_recon_loss_weight = self.config.latent_recon_loss_weight
self.perceptual_loss_weight = self.config.perceptual_loss_weight
self.support_size = self.config.support_size
self.action_space_size = self.config.action_space_size
self.max_cache_size = self.config.max_cache_size
self.num_layers = self.config.num_layers
def _initialize_patterns(self) -> None:
"""Initializes patterns (masks) for selecting specific tokens for prediction heads."""
self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block)
self.all_but_last_latent_state_pattern[-2] = 0
self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block)
self.act_tokens_pattern[-1] = 1
self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block)
self.value_policy_tokens_pattern[-2] = 1
def _get_final_norm(self, norm_option: str) -> nn.Module:
"""Returns the specified normalization module."""
if norm_option == 'LayerNorm':
return nn.LayerNorm(self.config.embed_dim, eps=1e-5)
elif norm_option == 'SimNorm':
return SimNorm(simnorm_dim=self.config.group_size)
else:
raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}")
def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None) -> Head:
"""Creates a standard prediction head."""
modules = [
nn.LayerNorm(self.config.embed_dim), # TODO
nn.Linear(self.config.embed_dim, self.config.embed_dim),
nn.LayerNorm(self.config.embed_dim),
nn.GELU(approximate='tanh'),
nn.Linear(self.config.embed_dim, output_dim)
]
if norm_layer:
modules.append(norm_layer)
return Head(
max_blocks=self.config.max_blocks,
block_mask=block_mask,
head_module=nn.Sequential(*modules)
)
def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, moe: Optional[nn.Module] = None) -> Head:
"""Creates a prediction head with a Mixture-of-Experts (MoE) layer."""
modules = [
nn.LayerNorm(self.config.embed_dim), # TODO
moe,
nn.Linear(self.config.embed_dim, output_dim)
]
if norm_layer:
modules.append(norm_layer)
return Head(
max_blocks=self.config.max_blocks,
block_mask=block_mask,
head_module=nn.Sequential(*modules)
)
def get_moe(self, name: str) -> nn.Module:
"""Gets or creates a MoE instance by name."""
from .moe import MoELayer, MultiplicationFeedForward
if name not in self.moe_instances:
# Create multiple FeedForward instances for multiplication-based MoE.
experts = nn.ModuleList([
MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer)
])
self.moe_instances[name] = MoELayer(
experts=experts,
gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False),
num_experts_per_tok=1,
)
return self.moe_instances[name]
def create_head_modules_moe(self) -> None:
"""Creates all MoE prediction head modules."""
self.head_rewards = self._create_head_moe(self.act_tokens_pattern, self.support_size, moe=self.get_moe("rewards_moe"))
self.head_observations = self._create_head_moe(self.all_but_last_latent_state_pattern, self.embed_dim, norm_layer=self.sim_norm, moe=self.get_moe("observations_moe"))
self.head_policy = self._create_head_moe(self.value_policy_tokens_pattern, self.action_space_size, moe=self.get_moe("policy_moe"))
self.head_value = self._create_head_moe(self.value_policy_tokens_pattern, self.support_size, moe=self.get_moe("value_moe"))
def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, soft_moe: Optional[nn.Module] = None) -> Head:
"""Creates a prediction head with a Soft-MoE layer."""
modules = [
soft_moe,
nn.Linear(self.config.embed_dim, output_dim)
]
if norm_layer:
modules.append(norm_layer)
return Head(
max_blocks=self.config.max_blocks,
block_mask=block_mask,
head_module=nn.Sequential(*modules)
)
def get_soft_moe(self, name: str) -> nn.Module:
"""Gets or creates a Soft-MoE instance by name."""
from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE
if name not in self.soft_moe_instances:
self.soft_moe_instances[name] = SoftMoE(
dim=self.embed_dim,
num_experts=self.num_experts_in_moe_head,
geglu=True
)
return self.soft_moe_instances[name]
def create_head_modules_softmoe(self) -> None:
"""Creates all Soft-MoE prediction head modules."""
self.head_rewards = self._create_head_softmoe(self.act_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("rewards_soft_moe"))
self.head_observations = self._create_head_softmoe(self.all_but_last_latent_state_pattern, self.config.embed_dim, norm_layer=self.sim_norm, soft_moe=self.get_soft_moe("observations_soft_moe"))
self.head_policy = self._create_head_softmoe(self.value_policy_tokens_pattern, self.action_space_size, soft_moe=self.get_soft_moe("policy_soft_moe"))
self.head_value = self._create_head_softmoe(self.value_policy_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("value_soft_moe"))
def _initialize_last_layer_mt(self) -> None:
"""Initializes the last linear layer of prediction heads to zero for training stability."""
last_linear_layer_init_zero = True
print(f'world_model_mt.py:self.task_num:{self.task_num}')
if last_linear_layer_init_zero:
if self.continuous_action_space:
# For continuous actions, policy head might have a different initialization strategy.
module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task
else:
module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task
for head in module_to_initialize:
for layer in reversed(head.head_module):
if isinstance(layer, nn.Linear):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
break
def _initialize_cache_structures(self) -> None:
"""Initializes cache structures for storing past keys and values during inference."""
self.past_kv_cache_recurrent_infer = {}
self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur
self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)]
# Auxiliary data structure for reverse lookup: pool_index -> key
self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)]
self.keys_values_wm_list = []
self.keys_values_wm_size_list = []
def _initialize_projection_input_dim(self) -> None:
"""Initializes the input dimension for the projection based on observation tokenization."""
if self.num_observations_tokens == 16:
self.projection_input_dim = 128
elif self.num_observations_tokens == 1:
if self.task_embed_option in ["concat_task_embed", "register_task_embed", "add_task_embed"]:
self.projection_input_dim = self.config.embed_dim
if self.task_embed_option == "concat_task_embed":
self.projection_input_dim -= self.task_embed_dim
else:
self.projection_input_dim = self.config.embed_dim
def _initialize_statistics(self) -> None:
"""Initializes counters for cache hit rates and other statistics."""
self.hit_count = 0
self.total_query_count = 0
self.length_largethan_maxminus5_context_cnt = 0
self.length_largethan_maxminus7_context_cnt = 0
self.root_hit_cnt = 0
self.root_total_query_cnt = 0
def _initialize_transformer_keys_values(self) -> None:
"""Initializes empty key-value cache structures for the transformer."""
self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length)
self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.context_length)
def precompute_pos_emb_diff_kv(self) -> None:
"""
Overview:
Precomputes positional embedding differences for keys and values. This is an
optimization to speed up KV cache updates during recurrent inference by avoiding
re-computation of positional embeddings.
"""
if self.context_length <= 2:
return # No context to precompute for.
# Precompute positional embedding matrices for all layers.
self.positional_embedding_k = [self._get_positional_embedding(layer, 'key') for layer in range(self.config.num_layers)]
self.positional_embedding_v = [self._get_positional_embedding(layer, 'value') for layer in range(self.config.num_layers)]
# Precompute all possible positional embedding differences.
self.pos_emb_diff_k = []
self.pos_emb_diff_v = []
for layer in range(self.config.num_layers):
layer_pos_emb_diff_k = {}
layer_pos_emb_diff_v = {}
# This is for the case when context window is full and we shift it.
# TODO: Generalize for different start/end points if necessary.
for start in [2]:
for end in [self.context_length - 1]:
original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :]
new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :]
layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k
original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :]
new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :]
layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v
self.pos_emb_diff_k.append(layer_pos_emb_diff_k)
self.pos_emb_diff_v.append(layer_pos_emb_diff_v)
def _get_positional_embedding(self, layer: int, attn_type: str) -> torch.Tensor:
"""
Overview:
Helper function to get positional embedding for a given layer and attention type.
Arguments:
- layer (:obj:`int`): The layer index.
- attn_type (:obj:`str`): The attention type, either 'key' or 'value'.
Returns:
- (:obj:`torch.Tensor`): The positional embedding tensor, detached from the graph.
"""
# TODO: Review the use of detach(). It's used here to prevent gradients from flowing back
# through the positional embeddings during this pre-computation phase.
attn_func = getattr(self.transformer.blocks[layer].attn, attn_type)
pos_emb = attn_func(self.pos_emb.weight).view(
1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads
).transpose(1, 2)
return pos_emb.to(self.device).detach()
def forward(
self,
obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]],
past_keys_values: Optional[torch.Tensor] = None,
kvcache_independent: bool = False,
is_init_infer: bool = True,
valid_context_lengths: Optional[torch.Tensor] = None,
task_id: int = 0
) -> WorldModelOutput:
"""
Overview:
Main forward pass for the world model. It processes either observation embeddings,
action tokens, or a combination of both, and passes them through the transformer
to generate predictions.
Arguments:
- obs_embeddings_or_act_tokens (:obj:`Dict`): A dictionary containing input tensors.
Can be 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'.
- past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps.
- kvcache_independent (:obj:`bool`): Whether to use independent KV caching per item in the batch.
- is_init_infer (:obj:`bool`): Flag indicating if this is an initial inference step.
- valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths for each item.
- task_id (:obj:`int`): The ID of the current task.
Returns:
- (:obj:`WorldModelOutput`): An object containing the transformer output and logits for
observations, rewards, policy, and value.
"""
if self.use_task_embed:
self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device))
self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1)
else:
# Use a zero tensor if task embeddings are disabled.
self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device)
prev_steps = 0 if past_keys_values is None else past_keys_values.size
if kvcache_independent:
prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], device=self.device)
if is_init_infer:
valid_context_lengths = None
# --- Branch 1: Inference Phase (Collect/Eval) - Process observation embeddings ---
if 'obs_embeddings' in obs_embeddings_or_act_tokens:
obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings']
if len(obs_embeddings.shape) == 2:
obs_embeddings = obs_embeddings.unsqueeze(1)
# Apply task embeddings based on the chosen strategy.
if self.task_embed_option == "add_task_embed":
obs_embeddings = obs_embeddings + self.task_embeddings
elif self.task_embed_option == "concat_task_embed":
if is_init_infer and not self.reanalyze_phase:
# Concatenate task embeddings only during initial inference.
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1)
obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1)
num_steps = obs_embeddings.size(1)
sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths)
# --- Branch 2: Inference Phase (Collect/Eval) - Process action tokens ---
elif 'act_tokens' in obs_embeddings_or_act_tokens:
act_tokens = obs_embeddings_or_act_tokens['act_tokens']
if self.continuous_action_space:
num_steps = 1
act_tokens = act_tokens.float()
if len(act_tokens.shape) == 2:
act_tokens = act_tokens.unsqueeze(1)
else:
if len(act_tokens.shape) == 3:
act_tokens = act_tokens.squeeze(1)
num_steps = act_tokens.size(1)
# Get action embeddings from the task-specific or shared table.
if self.task_num >= 1 and self.continuous_action_space:
act_embeddings = self.act_embedding_table[task_id](act_tokens)
else:
act_embeddings = self.act_embedding_table(act_tokens)
# Apply task embeddings.
if self.task_embed_option == "concat_task_embed":
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(act_embeddings.shape[0], act_embeddings.shape[1], -1)
act_embeddings = torch.cat([act_embeddings, task_emb_expanded], dim=-1)
sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths)
# --- Branch 3: Training Phase - Process combined observation embeddings and action tokens ---
else:
if self.continuous_action_space:
sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id)
else:
sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps)
# Pass sequences through the transformer.
x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=task_id)
# Generate logits using shared, task-specific, or MoE heads.
head_index = 0 if self.share_head else task_id
if self.use_moe_head or self.use_softmoe_head:
logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps)
logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps)
else:
logits_observations = self.head_observations_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps)
logits_rewards = self.head_rewards_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps)
logits_policy = self.head_policy_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps)
logits_value = self.head_value_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps)
# ==================== [NEW] Advanced Policy Logits Control ====================
# Apply configurable policy logits control to prevent explosion
# Multiple methods available: hard, soft_tanh, soft_sigmoid, normalize_max, etc.
self.use_policy_logits_clip=True # TODO
if self.use_policy_logits_clip:
logits_policy = self._apply_policy_logits_control(logits_policy)
# ================================================================================
return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value)
def _add_position_embeddings(
self,
embeddings: torch.Tensor,
prev_steps: Union[int, torch.Tensor],
num_steps: int,
kvcache_independent: bool,
is_init_infer: bool,
valid_context_lengths: Optional[torch.Tensor]
) -> torch.Tensor:
"""
Overview:
Adds positional embeddings to the input embeddings.
Arguments:
- embeddings (:obj:`torch.Tensor`): Input embeddings.
- prev_steps (:obj:`Union[int, torch.Tensor]`): Number of previous steps in the cache.
- num_steps (:obj:`int`): Number of new steps being added.
- kvcache_independent (:obj:`bool`): Flag for independent KV caching.
- is_init_infer (:obj:`bool`): Flag for initial inference.
- valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for each sequence.
Returns:
- (:obj:`torch.Tensor`): Embeddings with added positional information.
"""
if kvcache_independent:
steps_indices = prev_steps.unsqueeze(1) + torch.arange(num_steps, device=embeddings.device)
position_embeddings = self.pos_emb(steps_indices)
return embeddings + position_embeddings
else:
if is_init_infer:
# For initial inference, positions are sequential from the previous step count.
pos_indices = prev_steps + torch.arange(num_steps, device=self.device)
return embeddings + self.pos_emb(pos_indices)
else:
# For recurrent steps, use valid_context_lengths to get correct positions.
valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device)
pos_indices = valid_context_lengths.unsqueeze(1) + torch.arange(num_steps, device=self.device)
position_embeddings = self.pos_emb(pos_indices)
return embeddings + position_embeddings
def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]:
"""
Overview:
Processes and combines observation embeddings and continuous action tokens for training.
Arguments:
- obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'.
- prev_steps (:obj:`int`): Number of previous steps.
- task_id (:obj:`int`): The current task ID.
Returns:
- (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps.
"""
obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens']
if len(obs_embeddings.shape) == 3:
obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1)
num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
act_tokens = act_tokens.float()
if len(act_tokens.shape) == 2:
act_tokens = act_tokens.unsqueeze(-1)
act_embeddings = self.act_embedding_table[task_id](act_tokens)
B, L, K, E_obs = obs_embeddings.size()
obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device)
if self.task_embed_option == "concat_task_embed":
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1)
for i in range(L):
obs = obs_embeddings[:, i, :, :]
if self.task_embed_option == "add_task_embed":
obs = obs + self.task_embeddings
elif self.task_embed_option == "concat_task_embed":
obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1)
act = act_embeddings[:, i, :].unsqueeze(1)
if self.task_embed_option == "concat_task_embed":
act = torch.cat([act, task_emb_expanded], dim=-1)
obs_act = torch.cat([obs, act], dim=1)
obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act
pos_indices = prev_steps + torch.arange(num_steps, device=self.device)
return obs_act_embeddings + self.pos_emb(pos_indices), num_steps
def _process_obs_act_combined(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]:
"""
Overview:
Processes and combines observation embeddings and discrete action tokens for training.
Arguments:
- obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'.
- prev_steps (:obj:`int`): Number of previous steps.
- task_id (:obj:`int`): The current task ID.
Returns:
- (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps.
"""
obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens']
if len(obs_embeddings.shape) == 3:
obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1)
num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1))
act_embeddings = self.act_embedding_table(act_tokens)
B, L, K, E_obs = obs_embeddings.size()
obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device)
if self.task_embed_option == "concat_task_embed":
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1)
for i in range(L):
obs = obs_embeddings[:, i, :, :]
if self.task_embed_option == "add_task_embed":
obs = obs + self.task_embeddings
elif self.task_embed_option == "concat_task_embed":
obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1)
act = act_embeddings[:, i, 0, :].unsqueeze(1)
if self.task_embed_option == "concat_task_embed":
act = torch.cat([act, task_emb_expanded], dim=-1)
obs_act = torch.cat([obs, act], dim=1)
obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act
pos_indices = prev_steps + torch.arange(num_steps, device=self.device)
return obs_act_embeddings + self.pos_emb(pos_indices), num_steps
def _transformer_pass(
self,
sequences: torch.Tensor,
past_keys_values: Optional[torch.Tensor],
kvcache_independent: bool,
valid_context_lengths: Optional[torch.Tensor],
task_id: int = 0
) -> torch.Tensor:
"""
Overview:
Passes sequences through the transformer, handling different KV cache modes.
Arguments:
- sequences (:obj:`torch.Tensor`): Input sequences.
- past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps.
- kvcache_independent (:obj:`bool`): Flag for independent KV caching.
- valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths.
- task_id (:obj:`int`): The current task ID.
Returns:
- (:obj:`torch.Tensor`): The output from the transformer.
"""
if kvcache_independent:
x = [
self.transformer(sequences[k].unsqueeze(0), past_kv, valid_context_lengths=valid_context_lengths[k].unsqueeze(0))
for k, past_kv in enumerate(past_keys_values)
]
return torch.cat(x, dim=0)
else:
return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths)
@torch.no_grad()
def reset_for_initial_inference(self, obs_act_dict: dict, task_id: int = 0) -> Tuple[WorldModelOutput, torch.Tensor]:
"""
Overview:
Resets the model state for the beginning of an episode or a new inference sequence.
It processes the initial observations and actions to create the first latent state
and populate the KV cache.
Arguments:
- obs_act_dict (:obj:`dict`): A dictionary containing 'obs', 'action', and 'current_obs'.
- task_id (:obj:`int`): The ID of the current task.
Returns:
- (:obj:`Tuple[WorldModelOutput, torch.Tensor]`): A tuple containing the world model output
and the initial latent state.
"""
if self.use_task_embed:
self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device))
self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1)
else:
self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device)
batch_obs = obs_act_dict['obs']
batch_action = obs_act_dict['action']
batch_current_obs = obs_act_dict['current_obs']
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id)
if batch_current_obs is not None:
# --- Collect and Evaluation Phase ---
current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id)
# The latent state is the combination of observation embedding and task embedding.
if self.use_task_embed:
if self.task_embed_option == "add_task_embed":
self.latent_state = current_obs_embeddings + self.task_embeddings
elif self.task_embed_option == "concat_task_embed":
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(current_obs_embeddings.shape[0], current_obs_embeddings.shape[1], -1)
self.latent_state = torch.cat([current_obs_embeddings, task_emb_expanded], dim=-1)
else: # "register_task_embed" or other cases
self.latent_state = current_obs_embeddings
else:
self.latent_state = current_obs_embeddings
outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id)
else:
# --- Training Phase (for calculating target values) ---
if self.use_task_embed:
if self.task_embed_option == "add_task_embed":
self.latent_state = obs_embeddings + self.task_embeddings
elif self.task_embed_option == "concat_task_embed":
task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1)
self.latent_state = torch.cat([obs_embeddings, task_emb_expanded], dim=-1)
else:
self.latent_state = obs_embeddings
else:
self.latent_state = obs_embeddings
outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id)
return outputs_wm, self.latent_state
#@profile
@torch.no_grad()
def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor,
batch_action=None,
current_obs_embeddings=None, task_id = 0) -> torch.FloatTensor:
"""
Refresh key-value pairs with the initial latent state for inference.
Arguments:
- latent_state (:obj:`torch.LongTensor`): The latent state embeddings.
- batch_action (optional): Actions taken.
- current_obs_embeddings (optional): Current observation embeddings.
Returns:
- torch.FloatTensor: The outputs from the world model.
"""
n, num_observations_tokens, _ = last_obs_embeddings.shape
if n <= self.env_num and current_obs_embeddings is not None:
# ================ Collect and Evaluation Phase ================
if current_obs_embeddings is not None:
if self.continuous_action_space:
first_step_flag = not isinstance(batch_action[0], np.ndarray)
else:
first_step_flag = max(batch_action) == -1
if first_step_flag:
# First step in an episode
self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0],
max_tokens=self.context_length)
# print(f"current_obs_embeddings.device: {current_obs_embeddings.device}")
outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings},
past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id)
if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]:
# Copy and store keys_values_wm for a single environment
self.update_cache_context(self.latent_state, is_init_infer=True)
else:
# Copy and store keys_values_wm for a single environment
self.update_cache_context(current_obs_embeddings, is_init_infer=True)
else:
# Assume latest_state is the new latent_state, containing information from ready_env_num environments
ready_env_num = current_obs_embeddings.shape[0]
self.keys_values_wm_list = []
self.keys_values_wm_size_list = []
for i in range(ready_env_num):
# Retrieve latent state for a single environment
state_single_env = last_obs_embeddings[i]
# Compute hash value using latent state for a single environment
cache_key = hash_state(
state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor
# Retrieve cached value
cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key)
if cache_index is not None:
matched_value = self.shared_pool_init_infer[i][cache_index]
else:
matched_value = None
self.root_total_query_cnt += 1
if matched_value is not None:
# If a matching value is found, add it to the list
self.root_hit_cnt += 1
# deepcopy is needed because forward modifies matched_value in place
self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value))
self.keys_values_wm_size_list.append(matched_value.size)
else:
# Reset using zero values
self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length)
outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)},
past_keys_values=self.keys_values_wm_single_env,
is_init_infer=True, task_id=task_id)
self.keys_values_wm_list.append(self.keys_values_wm_single_env)
self.keys_values_wm_size_list.append(1)
# Input self.keys_values_wm_list, output self.keys_values_wm
self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True)
batch_action = batch_action[:ready_env_num]
# if ready_env_num < self.env_num:
# print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}')
if self.continuous_action_space:
act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1)
else:
act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1)
outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm,
is_init_infer=True, task_id=task_id)
outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings},
past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id)
# Copy and store keys_values_wm for a single environment
if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]:
# Copy and store keys_values_wm for a single environment
self.update_cache_context(self.latent_state, is_init_infer=True)
else:
# import ipdb; ipdb.set_trace()
# Copy and store keys_values_wm for a single environment
self.update_cache_context(current_obs_embeddings, is_init_infer=True)
elif batch_action is not None and current_obs_embeddings is None:
# elif n > self.env_num and batch_action is not None and current_obs_embeddings is None:
# ================ calculate the target value in Train phase ================
# [192, 16, 64] -> [32, 6, 16, 64]
last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens,
self.obs_act_embed_dim) # (BL, K) for unroll_step=1
last_obs_embeddings = last_obs_embeddings[:, :-1, :]
batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device)
if self.continuous_action_space:
act_tokens = batch_action
else:
act_tokens = rearrange(batch_action, 'b l -> b l 1')
# select the last timestep for each sample
# This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used.
last_steps_act = act_tokens[:, -1:, :]
act_tokens = torch.cat((act_tokens, last_steps_act), dim=1)
outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id)
# select the last timestep for each sample
last_steps_value = outputs_wm.logits_value[:, -1:, :]
outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1)
last_steps_policy = outputs_wm.logits_policy[:, -1:, :]
outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1)
# Reshape your tensors
# outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101)
outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e')
outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e')
return outputs_wm
#@profile
@torch.no_grad()