Skip to content

Commit b4afd7a

Browse files
committed
support swanlab
1 parent 286ae73 commit b4afd7a

4 files changed

Lines changed: 144 additions & 3 deletions

File tree

paddlenlp/trainer/integrations.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def is_ray_available():
4747
return importlib.util.find_spec("ray.air") is not None
4848

4949

50+
def is_swanlab_available():
51+
return importlib.util.find_spec("swanlab") is not None
52+
53+
5054
def get_available_reporting_integrations():
5155
integrations = []
5256
if is_visualdl_available():
@@ -55,7 +59,8 @@ def get_available_reporting_integrations():
5559
integrations.append("wandb")
5660
if is_tensorboardX_available():
5761
integrations.append("tensorboard")
58-
62+
if is_swanlab_available():
63+
integrations.append("swanlab")
5964
return integrations
6065

6166

@@ -410,11 +415,145 @@ def on_evaluate(self, args, state, control, **kwargs):
410415
self.session.report(metrics)
411416

412417

418+
class SwanLabCallback(TrainerCallback):
419+
"""
420+
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
421+
"""
422+
423+
def __init__(self):
424+
if not is_swanlab_available():
425+
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
426+
import swanlab
427+
428+
self._swanlab = swanlab
429+
self._initialized = False
430+
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
431+
432+
def setup(self, args, state, model, **kwargs):
433+
"""
434+
Setup the optional SwanLab (*swanlab*) integration.
435+
436+
One can subclass and override this method to customize the setup if needed. Find more information
437+
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
438+
439+
You can also override the following environment variables. Find more information about environment
440+
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
441+
442+
Environment:
443+
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
444+
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
445+
checks if the user is already logged in. If not, the login process is initiated.
446+
447+
- If a string is passed to the login interface, this environment variable is ignored.
448+
- If the user is already logged in, this environment variable takes precedence over locally stored
449+
login information.
450+
451+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
452+
Set this to a custom string to store results in a different project. If not specified, the name of the current
453+
running directory is used.
454+
455+
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
456+
This environment variable specifies the storage path for log files when running in local mode.
457+
By default, logs are saved in a folder named swanlog under the working directory.
458+
459+
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
460+
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
461+
local, cloud, and disabled. Note: Case-sensitive. Find more information
462+
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
463+
464+
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
465+
SwanLab does not currently support the save mode functionality.This feature will be available in a future
466+
release
467+
468+
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
469+
Web address for the SwanLab cloud environment for private version (its free)
470+
471+
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
472+
API address for the SwanLab cloud environment for private version (its free)
473+
474+
"""
475+
self._initialized = True
476+
477+
if state.is_world_process_zero:
478+
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
479+
combined_dict = {**args.to_dict()}
480+
481+
if hasattr(model, "config") and model.config is not None:
482+
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
483+
combined_dict = {**model_config, **combined_dict}
484+
if hasattr(model, "lora_config") and model.lora_config is not None:
485+
lora_config = model.lora_config if isinstance(model.lora_config, dict) else model.lora_config.to_dict()
486+
combined_dict = {**{"lora_config": lora_config}, **combined_dict}
487+
trial_name = state.trial_name
488+
init_args = {}
489+
if trial_name is not None and args.run_name is not None:
490+
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
491+
elif args.run_name is not None:
492+
init_args["experiment_name"] = args.run_name
493+
elif trial_name is not None:
494+
init_args["experiment_name"] = trial_name
495+
init_args["project"] = os.getenv("SWANLAB_PROJECT", "PaddleNLP")
496+
497+
if self._swanlab.get_run() is None:
498+
self._swanlab.init(
499+
**init_args,
500+
)
501+
# show transformers logo!
502+
self._swanlab.config["FRAMEWORK"] = "paddlenlp"
503+
# add config parameters (run may have been created manually)
504+
self._swanlab.config.update(combined_dict)
505+
506+
def on_train_begin(self, args, state, control, model=None, **kwargs):
507+
if not self._initialized:
508+
self.setup(args, state, model, **kwargs)
509+
510+
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
511+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
512+
logger.warning(
513+
"SwanLab does not currently support the save mode functionality. "
514+
"This feature will be available in a future release."
515+
)
516+
517+
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
518+
single_value_scalars = [
519+
"train_runtime",
520+
"train_samples_per_second",
521+
"train_steps_per_second",
522+
"train_loss",
523+
"total_flos",
524+
]
525+
526+
if not self._initialized:
527+
self.setup(args, state, model)
528+
if state.is_world_process_zero:
529+
for k, v in logs.items():
530+
if k in single_value_scalars:
531+
self._swanlab.log({f"single_value/{k}": v}, step=state.global_step)
532+
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
533+
non_scalar_logs = rewrite_logs(non_scalar_logs)
534+
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step}, step=state.global_step)
535+
536+
def on_save(self, args, state, control, **kwargs):
537+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
538+
logger.warning(
539+
"SwanLab does not currently support the save mode functionality. "
540+
"This feature will be available in a future release."
541+
)
542+
543+
def on_predict(self, args, state, control, metrics, **kwargs):
544+
if not self._initialized:
545+
self.setup(args, state, **kwargs)
546+
if state.is_world_process_zero:
547+
metrics = rewrite_logs(metrics)
548+
self._swanlab.log(metrics)
549+
550+
413551
INTEGRATION_TO_CALLBACK = {
414552
"visualdl": VisualDLCallback,
415553
"autonlp": AutoNLPCallback,
416554
"wandb": WandbCallback,
417555
"tensorboard": TensorBoardCallback,
556+
"swanlab": SwanLabCallback,
418557
}
419558

420559

paddlenlp/trainer/training_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ class TrainingArguments:
382382
instance of `Dataset`.
383383
report_to (`str` or `List[str]`, *optional*, defaults to `"visualdl"`):
384384
The list of integrations to report the results and logs to.
385-
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
385+
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`/`"swanlab".
386386
`"none"` for no integrations.
387387
ddp_find_unused_parameters (`bool`, *optional*):
388388
When using distributed training, the value of the flag `find_unused_parameters` passed to

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ wget
3131
huggingface_hub>=0.19.2
3232
tiktoken
3333
tokenizers<=0.20.3; python_version<="3.8"
34-
tokenizers>=0.21,<0.22; python_version>"3.8"
34+
tokenizers>=0.21,<0.22; python_version>"3.8"
35+
swanlab

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ ml_dtypes
3030
tokenizers<=0.20.3; python_version<="3.8"
3131
tokenizers>=0.21,<0.22; python_version>"3.8"
3232
omegaconf
33+
swanlab

0 commit comments

Comments
 (0)