11/* clang-format off */
22/*
3- * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+ * SPDX-FileCopyrightText: Copyright (c) 2025-2026 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44 * SPDX-License-Identifier: Apache-2.0
55 */
66/* clang-format on */
@@ -44,6 +44,48 @@ namespace cuopt::linear_programming::dual_simplex {
4444
4545auto constexpr use_gpu = true ;
4646
47+ // non-template wrappers to work around clang compiler bug
48+ [[maybe_unused]] static void pairwise_multiply (
49+ float * a, float * b, float * out, int size, rmm::cuda_stream_view stream)
50+ {
51+ cub::DeviceTransform::Transform (
52+ cuda::std::make_tuple (a, b), out, size, cuda::std::multiplies<>{}, stream);
53+ }
54+
55+ [[maybe_unused]] static void pairwise_multiply (
56+ double * a, double * b, double * out, int size, rmm::cuda_stream_view stream)
57+ {
58+ cub::DeviceTransform::Transform (
59+ cuda::std::make_tuple (a, b), out, size, cuda::std::multiplies<>{}, stream);
60+ }
61+
62+ [[maybe_unused]] static void axpy (
63+ float alpha, float * x, float beta, float * y, float * out, int size, rmm::cuda_stream_view stream)
64+ {
65+ cub::DeviceTransform::Transform (
66+ cuda::std::make_tuple (x, y),
67+ out,
68+ size,
69+ [alpha, beta] __host__ __device__ (float a, float b) { return alpha * a + beta * b; },
70+ stream);
71+ }
72+
73+ [[maybe_unused]] static void axpy (double alpha,
74+ double * x,
75+ double beta,
76+ double * y,
77+ double * out,
78+ int size,
79+ rmm::cuda_stream_view stream)
80+ {
81+ cub::DeviceTransform::Transform (
82+ cuda::std::make_tuple (x, y),
83+ out,
84+ size,
85+ [alpha, beta] __host__ __device__ (double a, double b) { return alpha * a + beta * b; },
86+ stream);
87+ }
88+
4789template <typename i_t , typename f_t >
4890class iteration_data_t {
4991 public:
@@ -1404,12 +1446,7 @@ class iteration_data_t {
14041446
14051447 // diag.pairwise_product(x1, r1);
14061448 // r1 <- D * x_1
1407- thrust::transform (handle_ptr->get_thrust_policy (),
1408- d_x1.data (),
1409- d_x1.data () + n,
1410- d_diag_.data (),
1411- d_r1.data (),
1412- thrust::multiplies<f_t >());
1449+ pairwise_multiply (d_x1.data (), d_diag_.data (), d_r1.data (), n, stream_view_);
14131450
14141451 // r1 <- Q x1 + D x1
14151452 if (Q.n > 0 ) {
@@ -1419,12 +1456,7 @@ class iteration_data_t {
14191456
14201457 // y1 <- - alpha * r1 + beta * y1
14211458 // y1.axpy(-alpha, r1, beta);
1422- thrust::transform (handle_ptr->get_thrust_policy (),
1423- d_r1.data (),
1424- d_r1.data () + n,
1425- d_y1.data (),
1426- d_y1.data (),
1427- axpy_op<f_t >{-alpha, beta});
1459+ axpy (-alpha, d_r1.data (), beta, d_y1.data (), d_y1.data (), n, stream_view_);
14281460
14291461 // matrix_transpose_vector_multiply(A, alpha, x2, 1.0, y1);
14301462 cusparse_view_.transpose_spmv (alpha, d_x2, 1.0 , d_y1);
0 commit comments