Skip to content

Commit d559687

Browse files
committed
turn off normalization and checking for invetory
1 parent 5b4b187 commit d559687

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

sf_examples/nethack/models/scaled.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -383,30 +383,34 @@ def __init__(self, embedding_dim: int):
383383
self.out_dim = self.INVENTORY_SIZE * embedding_dim
384384

385385
self.glyphs_embedding = nn.Embedding(
386-
nethack.NUM_OBJECTS + 1,
386+
nethack.MAX_GLYPH + 1,
387387
embedding_dim,
388-
padding_idx=nethack.NUM_OBJECTS,
388+
# padding_idx=nethack.MAX_GLYPH,
389+
)
390+
self.keys_embedding = nn.Embedding(
391+
256,
392+
embedding_dim,
393+
# padding_idx=0.
389394
)
390-
self.keys_embedding = nn.Embedding(96, embedding_dim, padding_idx=0)
391395
self.classes_embedding = nn.Embedding(
392396
nethack.MAXOCLASSES + 1,
393397
embedding_dim,
394-
padding_idx=nethack.MAXOCLASSES,
398+
# padding_idx=nethack.MAXOCLASSES,
395399
)
396400

397401
def forward(self, inv_glyps, inv_keys, inv_oclasses):
398-
normalized_glyph_ids = inv_glyps - nethack.GLYPH_OBJ_OFF
399-
normalized_glyph_ids = torch.where(
400-
inv_glyps < nethack.MAX_GLYPH, inv_glyps - nethack.GLYPH_OBJ_OFF, nethack.NUM_OBJECTS
401-
)
402-
normalized_inv_keys = torch.where(inv_keys > 0, inv_keys - 32, 0)
403-
404-
self._check_embedding_indices(normalized_glyph_ids, self.glyphs_embedding)
405-
self._check_embedding_indices(normalized_inv_keys, self.keys_embedding)
406-
self._check_embedding_indices(inv_oclasses, self.classes_embedding)
407-
408-
embedded_glyphs = selectt(self.glyphs_embedding, normalized_glyph_ids.long(), True)
409-
embedded_inv_keys = selectt(self.keys_embedding, normalized_inv_keys.long(), True)
402+
# normalized_glyph_ids = inv_glyps - nethack.GLYPH_OBJ_OFF
403+
# normalized_glyph_ids = torch.where(
404+
# inv_glyps < nethack.MAX_GLYPH, inv_glyps - nethack.GLYPH_OBJ_OFF, nethack.NUM_OBJECTS
405+
# )
406+
# normalized_inv_keys = torch.where(inv_keys > 0, inv_keys - 32, 0)
407+
408+
# self._check_embedding_indices(normalized_glyph_ids, self.glyphs_embedding)
409+
# self._check_embedding_indices(normalized_inv_keys, self.keys_embedding)
410+
# self._check_embedding_indices(inv_oclasses, self.classes_embedding)
411+
412+
embedded_glyphs = selectt(self.glyphs_embedding, inv_glyps.long(), True)
413+
embedded_inv_keys = selectt(self.keys_embedding, inv_keys.long(), True)
410414
embedded_classes = selectt(self.classes_embedding, inv_oclasses.long(), True)
411415

412416
encoding = embedded_glyphs + embedded_inv_keys + embedded_classes

0 commit comments

Comments
 (0)