Skip to content

KeyError occurs when loading custom dataset #131

Open
@JustinS6626

Description

@JustinS6626

I am trying load a dataset created from a Minigrid-type environment using the following code:

import os

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gymnasium import spaces
from stable_baselines3 import PPO
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import minari
from minari import DataCollectorV0

def collate_fn(batch):
    return {
        "id": torch.Tensor([x.id for x in batch]),
        "seed": torch.Tensor([x.seed for x in batch]),
        "total_timesteps": torch.Tensor([x.total_timesteps for x in batch]),
        "observations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.observations) for x in batch],
            batch_first=True
        ),
        "actions": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.actions) for x in batch],
            batch_first=True
        ),
        "rewards": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.rewards) for x in batch],
            batch_first=True
        ),
        "terminations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.terminations) for x in batch],
            batch_first=True
        ),
        "truncations": torch.nn.utils.rnn.pad_sequence(
            [torch.as_tensor(x.truncations) for x in batch],
            batch_first=True
        )
    }


torch.manual_seed(42)

minari_testset = minari.load_dataset("MinigridRandomWall-6Spots-v0")
dataloader = DataLoader(minari_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)

for batch in dataloader:
    print("Observation shape: " + str(batch['observations'].shape))
    print("Action shape: " + str(batch['actions'].shape))
    print("Reward shape: " + str(batch['rewards'].shape))
    print("Timestep shape " + str(batch["infos"]["timestep"].shape))

When I run the code, I get this error:

    minari_testset = minari.load_dataset("MinigridRandomWall-6Spots-v0")
  File "/home/justin/Minari/minari/storage/local.py", line 22, in load_dataset
    return MinariDataset(data_path)
  File "/home/justin/Minari/minari/dataset/minari_dataset.py", line 133, in __init__
    self._data = MinariStorage(data)
  File "/home/justin/Minari/minari/dataset/minari_storage.py", line 22, in __init__
    flatten_observations = f.attrs["flatten_observation"].item()
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "/usr/local/lib/python3.8/dist-packages/h5py/_hl/attrs.py", line 56, in __getitem__
    attr = h5a.open(self._id, self._e(name))
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5a.pyx", line 80, in h5py.h5a.open
KeyError: "Can't open attribute (can't locate attribute in name index)"

When I created the dataset, I used an ImgObsWrapper for the environment. Could that be the source of the problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions