@@ -303,21 +303,30 @@ def __new__(cls, description: Any, func: Any) -> Any:
303
303
obj .func = func
304
304
return obj
305
305
306
- def reduce_operation_with_scatter (
306
+ def reduce_operation_with_scatter_include_self (
307
307
self ,
308
- operation_lhs : Any ,
309
- initial_tensor : torch .Tensor ,
310
- dim : int ,
311
- index_tensor : torch .Tensor ,
312
- src_tensor : torch .Tensor ,
313
- ) -> Any :
308
+ operation_lhs ,
309
+ initial_tensor ,
310
+ dim ,
311
+ index_tensor ,
312
+ src_tensor ,
313
+ min_ele = float ("-inf" ),
314
+ max_ele = float ("inf" ),
315
+ include_self = True ,
316
+ ):
314
317
scatter_tensor = None
315
318
if self == ReduceOperation .SUM or self == ReduceOperation .MEAN :
316
319
scatter_tensor = torch .zeros_like (initial_tensor )
317
320
elif self == ReduceOperation .PROD :
318
321
scatter_tensor = torch .ones_like (initial_tensor )
319
- elif self == ReduceOperation .AMIN or self == ReduceOperation . AMAX :
322
+ elif self == ReduceOperation .AMAX :
320
323
scatter_tensor = initial_tensor
324
+ if not (include_self ):
325
+ scatter_tensor = torch .full_like (initial_tensor , min_ele )
326
+ elif self == ReduceOperation .AMIN :
327
+ scatter_tensor = initial_tensor
328
+ if not (include_self ):
329
+ scatter_tensor = torch .full_like (initial_tensor , max_ele )
321
330
else :
322
331
# This case would not be encountered from torch itself
323
332
print ("Invalid Operation for Reduce op!!" )
@@ -341,13 +350,39 @@ def scatter_reduce_decomposition(
341
350
include_self : bool = True ,
342
351
) -> torch .Tensor :
343
352
scatter_loop_tensor = input_tensor
353
+ MAX_ELE = 0
354
+ MIN_ELE = 0
355
+ if src_tensor .dtype == torch .int32 or input_tensor .dtype == torch .int32 :
356
+ MAX_ELE = 2147483647
357
+ MIN_ELE = - 2147483648
358
+ else :
359
+ MAX_ELE = float ("inf" )
360
+ MIN_ELE = float ("-inf" )
361
+ if not (include_self ):
362
+ if reduce == "sum" or reduce == "mean" :
363
+ scatter_loop_tensor = torch .scatter (
364
+ scatter_loop_tensor , dim , index , torch .zeros_like (src_tensor )
365
+ )
366
+ if reduce == "prod" :
367
+ scatter_loop_tensor = torch .scatter (
368
+ scatter_loop_tensor , dim , index , torch .ones_like (src_tensor )
369
+ )
370
+ if reduce == "amax" :
371
+ src_red_tensor = torch .full_like (src_tensor , MIN_ELE )
372
+ scatter_loop_tensor = torch .scatter (
373
+ scatter_loop_tensor , dim , index , src_red_tensor
374
+ )
375
+ if reduce == "amin" :
376
+ src_red_tensor = torch .full_like (src_tensor , MAX_ELE )
377
+ scatter_loop_tensor = torch .scatter (
378
+ scatter_loop_tensor , dim , index , src_red_tensor
379
+ )
380
+
344
381
device_input_tensor = input_tensor .device
345
382
# required for mean reduce operation
346
383
scatter_count_tensor = torch .zeros_like (input_tensor )
347
384
src_shape = list (src_tensor .shape )
348
385
src_dim = src_shape [dim ]
349
- if not include_self :
350
- raise AssertionError ("include_self False for scatter reduce not yet supported" )
351
386
for i in range (0 , src_dim ):
352
387
src_slice = torch .select (src_tensor , dim , i )
353
388
index_slice = torch .select (index , dim , i )
@@ -371,20 +406,53 @@ def scatter_reduce_decomposition(
371
406
dim ,
372
407
index_slice ,
373
408
torch .ones_like (src_slice ),
409
+ MIN_ELE ,
410
+ MAX_ELE ,
411
+ include_self ,
374
412
)
375
413
elif reduce == "amax" :
376
414
reduceOp = ReduceOperation .AMAX
377
415
elif reduce == "amin" :
378
416
reduceOp = ReduceOperation .AMIN
379
- scatter_loop_tensor = reduceOp .reduce_operation_with_scatter (
380
- scatter_loop_tensor , input_tensor , dim , index_slice , src_slice
417
+ scatter_loop_tensor = reduceOp .reduce_operation_with_scatter_include_self (
418
+ scatter_loop_tensor ,
419
+ input_tensor ,
420
+ dim ,
421
+ index_slice ,
422
+ src_slice ,
423
+ MIN_ELE ,
424
+ MAX_ELE ,
425
+ include_self ,
381
426
)
382
427
if reduce == "mean" :
383
428
scatter_loop_tensor = torch .div (
384
429
scatter_loop_tensor ,
385
- torch .add (scatter_count_tensor , torch .ones_like (scatter_count_tensor )),
430
+ (
431
+ torch .add (scatter_count_tensor , torch .ones_like (scatter_count_tensor ))
432
+ if include_self
433
+ else scatter_count_tensor
434
+ ),
386
435
rounding_mode = "trunc" ,
387
436
)
437
+ # for include_self cases for amax and amin additional processing is required
438
+ # except for the max elements in amax, rest are -inf or INT_MIN
439
+ # except for the min elements in amin, rest are +inf or INT_MAX
440
+ if reduce == "amax" and not (include_self ):
441
+ # the relevant should be min, rest original
442
+ return torch .max (
443
+ scatter_loop_tensor ,
444
+ torch .scatter (
445
+ input_tensor , dim , index , torch .full_like (src_tensor , MIN_ELE )
446
+ ),
447
+ )
448
+ if reduce == "amin" and not (include_self ):
449
+ # the relevant should be min, rest original
450
+ return torch .min (
451
+ scatter_loop_tensor ,
452
+ torch .scatter (
453
+ input_tensor , dim , index , torch .full_like (src_tensor , MAX_ELE )
454
+ ),
455
+ )
388
456
return scatter_loop_tensor
389
457
390
458
0 commit comments