diff --git a/docs/development/extensions-core/druid-pac4j.md b/docs/development/extensions-core/druid-pac4j.md index 243350cc51f3..967e1725fdb2 100644 --- a/docs/development/extensions-core/druid-pac4j.md +++ b/docs/development/extensions-core/druid-pac4j.md @@ -55,6 +55,7 @@ druid.auth.authenticator.jwt.type=jwt |`druid.auth.pac4j.oidc.discoveryURI`|discovery URI for fetching OP metadata [see this](http://openid.net/specs/openid-connect-discovery-1_0.html).|none|Yes| |`druid.auth.pac4j.oidc.oidcClaim`|[claim](https://openid.net/specs/openid-connect-core-1_0.html#Claims) that will be extracted from the ID Token after validation.|name|No| |`druid.auth.pac4j.oidc.scope`| scope is used by an application during authentication to authorize access to a user's details.|`openid profile email`|No| +|`druid.auth.pac4j.oidc.roleClaimPath`| Dot-separated path to the claim containing user roles|none|No| :::info Users must set a strong passphrase to ensure that an attacker is not able to guess it simply by brute force. diff --git a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/LDAPRoleProvider.java b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/LDAPRoleProvider.java index 95ffa229cf3e..b817e641a5b6 100644 --- a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/LDAPRoleProvider.java +++ b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/LDAPRoleProvider.java @@ -106,14 +106,24 @@ public Set getRoles(String authorizerPrefix, AuthenticationResult authen } } + Set claims = RoleProviderUtil.claimValuesFromCtx(authenticationResult.getContext()); + // Get the roles assigned to LDAP user from the metastore. - // This allow us to authorize LDAP users regardless of whether they belong to any groups or not in LDAP. - BasicAuthorizerUser user = userMap.get(authenticationResult.getIdentity()); - if (user != null) { - roleNames.addAll(user.getRoles()); + // This allow us to authorize LDAP users regardless of whether they belong to any groups or not in LDAP. + if (claims != null) { + return RoleProviderUtil.getRolesByClaimValue( + authorizerPrefix, + claims, + roleNames, + cacheManager + ); + } else { + return RoleProviderUtil.getRolesByIdentity( + userMap, + authenticationResult.getIdentity(), + roleNames + ); } - - return roleNames; } @Override diff --git a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/MetadataStoreRoleProvider.java b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/MetadataStoreRoleProvider.java index 7dc05ccc8c80..a5361638b146 100644 --- a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/MetadataStoreRoleProvider.java +++ b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/MetadataStoreRoleProvider.java @@ -57,11 +57,22 @@ public Set getRoles(String authorizerPrefix, AuthenticationResult authen throw new IAE("Could not load userMap for authorizer [%s]", authorizerPrefix); } - BasicAuthorizerUser user = userMap.get(authenticationResult.getIdentity()); - if (user != null) { - roleNames.addAll(user.getRoles()); + Set claims = RoleProviderUtil.claimValuesFromCtx(authenticationResult.getContext()); + + if (claims != null) { + return RoleProviderUtil.getRolesByClaimValue( + authorizerPrefix, + claims, + roleNames, + cacheManager + ); + } else { + return RoleProviderUtil.getRolesByIdentity( + userMap, + authenticationResult.getIdentity(), + roleNames + ); } - return roleNames; } @Override diff --git a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/RoleProviderUtil.java b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/RoleProviderUtil.java new file mode 100644 index 000000000000..1d984244704b --- /dev/null +++ b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/RoleProviderUtil.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.basic.authorization; + +import org.apache.druid.security.basic.authorization.db.cache.BasicAuthorizerCacheManager; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerRole; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerUser; + +import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class RoleProviderUtil +{ + public static final String ROLE_CLAIM_CONTEXT_KEY = "druidRoles"; + + public static Set getRolesByIdentity( + Map userMap, + String identity, + Set roleNames + ) + { + BasicAuthorizerUser user = userMap.get(identity); + if (user != null) { + roleNames.addAll(user.getRoles()); + } + return roleNames; + } + + public static Set getRolesByClaimValue( + String authorizerPrefix, + Set claimValue, + Set roleNames, + BasicAuthorizerCacheManager cacheManager + ) + { + Map roleMap = cacheManager.getRoleMap(authorizerPrefix); + + if (roleMap == null) { + return Set.of(); + } + + roleMap.keySet() + .stream() + .filter(claimValue::contains) + .forEach(roleNames::add); + + return roleNames; + } + + @Nullable + protected static Set claimValuesFromCtx(Map ctx) + { + Object value = (ctx == null) ? null : ctx.get(RoleProviderUtil.ROLE_CLAIM_CONTEXT_KEY); + if (!(value instanceof Set)) { + return null; + } + Set rawClaimValues = (Set) value; + + Set result = new HashSet<>(); + for (Object claimValue : rawClaimValues) { + if (!(claimValue instanceof String)) { + return null; + } + String str = ((String) claimValue).trim(); + if (!str.isEmpty()) { + result.add(str); + } + } + return result.isEmpty() ? null : result; + } +} diff --git a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/MetadataStoreRoleProviderGetRolesTest.java b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/MetadataStoreRoleProviderGetRolesTest.java new file mode 100644 index 000000000000..99f67aa73666 --- /dev/null +++ b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/MetadataStoreRoleProviderGetRolesTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.authorization; + +import org.apache.druid.security.basic.authorization.MetadataStoreRoleProvider; +import org.apache.druid.security.basic.authorization.RoleProviderUtil; +import org.apache.druid.security.basic.authorization.db.cache.BasicAuthorizerCacheManager; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerRole; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerUser; +import org.apache.druid.server.security.AuthenticationResult; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class MetadataStoreRoleProviderGetRolesTest +{ + + @Test + public void returnsRolesByClaimValuesWhenPresent() + { + Map roles = new HashMap<>(); + roles.put("admin", null); + roles.put("viewer", null); + + Set viewerRole = Set.of("viewer"); + + BasicAuthorizerUser user = new BasicAuthorizerUser("alice", viewerRole); + + Map users = Map.of("alice", user); + + BasicAuthorizerCacheManager cache = new StubCacheManager(users, roles); + MetadataStoreRoleProvider provider = new MetadataStoreRoleProvider(cache); + + Set claims = Set.of("admin", "extraneous"); + + Map ctx = Map.of(RoleProviderUtil.ROLE_CLAIM_CONTEXT_KEY, claims); + + AuthenticationResult ar = new AuthenticationResult("alice", "basic", "pac4j", ctx); + + Set out = provider.getRoles("basic", ar); + Set expected = Set.of("admin"); + assertEquals(expected, out); + } + + @Test + public void fallsBackToIdentityWhenNoClaimContext() + { + Set viewerRole = Set.of("viewer"); + BasicAuthorizerUser user = new BasicAuthorizerUser("alice", viewerRole); + + Map users = Map.of("alice", user); + + Map roles = new HashMap<>(); + roles.put("admin", null); + + BasicAuthorizerCacheManager cache = new StubCacheManager(users, roles); + MetadataStoreRoleProvider provider = new MetadataStoreRoleProvider(cache); + + AuthenticationResult ar = new AuthenticationResult("alice", "basic", "pac4j", Collections.emptyMap()); + + Set out = provider.getRoles("basic", ar); + Set expected = Set.of("viewer"); + assertEquals(expected, out); + } +} diff --git a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/RoleProviderUtilTest.java b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/RoleProviderUtilTest.java new file mode 100644 index 000000000000..8bf2b9ce7906 --- /dev/null +++ b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/RoleProviderUtilTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.authorization; + +import org.apache.druid.security.basic.authorization.RoleProviderUtil; +import org.apache.druid.security.basic.authorization.db.cache.BasicAuthorizerCacheManager; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerRole; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerUser; +import org.junit.Test; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + + +public class RoleProviderUtilTest +{ + + @Test + public void getRolesByIdentityAddsRolesWhenUserFound() + { + Set roles = Set.of("r1", "r2"); + BasicAuthorizerUser user = new BasicAuthorizerUser("id", roles); + + Map userMap = Map.of("id", user); + + Set out = RoleProviderUtil.getRolesByIdentity(userMap, "id", new HashSet<>()); + assertEquals(roles, out); + } + + @Test + public void getRolesByIdentityNoopWhenUserMissing() + { + Map userMap = Map.of(); + Set out = RoleProviderUtil.getRolesByIdentity(userMap, "missing", new HashSet<>()); + assertTrue(out.isEmpty()); + } + + @Test + public void getRolesByClaimValuesFiltersByRoleNames() + { + Map roles = new HashMap<>(); + roles.put("r1", null); + roles.put("r2", null); + + BasicAuthorizerCacheManager cache = new StubCacheManager(Map.of(), roles); + + Set claims = Set.of("r2", "nope"); + Set out = RoleProviderUtil.getRolesByClaimValue("authz", claims, new HashSet<>(), cache); + assertEquals(Set.of("r2"), out); + } + + @Test + public void getRolesByClaimValuesThrowsWhenRoleMapNull() + { + BasicAuthorizerCacheManager cache = new StubCacheManager(Map.of(), null); + assertTrue(RoleProviderUtil.getRolesByClaimValue("authz", Set.of("r2"), + new HashSet<>(), cache + ).isEmpty()); + } +} diff --git a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/StubCacheManager.java b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/StubCacheManager.java new file mode 100644 index 000000000000..bce9faa6b342 --- /dev/null +++ b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/StubCacheManager.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.authorization; + +import org.apache.druid.security.basic.authorization.db.cache.BasicAuthorizerCacheManager; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerGroupMapping; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerRole; +import org.apache.druid.security.basic.authorization.entity.BasicAuthorizerUser; + +import java.util.Map; + +public class StubCacheManager implements BasicAuthorizerCacheManager +{ + private final Map userMap; + private final Map roleMap; + + StubCacheManager( + Map userMap, + Map roleMap + ) + { + this.userMap = userMap; + this.roleMap = roleMap; + } + + @Override + public void handleAuthorizerUserUpdate(String authorizerPrefix, byte[] serializedUserAndRoleMap) + { + // No-op + } + + @Override + public void handleAuthorizerGroupMappingUpdate(String authorizerPrefix, byte[] serializedGroupMappingAndRoleMap) + { + // No-op + } + + @Override + public Map getUserMap(String authorizerPrefix) + { + return userMap; + } + + @Override + public Map getRoleMap(String authorizerPrefix) + { + return roleMap; + } + + @Override + public Map getGroupMappingMap(String authorizerPrefix) + { + return null; + } + + @Override + public Map getGroupMappingRoleMap(String authorizerPrefix) + { + return null; + } +} diff --git a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/OIDCConfig.java b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/OIDCConfig.java index 50b04455dbc5..ed5351d39977 100644 --- a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/OIDCConfig.java +++ b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/OIDCConfig.java @@ -41,6 +41,9 @@ public class OIDCConfig @JsonProperty private final String oidcClaim; + @JsonProperty + private final String roleClaimPath; + @JsonProperty private final String scope; @@ -50,6 +53,7 @@ public OIDCConfig( @JsonProperty("clientSecret") PasswordProvider clientSecret, @JsonProperty("discoveryURI") String discoveryURI, @JsonProperty("oidcClaim") String oidcClaim, + @JsonProperty("roleClaimPath") String roleClaimPath, @JsonProperty("scope") @Nullable String scope ) { @@ -57,6 +61,7 @@ public OIDCConfig( this.clientSecret = Preconditions.checkNotNull(clientSecret, "null clientSecret"); this.discoveryURI = Preconditions.checkNotNull(discoveryURI, "null discoveryURI"); this.oidcClaim = oidcClaim == null ? DEFAULT_SCOPE : oidcClaim; + this.roleClaimPath = roleClaimPath; this.scope = scope; } @@ -84,6 +89,12 @@ public String getOidcClaim() return oidcClaim; } + @JsonProperty + public String getRoleClaimPath() + { + return roleClaimPath; + } + @JsonProperty public String getScope() { diff --git a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jAuthenticator.java b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jAuthenticator.java index ef30f4c7e69d..63a3309369f1 100644 --- a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jAuthenticator.java +++ b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jAuthenticator.java @@ -52,6 +52,7 @@ public class Pac4jAuthenticator implements Authenticator private final Supplier pac4jConfigSupplier; private final Pac4jCommonConfig pac4jCommonConfig; private final SSLSocketFactory sslSocketFactory; + private final String roleClaimPath; @JsonCreator public Pac4jAuthenticator( @@ -73,6 +74,7 @@ public Pac4jAuthenticator( } this.pac4jConfigSupplier = Suppliers.memoize(() -> createPac4jConfig(oidcConfig)); + this.roleClaimPath = oidcConfig.getRoleClaimPath(); } @Override @@ -145,6 +147,10 @@ private Config createPac4jConfig(OIDCConfig oidcConfig) oidcClient.setUrlResolver(new DefaultUrlResolver(true)); oidcClient.setCallbackUrlResolver(new NoParameterCallbackUrlResolver()); + if (roleClaimPath != null && !roleClaimPath.isBlank()) { + oidcClient.addAuthorizationGenerator(new RoleBasedAuthGen(roleClaimPath)); + } + // This is used by OidcClient in various places to make HTTPrequests. if (sslSocketFactory != null) { HTTPRequest.setDefaultSSLSocketFactory(sslSocketFactory); diff --git a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jFilter.java b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jFilter.java index 4bc1ccf6f920..fb9ff3c3a230 100644 --- a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jFilter.java +++ b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/Pac4jFilter.java @@ -19,6 +19,8 @@ package org.apache.druid.security.pac4j; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthenticationResult; @@ -26,6 +28,7 @@ import org.pac4j.core.engine.DefaultCallbackLogic; import org.pac4j.core.engine.DefaultSecurityLogic; import org.pac4j.core.exception.http.HttpAction; +import org.pac4j.core.profile.UserProfile; import org.pac4j.jee.context.JEEContext; import org.pac4j.jee.http.adapter.JEEHttpActionAdapter; @@ -38,16 +41,21 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; +import java.util.Set; +import java.util.function.Supplier; public class Pac4jFilter implements Filter { private static final Logger LOGGER = new Logger(Pac4jFilter.class); + public static final String ROLE_CLAIM_CONTEXT_KEY = "druidRoles"; private final Config pac4jConfig; private final Pac4jSessionStore sessionStore; private final String callbackPath; private final String name; private final String authorizerName; + private final Supplier securityLogicFactory; + private final Supplier callbackLogicFactory; public Pac4jFilter( String name, @@ -62,6 +70,8 @@ public Pac4jFilter( this.name = name; this.authorizerName = authorizerName; this.sessionStore = new Pac4jSessionStore(cookiePassphrase); + this.securityLogicFactory = Suppliers.memoize(DefaultSecurityLogic::new); + this.callbackLogicFactory = Suppliers.memoize(DefaultCallbackLogic::new); } @Override @@ -85,7 +95,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo JEEContext context = new JEEContext(request, response); if (request.getRequestURI().equals(callbackPath)) { - DefaultCallbackLogic callbackLogic = new DefaultCallbackLogic(); + DefaultCallbackLogic callbackLogic = callbackLogicFactory.get(); String originalUrl = (String) request.getSession().getAttribute("pac4j.originalUrl"); String redirectUrl = originalUrl != null ? originalUrl : "/"; @@ -99,7 +109,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo null ); } else { - DefaultSecurityLogic securityLogic = new DefaultSecurityLogic(); + DefaultSecurityLogic securityLogic = securityLogicFactory.get(); try { securityLogic.perform( context, @@ -109,9 +119,17 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo try { // Extract user ID from pac4j profiles and create AuthenticationResult if (profiles != null && !profiles.isEmpty()) { - String uid = profiles.iterator().next().getId(); + UserProfile profile = profiles.iterator().next(); + String uid = profile.getId(); if (uid != null) { - AuthenticationResult authenticationResult = new AuthenticationResult(uid, authorizerName, name, null); + final Set roles = profile.getRoles(); + String identity = profile.getId(); + LOGGER.debug("Collected identity: %s with roles: %s", identity, roles); + final ImmutableMap.Builder authResultContext = ImmutableMap.builder(); + if (roles != null && !roles.isEmpty()) { + authResultContext.put(ROLE_CLAIM_CONTEXT_KEY, roles); + } + AuthenticationResult authenticationResult = new AuthenticationResult(uid, authorizerName, name, authResultContext.build()); servletRequest.setAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); filterChain.doFilter(servletRequest, servletResponse); } diff --git a/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/RoleBasedAuthGen.java b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/RoleBasedAuthGen.java new file mode 100644 index 000000000000..c4397a892011 --- /dev/null +++ b/extensions-core/druid-pac4j/src/main/java/org/apache/druid/security/pac4j/RoleBasedAuthGen.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.pac4j; + +import com.nimbusds.jwt.JWT; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.JWTParser; +import com.nimbusds.oauth2.sdk.token.AccessToken; +import org.apache.druid.java.util.common.logger.Logger; +import org.pac4j.core.authorization.generator.AuthorizationGenerator; +import org.pac4j.core.context.WebContext; +import org.pac4j.core.context.session.SessionStore; +import org.pac4j.core.profile.UserProfile; +import org.pac4j.oidc.profile.OidcProfile; + +import java.lang.reflect.Array; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +public class RoleBasedAuthGen implements AuthorizationGenerator +{ + + private static final Logger LOG = new Logger(RoleBasedAuthGen.class); + + private final String roleClaimPath; // dot separated path to roles claim in the ID token + + public RoleBasedAuthGen(String roleClaimPath) + { + this.roleClaimPath = roleClaimPath; + } + + @Override + public Optional generate(WebContext context, SessionStore sessionStore, UserProfile profile) + { + if (profile == null) { + return Optional.empty(); + } + + if (!(profile instanceof OidcProfile)) { + return Optional.of(profile); + } + + final AccessToken accessToken = ((OidcProfile) profile).getAccessToken(); + + if (accessToken == null) { + LOG.debug("No access token; skip role extraction"); + return Optional.of(profile); + } + + final String tokenValue = accessToken.getValue(); + if (tokenValue == null || tokenValue.isBlank()) { + LOG.debug("Empty access token, skip role extraction"); + return Optional.of(profile); + } + + try { + final JWT jwt = JWTParser.parse(tokenValue); + JWTClaimsSet set = jwt.getJWTClaimsSet(); + if (set != null) { + Map claims = set.getClaims(); + Set roles = claimAtPath(claims, roleClaimPath); + ((OidcProfile) profile).setRoles(roles); + LOG.debug( + "Extracted %,d roles from claim path [%s]: %s", + roles.size(), + roleClaimPath, + roles + ); + } + } + catch (Throwable t) { + LOG.debug("No usable ID token on profile; skip extraction"); + } + + return Optional.of(profile); + } + + + private static Set claimAtPath(final Map root, final String path) + { + if (root == null) { + LOG.warn("No claims found in token"); + return Collections.emptySet(); + } + + final String trimmed = path.trim(); + if (trimmed.isEmpty()) { + LOG.warn("Empty roles claim path"); + return Collections.emptySet(); + } + + Object cur = root; + final String[] parts = trimmed.split("\\."); + for (String key : parts) { + if (!(cur instanceof Map)) { + return Collections.emptySet(); + } + final Map m = (Map) cur; + if (!m.containsKey(key)) { + return Collections.emptySet(); + } + cur = m.get(key); + if (cur == null) { + return Collections.emptySet(); + } + } + return normalizeClaimValues(cur); + } + + + private static Set normalizeClaimValues(Object claim) + { + if (claim == null) { + return Set.of(); + } + + Stream stream; + if (claim instanceof Collection) { + stream = ((Collection) claim).stream(); + } else if (claim.getClass().isArray()) { + int len = Array.getLength(claim); + stream = IntStream.range(0, len).mapToObj(i -> Array.get(claim, i)); + } else { + stream = Stream.of(claim); + } + + return stream + .filter(Objects::nonNull) + .map(o -> o.toString().trim()) + .filter(s -> !s.isEmpty()) + .collect(Collectors.toSet()); + } +} diff --git a/extensions-core/druid-pac4j/src/test/java/org/apache/druid/security/pac4j/Pac4jFilterClaimsTest.java b/extensions-core/druid-pac4j/src/test/java/org/apache/druid/security/pac4j/Pac4jFilterClaimsTest.java new file mode 100644 index 000000000000..78177de5b7a2 --- /dev/null +++ b/extensions-core/druid-pac4j/src/test/java/org/apache/druid/security/pac4j/Pac4jFilterClaimsTest.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.pac4j; + +import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.server.security.AuthenticationResult; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.pac4j.core.config.Config; +import org.pac4j.core.context.WebContext; +import org.pac4j.core.context.session.SessionStore; +import org.pac4j.core.engine.DefaultCallbackLogic; +import org.pac4j.core.engine.DefaultSecurityLogic; +import org.pac4j.core.engine.SecurityGrantedAccessAdapter; +import org.pac4j.core.http.adapter.HttpActionAdapter; +import org.pac4j.core.profile.UserProfile; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +@RunWith(MockitoJUnitRunner.class) +public class Pac4jFilterClaimsTest +{ + @Mock + private Config pac4jConfig; + + @Mock + private HttpServletRequest req; + @Mock + private HttpServletResponse resp; + @Mock + private FilterChain chain; + @Mock + private UserProfile profile; + + @Mock + private DefaultSecurityLogic securityLogic; + @Mock + private DefaultCallbackLogic callbackLogic; + + @Mock + private Supplier securityLogicFactory; + @Mock + private Supplier callbackLogicFactory; + + private Pac4jFilter filter; + + @Before + public void setUp() + { + filter = new Pac4jFilter( + "testPac4j", + "basic", + pac4jConfig, + "/callback", + "cookiePassphrase" + ); + + when(securityLogicFactory.get()).thenReturn(securityLogic); + + setField(filter, "securityLogicFactory", securityLogicFactory); + setField(filter, "callbackLogicFactory", callbackLogicFactory); + + when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).thenReturn(null); + when(req.getRequestURI()).thenReturn("/some/api"); + } + + @Test + public void skipWhenAuthResultAlreadyPresent() throws IOException, ServletException + { + when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(new AuthenticationResult("id", "a", "n", Map.of())); + + filter.doFilter(req, resp, chain); + + verify(chain).doFilter(req, resp); + verifyNoInteractions(securityLogicFactory, callbackLogicFactory, securityLogic, callbackLogic); + } + + @Test + public void setsRolesInContextWhenPresent() throws IOException, ServletException + { + when(profile.getId()).thenReturn("user1"); + when(profile.getRoles()).thenReturn(Set.of("admin", "dev")); + + doAnswer(inv -> { + WebContext ctx = inv.getArgument(0); + SessionStore store = inv.getArgument(1); + SecurityGrantedAccessAdapter adapter = inv.getArgument(3); + + adapter.adapt(ctx, store, Set.of(profile), Collections.emptyMap()); + return null; + }).when(securityLogic).perform( + any(WebContext.class), + any(SessionStore.class), + eq(pac4jConfig), + any(), + any(HttpActionAdapter.class), + isNull(), + eq("none"), + isNull() + ); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + + filter.doFilter(req, resp, chain); + + verify(req).setAttribute(eq(AuthConfig.DRUID_AUTHENTICATION_RESULT), captor.capture()); + verify(chain).doFilter(req, resp); + + AuthenticationResult ar = (AuthenticationResult) captor.getValue(); + assertEquals("user1", ar.getIdentity()); + assertNotNull(ar.getContext()); + assertTrue(ar.getContext().containsKey(Pac4jFilter.ROLE_CLAIM_CONTEXT_KEY)); + + Set roles = (Set) ar.getContext().get(Pac4jFilter.ROLE_CLAIM_CONTEXT_KEY); + assertEquals(Set.of("admin", "dev"), roles); + } + + @Test + public void noRolesDoesNotSetRolesKey() throws IOException, ServletException + { + when(profile.getId()).thenReturn("user2"); + when(profile.getRoles()).thenReturn(null); + + doAnswer(inv -> { + WebContext ctx = inv.getArgument(0); + SessionStore store = inv.getArgument(1); + SecurityGrantedAccessAdapter adapter = inv.getArgument(3); + + adapter.adapt(ctx, store, Set.of(profile), Collections.emptyMap()); + return null; + }).when(securityLogic).perform( + any(WebContext.class), + any(SessionStore.class), + eq(pac4jConfig), + any(), + any(HttpActionAdapter.class), + isNull(), + eq("none"), + isNull() + ); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + + filter.doFilter(req, resp, chain); + + verify(req).setAttribute(eq(AuthConfig.DRUID_AUTHENTICATION_RESULT), captor.capture()); + verify(chain).doFilter(req, resp); + + AuthenticationResult ar = (AuthenticationResult) captor.getValue(); + assertEquals("user2", ar.getIdentity()); + assertNotNull(ar.getContext()); + assertFalse(ar.getContext().containsKey(Pac4jFilter.ROLE_CLAIM_CONTEXT_KEY)); + } + + @Test + public void emptyProfilesDoesNotSetAuthResultAndDoesNotContinueChain() throws IOException, ServletException + { + doAnswer(inv -> { + WebContext ctx = inv.getArgument(0); + SessionStore store = inv.getArgument(1); + SecurityGrantedAccessAdapter adapter = inv.getArgument(3); + + adapter.adapt(ctx, store, Collections.emptySet(), Collections.emptyMap()); + return null; + }).when(securityLogic).perform( + any(WebContext.class), + any(SessionStore.class), + eq(pac4jConfig), + any(), + any(HttpActionAdapter.class), + isNull(), + eq("none"), + isNull() + ); + + filter.doFilter(req, resp, chain); + + verify(req, never()).setAttribute(eq(AuthConfig.DRUID_AUTHENTICATION_RESULT), any()); + verify(chain, never()).doFilter(any(), any()); + } + + private static void setField(Object target, String name, Object value) + { + try { + Field f = target.getClass().getDeclaredField(name); + f.setAccessible(true); + f.set(target, value); + } + catch (ReflectiveOperationException e) { + throw new AssertionError("Failed setting field: " + name, e); + } + } +} diff --git a/extensions-core/druid-pac4j/src/test/java/org/apache/druid/security/pac4j/RoleBasedAuthGenTest.java b/extensions-core/druid-pac4j/src/test/java/org/apache/druid/security/pac4j/RoleBasedAuthGenTest.java new file mode 100644 index 000000000000..23aece73046d --- /dev/null +++ b/extensions-core/druid-pac4j/src/test/java/org/apache/druid/security/pac4j/RoleBasedAuthGenTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.security.pac4j; + +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.PlainJWT; +import com.nimbusds.oauth2.sdk.token.BearerAccessToken; +import org.junit.Test; +import org.pac4j.oidc.profile.OidcProfile; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class RoleBasedAuthGenTest +{ + + @Test + public void extractsRolesFromNestedArray() + { + Map druidms = new HashMap<>(); + druidms.put("roles", Arrays.asList("admin", "user", "")); + + Map resourceAccess = new HashMap<>(); + resourceAccess.put("druidms", druidms); + + String at = new PlainJWT(new JWTClaimsSet.Builder() + .claim("resource_access", resourceAccess) + .build()).serialize(); + + OidcProfile profile = new OidcProfile(); + profile.setAccessToken(new BearerAccessToken(at, 3600L, null)); + + RoleBasedAuthGen gen = new RoleBasedAuthGen("resource_access.druidms.roles"); + gen.generate(null, null, profile); + + assertEquals(new HashSet<>(Arrays.asList("admin", "user")), profile.getRoles()); + } + + @Test + public void supportsSingleStringLeaf() + { + Map realmAccess = new HashMap<>(); + realmAccess.put("roles", "admin"); + + String at = new PlainJWT(new JWTClaimsSet.Builder() + .claim("realm_access", realmAccess) + .build()).serialize(); + + OidcProfile profile = new OidcProfile(); + profile.setAccessToken(new BearerAccessToken(at, 3600L, null)); + + RoleBasedAuthGen gen = new RoleBasedAuthGen("realm_access.roles"); + gen.generate(null, null, profile); + + assertEquals(new HashSet<>(Collections.singletonList("admin")), profile.getRoles()); + } + + @Test + public void noAccessTokenLeavesRolesEmpty() + { + OidcProfile profile = new OidcProfile(); + + RoleBasedAuthGen gen = new RoleBasedAuthGen("resource_access.druidms.roles"); + gen.generate(null, null, profile); + + assertTrue(profile.getRoles().isEmpty()); + } +}