Skip to content
This repository was archived by the owner on Nov 7, 2025. It is now read-only.

Commit 5b7ed7f

Browse files
authored
[sample_flights] Fix total in top_hits response (#1337)
Makes the last skipped `sample_flights` test pass. In other aggregations we have `parent_count` for that. Here I don't add a new column to select, just extract from the existing ones, as it's always there and I find it much simpler to fix it this way. We can change that later if needed.
1 parent 97eec91 commit 5b7ed7f

File tree

5 files changed

+126
-23
lines changed

5 files changed

+126
-23
lines changed

platform/model/metrics_aggregations/top_hits.go

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"context"
77
"github.com/QuesmaOrg/quesma/platform/logger"
88
"github.com/QuesmaOrg/quesma/platform/model"
9+
"github.com/QuesmaOrg/quesma/platform/util"
10+
"strings"
911
)
1012

1113
type TopHits struct {
@@ -23,7 +25,6 @@ func (query *TopHits) AggregationType() model.AggregationType {
2325
return model.MetricsAggregation
2426
}
2527

26-
// TODO: implement correct
2728
func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) model.JsonMap {
2829
var topElems []any
2930
if len(rows) > 0 && 0 >= len(rows[0].Cols) {
@@ -39,7 +40,13 @@ func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) mo
3940
continue
4041
}
4142

42-
valuesForHits := row.Cols
43+
var valuesForHits []model.QueryResultCol
44+
if query.isCount(row.Cols[0]) {
45+
valuesForHits = row.Cols[1:]
46+
} else {
47+
valuesForHits = row.Cols
48+
}
49+
4350
sourceMap := model.JsonMap{}
4451

4552
for _, col := range valuesForHits {
@@ -63,13 +70,19 @@ func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) mo
6370
if len(topElems) == 0 {
6471
maxScore = nil
6572
}
73+
74+
var total int
75+
if len(rows) > 0 {
76+
total = query.getCount(&rows[0])
77+
}
78+
6679
return model.JsonMap{
6780
"hits": model.JsonMap{
6881
"hits": topElems,
6982
"max_score": maxScore, // placeholder
7083
"total": model.JsonMap{ // could be better
7184
"relation": "eq", // TODO: wrong, but let's pass test, it should ge geq
72-
"value": len(topElems),
85+
"value": total,
7386
},
7487
},
7588
}
@@ -78,3 +91,19 @@ func (query *TopHits) TranslateSqlResponseToJson(rows []model.QueryResultRow) mo
7891
func (query *TopHits) String() string {
7992
return "top_hits"
8093
}
94+
95+
func (query *TopHits) getCount(row *model.QueryResultRow) int {
96+
if len(row.Cols) == 0 {
97+
return 0
98+
}
99+
if asInt, ok := util.ExtractInt64Maybe(row.Cols[0].ExtractValue()); ok {
100+
return int(asInt)
101+
} else {
102+
logger.WarnWithCtxAndThrottling(query.ctx, "top_hits", "count", "could not extract count from top_hits, row: %v", row)
103+
return 0
104+
}
105+
}
106+
107+
func (query *TopHits) isCount(col model.QueryResultCol) bool {
108+
return strings.HasSuffix(col.ColName, "count")
109+
}

platform/parsers/elastic_query_dsl/pancake_json_rendering.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ func (p *pancakeJSONRenderer) selectTopHitsRows(topAggr *pancakeModelMetricAggre
6363
}
6464
newCols = append(newCols, col)
6565
}
66+
} else if topAggr.isColumnParentCount(col.ColName) {
67+
// top_hits needs parent count, when it's available
68+
newCols = append(newCols, col)
6669
}
6770
}
6871
result = append(result, model.QueryResultRow{Index: row.Index, Cols: newCols})

platform/parsers/elastic_query_dsl/pancake_model.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,25 @@ func (p pancakeModelMetricAggregation) InternalNameForCol(id int) string {
118118
return fmt.Sprintf("%s%d", p.InternalNamePrefix(), id)
119119
}
120120

