@@ -303,16 +303,30 @@ def __new__(cls, description, func):
303
303
obj .func = func
304
304
return obj
305
305
306
- def reduce_operation_with_scatter (
307
- self , operation_lhs , initial_tensor , dim , index_tensor , src_tensor
306
+ def reduce_operation_with_scatter_include_self (
307
+ self ,
308
+ operation_lhs ,
309
+ initial_tensor ,
310
+ dim ,
311
+ index_tensor ,
312
+ src_tensor ,
313
+ min_ele = float ("-inf" ),
314
+ max_ele = float ("inf" ),
315
+ include_self = True ,
308
316
):
309
317
scatter_tensor = None
310
318
if self == ReduceOperation .SUM or self == ReduceOperation .MEAN :
311
319
scatter_tensor = torch .zeros_like (initial_tensor )
312
320
elif self == ReduceOperation .PROD :
313
321
scatter_tensor = torch .ones_like (initial_tensor )
314
- elif self == ReduceOperation .AMIN or self == ReduceOperation . AMAX :
322
+ elif self == ReduceOperation .AMAX :
315
323
scatter_tensor = initial_tensor
324
+ if not (include_self ):
325
+ scatter_tensor = torch .full_like (initial_tensor , min_ele )
326
+ elif self == ReduceOperation .AMIN :
327
+ scatter_tensor = initial_tensor
328
+ if not (include_self ):
329
+ scatter_tensor = torch .full_like (initial_tensor , max_ele )
316
330
else :
317
331
# This case would not be encountered from torch itself
318
332
print ("Invalid Operation for Reduce op!!" )
@@ -336,13 +350,39 @@ def scatter_reduce_decomposition(
336
350
include_self : bool = True ,
337
351
) -> torch .Tensor :
338
352
scatter_loop_tensor = input_tensor
353
+ MAX_ELE = 0
354
+ MIN_ELE = 0
355
+ if src_tensor .dtype == torch .int32 or input_tensor .dtype == torch .int32 :
356
+ MAX_ELE = 2147483647
357
+ MIN_ELE = - 2147483648
358
+ else :
359
+ MAX_ELE = float ("inf" )
360
+ MIN_ELE = float ("-inf" )
361
+ if not (include_self ):
362
+ if reduce == "sum" or reduce == "mean" :
363
+ scatter_loop_tensor = torch .scatter (
364
+ scatter_loop_tensor , dim , index , torch .zeros_like (src_tensor )
365
+ )
366
+ if reduce == "prod" :
367
+ scatter_loop_tensor = torch .scatter (
368
+ scatter_loop_tensor , dim , index , torch .ones_like (src_tensor )
369
+ )
370
+ if reduce == "amax" :
371
+ src_red_tensor = torch .full_like (src_tensor , MIN_ELE )
372
+ scatter_loop_tensor = torch .scatter (
373
+ scatter_loop_tensor , dim , index , src_red_tensor
374
+ )
375
+ if reduce == "amin" :
376
+ src_red_tensor = torch .full_like (src_tensor , MAX_ELE )
377
+ scatter_loop_tensor = torch .scatter (
378
+ scatter_loop_tensor , dim , index , src_red_tensor
379
+ )
380
+
339
381
device_input_tensor = input_tensor .device
340
382
# required for mean reduce operation
341
383
scatter_count_tensor = torch .zeros_like (input_tensor )
342
384
src_shape = list (src_tensor .shape )
343
385
src_dim = src_shape [dim ]
344
- if include_self == False :
345
- raise AssertionError ("include_self False for scatter reduce not yet supported" )
346
386
for i in range (0 , src_dim ):
347
387
src_slice = torch .select (src_tensor , dim , i )
348
388
index_slice = torch .select (index , dim , i )
@@ -366,20 +406,53 @@ def scatter_reduce_decomposition(
366
406
dim ,
367
407
index_slice ,
368
408
torch .ones_like (src_slice ),
409
+ MIN_ELE ,
410
+ MAX_ELE ,
411
+ include_self ,
369
412
)
370
413
elif reduce == "amax" :
371
414
reduceOp = ReduceOperation .AMAX
372
415
elif reduce == "amin" :
373
416
reduceOp = ReduceOperation .AMIN
374
- scatter_loop_tensor = reduceOp .reduce_operation_with_scatter (
375
- scatter_loop_tensor , input_tensor , dim , index_slice , src_slice
417
+ scatter_loop_tensor = reduceOp .reduce_operation_with_scatter_include_self (
418
+ scatter_loop_tensor ,
419
+ input_tensor ,
420
+ dim ,
421
+ index_slice ,
422
+ src_slice ,
423
+ MIN_ELE ,
424
+ MAX_ELE ,
425
+ include_self ,
376
426
)
377
427
if reduce == "mean" :
378
428
scatter_loop_tensor = torch .div (
379
429
scatter_loop_tensor ,
380
- torch .add (scatter_count_tensor , torch .ones_like (scatter_count_tensor )),
430
+ (
431
+ torch .add (scatter_count_tensor , torch .ones_like (scatter_count_tensor ))
432
+ if include_self
433
+ else scatter_count_tensor
434
+ ),
381
435
rounding_mode = "trunc" ,
382
436
)
437
+ # for include_self cases for amax and amin additional processing is required
438
+ # except for the max elements in amax, rest are -inf or INT_MIN
439
+ # except for the min elements in amin, rest are +inf or INT_MAX
440
+ if reduce == "amax" and not (include_self ):
441
+ # the relevant should be min, rest original
442
+ return torch .max (
443
+ scatter_loop_tensor ,
444
+ torch .scatter (
445
+ input_tensor , dim , index , torch .full_like (src_tensor , MIN_ELE )
446
+ ),
447
+ )
448
+ if reduce == "amin" and not (include_self ):
449
+ # the relevant should be min, rest original
450
+ return torch .min (
451
+ scatter_loop_tensor ,
452
+ torch .scatter (
453
+ input_tensor , dim , index , torch .full_like (src_tensor , MAX_ELE )
454
+ ),
455
+ )
383
456
return scatter_loop_tensor
384
457
385
458
0 commit comments