Skip to content

Commit d4a569c

Browse files
authored
Token based authentication integration with core extension (#4011)
* tba draft * - stop authxmanager on pool close - swith to long dates * drop use of authxmanager and authenticatedconnection from core * -update submodule ref -change exception message * - remove submodule - update dependency * back to current version * - move autxhmanager creation to user space - introduce authenticationeventlisteners - clenaup in connectionpool - add entraidtestcontext - add redisintegrationtests - fix failing tokenbasedauthentication unit&integ tests * - prevent use of pubsub with TBA+RESP2 combination - fix flaky test * - support tba with clusters - add cluster+tba tests * - remove onerror from authxmanager - fix flaky tests * - fix flaky test * fix renewalDuringOperationsTest * -reviews from @sazzad16 * - fix config for managedIdentity - set audiences with scopes - managed identity tests * review from @ggivo - use getuser instead oid from Token * handle and propogate from unsuccessful AUTH response * adding reauth support for both pubsub and shardedpubsub * fix ping issue with pubsub * - review from @sazzad16 : make JedisSafeAuthenticator protected - fix failing unit tests * update authx version * - remove workaround for standalone endpoint
1 parent 90583d0 commit d4a569c

21 files changed

+1902
-86
lines changed

pom.xml

+13
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@
7575
<version>2.11.0</version>
7676
</dependency>
7777

78+
<dependency>
79+
<groupId>redis.clients.authentication</groupId>
80+
<artifactId>redis-authx-core</artifactId>
81+
<version>0.1.1-beta1</version>
82+
</dependency>
83+
7884
<!-- Optional dependencies -->
7985

8086
<!-- UNIX socket connection support -->
@@ -150,6 +156,13 @@
150156
<scope>test</scope>
151157
</dependency>
152158

159+
<dependency>
160+
<groupId>redis.clients.authentication</groupId>
161+
<artifactId>redis-authx-entraid</artifactId>
162+
<version>0.1.1-beta1</version>
163+
<scope>test</scope>
164+
</dependency>
165+
153166
<!-- circuit breaker / failover -->
154167
<dependency>
155168
<groupId>io.github.resilience4j</groupId>

src/main/java/redis/clients/jedis/Connection.java

+40-11
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
import java.util.List;
1515
import java.util.Map;
1616
import java.util.function.Supplier;
17+
import java.util.concurrent.atomic.AtomicReference;
1718

1819
import redis.clients.jedis.Protocol.Command;
1920
import redis.clients.jedis.Protocol.Keyword;
2021
import redis.clients.jedis.annots.Experimental;
2122
import redis.clients.jedis.args.ClientAttributeOption;
2223
import redis.clients.jedis.args.Rawable;
24+
import redis.clients.jedis.authentication.AuthXManager;
2325
import redis.clients.jedis.commands.ProtocolCommand;
2426
import redis.clients.jedis.exceptions.JedisConnectionException;
2527
import redis.clients.jedis.exceptions.JedisDataException;
@@ -44,6 +46,8 @@ public class Connection implements Closeable {
4446
private String strVal;
4547
protected String server;
4648
protected String version;
49+
private AtomicReference<RedisCredentials> currentCredentials = new AtomicReference<>(null);
50+
private AuthXManager authXManager;
4751

4852
public Connection() {
4953
this(Protocol.DEFAULT_HOST, Protocol.DEFAULT_PORT);
@@ -63,6 +67,7 @@ public Connection(final HostAndPort hostAndPort, final JedisClientConfig clientC
6367

6468
public Connection(final JedisSocketFactory socketFactory) {
6569
this.socketFactory = socketFactory;
70+
this.authXManager = null;
6671
}
6772

6873
public Connection(final JedisSocketFactory socketFactory, JedisClientConfig clientConfig) {
@@ -93,8 +98,8 @@ public String toIdentityString() {
9398
SocketAddress remoteAddr = socket.getRemoteSocketAddress();
9499
SocketAddress localAddr = socket.getLocalSocketAddress();
95100
if (remoteAddr != null) {
96-
strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id,
97-
localAddr, (broken ? '!' : '-'), remoteAddr);
101+
strVal = String.format("%s{id: 0x%X, L:%s %c R:%s}", className, id, localAddr,
102+
(broken ? '!' : '-'), remoteAddr);
98103
} else if (localAddr != null) {
99104
strVal = String.format("%s{id: 0x%X, L:%s}", className, id, localAddr);
100105
} else {
@@ -438,8 +443,8 @@ private static boolean validateClientInfo(String info) {
438443
for (int i = 0; i < info.length(); i++) {
439444
char c = info.charAt(i);
440445
if (c < '!' || c > '~') {
441-
throw new JedisValidationException("client info cannot contain spaces, "
442-
+ "newlines or special characters.");
446+
throw new JedisValidationException(
447+
"client info cannot contain spaces, " + "newlines or special characters.");
443448
}
444449
}
445450
return true;
@@ -451,7 +456,13 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {
451456

452457
protocol = config.getRedisProtocol();
453458

454-
final Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
459+
Supplier<RedisCredentials> credentialsProvider = config.getCredentialsProvider();
460+
461+
authXManager = config.getAuthXManager();
462+
if (authXManager != null) {
463+
credentialsProvider = authXManager;
464+
}
465+
455466
if (credentialsProvider instanceof RedisCredentialsProvider) {
456467
final RedisCredentialsProvider redisCredentialsProvider = (RedisCredentialsProvider) credentialsProvider;
457468
try {
@@ -469,7 +480,8 @@ protected void initializeFromClientConfig(final JedisClientConfig config) {
469480

470481
String clientName = config.getClientName();
471482
if (clientName != null && validateClientInfo(clientName)) {
472-
fireAndForgetMsg.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
483+
fireAndForgetMsg
484+
.add(new CommandArguments(Command.CLIENT).add(Keyword.SETNAME).add(clientName));
473485
}
474486

475487
ClientSetInfoConfig setInfoConfig = config.getClientSetInfoConfig();
@@ -525,12 +537,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c
525537
if (protocol != null && credentials != null && credentials.getUser() != null) {
526538
byte[] rawPass = encodeToBytes(credentials.getPassword());
527539
try {
528-
helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(), encode(credentials.getUser()), rawPass);
540+
helloResult = hello(encode(protocol.version()), Keyword.AUTH.getRaw(),
541+
encode(credentials.getUser()), rawPass);
529542
} finally {
530543
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
531544
}
532545
} else {
533-
auth(credentials);
546+
authenticate(credentials);
534547
helloResult = protocol == null ? null : hello(encode(protocol.version()));
535548
}
536549
if (helloResult != null) {
@@ -542,9 +555,13 @@ private void helloAndAuth(final RedisProtocol protocol, final RedisCredentials c
542555
// handled in RedisCredentialsProvider.cleanUp()
543556
}
544557

545-
private void auth(RedisCredentials credentials) {
558+
public void setCredentials(RedisCredentials credentials) {
559+
currentCredentials.set(credentials);
560+
}
561+
562+
private String authenticate(RedisCredentials credentials) {
546563
if (credentials == null || credentials.getPassword() == null) {
547-
return;
564+
return null;
548565
}
549566
byte[] rawPass = encodeToBytes(credentials.getPassword());
550567
try {
@@ -556,7 +573,11 @@ private void auth(RedisCredentials credentials) {
556573
} finally {
557574
Arrays.fill(rawPass, (byte) 0); // clear sensitive data
558575
}
559-
getStatusCodeReply();
576+
return getStatusCodeReply();
577+
}
578+
579+
public String reAuthenticate() {
580+
return authenticate(currentCredentials.getAndSet(null));
560581
}
561582

562583
protected Map<String, Object> hello(byte[]... args) {
@@ -585,4 +606,12 @@ public boolean ping() {
585606
}
586607
return true;
587608
}
609+
610+
protected boolean isTokenBasedAuthenticationEnabled() {
611+
return authXManager != null;
612+
}
613+
614+
protected AuthXManager getAuthXManager() {
615+
return authXManager;
616+
}
588617
}

src/main/java/redis/clients/jedis/ConnectionFactory.java

+65-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import org.slf4j.Logger;
77
import org.slf4j.LoggerFactory;
88

9+
import java.util.function.Supplier;
10+
911
import redis.clients.jedis.annots.Experimental;
12+
import redis.clients.jedis.authentication.AuthXManager;
13+
import redis.clients.jedis.authentication.JedisAuthenticationException;
14+
import redis.clients.jedis.authentication.AuthXEventListener;
1015
import redis.clients.jedis.csc.Cache;
1116
import redis.clients.jedis.csc.CacheConnection;
1217
import redis.clients.jedis.exceptions.JedisException;
@@ -20,28 +25,52 @@ public class ConnectionFactory implements PooledObjectFactory<Connection> {
2025

2126
private final JedisSocketFactory jedisSocketFactory;
2227
private final JedisClientConfig clientConfig;
23-
private Cache clientSideCache = null;
28+
private final Cache clientSideCache;
29+
private final Supplier<Connection> objectMaker;
30+
31+
private final AuthXEventListener authXEventListener;
2432

2533
public ConnectionFactory(final HostAndPort hostAndPort) {
26-
this.clientConfig = DefaultJedisClientConfig.builder().build();
27-
this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort);
34+
this(hostAndPort, DefaultJedisClientConfig.builder().build(), null);
2835
}
2936

3037
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig) {
31-
this.clientConfig = clientConfig;
32-
this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig);
38+
this(hostAndPort, clientConfig, null);
3339
}
3440

3541
@Experimental
36-
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig, Cache csCache) {
37-
this.clientConfig = clientConfig;
38-
this.jedisSocketFactory = new DefaultJedisSocketFactory(hostAndPort, this.clientConfig);
39-
this.clientSideCache = csCache;
42+
public ConnectionFactory(final HostAndPort hostAndPort, final JedisClientConfig clientConfig,
43+
Cache csCache) {
44+
this(new DefaultJedisSocketFactory(hostAndPort, clientConfig), clientConfig, csCache);
4045
}
4146

42-
public ConnectionFactory(final JedisSocketFactory jedisSocketFactory, final JedisClientConfig clientConfig) {
43-
this.clientConfig = clientConfig;
47+
public ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
48+
final JedisClientConfig clientConfig) {
49+
this(jedisSocketFactory, clientConfig, null);
50+
}
51+
52+
private ConnectionFactory(final JedisSocketFactory jedisSocketFactory,
53+
final JedisClientConfig clientConfig, Cache csCache) {
54+
4455
this.jedisSocketFactory = jedisSocketFactory;
56+
this.clientSideCache = csCache;
57+
this.clientConfig = clientConfig;
58+
59+
AuthXManager authXManager = clientConfig.getAuthXManager();
60+
if (authXManager == null) {
61+
this.objectMaker = connectionSupplier();
62+
this.authXEventListener = AuthXEventListener.NOOP_LISTENER;
63+
} else {
64+
Supplier<Connection> supplier = connectionSupplier();
65+
this.objectMaker = () -> (Connection) authXManager.addConnection(supplier.get());
66+
this.authXEventListener = authXManager.getListener();
67+
authXManager.start();
68+
}
69+
}
70+
71+
private Supplier<Connection> connectionSupplier() {
72+
return clientSideCache == null ? () -> new Connection(jedisSocketFactory, clientConfig)
73+
: () -> new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
4574
}
4675

4776
@Override
@@ -64,8 +93,7 @@ public void destroyObject(PooledObject<Connection> pooledConnection) throws Exce
6493
@Override
6594
public PooledObject<Connection> makeObject() throws Exception {
6695
try {
67-
Connection jedis = clientSideCache == null ? new Connection(jedisSocketFactory, clientConfig)
68-
: new CacheConnection(jedisSocketFactory, clientConfig, clientSideCache);
96+
Connection jedis = objectMaker.get();
6997
return new DefaultPooledObject<>(jedis);
7098
} catch (JedisException je) {
7199
logger.debug("Error while makeObject", je);
@@ -76,17 +104,40 @@ public PooledObject<Connection> makeObject() throws Exception {
76104
@Override
77105
public void passivateObject(PooledObject<Connection> pooledConnection) throws Exception {
78106
// TODO maybe should select db 0? Not sure right now.
107+
Connection jedis = pooledConnection.getObject();
108+
reAuthenticate(jedis);
79109
}
80110

81111
@Override
82112
public boolean validateObject(PooledObject<Connection> pooledConnection) {
83113
final Connection jedis = pooledConnection.getObject();
84114
try {
85115
// check HostAndPort ??
86-
return jedis.isConnected() && jedis.ping();
116+
if (!jedis.isConnected()) {
117+
return false;
118+
}
119+
reAuthenticate(jedis);
120+
return jedis.ping();
87121
} catch (final Exception e) {
88122
logger.warn("Error while validating pooled Connection object.", e);
89123
return false;
90124
}
91125
}
126+
127+
private void reAuthenticate(Connection jedis) throws Exception {
128+
try {
129+
String result = jedis.reAuthenticate();
130+
if (result != null && !result.equals("OK")) {
131+
String msg = "Re-authentication failed with server response: " + result;
132+
Exception failedAuth = new JedisAuthenticationException(msg);
133+
logger.error(failedAuth.getMessage(), failedAuth);
134+
authXEventListener.onConnectionAuthenticationError(failedAuth);
135+
return;
136+
}
137+
} catch (Exception e) {
138+
logger.error("Error while re-authenticating connection", e);
139+
authXEventListener.onConnectionAuthenticationError(e);
140+
throw e;
141+
}
142+
}
92143
}

src/main/java/redis/clients/jedis/ConnectionPool.java

+38-3
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,27 @@
22

33
import org.apache.commons.pool2.PooledObjectFactory;
44
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
5+
56
import redis.clients.jedis.annots.Experimental;
7+
import redis.clients.jedis.authentication.AuthXManager;
68
import redis.clients.jedis.csc.Cache;
9+
import redis.clients.jedis.exceptions.JedisException;
710
import redis.clients.jedis.util.Pool;
811

912
public class ConnectionPool extends Pool<Connection> {
1013

14+
private AuthXManager authXManager;
15+
1116
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
1217
this(new ConnectionFactory(hostAndPort, clientConfig));
18+
attachAuthenticationListener(clientConfig.getAuthXManager());
1319
}
1420

1521
@Experimental
16-
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache) {
22+
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
23+
Cache clientSideCache) {
1724
this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache));
25+
attachAuthenticationListener(clientConfig.getAuthXManager());
1826
}
1927

2028
public ConnectionPool(PooledObjectFactory<Connection> factory) {
@@ -24,12 +32,14 @@ public ConnectionPool(PooledObjectFactory<Connection> factory) {
2432
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
2533
GenericObjectPoolConfig<Connection> poolConfig) {
2634
this(new ConnectionFactory(hostAndPort, clientConfig), poolConfig);
35+
attachAuthenticationListener(clientConfig.getAuthXManager());
2736
}
2837

2938
@Experimental
30-
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig, Cache clientSideCache,
31-
GenericObjectPoolConfig<Connection> poolConfig) {
39+
public ConnectionPool(HostAndPort hostAndPort, JedisClientConfig clientConfig,
40+
Cache clientSideCache, GenericObjectPoolConfig<Connection> poolConfig) {
3241
this(new ConnectionFactory(hostAndPort, clientConfig, clientSideCache), poolConfig);
42+
attachAuthenticationListener(clientConfig.getAuthXManager());
3343
}
3444

3545
public ConnectionPool(PooledObjectFactory<Connection> factory,
@@ -43,4 +53,29 @@ public Connection getResource() {
4353
conn.setHandlingPool(this);
4454
return conn;
4555
}
56+
57+
@Override
58+
public void close() {
59+
try {
60+
if (authXManager != null) {
61+
authXManager.stop();
62+
}
63+
} finally {
64+
super.close();
65+
}
66+
}
67+
68+
private void attachAuthenticationListener(AuthXManager authXManager) {
69+
this.authXManager = authXManager;
70+
if (authXManager != null) {
71+
authXManager.addPostAuthenticationHook(token -> {
72+
try {
73+
// this is to trigger validations on each connection via ConnectionFactory
74+
evict();
75+
} catch (Exception e) {
76+
throw new JedisException("Failed to evict connections from pool", e);
77+
}
78+
});
79+
}
80+
}
4681
}

0 commit comments

Comments
 (0)