Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat flops monitor #1

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
LayerFreezing)
from llmfoundry.callbacks.wandb_loss_monitor import WAndBLossMonitor
from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector
from llmfoundry.registry import callbacks, callbacks_with_config

Expand All @@ -33,6 +34,7 @@
callbacks.register('mono_checkpoint_saver', func=MonolithicCheckpointSaver)
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('wandb_loss_monitor', func=WAndBLossMonitor)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand All @@ -47,4 +49,5 @@
'HuggingFaceCheckpointer',
'AsyncEval',
'CurriculumLearning',
'WAndBLossMonitor'
]
131 changes: 131 additions & 0 deletions llmfoundry/callbacks/wandb_loss_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import logging
import os
from collections import deque
from datetime import datetime
from typing import Optional, Deque, Tuple

import numpy as np
import numpy as np
import torch
from composer.core import State
from composer.loggers import Logger
from composer.utils import dist
from scipy.stats import linregress

from llmfoundry.interfaces import CallbackWithConfig
import wandb


log = logging.getLogger(__name__)


class WAndBLossMonitor(CallbackWithConfig):

def __init__(
self,
window_size: int = 100,
frequency_threshold: float = 0.6,
magnitude_threshold: float = 0.05,
slope_threshold: float = 0.1,
alert_frequency: int = 300,
report_ok: bool = False,
) -> None:
self.alert_frequency = alert_frequency
self.report_ok = report_ok

# Ensure we can check as soon as we sample enough data
self.last_alert = -self.alert_frequency
self.checker = LossDivergenceChecker(
window_size, frequency_threshold, magnitude_threshold, slope_threshold
)
def batch_end(self, state: State, logger: Logger) -> None:
# Only need to run on master process
if dist.get_global_rank() != 0:
return

if not isinstance(state.loss, torch.Tensor):
raise NotImplementedError("Multiple losses not supported.")
step = state.timestamp.batch.value
loss = state.loss.item()
now = state.timestamp.total_wct.seconds

div_start = 500
div_dur = 200
if step > div_start:
loss *= ((step - div_start) % div_dur) / div_dur * 5.0

self.checker.sample(step, loss)

if self._check(now):
message, alert = self.checker.check(loss)
if alert or (self.report_ok and message is not None):
self._alert(message, state.run_name)
self.last_alert = now

return alert

def _alert(self, message:str, run_name:str) -> None:
prefix = f"[{datetime.now()}][{run_name}][node_rank={dist.get_node_rank()}]"
node_name = os.environ.get("NODENAME", None)
if node_name is not None:
prefix += f"[node={node_name}]"

wandb.alert(title = prefix,
text = message,
level = wandb.AlertLevel.WARN,
wait_duration = 300)

def _check(self, wallclock_time: int) -> bool:
return wallclock_time - self.last_alert >= self.alert_frequency

class LossDivergenceChecker:
def __init__(
self,
window_size: int,
frequency_threshold: float,
magnitude_threshold: float,
slope_threshold: float,
) -> None:
self.window_size = window_size
self.frequency_threshold = frequency_threshold
self.magnitude_threshold = magnitude_threshold
self.slope_threshold = slope_threshold

self.steps: Deque[int] = deque(maxlen=self.window_size)
self.losses: Deque[float] = deque(maxlen=self.window_size)

def check(self, loss: float) -> Tuple[Optional[str], bool]:
# Skip if we have not sampled enough data points
if len(self.losses) != self.losses.maxlen:
return None, False

message = "No divergence"
min_loss = min(self.losses)
loss_range = max(self.losses) - min_loss

orig_loss = self.losses[0]
# Skip if loss is converging at the end of the window
if loss - orig_loss <= 0:
return message, False

norm_orig_loss = (orig_loss - min_loss) / loss_range
num_positives = 0
for i in range(1, len(self.losses)):
norm_loss = (self.losses[i] - min_loss) / loss_range
if norm_loss > self.magnitude_threshold and norm_loss > norm_orig_loss:
num_positives += 1

if num_positives >= self.window_size * self.frequency_threshold:
min_step = min(self.steps)
step_range = max(self.steps) - min_step
steps = (np.array(self.steps) - min_step) / step_range
losses = (np.array(self.losses) - min_loss) / loss_range
result = linregress(steps, losses)
if result.slope > self.slope_threshold:
message = "Divergence detected"
return message, True
return message, False

def sample(self, step: int, loss: float) -> None:
self.steps.append(step)
self.losses.append(loss)
22 changes: 20 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, om_model_config: DictConfig,
# Resolve "mixed" init device to either "cpu" or "meta"
resolved_init_device = hf_get_init_device(init_device)
requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'

allow_embedding_resizing = om_model_config.get('allow_embedding_resizing', False)
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
Expand Down Expand Up @@ -256,8 +256,9 @@ def _autoset_attn_implementation_monkeypatch(
eval_metrics=eval_metrics,
init_device=init_device,
peft_config=peft_config,
allow_embedding_resizing = allow_embedding_resizing
)

self.n_active_params = sum(p.numel() for p in self.parameters())
@staticmethod
def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
if peft_installed:
Expand All @@ -277,3 +278,20 @@ def _get_peft_config(peft_config_dict: Dict[str, Any]) -> 'PeftConfig':
raise ValueError(
'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.'
)

def flops_per_batch(self, batch: Mapping) -> int:
# Note: this computation does not take into account padding, and assumes
# that the dataset has been constructed without padding. Additionally, we
# assume the backward pass is approximately 2x the forward pass

bs, msl = batch['input_ids'].shape[0:2]
params = self.n_active_params
if not self.model.config.tie_word_embeddings:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.lm_head.weight.numel()
params_flops_per_token = 2 * params
params_flops_per_seq = params_flops_per_token * msl
attn_flops_per_seq = (self.model.config.num_hidden_layers * 2 * 2 *
(self.model.config.hidden_size * (msl**2)))

return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs
4 changes: 3 additions & 1 deletion llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self,
eval_metrics: Optional[List[Metric]] = None,
shift_labels: bool = False,
init_device: Optional[str] = None,
peft_config: Optional['PeftConfig'] = None):
peft_config: Optional['PeftConfig'] = None,
allow_embedding_resizing: bool = False):
super().__init__(
model,
tokenizer,
Expand All @@ -46,6 +47,7 @@ def __init__(self,
shift_labels=shift_labels,
peft_config=peft_config,
should_save_peft_only=True,
allow_embedding_resizing = allow_embedding_resizing
)

# Note: We need to add the FSDP related attributes to the model AFTER the super init,
Expand Down
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import os
import re
import warnings

import setuptools
from setuptools import setup
Expand Down Expand Up @@ -51,10 +52,10 @@
]

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22',
# 'mosaicml[libcloud,wandb,oci,gcs]>=0.21.1,<0.22',
'mlflow>=2.10,<3',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.38.2,<4.39',
'transformers>=4.39.0',
'mosaicml-streaming>=0.7.4,<0.8',
'torch>=2.2.1,<2.3',
'datasets>=2.16,<2.17',
Expand All @@ -72,6 +73,7 @@
'tenacity>=8.2.3,<9',
'catalogue>=2,<3',
'typer[all]<1',
'loguru'
]

extra_deps = {}
Expand Down Expand Up @@ -147,3 +149,6 @@
'console_scripts': ['llmfoundry = llmfoundry.cli.cli:app'],
},
)

warnings.warn("The required package 'composer' has been removed in this fork."
"Please install it separately.", UserWarning)