Skip to content

Commit d862b68

Browse files
committed
Fixed the comments
1 parent a9a27b1 commit d862b68

File tree

7 files changed

+33
-22
lines changed

7 files changed

+33
-22
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,12 @@ void TRTEngine::enable_profiling() {
289289
exec_ctx->setProfiler(trt_engine_profiler.get());
290290
}
291291

292-
void TRTEngine::set_requires_new_output_tensor(bool enable) {
293-
this->requires_new_output_tensor = enable;
292+
void TRTEngine::set_unowned_output_tensor(bool enable) {
293+
this->unowned_output_tensor = enable;
294294
}
295295

296-
bool TRTEngine::get_requires_new_output_tensor() {
297-
return this->requires_new_output_tensor;
296+
bool TRTEngine::is_unowned_output_tensor() {
297+
return this->unowned_output_tensor;
298298
}
299299

300300
void TRTEngine::set_profile_format(std::string format) {

core/runtime/TRTEngine.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct TRTEngine : torch::CustomClassHolder {
105105
std::pair<uint64_t, uint64_t> num_io;
106106
uint64_t io_size;
107107
std::map<std::string, bool> isShapeInferenceIO;
108-
bool requires_new_output_tensor = false;
108+
bool unowned_output_tensor = false;
109109
std::string name;
110110
RTDevice device_info;
111111

@@ -162,8 +162,8 @@ struct TRTEngine : torch::CustomClassHolder {
162162
int64_t get_automatic_device_memory_budget();
163163
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
164164
void set_pre_allocated_outputs(bool enable);
165-
void set_requires_new_output_tensor(bool enable);
166-
bool get_requires_new_output_tensor();
165+
void set_unowned_output_tensor(bool enable);
166+
bool is_unowned_output_tensor();
167167
TorchTRTRuntimeStates runtime_states;
168168
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
169169
static const char BINDING_DELIM = '%';

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
249249
if (can_use_pre_allocated_outputs) {
250250
outputs = compiled_engine->pre_allocated_outputs;
251251
} else {
252-
if (compiled_engine->allocated_outputs.size() == 0 or compiled_engine->requires_new_output_tensor or
253-
shape_changed) {
252+
if (compiled_engine->allocated_outputs.size() == 0 or compiled_engine->unowned_output_tensor or shape_changed) {
254253
compiled_engine->allocated_outputs = create_output_tensors(compiled_engine);
255254
new_outputs = true;
256255
}

core/runtime/register_jit_hooks.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
9090
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
9191
.def("infer_outputs", &TRTEngine::infer_outputs)
9292
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
93-
.def("set_requires_new_output_tensor", &TRTEngine::set_requires_new_output_tensor)
94-
.def("get_requires_new_output_tensor", &TRTEngine::get_requires_new_output_tensor)
93+
.def("set_unowned_output_tensor", &TRTEngine::set_unowned_output_tensor)
94+
.def("is_unowned_output_tensor", &TRTEngine::is_unowned_output_tensor)
9595
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
9696
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
9797
.def_property(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ def preserve_module_specs(
996996

997997
# Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor
998998
if trt_module:
999-
trt_module.set_requires_new_output_tensor(True)
999+
trt_module.set_unowned_output_tensor(True)
10001000

10011001
# Parse the graph I/O and store it in dryrun tracker
10021002
parse_graph_io(gm, dryrun_tracker)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,16 +221,28 @@ def __init__(
221221
self.use_output_allocator_outputs = False
222222
self.device = torch.cuda.current_device()
223223
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
224-
self.requires_new_output_tensor = False
224+
# If the output tensor is not owned by the engine (unowned_output_tensor=True), we need to create a new output tensor in each forward pass
225+
self.unowned_output_tensor = False
225226
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
226227
self.setup_engine()
227228
self.is_shape_inference_io = {
228229
input_name: self.engine.is_shape_inference_io(input_name)
229230
for input_name in self.input_names
230231
}
231232

232-
def set_requires_new_output_tensor(self, enabled: bool) -> None:
233-
self.requires_new_output_tensor = enabled
233+
def set_unowned_output_tensor(self, enabled: bool) -> None:
234+
"""
235+
Set the flag to indicate if the output tensor is unowned by the engine.
236+
If self.unowned_output_tensor=True, the engine will create a new output tensor in each forward pass.
237+
This would be slower but is required when users need to manipulate the output tensor after each forward pass.
238+
Therefore, this should be set to True only for the last module in a graph and leave to False for intermediate modules,
239+
which users don't have access to.
240+
Args:
241+
enabled: bool
242+
Whether to set the flag to True.
243+
244+
"""
245+
self.unowned_output_tensor = enabled
234246

235247
def get_streamable_device_memory_budget(self) -> Any:
236248
return self.engine.streamable_weights_size
@@ -520,7 +532,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
520532
)
521533
if (
522534
self.output_tensors is None
523-
or self.requires_new_output_tensor
535+
or self.unowned_output_tensor
524536
or shape_changed
525537
):
526538
self.output_tensors = self.create_output_tensors()

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
289289
metadata = TorchTensorRTModule.decode_metadata(serialized_metadata)
290290
self.settings = metadata["settings"]
291291
self.weight_name_map = metadata["weight_name_map"]
292-
self.requires_new_output_tensor = metadata["requires_new_output_tensor"]
293-
self.engine.set_requires_new_output_tensor(self.requires_new_output_tensor)
292+
self.unowned_output_tensor = metadata["unowned_output_tensor"]
293+
self.engine.set_unowned_output_tensor(self.unowned_output_tensor)
294294

295295
else:
296296
self.engine = None
@@ -362,11 +362,11 @@ def enable_profiling(
362362
self.engine.enable_profiling()
363363
self.engine.set_profile_format(profile_format)
364364

365-
def set_requires_new_output_tensor(self, enabled: bool) -> None:
366-
self.engine.set_requires_new_output_tensor(enabled)
365+
def set_unowned_output_tensor(self, enabled: bool) -> None:
366+
self.engine.set_unowned_output_tensor(enabled)
367367

368-
def get_requires_new_output_tensor(self) -> bool:
369-
return self.engine.get_requires_new_output_tensor()
368+
def is_unowned_output_tensor(self) -> bool:
369+
return self.engine.is_unowned_output_tensor()
370370

371371
def disable_profiling(self) -> None:
372372
"""Disable the profiler"""

0 commit comments

Comments
 (0)