@@ -31,6 +31,7 @@ def __init__(
3131 n_steps_output : int = 4 ,
3232 n_channels_out : int = 1 ,
3333 sampler_steps : int = 50 ,
34+ sampler : str = "euler" ,
3435 ):
3536 super ().__init__ ()
3637 self .teacher_forcing_ratio = teacher_forcing_ratio
@@ -40,6 +41,7 @@ def __init__(
4041 self .n_steps_output = n_steps_output
4142 self .n_channels_out = n_channels_out
4243 self .sampler_steps = sampler_steps
44+ self .sampler = sampler
4345
4446 # Create Azula denoiser with chosen preconditioning
4547 if denoiser_type == "simple" :
@@ -54,13 +56,15 @@ def __init__(
5456
5557 def map (self , x : Tensor ) -> Tensor :
5658 """Map input window of states/times to output window using denoiser."""
57- # if we start from zero at every autoregressive step, the model is asked to
58- # denoise using t=0, which is a point it has never been trained on.
59- # self.inference_t = 1e-5
6059 dtype = x .dtype
6160 device = x .device
6261 sampler = self ._get_sampler (self .sampler_steps , dtype = dtype , device = device )
6362 B , _ , W , H , _ = x .shape
63+
64+ # if we start from zero at every autoregressive step, the model is asked to
65+ # denoise using t=0, which is a point it has never been trained on.
66+ # self.inference_t = 1e-5
67+ # using azula sampler init to create noise at t=1
6468 x_1 = sampler .init (
6569 (B , self .n_steps_output , W , H , self .n_channels_out ),
6670 dtype = dtype ,
@@ -69,23 +73,6 @@ def map(self, x: Tensor) -> Tensor:
6973 return sampler (x_1 , cond = x )
7074
7175 def forward (self , x : Tensor ) -> Tensor :
72- # # Training mode: sample random time and denoise
73- # if self.train:
74- # B, _, W, H, _ = x.shape
75-
76- # # Sample a random time
77- # t = torch.rand(B, device=x.device) * 0.999 + 0.001 # Avoid t=0 or t=1
78-
79- # # Create noisy input
80- # x_noisy = torch.randn(
81- # B, self.n_steps_output, W, H, self.n_channels_out, device=x.device
82- # )
83-
84- # # Denoise (this preserves gradients)
85- # posterior = self.denoiser(x_noisy, t, cond=x)
86- # return posterior.mean
87-
88- # Evaluation mode: use the map function
8976 return self .map (x )
9077
9178 def _denoise (self , x : Tensor , t : Tensor , cond : Tensor ) -> Tensor :
@@ -107,6 +94,7 @@ def loss(self, batch: EncodedBatch) -> Tensor:
10794 # Cannot use Azula's built-in weighted loss since ligntning calls forward
10895 loss = self .denoiser .loss (x_0 , t = t , cond = x_cond )
10996
97+ # TODO: consider an API for looking at alternative losses
11098 # # Compute weighted loss
11199 # alpha_t, sigma_t = self.schedule(t)
112100 # alpha_t = alpha_t.view(-1, 1, 1, 1, 1) # (B, 1, 1, 1, 1)
@@ -124,11 +112,11 @@ def loss(self, batch: EncodedBatch) -> Tensor:
124112 def _get_sampler (
125113 self ,
126114 num_steps : int = 100 ,
127- sampler : str = "euler" ,
128115 eta : float = 0.0 ,
129116 silent : bool = True ,
130117 ** sampler_kwargs ,
131118 ) -> Sampler :
119+ sampler = self .sampler
132120 # Create appropriate Azula sampler
133121 if sampler == "euler" :
134122 azula_sampler = EulerSampler (
0 commit comments