Skip to content

Commit 2e55897

Browse files
author
Daniel Zuegner
committed
batch unconditional and conditional scores
1 parent adf86af commit 2e55897

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

mattergen/diffusion/sampling/classifier_free_guidance.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from mattergen.diffusion.sampling.pc_sampler import Diffusable, PredictorCorrector
9+
from mattergen.common.data.collate import collate
910

1011
BatchTransform = Callable[[Diffusable], Diffusable]
1112

@@ -71,9 +72,16 @@ def get_conditional_score():
7172
return get_unconditional_score()
7273
else:
7374
# guided_score = guidance_factor * conditional_score + (1-guidance_factor) * unconditional_score
75+
batch_no_condition = self._remove_conditioning_fn(x)
76+
batch_with_condition = self._keep_conditioning_fn(x)
77+
joint_batch = collate([batch_no_condition, batch_with_condition])
78+
combined_score = super(GuidedPredictorCorrector, self)._score_fn(
79+
x=joint_batch, t=torch.cat([t, t], dim=0),
80+
)
81+
# Split the combined score back into unconditional and conditional parts.
82+
unconditional_score = combined_score[0]
83+
conditional_score = combined_score[1]
7484

75-
conditional_score = get_conditional_score()
76-
unconditional_score = get_unconditional_score()
7785
return unconditional_score.replace(
7886
**{
7987
k: torch.lerp(

0 commit comments

Comments
 (0)