@@ -27,14 +27,14 @@ def t_to_sigma(self, t):
27
27
return (t * math .pi / 2 ).tan ()
28
28
29
29
def loss (self , input , noise , sigma , ** kwargs ):
30
- c_skip , c_out , c_in = [ utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma )]
30
+ c_skip , c_out , c_in = ( utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma ))
31
31
noised_input = input + noise * utils .append_dims (sigma , input .ndim )
32
32
model_output = self .inner_model (noised_input * c_in , self .sigma_to_t (sigma ), ** kwargs )
33
33
target = (input - c_skip * noised_input ) / c_out
34
34
return (model_output - target ).pow (2 ).flatten (1 ).mean (1 )
35
35
36
36
def forward (self , input , sigma , ** kwargs ):
37
- c_skip , c_out , c_in = [ utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma )]
37
+ c_skip , c_out , c_in = ( utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma ))
38
38
return self .inner_model (input * c_in , self .sigma_to_t (sigma ), ** kwargs ) * c_out + input * c_skip
39
39
40
40
@@ -102,13 +102,13 @@ def get_eps(self, *args, **kwargs):
102
102
return self .inner_model (* args , ** kwargs )
103
103
104
104
def loss (self , input , noise , sigma , ** kwargs ):
105
- c_out , c_in = [ utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma )]
105
+ c_out , c_in = ( utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma ))
106
106
noised_input = input + noise * utils .append_dims (sigma , input .ndim )
107
107
eps = self .get_eps (noised_input * c_in , self .sigma_to_t (sigma ), ** kwargs )
108
108
return (eps - noise ).pow (2 ).flatten (1 ).mean (1 )
109
109
110
110
def forward (self , input , sigma , ** kwargs ):
111
- c_out , c_in = [ utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma )]
111
+ c_out , c_in = ( utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma ))
112
112
eps = self .get_eps (input * c_in , self .sigma_to_t (sigma ), ** kwargs )
113
113
return input + eps * c_out
114
114
@@ -156,14 +156,14 @@ def get_v(self, *args, **kwargs):
156
156
return self .inner_model (* args , ** kwargs )
157
157
158
158
def loss (self , input , noise , sigma , ** kwargs ):
159
- c_skip , c_out , c_in = [ utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma )]
159
+ c_skip , c_out , c_in = ( utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma ))
160
160
noised_input = input + noise * utils .append_dims (sigma , input .ndim )
161
161
model_output = self .get_v (noised_input * c_in , self .sigma_to_t (sigma ), ** kwargs )
162
162
target = (input - c_skip * noised_input ) / c_out
163
163
return (model_output - target ).pow (2 ).flatten (1 ).mean (1 )
164
164
165
165
def forward (self , input , sigma , ** kwargs ):
166
- c_skip , c_out , c_in = [ utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma )]
166
+ c_skip , c_out , c_in = ( utils .append_dims (x , input .ndim ) for x in self .get_scalings (sigma ))
167
167
return self .get_v (input * c_in , self .sigma_to_t (sigma ), ** kwargs ) * c_out + input * c_skip
168
168
169
169
0 commit comments