Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def parallelize_deepseekv3(
# Get compiler passes from config
compiler_passes = get_compiler_passes_from_config(model, job_config)

# Create compilers with specified passes (defaults to no passes)
# Create compilers with specified passes
fw_compiler, bw_compiler = make_compiler_with_passes(
compiler_passes, dump_folder=job_config.job.dump_folder
)
Expand All @@ -94,6 +94,7 @@ def parallelize_deepseekv3(
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
job_config=job_config,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
139 changes: 123 additions & 16 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> No
def export_joint(
model, args, kwargs=None, dump_folder: str | None = None
) -> tuple[JointWithDescriptors, TracingContext]:
"""
Export joint forward-backward graph with AOT Autograd.

Args:
model: The model to export
args: Tuple of input arguments
kwargs: Dict of keyword arguments for the model
dump_folder: Optional folder to dump the graph to
"""
if kwargs is None:
kwargs = {}
assert isinstance(args, tuple)
Expand Down Expand Up @@ -68,6 +77,14 @@ def export_joint(


def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
"""
Export joint forward-backward graph with AOT Autograd.

Args:
model: The model to export
args: Tuple of input arguments
kwargs: Dict of keyword arguments for the model
"""
if kwargs is None:
kwargs = {}
assert isinstance(args, tuple)
Expand All @@ -79,6 +96,7 @@ def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
args,
kwargs,
)

return joint_with_descriptors


Expand All @@ -90,6 +108,7 @@ def joint_graph_builder(
bw_compiler: Optional[Callable] = None,
joint_custom_passes: Optional[List[Callable]] = None,
dump_folder: str | None = None,
job_config: Optional["JobConfig"] = None,
):
"""
Build a joint forward-backward graph for the model with optional custom compilers.
Expand All @@ -102,16 +121,41 @@ def joint_graph_builder(
bw_compiler: Optional custom backward compiler function
joint_custom_passes: list of custom passes to run on the joint graph
dump_folder: Optional folder to dump the graph to
job_config: Job configuration
"""
assert isinstance(model_args, tuple)
for idx, arg in enumerate(model_args):
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"

# get joint graph
(
joint_with_descriptors,
tracing_context,
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
(joint_with_descriptors, tracing_context,) = export_joint(
model,
model_args,
model_kwargs,
dump_folder=dump_folder,
)

# Check if inductor_decomposition is configured and create the pass with proper context
if job_config is not None:
joint_pass_names = getattr(job_config.compile, "joint_passes", [])
if "inductor_decomposition" in joint_pass_names:
from torchtitan.experiments.compiler_toolkit.passes import (
inductor_decomposition_pass,
)

# Create the decomposition pass with context
decomp_pass = functools.partial(
inductor_decomposition_pass,
model=model,
joint_with_descriptors=joint_with_descriptors,
forward_inputs=model_args,
tracing_context=tracing_context,
)

# Prepend to joint_custom_passes
if joint_custom_passes is None:
joint_custom_passes = []
joint_custom_passes = [decomp_pass] + joint_custom_passes

# run custom passes on joint-graph before partitioner
if joint_custom_passes is not None:
Expand Down Expand Up @@ -259,28 +303,36 @@ def compiler(
logger.info(f"Applying pass: {pass_name}")
gm = pass_fn(gm, example_inputs)

logger.debug(f"{name} after compiler:")
logger.debug(
gm.print_readable(print_output=False, include_stride=True, include_device=True)
)
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
# Only try to print/dump if gm is still a GraphModule
# (compile_fx_inner returns a CompiledFxGraph which doesn't have print_readable)
if hasattr(gm, "print_readable"):
logger.debug(f"{name} after compiler:")
logger.debug(
gm.print_readable(
print_output=False, include_stride=True, include_device=True
)
)
_dump_gm(dump_folder, gm, f"{name}_after_compiler")

return gm


def make_compiler_with_passes(
passes: List[Callable] = None, dump_folder: str | None = None
passes: List[Callable] = None,
dump_folder: str | None = None,
):
"""
Create forward and backward compilers with specified passes.

Args:
passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES.
dump_folder: Optional folder to dump graphs

Returns:
Tuple of (fw_compiler, bw_compiler) functions
"""

def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
def fw_compiler(gm: torch.fx.GraphModule, example_inputs):
return compiler(
"fwd_gm",
gm,
Expand All @@ -290,7 +342,7 @@ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
is_forward=True,
)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
def bw_compiler(gm: torch.fx.GraphModule, example_inputs):
return compiler(
"bwd_gm",
gm,
Expand All @@ -303,7 +355,17 @@ def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return fw_compiler, bw_compiler


def validate_pass_names(pass_names: list[str]) -> None:
def validate_pass_names(pass_names: list[str], joint_pass_names: list[str]) -> None:
"""
Validate compiler and joint pass names and their dependencies.

Args:
pass_names: List of compiler pass names
joint_pass_names: List of joint custom pass names

Raises:
ValueError: If pass configuration is invalid
"""
if "cudagraph" in pass_names:
assert (
pass_names[-1] == "cudagraph"
Expand All @@ -317,13 +379,22 @@ def validate_pass_names(pass_names: list[str]) -> None:
"Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!"
)

# Validate that full_inductor_compilation requires inductor_decomposition
if "full_inductor_compilation" in pass_names:
if "inductor_decomposition" not in joint_pass_names:
raise ValueError(
"full_inductor_compilation pass requires inductor_decomposition to be "
"specified in joint_passes. Please add --compile.joint_passes inductor_decomposition"
)


def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig):
"""
Extract and validate compiler passes from job config.

Args:
job_config: Job configuration containing compile.passes
model: The model being compiled
job_config: Job configuration containing compile.passes and compile.joint_passes

Returns:
List of compiler pass functions
Expand All @@ -334,9 +405,18 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi
)

pass_names = getattr(job_config.compile, "passes", [])
validate_pass_names(pass_names)
joint_pass_names = getattr(job_config.compile, "joint_passes", [])

validate_pass_names(pass_names, joint_pass_names)
compiler_passes = []

# Warn if full Inductor compilation is enabled
if "full_inductor_compilation" in pass_names:
logger.warning(
"Full Inductor compilation is enabled. Note that Inductor may change numerics "
"and does not guarantee bitwise equivalent results compared to eager mode."
)

for pass_name in pass_names:
if pass_name not in AVAILABLE_COMPILER_PASSES:
raise ValueError(
Expand All @@ -360,25 +440,52 @@ def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfi


def get_joint_custom_passes_from_config(
parallel_dims: ParallelDims, job_config: JobConfig
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Extract and validate joint custom passes from job config.

Note: The inductor_decomposition pass is handled separately in joint_graph_builder
because it requires context (model, joint_with_descriptors, etc.) that's only
available at graph capture time.

Args:
parallel_dims: Parallelism dimensions
job_config: Job configuration containing parallelism.fsdp_reshard_after_forward
and compile.joint_passes

Returns:
List of joint custom pass functions
"""
from torchtitan.experiments.compiler_toolkit.passes import (
AVAILABLE_JOINT_PASSES,
fsdp_reshard_after_fwd_pass,
validate_flex_attn_annotation_pass,
)

joint_custom_passes = []
joint_custom_passes.append(validate_flex_attn_annotation_pass)

# Handle joint passes from config (excluding inductor_decomposition)
joint_pass_names = getattr(job_config.compile, "joint_passes", [])
for pass_name in joint_pass_names:
if pass_name not in AVAILABLE_JOINT_PASSES:
raise ValueError(
f"Unknown joint pass: {pass_name}. "
f"Available joint passes: {list(AVAILABLE_JOINT_PASSES.keys())}"
)

# Skip inductor_decomposition - it's handled in joint_graph_builder
if pass_name == "inductor_decomposition":
continue

joint_custom_passes.append(AVAILABLE_JOINT_PASSES[pass_name])

if joint_pass_names:
logger.info(f"Using joint passes from config: {joint_pass_names}")

# Handle FSDP reshard after forward
match job_config.parallelism.fsdp_reshard_after_forward:
case "always":
fsdp_reshard_after_forward = True
Expand Down
17 changes: 14 additions & 3 deletions torchtitan/experiments/compiler_toolkit/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,22 @@
@dataclass
class Compile:
"""
List of compiler pass names to apply in the compiler toolkit workflow.
By default, no passes are applied.
Example: --compile.passes autobucketing_reordering,regional_inductor
Compiler configuration for the compiler toolkit workflow.

- joint_passes: List of joint graph pass names to apply on the joint forward-backward
graph before partitioning.

Example: --compile.joint_passes inductor_decomposition

- passes: List of compiler pass names to apply to the partitioned forward/backward graphs.

Example: --compile.passes full_inductor_compilation

Note: If "full_inductor_compilation" is specified, "inductor_decomposition" must
be included in joint_passes.
"""

joint_passes: list[str] = field(default_factory=list)
passes: list[str] = field(default_factory=list)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def parallelize_llama(
# Get compiler passes from config
compiler_passes = get_compiler_passes_from_config(model, job_config)

# Create compilers with specified passes (defaults to no passes)
# Create compilers with specified passes
fw_compiler, bw_compiler = make_compiler_with_passes(
compiler_passes, dump_folder=job_config.job.dump_folder
)
Expand All @@ -81,6 +81,7 @@ def parallelize_llama(
bw_compiler=bw_compiler,
joint_custom_passes=joint_custom_passes,
dump_folder=job_config.job.dump_folder,
job_config=job_config,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
Loading
Loading