Skip to content

Commit 0575e51

Browse files
committed
scatter reduce lowering with include_self=False
1 parent 501a1e1 commit 0575e51

File tree

2 files changed

+492
-39
lines changed

2 files changed

+492
-39
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+44-8
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,22 @@ def __new__(cls, description, func):
303303
obj.func = func
304304
return obj
305305

306-
def reduce_operation_with_scatter(
307-
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
306+
def reduce_operation_with_scatter_include_self(
307+
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor, min_ele = float('-inf'), max_ele = float('inf'), include_self=True
308308
):
309309
scatter_tensor = None
310310
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
311311
scatter_tensor = torch.zeros_like(initial_tensor)
312312
elif self == ReduceOperation.PROD:
313313
scatter_tensor = torch.ones_like(initial_tensor)
314-
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
314+
elif self == ReduceOperation.AMAX:
315315
scatter_tensor = initial_tensor
316+
if(not(include_self)):
317+
scatter_tensor = torch.full_like(initial_tensor, min_ele)
318+
elif self == ReduceOperation.AMIN:
319+
scatter_tensor = initial_tensor
320+
if(not(include_self)):
321+
scatter_tensor = torch.full_like(initial_tensor, max_ele)
316322
else:
317323
# This case would not be encountered from torch itself
318324
print("Invalid Operation for Reduce op!!")
@@ -336,13 +342,31 @@ def scatter_reduce_decomposition(
336342
include_self: bool = True,
337343
) -> torch.Tensor:
338344
scatter_loop_tensor = input_tensor
345+
MAX_ELE = 0
346+
MIN_ELE = 0
347+
if(src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32):
348+
MAX_ELE = 2147483647
349+
MIN_ELE = -2147483648
350+
else:
351+
MAX_ELE = float('inf')
352+
MIN_ELE = float('-inf')
353+
if(not(include_self)):
354+
if (reduce == "sum" or reduce == "mean"):
355+
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor))
356+
if (reduce == "prod"):
357+
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.ones_like(src_tensor))
358+
if (reduce == "amax"):
359+
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
360+
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
361+
if (reduce == "amin"):
362+
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
363+
scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
364+
339365
device_input_tensor = input_tensor.device
340366
# required for mean reduce operation
341367
scatter_count_tensor = torch.zeros_like(input_tensor)
342368
src_shape = list(src_tensor.shape)
343369
src_dim = src_shape[dim]
344-
if include_self == False:
345-
raise AssertionError("include_self False for scatter reduce not yet supported")
346370
for i in range(0, src_dim):
347371
src_slice = torch.select(src_tensor, dim, i)
348372
index_slice = torch.select(index, dim, i)
@@ -366,20 +390,32 @@ def scatter_reduce_decomposition(
366390
dim,
367391
index_slice,
368392
torch.ones_like(src_slice),
393+
MIN_ELE,
394+
MAX_ELE,
395+
include_self
369396
)
370397
elif reduce == "amax":
371398
reduceOp = ReduceOperation.AMAX
372399
elif reduce == "amin":
373400
reduceOp = ReduceOperation.AMIN
374-
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
375-
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
401+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
402+
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice, MIN_ELE, MAX_ELE, include_self
376403
)
377404
if reduce == "mean":
378405
scatter_loop_tensor = torch.div(
379406
scatter_loop_tensor,
380-
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
407+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)) if include_self else scatter_count_tensor,
381408
rounding_mode="trunc",
382409
)
410+
#for include_self cases for amax and amin additional processing is required
411+
#except for the max elements in amax, rest are -inf or INT_MIN
412+
#except for the min elements in amin, rest are +inf or INT_MAX
413+
if reduce == "amax" and not(include_self):
414+
#the relevant should be min, rest original
415+
return torch.max(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)))
416+
if reduce == "amin" and not(include_self):
417+
#the relevant should be min, rest original
418+
return torch.min(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)))
383419
return scatter_loop_tensor
384420

385421

0 commit comments

Comments
 (0)