-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scatter reduce lowering with include_self=False #3153
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
9a8c124
to
645a725
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
645a725
to
0575e51
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py 2024-09-11 04:39:39.095424+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py 2024-09-11 04:39:59.142801+00:00
@@ -302,24 +302,32 @@
obj.description = description
obj.func = func
return obj
def reduce_operation_with_scatter_include_self(
- self, operation_lhs, initial_tensor, dim, index_tensor, src_tensor, min_ele = float('-inf'), max_ele = float('inf'), include_self=True
+ self,
+ operation_lhs,
+ initial_tensor,
+ dim,
+ index_tensor,
+ src_tensor,
+ min_ele=float("-inf"),
+ max_ele=float("inf"),
+ include_self=True,
):
scatter_tensor = None
if self == ReduceOperation.SUM or self == ReduceOperation.MEAN:
scatter_tensor = torch.zeros_like(initial_tensor)
elif self == ReduceOperation.PROD:
scatter_tensor = torch.ones_like(initial_tensor)
elif self == ReduceOperation.AMAX:
scatter_tensor = initial_tensor
- if(not(include_self)):
+ if not (include_self):
scatter_tensor = torch.full_like(initial_tensor, min_ele)
elif self == ReduceOperation.AMIN:
scatter_tensor = initial_tensor
- if(not(include_self)):
+ if not (include_self):
scatter_tensor = torch.full_like(initial_tensor, max_ele)
else:
# This case would not be encountered from torch itself
print("Invalid Operation for Reduce op!!")
@@ -342,27 +350,35 @@
include_self: bool = True,
) -> torch.Tensor:
scatter_loop_tensor = input_tensor
MAX_ELE = 0
MIN_ELE = 0
- if(src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32):
+ if src_tensor.dtype == torch.int32 or input_tensor.dtype == torch.int32:
MAX_ELE = 2147483647
MIN_ELE = -2147483648
else:
- MAX_ELE = float('inf')
- MIN_ELE = float('-inf')
- if(not(include_self)):
- if (reduce == "sum" or reduce == "mean"):
- scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor))
- if (reduce == "prod"):
- scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, torch.ones_like(src_tensor))
- if (reduce == "amax"):
+ MAX_ELE = float("inf")
+ MIN_ELE = float("-inf")
+ if not (include_self):
+ if reduce == "sum" or reduce == "mean":
+ scatter_loop_tensor = torch.scatter(
+ scatter_loop_tensor, dim, index, torch.zeros_like(src_tensor)
+ )
+ if reduce == "prod":
+ scatter_loop_tensor = torch.scatter(
+ scatter_loop_tensor, dim, index, torch.ones_like(src_tensor)
+ )
+ if reduce == "amax":
src_red_tensor = torch.full_like(src_tensor, MIN_ELE)
- scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
- if (reduce == "amin"):
+ scatter_loop_tensor = torch.scatter(
+ scatter_loop_tensor, dim, index, src_red_tensor
+ )
+ if reduce == "amin":
src_red_tensor = torch.full_like(src_tensor, MAX_ELE)
- scatter_loop_tensor = torch.scatter(scatter_loop_tensor, dim, index, src_red_tensor)
+ scatter_loop_tensor = torch.scatter(
+ scatter_loop_tensor, dim, index, src_red_tensor
+ )
device_input_tensor = input_tensor.device
# required for mean reduce operation
scatter_count_tensor = torch.zeros_like(input_tensor)
src_shape = list(src_tensor.shape)
@@ -390,34 +406,55 @@
dim,
index_slice,
torch.ones_like(src_slice),
MIN_ELE,
MAX_ELE,
- include_self
+ include_self,
)
elif reduce == "amax":
reduceOp = ReduceOperation.AMAX
elif reduce == "amin":
reduceOp = ReduceOperation.AMIN
scatter_loop_tensor = reduceOp.reduce_operation_with_scatter_include_self(
- scatter_loop_tensor, input_tensor, dim, index_slice, src_slice, MIN_ELE, MAX_ELE, include_self
+ scatter_loop_tensor,
+ input_tensor,
+ dim,
+ index_slice,
+ src_slice,
+ MIN_ELE,
+ MAX_ELE,
+ include_self,
)
if reduce == "mean":
scatter_loop_tensor = torch.div(
scatter_loop_tensor,
- torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor)) if include_self else scatter_count_tensor,
+ (
+ torch.add(scatter_count_tensor, torch.ones_like(scatter_count_tensor))
+ if include_self
+ else scatter_count_tensor
+ ),
rounding_mode="trunc",
)
- #for include_self cases for amax and amin additional processing is required
- #except for the max elements in amax, rest are -inf or INT_MIN
- #except for the min elements in amin, rest are +inf or INT_MAX
- if reduce == "amax" and not(include_self):
- #the relevant should be min, rest original
- return torch.max(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)))
- if reduce == "amin" and not(include_self):
- #the relevant should be min, rest original
- return torch.min(scatter_loop_tensor, torch.scatter(input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)))
+ # for include_self cases for amax and amin additional processing is required
+ # except for the max elements in amax, rest are -inf or INT_MIN
+ # except for the min elements in amin, rest are +inf or INT_MAX
+ if reduce == "amax" and not (include_self):
+ # the relevant should be min, rest original
+ return torch.max(
+ scatter_loop_tensor,
+ torch.scatter(
+ input_tensor, dim, index, torch.full_like(src_tensor, MIN_ELE)
+ ),
+ )
+ if reduce == "amin" and not (include_self):
+ # the relevant should be min, rest original
+ return torch.min(
+ scatter_loop_tensor,
+ torch.scatter(
+ input_tensor, dim, index, torch.full_like(src_tensor, MAX_ELE)
+ ),
+ )
return scatter_loop_tensor
def get_decompositions(
enable_experimental_decompositions: bool = False,
0575e51
to
47fec01
Compare
47fec01
to
8900f67
Compare
This is for scatter_reduce decomposition where include_self=False