Skip to content

Commit fc425c4

Browse files
committed
add_SwanLabCallback
1 parent 286ae73 commit fc425c4

4 files changed

Lines changed: 146 additions & 3 deletions

File tree

paddlenlp/trainer/integrations.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def is_ray_available():
4747
return importlib.util.find_spec("ray.air") is not None
4848

4949

50+
def is_swanlab_available():
51+
return importlib.util.find_spec("swanlab") is not None
52+
53+
5054
def get_available_reporting_integrations():
5155
integrations = []
5256
if is_visualdl_available():
@@ -55,7 +59,8 @@ def get_available_reporting_integrations():
5559
integrations.append("wandb")
5660
if is_tensorboardX_available():
5761
integrations.append("tensorboard")
58-
62+
if is_swanlab_available():
63+
integrations.append("swanlab")
5964
return integrations
6065

6166

@@ -410,11 +415,147 @@ def on_evaluate(self, args, state, control, **kwargs):
410415
self.session.report(metrics)
411416

412417

418+
class SwanLabCallback(TrainerCallback):
419+
"""
420+
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
421+
"""
422+
423+
def __init__(self):
424+
if not is_swanlab_available():
425+
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
426+
import swanlab
427+
428+
self._swanlab = swanlab
429+
self._initialized = False
430+
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
431+
432+
def setup(self, args, state, model, **kwargs):
433+
"""
434+
Setup the optional SwanLab (*swanlab*) integration.
435+
436+
One can subclass and override this method to customize the setup if needed. Find more information
437+
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
438+
439+
You can also override the following environment variables. Find more information about environment
440+
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
441+
442+
Environment:
443+
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
444+
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
445+
checks if the user is already logged in. If not, the login process is initiated.
446+
447+
- If a string is passed to the login interface, this environment variable is ignored.
448+
- If the user is already logged in, this environment variable takes precedence over locally stored
449+
login information.
450+
451+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
452+
Set this to a custom string to store results in a different project. If not specified, the name of the current
453+
running directory is used.
454+
455+
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
456+
This environment variable specifies the storage path for log files when running in local mode.
457+
By default, logs are saved in a folder named swanlog under the working directory.
458+
459+
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
460+
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
461+
local, cloud, and disabled. Note: Case-sensitive. Find more information
462+
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
463+
464+
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
465+
SwanLab does not currently support the save mode functionality.This feature will be available in a future
466+
release
467+
468+
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
469+
Web address for the SwanLab cloud environment for private version (its free)
470+
471+
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
472+
API address for the SwanLab cloud environment for private version (its free)
473+
474+
"""
475+
self._initialized = True
476+
477+
if state.is_world_process_zero:
478+
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
479+
combined_dict = {**args.to_dict()}
480+
481+
if hasattr(model, "config") and model.config is not None:
482+
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
483+
combined_dict = {**model_config, **combined_dict}
484+
if hasattr(model, "lora_config") and model.lora_config is not None:
485+
lora_config = model.lora_config if isinstance(model.lora_config, dict) else model.lora_config.to_dict()
486+
combined_dict = {**{"lora_config": lora_config}, **combined_dict}
487+
trial_name = state.trial_name
488+
init_args = {}
489+
if trial_name is not None and args.run_name is not None:
490+
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
491+
elif args.run_name is not None:
492+
init_args["experiment_name"] = args.run_name
493+
elif trial_name is not None:
494+
init_args["experiment_name"] = trial_name
495+
init_args["project"] = os.getenv("SWANLAB_PROJECT", "PaddleNLP")
496+
if args.logging_dir is not None:
497+
init_args["logdir"] = os.getenv("SWANLAB_LOG_DIR", args.logging_dir)
498+
499+
if self._swanlab.get_run() is None:
500+
self._swanlab.init(
501+
**init_args,
502+
)
503+
# show paddlenlp logo!
504+
self._swanlab.config["FRAMEWORK"] = "paddlenlp"
505+
# add config parameters (run may have been created manually)
506+
self._swanlab.config.update(combined_dict)
507+
508+
def on_train_begin(self, args, state, control, model=None, **kwargs):
509+
if not self._initialized:
510+
self.setup(args, state, model, **kwargs)
511+
512+
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
513+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
514+
logger.warning(
515+
"SwanLab does not currently support the save mode functionality. "
516+
"This feature will be available in a future release."
517+
)
518+
519+
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
520+
single_value_scalars = [
521+
"train_runtime",
522+
"train_samples_per_second",
523+
"train_steps_per_second",
524+
"train_loss",
525+
"total_flos",
526+
]
527+
528+
if not self._initialized:
529+
self.setup(args, state, model)
530+
if state.is_world_process_zero:
531+
for k, v in logs.items():
532+
if k in single_value_scalars:
533+
self._swanlab.log({f"single_value/{k}": v}, step=state.global_step)
534+
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
535+
non_scalar_logs = rewrite_logs(non_scalar_logs)
536+
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step}, step=state.global_step)
537+
538+
def on_save(self, args, state, control, **kwargs):
539+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
540+
logger.warning(
541+
"SwanLab does not currently support the save mode functionality. "
542+
"This feature will be available in a future release."
543+
)
544+
545+
def on_predict(self, args, state, control, metrics, **kwargs):
546+
if not self._initialized:
547+
self.setup(args, state, **kwargs)
548+
if state.is_world_process_zero:
549+
metrics = rewrite_logs(metrics)
550+
self._swanlab.log(metrics)
551+
552+
413553
INTEGRATION_TO_CALLBACK = {
414554
"visualdl": VisualDLCallback,
415555
"autonlp": AutoNLPCallback,
416556
"wandb": WandbCallback,
417557
"tensorboard": TensorBoardCallback,
558+
"swanlab": SwanLabCallback,
418559
}
419560

420561

paddlenlp/trainer/training_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ class TrainingArguments:
382382
instance of `Dataset`.
383383
report_to (`str` or `List[str]`, *optional*, defaults to `"visualdl"`):
384384
The list of integrations to report the results and logs to.
385-
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
385+
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`/`"swanlab".
386386
`"none"` for no integrations.
387387
ddp_find_unused_parameters (`bool`, *optional*):
388388
When using distributed training, the value of the flag `find_unused_parameters` passed to

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ wget
3131
huggingface_hub>=0.19.2
3232
tiktoken
3333
tokenizers<=0.20.3; python_version<="3.8"
34-
tokenizers>=0.21,<0.22; python_version>"3.8"
34+
tokenizers>=0.21,<0.22; python_version>"3.8"
35+
swanlab[dashboard]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ ml_dtypes
3030
tokenizers<=0.20.3; python_version<="3.8"
3131
tokenizers>=0.21,<0.22; python_version>"3.8"
3232
omegaconf
33+
swanlab[dashboard]

0 commit comments

Comments
 (0)