Skip to content

Commit 47fec01

Browse files
committed
scatter reduce lowering with include_self=False
1 parent 501a1e1 commit 47fec01

File tree

2 files changed

+529
-39
lines changed

2 files changed

+529
-39
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+81-8
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,30 @@ def __new__(cls, description, func):
303303
obj.func = func
304304
return obj
305305

306-
def reduce_operation_with_scatter(
307-
self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor
306+
def reduce_operation_with_scatter_include_self(
307+
self,
308+
operation_lhs,
309+
initial_tensor,
310+
dim,
311+
index_tensor,
312+
src_tensor,
313+
min_ele=float("-inf"),
314+
max_ele=float("inf"),
315+
include_self=True,
308316
):
309317
scatter_tensor = None
310318
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
311319
scatter_tensor = torch.zeros_like(initial_tensor)
312320
elif self == ReduceOperation.PROD:
313321
scatter_tensor = torch.ones_like(initial_tensor)
314-
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
322+
elif self == ReduceOperation.AMAX:
315323
scatter_tensor = initial_tensor
324+
if not (include_self):
325+
scatter_tensor = torch.full_like(initial_tensor, min_ele)
326+
elif self == ReduceOperation.AMIN:
327+
scatter_tensor = initial_tensor
328+
if not (include_self):
329+
scatter_tensor = torch.full_like(initial_tensor, max_ele)
316330
else:
317331
# This case would not be encountered from torch itself
318332
print("Invalid Operation for Reduce op!!")
@@ -336,13 +350,39 @@ def scatter_reduce_decomposition(
336350
include_self: bool = True,
337351
) -> torch.Tensor:
338352
scatter_loop_tensor = input_tensor
353+
MAX_ELE = 0
354+
MIN_ELE = 0
355+
if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32:
356+
MAX_ELE = 2147483647
357+
MIN_ELE = -2147483648
358+
else:
359+
MAX_ELE = float("inf")
360+
MIN_ELE = float("-inf")
361+
if not (include_self):
362+
if reduce == "sum" or reduce == "mean":
363+
scatter_loop_tensor = torch.scatter(
364+
scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor)
365+
)
366+
if reduce == "prod":
367+
scatter_loop_tensor = torch.scatter(
368+
scatter_loop_tensor, dim, index, torch.ones_like(src_tensor)
369+
)
370+
if reduce == "amax":
371+
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
372+
scatter_loop_tensor = torch.scatter(
373+
scatter_loop_tensor, dim, index, src_red_tensor
374+
)
375+
if reduce == "amin":
376+
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
377+
scatter_loop_tensor = torch.scatter(
378+
scatter_loop_tensor, dim, index, src_red_tensor
379+
)
380+
339381
device_input_tensor = input_tensor.device
340382
# required for mean reduce operation
341383
scatter_count_tensor = torch.zeros_like(input_tensor)
342384
src_shape = list(src_tensor.shape)
343385
src_dim = src_shape[dim]
344-
if include_self == False:
345-
raise AssertionError("include_self False for scatter reduce not yet supported")
346386
for i in range(0, src_dim):
347387
src_slice = torch.select(src_tensor, dim, i)
348388
index_slice = torch.select(index, dim, i)
@@ -366,20 +406,53 @@ def scatter_reduce_decomposition(
366406
dim,
367407
index_slice,
368408
torch.ones_like(src_slice),
409+
MIN_ELE,
410+
MAX_ELE,
411+
include_self,
369412
)
370413
elif reduce == "amax":
371414
reduceOp = ReduceOperation.AMAX
372415
elif reduce == "amin":
373416
reduceOp = ReduceOperation.AMIN
374-
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
375-
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
417+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
418+
scatter_loop_tensor,
419+
input_tensor,
420+
dim,
421+
index_slice,
422+
src_slice,
423+
MIN_ELE,
424+
MAX_ELE,
425+
include_self,
376426
)
377427
if reduce == "mean":
378428
scatter_loop_tensor = torch.div(
379429
scatter_loop_tensor,
380-
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
430+
(
431+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))
432+
if include_self
433+
else scatter_count_tensor
434+
),
381435
rounding_mode="trunc",
382436
)
437+
# for include_self cases for amax and amin additional processing is required
438+
# except for the max elements in amax, rest are -inf or INT_MIN
439+
# except for the min elements in amin, rest are +inf or INT_MAX
440+
if reduce == "amax" and not (include_self):
441+
# the relevant should be min, rest original
442+
return torch.max(
443+
scatter_loop_tensor,
444+
torch.scatter(
445+
input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)
446+
),
447+
)
448+
if reduce == "amin" and not (include_self):
449+
# the relevant should be min, rest original
450+
return torch.min(
451+
scatter_loop_tensor,
452+
torch.scatter(
453+
input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)
454+
),
455+
)
383456
return scatter_loop_tensor
384457

385458

0 commit comments

Comments
 (0)