Skip to content
2 changes: 1 addition & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ set -ex
# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU
# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
export LOG_RANK=${LOG_RANK:-0,2}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training
Expand Down
75 changes: 71 additions & 4 deletions torchtitan/distributed/dual_pipe_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool:
)

if dual_pipe_v and job_config.activation_checkpoint.mode != "none":
raise NotImplementedError(
"Expert Parallel with DualPipeV and Activation Checkpointing "
"cannot be used together. Please disable one of them."
)
pass
# raise NotImplementedError(
# "Expert Parallel with DualPipeV and Activation Checkpointing "
# "cannot be used together. Please disable one of them."
# )

return dual_pipe_v

Expand Down Expand Up @@ -98,6 +99,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
)


# Thread-local flag to track if we're in the backward thread
# Any SyncHook.forward call from the backward thread is checkpoint recomputation
_backward_thread_flag = threading.local()


class HookCoordinator:
def __init__(self):
# Barrier for 2 threads (forward and backward) to synchronize
Expand Down Expand Up @@ -141,6 +147,16 @@ def is_coordination_enabled(self):
return self._coordination_enabled


def _is_in_backward_thread() -> bool:
"""Check if current thread is the backward thread."""
return getattr(_backward_thread_flag, 'value', False)


def _set_backward_thread_flag(value: bool):
"""Set the backward thread flag for current thread."""
_backward_thread_flag.value = value


# Global coordinator
_hook_coordinator = HookCoordinator()

Expand All @@ -150,6 +166,16 @@ class SyncHook(torch.autograd.Function):
# pyrefly: ignore [bad-override]
def forward(ctx, x, hook_name=""):
ctx.hook_name = hook_name

# Skip barrier if we're in the backward thread - this means we're being called
# during checkpoint recomputation (the forward thread never sets this flag)
if _is_in_backward_thread():
print("skipping backward barrier", flush=True)
ctx.skip_backward_barrier = True
return x

ctx.skip_backward_barrier = False

# handle edge case for transformer level boundary
if _hook_coordinator._coordination_enabled and hook_name == "D":
_hook_coordinator._cycle_count += 1
Expand All @@ -165,6 +191,13 @@ def forward(ctx, x, hook_name=""):
def backward(ctx, grad_output):
hook_name = ctx.hook_name

# Skip barrier if this backward corresponds to a checkpoint recompute forward
# These are "extra" backward nodes created by checkpoint that don't have
# corresponding partners in the other thread
if ctx.skip_backward_barrier:
print("skipping backward barrier", flush=True)
return grad_output, None

# Edge case, skip initial barrier, all subsequent backward hooks will acquire
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
return grad_output, None
Expand All @@ -184,6 +217,9 @@ def _count_moe_modules(model):
return moe_count


# import fbvscode
# fbvscode.attach_debugger()

device_type, device_module = get_device_info()


Expand Down Expand Up @@ -264,6 +300,10 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):

# Shared container for exception from backward thread
def run_backward():
# Mark this thread as the backward thread so SyncHook.forward
# can detect checkpoint recomputation (forward called from backward thread)
_set_backward_thread_flag(True)

# pyrefly: ignore [missing-attribute]
schedule._assert_unsharded(backward_stage)
# Set the backward thread to use the same stream as forward
Expand Down Expand Up @@ -294,6 +334,24 @@ def run_backward():
# pyrefly: ignore [bad-argument-type]
backward_mb_index,
)
backward_stage.backward_one_chunk(
backward_mb_index,
loss=loss,
full_backward=True,
last_backward=last_backward,
)

if backward_is_prev_stage_on_this_rank:
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
backward_stage.get_local_bwd_output(backward_mb_index),
backward_mb_index,
)
except BaseException as e:
backward_exception.append(e)
# Abort barrier to unblock forward thread if it's waiting
_hook_coordinator.disable_coordination()
finally:
_set_backward_thread_flag(False)

def run_forward():
# pyrefly: ignore [missing-attribute]
Expand All @@ -306,6 +364,11 @@ def run_forward():
# pyrefly: ignore [bad-index, unsupported-operation]
kwarg_mbs[forward_mb_index],
)
# # TODO its error prone to have this logic scattered inside and outside the runtime file..
# # this goes along with the patch to pytorch: https://github.com/pytorch/pytorch/pull/167002/
# key = f"{forward_stage.stage_index}_{forward_mb_index}"
# assert key not in schedule.ownership_tokens
# schedule.ownership_tokens[key] = output.view_as(output).grad_fn
schedule._maybe_compute_loss(
forward_stage, output, ctx.target_mbs, forward_mb_index
)
Expand All @@ -323,3 +386,7 @@ def run_forward():
thread.join()

_hook_coordinator.disable_coordination()

# Re-raise exception from backward thread with full traceback
if backward_exception:
raise backward_exception[0]
Loading
Loading