Skip to content

Commit 3cdf552

Browse files
committed
account for both cases when chemsys: str | list[str]
1 parent 88d28d7 commit 3cdf552

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

mattergen/common/data/collate.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,20 +237,8 @@ def _merge(xs: list[PyTree[T]], structure: PyTree[int]) -> PyTree[T]:
237237
)
238238
del x[attr] # type: ignore
239239

240-
# Batch.from_data_list will concat attr: list[list] to list[list[list]], we need to handle separately
241-
attr_is_twod_list = []
242-
243-
for attr in attrs:
244-
if all(isinstance(x[attr], list) for x in xs) and all(isinstance(_x,list) for x in xs for _x in x[attr]):
245-
attr_is_twod_list.append(attr)
246-
247240
try:
248241
batch = Batch.from_data_list(xs)
249-
250-
# handle attr: list[list] as a special case
251-
for attr in attr_is_twod_list:
252-
# convert batch.attr: list[list[list]] to list[list]
253-
batch[attr] = list(chain(*[x[attr] for x in xs]))
254242
except Exception as e:
255243
# Check if dtypes do not match:
256244
for attr in attrs:

mattergen/common/data/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mattergen.common.utils.globals import PROPERTY_SOURCE_IDS
1212

1313
PropertySourceId = str
14-
TargetProperty = dict[PropertySourceId, int | float | Sequence[str]]
14+
TargetProperty = dict[PropertySourceId, int | float | str | Sequence[str]]
1515

1616

1717
@dataclass(frozen=True)

mattergen/diffusion/sampling/classifier_free_guidance.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,18 @@ def get_conditional_score():
7575
batch_no_condition = self._remove_conditioning_fn(x)
7676
batch_with_condition = self._keep_conditioning_fn(x)
7777
joint_batch = collate([batch_no_condition, batch_with_condition])
78+
79+
for attr,value in batch_no_condition.items():
80+
if isinstance(value, list):
81+
joint_batch[attr] = batch_no_condition[attr]+batch_with_condition[attr]
82+
83+
7884
combined_score = super(GuidedPredictorCorrector, self)._score_fn(
7985
x=joint_batch, t=torch.cat([t, t], dim=0),
8086
)
8187
# Split the combined score back into unconditional and conditional parts.
88+
# Any batch.attr: list fields will be wrong here because of the manual concatenation above
89+
# this should be ok as self._multi_corruption.corrupted_fields are always torch.Tensor
8290
unconditional_score = combined_score[0]
8391
conditional_score = combined_score[1]
8492

0 commit comments

Comments
 (0)