Skip to content

Commit c6a48fa

Browse files
committed
Lock sorted fronts
1 parent c8d0683 commit c6a48fa

File tree

5 files changed

+40
-34
lines changed

5 files changed

+40
-34
lines changed

cache.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (f *fronted) prepopulateFronts(cacheFile string) {
4242
now := time.Now()
4343

4444
// update last succeeded status of masquerades based on cached values
45-
for _, fr := range f.fronts {
45+
for _, fr := range f.fronts.fronts {
4646
for _, cf := range cachedFronts {
4747
sameFront := cf.ProviderID == fr.getProviderID() && cf.Domain == fr.getDomain() && cf.IpAddress == fr.getIpAddress()
4848
cachedValueFresh := now.Sub(fr.lastSucceeded()) < f.maxAllowedCachedAge
@@ -81,10 +81,7 @@ func (f *fronted) maintainCache(cacheFile string) {
8181
func (f *fronted) updateCache(cacheFile string) {
8282
log.Debugf("Updating cache at %v", cacheFile)
8383
cache := f.fronts.sortedCopy()
84-
sizeToSave := len(cache)
85-
if f.maxCacheSize < sizeToSave {
86-
sizeToSave = f.maxCacheSize
87-
}
84+
sizeToSave := min(f.maxCacheSize, len(cache))
8885
b, err := json.Marshal(cache[:sizeToSave])
8986
if err != nil {
9087
log.Errorf("Unable to marshal cache to JSON: %v", err)

cache_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func TestCaching(t *testing.T) {
2929
log.Debug("Creating fronted")
3030
makeFronted := func() *fronted {
3131
f := &fronted{
32-
fronts: make(sortedFronts, 0, 1000),
32+
fronts: newSortedFronts(0),
3333
maxAllowedCachedAge: 250 * time.Millisecond,
3434
maxCacheSize: 4,
3535
cacheSaveInterval: 50 * time.Millisecond,
@@ -51,7 +51,7 @@ func TestCaching(t *testing.T) {
5151
f := makeFronted()
5252

5353
log.Debug("Adding fronts")
54-
f.fronts = append(f.fronts, mb, mc, md)
54+
f.fronts.fronts = append(f.fronts.fronts, mb, mc, md)
5555

5656
readCached := func() []*front {
5757
log.Debug("Reading cached fronts")

front.go

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -360,20 +360,28 @@ func NewStatusCodeValidator(reject []int) ResponseValidator {
360360
}
361361
}
362362

363-
// slice of masquerade sorted by last vetted time
364-
type sortedFronts []Front
363+
// fronts sorted by last vetted time
364+
type sortedFronts struct {
365+
fronts []Front
366+
mu sync.RWMutex
367+
}
365368

366-
var frontsMu sync.RWMutex
369+
func newSortedFronts(size int) *sortedFronts {
370+
return &sortedFronts{
371+
fronts: make([]Front, size),
372+
mu: sync.RWMutex{},
373+
}
374+
}
367375

368-
func (m sortedFronts) Len() int { return len(m) }
369-
func (m sortedFronts) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
370-
func (m sortedFronts) Less(i, j int) bool {
371-
if m[i].lastSucceeded().After(m[j].lastSucceeded()) {
376+
func (m *sortedFronts) Len() int { return len(m.fronts) }
377+
func (m *sortedFronts) Swap(i, j int) { m.fronts[i], m.fronts[j] = m.fronts[j], m.fronts[i] }
378+
func (m *sortedFronts) Less(i, j int) bool {
379+
if m.fronts[i].lastSucceeded().After(m.fronts[j].lastSucceeded()) {
372380
return true
373-
} else if m[j].lastSucceeded().After(m[i].lastSucceeded()) {
381+
} else if m.fronts[j].lastSucceeded().After(m.fronts[i].lastSucceeded()) {
374382
return false
375383
} else {
376-
return m[i].getIpAddress() < m[j].getIpAddress()
384+
return m.fronts[i].getIpAddress() < m.fronts[j].getIpAddress()
377385
}
378386
}
379387

@@ -382,26 +390,27 @@ func (m *sortedFronts) sortedCopy() []Front {
382390
defer m.mu.Unlock()
383391
c := make([]Front, len(m.fronts))
384392
copy(c, m.fronts)
385-
sort.Sort(sortedFronts{fronts: c})
393+
sf := sortedFronts{fronts: c}
394+
sort.Sort(&sf)
386395
return c
387396
}
388397

389-
func (m *sortedFronts) addFronts(fronts []Front) {
398+
func (m *sortedFronts) addFronts(fronts *sortedFronts) {
390399
// Add new masquerades to the existing masquerades slice, but add them at the beginning.
391400
m.mu.Lock()
392401
defer m.mu.Unlock()
393-
m.fronts = append(fronts, m.fronts...)
402+
m.fronts = append(fronts.fronts, m.fronts...)
394403
}
395404

396405
func (m *sortedFronts) size() int {
397-
m.mu.Lock()
398-
defer m.mu.Unlock()
406+
m.mu.RLock()
407+
defer m.mu.RUnlock()
399408
return len(m.fronts)
400409
}
401410

402411
func (m *sortedFronts) frontAt(i int) Front {
403-
m.mu.Lock()
404-
defer m.mu.Unlock()
412+
m.mu.RLock()
413+
defer m.mu.RUnlock()
405414
return m.fronts[i]
406415
}
407416

fronted.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ var (
4949
// an implementation of http.RoundTripper for the convenience of callers.
5050
type fronted struct {
5151
certPool atomic.Value
52-
fronts sortedFronts
52+
fronts *sortedFronts
5353
maxAllowedCachedAge time.Duration
5454
maxCacheSize int
5555
cacheFile string
@@ -102,7 +102,7 @@ func NewFronted(options ...Option) Fronted {
102102

103103
f := &fronted{
104104
certPool: atomic.Value{},
105-
fronts: make(sortedFronts, 0),
105+
fronts: newSortedFronts(1000),
106106
maxAllowedCachedAge: defaultMaxAllowedCachedAge,
107107
maxCacheSize: defaultMaxCacheSize,
108108
cacheSaveInterval: defaultCacheSaveInterval,
@@ -559,7 +559,7 @@ func copyProviders(providers map[string]*Provider, countryCode string) map[strin
559559
return providersCopy
560560
}
561561

562-
func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) sortedFronts {
562+
func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) *sortedFronts {
563563
log.Debugf("Loading candidates for %d providers", len(providers))
564564
defer log.Debug("Finished loading candidates")
565565

@@ -569,7 +569,7 @@ func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) sor
569569
size += len(p.Masquerades)
570570
}
571571

572-
fronts := make(sortedFronts, size)
572+
fronts := newSortedFronts(size)
573573

574574
// Note that map iteration order is random, so the order of the providers is automatically randomized.
575575
index := 0
@@ -588,7 +588,7 @@ func loadFronts(providers map[string]*Provider, cacheDirty chan interface{}) sor
588588
}
589589

590590
for _, c := range sh {
591-
fronts[index] = newFront(c, key, cacheDirty)
591+
fronts.fronts[index] = newFront(c, key, cacheDirty)
592592
index++
593593
}
594594
}

fronted_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func doTestDomainFronting(t *testing.T, cacheFile string, expectedMasqueradesAtE
155155
// Check the number of masquerades at the end, waiting until we get the right number
156156
masqueradesAtEnd := 0
157157
for i := 0; i < 1000; i++ {
158-
masqueradesAtEnd = len(d.fronts)
158+
masqueradesAtEnd = len(d.fronts.fronts)
159159
if masqueradesAtEnd == expectedMasqueradesAtEnd {
160160
break
161161
}
@@ -761,9 +761,9 @@ func TestFindWorkingMasquerades(t *testing.T) {
761761
}
762762
f.providers = make(map[string]*Provider)
763763
f.providers["testProviderId"] = NewProvider(nil, "", nil, nil, nil, nil, "")
764-
f.fronts = make(sortedFronts, len(tt.masquerades))
765-
for i, m := range tt.masquerades {
766-
f.fronts[i] = m
764+
f.fronts = newSortedFronts(0)
765+
for _, m := range tt.masquerades {
766+
f.fronts.fronts = append(f.fronts.fronts, m)
767767
}
768768

769769
f.tryAllFronts()
@@ -806,9 +806,9 @@ func TestLoadFronts(t *testing.T) {
806806
cacheDirty := make(chan interface{}, 10)
807807
masquerades := loadFronts(providers, cacheDirty)
808808

809-
assert.Equal(t, 4, len(masquerades), "Unexpected number of masquerades loaded")
809+
assert.Equal(t, 4, len(masquerades.fronts), "Unexpected number of masquerades loaded")
810810

811-
for _, m := range masquerades {
811+
for _, m := range masquerades.fronts {
812812
assert.True(t, expected[m.getDomain()], "Unexpected masquerade domain: %s", m.getDomain())
813813
}
814814
}

0 commit comments

Comments
 (0)