diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index aaba90eeec..3e2a21617e 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -15,6 +15,7 @@ MonolithicCheckpointSaver from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling, LayerFreezing) +from llmfoundry.callbacks.wandb_loss_monitor import WAndBLossMonitor from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector from llmfoundry.registry import callbacks, callbacks_with_config @@ -33,6 +34,7 @@ callbacks.register('mono_checkpoint_saver', func=MonolithicCheckpointSaver) callbacks.register('scheduled_gc', func=ScheduledGarbageCollector) callbacks.register('oom_observer', func=OOMObserver) +callbacks.register('wandb_loss_monitor', func=WAndBLossMonitor) callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -47,4 +49,5 @@ 'HuggingFaceCheckpointer', 'AsyncEval', 'CurriculumLearning', + 'WAndBLossMonitor' ] diff --git a/llmfoundry/callbacks/wandb_loss_monitor.py b/llmfoundry/callbacks/wandb_loss_monitor.py new file mode 100644 index 0000000000..4eea1b9aba --- /dev/null +++ b/llmfoundry/callbacks/wandb_loss_monitor.py @@ -0,0 +1,131 @@ +import logging +import os +from collections import deque +from datetime import datetime +from typing import Optional, Deque, Tuple + +import numpy as np +import numpy as np +import torch +from composer.core import State +from composer.loggers import Logger +from composer.utils import dist +from scipy.stats import linregress + +from llmfoundry.interfaces import CallbackWithConfig +import wandb + + +log = logging.getLogger(__name__) + + +class WAndBLossMonitor(CallbackWithConfig): + + def __init__( + self, + window_size: int = 100, + frequency_threshold: float = 0.6, + magnitude_threshold: float = 0.05, + slope_threshold: float = 0.1, + alert_frequency: int = 300, + report_ok: bool = False, + ) -> None: + self.alert_frequency = alert_frequency + self.report_ok = report_ok + + # Ensure we can check as soon as we sample enough data + self.last_alert = -self.alert_frequency + self.checker = LossDivergenceChecker( + window_size, frequency_threshold, magnitude_threshold, slope_threshold + ) + def batch_end(self, state: State, logger: Logger) -> None: + # Only need to run on master process + if dist.get_global_rank() != 0: + return + + if not isinstance(state.loss, torch.Tensor): + raise NotImplementedError("Multiple losses not supported.") + step = state.timestamp.batch.value + loss = state.loss.item() + now = state.timestamp.total_wct.seconds + + div_start = 500 + div_dur = 200 + if step > div_start: + loss *= ((step - div_start) % div_dur) / div_dur * 5.0 + + self.checker.sample(step, loss) + + if self._check(now): + message, alert = self.checker.check(loss) + if alert or (self.report_ok and message is not None): + self._alert(message, state.run_name) + self.last_alert = now + + return alert + + def _alert(self, message:str, run_name:str) -> None: + prefix = f"[{datetime.now()}][{run_name}][node_rank={dist.get_node_rank()}]" + node_name = os.environ.get("NODENAME", None) + if node_name is not None: + prefix += f"[node={node_name}]" + + wandb.alert(title = prefix, + text = message, + level = wandb.AlertLevel.WARN, + wait_duration = 300) + + def _check(self, wallclock_time: int) -> bool: + return wallclock_time - self.last_alert >= self.alert_frequency + +class LossDivergenceChecker: + def __init__( + self, + window_size: int, + frequency_threshold: float, + magnitude_threshold: float, + slope_threshold: float, + ) -> None: + self.window_size = window_size + self.frequency_threshold = frequency_threshold + self.magnitude_threshold = magnitude_threshold + self.slope_threshold = slope_threshold + + self.steps: Deque[int] = deque(maxlen=self.window_size) + self.losses: Deque[float] = deque(maxlen=self.window_size) + + def check(self, loss: float) -> Tuple[Optional[str], bool]: + # Skip if we have not sampled enough data points + if len(self.losses) != self.losses.maxlen: + return None, False + + message = "No divergence" + min_loss = min(self.losses) + loss_range = max(self.losses) - min_loss + + orig_loss = self.losses[0] + # Skip if loss is converging at the end of the window + if loss - orig_loss <= 0: + return message, False + + norm_orig_loss = (orig_loss - min_loss) / loss_range + num_positives = 0 + for i in range(1, len(self.losses)): + norm_loss = (self.losses[i] - min_loss) / loss_range + if norm_loss > self.magnitude_threshold and norm_loss > norm_orig_loss: + num_positives += 1 + + if num_positives >= self.window_size * self.frequency_threshold: + min_step = min(self.steps) + step_range = max(self.steps) - min_step + steps = (np.array(self.steps) - min_step) / step_range + losses = (np.array(self.losses) - min_loss) / loss_range + result = linregress(steps, losses) + if result.slope > self.slope_threshold: + message = "Divergence detected" + return message, True + return message, False + + def sample(self, step: int, loss: float) -> None: + self.steps.append(step) + self.losses.append(loss) \ No newline at end of file diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 38ed7a7e70..2e6f903447 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -88,7 +88,7 @@ def __init__(self, om_model_config: DictConfig, # Resolve "mixed" init device to either "cpu" or "meta" resolved_init_device = hf_get_init_device(init_device) requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' - + allow_embedding_resizing = om_model_config.get('allow_embedding_resizing', False) if use_flash_attention_2 and not is_flash_v2_installed(): raise ValueError( 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' @@ -256,8 +256,9 @@ def _autoset_attn_implementation_monkeypatch( eval_metrics=eval_metrics, init_device=init_device, peft_config=peft_config, + allow_embedding_resizing = allow_embedding_resizing ) - + self.n_active_params = sum(p.numel() for p in self.parameters()) @staticmethod def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': if peft_installed: @@ -277,3 +278,20 @@ def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig': raise ValueError( 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.' ) + + def flops_per_batch(self, batch: Mapping) -> int: + # Note: this computation does not take into account padding, and assumes + # that the dataset has been constructed without padding. Additionally, we + # assume the backward pass is approximately 2x the forward pass + + bs, msl = batch['input_ids'].shape[0:2] + params = self.n_active_params + if not self.model.config.tie_word_embeddings: + # embedding layers are lookup tables, therefore are not counted in the FLOP computation + params -= self.model.lm_head.weight.numel() + params_flops_per_token = 2 * params + params_flops_per_seq = params_flops_per_token * msl + attn_flops_per_seq = (self.model.config.num_hidden_layers * 2 * 2 * + (self.model.config.hidden_size * (msl**2))) + + return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 2ba88d390c..ced1a4efb6 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -36,7 +36,8 @@ def __init__(self, eval_metrics: Optional[List[Metric]] = None, shift_labels: bool = False, init_device: Optional[str] = None, - peft_config: Optional['PeftConfig'] = None): + peft_config: Optional['PeftConfig'] = None, + allow_embedding_resizing: bool = False): super().__init__( model, tokenizer, @@ -46,6 +47,7 @@ def __init__(self, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True, + allow_embedding_resizing = allow_embedding_resizing ) # Note: We need to add the FSDP related attributes to the model AFTER the super init, diff --git a/setup.py b/setup.py index 22b7cb17ca..3e06efdff2 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,7 @@ import copy import os import re +import warnings import setuptools from setuptools import setup @@ -51,10 +52,10 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22', + # 'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22', 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.38.2,<4.39', + 'transformers>=4.39.0', 'mosaicml-streaming>=0.7.4,<0.8', 'torch>=2.2.1,<2.3', 'datasets>=2.16,<2.17', @@ -72,6 +73,7 @@ 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', 'typer[all]<1', + 'loguru' ] extra_deps = {} @@ -147,3 +149,6 @@ 'console_scripts': ['llmfoundry = llmfoundry.cli.cli:app'], }, ) + +warnings.warn("The required package 'composer' has been removed in this fork." + "Please install it separately.", UserWarning)