Skip to content

Commit 73467e8

Browse files
client_routes: add sticky route tracking
When a host has multiple routes (one per connection), remember which connectionID was used on the first successful TranslateHost call and prefer it on subsequent calls via findPreferredRoute. If the preferred route is removed (e.g. connection pruned), fall back to FindByHostID. This avoids unnecessary connection churn when multiple PrivateLink endpoints serve the same host.
1 parent 9b9e7ac commit 73467e8

2 files changed

Lines changed: 88 additions & 13 deletions

File tree

client_routes.go

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,16 @@ func (l *UnresolvedClientRouteList) FindByHostID(hostID string) *UnresolvedClien
187187
}
188188

189189
type ClientRoutesHandler struct {
190-
log StdLogger
191-
c controlConnection
192-
resolver DNSResolver
193-
sub *eventbus.Subscriber[events.Event]
194-
routes UnresolvedClientRouteList
195-
updateTasks chan updateTask
196-
closeChan chan struct{}
197-
cfg ClientRoutesConfig
198-
mu sync.RWMutex
190+
log StdLogger
191+
c controlConnection
192+
resolver DNSResolver
193+
sub *eventbus.Subscriber[events.Event]
194+
routes UnresolvedClientRouteList
195+
stickyRoute map[string]string // hostID → preferred connectionID
196+
updateTasks chan updateTask
197+
closeChan chan struct{}
198+
cfg ClientRoutesConfig
199+
mu sync.RWMutex
199200
pickTLSPorts bool
200201
initialized bool
201202
}
@@ -215,6 +216,20 @@ func pickProperPort(pickTLSPorts bool, rec *UnresolvedClientRoute) uint16 {
215216
return rec.CQLPort
216217
}
217218

219+
// findPreferredRoute returns the route for hostID that matches the sticky
220+
// connectionID, falling back to the first route for that host.
221+
// Must be called with p.mu held (at least RLock).
222+
func (p *ClientRoutesHandler) findPreferredRoute(hostID string) *UnresolvedClientRoute {
223+
if preferred, ok := p.stickyRoute[hostID]; ok {
224+
for i := range p.routes {
225+
if p.routes[i].HostID == hostID && p.routes[i].ConnectionID == preferred {
226+
return &p.routes[i]
227+
}
228+
}
229+
}
230+
return p.routes.FindByHostID(hostID)
231+
}
232+
218233
// TranslateHost implements AddressTranslatorV2 interface.
219234
// It resolves DNS on every call rather than caching resolved addresses.
220235
func (p *ClientRoutesHandler) TranslateHost(host AddressTranslatorHostInfo, addr AddressPort) (AddressPort, error) {
@@ -224,7 +239,7 @@ func (p *ClientRoutesHandler) TranslateHost(host AddressTranslatorHostInfo, addr
224239
}
225240

226241
p.mu.RLock()
227-
rec := p.routes.FindByHostID(hostID)
242+
rec := p.findPreferredRoute(hostID)
228243
var route UnresolvedClientRoute
229244
found := rec != nil
230245
if found {
@@ -249,6 +264,10 @@ func (p *ClientRoutesHandler) TranslateHost(host AddressTranslatorHostInfo, addr
249264
return addr, fmt.Errorf("record %s/%s has target port empty", route.HostID, route.ConnectionID)
250265
}
251266

267+
p.mu.Lock()
268+
p.stickyRoute[hostID] = route.ConnectionID
269+
p.mu.Unlock()
270+
252271
return AddressPort{Address: ips[0], Port: port}, nil
253272
}
254273

@@ -400,6 +419,7 @@ func NewClientRoutesAddressTranslator(
400419
updateTasks: make(chan updateTask, 1024),
401420
resolver: resolver,
402421
routes: make(UnresolvedClientRouteList, 0),
422+
stickyRoute: make(map[string]string),
403423
}
404424
}
405425

client_routes_unit_test.go

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ func TestClientRoutesHandlerTranslateHost(t *testing.T) {
119119
})
120120

