@@ -37,6 +37,7 @@ class LogDiffusionImages(Callback):
37
37
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
38
38
Default: ``1138``.
39
39
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
40
+ use_mask (bool): Whether or not to use the mask for the encoded text. Default: ``True``.
40
41
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
41
42
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
42
43
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
@@ -56,6 +57,7 @@ def __init__(self,
56
57
rescaled_guidance : Optional [float ] = None ,
57
58
seed : Optional [int ] = 1138 ,
58
59
use_table : bool = False ,
60
+ use_mask : bool = True ,
59
61
t5_encoder : Optional [str ] = None ,
60
62
clip_encoder : Optional [str ] = None ,
61
63
t5_latent_key : str = 'T5_LATENTS' ,
@@ -71,6 +73,7 @@ def __init__(self,
71
73
self .rescaled_guidance = rescaled_guidance
72
74
self .seed = seed
73
75
self .use_table = use_table
76
+ self .use_mask = use_mask
74
77
self .t5_latent_key = t5_latent_key
75
78
self .t5_mask_key = t5_mask_key
76
79
self .clip_latent_key = clip_latent_key
@@ -100,47 +103,47 @@ def __init__(self,
100
103
local_files_only = True )
101
104
102
105
t5_model = AutoModel .from_pretrained (t5_encoder ,
103
- torch_dtype = torch .float16 ,
106
+ torch_dtype = torch .bfloat16 ,
104
107
cache_dir = self .cache_dir ,
105
108
local_files_only = True ).encoder .cuda ().eval ()
106
109
clip_model = CLIPTextModel .from_pretrained (clip_encoder ,
107
110
subfolder = 'text_encoder' ,
108
- torch_dtype = torch .float16 ,
111
+ torch_dtype = torch .bfloat16 ,
109
112
cache_dir = self .cache_dir ,
110
113
local_files_only = True ).cuda ().eval ()
111
-
112
- for batch in self .batched_prompts :
113
- latent_batch = {}
114
- tokenized_t5 = t5_tokenizer (batch ,
115
- padding = 'max_length' ,
116
- max_length = t5_tokenizer .model_max_length ,
117
- truncation = True ,
118
- return_tensors = 'pt' )
119
- t5_attention_mask = tokenized_t5 ['attention_mask' ].to (torch .bool ).cuda ()
120
- t5_ids = tokenized_t5 ['input_ids' ].cuda ()
121
- t5_latents = t5_model (input_ids = t5_ids , attention_mask = t5_attention_mask )[0 ].cpu ()
122
- t5_attention_mask = t5_attention_mask .cpu ().to (torch .long )
123
-
124
- tokenized_clip = clip_tokenizer (batch ,
114
+ with torch .no_grad ():
115
+ for batch in self .batched_prompts :
116
+ latent_batch = {}
117
+ tokenized_t5 = t5_tokenizer (batch ,
125
118
padding = 'max_length' ,
126
- max_length = clip_tokenizer .model_max_length ,
119
+ max_length = t5_tokenizer .model_max_length ,
127
120
truncation = True ,
128
121
return_tensors = 'pt' )
129
- clip_attention_mask = tokenized_clip ['attention_mask' ].cuda ()
130
- clip_ids = tokenized_clip ['input_ids' ].cuda ()
131
- clip_outputs = clip_model (input_ids = clip_ids ,
132
- attention_mask = clip_attention_mask ,
133
- output_hidden_states = True )
134
- clip_latents = clip_outputs .hidden_states [- 2 ].cpu ()
135
- clip_pooled = clip_outputs [1 ].cpu ()
136
- clip_attention_mask = clip_attention_mask .cpu ().to (torch .long )
137
-
138
- latent_batch [self .t5_latent_key ] = t5_latents
139
- latent_batch [self .t5_mask_key ] = t5_attention_mask
140
- latent_batch [self .clip_latent_key ] = clip_latents
141
- latent_batch [self .clip_mask_key ] = clip_attention_mask
142
- latent_batch [self .clip_pooled_key ] = clip_pooled
143
- self .batched_latents .append (latent_batch )
122
+ t5_attention_mask = tokenized_t5 ['attention_mask' ].to (torch .bool ).cuda ()
123
+ t5_ids = tokenized_t5 ['input_ids' ].cuda ()
124
+ t5_latents = t5_model (input_ids = t5_ids , attention_mask = t5_attention_mask )[0 ].cpu ()
125
+ t5_attention_mask = t5_attention_mask .cpu ().to (torch .long )
126
+
127
+ tokenized_clip = clip_tokenizer (batch ,
128
+ padding = 'max_length' ,
129
+ max_length = clip_tokenizer .model_max_length ,
130
+ truncation = True ,
131
+ return_tensors = 'pt' )
132
+ clip_attention_mask = tokenized_clip ['attention_mask' ].cuda ()
133
+ clip_ids = tokenized_clip ['input_ids' ].cuda ()
134
+ clip_outputs = clip_model (input_ids = clip_ids ,
135
+ attention_mask = clip_attention_mask ,
136
+ output_hidden_states = True )
137
+ clip_latents = clip_outputs .hidden_states [- 2 ].cpu ()
138
+ clip_pooled = clip_outputs [1 ].cpu ()
139
+ clip_attention_mask = clip_attention_mask .cpu ().to (torch .long )
140
+
141
+ latent_batch [self .t5_latent_key ] = t5_latents
142
+ latent_batch [self .t5_mask_key ] = t5_attention_mask
143
+ latent_batch [self .clip_latent_key ] = clip_latents
144
+ latent_batch [self .clip_mask_key ] = clip_attention_mask
145
+ latent_batch [self .clip_pooled_key ] = clip_pooled
146
+ self .batched_latents .append (latent_batch )
144
147
145
148
del t5_model
146
149
del clip_model
@@ -160,21 +163,40 @@ def eval_start(self, state: State, logger: Logger):
160
163
if self .precomputed_latents :
161
164
for batch in self .batched_latents :
162
165
pooled_prompt = batch [self .clip_pooled_key ].cuda ()
163
- prompt_embeds , prompt_mask = model .prepare_text_embeddings (batch [self .t5_latent_key ].cuda (),
164
- batch [self .clip_latent_key ].cuda (),
165
- batch [self .t5_mask_key ].cuda (),
166
- batch [self .clip_mask_key ].cuda ())
167
- gen_images = model .generate (prompt_embeds = prompt_embeds ,
168
- pooled_prompt = pooled_prompt ,
169
- prompt_mask = prompt_mask ,
170
- height = self .size [0 ],
171
- width = self .size [1 ],
172
- guidance_scale = self .guidance_scale ,
173
- rescaled_guidance = self .rescaled_guidance ,
174
- progress_bar = False ,
175
- num_inference_steps = self .num_inference_steps ,
176
- seed = self .seed )
166
+ if self .use_mask :
167
+ prompt_embeds , prompt_mask = model .prepare_text_embeddings (batch [self .t5_latent_key ].cuda (),
168
+ batch [self .clip_latent_key ].cuda (),
169
+ batch [self .t5_mask_key ].cuda (),
170
+ batch [self .clip_mask_key ].cuda ())
171
+ gen_images = model .generate (prompt_embeds = prompt_embeds ,
172
+ pooled_prompt = pooled_prompt ,
173
+ prompt_mask = prompt_mask ,
174
+ height = self .size [0 ],
175
+ width = self .size [1 ],
176
+ guidance_scale = self .guidance_scale ,
177
+ rescaled_guidance = self .rescaled_guidance ,
178
+ progress_bar = False ,
179
+ num_inference_steps = self .num_inference_steps ,
180
+ seed = self .seed )
181
+ else :
182
+ prompt_embeds = model .prepare_text_embeddings (batch [self .t5_latent_key ].cuda (),
183
+ batch [self .clip_latent_key ].cuda ())
184
+ gen_images = model .generate (prompt_embeds = prompt_embeds ,
185
+ pooled_prompt = pooled_prompt ,
186
+ height = self .size [0 ],
187
+ width = self .size [1 ],
188
+ guidance_scale = self .guidance_scale ,
189
+ rescaled_guidance = self .rescaled_guidance ,
190
+ progress_bar = False ,
191
+ num_inference_steps = self .num_inference_steps ,
192
+ seed = self .seed )
177
193
all_gen_images .append (gen_images )
194
+ # Clear up GPU tensors
195
+ del pooled_prompt
196
+ del prompt_embeds
197
+ if self .use_mask :
198
+ del prompt_mask
199
+ torch .cuda .empty_cache ()
178
200
else :
179
201
for batch in self .batched_prompts :
180
202
gen_images = model .generate (
0 commit comments