|
22 | 22 | import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; |
23 | 23 |
|
24 | 24 | import com.google.common.collect.ImmutableMap; |
| 25 | +import com.google.common.primitives.UnsignedInts; |
25 | 26 | import com.google.errorprone.annotations.CheckReturnValue; |
26 | 27 | import io.grpc.Attributes; |
27 | 28 | import io.grpc.EquivalentAddressGroup; |
|
33 | 34 | import io.grpc.NameResolver; |
34 | 35 | import io.grpc.Status; |
35 | 36 | import io.grpc.StatusOr; |
| 37 | +import io.grpc.internal.GrpcUtil; |
36 | 38 | import io.grpc.util.GracefulSwitchLoadBalancer; |
37 | 39 | import io.grpc.util.OutlierDetectionLoadBalancer.OutlierDetectionLoadBalancerConfig; |
38 | 40 | import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; |
|
74 | 76 | * by a group of sub-clusters in a tree hierarchy. |
75 | 77 | */ |
76 | 78 | final class CdsLoadBalancer2 extends LoadBalancer { |
| 79 | + static boolean pickFirstWeightedShuffling = |
| 80 | + GrpcUtil.getFlag("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true); |
| 81 | + |
77 | 82 | private final XdsLogger logger; |
78 | 83 | private final Helper helper; |
79 | 84 | private final LoadBalancerRegistry lbRegistry; |
@@ -222,6 +227,26 @@ private String errorPrefix() { |
222 | 227 | return "CdsLb for " + clusterName + ": "; |
223 | 228 | } |
224 | 229 |
|
| 230 | + /** |
| 231 | + * The number of bits assigned to the fractional part of fixed-point values. We normalize weights |
| 232 | + * to a fixed-point number between 0 and 1, representing that item's proportion of traffic (1 == |
| 233 | + * 100% of traffic). We reserve at least one bit for the whole number so that we don't need to |
| 234 | + * special case a single item, and so that we can round up very low values without risking uint32 |
| 235 | + * overflow of the sum of weights. |
| 236 | + */ |
| 237 | + private static final int FIXED_POINT_FRACTIONAL_BITS = 31; |
| 238 | + |
| 239 | + /** Divide two uint32s and produce a fixed-point uint32 result. */ |
| 240 | + private static long fractionToFixedPoint(long numerator, long denominator) { |
| 241 | + long one = 1L << FIXED_POINT_FRACTIONAL_BITS; |
| 242 | + return numerator * one / denominator; |
| 243 | + } |
| 244 | + |
| 245 | + /** Multiply two uint32 fixed-point numbers, returning a uint32 fixed-point. */ |
| 246 | + private static long fixedPointMultiply(long a, long b) { |
| 247 | + return (a * b) >> FIXED_POINT_FRACTIONAL_BITS; |
| 248 | + } |
| 249 | + |
225 | 250 | private static StatusOr<EdsUpdate> getEdsUpdate(XdsConfig xdsConfig, String cluster) { |
226 | 251 | StatusOr<XdsClusterConfig> clusterConfig = xdsConfig.getClusters().get(cluster); |
227 | 252 | if (clusterConfig == null) { |
@@ -286,17 +311,61 @@ StatusOr<ClusterResolutionResult> edsUpdateToResult( |
286 | 311 | Map<String, Map<Locality, Integer>> prioritizedLocalityWeights = new HashMap<>(); |
287 | 312 | List<String> sortedPriorityNames = |
288 | 313 | generatePriorityNames(clusterName, localityLbEndpoints); |
| 314 | + Map<String, Long> priorityLocalityWeightSums; |
| 315 | + if (pickFirstWeightedShuffling) { |
| 316 | + priorityLocalityWeightSums = new HashMap<>(sortedPriorityNames.size() * 2); |
| 317 | + for (Locality locality : localityLbEndpoints.keySet()) { |
| 318 | + LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); |
| 319 | + String priorityName = localityPriorityNames.get(locality); |
| 320 | + Long sum = priorityLocalityWeightSums.get(priorityName); |
| 321 | + if (sum == null) { |
| 322 | + sum = 0L; |
| 323 | + } |
| 324 | + long weight = UnsignedInts.toLong(localityLbInfo.localityWeight()); |
| 325 | + priorityLocalityWeightSums.put(priorityName, sum + weight); |
| 326 | + } |
| 327 | + } else { |
| 328 | + priorityLocalityWeightSums = null; |
| 329 | + } |
| 330 | + |
289 | 331 | for (Locality locality : localityLbEndpoints.keySet()) { |
290 | 332 | LocalityLbEndpoints localityLbInfo = localityLbEndpoints.get(locality); |
291 | 333 | String priorityName = localityPriorityNames.get(locality); |
292 | 334 | boolean discard = true; |
| 335 | + // These sums _should_ fit in uint32, but XdsEndpointResource isn't actually verifying that |
| 336 | + // is true today. Since we are using long to avoid signedness trouble, the math happens to |
| 337 | + // still work if it turns out the sums exceed uint32. |
| 338 | + long localityWeightSum = 0; |
| 339 | + long endpointWeightSum = 0; |
| 340 | + if (pickFirstWeightedShuffling) { |
| 341 | + localityWeightSum = priorityLocalityWeightSums.get(priorityName); |
| 342 | + for (LbEndpoint endpoint : localityLbInfo.endpoints()) { |
| 343 | + if (endpoint.isHealthy()) { |
| 344 | + endpointWeightSum += UnsignedInts.toLong(endpoint.loadBalancingWeight()); |
| 345 | + } |
| 346 | + } |
| 347 | + } |
293 | 348 | for (LbEndpoint endpoint : localityLbInfo.endpoints()) { |
294 | 349 | if (endpoint.isHealthy()) { |
295 | 350 | discard = false; |
296 | | - long weight = localityLbInfo.localityWeight(); |
297 | | - if (endpoint.loadBalancingWeight() != 0) { |
298 | | - weight *= endpoint.loadBalancingWeight(); |
| 351 | + long weight; |
| 352 | + if (pickFirstWeightedShuffling) { |
| 353 | + // Combine locality and endpoint weights as defined by gRFC A113 |
| 354 | + long localityWeight = fractionToFixedPoint( |
| 355 | + UnsignedInts.toLong(localityLbInfo.localityWeight()), localityWeightSum); |
| 356 | + long endpointWeight = fractionToFixedPoint( |
| 357 | + UnsignedInts.toLong(endpoint.loadBalancingWeight()), endpointWeightSum); |
| 358 | + weight = fixedPointMultiply(localityWeight, endpointWeight); |
| 359 | + if (weight == 0) { |
| 360 | + weight = 1; |
| 361 | + } |
| 362 | + } else { |
| 363 | + weight = localityLbInfo.localityWeight(); |
| 364 | + if (endpoint.loadBalancingWeight() != 0) { |
| 365 | + weight *= endpoint.loadBalancingWeight(); |
| 366 | + } |
299 | 367 | } |
| 368 | + |
300 | 369 | String localityName = localityName(locality); |
301 | 370 | Attributes attr = |
302 | 371 | endpoint.eag().getAttributes().toBuilder() |
|
0 commit comments