Skip to content

Commit b6cea64

Browse files
authored
Merge pull request ClickHouse#107185 from nihalzp/crosstab-deserialization-validation
Properly validate CrossTab aggregate function states on deserialization
2 parents 91dab24 + f0aa636 commit b6cea64

4 files changed

Lines changed: 167 additions & 20 deletions

File tree

src/AggregateFunctions/AggregateFunctionTheilsU.h

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,7 @@ struct TheilsUWindowData : CrossTabCountsState
153153

154154
void deserialize(ReadBuffer & buf)
155155
{
156-
clear();
157-
158-
readBinary(count, buf);
159-
count_a.read(buf);
160-
count_b.read(buf);
161-
count_ab.read(buf);
156+
CrossTabCountsState::deserialize(buf);
162157

163158
/// Restore cached Σ n logn sums
164159
sum_a_nlogn = recomputeNLogNSum(count_a);
@@ -218,17 +213,6 @@ struct TheilsUWindowData : CrossTabCountsState
218213
return xf * std::log(xf);
219214
}
220215

221-
void clear()
222-
{
223-
count = 0;
224-
count_a.clear();
225-
count_b.clear();
226-
count_ab.clear();
227-
sum_a_nlogn = 0.0;
228-
sum_b_nlogn = 0.0;
229-
sum_ab_nlogn = 0.0;
230-
}
231-
232216
template <typename Map, typename Key>
233217
static void addToCountAndSum(Map & map, const Key & key, UInt64 add_value, Float64 & sum_xlogx)
234218
{

src/AggregateFunctions/CrossTab.h

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <AggregateFunctions/IAggregateFunction.h>
44
#include <AggregateFunctions/UniqVariadicHash.h>
55
#include <DataTypes/DataTypesNumber.h>
6+
#include <base/arithmeticOverflow.h>
67
#include <Common/Exception.h>
78
#include <Common/HashTable/HashMap.h>
89
#include <Common/VectorWithMemoryTracking.h>
@@ -82,6 +83,84 @@ struct CrossTabCountsState
8283
count_a.read(buf);
8384
count_b.read(buf);
8485
count_ab.read(buf);
86+
87+
validateDeserialized(count, count_a, count_b, count_ab);
88+
}
89+
90+
/// Aggregate function states can be constructed from untrusted data, e.g. by `CAST` from `String`,
91+
/// so deserialization has to check all invariants that the calculations rely on
92+
/// (a genuine state produced by `add` and `merge` satisfies them by construction):
93+
/// - all joint counts in `count_ab` are positive;
94+
/// - `count_a` and `count_b` are exactly the marginal sums of `count_ab`;
95+
/// - the joint counts sum up to `count`.
96+
/// These invariants guarantee that the state describes a valid contingency table, which in turn
97+
/// guarantees the theoretical bounds asserted during finalization (e.g. 0 <= φ² <= min(|A|, |B|) - 1).
98+
static void validateDeserialized(UInt64 count, const auto & count_a, const auto & count_b, const auto & count_ab)
99+
{
100+
UInt64 total = 0;
101+
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> marginal_a;
102+
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> marginal_b;
103+
104+
for (const auto & [key, value] : count_ab)
105+
{
106+
if (value == 0)
107+
throw Exception(
108+
ErrorCodes::CORRUPTED_DATA,
109+
"Corrupted aggregate function state: the joint count of the value pair with hashes ({}, {}) is zero",
110+
key.items[UInt128::_impl::little(0)],
111+
key.items[UInt128::_impl::little(1)]);
112+
113+
if (common::addOverflow(total, value, total))
114+
throw Exception(ErrorCodes::CORRUPTED_DATA, "Corrupted aggregate function state: the sum of joint counts overflows UInt64");
115+
116+
/// The marginal sums cannot overflow if the total sum does not: each of them is bounded by `total`.
117+
marginal_a[key.items[UInt128::_impl::little(0)]] += value;
118+
marginal_b[key.items[UInt128::_impl::little(1)]] += value;
119+
}
120+
121+
if (total != count)
122+
throw Exception(
123+
ErrorCodes::CORRUPTED_DATA,
124+
"Corrupted aggregate function state: the joint counts sum up to {}, while the total count is {}",
125+
total,
126+
count);
127+
128+
auto check_marginals = [](const auto & expected, const auto & stored, const char * side)
129+
{
130+
if (expected.size() != stored.size())
131+
throw Exception(
132+
ErrorCodes::CORRUPTED_DATA,
133+
"Corrupted aggregate function state: there are {} distinct values of the {} argument, "
134+
"while the joint counts imply {}",
135+
stored.size(),
136+
side,
137+
expected.size());
138+
139+
for (const auto & [key, value] : expected)
140+
{
141+
const auto * it = stored.find(key);
142+
if (it == stored.end())
143+
throw Exception(
144+
ErrorCodes::CORRUPTED_DATA,
145+
"Corrupted aggregate function state: the value with hash {} of the {} argument is present "
146+
"in the joint counts but has no marginal count",
147+
key,
148+
side);
149+
150+
if (it->getMapped() != value)
151+
throw Exception(
152+
ErrorCodes::CORRUPTED_DATA,
153+
"Corrupted aggregate function state: the marginal count of the value with hash {} "
154+
"of the {} argument is {}, while its joint counts sum up to {}",
155+
key,
156+
side,
157+
it->getMapped(),
158+
value);
159+
}
160+
};
161+
162+
check_marginals(marginal_a, count_a, "first");
163+
check_marginals(marginal_b, count_b, "second");
85164
}
86165
};
87166