121+
// isColumnParentCount checks if `internalName` is a parent count column for this metric aggregation
122+
// Only tested/works for `top_hits`, not needed anywhere else.
123+
func (p pancakeModelMetricAggregation) isColumnParentCount(internalNameMaybeParent string) bool {
124+
// We return true only when:
125+
// p.internalName ==."top_hits__[AGG_PATH]__[name]"
126+
// AND internalNameMaybeParent == "aggr__[AGG_PATH]__count"
127+
// (AGG_PATH must be the same)
128+
thisAggrRegex := regexp.MustCompile("top_hits__([a-zA-Z0-9_]+)__[a-zA-Z0-9_]+")
129+
maybeParentRegex := regexp.MustCompile("aggr__([a-zA-Z0-9_]+)__count")
130+
if !thisAggrRegex.MatchString(p.internalName) || !maybeParentRegex.MatchString(internalNameMaybeParent) {
131+
return false
132+
}
133+
134+
matchThisAggr := thisAggrRegex.FindStringSubmatch(p.InternalNamePrefix())
135+
matchMaybeParent := maybeParentRegex.FindStringSubmatch(internalNameMaybeParent)
136+
// [1] is the first capturing group in the regex (called AGG_PATH above). It's ([a-zA-Z0-9_]+) from the regex
137+
return len(matchThisAggr) == 2 && len(matchMaybeParent) == 2 && matchThisAggr[1] == matchMaybeParent[1]
138+
}
139+
121140
func (p pancakeModelBucketAggregation) ShallowClone() pancakeModelBucketAggregation {
122141
return pancakeModelBucketAggregation{
123142
name: p.name,

platform/parsers/elastic_query_dsl/pancake_sql_query_generation_test.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ func TestPancakeQueryGeneration(t *testing.T) {
4848

4949
for i, test := range allAggregationTests() {
5050
t.Run(test.TestName+"("+strconv.Itoa(i)+")", func(t *testing.T) {
51-
// sample_flights
52-
if test.TestName == "TODO Airport Connections (Hover Over Airport)(file:kibana-sample-data-flights,nr:14)" {
53-
t.Skip("Fixing right now")
54-
}
5551
// sample_ecommerce
5652
if test.TestName == "TODO Top products this/last week(file:kibana-sample-data-ecommerce,nr:9)" {
5753
t.Skip("works IRL, need to update test's schema. It's already WIP https://github.com/QuesmaOrg/quesma/pull/1255. Let's wait for merge.")

platform/testdata/kibana_sample_data_flights.go

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2735,8 +2735,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
27352735
"_score": 1.0,
27362736
"_source": {
27372737
"DestLocation": {
2738-
"lat": "-34.8222",
2739-
"lon": "-58.5358"
2738+
"lat": -34.8222,
2739+
"lon": -58.5358
27402740
}
27412741
}
27422742
}
@@ -2761,8 +2761,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
27612761
"_score": 1.0,
27622762
"_source": {
27632763
"DestLocation": {
2764-
"lat": "-0.129166667",
2765-
"lon": "-78.3575"
2764+
"lat": -0.129166667,
2765+
"lon": -78.3575
27662766
}
27672767
}
27682768
}
@@ -2793,8 +2793,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
27932793
"_source": {
27942794
"Origin": "Mariscal Sucre International Airport",
27952795
"OriginLocation": {
2796-
"lat": "-0.129166667",
2797-
"lon": "-78.3575"
2796+
"lat": -0.129166667,
2797+
"lon": -78.3575
27982798
}
27992799
}
28002800
}
@@ -2820,8 +2820,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
28202820
"_score": 1.0,
28212821
"_source": {
28222822
"DestLocation": {
2823-
"lat": "45.47060013",
2824-
"lon": "-73.74079895"
2823+
"lat": 45.47060013,
2824+
"lon": -73.74079895
28252825
}
28262826
}
28272827
}
@@ -2846,8 +2846,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
28462846
"_score": 1.0,
28472847
"_source": {
28482848
"DestLocation": {
2849-
"lat": "-34.8222",
2850-
"lon": "-58.5358"
2849+
"lat": -34.8222,
2850+
"lon": -58.5358
28512851
}
28522852
}
28532853
}
@@ -2878,8 +2878,8 @@ var KibanaSampleDataFlights = []AggregationTestCase{
28782878
"_source": {
28792879
"Origin": "Ministro Pistarini International Airport",
28802880
"OriginLocation": {
2881-
"lat": "-34.8222",
2882-
"lon": "-58.5358"
2881+
"lat": -34.8222,
2882+
"lon": -58.5358
28832883
}
28842884
}
28852885
}
@@ -2894,15 +2894,15 @@ var KibanaSampleDataFlights = []AggregationTestCase{
28942894
}
28952895
],
28962896
"doc_count_error_upper_bound": 0,
2897-
"sum_other_doc_count": 12474
2897+
"sum_other_doc_count": 1460
28982898
}
28992899
},
29002900
"hits": {
29012901
"hits": [],
29022902
"max_score": null,
29032903
"total": {
29042904
"relation": "eq",
2905-
"value": 13014
2905+
"value": 2000
29062906
}
29072907
},
29082908
"timed_out": false,
@@ -2912,15 +2912,71 @@ var KibanaSampleDataFlights = []AggregationTestCase{
29122912
}`,
29132913
ExpectedPancakeResults: []model.QueryResultRow{
29142914
{Cols: []model.QueryResultCol{
2915-
model.NewQueryResultCol("aggr__origins__parent_count", int64(283)),
2915+
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
29162916
model.NewQueryResultCol("aggr__origins__key_0", "UIO"),
29172917
model.NewQueryResultCol("aggr__origins__count", int64(283)),
29182918
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(283)),
29192919
model.NewQueryResultCol("aggr__origins__distinations__key_0", "EZE"),
29202920
model.NewQueryResultCol("aggr__origins__distinations__count", int64(21)),
2921-
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", "[-34.8222, -58.5358]"),
2921+
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": -34.8222, "lon": -58.5358}),
29222922
model.NewQueryResultCol("top_hits_rank", int64(1)),
29232923
}},
2924+
{Cols: []model.QueryResultCol{
2925+
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
2926+
model.NewQueryResultCol("aggr__origins__key_0", "UIO"),
2927+
model.NewQueryResultCol("aggr__origins__count", int64(283)),
2928+
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(283)),
2929+
model.NewQueryResultCol("aggr__origins__distinations__key_0", "UIO"),
2930+
model.NewQueryResultCol("aggr__origins__distinations__count", int64(12)),
2931+
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": -0.129167, "lon": -78.3575}),
2932+
model.NewQueryResultCol("top_hits_rank", int64(1)),
2933+
}},
2934+
{Cols: []model.QueryResultCol{
2935+
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
2936+
model.NewQueryResultCol("aggr__origins__key_0", "EZE"),
2937+
model.NewQueryResultCol("aggr__origins__count", int64(257)),
2938+
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(257)),
2939+
model.NewQueryResultCol("aggr__origins__distinations__key_0", "YUL"),
2940+
model.NewQueryResultCol("aggr__origins__distinations__count", int64(11)),
2941+
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": 45.470600, "lon": -73.740799}),
2942+
model.NewQueryResultCol("top_hits_rank", int64(1)),
2943+
}},
2944+
{Cols: []model.QueryResultCol{
2945+
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
2946+
model.NewQueryResultCol("aggr__origins__key_0", "EZE"),
2947+
model.NewQueryResultCol("aggr__origins__count", int64(257)),
2948+
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(257)),
2949+
model.NewQueryResultCol("aggr__origins__distinations__key_0", "EZE"),
2950+
model.NewQueryResultCol("aggr__origins__distinations__count", int64(10)),
2951+
model.NewQueryResultCol("top_hits__origins__distinations__destLocation_col_0", model.JsonMap{"lat": -34.822200, "lon": -58.535800}),
2952+
model.NewQueryResultCol("top_hits_rank", int64(1)),
2953+
}},
2954+
},
2955+
ExpectedAdditionalPancakeResults: [][]model.QueryResultRow{
2956+
{
2957+
{Cols: []model.QueryResultCol{
2958+
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
2959+
model.NewQueryResultCol("aggr__origins__key_0", "UIO"),
2960+
model.NewQueryResultCol("aggr__origins__count", int64(283)),
2961+
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(283)),
2962+
model.NewQueryResultCol("aggr__origins__distinations__key_0", "EZE"),
2963+
model.NewQueryResultCol("aggr__origins__distinations__count", int64(21)),
2964+
model.NewQueryResultCol("top_hits__origins__originLocation_col_0", model.JsonMap{"lat": -0.129167, "lon": -78.3575}),
2965+
model.NewQueryResultCol("top_hits__origins__originLocation_col_1", "Mariscal Sucre International Airport"),
2966+
model.NewQueryResultCol("top_hits_rank", int64(1)),
2967+
}},
2968+
{Cols: []model.QueryResultCol{
2969+
model.NewQueryResultCol("aggr__origins__parent_count", int64(2000)),
2970+
model.NewQueryResultCol("aggr__origins__key_0", "EZE"),
2971+
model.NewQueryResultCol("aggr__origins__count", int64(257)),
2972+
model.NewQueryResultCol("aggr__origins__distinations__parent_count", int64(257)),
2973+
model.NewQueryResultCol("aggr__origins__distinations__key_0", "YUL"),
2974+
model.NewQueryResultCol("aggr__origins__distinations__count", int64(11)),
2975+
model.NewQueryResultCol("top_hits__origins__originLocation_col_0", model.JsonMap{"lat": -34.822200, "lon": -58.535800}),
2976+
model.NewQueryResultCol("top_hits__origins__originLocation_col_1", "Ministro Pistarini International Airport"),
2977+
model.NewQueryResultCol("top_hits_rank", int64(1)),
2978+
}},
2979+
},
29242980
},
29252981
ExpectedPancakeSQL: `
29262982
WITH quesma_top_hits_group_table AS (

0 commit comments

Comments
 (0)