Skip to content

Commit 65304f5

Browse files
committed
Pass dtype and device to sampler init
1 parent d3c48c8 commit 65304f5

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/auto_cast/processors/diffusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,14 @@ def map(self, x: Tensor) -> Tensor:
5757
# if we start from zero at every autoregressive step, the model is asked to
5858
# denoise using t=0, which is a point it has never been trained on.
5959
# self.inference_t = 1e-5
60-
sampler = self._get_sampler(self.sampler_steps, dtype=x.dtype, device=x.device)
60+
dtype = x.dtype
61+
device = x.device
62+
sampler = self._get_sampler(self.sampler_steps, dtype=dtype, device=device)
6163
B, _, W, H, _ = x.shape
6264
x_1 = sampler.init(
63-
(B, self.n_steps_output, W, H, self.n_channels_out)
65+
(B, self.n_steps_output, W, H, self.n_channels_out),
66+
dtype=dtype,
67+
device=device,
6468
) # Fully noised
6569
return sampler(x_1, cond=x)
6670

0 commit comments

Comments
 (0)