Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 20, 2024
2 parents 666e0d3 + 9604920 commit abe0a2b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 14 deletions.
9 changes: 1 addition & 8 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
from torchtitan.logging_utils import init_logger, logger


DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


class IntervalType(enum.Enum):
SECONDS = enum.auto()
STEPS = enum.auto()
Expand Down Expand Up @@ -141,7 +134,7 @@ def __init__(
self.pg = dist.new_group(backend="gloo")

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype]
self.export_dtype = ckpt_config.export_dtype

self.mp = None
async_mode = ckpt_config.async_mode.lower()
Expand Down
35 changes: 34 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@
from collections import defaultdict
from typing import Tuple, Union

import torch

try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib

from torchtitan.logging_utils import logger

DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


def torch_dtype(dtype_str: str) -> torch.dtype:
return DTYPE_MAP[dtype_str]


def string_list(raw_arg):
return raw_arg.split(",")
Expand Down Expand Up @@ -257,6 +269,26 @@ def __init__(self):
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=torch_dtype,
default="bfloat16",
choices=["bfloat16", "float32"],
help="""
torch dtype to use for parameters when applying mixed precision via FSDP.
This feature only takes effect when data_parallel_degree > 1
""",
)
self.parser.add_argument(
"--training.mixed_precision_reduce",
type=torch_dtype,
default="float32",
choices=["float32"],
help="""
torch dtype to use for reductions when applying mixed precision via FSDP.
This feature only takes effect when data_parallel_degree > 1
""",
)
self.parser.add_argument(
"--training.compile",
action="store_true",
Expand Down Expand Up @@ -323,8 +355,9 @@ def __init__(self):
)
self.parser.add_argument(
"--checkpoint.export_dtype",
type=str,
type=torch_dtype,
default="float32",
choices=["float16", "bfloat16", "float32"],
help="""
Converts to the specified precision when training completes and model_weights_only=true.
Currently supports float32, float16, and bfloat16.
Expand Down
13 changes: 8 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def pipeline_llama_manual(
int(job_config.training.seq_len // parallel_dims.tp),
model_config.dim,
),
dtype=torch.bfloat16 if parallel_dims.dp_enabled else torch.float32,
dtype=job_config.training.mixed_precision_param
if parallel_dims.dp_enabled
else torch.float32,
device=device,
)

Expand All @@ -254,7 +256,9 @@ def pipeline_llama_manual(
int(job_config.training.seq_len // parallel_dims.tp),
model_config.dim,
),
dtype=torch.bfloat16 if parallel_dims.dp_enabled else torch.float32,
dtype=job_config.training.mixed_precision_param
if parallel_dims.dp_enabled
else torch.float32,
device=device,
)

Expand Down Expand Up @@ -392,10 +396,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
# TODO: Expose `reduce_dtype` as a config option.
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
param_dtype=job_config.training.mixed_precision_param,
reduce_dtype=job_config.training.mixed_precision_param,
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
Expand Down

0 comments on commit abe0a2b

Please sign in to comment.