Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ func nrand(n int) []int {
return i
}

func BenchmarkNewSet(b *testing.B) {
nums := nrand(1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = NewSet(nums...)
}
}

func BenchmarkNewThreadUnsafeSet(b *testing.B) {
nums := nrand(1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = NewThreadUnsafeSet(nums...)
}
}

func benchAdd(b *testing.B, n int, newSet func(...int) Set[int]) {
nums := nrand(n)
b.ResetTimer()
Expand All @@ -57,6 +73,23 @@ func BenchmarkAddUnsafe(b *testing.B) {
benchAdd(b, 1000, NewThreadUnsafeSet[int])
}

func benchAppend(b *testing.B, n int, newSet func(...int) Set[int]) {
nums := nrand(n)
b.ResetTimer()
for i := 0; i < b.N; i++ {
s := newSet()
s.Append(nums...)
}
}

func BenchmarkAppendSafe(b *testing.B) {
benchAppend(b, 1000, NewSet[int])
}

func BenchmarkAppendUnsafe(b *testing.B) {
benchAppend(b, 1000, NewThreadUnsafeSet[int])
}

func benchRemove(b *testing.B, s Set[int]) {
nums := nrand(b.N)
for _, v := range nums {
Expand Down
30 changes: 12 additions & 18 deletions set.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,45 +217,39 @@ type Set[T comparable] interface {

// NewSet creates and returns a new set with the given elements.
// Operations on the resulting set are thread-safe.
func NewSet[T comparable](vals ...T) Set[T] {
s := newThreadSafeSetWithSize[T](len(vals))
for _, item := range vals {
s.Add(item)
}
func NewSet[T comparable](vs ...T) Set[T] {
s := newThreadSafeSetWithSize[T](len(vs))
s.uss.append(vs...)
return s
}

// NewSetWithSize creates and returns a reference to an empty set with a specified
// capacity. Operations on the resulting set are thread-safe.
func NewSetWithSize[T comparable](cardinality int) Set[T] {
s := newThreadSafeSetWithSize[T](cardinality)
return s
return newThreadSafeSetWithSize[T](cardinality)
}

// NewThreadUnsafeSet creates and returns a new set with the given elements.
// Operations on the resulting set are not thread-safe.
func NewThreadUnsafeSet[T comparable](vals ...T) Set[T] {
s := newThreadUnsafeSetWithSize[T](len(vals))
for _, item := range vals {
s.Add(item)
}
func NewThreadUnsafeSet[T comparable](vs ...T) Set[T] {
s := newThreadUnsafeSetWithSize[T](len(vs))
s.append(vs...)
return s
}

// NewThreadUnsafeSetWithSize creates and returns a reference to an empty set with
// a specified capacity. Operations on the resulting set are not thread-safe.
func NewThreadUnsafeSetWithSize[T comparable](cardinality int) Set[T] {
s := newThreadUnsafeSetWithSize[T](cardinality)
return s
return newThreadUnsafeSetWithSize[T](cardinality)
}

// NewSetFromMapKeys creates and returns a new set with the given keys of the map.
// Operations on the resulting set are thread-safe.
func NewSetFromMapKeys[T comparable, V any](val map[T]V) Set[T] {
s := NewSetWithSize[T](len(val))
s := newThreadSafeSetWithSize[T](len(val))

for k := range val {
s.Add(k)
s.uss.add(k)
}

return s
Expand All @@ -264,10 +258,10 @@ func NewSetFromMapKeys[T comparable, V any](val map[T]V) Set[T] {
// NewThreadUnsafeSetFromMapKeys creates and returns a new set with the given keys of the map.
// Operations on the resulting set are not thread-safe.
func NewThreadUnsafeSetFromMapKeys[T comparable, V any](val map[T]V) Set[T] {
s := NewThreadUnsafeSetWithSize[T](len(val))
s := newThreadUnsafeSetWithSize[T](len(val))

for k := range val {
s.Add(k)
s.add(k)
}

return s
Expand Down
44 changes: 24 additions & 20 deletions threadunsafe.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,30 @@ func newThreadUnsafeSetWithSize[T comparable](cardinality int) *threadUnsafeSet[
return &t
}

func (s threadUnsafeSet[T]) Add(v T) bool {
prevLen := len(s)
s[v] = struct{}{}
return prevLen != len(s)
}

func (s *threadUnsafeSet[T]) Append(v ...T) int {
prevLen := len(*s)
for _, val := range v {
(*s)[val] = struct{}{}
}
return len(*s) - prevLen
func (s *threadUnsafeSet[T]) Add(v T) bool {
prevLen := s.Cardinality()
s.add(v)
return prevLen != s.Cardinality()
}

// private version of Add which doesn't return a value
func (s *threadUnsafeSet[T]) add(v T) {
(*s)[v] = struct{}{}
}

func (s *threadUnsafeSet[T]) Append(vs ...T) int {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like an append() is needed as a counterpart to add(), since the optimized logic is used in multiple places:

func  (s *threadUnsafeSet[T]) Append(vs ...T) {
    prevLen := s.Cardinality()
	s.append(v)
	return prevLen != s.Cardinality()
}

func (s *threadUnsafeSet[T]) append(vs ...T) {
    for i := range vs {
        s.add(vs[i])
    }
}
// s.append can then be used in JSON/BSON unmarshaling

prevLen := s.Cardinality()
s.append(vs...)
return s.Cardinality() - prevLen
}

// private version of Append which doesn't return a value
func (s *threadUnsafeSet[T]) append(vs ...T) {
for i := range vs {
s.add(vs[i])
}
}

func (s *threadUnsafeSet[T]) Cardinality() int {
return len(*s)
}
Expand All @@ -91,21 +96,20 @@ func (s *threadUnsafeSet[T]) Clone() Set[T] {

func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
for _, val := range v {
if _, ok := (*s)[val]; !ok {
if !s.contains(val) {
return false
}
}
return true
}

func (s *threadUnsafeSet[T]) ContainsOne(v T) bool {
_, ok := (*s)[v]
return ok
return s.contains(v)
}

func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool {
for _, val := range v {
if _, ok := (*s)[val]; ok {
if s.contains(val) {
return true
}
}
Expand Down Expand Up @@ -134,8 +138,8 @@ func (s *threadUnsafeSet[T]) ContainsAnyElement(other Set[T]) bool {

// private version of Contains for a single element v
func (s *threadUnsafeSet[T]) contains(v T) (ok bool) {
_, ok = (*s)[v]
return ok
_, found := (*s)[v]
return found
}

func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
Expand Down Expand Up @@ -372,7 +376,7 @@ func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
s.Append(i...)
s.append(i...)

return nil
}
Expand All @@ -393,7 +397,7 @@ func (s threadUnsafeSet[T]) UnmarshalBSONValue(bt bsontype.Type, b []byte) error
if err != nil {
return err
}
s.Append(i...)
s.append(i...)

return nil
}
Loading