Skip to content

Commit 9604920

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 513dd94 commit 9604920

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

torchtitan/checkpoint.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@
2626
from torchtitan.logging_utils import init_logger, logger
2727

2828

29-
DTYPE_MAP = {
30-
"float16": torch.float16,
31-
"float32": torch.float32,
32-
"bfloat16": torch.bfloat16,
33-
}
34-
35-
3629
class IntervalType(enum.Enum):
3730
SECONDS = enum.auto()
3831
STEPS = enum.auto()
@@ -141,7 +134,7 @@ def __init__(
141134
self.pg = dist.new_group(backend="gloo")
142135

143136
self.model_weights_only = ckpt_config.model_weights_only
144-
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype]
137+
self.export_dtype = ckpt_config.export_dtype
145138

146139
self.mp = None
147140
async_mode = ckpt_config.async_mode.lower()

torchtitan/config_manager.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,25 @@
99
from collections import defaultdict
1010
from typing import Tuple, Union
1111

12+
import torch
13+
1214
try:
1315
import tomllib
1416
except ModuleNotFoundError:
1517
import tomli as tomllib
1618

1719
from torchtitan.logging_utils import logger
1820

21+
DTYPE_MAP = {
22+
"float16": torch.float16,
23+
"float32": torch.float32,
24+
"bfloat16": torch.bfloat16,
25+
}
26+
27+
28+
def torch_dtype(dtype_str: str) -> torch.dtype:
29+
return DTYPE_MAP[dtype_str]
30+
1931

2032
def string_list(raw_arg):
2133
return raw_arg.split(",")
@@ -257,6 +269,26 @@ def __init__(self):
257269
split via the provided split points, unflattened into an nn.Module,
258270
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
259271
)
272+
self.parser.add_argument(
273+
"--training.mixed_precision_param",
274+
type=torch_dtype,
275+
default="bfloat16",
276+
choices=["bfloat16", "float32"],
277+
help="""
278+
torch dtype to use for parameters when applying mixed precision via FSDP.
279+
This feature only takes effect when data_parallel_degree > 1
280+
""",
281+
)
282+
self.parser.add_argument(
283+
"--training.mixed_precision_reduce",
284+
type=torch_dtype,
285+
default="float32",
286+
choices=["float32"],
287+
help="""
288+
torch dtype to use for reductions when applying mixed precision via FSDP.
289+
This feature only takes effect when data_parallel_degree > 1
290+
""",
291+
)
260292
self.parser.add_argument(
261293
"--training.compile",
262294
action="store_true",
@@ -323,8 +355,9 @@ def __init__(self):
323355
)
324356
self.parser.add_argument(
325357
"--checkpoint.export_dtype",
326-
type=str,
358+
type=torch_dtype,
327359
default="float32",
360+
choices=["float16", "bfloat16", "float32"],
328361
help="""
329362
Converts to the specified precision when training completes and model_weights_only=true.
330363
Currently supports float32, float16, and bfloat16.

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def pipeline_llama_manual(
225225
int(job_config.training.seq_len // parallel_dims.tp),
226226
model_config.dim,
227227
),
228-
dtype=torch.bfloat16 if parallel_dims.dp_enabled else torch.float32,
228+
dtype=job_config.training.mixed_precision_param
229+
if parallel_dims.dp_enabled
230+
else torch.float32,
229231
device=device,
230232
)
231233

@@ -248,7 +250,9 @@ def pipeline_llama_manual(
248250
int(job_config.training.seq_len // parallel_dims.tp),
249251
model_config.dim,
250252
),
251-
dtype=torch.bfloat16 if parallel_dims.dp_enabled else torch.float32,
253+
dtype=job_config.training.mixed_precision_param
254+
if parallel_dims.dp_enabled
255+
else torch.float32,
252256
device=device,
253257
)
254258

@@ -386,10 +390,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
386390
if parallel_dims.dp_enabled:
387391
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
388392
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
389-
# TODO: Expose `reduce_dtype` as a config option.
390393
mp_policy = MixedPrecisionPolicy(
391-
param_dtype=torch.bfloat16,
392-
reduce_dtype=torch.float32,
394+
param_dtype=job_config.training.mixed_precision_param,
395+
reduce_dtype=job_config.training.mixed_precision_param,
393396
)
394397
ac_mode = job_config.activation_checkpoint.mode
395398
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}

0 commit comments

Comments
 (0)