Skip to content

Commit 61d569d

Browse files
authored
*: Implementation of weighted random shuffling (A113) (#8864)
This PR implements the currently in-review gRFC A113: grpc/proposal#535. I've split the PR into logically separate commits to help with the review process. Summary of changes: - Commit 1: simplify the implementation of `groupLocalitiesByPriority` - Change the implementation to use newly added methods in the stdlib `maps` and `slices` package to significantly simplify the implementation (and get rid of an unnecessary test) - Commit 2: Remove code that handles localities and endpoints of weight 0 - Remove unnecessary checks for locality and endpoint weights of `0` in `cluster_resolver`. The xDS client already guarantees that these weights will never be set to `0`. - Commit 3: add the env var GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING - Commit 4: Weight computation changes in cluster_resolver LB policy - This performs the weight normalization and fixed-point arithmetic specified in A113 - The change here is guarded by the above env var - Ended up duplicating the tests that verify the weight computation behavior. This will make it easier to delete the old tests when the env var is removed. - Commit 5: Fix a broken test in ring_hash due to the new weight computation - Commit 6: Weighted shuffling in pick_first - Contains the changes specified in A113 for the pick_first LB policy - Changes are guarded by the env var RELEASE NOTES: - pickfirst: Add support for weighted random shuffling of endpoints, as described in gRFC A113
1 parent 7985bb4 commit 61d569d

File tree

7 files changed

+591
-141
lines changed

7 files changed

+591
-141
lines changed

balancer/pickfirst/internal/internal.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import (
2626
var (
2727
// RandShuffle pseudo-randomizes the order of addresses.
2828
RandShuffle = rand.Shuffle
29+
// RandFloat64 returns, as a float64, a pseudo-random number in [0.0,1.0).
30+
RandFloat64 = rand.Float64
2931
// TimeAfterFunc allows mocking the timer for testing connection delay
3032
// related functionality.
3133
TimeAfterFunc = func(d time.Duration, f func()) func() {

balancer/pickfirst/pickfirst.go

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
package pickfirst
2222

2323
import (
24+
"cmp"
2425
"encoding/json"
2526
"errors"
2627
"fmt"
28+
"math"
2729
"net"
2830
"net/netip"
31+
"slices"
2932
"sync"
3033
"time"
3134

@@ -34,6 +37,8 @@ import (
3437
"google.golang.org/grpc/connectivity"
3538
expstats "google.golang.org/grpc/experimental/stats"
3639
"google.golang.org/grpc/grpclog"
40+
"google.golang.org/grpc/internal/balancer/weight"
41+
"google.golang.org/grpc/internal/envconfig"
3742
internalgrpclog "google.golang.org/grpc/internal/grpclog"
3843
"google.golang.org/grpc/internal/pretty"
3944
"google.golang.org/grpc/resolver"
@@ -258,8 +263,42 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState
258263
// will change the order of endpoints but not touch the order of the
259264
// addresses within each endpoint. - A61
260265
if cfg.ShuffleAddressList {
261-
endpoints = append([]resolver.Endpoint{}, endpoints...)
262-
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
266+
if envconfig.PickFirstWeightedShuffling {
267+
type weightedEndpoint struct {
268+
endpoint resolver.Endpoint
269+
weight float64
270+
}
271+
272+
// For each endpoint, compute a key as described in A113 and
273+
// https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf:
274+
var weightedEndpoints []weightedEndpoint
275+
for _, endpoint := range endpoints {
276+
u := internal.RandFloat64() // Random number in [0.0, 1.0)
277+
weight := weightAttribute(endpoint)
278+
weightedEndpoints = append(weightedEndpoints, weightedEndpoint{
279+
endpoint: endpoint,
280+
weight: math.Pow(u, 1.0/float64(weight)),
281+
})
282+
}
283+
// Sort endpoints by key in descending order and reconstruct the
284+
// endpoints slice.
285+
slices.SortFunc(weightedEndpoints, func(a, b weightedEndpoint) int {
286+
return cmp.Compare(b.weight, a.weight)
287+
})
288+
289+
// Here, and in the "else" block below, we clone the endpoints
290+
// slice to avoid mutating the resolver state. Doing the latter
291+
// would lead to data races if the caller is accessing the same
292+
// slice concurrently.
293+
sortedEndpoints := make([]resolver.Endpoint, len(endpoints))
294+
for i, we := range weightedEndpoints {
295+
sortedEndpoints[i] = we.endpoint
296+
}
297+
endpoints = sortedEndpoints
298+
} else {
299+
endpoints = slices.Clone(endpoints)
300+
internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] })
301+
}
263302
}
264303

265304
// "Flatten the list by concatenating the ordered list of addresses for
@@ -906,3 +945,17 @@ func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool {
906945
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
907946
a.Attributes.Equal(b.Attributes)
908947
}
948+
949+
// weightAttribute is a convenience function which returns the value of the
950+
// weight endpoint Attribute.
951+
//
952+
// When used in the xDS context, the weight attribute is guaranteed to be
953+
// non-zero. But, when used in a non-xDS context, the weight attribute could be
954+
// unset. A Default of 1 is used in the latter case.
955+
func weightAttribute(e resolver.Endpoint) uint32 {
956+
w := weight.FromEndpoint(e).Weight
957+
if w == 0 {
958+
return 1
959+
}
960+
return w
961+
}

balancer/pickfirst/pickfirst_ext_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ import (
3939
"google.golang.org/grpc/credentials/insecure"
4040
"google.golang.org/grpc/internal"
4141
"google.golang.org/grpc/internal/balancer/stub"
42+
"google.golang.org/grpc/internal/balancer/weight"
4243
"google.golang.org/grpc/internal/channelz"
44+
"google.golang.org/grpc/internal/envconfig"
4345
"google.golang.org/grpc/internal/grpcsync"
4446
"google.golang.org/grpc/internal/grpctest"
4547
"google.golang.org/grpc/internal/stubserver"
@@ -425,6 +427,8 @@ func (s) TestPickFirst_StickyTransientFailure(t *testing.T) {
425427

426428
// Tests the PF LB policy with shuffling enabled.
427429
func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
430+
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, false)
431+
428432
const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}`
429433

430434
// Install a shuffler that always reverses two entries.
@@ -485,6 +489,8 @@ func (s) TestPickFirst_ShuffleAddressList(t *testing.T) {
485489
// Endpoints field in the resolver update to test the shuffling of the
486490
// Addresses.
487491
func (s) TestPickFirst_ShuffleAddressListNoEndpoints(t *testing.T) {
492+
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, false)
493+
488494
// Install a shuffler that always reverses two entries.
489495
origShuf := pfinternal.RandShuffle
490496
defer func() { pfinternal.RandShuffle = origShuf }()
@@ -560,8 +566,73 @@ func (s) TestPickFirst_ShuffleAddressListNoEndpoints(t *testing.T) {
560566
}
561567
}
562568

569+
// Tests the PF LB policy with weighted shuffling enabled.
570+
func (s) TestPickFirst_ShuffleAddressList_WeightedShuffling(t *testing.T) {
571+
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, true)
572+
573+
const serviceConfig = `{"loadBalancingConfig": [{"pick_first":{ "shuffleAddressList": true }}]}`
574+
575+
// Install a rand func that returns a constant value. The test sets up three
576+
// endpoints with increasing weights. This means that in the weighted
577+
// shuffling algorithm, the endpoints will end up with increasing values for
578+
// their keys. And since the algorithm sorts in descending order, the last
579+
// endpoint should be the one that would get picked.
580+
origRand := pfinternal.RandFloat64
581+
defer func() { pfinternal.RandFloat64 = origRand }()
582+
pfinternal.RandFloat64 = func() float64 {
583+
return 0.5
584+
}
585+
586+
// Set up our backends.
587+
cc, r, backends := setupPickFirst(t, 3)
588+
addrs := stubBackendsToResolverAddrs(backends)
589+
590+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
591+
defer cancel()
592+
593+
// Create endpoints for the above backends with increasing weights.
594+
ep1 := resolver.Endpoint{Addresses: []resolver.Address{addrs[0]}}
595+
ep1 = weight.Set(ep1, weight.EndpointInfo{Weight: 357913941}) // Normalized weight of 1/6
596+
ep2 := resolver.Endpoint{Addresses: []resolver.Address{addrs[1]}}
597+
ep2 = weight.Set(ep2, weight.EndpointInfo{Weight: 715827882}) // Normalized weight of 2/6
598+
ep3 := resolver.Endpoint{Addresses: []resolver.Address{addrs[2]}}
599+
ep3 = weight.Set(ep3, weight.EndpointInfo{Weight: 1073741824}) // Normalized weight of 3/6
600+
601+
// Push an update with all addresses and shuffling disabled. We should
602+
// connect to backend 0.
603+
r.UpdateState(resolver.State{Endpoints: []resolver.Endpoint{ep1, ep2, ep3}})
604+
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
605+
t.Fatal(err)
606+
}
607+
608+
// Send a config with shuffling enabled. This will reverse the addresses,
609+
// but the channel should still be connected to backend 0.
610+
shufState := resolver.State{
611+
ServiceConfig: parseServiceConfig(t, r, serviceConfig),
612+
Endpoints: []resolver.Endpoint{ep1, ep2, ep3},
613+
}
614+
r.UpdateState(shufState)
615+
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[0]); err != nil {
616+
t.Fatal(err)
617+
}
618+
619+
// Send a resolver update with no addresses. This should push the channel
620+
// into TransientFailure.
621+
r.UpdateState(resolver.State{})
622+
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
623+
624+
// Send the same config as last time with shuffling enabled. Since we are
625+
// not connected to backend 0, we should connect to backend 2.
626+
r.UpdateState(shufState)
627+
if err := pickfirst.CheckRPCsToBackend(ctx, cc, addrs[2]); err != nil {
628+
t.Fatal(err)
629+
}
630+
}
631+
563632
// Test config parsing with the env var turned on and off for various scenarios.
564633
func (s) TestPickFirst_ParseConfig_Success(t *testing.T) {
634+
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, false)
635+
565636
// Install a shuffler that always reverses two entries.
566637
origShuf := pfinternal.RandShuffle
567638
defer func() { pfinternal.RandShuffle = origShuf }()

balancer/ringhash/ringhash_e2e_test.go

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,29 +1240,30 @@ func (s) TestRingHash_UnsupportedHashPolicyUntilChannelIdHashing(t *testing.T) {
12401240
// Tests that ring hash policy that hashes using a random value can spread RPCs
12411241
// across all the backends according to locality weight.
12421242
func (s) TestRingHash_RandomHashingDistributionAccordingToLocalityAndEndpointWeight(t *testing.T) {
1243+
testutils.SetEnvConfig(t, &envconfig.PickFirstWeightedShuffling, true)
12431244
backends := backendAddrs(startTestServiceBackends(t, 2))
12441245

12451246
const clusterName = "cluster"
1246-
const locality1Weight = uint32(1)
1247-
const endpoint1Weight = uint32(1)
1248-
const locality2Weight = uint32(2)
1249-
const endpoint2Weight = uint32(2)
1247+
const locality0Weight = uint32(1)
1248+
const endpoint0Weight = uint32(1)
1249+
const locality1Weight = uint32(2)
1250+
const endpoint1Weight = uint32(2)
12501251
endpoints := e2e.EndpointResourceWithOptions(e2e.EndpointOptions{
12511252
ClusterName: clusterName,
12521253
Localities: []e2e.LocalityOptions{
12531254
{
12541255
Backends: []e2e.BackendOptions{{
12551256
Ports: []uint32{testutils.ParsePort(t, backends[0])},
1256-
Weight: endpoint1Weight,
1257+
Weight: endpoint0Weight,
12571258
}},
1258-
Weight: locality1Weight,
1259+
Weight: locality0Weight,
12591260
},
12601261
{
12611262
Backends: []e2e.BackendOptions{{
12621263
Ports: []uint32{testutils.ParsePort(t, backends[1])},
1263-
Weight: endpoint2Weight,
1264+
Weight: endpoint1Weight,
12641265
}},
1265-
Weight: locality2Weight,
1266+
Weight: locality1Weight,
12661267
},
12671268
},
12681269
})
@@ -1289,21 +1290,26 @@ func (s) TestRingHash_RandomHashingDistributionAccordingToLocalityAndEndpointWei
12891290
defer conn.Close()
12901291
client := testgrpc.NewTestServiceClient(conn)
12911292

1292-
const weight1 = endpoint1Weight * locality1Weight
1293-
const weight2 = endpoint2Weight * locality2Weight
1294-
const wantRPCs1 = float64(weight1) / float64(weight1+weight2)
1295-
const wantRPCs2 = float64(weight2) / float64(weight1+weight2)
1296-
numRPCs := computeIdealNumberOfRPCs(t, math.Min(wantRPCs1, wantRPCs2), errorTolerance)
1293+
// The target fraction of RPCs to each backend is computed as the product of
1294+
// the probability of selecting the locality and the probability of
1295+
// selecting the endpoint within the locality. The probability of selecting
1296+
// locality0 is 1/3 and locality1 is 2/3. Since there is only one endpoint
1297+
// in each locality, the probability of selecting the endpoint within the
1298+
// locality is 1. Therefore, the target fractions end up as 1/3 and 2/3
1299+
// respectively.
1300+
const wantRPCs0 = float64(1) / float64(3)
1301+
const wantRPCs1 = float64(2) / float64(3)
1302+
numRPCs := computeIdealNumberOfRPCs(t, math.Min(wantRPCs0, wantRPCs1), errorTolerance)
12971303

12981304
// Send a large number of RPCs and check that they are distributed randomly.
12991305
gotPerBackend := checkRPCSendOK(ctx, t, client, numRPCs)
13001306
got := float64(gotPerBackend[backends[0]]) / float64(numRPCs)
1301-
if !cmp.Equal(got, wantRPCs1, cmpopts.EquateApprox(0, errorTolerance)) {
1302-
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[2], got, wantRPCs1, errorTolerance)
1307+
if !cmp.Equal(got, wantRPCs0, cmpopts.EquateApprox(0, errorTolerance)) {
1308+
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[0], got, wantRPCs0, errorTolerance)
13031309
}
13041310
got = float64(gotPerBackend[backends[1]]) / float64(numRPCs)
1305-
if !cmp.Equal(got, wantRPCs2, cmpopts.EquateApprox(0, errorTolerance)) {
1306-
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[2], got, wantRPCs2, errorTolerance)
1311+
if !cmp.Equal(got, wantRPCs1, cmpopts.EquateApprox(0, errorTolerance)) {
1312+
t.Errorf("Fraction of RPCs to backend %s: got %v, want %v (margin: +-%v)", backends[1], got, wantRPCs1, errorTolerance)
13071313
}
13081314
}
13091315

internal/envconfig/envconfig.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ var (
9090
// This feature is defined in gRFC A81 and is enabled by setting the
9191
// environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true".
9292
XDSAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)
93+
94+
// PickFirstWeightedShuffling indicates whether weighted endpoint shuffling
95+
// is enabled in the pick_first LB policy, as defined in gRFC A113. This
96+
// feature can be disabled by setting the environment variable
97+
// GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false".
98+
PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true)
9399
)
94100

95101
func boolFromEnv(envVar string, def bool) bool {

0 commit comments

Comments
 (0)