Skip to content

Commit 70d6ac0

Browse files
Merge branch 'main' into use-cross-proj-mode-decider-in-rcs
2 parents 3a10c99 + bad616e commit 70d6ac0

File tree

5 files changed

+190
-0
lines changed

5 files changed

+190
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.analysis;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Alias;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.Literal;
13+
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
14+
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
15+
import org.elasticsearch.xpack.esql.core.type.DataType;
16+
import org.elasticsearch.xpack.esql.core.util.Holder;
17+
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
18+
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
19+
import org.elasticsearch.xpack.esql.expression.function.aggregate.HistogramMergeOverTime;
20+
import org.elasticsearch.xpack.esql.expression.function.aggregate.LastOverTime;
21+
import org.elasticsearch.xpack.esql.expression.function.aggregate.TimeSeriesAggregateFunction;
22+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
23+
import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
24+
import org.elasticsearch.xpack.esql.rule.Rule;
25+
26+
import java.util.List;
27+
28+
/**
29+
* Ensures that {@link TypedAttribute}s used inside a {@link TimeSeriesAggregate} are wrapped in a
30+
* {@link TimeSeriesAggregateFunction}.
31+
* Examples:
32+
* <pre>
33+
* foo + bar ->
34+
* LAST_OVER_TIME(foo) + LAST_OVER_TIME(bar)
35+
*
36+
* SUM(foo + LAST_OVER_TIME(bar)) ->
37+
* SUM(LAST_OVER_TIME(foo) + LAST_OVER_TIME(bar))
38+
*
39+
* foo / 2 + bar * 2 ->
40+
* LAST_OVER_TIME(foo) / 2 + LAST_OVER_TIME(bar) * 2
41+
* </pre>
42+
*/
43+
public class InsertDefaultInnerTimeSeriesAggregate extends Rule<LogicalPlan, LogicalPlan> {
44+
@Override
45+
public LogicalPlan apply(LogicalPlan logicalPlan) {
46+
return logicalPlan.transformUp(node -> node instanceof TimeSeriesAggregate, this::rule);
47+
}
48+
49+
public LogicalPlan rule(TimeSeriesAggregate aggregate) {
50+
Holder<Boolean> changed = new Holder<>(false);
51+
List<NamedExpression> newAggregates = aggregate.aggregates().stream().map(agg -> {
52+
if (agg instanceof Alias alias) {
53+
return alias.replaceChild(addDefaultInnerAggs(alias.child(), aggregate.timestamp(), changed));
54+
} else {
55+
return agg;
56+
}
57+
}).toList();
58+
if (changed.get() == false) {
59+
return aggregate;
60+
}
61+
return aggregate.with(aggregate.groupings(), newAggregates);
62+
}
63+
64+
private static Expression addDefaultInnerAggs(Expression expression, Expression timestamp, Holder<Boolean> changed) {
65+
return expression.transformDownSkipBranch((expr, skipBranch) -> {
66+
// the default is to end the traversal here as we're either done or a recursive call will handle it
67+
skipBranch.set(true);
68+
return switch (expr) {
69+
// this is already a time series aggregation, no need to go deeper
70+
case TimeSeriesAggregateFunction ts -> ts;
71+
// only transform field, not all children (such as inline filter or window)
72+
case AggregateFunction af -> af.withField(addDefaultInnerAggs(af.field(), timestamp, changed));
73+
// avoid modifying filter conditions, just the delegate
74+
case FilteredExpression filtered -> filtered.withDelegate(addDefaultInnerAggs(filtered.delegate(), timestamp, changed));
75+
// if we reach a TypedAttribute, it hasn't been wrapped in a TimeSeriesAggregateFunction yet
76+
// (otherwise the traversal would have stopped earlier)
77+
// so we wrap it with a default one
78+
case TypedAttribute ta -> insertDefaultInnerAggregation(ta, timestamp, changed);
79+
default -> {
80+
// for other expressions, continue the traversal
81+
skipBranch.set(false);
82+
yield expr;
83+
}
84+
};
85+
});
86+
}
87+
88+
private static TimeSeriesAggregateFunction insertDefaultInnerAggregation(
89+
TypedAttribute attr,
90+
Expression timestamp,
91+
Holder<Boolean> changed
92+
) {
93+
changed.set(true);
94+
if (attr.dataType() == DataType.EXPONENTIAL_HISTOGRAM || attr.dataType() == DataType.TDIGEST) {
95+
return new HistogramMergeOverTime(attr.source(), attr, Literal.TRUE, AggregateFunction.NO_WINDOW);
96+
} else {
97+
return new LastOverTime(attr.source(), attr, AggregateFunction.NO_WINDOW, timestamp);
98+
}
99+
}
100+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/core/tree/Node.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.List;
1717
import java.util.Objects;
1818
import java.util.function.BiConsumer;
19+
import java.util.function.BiFunction;
1920
import java.util.function.Consumer;
2021
import java.util.function.Function;
2122
import java.util.function.Predicate;
@@ -233,6 +234,23 @@ public T transformDown(Function<? super T, ? extends T> rule) {
233234
return node.transformChildren(child -> child.transformDown(rule));
234235
}
235236

237+
@SuppressWarnings("unchecked")
238+
public T transformDownSkipBranch(BiFunction<? super T, Holder<Boolean>, ? extends T> rule) {
239+
Holder<Boolean> skipBranch = new Holder<>(Boolean.FALSE);
240+
return transformDownSkipBranch(skipBranch, rule);
241+
}
242+
243+
@SuppressWarnings("unchecked")
244+
T transformDownSkipBranch(Holder<Boolean> skipBranch, BiFunction<? super T, Holder<Boolean>, ? extends T> rule) {
245+
T root = rule.apply((T) this, skipBranch);
246+
Node<T> node = this.equals(root) ? this : root;
247+
if (skipBranch.get()) {
248+
skipBranch.set(false);
249+
return (T) node;
250+
}
251+
return node.transformChildren(child -> child.transformDownSkipBranch(skipBranch, rule));
252+
}
253+
236254
@SuppressWarnings("unchecked")
237255
public <E extends T> T transformDown(Class<E> typeToken, Function<E, ? extends T> rule) {
238256
return transformDown((t) -> (typeToken.isInstance(t) ? rule.apply((E) t) : t));

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,11 @@ public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
210210
}
211211
};
212212
}
213+
214+
public AggregateFunction withField(Expression newField) {
215+
if (newField == this.field) {
216+
return this;
217+
}
218+
return (AggregateFunction) replaceChildren(CollectionUtils.combine(asList(newField, filter, window), parameters));
219+
}
213220
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FilteredExpression.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,11 @@ protected NodeInfo<FilteredExpression> info() {
9292
public Expression replaceChildren(List<Expression> newChildren) {
9393
return new FilteredExpression(source(), newChildren.get(0), newChildren.get(1));
9494
}
95+
96+
public FilteredExpression withDelegate(Expression newDelegate) {
97+
if (newDelegate == delegate) {
98+
return this;
99+
}
100+
return new FilteredExpression(source(), newDelegate, filter());
101+
}
95102
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.analysis;
9+
10+
import org.elasticsearch.xpack.esql.core.expression.Alias;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.optimizer.AbstractLogicalPlanOptimizerTests;
13+
14+
import java.util.Locale;
15+
import java.util.function.Function;
16+
17+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalToIgnoringIds;
18+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.ignoreIds;
19+
20+
public class InsertDefaultInnerTimeSeriesAggregateTests extends AbstractLogicalPlanOptimizerTests {
21+
22+
private final InsertDefaultInnerTimeSeriesAggregate rule = new InsertDefaultInnerTimeSeriesAggregate();
23+
24+
public void testSimpleImplicitOverTime() {
25+
assertStatsEqual("sum(network.bytes_in)", "sum(last_over_time(network.bytes_in))");
26+
}
27+
28+
public void testBinaryWithImplicitAndExplicitOverTime() {
29+
var expected = "sum(last_over_time(network.eth0.tx) + last_over_time(network.eth0.rx))";
30+
assertStatsEqual("sum(network.eth0.tx + network.eth0.rx)", expected);
31+
assertStatsEqual("sum(network.eth0.tx + last_over_time(network.eth0.rx))", expected);
32+
assertStatsEqual("sum(last_over_time(network.eth0.tx) + network.eth0.rx)", expected);
33+
assertStatsEqual(expected, expected);
34+
}
35+
36+
public void testComplexArithmetic() {
37+
var expected = "sum(last_over_time(network.eth0.tx) / 2 + last_over_time(network.eth0.rx) * 2)";
38+
assertStatsEqual("sum(network.eth0.tx / 2 + network.eth0.rx * 2)", expected);
39+
assertStatsEqual("sum(last_over_time(network.eth0.tx) / 2 + network.eth0.rx * 2)", expected);
40+
assertStatsEqual("sum(network.eth0.tx / 2 + last_over_time(network.eth0.rx) * 2)", expected);
41+
assertStatsEqual(expected, expected);
42+
}
43+
44+
private void assertStatsEqual(String stats1, String stats2) {
45+
var baseQuery = """
46+
TS k8s
47+
| STATS %s BY bucket(@timestamp, 1 minute)
48+
| LIMIT 10
49+
""";
50+
var plan1 = rule.apply(metricsAnalyzer.analyze(parser.parseQuery(String.format(Locale.ROOT, baseQuery, stats1))));
51+
Function<Alias, Expression> ignoreAliasName = (Alias a) -> new Alias(a.source(), "dummy", a.child(), a.id());
52+
plan1 = plan1.transformExpressionsDown(Alias.class, ignoreAliasName);
53+
var plan2 = metricsAnalyzer.analyze(parser.parseQuery(String.format(Locale.ROOT, baseQuery, stats2)));
54+
plan2 = plan2.transformExpressionsDown(Alias.class, ignoreAliasName);
55+
assertThat(ignoreIds(plan1), equalToIgnoringIds(plan2));
56+
}
57+
58+
}

0 commit comments

Comments
 (0)