Skip to content

Commit 352be40

Browse files
FindHaofacebook-github-bot
authored andcommitted
Add internal jagged_dense_dense_sum op
Summary: This diff add the internal jagged_dense_dense_elementwise_add_jagged_output_forward operator. It uses the internal input shapes, which is the first step to utilize durin data in tritonbench. scuba table for this op https://fburl.com/scuba/gpu_kernel_stats/huxds0z5 Reviewed By: xuzhao9 Differential Revision: D71073294 fbshipit-source-id: 3ea4a9eca1d4cb4c8c43664b9d2d505b202def14
1 parent 73c9b75 commit 352be40

File tree

6 files changed

+144
-26
lines changed

6 files changed

+144
-26
lines changed

tritonbench/data/__init__.py

Whitespace-only changes.

tritonbench/operators/jagged_softmax/operator.py

+43-13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_styles,
1919
get_tensor_bytes_limit,
2020
GIGABYTES_PER_BYTE,
21+
jagged_to_nested_tensor,
2122
RANDOM_CHOICE_MARGIN,
2223
RELATIVE_TOLERANCE,
2324
)
@@ -178,20 +179,49 @@ def get_input_iter(self) -> Generator:
178179
"""
179180
Generate random nested tensors of shape (B, *, M), where * is the ragged dimension
180181
"""
182+
if not self.prod_shapes:
183+
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
184+
185+
for nt, B, M, max_seqlen, sparsity in generate_random_nested_tensors(
186+
B_vals,
187+
M_vals,
188+
seqlen_vals,
189+
sparsity_vals,
190+
device=self.device,
191+
dtype=self.dtype,
192+
TENSOR_BYTES_LIMIT=self.tensor_bytes_limit,
193+
RANDOM_CHOICE_MARGIN=RANDOM_CHOICE_MARGIN,
194+
):
195+
yield (nt, B, M, max_seqlen, sparsity)
196+
else:
197+
from tritonbench.data.fb.jagged_dense_dense import (
198+
generate_input_vals_fb,
199+
get_prod_input_metadata,
200+
)
181201

182-
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
183-
184-
for nt, B, M, max_seqlen, sparsity in generate_random_nested_tensors(
185-
B_vals,
186-
M_vals,
187-
seqlen_vals,
188-
sparsity_vals,
189-
device=self.device,
190-
dtype=self.dtype,
191-
TENSOR_BYTES_LIMIT=self.tensor_bytes_limit,
192-
RANDOM_CHOICE_MARGIN=RANDOM_CHOICE_MARGIN,
193-
):
194-
yield (nt, B, M, max_seqlen, sparsity)
202+
input_data = get_prod_input_metadata()
203+
for (
204+
jagged_values_shape,
205+
dense_0_shape,
206+
dense_1_shape,
207+
jagged_values_dtype,
208+
dense_0_dtype,
209+
dense_1_dtype,
210+
) in input_data:
211+
jagged_values, jagged_offsets, _, _ = generate_input_vals_fb(
212+
jagged_values_shape,
213+
dense_0_shape=dense_0_shape,
214+
dense_1_shape=dense_1_shape,
215+
jagged_values_dtype=jagged_values_dtype,
216+
dense_0_dtype=dense_0_dtype,
217+
dense_1_dtype=dense_1_dtype,
218+
)
219+
nested_tensor = jagged_to_nested_tensor(jagged_values, jagged_offsets)
220+
# Yueming: in the future, if we integrate more input shapes for other jagged operators,
221+
# the dense_0 may be None. In that case, we should use another way to obtain the batch size
222+
# and max seq len.
223+
batch_size, max_seq_len, _ = dense_0_shape
224+
yield (nested_tensor, batch_size, 1, max_seq_len, 0.0)
195225

196226
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
197227
output = fn()

tritonbench/operators/jagged_sum/operator.py

+43-13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_styles,
1919
get_tensor_bytes_limit,
2020
GIGABYTES_PER_BYTE,
21+
jagged_to_nested_tensor,
2122
RANDOM_CHOICE_MARGIN,
2223
RELATIVE_TOLERANCE,
2324
)
@@ -192,20 +193,49 @@ def get_input_iter(self) -> Generator:
192193
"""
193194
Generate random nested tensors of shape (B, *, M), where * is the ragged dimension
194195
"""
196+
if not self.prod_shapes:
197+
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
198+
199+
for nt, B, M, max_seqlen, sparsity in generate_random_nested_tensors(
200+
B_vals,
201+
M_vals,
202+
seqlen_vals,
203+
sparsity_vals,
204+
device=self.device,
205+
dtype=self.dtype,
206+
TENSOR_BYTES_LIMIT=self.tensor_bytes_limit,
207+
RANDOM_CHOICE_MARGIN=RANDOM_CHOICE_MARGIN,
208+
):
209+
yield (nt, B, M, max_seqlen, sparsity)
210+
else:
211+
from tritonbench.data.fb.jagged_dense_dense import (
212+
generate_input_vals_fb,
213+
get_prod_input_metadata,
214+
)
195215

