Skip to content
17 changes: 12 additions & 5 deletions cpp/src/arrow/array/statistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>
#include <variant>

#include "arrow/compare.h"
#include "arrow/type.h"
#include "arrow/util/visibility.h"

Expand Down Expand Up @@ -127,11 +128,17 @@ struct ARROW_EXPORT ArrayStatistics {
/// \brief Whether the maximum value is exact or not
bool is_max_exact = false;

/// \brief Check two statistics for equality
bool Equals(const ArrayStatistics& other) const {
return null_count == other.null_count && distinct_count == other.distinct_count &&
min == other.min && is_min_exact == other.is_min_exact && max == other.max &&
is_max_exact == other.is_max_exact;
/// \brief Check two \ref arrow::ArrayStatistics for equality
///
/// \param other The \ref arrow::ArrayStatistics instance to compare against.
///
/// \param equal_options Options used to compare double values for equality.
///
/// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise,
/// false.
bool Equals(const ArrayStatistics& other,
const EqualOptions& equal_options = EqualOptions::Defaults()) const {
return ArrayStatisticsEquals(*this, other, equal_options);
}

/// \brief Check two statistics for equality
Expand Down
64 changes: 59 additions & 5 deletions cpp/src/arrow/array/statistics_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,33 @@
// specific language governing permissions and limitations
// under the License.

#include <limits>
#include <variant>

#include <gtest/gtest.h>

#include "arrow/array/statistics.h"
#include "arrow/compare.h"

namespace arrow {

TEST(ArrayStatisticsTest, TestNullCount) {
TEST(TestArrayStatistics, NullCount) {
ArrayStatistics statistics;
ASSERT_FALSE(statistics.null_count.has_value());
statistics.null_count = 29;
ASSERT_TRUE(statistics.null_count.has_value());
ASSERT_EQ(29, statistics.null_count.value());
}

TEST(ArrayStatisticsTest, TestDistinctCount) {
TEST(TestArrayStatistics, DistinctCount) {
ArrayStatistics statistics;
ASSERT_FALSE(statistics.distinct_count.has_value());
statistics.distinct_count = 29;
ASSERT_TRUE(statistics.distinct_count.has_value());
ASSERT_EQ(29, statistics.distinct_count.value());
}

TEST(ArrayStatisticsTest, TestMin) {
TEST(TestArrayStatistics, Min) {
ArrayStatistics statistics;
ASSERT_FALSE(statistics.min.has_value());
ASSERT_FALSE(statistics.is_min_exact);
Expand All @@ -49,7 +53,7 @@ TEST(ArrayStatisticsTest, TestMin) {
ASSERT_TRUE(statistics.is_min_exact);
}

TEST(ArrayStatisticsTest, TestMax) {
TEST(TestArrayStatistics, Max) {
ArrayStatistics statistics;
ASSERT_FALSE(statistics.max.has_value());
ASSERT_FALSE(statistics.is_max_exact);
Expand All @@ -61,7 +65,7 @@ TEST(ArrayStatisticsTest, TestMax) {
ASSERT_FALSE(statistics.is_max_exact);
}

TEST(ArrayStatisticsTest, TestEquality) {
TEST(TestArrayStatistics, EqualityNonDoulbeValue) {
ArrayStatistics statistics1;
ArrayStatistics statistics2;

Expand Down Expand Up @@ -96,6 +100,56 @@ TEST(ArrayStatisticsTest, TestEquality) {
ASSERT_NE(statistics1, statistics2);
statistics2.is_max_exact = true;
ASSERT_EQ(statistics1, statistics2);

// Test different ArrayStatistics::ValueType
statistics1.max = static_cast<uint64_t>(29);
statistics1.max = static_cast<int64_t>(29);
ASSERT_NE(statistics1, statistics2);
}

class TestArrayStatisticsEqualityDoubleValue : public ::testing::Test {
protected:
ArrayStatistics statistics1_;
ArrayStatistics statistics2_;
EqualOptions options_ = EqualOptions::Defaults();
};

TEST_F(TestArrayStatisticsEqualityDoubleValue, ExactValue) {
statistics2_.min = 29.0;
statistics1_.min = 29.0;
ASSERT_EQ(statistics1_, statistics2_);
statistics2_.min = 30.0;
ASSERT_NE(statistics1_, statistics2_);
}

TEST_F(TestArrayStatisticsEqualityDoubleValue, SignedZero) {
statistics1_.min = +0.0;
statistics2_.min = -0.0;
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.signed_zeros_equal(true)));
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.signed_zeros_equal(false)));
}

TEST_F(TestArrayStatisticsEqualityDoubleValue, Infinity) {
auto infinity = std::numeric_limits<double>::infinity();
statistics1_.min = infinity;
statistics2_.min = infinity;
ASSERT_EQ(statistics1_, statistics2_);
statistics1_.min = -infinity;
ASSERT_NE(statistics1_, statistics2_);
}

TEST_F(TestArrayStatisticsEqualityDoubleValue, NaN) {
statistics1_.min = std::numeric_limits<double>::quiet_NaN();
statistics2_.min = std::numeric_limits<double>::quiet_NaN();
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.nans_equal(true)));
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.nans_equal(false)));
}

TEST_F(TestArrayStatisticsEqualityDoubleValue, ApproximateEquals) {
statistics1_.max = 0.5001f;
statistics2_.max = 0.5;
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(false)));
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3)));
}

} // namespace arrow
54 changes: 54 additions & 0 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "arrow/array.h"
#include "arrow/array/diff.h"
#include "arrow/array/statistics.h"
#include "arrow/buffer.h"
#include "arrow/scalar.h"
#include "arrow/sparse_tensor.h"
Expand Down Expand Up @@ -1523,4 +1526,55 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata
}
}

