Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Settings to make Llama3-8B on 8 GPUs faster #615

Draft
wants to merge 4 commits into
base: gh/awgu/17/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,14 @@ def forward(

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)
xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)

# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = F.scaled_dot_product_attention(
xq, xk, xv, is_causal=True, enable_gqa=self.n_rep > 1
)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
Expand Down Expand Up @@ -373,7 +371,7 @@ def __init__(self, model_args: ModelArgs):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
"fused_rmsnorm", dim=model_args.dim, eps=model_args.norm_eps
)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
Expand Down Expand Up @@ -438,7 +436,7 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
output = self.output(h) if self.output else h
return output

@classmethod
Expand Down
13 changes: 5 additions & 8 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,16 @@ def apply_fsdp(
check_strided_sharding_enabled()

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
# all-gathers, which can be expensive and non-overlapped
reshard_after_forward = False
else:
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.layers) - 1
reshard_after_forward = False
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
fully_shard(model.tok_embeddings, **fsdp_config)
# Embedding weight is not needed for embedding backward
model.tok_embeddings.set_unshard_in_backward(False)
fully_shard([model.output, model.norm], **fsdp_config, reshard_after_forward=False)
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)


Expand Down
16 changes: 10 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import timedelta

import torch
import torch._inductor.config as inductor_config
from torch.distributed.elastic.multiprocessing.errors import record

from torchtitan import utils
Expand Down Expand Up @@ -142,11 +143,13 @@ def main(job_config: JobConfig):
f"{color.blue}Model {model_name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
)
if job_config.training.compile:
inductor_config.coordinate_descent_tuning = True

# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

# apply parallelisms and initialization
Expand Down Expand Up @@ -271,6 +274,8 @@ def loss_fn(pred, labels):
ntokens_since_last_log += labels.numel()
data_loading_times.append(time.perf_counter() - data_load_start)

model.tok_embeddings.unshard(async_op=True)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()
Expand All @@ -297,11 +302,10 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
if job_config.training.compile:
loss = torch.compile(loss_fn)(model(input_ids), labels)
else:
loss = loss_fn(model(input_ids), labels)
loss.backward()

# clip gradients
Expand Down
20 changes: 11 additions & 9 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100
profile_freq = 10
enable_memory_snapshot = false

[metrics]
log_freq = 10
enable_tensorboard = true
log_freq = 1
enable_color_printing = true
enable_tensorboard = false
save_tb_folder = "tb"

[model]
Expand All @@ -24,18 +26,18 @@ tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
fused = true

[training]
batch_size = 1
seq_len = 8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
data_parallel_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4"
compile = true
dataset = "c4_test"

[experimental]
pipeline_parallel_degree = 1
Expand All @@ -50,7 +52,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
mode = 'none' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
Loading