196-
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
197-
198-
for nt, B, M, max_seqlen, sparsity in generate_random_nested_tensors(
199-
B_vals,
200-
M_vals,
201-
seqlen_vals,
202-
sparsity_vals,
203-
device=self.device,
204-
dtype=self.dtype,
205-
TENSOR_BYTES_LIMIT=self.tensor_bytes_limit,
206-
RANDOM_CHOICE_MARGIN=RANDOM_CHOICE_MARGIN,
207-
):
208-
yield (nt, B, M, max_seqlen, sparsity)
216+
input_data = get_prod_input_metadata()
217+
for (
218+
jagged_values_shape,
219+
dense_0_shape,
220+
dense_1_shape,
221+
jagged_values_dtype,
222+
dense_0_dtype,
223+
dense_1_dtype,
224+
) in input_data:
225+
jagged_values, jagged_offsets, _, _ = generate_input_vals_fb(
226+
jagged_values_shape,
227+
dense_0_shape=dense_0_shape,
228+
dense_1_shape=dense_1_shape,
229+
jagged_values_dtype=jagged_values_dtype,
230+
dense_0_dtype=dense_0_dtype,
231+
dense_1_dtype=dense_1_dtype,
232+
)
233+
nested_tensor = jagged_to_nested_tensor(jagged_values, jagged_offsets)
234+
# Yueming: in the future, if we integrate more input shapes for other jagged operators,
235+
# the dense_0 may be None. In that case, we should use another way to obtain the batch size
236+
# and max seq len.
237+
batch_size, max_seq_len, _ = dense_0_shape
238+
yield (nested_tensor, batch_size, 1, max_seq_len, 0.0)
209239

210240
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
211241
output = fn()

tritonbench/utils/jagged_utils.py

+51
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,57 @@ def generate_input_vals(B, M, max_seqlen, sparsity, sizes):
133133
return B_vals, M_vals, seqlen_vals, sparsity_vals
134134

135135

136+
def jagged_to_nested_tensor(values: torch.Tensor, offsets: list[torch.Tensor]):
137+
"""
138+
Convert jagged tensor (values + offsets) to torch.nested.nested_tensor
139+
140+
Args:
141+
values: Compressed values tensor
142+
offsets: List of offset tensors, indicating the starting position of each sequence
143+
144+
Returns:
145+
Tensor in torch.nested.nested_tensor format
146+
"""
147+
# Calculate the length of each sequence
148+
lengths = []
149+
for i in range(len(offsets)):
150+
if i == 0:
151+
# For the first layer, calculate the length of each batch
152+
batch_size = offsets[i].size(0) - 1
153+
batch_lengths = []
154+
for b in range(batch_size):
155+
batch_lengths.append(offsets[i][b + 1] - offsets[i][b])
156+
lengths.append(batch_lengths)
157+
else:
158+
# For deeper levels of nesting
159+
prev_lengths = lengths[i - 1]
160+
curr_lengths = []
161+
offset_idx = 0
162+
for prev_len in prev_lengths:
163+
seq_lengths = []
164+
for _ in range(prev_len):
165+
seq_lengths.append(
166+
offsets[i][offset_idx + 1] - offsets[i][offset_idx]
167+
)
168+
offset_idx += 1
169+
curr_lengths.append(seq_lengths)
170+
lengths.append(curr_lengths)
171+
172+
# Build tensor list based on lengths and values
173+
tensor_list = []
174+
start_idx = 0
175+
for b in range(len(lengths[0])):
176+
length = lengths[0][b]
177+
end_idx = start_idx + length
178+
tensor_list.append(values[start_idx:end_idx])
179+
start_idx = end_idx
180+
181+
# Create nested tensor
182+
return torch.nested.nested_tensor(
183+
tensor_list, layout=torch.jagged, device=values.device, dtype=values.dtype
184+
)
185+
186+
136187
def get_size_in_bytes(shape, dtype) -> int:
137188
num_elements = math.prod(shape)
138189
element_size = dtype.itemsize

tritonbench/utils/parser.py

+6
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ def get_parser(args=None):
193193
help="The directory to store input or output.",
194194
)
195195

196+
parser.add_argument(
197+
"--prod-shapes",
198+
action="store_true",
199+
help="Only run with pre-defined production shapes.",
200+
)
201+
196202
if IS_FBCODE:
197203
parser.add_argument("--log-scuba", action="store_true", help="Log to scuba.")
198204
parser.add_argument(

tritonbench/utils/triton_op.py

+1
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ def __init__(
724724
self._skip = _split_params_by_comma(self.tb_args.skip)
725725
self._input_id = self.tb_args.input_id
726726
self._num_inputs = self.tb_args.num_inputs
727+
self.prod_shapes = self.tb_args.prod_shapes
727728

728729
# Run the post initialization
729730
def __post__init__(self):

0 commit comments

Comments
 (0)