diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go index 3c1c3ea8322..4de7ca48059 100644 --- a/pkg/adt/interval_tree.go +++ b/pkg/adt/interval_tree.go @@ -472,8 +472,16 @@ func (ivt *intervalTree) Insert(ivl Interval, val any) { x := ivt.root for x != ivt.sentinel { y = x - if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 { + // Split on left endpoint. If left endpoints match, instead split on right endpoint. + beginCompare := z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) + if beginCompare < 0 { x = x.left + } else if beginCompare == 0 { + if z.iv.Ivl.End.Compare(x.iv.Ivl.End) < 0 { + x = x.left + } else { + x = x.right + } } else { x = x.right } @@ -483,8 +491,15 @@ func (ivt *intervalTree) Insert(ivl Interval, val any) { if y == ivt.sentinel { ivt.root = z } else { - if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 { + beginCompare := z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) + if beginCompare < 0 { y.left = z + } else if beginCompare == 0 { + if z.iv.Ivl.End.Compare(y.iv.Ivl.End) < 0 { + y.left = z + } else { + y.right = z + } } else { y.right = z } @@ -701,16 +716,29 @@ func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) { // find the exact node for a given interval func (ivt *intervalTree) find(ivl Interval) *intervalNode { - ret := ivt.sentinel - f := func(n *intervalNode) bool { - if n.iv.Ivl != ivl { - return true + x := ivt.root + // Search until hit sentinel or exact match. + for x != ivt.sentinel { + beginCompare := ivl.Begin.Compare(x.iv.Ivl.Begin) + endCompare := ivl.End.Compare(x.iv.Ivl.End) + if beginCompare == 0 && endCompare == 0 { + return x + } + // Split on left endpoint. If left endpoints match, + // instead split on right endpoints. + if beginCompare < 0 { + x = x.left + } else if beginCompare == 0 { + if endCompare < 0 { + x = x.left + } else { + x = x.right + } + } else { + x = x.right } - ret = n - return false } - ivt.root.visit(&ivl, ivt.sentinel, f) - return ret + return x } // Find gets the IntervalValue for the node matching the given interval