@@ -316,7 +316,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) {
316316 const int threads = 1024 ;
317317 const int blocks = (batch + threads - 1 ) / threads;
318318
319- AT_DISPATCH_FLOATING_TYPES (y.scalar_type (), " isotonic_l2" , ([&] {
319+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (y.scalar_type (), " isotonic_l2" , ([&] {
320320 isotonic_l2_kernel<scalar_t ><<<blocks, threads>>> (
321321 y.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
322322 sol.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
@@ -342,7 +342,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
342342 const int threads = 1024 ;
343343 const int blocks = (batch + threads - 1 ) / threads;
344344
345- AT_DISPATCH_FLOATING_TYPES (y.scalar_type (), " isotonic_kl" , ([&] {
345+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (y.scalar_type (), " isotonic_kl" , ([&] {
346346 isotonic_kl_kernel<scalar_t ><<<blocks, threads>>> (
347347 y.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
348348 w.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
@@ -365,7 +365,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te
365365 const int threads = 1024 ;
366366 const int blocks = (batch + threads - 1 ) / threads;
367367
368- AT_DISPATCH_FLOATING_TYPES (sol.scalar_type (), " isotonic_l2_backward" , ([&] {
368+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (sol.scalar_type (), " isotonic_l2_backward" , ([&] {
369369 isotonic_l2_backward_kernel<scalar_t ><<<blocks, threads>>> (
370370 s.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
371371 sol.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
@@ -387,7 +387,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te
387387 const int threads = 1024 ;
388388 const int blocks = (batch + threads - 1 ) / threads;
389389
390- AT_DISPATCH_FLOATING_TYPES (sol.scalar_type (), " isotonic_kl_backward" , ([&] {
390+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (sol.scalar_type (), " isotonic_kl_backward" , ([&] {
391391 isotonic_kl_backward_kernel<scalar_t ><<<blocks, threads>>> (
392392 s.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
393393 sol.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
0 commit comments