Skip to content

Commit 25c51e4

Browse files
committed
core: Add pick_first weighted shuffle
The prior uniform shuffle in pick_first will send uniform load across clients. When endpoints have weights, we'd desire for endpoints to be selected proportionally to their weight. The server weight attribute has to move out of xDS to be seen by pick-first, but it is kept as internal for now. Since xDS is the only thing that sets weights, the behavior change is only visible to xDS. See gRFC A113
1 parent c589bef commit 25c51e4

File tree

7 files changed

+435
-17
lines changed

7 files changed

+435
-17
lines changed

api/src/main/java/io/grpc/EquivalentAddressGroup.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ public final class EquivalentAddressGroup {
5555
*/
5656
public static final Attributes.Key<String> ATTR_LOCALITY_NAME =
5757
Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY");
58+
/**
59+
* Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32.
60+
* Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is
61+
* twice that of another endpoint, it is intended to receive twice the load.
62+
*/
63+
@Attr
64+
static final Attributes.Key<Long> ATTR_WEIGHT =
65+
Attributes.Key.create("io.grpc.EquivalentAddressGroup.ATTR_WEIGHT");
66+
5867
private final List<SocketAddress> addrs;
5968
private final Attributes attrs;
6069

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright 2026 The gRPC Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.grpc;
18+
19+
@Internal
20+
public final class InternalEquivalentAddressGroup {
21+
private InternalEquivalentAddressGroup() {}
22+
23+
/**
24+
* Endpoint weight for load balancing purposes. While the type is Long, it must be a valid uint32.
25+
* Must not be zero. The weight is proportional to the other endpoints; if an endpoint's weight is
26+
* twice that of another endpoint, it is intended to receive twice the load.
27+
*/
28+
public static final Attributes.Key<Long> ATTR_WEIGHT = EquivalentAddressGroup.ATTR_WEIGHT;
29+
}

core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
import com.google.common.annotations.VisibleForTesting;
2727
import com.google.common.collect.ImmutableList;
2828
import com.google.common.collect.Lists;
29+
import com.google.errorprone.annotations.CheckReturnValue;
2930
import io.grpc.Attributes;
3031
import io.grpc.ConnectivityState;
3132
import io.grpc.ConnectivityStateInfo;
3233
import io.grpc.EquivalentAddressGroup;
34+
import io.grpc.InternalEquivalentAddressGroup;
3335
import io.grpc.LoadBalancer;
3436
import io.grpc.Status;
3537
import io.grpc.SynchronizationContext.ScheduledHandle;
@@ -61,6 +63,8 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer {
6163
static final int CONNECTION_DELAY_INTERVAL_MS = 250;
6264
private final boolean enableHappyEyeballs = !isSerializingRetries()
6365
&& PickFirstLoadBalancerProvider.isEnabledHappyEyeballs();
66+
static boolean weightedShuffling =
67+
GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true);
6468
private final Helper helper;
6569
private final Map<SocketAddress, SubchannelData> subchannels = new HashMap<>();
6670
private final Index addressIndex = new Index(ImmutableList.of(), this.enableHappyEyeballs);
@@ -128,13 +132,13 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
128132
PickFirstLeafLoadBalancerConfig config
129133
= (PickFirstLeafLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
130134
if (config.shuffleAddressList != null && config.shuffleAddressList) {
131-
Collections.shuffle(cleanServers,
132-
config.randomSeed != null ? new Random(config.randomSeed) : new Random());
135+
cleanServers = shuffle(
136+
cleanServers, config.randomSeed != null ? new Random(config.randomSeed) : new Random());
133137
}
134138
}
135139

136140
final ImmutableList<EquivalentAddressGroup> newImmutableAddressGroups =
137-
ImmutableList.<EquivalentAddressGroup>builder().addAll(cleanServers).build();
141+
ImmutableList.copyOf(cleanServers);
138142

139143
if (rawConnectivityState == READY
140144
|| (rawConnectivityState == CONNECTING
@@ -224,6 +228,46 @@ private static List<EquivalentAddressGroup> deDupAddresses(List<EquivalentAddres
224228
return newGroups;
225229
}
226230

231+
// Also used by PickFirstLoadBalancer
232+
@CheckReturnValue
233+
static List<EquivalentAddressGroup> shuffle(List<EquivalentAddressGroup> eags, Random random) {
234+
if (weightedShuffling) {
235+
List<WeightEntry> weightedEntries = new ArrayList<>(eags.size());
236+
for (EquivalentAddressGroup eag : eags) {
237+
weightedEntries.add(new WeightEntry(eag, eagToWeight(eag, random)));
238+
}
239+
Collections.sort(weightedEntries, Collections.reverseOrder() /* descending */);
240+
return Lists.transform(weightedEntries, entry -> entry.eag);
241+
} else {
242+
List<EquivalentAddressGroup> eagsCopy = new ArrayList<>(eags);
243+
Collections.shuffle(eagsCopy, random);
244+
return eagsCopy;
245+
}
246+
}
247+
248+
private static double eagToWeight(EquivalentAddressGroup eag, Random random) {
249+
Long weight = eag.getAttributes().get(InternalEquivalentAddressGroup.ATTR_WEIGHT);
250+
if (weight == null) {
251+
weight = 1L;
252+
}
253+
return Math.pow(random.nextDouble(), 1.0 / weight);
254+
}
255+
256+
private static final class WeightEntry implements Comparable<WeightEntry> {
257+
final EquivalentAddressGroup eag;
258+
final double weight;
259+
260+
public WeightEntry(EquivalentAddressGroup eag, double weight) {
261+
this.eag = eag;
262+
this.weight = weight;
263+
}
264+
265+
@Override
266+
public int compareTo(WeightEntry entry) {
267+
return Double.compare(this.weight, entry.weight);
268+
}
269+
}
270+
227271
@Override
228272
public void handleNameResolutionError(Status error) {
229273
if (rawConnectivityState == SHUTDOWN) {

core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
import io.grpc.EquivalentAddressGroup;
2828
import io.grpc.LoadBalancer;
2929
import io.grpc.Status;
30-
import java.util.ArrayList;
31-
import java.util.Collections;
3230
import java.util.List;
3331
import java.util.Random;
3432
import java.util.concurrent.atomic.AtomicBoolean;
@@ -65,9 +63,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
6563
PickFirstLoadBalancerConfig config
6664
= (PickFirstLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
6765
if (config.shuffleAddressList != null && config.shuffleAddressList) {
68-
servers = new ArrayList<EquivalentAddressGroup>(servers);
69-
Collections.shuffle(servers,
70-
config.randomSeed != null ? new Random(config.randomSeed) : new Random());
66+
servers = PickFirstLeafLoadBalancer.shuffle(
67+
servers, config.randomSeed != null ? new Random(config.randomSeed) : new Random());
7168
}
7269
}
7370

core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java

Lines changed: 139 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import static io.grpc.ConnectivityState.READY;
2424
import static io.grpc.ConnectivityState.SHUTDOWN;
2525
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
26+
import static io.grpc.InternalEquivalentAddressGroup.ATTR_WEIGHT;
2627
import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY;
2728
import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY;
2829
import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY;
@@ -70,10 +71,13 @@
7071
import io.grpc.internal.PickFirstLeafLoadBalancer.PickFirstLeafLoadBalancerConfig;
7172
import java.net.InetSocketAddress;
7273
import java.net.SocketAddress;
74+
import java.util.ArrayDeque;
7375
import java.util.ArrayList;
7476
import java.util.Arrays;
7577
import java.util.Collections;
7678
import java.util.List;
79+
import java.util.Queue;
80+
import java.util.Random;
7781
import java.util.concurrent.ScheduledExecutorService;
7882
import java.util.concurrent.TimeUnit;
7983
import org.junit.After;
@@ -149,6 +153,7 @@ public void uncaughtException(Thread t, Throwable e) {
149153

150154
private String originalHappyEyeballsEnabledValue;
151155
private String originalSerializeRetriesValue;
156+
private boolean originalWeightedShuffling;
152157

153158
private long backoffMillis;
154159

@@ -165,6 +170,8 @@ public void setUp() {
165170
System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS,
166171
Boolean.toString(enableHappyEyeballs));
167172

173+
originalWeightedShuffling = PickFirstLeafLoadBalancer.weightedShuffling;
174+
168175
for (int i = 1; i <= 5; i++) {
169176
SocketAddress addr = new FakeSocketAddress("server" + i);
170177
servers.add(new EquivalentAddressGroup(addr));
@@ -207,6 +214,7 @@ public void tearDown() {
207214
System.setProperty(PickFirstLoadBalancerProvider.GRPC_PF_USE_HAPPY_EYEBALLS,
208215
originalHappyEyeballsEnabledValue);
209216
}
217+
PickFirstLeafLoadBalancer.weightedShuffling = originalWeightedShuffling;
210218

211219
loadBalancer.shutdown();
212220
verifyNoMoreInteractions(mockArgs);
@@ -242,6 +250,12 @@ public void pickAfterResolved() {
242250
verifyNoMoreInteractions(mockHelper);
243251
}
244252

253+
@Test
254+
public void pickAfterResolved_shuffle_oppositeWeightedShuffling() {
255+
PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling;
256+
pickAfterResolved_shuffle();
257+
}
258+
245259
@Test
246260
public void pickAfterResolved_shuffle() {
247261
servers.remove(4);
@@ -305,6 +319,103 @@ public void pickAfterResolved_noShuffle() {
305319
assertNotNull(pickerCaptor.getValue().pickSubchannel(mockArgs));
306320
}
307321

322+
@Test
323+
public void pickAfterResolved_shuffleImplicitUniform_oppositeWeightedShuffling() {
324+
PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling;
325+
pickAfterResolved_shuffleImplicitUniform();
326+
}
327+
328+
@Test
329+
public void pickAfterResolved_shuffleImplicitUniform() {
330+
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(new FakeSocketAddress("server1"));
331+
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(new FakeSocketAddress("server2"));
332+
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(new FakeSocketAddress("server3"));
333+
334+
int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3));
335+
assertThat(counts[0]).isWithin(7).of(33);
336+
assertThat(counts[1]).isWithin(7).of(33);
337+
assertThat(counts[2]).isWithin(7).of(33);
338+
}
339+
340+
@Test
341+
public void pickAfterResolved_shuffleExplicitUniform_oppositeWeightedShuffling() {
342+
PickFirstLeafLoadBalancer.weightedShuffling = !PickFirstLeafLoadBalancer.weightedShuffling;
343+
pickAfterResolved_shuffleExplicitUniform();
344+
}
345+
346+
@Test
347+
public void pickAfterResolved_shuffleExplicitUniform() {
348+
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(
349+
new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build());
350+
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(
351+
new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build());
352+
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(
353+
new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 111L).build());
354+
355+
int[] counts = countAddressSelections(99, Arrays.asList(eag1, eag2, eag3));
356+
assertThat(counts[0]).isWithin(7).of(33);
357+
assertThat(counts[1]).isWithin(7).of(33);
358+
assertThat(counts[2]).isWithin(7).of(33);
359+
}
360+
361+
@Test
362+
public void pickAfterResolved_shuffleWeighted_noWeightedShuffling() {
363+
PickFirstLeafLoadBalancer.weightedShuffling = false;
364+
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(
365+
new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build());
366+
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(
367+
new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build());
368+
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(
369+
new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build());
370+
371+
int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3));
372+
assertThat(counts[0]).isWithin(7).of(33);
373+
assertThat(counts[1]).isWithin(7).of(33);
374+
assertThat(counts[2]).isWithin(7).of(33);
375+
}
376+
377+
@Test
378+
public void pickAfterResolved_shuffleWeighted_weightedShuffling() {
379+
PickFirstLeafLoadBalancer.weightedShuffling = true;
380+
EquivalentAddressGroup eag1 = new EquivalentAddressGroup(
381+
new FakeSocketAddress("server1"), Attributes.newBuilder().set(ATTR_WEIGHT, 12L).build());
382+
EquivalentAddressGroup eag2 = new EquivalentAddressGroup(
383+
new FakeSocketAddress("server2"), Attributes.newBuilder().set(ATTR_WEIGHT, 3L).build());
384+
EquivalentAddressGroup eag3 = new EquivalentAddressGroup(
385+
new FakeSocketAddress("server3"), Attributes.newBuilder().set(ATTR_WEIGHT, 1L).build());
386+
387+
int[] counts = countAddressSelections(100, Arrays.asList(eag1, eag2, eag3));
388+
assertThat(counts[0]).isWithin(7).of(75); // 100*12/16
389+
assertThat(counts[1]).isWithin(7).of(19); // 100*3/16
390+
assertThat(counts[2]).isWithin(7).of(6); // 100*1/16
391+
}
392+
393+
/** Returns int[index_of_eag] array with number of times each eag was selected. */
394+
private int[] countAddressSelections(int trials, List<EquivalentAddressGroup> eags) {
395+
int[] counts = new int[eags.size()];
396+
Random random = new Random(1);
397+
for (int i = 0; i < trials; i++) {
398+
RecordingHelper helper = new RecordingHelper();
399+
LoadBalancer lb = new PickFirstLeafLoadBalancer(helper);
400+
assertThat(lb.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
401+
.setAddresses(eags)
402+
.setAttributes(affinity)
403+
.setLoadBalancingPolicyConfig(
404+
new PickFirstLeafLoadBalancerConfig(true, random.nextLong()))
405+
.build()))
406+
.isSameInstanceAs(Status.OK);
407+
helper.subchannels.remove().listener.onSubchannelState(
408+
ConnectivityStateInfo.forNonError(READY));
409+
410+
assertThat(helper.state).isEqualTo(READY);
411+
Subchannel subchannel = helper.picker.pickSubchannel(mockArgs).getSubchannel();
412+
counts[eags.indexOf(subchannel.getAddresses())]++;
413+
414+
lb.shutdown();
415+
}
416+
return counts;
417+
}
418+
308419
@Test
309420
public void requestConnectionPicker() {
310421
// Set up
@@ -2945,13 +3056,7 @@ public String toString() {
29453056
}
29463057
}
29473058

2948-
private class MockHelperImpl extends LoadBalancer.Helper {
2949-
private final List<Subchannel> subchannels;
2950-
2951-
public MockHelperImpl(List<? extends Subchannel> subchannels) {
2952-
this.subchannels = new ArrayList<Subchannel>(subchannels);
2953-
}
2954-
3059+
private class BaseHelper extends LoadBalancer.Helper {
29553060
@Override
29563061
public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String authority) {
29573062
return null;
@@ -2981,6 +3086,14 @@ public ScheduledExecutorService getScheduledExecutorService() {
29813086
public void refreshNameResolution() {
29823087
// noop
29833088
}
3089+
}
3090+
3091+
private class MockHelperImpl extends BaseHelper {
3092+
private final List<Subchannel> subchannels;
3093+
3094+
public MockHelperImpl(List<? extends Subchannel> subchannels) {
3095+
this.subchannels = new ArrayList<Subchannel>(subchannels);
3096+
}
29843097

29853098
@Override
29863099
public Subchannel createSubchannel(CreateSubchannelArgs args) {
@@ -2997,4 +3110,23 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
29973110
throw new IllegalArgumentException("Unexpected addresses: " + args.getAddresses());
29983111
}
29993112
}
3113+
3114+
class RecordingHelper extends BaseHelper {
3115+
ConnectivityState state;
3116+
SubchannelPicker picker;
3117+
final Queue<FakeSubchannel> subchannels = new ArrayDeque<>();
3118+
3119+
@Override
3120+
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
3121+
this.state = newState;
3122+
this.picker = newPicker;
3123+
}
3124+
3125+
@Override
3126+
public Subchannel createSubchannel(CreateSubchannelArgs args) {
3127+
FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes());
3128+
subchannels.add(subchannel);
3129+
return subchannel;
3130+
}
3131+
}
30003132
}

0 commit comments

Comments
 (0)