1
- import argparse , os , sys , glob
1
+ import argparse , os , sys , glob , uuid
2
2
import torch
3
3
import numpy as np
4
4
from omegaconf import OmegaConf
@@ -50,23 +50,23 @@ def main():
50
50
type = str ,
51
51
nargs = "?" ,
52
52
default = "a painting of a virus monster playing guitar" ,
53
- help = "the prompt to render"
53
+ help = "the prompt to render" ,
54
54
)
55
55
parser .add_argument (
56
56
"--outdir" ,
57
57
type = str ,
58
58
nargs = "?" ,
59
59
help = "dir to write results to" ,
60
- default = "outputs/txt2img-samples"
60
+ default = "outputs/txt2img-samples" ,
61
61
)
62
62
parser .add_argument (
63
63
"--skip_grid" ,
64
- action = ' store_true' ,
64
+ action = " store_true" ,
65
65
help = "do not save a grid, only individual samples. Helpful when evaluating lots of samples" ,
66
66
)
67
67
parser .add_argument (
68
68
"--skip_save" ,
69
- action = ' store_true' ,
69
+ action = " store_true" ,
70
70
help = "do not save individual samples. For speed measurements." ,
71
71
)
72
72
parser .add_argument (
@@ -77,17 +77,17 @@ def main():
77
77
)
78
78
parser .add_argument (
79
79
"--plms" ,
80
- action = ' store_true' ,
80
+ action = " store_true" ,
81
81
help = "use plms sampling" ,
82
82
)
83
83
parser .add_argument (
84
84
"--laion400m" ,
85
- action = ' store_true' ,
85
+ action = " store_true" ,
86
86
help = "uses the LAION400M model" ,
87
87
)
88
88
parser .add_argument (
89
89
"--fixed_code" ,
90
- action = ' store_true' ,
90
+ action = " store_true" ,
91
91
help = "if enabled, uses the same starting code across samples " ,
92
92
)
93
93
parser .add_argument (
@@ -160,7 +160,7 @@ def main():
160
160
type = str ,
161
161
default = "models/ldm/stable-diffusion-v1/model.ckpt" ,
162
162
help = "path to checkpoint of model" ,
163
- )
163
+ )
164
164
parser .add_argument (
165
165
"--seed" ,
166
166
type = int ,
@@ -172,14 +172,14 @@ def main():
172
172
type = str ,
173
173
help = "evaluate at this precision" ,
174
174
choices = ["full" , "autocast" ],
175
- default = "autocast"
175
+ default = "autocast" ,
176
176
)
177
177
178
-
179
178
parser .add_argument (
180
- "--embedding_path" ,
181
- type = str ,
182
- help = "Path to a pre-trained embedding manager checkpoint" )
179
+ "--embedding_path" ,
180
+ type = str ,
181
+ help = "Path to a pre-trained embedding manager checkpoint" ,
182
+ )
183
183
184
184
opt = parser .parse_args ()
185
185
@@ -193,7 +193,7 @@ def main():
193
193
194
194
config = OmegaConf .load (f"{ opt .config } " )
195
195
model = load_model_from_config (config , f"{ opt .ckpt } " )
196
- #model.embedding_manager.load(opt.embedding_path)
196
+ # model.embedding_manager.load(opt.embedding_path)
197
197
198
198
device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
199
199
model = model .to (device )
@@ -226,9 +226,11 @@ def main():
226
226
227
227
start_code = None
228
228
if opt .fixed_code :
229
- start_code = torch .randn ([opt .n_samples , opt .C , opt .H // opt .f , opt .W // opt .f ], device = device )
229
+ start_code = torch .randn (
230
+ [opt .n_samples , opt .C , opt .H // opt .f , opt .W // opt .f ], device = device
231
+ )
230
232
231
- precision_scope = autocast if opt .precision == "autocast" else nullcontext
233
+ precision_scope = autocast if opt .precision == "autocast" else nullcontext
232
234
with torch .no_grad ():
233
235
with precision_scope ("cuda" ):
234
236
with model .ema_scope ():
@@ -243,24 +245,31 @@ def main():
243
245
prompts = list (prompts )
244
246
c = model .get_learned_conditioning (prompts )
245
247
shape = [opt .C , opt .H // opt .f , opt .W // opt .f ]
246
- samples_ddim , _ = sampler .sample (S = opt .ddim_steps ,
247
- conditioning = c ,
248
- batch_size = opt .n_samples ,
249
- shape = shape ,
250
- verbose = False ,
251
- unconditional_guidance_scale = opt .scale ,
252
- unconditional_conditioning = uc ,
253
- eta = opt .ddim_eta ,
254
- x_T = start_code )
248
+ samples_ddim , _ = sampler .sample (
249
+ S = opt .ddim_steps ,
250
+ conditioning = c ,
251
+ batch_size = opt .n_samples ,
252
+ shape = shape ,
253
+ verbose = False ,
254
+ unconditional_guidance_scale = opt .scale ,
255
+ unconditional_conditioning = uc ,
256
+ eta = opt .ddim_eta ,
257
+ x_T = start_code ,
258
+ )
255
259
256
260
x_samples_ddim = model .decode_first_stage (samples_ddim )
257
- x_samples_ddim = torch .clamp ((x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
261
+ x_samples_ddim = torch .clamp (
262
+ (x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0
263
+ )
258
264
259
265
if not opt .skip_save :
260
266
for x_sample in x_samples_ddim :
261
- x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
267
+ x_sample = 255.0 * rearrange (
268
+ x_sample .cpu ().numpy (), "c h w -> h w c"
269
+ )
262
270
Image .fromarray (x_sample .astype (np .uint8 )).save (
263
- os .path .join (sample_path , f"{ base_count :05} .jpg" ))
271
+ os .path .join (sample_path , f"{ base_count :05} .jpg" )
272
+ )
264
273
base_count += 1
265
274
266
275
if not opt .skip_grid :
@@ -269,23 +278,29 @@ def main():
269
278
if not opt .skip_grid :
270
279
# additionally, save as grid
271
280
grid = torch .stack (all_samples , 0 )
272
- grid = rearrange (grid , 'n b c h w -> (n b) c h w' )
273
-
281
+ grid = rearrange (grid , "n b c h w -> (n b) c h w" )
282
+
283
+ batch_uuid = uuid .uuid4 ()
274
284
for i in range (grid .size (0 )):
275
- save_image (grid [i , :, :, :], os .path .join (outpath ,opt .prompt [:30 ]+ '_{}.png' .format (i )))
285
+ file_name = f"{ batch_uuid .hex [:10 ]} _{ opt .prompt [:30 ]} _{ i } .png"
286
+ save_image (grid [i , :, :, :], os .path .join (outpath , file_name ))
276
287
grid = make_grid (grid , nrow = n_rows )
277
288
278
289
# to image
279
- grid = 255. * rearrange (grid , 'c h w -> h w c' ).cpu ().numpy ()
280
- Image .fromarray (grid .astype (np .uint8 )).save (os .path .join (outpath , f'{ prompt .replace (" " , "-" )[:30 ]} -{ grid_count :04} .jpg' ))
290
+ grid = 255.0 * rearrange (grid , "c h w -> h w c" ).cpu ().numpy ()
291
+ Image .fromarray (grid .astype (np .uint8 )).save (
292
+ os .path .join (
293
+ outpath ,
294
+ f'{ batch_uuid .hex [:10 ]} -{ prompt .replace (" " , "-" )[:30 ]} -{ grid_count :04} .jpg' ,
295
+ )
296
+ )
281
297
grid_count += 1
282
-
283
-
284
298
285
299
toc = time .time ()
286
300
287
- print (f"Your samples are ready and waiting for you here: \n { outpath } \n "
288
- f" \n Enjoy." )
301
+ print (
302
+ f"Your samples are ready and waiting for you here: \n { outpath } \n " f" \n Enjoy."
303
+ )
289
304
290
305
291
306
if __name__ == "__main__" :
0 commit comments