Skip to content

FX graph visualization #3528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ void TRTEngine::enable_profiling() {
exec_ctx->setProfiler(trt_engine_profiler.get());
}

void TRTEngine::set_profile_format(std::string format) {
if (format == "trex") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
} else if (format == "perfetto") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kPERFETTO);
} else {
TORCHTRT_THROW_ERROR("Invalid profile format: " + format);
}
}

std::string TRTEngine::get_engine_layer_info() {
auto inspector = cuda_engine->createEngineInspector();
return inspector->getEngineInformation(nvinfer1::LayerInformationFormat::kJSON);
Expand Down
1 change: 1 addition & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
void enable_profiling();
void set_profile_format(std::string profile_format);
void disable_profiling();
std::string get_engine_layer_info();

Expand Down
31 changes: 23 additions & 8 deletions core/runtime/TRTEngineProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,40 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector<
}
}

void TRTEngineProfiler::set_profile_format(TraceFormat format) {
this->profile_format = format;
}

void dump_trace(const std::string& path, const TRTEngineProfiler& value) {
std::stringstream out;
out << "[" << std::endl;
double ts = 0.0;
double running_time = 0.0;
for (size_t i = 0; i < value.layer_names.size(); i++) {
auto layer_name = value.layer_names[i];
auto elem = value.profile.at(layer_name);
ts += elem.time;
}
for (size_t i = 0; i < value.layer_names.size(); i++) {
auto layer_name = value.layer_names[i];
auto elem = value.profile.at(layer_name);

out << " {" << std::endl;
out << " \"name\": \"" << layer_name << "\"," << std::endl;
out << " \"ph\": \"X\"," << std::endl;
out << " \"ts\": " << ts * 1000 << "," << std::endl;
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
out << " \"tid\": 1," << std::endl;
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
out << " \"args\": {}" << std::endl;
if (value.profile_format == TraceFormat::kPERFETTO) {
out << " \"ph\": \"X\"," << std::endl;
out << " \"ts\": " << running_time * 1000 << "," << std::endl;
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
out << " \"tid\": 1," << std::endl;
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
out << " \"args\": {}" << std::endl;
} else { // kTREX
out << " \"timeMs\": " << elem.time << "," << std::endl;
out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl;
out << " \"percentage\": " << (elem.time * 100.0 / ts) << "," << std::endl;
}
out << " }," << std::endl;

ts += elem.time;
running_time += elem.time;
}
out.seekp(-2, out.cur);
out << "\n]" << std::endl;
Expand Down
7 changes: 6 additions & 1 deletion core/runtime/TRTEngineProfiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

enum TraceFormat { kPERFETTO, kTREX };

// Forward declare the function

struct TRTEngineProfiler : public nvinfer1::IProfiler {
struct Record {
float time{0};
int count{0};
};

void set_profile_format(TraceFormat format);
virtual void reportLayerTime(const char* layerName, float ms) noexcept;
TRTEngineProfiler(
const std::string& name,
Expand All @@ -27,6 +31,7 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler {
std::string name;
std::vector<std::string> layer_names;
std::map<std::string, Record> profile;
TraceFormat profile_format = TraceFormat::kPERFETTO;
};

} // namespace runtime
Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("__repr__", &TRTEngine::to_str)
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
.def("enable_profiling", &TRTEngine::enable_profiling)
.def("set_profile_format", &TRTEngine::set_profile_format)
.def("disable_profiling", &TRTEngine::disable_profiling)
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
Expand Down
167 changes: 167 additions & 0 deletions py/torch_tensorrt/dynamo/_Debugger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import logging
import os
import tempfile
from logging.config import dictConfig
from typing import Any, List, Optional

import torch
from torch_tensorrt.dynamo.lowering import (
ATEN_POST_LOWERING_PASSES,
ATEN_PRE_LOWERING_PASSES,
)

_LOGGER = logging.getLogger("torch_tensorrt [TensorRT Conversion Context]")
GRAPH_LEVEL = 5
logging.addLevelName(GRAPH_LEVEL, "GRAPHS")


class Debugger:
def __init__(
self,
log_level: str,
capture_fx_graph_before: Optional[List[str]] = None,
capture_fx_graph_after: Optional[List[str]] = None,
save_engine_profile: bool = False,
logging_dir: Optional[str] = None,
):
self.debug_file_dir = tempfile.TemporaryDirectory().name

if log_level == "debug":
self.log_level = logging.DEBUG
elif log_level == "info":
self.log_level = logging.INFO
elif log_level == "warning":
self.log_level = logging.WARNING
elif log_level == "error":
self.log_level = logging.ERROR
elif log_level == "internal_errors":
self.log_level = logging.CRITICAL
elif log_level == "graphs":
self.log_level = GRAPH_LEVEL

else:
raise ValueError(
f"Invalid level: {log_level}, allowed levels are: debug, info, warning, error, internal_errors, graphs"
)

self.capture_fx_graph_before = capture_fx_graph_before
self.capture_fx_graph_after = capture_fx_graph_after

if logging_dir is not None:
self.debug_file_dir = logging_dir
os.makedirs(self.debug_file_dir, exist_ok=True)

def __enter__(self) -> None:
self.original_lvl = _LOGGER.getEffectiveLevel()
self.rt_level = torch.ops.tensorrt.get_logging_level()
dictConfig(self.get_config())

if self.capture_fx_graph_before or self.capture_fx_graph_after:
self.old_pre_passes, self.old_post_passes = (
ATEN_PRE_LOWERING_PASSES.passes,
ATEN_POST_LOWERING_PASSES.passes,
)
pre_pass_names = [p.__name__ for p in self.old_pre_passes]
post_pass_names = [p.__name__ for p in self.old_post_passes]
path = os.path.join(self.debug_file_dir, "lowering_passes_visualization")
if self.capture_fx_graph_before is not None:
pre_vis_passes = [
p for p in self.capture_fx_graph_before if p in pre_pass_names
]
post_vis_passes = [
p for p in self.capture_fx_graph_before if p in post_pass_names
]
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_before(pre_vis_passes, path)
ATEN_POST_LOWERING_PASSES.insert_debug_pass_before(
post_vis_passes, path
)
if self.capture_fx_graph_after is not None:
pre_vis_passes = [
p for p in self.capture_fx_graph_after if p in pre_pass_names
]
post_vis_passes = [
p for p in self.capture_fx_graph_after if p in post_pass_names
]
ATEN_PRE_LOWERING_PASSES.insert_debug_pass_after(pre_vis_passes, path)
ATEN_POST_LOWERING_PASSES.insert_debug_pass_after(post_vis_passes, path)

def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:

dictConfig(self.get_default_config())
torch.ops.tensorrt.set_logging_level(self.rt_level)
if self.capture_fx_graph_before or self.capture_fx_graph_after:
ATEN_PRE_LOWERING_PASSES.passes, ATEN_POST_LOWERING_PASSES.passes = (
self.old_pre_passes,
self.old_post_passes,
)
self.debug_file_dir = tempfile.TemporaryDirectory().name

def get_customized_logging_config(self) -> dict[str, Any]:
config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"brief": {
"format": "%(asctime)s - %(levelname)s - %(message)s",
"datefmt": "%H:%M:%S",
},
"standard": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
},
"handlers": {
"file": {
"level": self.log_level,
"class": "logging.FileHandler",
"filename": f"{self.debug_file_dir}/torch_tensorrt_logging.log",
"formatter": "standard",
},
"console": {
"level": self.log_level,
"class": "logging.StreamHandler",
"formatter": "brief",
},
},
"loggers": {
"": { # root logger
"handlers": ["file", "console"],
"level": self.log_level,
"propagate": True,
},
},
"force": True,
}
return config