@@ -507,6 +586,8 @@ struct CrossTabPhiSquaredWindowData
507586
count_ab.size(),
508587
INVALID_EDGE_IDX - 1);
509588

589+
CrossTabCountsState::validateDeserialized(count, count_a, count_b, count_ab);
590+
510591
/// Update the internal states
511592
a_hash_by_index.reserve(count_a.size());
512593
a_marginal_count.reserve(count_a.size());
@@ -561,9 +642,7 @@ struct CrossTabPhiSquaredWindowData
561642
const Float64 a = static_cast<Float64>(a_marginal_count[a_idx]);
562643
const Float64 b = static_cast<Float64>(b_marginal_count[b_idx]);
563644

564-
if (unlikely(a <= 0 || b <= 0))
565-
throw Exception(
566-
ErrorCodes::CORRUPTED_DATA, "Corrupted aggregate function state: value frequency must be positive (a={}, b={})", a, b);
645+
chassert(a > 0 && b > 0 && "value frequencies are positive after validation");
567646

568647
phi_term_sum += phiTerm(cnt_ab, a, b);
569648
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
OK fuzzer query
2+
OK foreign state
3+
0
4+
OK total count mismatch
5+
OK marginal count mismatch
6+
OK zero pair count
7+
OK cramersV
8+
OK cramersVBiasCorrected
9+
OK theilsU
10+
1 1 1 1
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#!/usr/bin/env bash
2+
3+
# Deserialization of CrossTab-family aggregate function states (`contingency`, `cramersV`,
4+
# `cramersVBiasCorrected`, `theilsU`) must validate that the counts form a valid contingency table:
5+
# the joint counts are positive, the value counts are exactly the marginal sums of the pair counts,
6+
# and the pair counts sum up to the total count. Otherwise finalization of a forged state
7+
# (e.g. constructed by `CAST` from `String`) could fail an assertion or divide by zero.
8+
9+
10+
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
11+
# shellcheck source=../shell_config.sh
12+
. "$CUR_DIR"/../shell_config.sh
13+
14+
function expect_corrupted_data()
15+
{
16+
local name="$1"
17+
local query="$2"
18+
$CLICKHOUSE_LOCAL --query "$query" 2>&1 \
19+
| grep -q -F 'CORRUPTED_DATA' && echo "OK $name" || echo "FAIL $name"
20+
}
21+
22+
# The exact query found by the fuzzer: a window `argMin` state reinterpreted as a `contingency`
23+
# state. It previously failed the assertion `phi_squared > -1e-4`.
24+
expect_corrupted_data 'fuzzer query' "
25+
SELECT round(roundtrip, 2147483647) AS roundtrip, abs(direct - roundtrip) < 1e-9, round(direct) AS direct FROM (SELECT finalizeAggregation(st_win) AS direct, finalizeAggregation(CAST(CAST(st_win, 'String'), 'AggregateFunction(contingency, UInt8, UInt8)')) AS roundtrip FROM (SELECT argMinState(toUInt8(number % 10), *) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS st_win FROM numbers(10000) ORDER BY number DESC NULLS FIRST LIMIT 9223372036854775807))"
26+
27+
# A simplified version of the fuzzer query: the same foreign state built by a plain aggregation.
28+
expect_corrupted_data 'foreign state' "
29+
SELECT finalizeAggregation(CAST(CAST(argMinState(toUInt8(number % 10), number), 'String'), 'AggregateFunction(contingency, UInt8, UInt8)'))
30+
FROM numbers(10000)"
31+
32+
# The following byte strings are based on the serialization of `contingencyState(toUInt8(1), toUInt8(2))`
33+
# over two rows: total count 2, then three hash maps (count_a, count_b, count_ab), each serialized as
34+
# a varint size followed by (key, value) pairs.
35+
36+
# A genuine state: deserializes and finalizes normally.
37+
$CLICKHOUSE_LOCAL --query "
38+
SELECT finalizeAggregation(CAST(unhex('0200000000000000014769A5692A310962020000000000000001D2D33DBDEA4FD1AC0200000000000000014769A5692A310962D2D33DBDEA4FD1AC0200000000000000'), 'AggregateFunction(contingency, UInt8, UInt8)'))"
39+
40+
# The total count is changed from 2 to 3: it does not match the sum of the pair counts.
41+
expect_corrupted_data 'total count mismatch' "
42+
SELECT finalizeAggregation(CAST(unhex('0300000000000000014769A5692A310962020000000000000001D2D33DBDEA4FD1AC0200000000000000014769A5692A310962D2D33DBDEA4FD1AC0200000000000000'), 'AggregateFunction(contingency, UInt8, UInt8)'))"
43+
44+
# The count of the first value is changed from 2 to 3: it does not match the marginal sum of the pair counts.
45+
expect_corrupted_data 'marginal count mismatch' "
46+
SELECT finalizeAggregation(CAST(unhex('0200000000000000014769A5692A310962030000000000000001D2D33DBDEA4FD1AC0200000000000000014769A5692A310962D2D33DBDEA4FD1AC0200000000000000'), 'AggregateFunction(contingency, UInt8, UInt8)'))"
47+
48+
# The pair count is changed from 2 to 0: pair counts must be positive.
49+
expect_corrupted_data 'zero pair count' "
50+
SELECT finalizeAggregation(CAST(unhex('0200000000000000014769A5692A310962020000000000000001D2D33DBDEA4FD1AC0200000000000000014769A5692A310962D2D33DBDEA4FD1AC0000000000000000'), 'AggregateFunction(contingency, UInt8, UInt8)'))"
51+
52+
# All functions of the family share the state format and the validation.
53+
for func in cramersV cramersVBiasCorrected theilsU
54+
do
55+
expect_corrupted_data "$func" "
56+
SELECT finalizeAggregation(CAST(unhex('0300000000000000014769A5692A310962020000000000000001D2D33DBDEA4FD1AC0200000000000000014769A5692A310962D2D33DBDEA4FD1AC0200000000000000'), 'AggregateFunction($func, UInt8, UInt8)'))"
57+
done
58+
59+
# Genuine states still survive a roundtrip through `String` and finalize to the same results.
60+
$CLICKHOUSE_LOCAL --query "
61+
SELECT
62+
finalizeAggregation(st_c) = finalizeAggregation(CAST(CAST(st_c, 'String'), 'AggregateFunction(contingency, UInt16, UInt16)')),
63+
finalizeAggregation(st_v) = finalizeAggregation(CAST(CAST(st_v, 'String'), 'AggregateFunction(cramersV, UInt16, UInt16)')),
64+
finalizeAggregation(st_b) = finalizeAggregation(CAST(CAST(st_b, 'String'), 'AggregateFunction(cramersVBiasCorrected, UInt16, UInt16)')),
65+
finalizeAggregation(st_u) = finalizeAggregation(CAST(CAST(st_u, 'String'), 'AggregateFunction(theilsU, UInt16, UInt16)'))
66+
FROM
67+
(
68+
SELECT
69+
contingencyState(toUInt16(number % 17), toUInt16(number % 9)) AS st_c,
70+
cramersVState(toUInt16(number % 17), toUInt16(number % 9)) AS st_v,
71+
cramersVBiasCorrectedState(toUInt16(number % 17), toUInt16(number % 9)) AS st_b,
72+
theilsUState(toUInt16(number % 17), toUInt16(number % 9)) AS st_u
73+
FROM numbers(1000)
74+
)"

0 commit comments

Comments
 (0)