11import argparse
22import os
33
4+ import dataloader
45import jax
56import matplotlib .pyplot as plt
67import numpy as np
1213from jax .experimental import mesh_utils
1314
1415import blaxbird
15- import dataloader
1616from blaxbird import get_default_checkpointer , train_fn
1717from blaxbird .experimental import rfm
1818
@@ -38,19 +38,19 @@ def visualize_hook(sample_fn, val_iter, hook_every_n_steps, log_to_wandb):
3838
3939 def convert_batch_to_image_grid (image_batch ):
4040 reshaped = (
41- image_batch .reshape (n_row , n_col , * img_size )
42- .transpose ([0 , 2 , 1 , 3 , 4 ])
43- .reshape (n_row * img_size [0 ], n_col * img_size [1 ], img_size [2 ])
41+ image_batch .reshape (n_row , n_col , * img_size )
42+ .transpose ([0 , 2 , 1 , 3 , 4 ])
43+ .reshape (n_row * img_size [0 ], n_col * img_size [1 ], img_size [2 ])
4444 )
4545 return (reshaped + 1.0 ) / 2.0
4646
4747 def plot (images ):
4848 fig = plt .figure (figsize = (16 , 6 ))
4949 ax = fig .add_subplot (1 , 1 , 1 )
5050 ax .imshow (
51- images ,
52- interpolation = "nearest" ,
53- cmap = "gray" ,
51+ images ,
52+ interpolation = "nearest" ,
53+ cmap = "gray" ,
5454 )
5555 plt .axis ("off" )
5656 plt .tight_layout ()
@@ -61,26 +61,31 @@ def fn(step, *, model, **kwargs):
6161 return
6262 all_samples = []
6363 for i , batch in enumerate (val_iter ):
64- samples = sample_fn (model , jr .fold_in (jr .key (step ), i ), sample_shape = batch ["inputs" ].shape )
65- all_samples .append (samples )
66- if len (all_samples ) * all_samples [0 ].shape [0 ] >= n_row * n_col :
67- break
68- all_samples = np .concatenate (all_samples , axis = 0 )[:(n_row * n_col )]
64+ samples = sample_fn (
65+ model , jr .fold_in (jr .key (step ), i ), sample_shape = batch ["inputs" ].shape
66+ )
67+ all_samples .append (samples )
68+ if len (all_samples ) * all_samples [0 ].shape [0 ] >= n_row * n_col :
69+ break
70+ all_samples = np .concatenate (all_samples , axis = 0 )[: (n_row * n_col )]
6971 all_samples = convert_batch_to_image_grid (all_samples )
7072 fig = plot (all_samples )
7173 if jax .process_index () == 0 and log_to_wandb :
72- wandb .log ({"images" : wandb .Image (fig )}, step = step )
74+ wandb .log ({"images" : wandb .Image (fig )}, step = step )
7375
7476 return fn
7577
7678
77- def get_hooks (sample_fn , val_itr , hook_every_n_steps , log_to_wandb ):
79+ def get_hooks (sample_fn , val_itr , hook_every_n_steps , log_to_wandb ):
7880 return [visualize_hook (sample_fn , val_itr , hook_every_n_steps , log_to_wandb )]
7981
8082
8183def get_train_and_val_itrs (rng_key , outfolder ):
8284 return dataloader .data_loaders (
83- rng_key , outfolder , split = ["train[:90%]" , "train[90%:]" ], shuffle = [True , False ],
85+ rng_key ,
86+ outfolder ,
87+ split = ["train[:90%]" , "train[90%:]" ],
88+ shuffle = [True , False ],
8489 )
8590
8691
@@ -92,14 +97,19 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb):
9297 jr .key (0 ), os .path .join (outfolder , "data" )
9398 )
9499
95- model = getattr (blaxbird .experimental , dit_type )(image_size = (32 , 32 , 3 ), rngs = nnx .rnglib .Rngs (jr .key (1 )))
100+ model = getattr (blaxbird .experimental , dit_type )(
101+ image_size = (32 , 32 , 3 ), rngs = nnx .rnglib .Rngs (jr .key (1 ))
102+ )
96103 train_step , val_step , sample_fn = rfm ()
97104 optimizer = get_optimizer (model )
98105
99106 save_fn , _ , restore_last_fn = get_default_checkpointer (
100- os .path .join (outfolder , "checkpoints" ), save_every_n_steps = eval_every_n_steps
107+ os .path .join (outfolder , "checkpoints" ),
108+ save_every_n_steps = eval_every_n_steps ,
101109 )
102- hooks = get_hooks (sample_fn , val_itr , eval_every_n_steps , log_to_wandb ) + [save_fn ]
110+ hooks = get_hooks (sample_fn , val_itr , eval_every_n_steps , log_to_wandb ) + [
111+ save_fn
112+ ]
103113
104114 model_sharding , data_sharding = get_sharding ()
105115 model , optimizer = restore_last_fn (model , optimizer )
@@ -121,7 +131,15 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb):
121131 parser .add_argument ("--n-steps" , type = int , default = 1_000 )
122132 parser .add_argument ("--eval-every-n-steps" , type = int , default = 50 )
123133 parser .add_argument ("--n-eval-batches" , type = int , default = 10 )
124- parser .add_argument ("--dit" , type = str , choices = ["SmallDiT" , "BaseDiT" ], default = "SmallDiT" )
134+ parser .add_argument (
135+ "--dit" , type = str , choices = ["SmallDiT" , "BaseDiT" ], default = "SmallDiT"
136+ )
125137 parser .add_argument ("--log-to-wandb" , action = "store_true" )
126138 args = parser .parse_args ()
127- run (args .n_steps , args .eval_every_n_steps , args .n_eval_batches , args .dit , args .log_to_wandb )
139+ run (
140+ args .n_steps ,
141+ args .eval_every_n_steps ,
142+ args .n_eval_batches ,
143+ args .dit ,
144+ args .log_to_wandb ,
145+ )
0 commit comments