@@ -383,30 +383,34 @@ def __init__(self, embedding_dim: int):
383383 self .out_dim = self .INVENTORY_SIZE * embedding_dim
384384
385385 self .glyphs_embedding = nn .Embedding (
386- nethack .NUM_OBJECTS + 1 ,
386+ nethack .MAX_GLYPH + 1 ,
387387 embedding_dim ,
388- padding_idx = nethack .NUM_OBJECTS ,
388+ # padding_idx=nethack.MAX_GLYPH,
389+ )
390+ self .keys_embedding = nn .Embedding (
391+ 256 ,
392+ embedding_dim ,
393+ # padding_idx=0.
389394 )
390- self .keys_embedding = nn .Embedding (96 , embedding_dim , padding_idx = 0 )
391395 self .classes_embedding = nn .Embedding (
392396 nethack .MAXOCLASSES + 1 ,
393397 embedding_dim ,
394- padding_idx = nethack .MAXOCLASSES ,
398+ # padding_idx=nethack.MAXOCLASSES,
395399 )
396400
397401 def forward (self , inv_glyps , inv_keys , inv_oclasses ):
398- normalized_glyph_ids = inv_glyps - nethack .GLYPH_OBJ_OFF
399- normalized_glyph_ids = torch .where (
400- inv_glyps < nethack .MAX_GLYPH , inv_glyps - nethack .GLYPH_OBJ_OFF , nethack .NUM_OBJECTS
401- )
402- normalized_inv_keys = torch .where (inv_keys > 0 , inv_keys - 32 , 0 )
403-
404- self ._check_embedding_indices (normalized_glyph_ids , self .glyphs_embedding )
405- self ._check_embedding_indices (normalized_inv_keys , self .keys_embedding )
406- self ._check_embedding_indices (inv_oclasses , self .classes_embedding )
407-
408- embedded_glyphs = selectt (self .glyphs_embedding , normalized_glyph_ids .long (), True )
409- embedded_inv_keys = selectt (self .keys_embedding , normalized_inv_keys .long (), True )
402+ # normalized_glyph_ids = inv_glyps - nethack.GLYPH_OBJ_OFF
403+ # normalized_glyph_ids = torch.where(
404+ # inv_glyps < nethack.MAX_GLYPH, inv_glyps - nethack.GLYPH_OBJ_OFF, nethack.NUM_OBJECTS
405+ # )
406+ # normalized_inv_keys = torch.where(inv_keys > 0, inv_keys - 32, 0)
407+
408+ # self._check_embedding_indices(normalized_glyph_ids, self.glyphs_embedding)
409+ # self._check_embedding_indices(normalized_inv_keys, self.keys_embedding)
410+ # self._check_embedding_indices(inv_oclasses, self.classes_embedding)
411+
412+ embedded_glyphs = selectt (self .glyphs_embedding , inv_glyps .long (), True )
413+ embedded_inv_keys = selectt (self .keys_embedding , inv_keys .long (), True )
410414 embedded_classes = selectt (self .classes_embedding , inv_oclasses .long (), True )
411415
412416 encoding = embedded_glyphs + embedded_inv_keys + embedded_classes
0 commit comments