Skip to content

Commit 8900f67

Browse files
committed
scatter reduce lowering with include_self=False
1 parent b031ef0 commit 8900f67

File tree

2 files changed

+529
-44
lines changed

2 files changed

+529
-44
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

+81-13
Original file line numberDiff line numberDiff line change
@@ -303,21 +303,30 @@ def __new__(cls, description: Any, func: Any) -> Any:
303303
obj.func = func
304304
return obj
305305

306-
def reduce_operation_with_scatter(
306+
def reduce_operation_with_scatter_include_self(
307307
self,
308-
operation_lhs: Any,
309-
initial_tensor: torch.Tensor,
310-
dim: int,
311-
index_tensor: torch.Tensor,
312-
src_tensor: torch.Tensor,
313-
) -> Any:
308+
operation_lhs,
309+
initial_tensor,
310+
dim,
311+
index_tensor,
312+
src_tensor,
313+
min_ele=float("-inf"),
314+
max_ele=float("inf"),
315+
include_self=True,
316+
):
314317
scatter_tensor = None
315318
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
316319
scatter_tensor = torch.zeros_like(initial_tensor)
317320
elif self == ReduceOperation.PROD:
318321
scatter_tensor = torch.ones_like(initial_tensor)
319-
elif self == ReduceOperation.AMIN or self == ReduceOperation.AMAX:
322+
elif self == ReduceOperation.AMAX:
320323
scatter_tensor = initial_tensor
324+
if not (include_self):
325+
scatter_tensor = torch.full_like(initial_tensor, min_ele)
326+
elif self == ReduceOperation.AMIN:
327+
scatter_tensor = initial_tensor
328+
if not (include_self):
329+
scatter_tensor = torch.full_like(initial_tensor, max_ele)
321330
else:
322331
# This case would not be encountered from torch itself
323332
print("Invalid Operation for Reduce op!!")
@@ -341,13 +350,39 @@ def scatter_reduce_decomposition(
341350
include_self: bool = True,
342351
) -> torch.Tensor:
343352
scatter_loop_tensor = input_tensor
353+
MAX_ELE = 0
354+
MIN_ELE = 0
355+
if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32:
356+
MAX_ELE = 2147483647
357+
MIN_ELE = -2147483648
358+
else:
359+
MAX_ELE = float("inf")
360+
MIN_ELE = float("-inf")
361+
if not (include_self):
362+
if reduce == "sum" or reduce == "mean":
363+
scatter_loop_tensor = torch.scatter(
364+
scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor)
365+
)
366+
if reduce == "prod":
367+
scatter_loop_tensor = torch.scatter(
368+
scatter_loop_tensor, dim, index, torch.ones_like(src_tensor)
369+
)
370+
if reduce == "amax":
371+
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
372+
scatter_loop_tensor = torch.scatter(
373+
scatter_loop_tensor, dim, index, src_red_tensor
374+
)
375+
if reduce == "amin":
376+
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
377+
scatter_loop_tensor = torch.scatter(
378+
scatter_loop_tensor, dim, index, src_red_tensor
379+
)
380+
344381
device_input_tensor = input_tensor.device
345382
# required for mean reduce operation
346383
scatter_count_tensor = torch.zeros_like(input_tensor)
347384
src_shape = list(src_tensor.shape)
348385
src_dim = src_shape[dim]
349-
if not include_self:
350-
raise AssertionError("include_self False for scatter reduce not yet supported")
351386
for i in range(0, src_dim):
352387
src_slice = torch.select(src_tensor, dim, i)
353388
index_slice = torch.select(index, dim, i)
@@ -371,20 +406,53 @@ def scatter_reduce_decomposition(
371406
dim,
372407
index_slice,
373408
torch.ones_like(src_slice),
409+
MIN_ELE,
410+
MAX_ELE,
411+
include_self,
374412
)
375413
elif reduce == "amax":
376414
reduceOp = ReduceOperation.AMAX
377415
elif reduce == "amin":
378416
reduceOp = ReduceOperation.AMIN
379-
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter(
380-
scatter_loop_tensor, input_tensor, dim, index_slice, src_slice
417+
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
418+
scatter_loop_tensor,
419+
input_tensor,
420+
dim,
421+
index_slice,
422+
src_slice,
423+
MIN_ELE,
424+
MAX_ELE,
425+
include_self,
381426
)
382427
if reduce == "mean":
383428
scatter_loop_tensor = torch.div(
384429
scatter_loop_tensor,
385-
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)),
430+
(
431+
torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))
432+
if include_self
433+
else scatter_count_tensor
434+
),
386435
rounding_mode="trunc",
387436
)
437+
# for include_self cases for amax and amin additional processing is required
438+
# except for the max elements in amax, rest are -inf or INT_MIN
439+
# except for the min elements in amin, rest are +inf or INT_MAX
440+
if reduce == "amax" and not (include_self):
441+
# the relevant should be min, rest original
442+
return torch.max(
443+
scatter_loop_tensor,
444+
torch.scatter(
445+
input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)
446+
),
447+
)
448+
if reduce == "amin" and not (include_self):
449+
# the relevant should be min, rest original
450+
return torch.min(
451+
scatter_loop_tensor,
452+
torch.scatter(
453+
input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)
454+
),
455+
)
388456
return scatter_loop_tensor
389457

390458

0 commit comments

Comments
 (0)