1616
1717#include " round_float.hpp"
1818
19- #include < cuda/std/type_traits>
20-
2119#include < cudf/column/column_factories.hpp>
2220#include < cudf/copying.hpp>
2321#include < cudf/detail/null_mask.hpp>
3028#include < rmm/cuda_stream_view.hpp>
3129#include < rmm/exec_policy.hpp>
3230
31+ #include < cuda/std/type_traits>
3332#include < thrust/transform.h>
3433
3534#include < cmath>
3635#include < type_traits>
3736
3837namespace spark_rapids_jni {
3938
40-
4139inline float __device__ generic_round (float f) { return roundf (f); }
4240inline double __device__ generic_round (double d) { return ::round (d); }
4341
@@ -47,14 +45,10 @@ inline double __device__ generic_round_half_even(double d) { return rint(d); }
4745inline __device__ float generic_modf (float a, float * b) { return modff (a, b); }
4846inline __device__ double generic_modf (double a, double * b) { return modf (a, b); }
4947
50-
5148template <typename T>
5249struct half_up_zero {
5350 T n; // unused in the decimal_places = 0 case
54- __device__ T operator ()(T e)
55- {
56- return generic_round (e);
57- }
51+ __device__ T operator ()(T e) { return generic_round (e); }
5852};
5953
6054template <typename T>
@@ -71,19 +65,13 @@ struct half_up_positive {
7165template <typename T>
7266struct half_up_negative {
7367 T n;
74- __device__ T operator ()(T e)
75- {
76- return generic_round (e / n) * n;
77- }
68+ __device__ T operator ()(T e) { return generic_round (e / n) * n; }
7869};
7970
8071template <typename T>
8172struct half_even_zero {
8273 T n; // unused in the decimal_places = 0 case
83- __device__ T operator ()(T e)
84- {
85- return generic_round_half_even (e);
86- }
74+ __device__ T operator ()(T e) { return generic_round_half_even (e); }
8775};
8876
8977template <typename T>
@@ -100,18 +88,15 @@ struct half_even_positive {
10088template <typename T>
10189struct half_even_negative {
10290 T n;
103- __device__ T operator ()(T e)
104- {
105- return generic_round_half_even (e / n) * n;
106- }
91+ __device__ T operator ()(T e) { return generic_round_half_even (e / n) * n; }
10792};
10893
10994template <typename T, template <typename > typename RoundFunctor>
11095std::unique_ptr<cudf::column> round_with (cudf::column_view const & input,
11196 int32_t decimal_places,
11297 rmm::cuda_stream_view stream,
11398 rmm::device_async_resource_ref mr)
114- requires(std::is_floating_point_v<T>)
99+ requires(std::is_floating_point_v<T>)
115100{
116101 using Functor = RoundFunctor<T>;
117102
@@ -140,7 +125,7 @@ struct round_type_dispatcher {
140125 {
141126 CUDF_FAIL (" Type not supported for spark_rapids_jni::round" );
142127 }
143-
128+
144129 template <typename T>
145130 std::unique_ptr<cudf::column> operator ()(cudf::column_view const & input,
146131 int32_t decimal_places,
@@ -175,15 +160,13 @@ std::unique_ptr<cudf::column> round(cudf::column_view const& input,
175160 CUDF_EXPECTS (cudf::is_numeric (input.type ()) || cudf::is_fixed_point (input.type ()),
176161 " Only integral/floating point/fixed point currently supported." );
177162
178- if (!cudf::is_floating_point (input.type ())) {
163+ if (!cudf::is_floating_point (input.type ())) {
179164 return cudf::round_decimal (input, decimal_places, method, stream, mr);
180165 }
181- if (input.is_empty ()) {
182- return cudf::empty_like (input);
183- }
166+ if (input.is_empty ()) { return cudf::empty_like (input); }
184167
185168 return cudf::type_dispatcher (
186169 input.type (), round_type_dispatcher{}, input, decimal_places, method, stream, mr);
187170}
188171
189- } // namespace spark_rapids_jni
172+ } // namespace spark_rapids_jni
0 commit comments