forked from mosaicml/llm-foundry
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
53 lines (48 loc) · 2.44 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from composer.callbacks import (EarlyStopper, Generate, LRMonitor,
MemoryMonitor, MemorySnapshot, OOMObserver,
OptimizerMonitor, RuntimeEstimator,
SpeedMonitor)
from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
LayerFreezing)
from llmfoundry.callbacks.wandb_loss_monitor import WAndBLossMonitor
from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector
from llmfoundry.registry import callbacks, callbacks_with_config
callbacks.register('lr_monitor', func=LRMonitor)
callbacks.register('memory_monitor', func=MemoryMonitor)
callbacks.register('memory_snapshot', func=MemorySnapshot)
callbacks.register('speed_monitor', func=SpeedMonitor)
callbacks.register('runtime_estimator', func=RuntimeEstimator)
callbacks.register('optimizer_monitor', func=OptimizerMonitor)
callbacks.register('generate_callback', func=Generate)
callbacks.register('early_stopper', func=EarlyStopper)
callbacks.register('fdiff_metrics', func=FDiffMetrics)
callbacks.register('hf_checkpointer', func=HuggingFaceCheckpointer)
callbacks.register('global_lr_scaling', func=GlobalLRScaling)
callbacks.register('layer_freezing', func=LayerFreezing)
callbacks.register('mono_checkpoint_saver', func=MonolithicCheckpointSaver)
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('wandb_loss_monitor', func=WAndBLossMonitor)
callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
__all__ = [
'FDiffMetrics',
'MonolithicCheckpointSaver',
'GlobalLRScaling',
'LayerFreezing',
'ScheduledGarbageCollector',
'EvalGauntlet',
'HuggingFaceCheckpointer',
'AsyncEval',
'CurriculumLearning',
'WAndBLossMonitor'
]