@@ -92,7 +92,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):
92
92
93
93
94
94
class Operator (BenchmarkOperator ):
95
-
96
95
DEFAULT_METRICS = ["latency" , "accuracy" ]
97
96
DEFAULT_PRECISION = "fp32"
98
97
@@ -104,8 +103,8 @@ def __init__(
104
103
self , tb_args : argparse .Namespace , extra_args : Optional [List [str ]] = None
105
104
):
106
105
super ().__init__ (tb_args , extra_args )
107
- self .sizes = list ( range ( 2 , 12 , 4 )) + list (
108
- range (12 , 23 , 3 )
106
+ self .sizes = (
107
+ list ( range (2 , 12 , 4 )) + list ( range ( 12 , 23 , 3 ) )
109
108
) # bias towards larger sizes, which are more representative of real-world shapes
110
109
111
110
args = parse_op_args (self .extra_args )
@@ -130,28 +129,37 @@ def torch_jagged_mean_unbind_torch_mean(
130
129
def torch_jagged_mean_torch_nanmean (
131
130
self , x : torch .Tensor , B : int , M : int , seqlen : int , sparsity : float
132
131
):
133
- return lambda : torch .nanmean (
134
- torch .ops .aten ._jagged_to_padded_dense_forward (
135
- x .values (),
136
- [x .offsets ()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
137
- max_lengths = [seqlen ], # max length of ragged dimension
138
- padding_value = float ("nan" ),
139
- ),
140
- dim = 1 ,
132
+ return (
133
+ lambda : torch .nanmean (
134
+ torch .ops .aten ._jagged_to_padded_dense_forward (
135
+ x .values (),
136
+ [
137
+ x .offsets ()
138
+ ], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
139
+ max_lengths = [seqlen ], # max length of ragged dimension
140
+ padding_value = float ("nan" ),
141
+ ),
142
+ dim = 1 ,
143
+ )
141
144
)
142
145
143
146
@register_benchmark ()
144
147
def torch_jagged_mean_torch_sum (
145
148
self , x : torch .Tensor , B : int , M : int , seqlen : int , sparsity : float
146
149
):
147
- return lambda : torch .sum (
148
- torch .ops .aten ._jagged_to_padded_dense_forward (
149
- x .values (),
150
- [x .offsets ()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
151
- max_lengths = [seqlen ], # max length of ragged dimension
152
- ),
153
- dim = 1 ,
154
- ) / x .offsets ().diff ().unsqueeze (1 )
150
+ return (
151
+ lambda : torch .sum (
152
+ torch .ops .aten ._jagged_to_padded_dense_forward (
153
+ x .values (),
154
+ [
155
+ x .offsets ()
156
+ ], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
157
+ max_lengths = [seqlen ], # max length of ragged dimension
158
+ ),
159
+ dim = 1 ,
160
+ )
161
+ / x .offsets ().diff ().unsqueeze (1 )
162
+ )
155
163
156
164
@register_benchmark ()
157
165
def triton_jagged_mean_simple_fused (
@@ -176,7 +184,9 @@ def torch_compile_nested_tensor_integration(
176
184
self , x : torch .Tensor , B : int , M : int , seqlen : int , sparsity : float
177
185
):
178
186
def _inner (x : torch .Tensor ): # mean along ragged dimension (dim == 1)
179
- return torch .mean (x , dim = x ._ragged_idx , keepdim = True ) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
187
+ return torch .mean (
188
+ x , dim = x ._ragged_idx , keepdim = True
189
+ ) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
180
190
181
191
torch_compile_func = torch .compile (_inner )
182
192
return lambda : torch_compile_func (x )
0 commit comments