@@ -303,16 +303,22 @@ def __new__(cls, description, func):
303
303
obj .func = func
304
304
return obj
305
305
306
- def reduce_operation_with_scatter (
307
- self , operation_lhs , initial_tensor , dim , index_tensor , src_tensor
306
+ def reduce_operation_with_scatter_include_self (
307
+ self , operation_lhs , initial_tensor , dim , index_tensor , src_tensor , min_ele = float ( '-inf' ), max_ele = float ( 'inf' ), include_self = True
308
308
):
309
309
scatter_tensor = None
310
310
if self == ReduceOperation .SUM or self == ReduceOperation .MEAN :
311
311
scatter_tensor = torch .zeros_like (initial_tensor )
312
312
elif self == ReduceOperation .PROD :
313
313
scatter_tensor = torch .ones_like (initial_tensor )
314
- elif self == ReduceOperation .AMIN or self == ReduceOperation . AMAX :
314
+ elif self == ReduceOperation .AMAX :
315
315
scatter_tensor = initial_tensor
316
+ if (not (include_self )):
317
+ scatter_tensor = torch .full_like (initial_tensor , min_ele )
318
+ elif self == ReduceOperation .AMIN :
319
+ scatter_tensor = initial_tensor
320
+ if (not (include_self )):
321
+ scatter_tensor = torch .full_like (initial_tensor , max_ele )
316
322
else :
317
323
# This case would not be encountered from torch itself
318
324
print ("Invalid Operation for Reduce op!!" )
@@ -336,13 +342,31 @@ def scatter_reduce_decomposition(
336
342
include_self : bool = True ,
337
343
) -> torch .Tensor :
338
344
scatter_loop_tensor = input_tensor
345
+ MAX_ELE = 0
346
+ MIN_ELE = 0
347
+ if (src_tensor .dtype == torch .int32 or input_tensor .dtype == torch .int32 ):
348
+ MAX_ELE = 2147483647
349
+ MIN_ELE = - 2147483648
350
+ else :
351
+ MAX_ELE = float ('inf' )
352
+ MIN_ELE = float ('-inf' )
353
+ if (not (include_self )):
354
+ if (reduce == "sum" or reduce == "mean" ):
355
+ scatter_loop_tensor = torch .scatter (scatter_loop_tensor , dim , index , torch .zeros_like (src_tensor ))
356
+ if (reduce == "prod" ):
357
+ scatter_loop_tensor = torch .scatter (scatter_loop_tensor , dim , index , torch .ones_like (src_tensor ))
358
+ if (reduce == "amax" ):
359
+ src_red_tensor = torch .full_like (src_tensor , MIN_ELE )
360
+ scatter_loop_tensor = torch .scatter (scatter_loop_tensor , dim , index , src_red_tensor )
361
+ if (reduce == "amin" ):
362
+ src_red_tensor = torch .full_like (src_tensor , MAX_ELE )
363
+ scatter_loop_tensor = torch .scatter (scatter_loop_tensor , dim , index , src_red_tensor )
364
+
339
365
device_input_tensor = input_tensor .device
340
366
# required for mean reduce operation
341
367
scatter_count_tensor = torch .zeros_like (input_tensor )
342
368
src_shape = list (src_tensor .shape )
343
369
src_dim = src_shape [dim ]
344
- if include_self == False :
345
- raise AssertionError ("include_self False for scatter reduce not yet supported" )
346
370
for i in range (0 , src_dim ):
347
371
src_slice = torch .select (src_tensor , dim , i )
348
372
index_slice = torch .select (index , dim , i )
@@ -366,20 +390,32 @@ def scatter_reduce_decomposition(
366
390
dim ,
367
391
index_slice ,
368
392
torch .ones_like (src_slice ),
393
+ MIN_ELE ,
394
+ MAX_ELE ,
395
+ include_self
369
396
)
370
397
elif reduce == "amax" :
371
398
reduceOp = ReduceOperation .AMAX
372
399
elif reduce == "amin" :
373
400
reduceOp = ReduceOperation .AMIN
374
- scatter_loop_tensor = reduceOp .reduce_operation_with_scatter (
375
- scatter_loop_tensor , input_tensor , dim , index_slice , src_slice
401
+ scatter_loop_tensor = reduceOp .reduce_operation_with_scatter_include_self (
402
+ scatter_loop_tensor , input_tensor , dim , index_slice , src_slice , MIN_ELE , MAX_ELE , include_self
376
403
)
377
404
if reduce == "mean" :
378
405
scatter_loop_tensor = torch .div (
379
406
scatter_loop_tensor ,
380
- torch .add (scatter_count_tensor , torch .ones_like (scatter_count_tensor )),
407
+ torch .add (scatter_count_tensor , torch .ones_like (scatter_count_tensor )) if include_self else scatter_count_tensor ,
381
408
rounding_mode = "trunc" ,
382
409
)
410
+ #for include_self cases for amax and amin additional processing is required
411
+ #except for the max elements in amax, rest are -inf or INT_MIN
412
+ #except for the min elements in amin, rest are +inf or INT_MAX
413
+ if reduce == "amax" and not (include_self ):
414
+ #the relevant should be min, rest original
415
+ return torch .max (scatter_loop_tensor , torch .scatter (input_tensor , dim , index , torch .full_like (src_tensor , MIN_ELE )))
416
+ if reduce == "amin" and not (include_self ):
417
+ #the relevant should be min, rest original
418
+ return torch .min (scatter_loop_tensor , torch .scatter (input_tensor , dim , index , torch .full_like (src_tensor , MAX_ELE )))
383
419
return scatter_loop_tensor
384
420
385
421
0 commit comments