@@ -246,15 +246,16 @@ def __init__(
246
246
def forward_with_cond_scale (
247
247
self ,
248
248
x ,
249
- * ,
249
+ * args ,
250
250
cond_scale = 1. ,
251
251
** kwargs
252
252
):
253
+ logits = self .forward (x , * args , ** kwargs )
254
+
253
255
if cond_scale == 1 :
254
- return self . forward ( x , ** kwargs )
256
+ return logits
255
257
256
- logits = self .forward (x , ** kwargs )
257
- null_logits = self .forward (x , cond_drop_prob = 1. , ** kwargs )
258
+ null_logits = self .forward (x , * args , cond_drop_prob = 1. , ** kwargs )
258
259
return null_logits + (logits - null_logits ) * cond_scale
259
260
260
261
def forward (
@@ -635,15 +636,16 @@ def __init__(
635
636
def forward_with_cond_scale (
636
637
self ,
637
638
x ,
638
- * ,
639
+ * args ,
639
640
cond_scale = 1. ,
640
641
** kwargs
641
642
):
643
+ logits = self .forward (x , * args , ** kwargs )
644
+
642
645
if cond_scale == 1 :
643
- return self . forward ( x , ** kwargs )
646
+ return logits
644
647
645
- logits = self .forward (x , ** kwargs )
646
- null_logits = self .forward (x , cond_drop_prob = 1. , ** kwargs )
648
+ null_logits = self .forward (x , * args , cond_drop_prob = 1. , ** kwargs )
647
649
return null_logits + (logits - null_logits ) * cond_scale
648
650
649
651
def forward (
@@ -774,8 +776,8 @@ def q_posterior(self, x_start, x_t, t):
774
776
posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
775
777
return posterior_mean , posterior_variance , posterior_log_variance_clipped
776
778
777
- def p_mean_variance (self , x , t , image_embed , clip_denoised : bool ):
778
- x_recon = self .predict_start_from_noise (x , t = t , noise = self .net (x , t , image_embed = image_embed ))
779
+ def p_mean_variance (self , x , t , image_embed , clip_denoised = True , cond_scale = 1. ):
780
+ x_recon = self .predict_start_from_noise (x , t = t , noise = self .net . forward_with_cond_scale (x , t , image_embed = image_embed , cond_scale = cond_scale ))
779
781
780
782
if clip_denoised :
781
783
x_recon .clamp_ (- 1. , 1. )
@@ -784,31 +786,31 @@ def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
784
786
return model_mean , posterior_variance , posterior_log_variance
785
787
786
788
@torch .no_grad ()
787
- def p_sample (self , x , t , image_embed , clip_denoised = True , repeat_noise = False ):
789
+ def p_sample (self , x , t , image_embed , cond_scale = 1. , clip_denoised = True , repeat_noise = False ):
788
790
b , * _ , device = * x .shape , x .device
789
- model_mean , _ , model_log_variance = self .p_mean_variance (x = x , t = t , image_embed = image_embed , clip_denoised = clip_denoised )
791
+ model_mean , _ , model_log_variance = self .p_mean_variance (x = x , t = t , image_embed = image_embed , cond_scale = cond_scale , clip_denoised = clip_denoised )
790
792
noise = noise_like (x .shape , device , repeat_noise )
791
793
# no noise when t == 0
792
794
nonzero_mask = (1 - (t == 0 ).float ()).reshape (b , * ((1 ,) * (len (x .shape ) - 1 )))
793
795
return model_mean + nonzero_mask * (0.5 * model_log_variance ).exp () * noise
794
796
795
797
@torch .no_grad ()
796
- def p_sample_loop (self , shape , image_embed ):
798
+ def p_sample_loop (self , shape , image_embed , cond_scale = 1 ):
797
799
device = self .betas .device
798
800
799
801
b = shape [0 ]
800
802
img = torch .randn (shape , device = device )
801
803
802
804
for i in tqdm (reversed (range (0 , self .num_timesteps )), desc = 'sampling loop time step' , total = self .num_timesteps ):
803
- img = self .p_sample (img , torch .full ((b ,), i , device = device , dtype = torch .long ), image_embed = image_embed )
805
+ img = self .p_sample (img , torch .full ((b ,), i , device = device , dtype = torch .long ), image_embed = image_embed , cond_scale = cond_scale )
804
806
return img
805
807
806
808
@torch .no_grad ()
807
- def sample (self , image_embed ):
809
+ def sample (self , image_embed , cond_scale = 1. ):
808
810
batch_size = image_embed .shape [0 ]
809
811
image_size = self .image_size
810
812
channels = self .channels
811
- return self .p_sample_loop ((batch_size , channels , image_size , image_size ), image_embed = image_embed )
813
+ return self .p_sample_loop ((batch_size , channels , image_size , image_size ), image_embed = image_embed , cond_scale = cond_scale )
812
814
813
815
def q_sample (self , x_start , t , noise = None ):
814
816
noise = default (noise , lambda : torch .randn_like (x_start ))
@@ -869,7 +871,8 @@ def __init__(
869
871
@torch .no_grad ()
870
872
def forward (
871
873
self ,
872
- text
874
+ text ,
875
+ cond_scale = 1.
873
876
):
874
877
device = next (self .parameters ()).device
875
878
@@ -878,5 +881,5 @@ def forward(
878
881
text = tokenizer .tokenize (text ).to (device )
879
882
880
883
image_embed = self .prior .sample (text , num_samples_per_batch = self .prior_num_samples )
881
- images = self .decoder .sample (image_embed )
884
+ images = self .decoder .sample (image_embed , cond_scale = cond_scale )
882
885
return images
0 commit comments