def get_default_logging_config(self) -> dict[str, Any]:
config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"brief": {
"format": "%(asctime)s - %(levelname)s - %(message)s",
"datefmt": "%H:%M:%S",
},
"standard": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
},
"handlers": {
"console": {
"level": self.original_lvl,
"class": "logging.StreamHandler",
"formatter": "brief",
},
},
"loggers": {
"": { # root logger
"handlers": ["console"],
"level": self.original_lvl,
"propagate": True,
},
},
"force": True,
}
return config
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
load_cross_compiled_exported_program,
save_cross_compiled_exported_program,
)
from ._Debugger import Debugger
from ._exporter import export
from ._refit import refit_module_weights
from ._settings import CompilationSettings
Expand Down
6 changes: 1 addition & 5 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,6 @@ def compile(
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
"""

if debug:
set_log_level(logger.parent, logging.DEBUG)
if "truncate_long_and_double" in kwargs.keys():
if truncate_double is not _defaults.TRUNCATE_DOUBLE:
raise ValueError(
Expand Down Expand Up @@ -725,7 +723,7 @@ def compile_module(

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
gm, settings.torch_executed_ops
)

dryrun_tracker.total_ops_in_graph = total_ops
Expand Down Expand Up @@ -777,7 +775,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
logger.info("Partitioning the graph via the fast partitioner")
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
Expand All @@ -798,7 +795,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
logger.info("Partitioning the graph via the global partitioner")
partitioned_module, supported_ops = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
Expand Down
7 changes: 0 additions & 7 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
check_module_output,
get_model_device,
get_torch_inputs,
set_log_level,
to_torch_device,
to_torch_tensorrt_device,
)
Expand Down Expand Up @@ -72,7 +71,6 @@ def construct_refit_mapping(
interpreter = TRTInterpreter(
module,
inputs,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
output_dtypes=output_dtypes,
compilation_settings=settings,
)
Expand Down Expand Up @@ -266,9 +264,6 @@ def refit_module_weights(
not settings.immutable_weights
), "Refitting is not enabled. Please recompile the engine with immutable_weights=False."

if settings.debug:
set_log_level(logger.parent, logging.DEBUG)

device = to_torch_tensorrt_device(settings.device)
if arg_inputs:
if not isinstance(arg_inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -304,7 +299,6 @@ def refit_module_weights(
try:
new_partitioned_module, supported_ops = partitioning.fast_partition(
new_gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
Expand All @@ -320,7 +314,6 @@ def refit_module_weights(
if not settings.use_fast_partitioner:
new_partitioned_module, supported_ops = partitioning.global_partition(
new_gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(
self,
module: torch.fx.GraphModule,
input_specs: Sequence[Input],
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
output_dtypes: Optional[Sequence[dtype]] = None,
compilation_settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
from typing import Any, List, Optional, Sequence

import tensorrt as trt
import torch
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand Down Expand Up @@ -60,7 +59,6 @@ def interpret_module_to_result(
interpreter = TRTInterpreter(
module,
inputs,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
output_dtypes=output_dtypes,
compilation_settings=settings,
engine_cache=engine_cache,
Expand Down
Loading
Loading