Skip to content

Commit 27181b6

Browse files
authored
Add weight dtype validation to fix CI issue (#2770)
# Motivation pytorch/pytorch#172018 adds weight dtype validation to CUDA and MPS, then introduces CI issue that we have skipped in pytorch/pytorch#173335. This PR aims to add weight dtype to XPU so that it resolves pytorch/pytorch#173335
1 parent 45e4ded commit 27181b6

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/ATen/native/xpu/sycl/LossNLLKernel.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,15 @@ void nll_loss_forward_kernel(
406406

407407
auto weight_ = weight.defined() ? weight.contiguous() : weight;
408408

409+
if (weight_.defined()) {
410+
TORCH_CHECK(
411+
input.scalar_type() == weight_.scalar_type(),
412+
"expected scalar type ",
413+
input.scalar_type(),
414+
" but found ",
415+
weight_.scalar_type());
416+
}
417+
409418
if (reduction == at::Reduction::None && n_dims == 2) {
410419
at::native::resize_output(output, {batch_size});
411420
total_weight.zero_();

0 commit comments

Comments
 (0)