Skip to content

Commit 91dac6e

Browse files
committed
add license
1 parent faba229 commit 91dac6e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ml4h/models/train_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def regress_on_batch(diffuser, regressor, controls, tm_out, batch_size):
131131
diffusion_steps=50,
132132
)
133133
control_predictions = regressor.predict(generated_images, verbose=0)
134-
return control_predictions[:, 0] if tm_out.is_continuous() else control_predictions
134+
return control_predictions[tm_out.output_name()][:, 0] if tm_out.is_continuous() else control_predictions
135135

136136

137137
def regress_on_controlled_generations(diffuser, regressor, tm_out, batches, batch_size, mean, std, prefix):

0 commit comments

Comments
 (0)