Skip to content

Commit 50b073c

Browse files
committed
Merge branch 'ko3n1g/docs/fixes' into 'main'
docs: Fixes to allow building docs again See merge request ADLR/megatron-lm!1962
2 parents 34e607e + 46736de commit 50b073c

File tree

13 files changed

+190
-100
lines changed

13 files changed

+190
-100
lines changed

.gitlab/stages/00.pre.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ label_merge_request:
3838
source labels
3939
curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}" --data-urlencode "add_labels=$LABELS" -X PUT
4040
41+
clean_docker_node:
42+
stage: .pre
43+
image: docker:26.1.4-dind
44+
tags: [mcore-docker-node]
45+
script:
46+
- docker system prune -a --filter "until=48h" -f
47+
4148
check_milestone:
4249
rules:
4350
- if: $CI_PIPELINE_SOURCE == "merge_request_event"

.gitlab/stages/01.tests.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,15 @@ unit_tests:
104104
- coverage
105105

106106
docs_build_test:
107-
image: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/python-format:0.0.1
107+
image: ${CI_MCORE_IMAGE}:${CI_PIPELINE_ID}
108108
tags: [mcore-docker-node-small]
109+
needs: [build_image]
109110
script:
110111
- cd ..
111112
- rm -rf documentation && git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@${GITLAB_ENDPOINT}/nemo-megatron-core-tme/documentation.git
112113
- mv megatron-lm/ documentation/
113114
- cd documentation/
114115
- ./repo docs
115-
allow_failure: true
116-
except:
117-
- main
118116

119117
formatting:
120118
extends: [.tests_common]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ The figures below illustrate the grad buffer's sharding scheme, and the key step
2828

2929
## Data flow
3030

31-
![Data flow](images/distrib_optimizer/data_flow.png)
31+
![Data flow](../images/distrib_optimizer/data_flow.png)
3232

3333
## Sharding scheme
3434

35-
![Sharding scheme](images/distrib_optimizer/sharding_scheme.png)
35+
![Sharding scheme](../images/distrib_optimizer/sharding_scheme.png)
3636

3737
## Key steps
3838

docs/source/api-guide/fusions.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ fusions.fused\_cross\_entropy\_loss module
5858

5959
This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls.
6060

61-
.. automodule:: core.fusions.fused_softmax
61+
.. automodule:: core.fusions.fused_cross_entropy
6262
:members:
6363
:undoc-members:
6464
:show-inheritance:

docs/source/api-guide/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ API Guide
1212
transformer
1313
moe
1414
dist_checkpointing
15+
dist_optimizer
1516
distributed
1617
datasets
1718
num_microbatches_calculator

docs/source/api-guide/num_microbatches_calculator.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Microbatches Calculator
2-
==============
2+
=======================
33
This api is used to calculate the number of microbatches required to fit a given model on a given batch size.
44

55

megatron/core/dist_checkpointing/strategies/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@
22

33
""" Various loading and saving strategies """
44

5-
from .common import _import_trigger
5+
# We mock imports to populate the `default_strategies` objects.
6+
# Since they are defined in base but populated in common, we have to mock
7+
# import both modules.
8+
from megatron.core.dist_checkpointing.strategies.base import _import_trigger
9+
from megatron.core.dist_checkpointing.strategies.common import _import_trigger

megatron/core/dist_checkpointing/strategies/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import defaultdict
77
from enum import Enum
88
from pathlib import Path
9+
from typing import Any, DefaultDict
910

1011
from ..mapping import CheckpointingException, ShardedStateDict, StateDict
1112
from .async_utils import AsyncCallsQueue, AsyncRequest
@@ -18,7 +19,8 @@ class StrategyAction(Enum):
1819
SAVE_SHARDED = 'save_sharded'
1920

2021

21-
default_strategies = defaultdict(dict)
22+
_import_trigger = None
23+
default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)
2224

2325
async_calls = AsyncCallsQueue()
2426

@@ -35,7 +37,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
3537
from .torch import _import_trigger
3638
except ImportError as e:
3739
raise CheckpointingException(
38-
f'Cannot import a default strategy for: {(action.value, backend, version)}. Error: {e}. Hint: {error_hint}'
40+
f'Cannot import a default strategy for: {(action.value, backend, version)}. '
41+
f'Error: {e}. Hint: {error_hint}'
3942
) from e
4043
try:
4144
return default_strategies[action.value][(backend, version)]
@@ -46,7 +49,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
4649

4750

4851
class LoadStrategyBase(ABC):
49-
"""Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version."""
52+
"""Base class for a load strategy. Requires implementing checks for compatibility with a
53+
given checkpoint version."""
5054

5155
@abstractmethod
5256
def check_backend_compatibility(self, loaded_version):
@@ -63,7 +67,8 @@ def can_handle_sharded_objects(self):
6367

6468

6569
class SaveStrategyBase(ABC):
66-
"""Base class for a save strategy. Requires defining a backend type and version of the saved format."""
70+
"""Base class for a save strategy. Requires defining a backend type and
71+
version of the saved format."""
6772

6873
def __init__(self, backend: str, version: int):
6974
self.backend = backend

megatron/core/dist_checkpointing/strategies/common.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66
import os
7-
from itertools import product
87
from pathlib import Path
98

109
import torch
@@ -68,10 +67,12 @@ def load_common(self, checkpoint_dir: Path):
6867
def load_sharded_objects(
6968
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
7069
):
71-
"""Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.
70+
"""Replaces all ShardedObject from a given state dict with values loaded from the
71+
checkpoint.
7272
7373
Args:
74-
sharded_objects_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded.
74+
sharded_objects_state_dict (ShardedStateDict):
75+
sharded state dict defining what objects should be loaded.
7576
checkpoint_dir (Path): checkpoint directory
7677
7778
Returns:
@@ -99,7 +100,8 @@ def load_sharded_object(sh_obj: ShardedObject):
99100
else:
100101
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
101102
logger.debug(
102-
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint directory content: {ckpt_files}'
103+
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint'
104+
f' directory content: {ckpt_files}'
103105
)
104106
raise CheckpointingException(err_msg) from e
105107
return loaded_obj
@@ -119,7 +121,8 @@ def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
119121
full_key = f'{subdir.name}/{shard_file.stem}'
120122
sh_objs.append(ShardedObject.empty_from_unique_key(full_key))
121123

122-
# This is a backward-compatibility fix, where the last global shape is missing in the name
124+
# This is a backward-compatibility fix, where the last global shape is missing in the
125+
# name
123126
if sh_objs[0].global_shape[-1] < 0:
124127
max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs))
125128
for sh_obj in sh_objs:

megatron/core/fusions/fused_bias_gelu.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from megatron.core.jit import jit_fuser
66

7-
###### BIAS GELU FUSION/ NO AUTOGRAD ################
7+
# BIAS GELU FUSION/ NO AUTOGRAD ################
88
# 1/sqrt(2*pi)-> 0.3989423
99
# 1/sqrt(2) -> 0.70710678
1010
# sqrt(2/pi) -> 0.79788456
@@ -46,5 +46,10 @@ def backward(ctx, grad_output):
4646
tmp = bias_gelu_back(grad_output, bias, input)
4747
return tmp, tmp
4848

49+
# This is required to make Sphinx happy :-(
50+
@classmethod
51+
def apply(cls, *args, **kwargs):
52+
super().apply(*args, **kwargs)
53+
4954

5055
bias_gelu_impl = GeLUFunction.apply

0 commit comments

Comments
 (0)