@@ -99,11 +99,14 @@ def __init__(
99
99
self .alphas = 1.0 - self .betas
100
100
self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
101
101
102
- self .sigmas = ((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5
102
+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
103
+ sigmas = np .concatenate ([sigmas [::- 1 ], [0.0 ]]).astype (np .float32 )
104
+ self .sigmas = torch .from_numpy (sigmas )
103
105
104
106
# setable values
105
107
self .num_inference_steps = None
106
- self .timesteps = np .arange (0 , num_train_timesteps )[::- 1 ] # to be consistent has to be smaller than sigmas by 1
108
+ timesteps = np .linspace (0 , num_train_timesteps - 1 , num_train_timesteps , dtype = float )[::- 1 ].copy ()
109
+ self .timesteps = torch .from_numpy (timesteps )
107
110
self .derivatives = []
108
111
109
112
def get_lms_coefficient (self , order , t , current_order ):
@@ -137,17 +140,14 @@ def set_timesteps(self, num_inference_steps: int):
137
140
the number of diffusion steps used when generating samples with a pre-trained model.
138
141
"""
139
142
self .num_inference_steps = num_inference_steps
140
- timesteps = np .linspace (self .config .num_train_timesteps - 1 , 0 , num_inference_steps , dtype = float )
141
143
142
- low_idx = np .floor (timesteps ).astype (int )
143
- high_idx = np .ceil (timesteps ).astype (int )
144
- frac = np .mod (timesteps , 1.0 )
144
+ timesteps = np .linspace (0 , self .config .num_train_timesteps - 1 , num_inference_steps , dtype = float )[::- 1 ].copy ()
145
145
sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
146
- sigmas = ( 1 - frac ) * sigmas [ low_idx ] + frac * sigmas [ high_idx ]
146
+ sigmas = np . interp ( timesteps , np . arange ( 0 , len ( sigmas )), sigmas )
147
147
sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
148
148
self .sigmas = torch .from_numpy (sigmas )
149
+ self .timesteps = torch .from_numpy (timesteps )
149
150
150
- self .timesteps = timesteps .astype (int )
151
151
self .derivatives = []
152
152
153
153
def step (
0 commit comments