Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Data Dependent Shape (DDS) and NonZero op #3364

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ std::vector<std::string> split(const std::string& str, char delim) {
return strings;
}

DynamicOutputAllocator::DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes)
: dtypes(output_dtypes) {}

void* DynamicOutputAllocator::reallocateOutputAsync(
char const* tensorName,
void* currentMemory,
uint64_t size,
uint64_t alignment,
cudaStream_t stream) {
std::vector<int64_t> shape = {static_cast<int64_t>(size)};
auto it = buffers.find(tensorName);
if (it == buffers.end() || it->second.sizes() != shape) {
buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(c10::kCUDA));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you know what device to use?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As soon as this happens the recapture flag needs to be set

}
return buffers[tensorName].data_ptr();
}

void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept {
shapes[tensorName] = dims;
}

TRTEngine::TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
Expand Down Expand Up @@ -137,7 +158,6 @@ TRTEngine::TRTEngine(
in_binding_names.resize(inputs);
input_buffers.resize(inputs);
out_binding_names.resize(outputs);
output_buffers.resize(outputs);
for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) {
std::string bind_name = cuda_engine->getIOTensorName(x);
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
Expand Down Expand Up @@ -179,7 +199,6 @@ TRTEngine::TRTEngine(

uint64_t outputs = _out_binding_names.size();
out_binding_names.resize(outputs);
output_buffers.resize(outputs);
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
auto binding_name = _out_binding_names[pyt_idx];
// Check if the binding name provided is in the list of engine's bindings
Expand Down
29 changes: 28 additions & 1 deletion core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,39 @@ struct TorchTRTRuntimeStates {
}
};

class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
public:
DynamicOutputAllocator(const std::unordered_map<std::string, at::ScalarType>& output_dtypes);

void* reallocateOutputAsync(
char const* tensorName,
void* currentMemory,
uint64_t size,
uint64_t alignment,
cudaStream_t stream) override;

void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override;

const std::unordered_map<std::string, at::Tensor>& getBuffers() const {
return buffers;
}

const std::unordered_map<std::string, nvinfer1::Dims>& getShapes() const {
return shapes;
}

private:
std::unordered_map<std::string, at::ScalarType> dtypes;
std::unordered_map<std::string, at::Tensor> buffers;
std::unordered_map<std::string, nvinfer1::Dims> shapes;
};

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
std::shared_ptr<DynamicOutputAllocator> output_allocator;
std::pair<uint64_t, uint64_t> num_io;
std::string name;
RTDevice device_info;
Expand Down Expand Up @@ -141,7 +169,6 @@ struct TRTEngine : torch::CustomClassHolder {
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key = "None";
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;
Expand Down
95 changes: 41 additions & 54 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,23 @@ void setup_input_tensors(
}
}
}
std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
for (auto output_indices : compiled_engine->out_binding_map) {
// out_binding_map stores TRT_IDX: PYT_IDX
auto pyt_idx = output_indices.second;

std::string name = compiled_engine->out_binding_names[pyt_idx];
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);

auto dims = core::util::toVec(out_shape);
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());

void setup_output_allocator(c10::intrusive_ptr<TRTEngine> compiled_engine) {
if (compiled_engine->output_allocator == nullptr) {
std::unordered_map<std::string, at::ScalarType> output_dtypes_dict;
for (size_t o = 0; o < compiled_engine->out_binding_names.size(); ++o) {
auto name = compiled_engine->out_binding_names[o];
output_dtypes_dict[name] =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
}
compiled_engine->output_allocator = std::make_shared<DynamicOutputAllocator>(output_dtypes_dict);
}

return outputs;
for (const auto& output_name : compiled_engine->out_binding_names) {
if (!compiled_engine->exec_ctx->setOutputAllocator(output_name.c_str(), compiled_engine->output_allocator.get())) {
throw std::runtime_error("Failed to set output allocator for " + output_name);
}
}
}

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
Expand Down Expand Up @@ -218,7 +219,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}

// Intialize inputs and outputs to be available throughout the succeeding scopes
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);

if (MULTI_DEVICE_SAFE_MODE) {
std::unique_ptr<torch::autograd::profiler::RecordProfile> device_profiler_guard;
Expand Down Expand Up @@ -287,44 +287,20 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
<< " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly");
}

{ // Output Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
{ // OutputAllocator Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_allocator_profiler_guard;
if (compiled_engine->profile_execution) {
output_profiler_guard =
output_allocator_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
}
if (can_use_pre_allocated_outputs) {
outputs = compiled_engine->pre_allocated_outputs;
} else {
outputs = create_output_tensors(compiled_engine);
}

for (auto output_indices : compiled_engine->out_binding_map) {
auto pyt_idx = output_indices.second;
std::string name = compiled_engine->out_binding_names[pyt_idx];
if (need_cudagraphs_record) {
// If we are recording the cuda graph then we need to update the persistent output buffer
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}

if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
} else {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
}
}
setup_output_allocator(compiled_engine);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do this everytime the engine is going to run?

}

auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else if (outputs.size() > 0) {
current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart
} else {
current_device_id = c10::cuda::current_device();
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
Expand Down Expand Up @@ -368,21 +344,32 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
} // End engine exeuction (resets to caller stream)

// Create output buffer for next execution of graph or trt context.
if (compiled_engine->use_pre_allocated_outputs) {
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
}

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

if (cudagraphs_enabled) {
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
outputs[o].copy_(compiled_engine->output_buffers[o], false);
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
if (compiled_engine->profile_execution) {
output_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
}
std::vector<at::Tensor> outputs;
for (size_t i = 0; i < compiled_engine->out_binding_names.size(); i++) {
auto name = compiled_engine->out_binding_names[i];
auto dims = compiled_engine->output_allocator->getShapes().at(name);
auto dtype = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
at::Tensor output = compiled_engine->output_allocator->getBuffers().at(name).clone().detach();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not directly copy to the result buffers like before instead of allocating new memory here? or does it not make a difference.

int64_t prod = 1;
for (int i = 0; i < dims.nbDims; ++i) {
prod *= dims.d[i];
}
std::vector<int64_t> dims_vec(dims.nbDims);
for (int i = 0; i < dims.nbDims; ++i) {
dims_vec[i] = dims.d[i];
}
output = output.reshape(-1).view(dtype).slice(0, 0, prod).reshape(dims_vec);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

outputs.push_back(output);
}

if (compiled_engine->profile_execution) {
Expand Down
113 changes: 0 additions & 113 deletions examples/dynamo/pre_allocated_output_example.py

This file was deleted.

17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3582,3 +3582,20 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
def aten_ops_nonzero(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.nonzero(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,18 @@ def native_dropout(
mask = np.ones(input_val.shape, dtype=bool)
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
return identity_layer.get_output(0), mask


def nonzero(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
non_zero_layer = ctx.net.add_non_zero(input_val)
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
shuffle_layer.first_transpose = trt.Permutation([1, 0])
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
return shuffle_layer.get_output(0)
Loading