namespace {

bool DoubleEquals(const double& left, const double& right, const EqualOptions& options) {
bool result;
auto visitor = [&](auto&& compare_func) { result = compare_func(left, right); };
VisitFloatingEquality<double>(options, options.use_atol(), std::move(visitor));
return result;
}

bool ArrayStatisticsValueTypeEquals(
const std::optional<ArrayStatistics::ValueType>& left,
const std::optional<ArrayStatistics::ValueType>& right, const EqualOptions& options) {
if (!left.has_value() || !right.has_value()) {
return left.has_value() == right.has_value();
} else if (left->index() != right->index()) {
return false;
} else {
auto EqualsVisitor = [&](const auto& v1, const auto& v2) {
using type_1 = std::decay_t<decltype(v1)>;
using type_2 = std::decay_t<decltype(v2)>;
if constexpr (std::conjunction_v<std::is_same<type_1, double>,
std::is_same<type_2, double>>) {
return DoubleEquals(v1, v2, options);
} else if constexpr (std::is_same_v<type_1, type_2>) {
return v1 == v2;
}
// It is unreachable
DCHECK(false);
return false;
};
return std::visit(EqualsVisitor, left.value(), right.value());
}
}

bool ArrayStatisticsEqualsImpl(const ArrayStatistics& left, const ArrayStatistics& right,
const EqualOptions& equal_options) {
return left.null_count == right.null_count &&
left.distinct_count == right.distinct_count &&
left.is_min_exact == right.is_min_exact &&
left.is_max_exact == right.is_max_exact &&
ArrayStatisticsValueTypeEquals(left.min, right.min, equal_options) &&
ArrayStatisticsValueTypeEquals(left.max, right.max, equal_options);
}

} // namespace

bool ArrayStatisticsEquals(const ArrayStatistics& left, const ArrayStatistics& right,
const EqualOptions& options) {
return ArrayStatisticsEqualsImpl(left, right, options);
}

} // namespace arrow
23 changes: 23 additions & 0 deletions cpp/src/arrow/compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

namespace arrow {

struct ArrayStatistics;
class Array;
class DataType;
class Tensor;
Expand Down Expand Up @@ -58,7 +59,18 @@ class EqualOptions {
return res;
}

/// Whether the "atol" property is used in the comparison.
bool use_atol() const { return use_atol_; }

/// Return a new EqualOptions object with the "use_atol" property changed.
EqualOptions use_atol(bool v) const {
auto res = EqualOptions(*this);
res.use_atol_ = v;
return res;
}

/// The absolute tolerance for approximate comparisons of floating-point values.
/// Note that this option is ignored if "use_atol" is set to false.
double atol() const { return atol_; }

/// Return a new EqualOptions object with the "atol" property changed.
Expand Down Expand Up @@ -87,6 +99,7 @@ class EqualOptions {
double atol_ = kDefaultAbsoluteTolerance;
bool nans_equal_ = false;
bool signed_zeros_equal_ = true;
bool use_atol_ = true;

std::ostream* diff_sink_ = NULLPTR;
};
Expand Down Expand Up @@ -135,6 +148,16 @@ ARROW_EXPORT bool SparseTensorEquals(const SparseTensor& left, const SparseTenso
ARROW_EXPORT bool TypeEquals(const DataType& left, const DataType& right,
bool check_metadata = true);

/// \brief Check two \ref arrow::ArrayStatistics for equality
/// \param[in] left an \ref arrow::ArrayStatistics
/// \param[in] right an \ref arrow::ArrayStatistics
/// \param[in] options Options used to compare double values for equality.
/// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise,
/// false.
ARROW_EXPORT bool ArrayStatisticsEquals(
const ArrayStatistics& left, const ArrayStatistics& right,
const EqualOptions& options = EqualOptions::Defaults());

/// Returns true if scalars are equal
/// \param[in] left a Scalar
/// \param[in] right a Scalar
Expand Down
Loading