Skip to content

Commit d0c69e4

Browse files
lchen2331CuiYifeng
authored andcommitted
Fix bug
1 parent 64ca2fa commit d0c69e4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ATen/Dispatch.h>
1313
#include <ATen/core/TensorAccessor.h>
1414
#include <ATen/native/StridedRandomAccessor.h>
15+
#include <ATen/native/nested/NestedTensorUtils.h>
1516
#include <ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h>
1617
#include <comm/SYCLContext.h>
1718

@@ -953,13 +954,14 @@ at::Tensor _fbgemm_jagged_to_padded_dense_forward_kernel(
953954
values.scalar_type(),
954955
"jagged_to_padded_dense_xpu",
955956
[&] {
957+
scalar_t fill_value = at::native::_get_padding_value<scalar_t>(padding_value, values.is_floating_point());
956958
jagged_dense_elementwise_dense_template<scalar_t>(
957959
values_canonicalized,
958960
offsets.vec(),
959961
padded_values_view, // dummy not used in the lambda function
960962
padded_values_view,
961963
PaddingValueFuncutor<scalar_t>(),
962-
static_cast<scalar_t>(padding_value));
964+
fill_value);
963965
});
964966

965967
return padded_values;

0 commit comments

Comments
 (0)