Skip to content

Commit 7da1b90

Browse files
committed
refactor: moves util quantiles util
1 parent bceed01 commit 7da1b90

8 files changed

Lines changed: 237 additions & 24 deletions

File tree

File renamed without changes.
File renamed without changes.

internal/quantiles/sorted_view.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package quantiles
19+
20+
// MinMaxResult holds a pair of adjusted sorted-view arrays.
21+
type MinMaxResult[T Number] struct {
22+
Quantiles []T
23+
CumWeights []int64
24+
}
25+
26+
// IncludeMinMax reinserts the true min and max items into sorted-view
27+
// arrays when they are missing from the retained quantiles.
28+
//
29+
// The returned cumulative weights remain in cumulative form. When no
30+
// adjustment is required, the input slices are returned unchanged.
31+
func IncludeMinMax[T Number](
32+
quantiles []T,
33+
cumWeights []int64,
34+
maxItem, minItem T,
35+
) MinMaxResult[T] {
36+
lenIn := len(cumWeights)
37+
adjLow := quantiles[0] != minItem
38+
adjHigh := quantiles[lenIn-1] != maxItem
39+
adjLen := lenIn
40+
if adjLow {
41+
adjLen++
42+
}
43+
if adjHigh {
44+
adjLen++
45+
}
46+
47+
if adjLen == lenIn {
48+
return MinMaxResult[T]{
49+
Quantiles: quantiles,
50+
CumWeights: cumWeights,
51+
}
52+
}
53+
54+
adjQuantiles := make([]T, adjLen)
55+
adjCumWeights := make([]int64, adjLen)
56+
offset := 0
57+
if adjLow {
58+
offset = 1
59+
}
60+
copy(adjQuantiles[offset:], quantiles[:lenIn])
61+
copy(adjCumWeights[offset:], cumWeights[:lenIn])
62+
63+
if adjLow {
64+
adjQuantiles[0] = minItem
65+
adjCumWeights[0] = 1
66+
}
67+
68+
if adjHigh {
69+
adjQuantiles[adjLen-1] = maxItem
70+
adjCumWeights[adjLen-1] = cumWeights[lenIn-1]
71+
adjCumWeights[adjLen-2] = cumWeights[lenIn-1] - 1
72+
}
73+
74+
return MinMaxResult[T]{
75+
Quantiles: adjQuantiles,
76+
CumWeights: adjCumWeights,
77+
}
78+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package quantiles
19+
20+
import (
21+
"testing"
22+
23+
"github.com/stretchr/testify/assert"
24+
)
25+
26+
func TestIncludeMinMax(t *testing.T) {
27+
t.Run("Adjust Both Ends", func(t *testing.T) {
28+
quantiles := []float32{2, 4, 6, 7}
29+
cumWeights := []int64{2, 4, 6, 8}
30+
31+
got := IncludeMinMax(quantiles, cumWeights, 8, 1)
32+
33+
wantQuantiles := []float32{1, 2, 4, 6, 7, 8}
34+
wantCumWeights := []int64{1, 2, 4, 6, 7, 8}
35+
36+
assert.Equal(t, wantQuantiles, got.Quantiles)
37+
assert.Equal(t, wantCumWeights, got.CumWeights)
38+
})
39+
40+
t.Run("Return original slices", func(t *testing.T) {
41+
quantiles := []float32{2, 4, 6, 7}
42+
cumWeights := []int64{2, 4, 6, 8}
43+
44+
got := IncludeMinMax(quantiles, cumWeights, 7, 2)
45+
46+
assert.Equal(t, quantiles, got.Quantiles)
47+
assert.Equal(t, cumWeights, got.CumWeights)
48+
})
49+
50+
t.Run("Adjust Low End Only", func(t *testing.T) {
51+
quantiles := []float32{2, 4, 6, 8}
52+
cumWeights := []int64{2, 4, 6, 8}
53+
54+
got := IncludeMinMax(quantiles, cumWeights, 8, 1)
55+
56+
wantQuantiles := []float32{1, 2, 4, 6, 8}
57+
wantCumWeights := []int64{1, 2, 4, 6, 8}
58+
59+
assert.Equal(t, wantQuantiles, got.Quantiles)
60+
assert.Equal(t, wantCumWeights, got.CumWeights)
61+
})
62+
63+
t.Run("Adjust High End Only", func(t *testing.T) {
64+
quantiles := []float32{1, 2, 4, 6}
65+
cumWeights := []int64{1, 2, 4, 8}
66+
67+
got := IncludeMinMax(quantiles, cumWeights, 8, 1)
68+
69+
wantQuantiles := []float32{1, 2, 4, 6, 8}
70+
wantCumWeights := []int64{1, 2, 4, 7, 8}
71+
72+
assert.Equal(t, wantQuantiles, got.Quantiles)
73+
assert.Equal(t, wantCumWeights, got.CumWeights)
74+
})
75+
}

internal/quantiles/utils.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ import (
2222
"math"
2323
)
2424

25-
var ErrNanInSplitPoints = errors.New("NaN in split points")
25+
const tailRoundingFactor = 1e7
2626

27-
var ErrInvalidSplitPoints = errors.New("values must be unique and monotonically increasing")
27+
// errors.
28+
var (
29+
ErrNanInSplitPoints = errors.New("NaN in split points")
30+
ErrInvalidSplitPoints = errors.New("values must be unique and monotonically increasing")
31+
)
2832

2933
type Number interface {
3034
float32 | float64 | int64
@@ -41,3 +45,22 @@ func ValidateSplitPoints[N Number](values []N) error {
4145
}
4246
return nil
4347
}
48+
49+
func ValidateNormalizedRankBounds(rank float64) error {
50+
if rank < 0 || rank > 1 {
51+
return errors.New("rank must be between 0 and 1 inclusive")
52+
}
53+
return nil
54+
}
55+
56+
// ComputeNaturalRank Computes the closest Natural Rank from a given Normalized Rank
57+
func ComputeNaturalRank(normalizedRank float64, totalN uint64, inclusive bool) int64 {
58+
naturalRank := normalizedRank * float64(totalN)
59+
if totalN <= tailRoundingFactor {
60+
naturalRank = math.Round(naturalRank*tailRoundingFactor) / tailRoundingFactor
61+
}
62+
if inclusive {
63+
return int64(math.Ceil(naturalRank))
64+
}
65+
return int64(math.Floor(naturalRank))
66+
}

internal/quantiles/utils_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,59 @@ func TestValidateSplitPoints(t *testing.T) {
4949
})
5050
}
5151
}
52+
53+
func TestValidateNormalizedRankBounds(t *testing.T) {
54+
tests := []struct {
55+
name string
56+
rank float64
57+
wantErr string
58+
}{
59+
{"below zero", -0.1, "rank must be between 0 and 1 inclusive"},
60+
{"zero", 0, ""},
61+
{"middle", 0.5, ""},
62+
{"one", 1, ""},
63+
{"above one", 1.1, "rank must be between 0 and 1 inclusive"},
64+
}
65+
for _, tt := range tests {
66+
t.Run(tt.name, func(t *testing.T) {
67+
err := ValidateNormalizedRankBounds(tt.rank)
68+
if tt.wantErr != "" {
69+
assert.EqualError(t, err, tt.wantErr)
70+
} else {
71+
assert.NoError(t, err)
72+
}
73+
})
74+
}
75+
}
76+
77+
func TestComputeNaturalRank(t *testing.T) {
78+
rankJustBelowThreeNoRounding := math.Nextafter(3.0/10000001.0, 0)
79+
80+
tests := []struct {
81+
name string
82+
normalizedRank float64
83+
totalN uint64
84+
inclusive bool
85+
want int64
86+
}{
87+
{"zero exclusive", 0, 10, false, 0},
88+
{"zero inclusive", 0, 10, true, 0},
89+
{"one exclusive", 1, 10, false, 10},
90+
{"one inclusive", 1, 10, true, 10},
91+
{"exact integer exclusive", 0.5, 10, false, 5},
92+
{"exact integer inclusive", 0.5, 10, true, 5},
93+
{"fractional exclusive floors", 0.21, 10, false, 2},
94+
{"fractional inclusive ceils", 0.21, 10, true, 3},
95+
{"rounding enabled exclusive", 0.299999996, 10, false, 3},
96+
{"rounding enabled inclusive", 0.299999996, 10, true, 3},
97+
{"rounding disabled exclusive", rankJustBelowThreeNoRounding, 10000001, false, 2},
98+
{"rounding disabled inclusive", rankJustBelowThreeNoRounding, 10000001, true, 3},
99+
}
100+
101+
for _, tt := range tests {
102+
t.Run(tt.name, func(t *testing.T) {
103+
got := ComputeNaturalRank(tt.normalizedRank, tt.totalN, tt.inclusive)
104+
assert.Equal(t, tt.want, got)
105+
})
106+
}
107+
}

