Skip to content

Commit 690a55d

Browse files
cherry pick from 3008 to release/2.4 (#3035)
1 parent 9cced93 commit 690a55d

File tree

3 files changed

+506
-0
lines changed

3 files changed

+506
-0
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

+5
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def trunc_div(
6565
prod_output,
6666
)
6767

68+
# cast the sign_output back to int32 for trunc div
69+
# This is required for scatter_reduce_.two(reduce='mean' where trunc_div casts it to float32 and TRTInterpreter expects int32)
70+
if (isinstance(sign_output, TRTTensor)) and (sign_output.dtype == trt.float32):
71+
sign_output = cast_trt_tensor(ctx, sign_output, trt.int32, name)
72+
6873
# Convert constant input into ITensor for UnaryOperation
6974
if not isinstance(input, trt.tensorrt.ITensor):
7075
input = get_trt_tensor(ctx, input, f"{name}_input")

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+94
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from enum import Enum, auto
23
from typing import Any, Callable, Dict, List, Optional
34

45
import torch
@@ -238,6 +239,99 @@ def empty_strided_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
238239
return torch.as_strided(torch.empty(empty_size), empty_size, empty_stride)
239240

240241

242+
# enum class for reduce operation of scatter_reduce
243+
class ReduceOperation(Enum):
244+
SUM = ("Sum reduce operation", lambda x, y: torch.add(x, y))
245+
PROD = ("Product reduce operation", lambda x, y: torch.mul(x, y))
246+
MEAN = ("Mean reduce operation", lambda x, y: torch.add(x, y))
247+
AMAX = ("Amax reduce operation", lambda x, y: torch.max(x, y))
248+
AMIN = ("Amin reduce operation", lambda x, y: torch.min(x, y))
249+
250+
def __new__(cls, description, func):
251+
obj = object.__new__(cls)
252+
obj._value_ = auto()
253+
obj.description = description
254+
obj.func = func
255+
return obj
256+
257+
def reduce_operation_with_scatter(
258+
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
259+
):
260+
scatter_tensor = None
261+
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
262+
scatter_tensor = torch.zeros_like(initial_tensor)
263+
elif self == ReduceOperation.PROD:
264+
scatter_tensor = torch.ones_like(initial_tensor)
265+
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
266+
scatter_tensor = initial_tensor
267+
else:
268+
# This case would not be encountered from torch itself
269+
print("Invalid Operation for Reduce op!!")
270+
271+
operation_rhs = torch.scatter(scatter_tensor, dim, index_tensor, src_tensor)
272+
device = to_torch_device(default_device())
273+
operation_lhs = operation_lhs.to(device)
274+
operation_rhs = operation_rhs.to(device)
275+
return self.func(operation_lhs, operation_rhs)
276+
277+
278+
@register_torch_trt_decomposition(
279+
torch.ops.aten.scatter_reduce.two, registry=TORCH_TRT_DECOMPOSITIONS
280+
)
281+
def scatter_reduce_decomposition(
282+
input_tensor: torch.Tensor,
283+
dim: int,
284+
index: torch.Tensor,
285+
src_tensor: torch.Tensor,
286+
reduce: str,
287+
) -> torch.Tensor:
288+
scatter_loop_tensor = input_tensor
289+
# required for mean reduce operation
290+
scatter_count_tensor = torch.zeros_like(input_tensor)
291+
src_shape = list(src_tensor.shape)
292+
src_dim = src_shape[dim]
293+
294+
for i in range(0, src_dim):
295+
src_slice = torch.select(src_tensor, dim, i)
296+
index_slice = torch.select(index, dim, i)
297+
# unsqueeze src and index in dim
298+
src_slice = torch.unsqueeze(src_slice, dim)
299+
index_slice = torch.unsqueeze(index_slice, dim)
300+
device = to_torch_device(default_device())
301+
302+
# moving tensor to default device
303+
scatter_loop_tensor = scatter_loop_tensor.to(device)
304+
index_slice = index_slice.to(device)
305+
src_slice = src_slice.to(device)
306+
if reduce == "sum":
307+
reduceOp = ReduceOperation.SUM
308+
elif reduce == "prod":
309+
reduceOp = ReduceOperation.PROD
310+
elif reduce == "mean":
311+
reduceOp = ReduceOperation.MEAN
312+
scatter_count_tensor = reduceOp.reduce_operation_with_scatter(
313+
scatter_count_tensor,
314+
input_tensor,
315+
dim,
316+
index_slice,
317+
torch.ones_like(src_slice),
318+
)
319+
elif reduce == "amax":
320+
reduceOp = ReduceOperation.AMAX
321+
elif reduce == "amin":
322+
reduceOp = ReduceOperation.AMIN
323+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
324+
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
325+
)
326+
if reduce == "mean":
327+
scatter_loop_tensor = torch.div(
328+
scatter_loop_tensor,
329+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
330+
rounding_mode="trunc",
331+
)
332+
return scatter_loop_tensor
333+
334+
241335
@register_torch_trt_decomposition(
242336
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
243337
)

0 commit comments

Comments
 (0)