@@ -250,6 +250,17 @@ def add_general_parsers(parser):
250250 default = None ,
251251 help = "Pretrained LoRA path. Required if the training is resumed." ,
252252 )
253+ parser .add_argument (
254+ "--use_swanlab" ,
255+ default = False ,
256+ action = "store_true" ,
257+ help = "Whether to use SwanLab logger." ,
258+ )
259+ parser .add_argument (
260+ "--swanlab_mode" ,
261+ default = None ,
262+ help = "SwanLab mode (cloud or local)." ,
263+ )
253264 return parser
254265
255266
@@ -270,6 +281,20 @@ def launch_training_task(model, args):
270281 num_workers = args .dataloader_num_workers
271282 )
272283 # train
284+ if args .use_swanlab :
285+ from swanlab .integration .pytorch_lightning import SwanLabLogger
286+ swanlab_config = {"UPPERFRAMEWORK" : "DiffSynth-Studio" }
287+ swanlab_config .update (vars (args ))
288+ swanlab_logger = SwanLabLogger (
289+ project = "diffsynth_studio" ,
290+ name = "diffsynth_studio" ,
291+ config = swanlab_config ,
292+ mode = args .swanlab_mode ,
293+ logdir = args .output_path ,
294+ )
295+ logger = [swanlab_logger ]
296+ else :
297+ logger = None
273298 trainer = pl .Trainer (
274299 max_epochs = args .max_epochs ,
275300 accelerator = "gpu" ,
@@ -278,7 +303,8 @@ def launch_training_task(model, args):
278303 strategy = args .training_strategy ,
279304 default_root_dir = args .output_path ,
280305 accumulate_grad_batches = args .accumulate_grad_batches ,
281- callbacks = [pl .pytorch .callbacks .ModelCheckpoint (save_top_k = - 1 )]
306+ callbacks = [pl .pytorch .callbacks .ModelCheckpoint (save_top_k = - 1 )],
307+ logger = logger ,
282308 )
283309 trainer .fit (model = model , train_dataloaders = train_loader )
284310
0 commit comments