Skip to content

Commit de37f06

Browse files
fix formatting
Signed-off-by: Paul Mattione <pmattione@nvidia.com>
1 parent b563f59 commit de37f06

3 files changed

Lines changed: 21 additions & 42 deletions

File tree

src/main/cpp/src/ArithmeticJni.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,8 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Arithmetic_multiply(JNI
6161
CATCH_EXCEPTION_WITH_ROW_INDEX(env, 0);
6262
}
6363

64-
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Arithmetic_round(JNIEnv* env,
65-
jclass,
66-
jlong input_ptr,
67-
jint decimal_places,
68-
jint rounding_method)
64+
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Arithmetic_round(
65+
JNIEnv* env, jclass, jlong input_ptr, jint decimal_places, jint rounding_method)
6966
{
7067
JNI_NULL_CHECK(env, input_ptr, "input is null", 0);
7168
try {
@@ -76,5 +73,4 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Arithmetic_round(JNIEnv
7673
}
7774
CATCH_STD(env, 0);
7875
}
79-
8076
}

src/main/cpp/src/round_float.cu

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
#include "round_float.hpp"
1818

19-
#include <cuda/std/type_traits>
20-
2119
#include <cudf/column/column_factories.hpp>
2220
#include <cudf/copying.hpp>
2321
#include <cudf/detail/null_mask.hpp>
@@ -30,14 +28,14 @@
3028
#include <rmm/cuda_stream_view.hpp>
3129
#include <rmm/exec_policy.hpp>
3230

31+
#include <cuda/std/type_traits>
3332
#include <thrust/transform.h>
3433

3534
#include <cmath>
3635
#include <type_traits>
3736

3837
namespace spark_rapids_jni {
3938

40-
4139
inline float __device__ generic_round(float f) { return roundf(f); }
4240
inline double __device__ generic_round(double d) { return ::round(d); }
4341

@@ -47,14 +45,10 @@ inline double __device__ generic_round_half_even(double d) { return rint(d); }
4745
inline __device__ float generic_modf(float a, float* b) { return modff(a, b); }
4846
inline __device__ double generic_modf(double a, double* b) { return modf(a, b); }
4947

50-
5148
template <typename T>
5249
struct half_up_zero {
5350
T n; // unused in the decimal_places = 0 case
54-
__device__ T operator()(T e)
55-
{
56-
return generic_round(e);
57-
}
51+
__device__ T operator()(T e) { return generic_round(e); }
5852
};
5953

6054
template <typename T>
@@ -71,19 +65,13 @@ struct half_up_positive {
7165
template <typename T>
7266
struct half_up_negative {
7367
T n;
74-
__device__ T operator()(T e)
75-
{
76-
return generic_round(e / n) * n;
77-
}
68+
__device__ T operator()(T e) { return generic_round(e / n) * n; }
7869
};
7970

8071
template <typename T>
8172
struct half_even_zero {
8273
T n; // unused in the decimal_places = 0 case
83-
__device__ T operator()(T e)
84-
{
85-
return generic_round_half_even(e);
86-
}
74+
__device__ T operator()(T e) { return generic_round_half_even(e); }
8775
};
8876

8977
template <typename T>
@@ -100,18 +88,15 @@ struct half_even_positive {
10088
template <typename T>
10189
struct half_even_negative {
10290
T n;
103-
__device__ T operator()(T e)
104-
{
105-
return generic_round_half_even(e / n) * n;
106-
}
91+
__device__ T operator()(T e) { return generic_round_half_even(e / n) * n; }
10792
};
10893

10994
template <typename T, template <typename> typename RoundFunctor>
11095
std::unique_ptr<cudf::column> round_with(cudf::column_view const& input,
11196
int32_t decimal_places,
11297
rmm::cuda_stream_view stream,
11398
rmm::device_async_resource_ref mr)
114-
requires(std::is_floating_point_v<T>)
99+
requires(std::is_floating_point_v<T>)
115100
{
116101
using Functor = RoundFunctor<T>;
117102

@@ -140,7 +125,7 @@ struct round_type_dispatcher {
140125
{
141126
CUDF_FAIL("Type not supported for spark_rapids_jni::round");
142127
}
143-
128+
144129
template <typename T>
145130
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
146131
int32_t decimal_places,
@@ -175,15 +160,13 @@ std::unique_ptr<cudf::column> round(cudf::column_view const& input,
175160
CUDF_EXPECTS(cudf::is_numeric(input.type()) || cudf::is_fixed_point(input.type()),
176161
"Only integral/floating point/fixed point currently supported.");
177162

178-
if(!cudf::is_floating_point(input.type())) {
163+
if (!cudf::is_floating_point(input.type())) {
179164
return cudf::round_decimal(input, decimal_places, method, stream, mr);
180165
}
181-
if (input.is_empty()) {
182-
return cudf::empty_like(input);
183-
}
166+
if (input.is_empty()) { return cudf::empty_like(input); }
184167

185168
return cudf::type_dispatcher(
186169
input.type(), round_type_dispatcher{}, input, decimal_places, method, stream, mr);
187170
}
188171

189-
} // namespace spark_rapids_jni
172+
} // namespace spark_rapids_jni

src/main/cpp/src/round_float.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
16+
1717
#pragma once
1818

1919
#include <cudf/column/column.hpp>
@@ -59,11 +59,11 @@ namespace spark_rapids_jni {
5959
*
6060
* @return Column with each of the values rounded
6161
*/
62-
std::unique_ptr<cudf::column> round(
63-
cudf::column_view const& input,
64-
int32_t decimal_places = 0,
65-
cudf::rounding_method method = cudf::rounding_method::HALF_UP,
66-
rmm::cuda_stream_view stream = cudf::get_default_stream(),
67-
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());
68-
69-
} // namespace spark_rapids_jni
62+
std::unique_ptr<cudf::column> round(
63+
cudf::column_view const& input,
64+
int32_t decimal_places = 0,
65+
cudf::rounding_method method = cudf::rounding_method::HALF_UP,
66+
rmm::cuda_stream_view stream = cudf::get_default_stream(),
67+
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());
68+
69+
} // namespace spark_rapids_jni

0 commit comments

Comments
 (0)