Skip to content

Commit 34c92eb

Browse files
committed
[Bugfix] Fix AttributeError when val is a Tensor in Minari
_extract_nontensor_fields expects a TensorDict but was being called on plain Tensors. Added check to only call it when val is a tensor collection.
1 parent 3e384c6 commit 34c92eb

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

torchrl/data/datasets/minari_data.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,14 +416,18 @@ def _download_and_preproc(self):
416416
val_next = val[1:].clone()
417417
val_copy = val[:-1].clone()
418418

419-
non_tensors_next = _extract_nontensor_fields(val_next)
420-
non_tensors_now = _extract_nontensor_fields(val_copy)
421-
422419
data_view["next", match].copy_(val_next)
423420
data_view[match].copy_(val_copy)
424421

425-
data_view["next", match].update_(non_tensors_next)
426-
data_view[match].update_(non_tensors_now)
422+
if is_tensor_collection(val_next):
423+
non_tensors_next = _extract_nontensor_fields(
424+
val_next
425+
)
426+
non_tensors_now = _extract_nontensor_fields(
427+
val_copy
428+
)
429+
data_view["next", match].update_(non_tensors_next)
430+
data_view[match].update_(non_tensors_now)
427431

428432
elif key not in ("terminations", "truncations", "rewards"):
429433
if steps is None:

0 commit comments

Comments
 (0)