Skip to content

Commit 8466df7

Browse files
sundy-liTCeason
andauthored
fix(query): keep remaining_predicates when filtering grouping sets (#16971)
* fix(query): keep remaining_predicates when filtering grouping sets * update * update * update * update * Update 03_0003_select_group_by.test * update --------- Co-authored-by: TCeason <[email protected]>
1 parent 37de572 commit 8466df7

File tree

12 files changed

+863
-92
lines changed

12 files changed

+863
-92
lines changed

โ€Žsrc/query/ast/src/ast/format/syntax/query.rs

+9
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,15 @@ fn pretty_group_by(group_by: Option<GroupBy>) -> RcDoc<'static> {
272272
)
273273
.append(RcDoc::line())
274274
.append(RcDoc::text(")")),
275+
276+
GroupBy::Combined(sets) => RcDoc::line()
277+
.append(RcDoc::text("GROUP BY ").append(RcDoc::line().nest(NEST_FACTOR)))
278+
.append(
279+
interweave_comma(sets.into_iter().map(|s| RcDoc::text(s.to_string())))
280+
.nest(NEST_FACTOR)
281+
.group(),
282+
)
283+
.append(RcDoc::line()),
275284
}
276285
} else {
277286
RcDoc::nil()

โ€Žsrc/query/ast/src/ast/query.rs

+55-31
Original file line numberDiff line numberDiff line change
@@ -189,38 +189,8 @@ impl Display for SelectStmt {
189189
// GROUP BY clause
190190
if self.group_by.is_some() {
191191
write!(f, " GROUP BY ")?;
192-
match self.group_by.as_ref().unwrap() {
193-
GroupBy::Normal(exprs) => {
194-
write_comma_separated_list(f, exprs)?;
195-
}
196-
GroupBy::All => {
197-
write!(f, "ALL")?;
198-
}
199-
GroupBy::GroupingSets(sets) => {
200-
write!(f, "GROUPING SETS (")?;
201-
for (i, set) in sets.iter().enumerate() {
202-
if i > 0 {
203-
write!(f, ", ")?;
204-
}
205-
write!(f, "(")?;
206-
write_comma_separated_list(f, set)?;
207-
write!(f, ")")?;
208-
}
209-
write!(f, ")")?;
210-
}
211-
GroupBy::Cube(exprs) => {
212-
write!(f, "CUBE (")?;
213-
write_comma_separated_list(f, exprs)?;
214-
write!(f, ")")?;
215-
}
216-
GroupBy::Rollup(exprs) => {
217-
write!(f, "ROLLUP (")?;
218-
write_comma_separated_list(f, exprs)?;
219-
write!(f, ")")?;
220-
}
221-
}
192+
write!(f, "{}", self.group_by.as_ref().unwrap())?;
222193
}
223-
224194
// HAVING clause
225195
if let Some(having) = &self.having {
226196
write!(f, " HAVING {having}")?;
@@ -254,6 +224,60 @@ pub enum GroupBy {
254224
Cube(Vec<Expr>),
255225
/// GROUP BY ROLLUP ( expr [, expr]* )
256226
Rollup(Vec<Expr>),
227+
Combined(Vec<GroupBy>),
228+
}
229+
230+
impl GroupBy {
231+
pub fn normal_items(&self) -> Vec<Expr> {
232+
match self {
233+
GroupBy::Normal(exprs) => exprs.clone(),
234+
_ => Vec::new(),
235+
}
236+
}
237+
}
238+
239+
impl Display for GroupBy {
240+
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
241+
match self {
242+
GroupBy::Normal(exprs) => {
243+
write_comma_separated_list(f, exprs)?;
244+
}
245+
GroupBy::All => {
246+
write!(f, "ALL")?;
247+
}
248+
GroupBy::GroupingSets(sets) => {
249+
write!(f, "GROUPING SETS (")?;
250+
for (i, set) in sets.iter().enumerate() {
251+
if i > 0 {
252+
write!(f, ", ")?;
253+
}
254+
write!(f, "(")?;
255+
write_comma_separated_list(f, set)?;
256+
write!(f, ")")?;
257+
}
258+
write!(f, ")")?;
259+
}
260+
GroupBy::Cube(exprs) => {
261+
write!(f, "CUBE (")?;
262+
write_comma_separated_list(f, exprs)?;
263+
write!(f, ")")?;
264+
}
265+
GroupBy::Rollup(exprs) => {
266+
write!(f, "ROLLUP (")?;
267+
write_comma_separated_list(f, exprs)?;
268+
write!(f, ")")?;
269+
}
270+
GroupBy::Combined(group_bys) => {
271+
for (i, group_by) in group_bys.iter().enumerate() {
272+
if i > 0 {
273+
write!(f, ", ")?;
274+
}
275+
write!(f, "{}", group_by)?;
276+
}
277+
}
278+
}
279+
Ok(())
280+
}
257281
}
258282

259283
/// A relational set expression, like `SELECT ... FROM ... {UNION|EXCEPT|INTERSECT} SELECT ... FROM ...`

โ€Žsrc/query/ast/src/parser/query.rs

+23-6
Original file line numberDiff line numberDiff line change
@@ -1073,10 +1073,6 @@ impl<'a, I: Iterator<Item = WithSpan<'a, TableReferenceElement>>> PrattParser<I>
10731073
}
10741074

10751075
pub fn group_by_items(i: Input) -> IResult<GroupBy> {
1076-
let normal = map(rule! { ^#comma_separated_list1(expr) }, |groups| {
1077-
GroupBy::Normal(groups)
1078-
});
1079-
10801076
let all = map(rule! { ALL }, |_| GroupBy::All);
10811077

10821078
let cube = map(
@@ -1096,10 +1092,31 @@ pub fn group_by_items(i: Input) -> IResult<GroupBy> {
10961092
map(rule! { #expr }, |e| vec![e]),
10971093
));
10981094
let group_sets = map(
1099-
rule! { GROUPING ~ SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" },
1095+
rule! { GROUPING ~ ^SETS ~ "(" ~ ^#comma_separated_list1(group_set) ~ ")" },
11001096
|(_, _, _, sets, _)| GroupBy::GroupingSets(sets),
11011097
);
1102-
rule!(#all | #group_sets | #cube | #rollup | #normal)(i)
1098+
1099+
// New rule to handle multiple GroupBy items
1100+
let single_normal = map(rule! { #expr }, |group| GroupBy::Normal(vec![group]));
1101+
let group_by_item = alt((all, group_sets, cube, rollup, single_normal));
1102+
map(rule! { ^#comma_separated_list1(group_by_item) }, |items| {
1103+
if items.len() > 1 {
1104+
if items.iter().all(|item| matches!(item, GroupBy::Normal(_))) {
1105+
let items = items
1106+
.into_iter()
1107+
.flat_map(|item| match item {
1108+
GroupBy::Normal(exprs) => exprs,
1109+
_ => unreachable!(),
1110+
})
1111+
.collect();
1112+
GroupBy::Normal(items)
1113+
} else {
1114+
GroupBy::Combined(items)
1115+
}
1116+
} else {
1117+
items.into_iter().next().unwrap()
1118+
}
1119+
})(i)
11031120
}
11041121

11051122
pub fn window_frame_bound(i: Input) -> IResult<WindowFrameBound> {

โ€Žsrc/query/ast/tests/it/parser.rs

+4
Original file line numberDiff line numberDiff line change
@@ -591,12 +591,16 @@ fn test_statement() {
591591
"#,
592592
r#"SHOW FILE FORMATS"#,
593593
r#"DROP FILE FORMAT my_csv"#,
594+
r#"SELECT * FROM t GROUP BY all"#,
595+
r#"SELECT * FROM t GROUP BY a, b, c, d"#,
594596
r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, c, d)"#,
595597
r#"SELECT * FROM t GROUP BY GROUPING SETS (a, b, (c, d))"#,
596598
r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (c), (d, e))"#,
597599
r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b), (), (d, e))"#,
598600
r#"SELECT * FROM t GROUP BY CUBE (a, b, c)"#,
599601
r#"SELECT * FROM t GROUP BY ROLLUP (a, b, c)"#,
602+
r#"SELECT * FROM t GROUP BY a, ROLLUP (b, c)"#,
603+
r#"SELECT * FROM t GROUP BY GROUPING SETS ((a, b)), a, ROLLUP (b, c)"#,
600604
r#"CREATE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy'"#,
601605
r#"CREATE OR REPLACE MASKING POLICY email_mask AS (val STRING) RETURNS STRING -> CASE WHEN current_role() IN ('ANALYST') THEN VAL ELSE '*********'END comment = 'this is a masking policy'"#,
602606
r#"DESC MASKING POLICY email_mask"#,

0 commit comments

Comments
ย (0)