Skip to content

Commit 483c42c

Browse files
authored
flowmatching condition shape fix (#1584)
* flowmatching condition shape fix * self.condition shape not condition shape for fix
1 parent b25cc17 commit 483c42c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

sbi/neural_nets/estimators/flowmatching_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def forward(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor:
119119
# and remember the original shape
120120
target_shape = input.shape
121121
input = input.reshape(-1, input.shape[-1])
122-
condition = condition.reshape(-1, condition.shape[-1])
122+
condition = condition.reshape(-1, *self.condition_shape)
123123
t = t.reshape(-1, t.shape[-1])
124124

125125
# embed the input and condition

0 commit comments

Comments
 (0)