121121
handler := &ClientRoutesHandler{
122-
resolver: resolver,
123-
routes: make(UnresolvedClientRouteList, 0),
122+
stickyRoute: make(map[string]string),
123+
resolver: resolver,
124+
routes: make(UnresolvedClientRouteList, 0),
124125
}
125126

126127
res, err := handler.TranslateHost(noHost, addr)
@@ -162,14 +163,68 @@ func TestClientRoutesHandlerTranslateHost(t *testing.T) {
162163
resolver: dnsResolverFunc(func(host string) ([]net.IP, error) {
163164
return nil, errors.New("lookup failed")
164165
}),
165-
routes: UnresolvedClientRouteList{{ConnectionID: "c2", HostID: "h2", Address: "host", CQLPort: 9042}},
166+
stickyRoute: make(map[string]string),
167+
routes: UnresolvedClientRouteList{{ConnectionID: "c2", HostID: "h2", Address: "host", CQLPort: 9042}},
166168
}
167169
_, err = errorHandler.TranslateHost(testHostInfo{hostID: "h2"}, addr)
168170
if err == nil {
169171
t.Fatalf("expected resolver error to bubble up")
170172
}
171173
}
172174

175+
func TestTranslateHost_StickyRoute(t *testing.T) {
176+
addr := AddressPort{Address: net.ParseIP("1.1.1.1"), Port: 9042}
177+
resolvedIPs := map[string]net.IP{
178+
"addr-c1": net.ParseIP("10.0.0.1"),
179+
"addr-c2": net.ParseIP("10.0.0.2"),
180+
}
181+
handler := &ClientRoutesHandler{
182+
pickTLSPorts: false,
183+
stickyRoute: make(map[string]string),
184+
resolver: dnsResolverFunc(func(host string) ([]net.IP, error) {
185+
if ip, ok := resolvedIPs[host]; ok {
186+
return []net.IP{ip}, nil
187+
}
188+
return nil, fmt.Errorf("unknown host %s", host)
189+
}),
190+
routes: UnresolvedClientRouteList{
191+
{ConnectionID: "c1", HostID: "h1", Address: "addr-c1", CQLPort: 9042},
192+
{ConnectionID: "c2", HostID: "h1", Address: "addr-c2", CQLPort: 9042},
193+
},
194+
}
195+
196+
// First call picks the first route (c1) and records it as sticky.
197+
res, err := handler.TranslateHost(testHostInfo{hostID: "h1"}, addr)
198+
if err != nil {
199+
t.Fatalf("unexpected error: %v", err)
200+
}
201+
if !res.Address.Equal(net.ParseIP("10.0.0.1")) {
202+
t.Fatalf("expected first route IP 10.0.0.1, got %v", res.Address)
203+
}
204+
205+
// Second call should stick to c1 even though c2 also matches h1.
206+
res, err = handler.TranslateHost(testHostInfo{hostID: "h1"}, addr)
207+
if err != nil {
208+
t.Fatalf("unexpected error: %v", err)
209+
}
210+
if !res.Address.Equal(net.ParseIP("10.0.0.1")) {
211+
t.Fatalf("expected sticky route IP 10.0.0.1, got %v", res.Address)
212+
}
213+
214+
// Remove c1 route; sticky route should fall back to c2.
215+
handler.mu.Lock()
216+
handler.routes = handler.routes[1:]
217+
handler.mu.Unlock()
218+
219+
res, err = handler.TranslateHost(testHostInfo{hostID: "h1"}, addr)
220+
if err != nil {
221+
t.Fatalf("unexpected error: %v", err)
222+
}
223+
if !res.Address.Equal(net.ParseIP("10.0.0.2")) {
224+
t.Fatalf("expected fallback to c2 IP 10.0.0.2, got %v", res.Address)
225+
}
226+
}
227+
173228
func TestGetHostPortMappingFromClusterQuery(t *testing.T) {
174229
tcases := []struct {
175230
name string

0 commit comments

Comments
 (0)