@@ -47,6 +47,10 @@ def is_ray_available():
4747 return importlib .util .find_spec ("ray.air" ) is not None
4848
4949
50+ def is_swanlab_available ():
51+ return importlib .util .find_spec ("swanlab" ) is not None
52+
53+
5054def get_available_reporting_integrations ():
5155 integrations = []
5256 if is_visualdl_available ():
@@ -55,7 +59,8 @@ def get_available_reporting_integrations():
5559 integrations .append ("wandb" )
5660 if is_tensorboardX_available ():
5761 integrations .append ("tensorboard" )
58-
62+ if is_swanlab_available ():
63+ integrations .append ("swanlab" )
5964 return integrations
6065
6166
@@ -410,11 +415,145 @@ def on_evaluate(self, args, state, control, **kwargs):
410415 self .session .report (metrics )
411416
412417
418+ class SwanLabCallback (TrainerCallback ):
419+ """
420+ A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
421+ """
422+
423+ def __init__ (self ):
424+ if not is_swanlab_available ():
425+ raise RuntimeError ("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`." )
426+ import swanlab
427+
428+ self ._swanlab = swanlab
429+ self ._initialized = False
430+ self ._log_model = os .getenv ("SWANLAB_LOG_MODEL" , None )
431+
432+ def setup (self , args , state , model , ** kwargs ):
433+ """
434+ Setup the optional SwanLab (*swanlab*) integration.
435+
436+ One can subclass and override this method to customize the setup if needed. Find more information
437+ [here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
438+
439+ You can also override the following environment variables. Find more information about environment
440+ variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
441+
442+ Environment:
443+ - **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
444+ Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
445+ checks if the user is already logged in. If not, the login process is initiated.
446+
447+ - If a string is passed to the login interface, this environment variable is ignored.
448+ - If the user is already logged in, this environment variable takes precedence over locally stored
449+ login information.
450+
451+ - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
452+ Set this to a custom string to store results in a different project. If not specified, the name of the current
453+ running directory is used.
454+
455+ - **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
456+ This environment variable specifies the storage path for log files when running in local mode.
457+ By default, logs are saved in a folder named swanlog under the working directory.
458+
459+ - **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
460+ SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
461+ local, cloud, and disabled. Note: Case-sensitive. Find more information
462+ [here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
463+
464+ - **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
465+ SwanLab does not currently support the save mode functionality.This feature will be available in a future
466+ release
467+
468+ - **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
469+ Web address for the SwanLab cloud environment for private version (its free)
470+
471+ - **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
472+ API address for the SwanLab cloud environment for private version (its free)
473+
474+ """
475+ self ._initialized = True
476+
477+ if state .is_world_process_zero :
478+ logger .info ('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"' )
479+ combined_dict = {** args .to_dict ()}
480+
481+ if hasattr (model , "config" ) and model .config is not None :
482+ model_config = model .config if isinstance (model .config , dict ) else model .config .to_dict ()
483+ combined_dict = {** model_config , ** combined_dict }
484+ if hasattr (model , "lora_config" ) and model .lora_config is not None :
485+ lora_config = model .lora_config if isinstance (model .lora_config , dict ) else model .lora_config .to_dict ()
486+ combined_dict = {** {"lora_config" : lora_config }, ** combined_dict }
487+ trial_name = state .trial_name
488+ init_args = {}
489+ if trial_name is not None and args .run_name is not None :
490+ init_args ["experiment_name" ] = f"{ args .run_name } -{ trial_name } "
491+ elif args .run_name is not None :
492+ init_args ["experiment_name" ] = args .run_name
493+ elif trial_name is not None :
494+ init_args ["experiment_name" ] = trial_name
495+ init_args ["project" ] = os .getenv ("SWANLAB_PROJECT" , "PaddleNLP" )
496+
497+ if self ._swanlab .get_run () is None :
498+ self ._swanlab .init (
499+ ** init_args ,
500+ )
501+ # show transformers logo!
502+ self ._swanlab .config ["FRAMEWORK" ] = "paddlenlp"
503+ # add config parameters (run may have been created manually)
504+ self ._swanlab .config .update (combined_dict )
505+
506+ def on_train_begin (self , args , state , control , model = None , ** kwargs ):
507+ if not self ._initialized :
508+ self .setup (args , state , model , ** kwargs )
509+
510+ def on_train_end (self , args , state , control , model = None , processing_class = None , ** kwargs ):
511+ if self ._log_model is not None and self ._initialized and state .is_world_process_zero :
512+ logger .warning (
513+ "SwanLab does not currently support the save mode functionality. "
514+ "This feature will be available in a future release."
515+ )
516+
517+ def on_log (self , args , state , control , model = None , logs = None , ** kwargs ):
518+ single_value_scalars = [
519+ "train_runtime" ,
520+ "train_samples_per_second" ,
521+ "train_steps_per_second" ,
522+ "train_loss" ,
523+ "total_flos" ,
524+ ]
525+
526+ if not self ._initialized :
527+ self .setup (args , state , model )
528+ if state .is_world_process_zero :
529+ for k , v in logs .items ():
530+ if k in single_value_scalars :
531+ self ._swanlab .log ({f"single_value/{ k } " : v }, step = state .global_step )
532+ non_scalar_logs = {k : v for k , v in logs .items () if k not in single_value_scalars }
533+ non_scalar_logs = rewrite_logs (non_scalar_logs )
534+ self ._swanlab .log ({** non_scalar_logs , "train/global_step" : state .global_step }, step = state .global_step )
535+
536+ def on_save (self , args , state , control , ** kwargs ):
537+ if self ._log_model is not None and self ._initialized and state .is_world_process_zero :
538+ logger .warning (
539+ "SwanLab does not currently support the save mode functionality. "
540+ "This feature will be available in a future release."
541+ )
542+
543+ def on_predict (self , args , state , control , metrics , ** kwargs ):
544+ if not self ._initialized :
545+ self .setup (args , state , ** kwargs )
546+ if state .is_world_process_zero :
547+ metrics = rewrite_logs (metrics )
548+ self ._swanlab .log (metrics )
549+
550+
413551INTEGRATION_TO_CALLBACK = {
414552 "visualdl" : VisualDLCallback ,
415553 "autonlp" : AutoNLPCallback ,
416554 "wandb" : WandbCallback ,
417555 "tensorboard" : TensorBoardCallback ,
556+ "swanlab" : SwanLabCallback ,
418557}
419558
420559
0 commit comments