Skip to content

Commit 031267c

Browse files
committed
Fixed the comments, changed dump function
1 parent 1261983 commit 031267c

File tree

7 files changed

+17
-19
lines changed

7 files changed

+17
-19
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,12 @@ void TRTEngine::enable_profiling() {
283283

284284
void TRTEngine::set_profile_format(std::string format) {
285285
if (format == "trex") {
286-
profile_format = TraceFormat::kTREX;
286+
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
287287
} else if (format == "perfetto") {
288-
profile_format = TraceFormat::kPERFETTO;
288+
this->trt_engine_profiler->set_profile_format(TraceFormat::kPERFETTO);
289289
} else {
290290
TORCHTRT_THROW_ERROR("Invalid profile format: " + format);
291291
}
292-
293-
profile_format = profile_format;
294292
}
295293

296294
std::string TRTEngine::get_engine_layer_info() {

core/runtime/TRTEngine.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ struct TRTEngine : torch::CustomClassHolder {
192192
#else
193193
bool profile_execution = false;
194194
#endif
195-
TraceFormat profile_format = TraceFormat::kPERFETTO;
196195
std::string device_profile_path;
197196
std::string input_profile_path;
198197
std::string output_profile_path;

core/runtime/TRTEngineProfiler.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ TRTEngineProfiler::TRTEngineProfiler(const std::string& name, const std::vector<
3232
}
3333
}
3434

35-
void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format) {
35+
void TRTEngineProfiler::set_profile_format(TraceFormat format) {
36+
this->profile_format = format;
37+
}
38+
39+
void dump_trace(const std::string& path, const TRTEngineProfiler& value) {
3640
std::stringstream out;
3741
out << "[" << std::endl;
3842
double ts = 0.0;
@@ -48,17 +52,17 @@ void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFo
4852

4953
out << " {" << std::endl;
5054
out << " \"name\": \"" << layer_name << "\"," << std::endl;
51-
if (format == kPERFETTO) {
55+
if (value.profile_format == TraceFormat::kPERFETTO) {
5256
out << " \"ph\": \"X\"," << std::endl;
5357
out << " \"ts\": " << running_time * 1000 << "," << std::endl;
5458
out << " \"dur\": " << elem.time * 1000 << "," << std::endl;
5559
out << " \"tid\": 1," << std::endl;
5660
out << " \"pid\": \"" << value.name << " Engine Execution\"," << std::endl;
61+
out << " \"args\": {}" << std::endl;
5762
} else { // kTREX
5863
out << " \"timeMs\": " << elem.time << "," << std::endl;
5964
out << " \"averageMs\": " << elem.time / elem.count << "," << std::endl;
6065
out << " \"percentage\": " << (elem.time * 100.0 / ts) << "," << std::endl;
61-
out << " \"args\": {}" << std::endl;
6266
}
6367
out << " }," << std::endl;
6468
running_time += elem.time;

core/runtime/TRTEngineProfiler.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,19 @@ struct TRTEngineProfiler : public nvinfer1::IProfiler {
1919
float time{0};
2020
int count{0};
2121
};
22-
22+
void set_profile_format(TraceFormat format);
2323
virtual void reportLayerTime(const char* layerName, float ms) noexcept;
2424
TRTEngineProfiler(
2525
const std::string& name,
2626
const std::vector<TRTEngineProfiler>& srcProfilers = std::vector<TRTEngineProfiler>());
2727
friend std::ostream& operator<<(std::ostream& out, const TRTEngineProfiler& value);
28-
friend void dump_trace(const std::string& path, const TRTEngineProfiler& value, TraceFormat format);
28+
friend void dump_trace(const std::string& path, const TRTEngineProfiler& value);
2929

3030
private:
3131
std::string name;
3232
std::vector<std::string> layer_names;
3333
std::map<std::string, Record> profile;
34+
TraceFormat profile_format = TraceFormat::kPERFETTO;
3435
};
3536

3637
} // namespace runtime

core/runtime/execute_engine.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
339339

340340
if (compiled_engine->profile_execution) {
341341
LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler);
342-
dump_trace(
343-
compiled_engine->trt_engine_profile_path,
344-
*compiled_engine->trt_engine_profiler,
345-
compiled_engine->profile_format);
342+
dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler);
346343
compiled_engine->dump_engine_layer_info();
347344
}
348345

@@ -443,10 +440,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
443440

444441
if (compiled_engine->profile_execution) {
445442
LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler);
446-
dump_trace(
447-
compiled_engine->trt_engine_profile_path,
448-
*compiled_engine->trt_engine_profiler,
449-
compiled_engine->profile_format);
443+
dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler);
450444
compiled_engine->dump_engine_layer_info();
451445
}
452446

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
335335
return tuple(outputs)
336336

337337
def enable_profiling(
338-
self, profiling_results_dir: Optional[str] = None, profile_format: str = "trex"
338+
self,
339+
profiling_results_dir: Optional[str] = None,
340+
profile_format: str = "perfetto",
339341
) -> None:
340342
"""Enable the profiler to collect latency information about the execution of the engine
341343

0 commit comments

Comments
 (0)