kll/items_sketch_sorted_view.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/apache/datasketches-go/common"
2525
"github.com/apache/datasketches-go/internal"
26+
"github.com/apache/datasketches-go/internal/quantiles"
2627
)
2728

2829
type ItemsSketchSortedView[C comparable] struct {
@@ -98,7 +99,7 @@ func (s *ItemsSketchSortedView[C]) GetQuantile(rank float64, inclusive bool) (C,
9899
var zero C
99100
return zero, errors.New("empty sketch")
100101
}
101-
err := checkNormalizedRankBounds(rank)
102+
err := quantiles.ValidateNormalizedRankBounds(rank)
102103
if err != nil {
103104
var zero C
104105
return zero, err
@@ -155,7 +156,7 @@ func (s *ItemsSketchSortedView[C]) Iterator() *ItemsSketchSortedViewIterator[C]
155156

156157
func (s *ItemsSketchSortedView[C]) getQuantileIndex(rank float64, inclusive bool) (int, error) {
157158
length := len(s.quantiles)
158-
naturalRank := getNaturalRank(rank, s.totalN, inclusive)
159+
naturalRank := quantiles.ComputeNaturalRank(rank, s.totalN, inclusive)
159160
crit := internal.InequalityGT
160161
if inclusive {
161162
crit = internal.InequalityGE

kll/utils.go

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ import (
2828
)
2929

3030
const (
31-
tailRoundingFactor = 1e7
32-
3331
_PMF_COEF = 2.446
3432
_PMF_EXP = 0.9433
3533
_CDF_COEF = 2.296
@@ -45,17 +43,6 @@ func convertToCumulative(array []int64) int64 {
4543
return subtotal
4644
}
4745

48-
func getNaturalRank(normalizedRank float64, totalN uint64, inclusive bool) int64 {
49-
naturalRank := normalizedRank * float64(totalN)
50-
if totalN <= tailRoundingFactor {
51-
naturalRank = math.Round(naturalRank*tailRoundingFactor) / tailRoundingFactor
52-
}
53-
if inclusive {
54-
return int64(math.Ceil(naturalRank))
55-
}
56-
return int64(math.Floor(naturalRank))
57-
}
58-
5946
func checkK(k uint16, m uint8) error {
6047
if k < uint16(m) || k > _MAX_K {
6148
return errors.New("K must be >= " + strconv.Itoa(int(m)) + " and <= " + strconv.Itoa(_MAX_K) + ": " + strconv.Itoa(int(k)))
@@ -70,13 +57,6 @@ func checkM(m uint8) error {
7057
return nil
7158
}
7259

73-
func checkNormalizedRankBounds(rank float64) error {
74-
if rank < 0 || rank > 1 {
75-
return errors.New("rank must be between 0 and 1 inclusive")
76-
}
77-
return nil
78-
}
79-
8060
func checkItems[C comparable](items []C, compareFn common.CompareFn[C]) error {
8161
if len(items) == 1 && internal.IsNil(items[0]) {
8262
return errors.New("items must be unique, monotonically increasing and not nil")

0 commit comments

Comments
 (0)