Skip to content

Commit 88d28d7

Browse files
committed
fix collate to account for attr: list[list]
1 parent a4935de commit 88d28d7

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

mattergen/common/data/collate.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import warnings
55
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, overload
6+
from itertools import chain
67

78
from torch import Tensor
89
from torch_geometric.data import Batch, Data
@@ -236,8 +237,20 @@ def _merge(xs: list[PyTree[T]], structure: PyTree[int]) -> PyTree[T]:
236237
)
237238
del x[attr] # type: ignore
238239

240+
# Batch.from_data_list will concat attr: list[list] to list[list[list]], we need to handle separately
241+
attr_is_twod_list = []
242+
243+
for attr in attrs:
244+
if all(isinstance(x[attr], list) for x in xs) and all(isinstance(_x,list) for x in xs for _x in x[attr]):
245+
attr_is_twod_list.append(attr)
246+
239247
try:
240248
batch = Batch.from_data_list(xs)
249+
250+
# handle attr: list[list] as a special case
251+
for attr in attr_is_twod_list:
252+
# convert batch.attr: list[list[list]] to list[list]
253+
batch[attr] = list(chain(*[x[attr] for x in xs]))
241254
except Exception as e:
242255
# Check if dtypes do not match:
243256
for attr in attrs:

0 commit comments

Comments
 (0)