Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import static io.trino.client.ProtocolHeaders.detectProtocol;
import static io.trino.server.ServletSecurityUtils.authenticatedIdentity;
import static io.trino.spi.security.AccessDeniedException.denySetRole;
import static io.trino.spi.security.ExtraCredentials.isInternalExtraCredential;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -259,11 +260,15 @@ private Identity buildSessionIdentity(Optional<Identity> authenticatedIdentity,
if (systemRole.getType() == Type.ROLE) {
systemEnabledRoles.add(systemRole.getRole().orElseThrow());
}
// Authenticated credentials (placed by the server-side authenticator under internal$*
// keys) take precedence over client-supplied credentials with the same name.
Map<String, String> extraCredentials = new HashMap<>(parseExtraCredentials(protocolHeaders, headers));
authenticatedIdentity.map(Identity::getExtraCredentials).ifPresent(extraCredentials::putAll);
Identity newIdentity = authenticatedIdentity
.map(identity -> Identity.from(identity).withUser(user))
.orElseGet(() -> Identity.forUser(user))
.withAdditionalConnectorRoles(parseConnectorRoleHeaders(protocolHeaders, headers))
.withAdditionalExtraCredentials(parseExtraCredentials(protocolHeaders, headers))
.withExtraCredentials(extraCredentials)
.withAdditionalGroups(groupProvider.getGroups(user))
.withEnabledRoles(systemEnabledRoles.build())
.build();
Expand Down Expand Up @@ -343,7 +348,7 @@ private static Map<String, String> parseExtraCredentials(ProtocolHeaders protoco
{
Map<String, String> credentials = parseProperty(headers, protocolHeaders.requestExtraCredential());
for (String name : credentials.keySet()) {
assertRequest(!name.startsWith("internal$"), "Invalid extra credential name: %s", name);
assertRequest(!isInternalExtraCredential(name), "Invalid extra credential name");
}
return credentials;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.net.MediaType;
Expand Down Expand Up @@ -126,6 +127,7 @@
import static io.trino.spi.HostAddress.fromUri;
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_ERROR;
import static io.trino.spi.security.ExtraCredentials.isInternalExtraCredential;
import static io.trino.util.Failures.toFailure;
import static java.lang.Math.addExact;
import static java.lang.Math.clamp;
Expand All @@ -142,6 +144,7 @@ public final class HttpRemoteTask
private final TaskId taskId;

private final Session session;
private final Map<String, String> workerExtraCredentials;
private final Span stageSpan;
private final String nodeId;
private final AtomicBoolean speculative;
Expand Down Expand Up @@ -266,6 +269,7 @@ public HttpRemoteTask(
try (SetThreadName _ = new SetThreadName("HttpRemoteTask-" + taskId)) {
this.taskId = taskId;
this.session = session;
this.workerExtraCredentials = extraCredentialsForWorker(session.getIdentity().getExtraCredentials());
this.stageSpan = stageSpan;
this.nodeId = node.getNodeIdentifier();
this.speculative = new AtomicBoolean(speculative);
Expand Down Expand Up @@ -777,7 +781,7 @@ private void sendUpdateInternal()
Optional<PlanFragment> fragment = sendPlan.get() ? Optional.of(planFragment.withoutEmbeddedJsonRepresentation()) : Optional.empty();
TaskUpdateRequest updateRequest = new TaskUpdateRequest(
session.toSessionRepresentation(),
session.getIdentity().getExtraCredentials(),
workerExtraCredentials,
stageSpan,
fragment,
tableCredentials,
Expand Down Expand Up @@ -820,6 +824,11 @@ private void sendUpdateInternal()
executor);
}

private static Map<String, String> extraCredentialsForWorker(Map<String, String> extraCredentials)
{
return ImmutableMap.copyOf(Maps.filterKeys(extraCredentials, key -> !isInternalExtraCredential(key)));
}

private synchronized List<SplitAssignment> getSplitAssignments(int currentSplitBatchSize)
{
return Stream.concat(planFragment.getPartitionedSourceNodes().stream(), planFragment.getRemoteSourceNodes().stream())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.trino.server.security.AbstractBearerAuthenticator;
Expand All @@ -34,6 +35,7 @@
import static io.trino.server.security.UserMapping.createUserMapping;
import static io.trino.server.security.oauth2.OAuth2TokenExchangeResource.getInitiateUri;
import static io.trino.server.security.oauth2.OAuth2TokenExchangeResource.getTokenUri;
import static io.trino.spi.security.ExtraCredentials.authenticatedExtraCredentialName;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

Expand All @@ -46,14 +48,17 @@ public class OAuth2Authenticator
private final UserMapping userMapping;
private final TokenPairSerializer tokenPairSerializer;
private final TokenRefresher tokenRefresher;
private final Optional<String> accessTokenExtraCredentialName;

@Inject
public OAuth2Authenticator(OAuth2Client client, OAuth2Config config, TokenRefresher tokenRefresher, TokenPairSerializer tokenPairSerializer)
{
requireNonNull(config, "config is null");
this.client = requireNonNull(client, "service is null");
this.principalField = config.getPrincipalField();
this.tokenRefresher = requireNonNull(tokenRefresher, "tokenRefresher is null");
this.tokenPairSerializer = requireNonNull(tokenPairSerializer, "tokenPairSerializer is null");
this.accessTokenExtraCredentialName = config.getAccessTokenExtraCredentialName();
userMapping = createUserMapping(config.getUserMappingPattern(), config.getUserMappingFile());
}

Expand All @@ -80,6 +85,9 @@ protected Optional<Identity> createIdentity(String token)
}
Identity.Builder builder = Identity.forUser(userMapping.mapUser(principal.get()));
builder.withPrincipal(new BasicPrincipal(principal.get()));
accessTokenExtraCredentialName.ifPresent(name -> builder.withAdditionalExtraCredentials(ImmutableMap.of(
name, tokenPair.accessToken(),
authenticatedExtraCredentialName(name), tokenPair.accessToken())));
return Optional.of(builder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.airlift.configuration.validation.FileExists;
import io.airlift.units.Duration;
import io.airlift.units.MinDuration;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;

import java.io.File;
Expand All @@ -31,7 +32,9 @@
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static com.google.common.base.Strings.emptyToNull;
import static io.trino.server.security.oauth2.OAuth2Service.OPENID_SCOPE;
import static io.trino.spi.security.ExtraCredentials.isInternalExtraCredential;

public class OAuth2Config
{
Expand All @@ -47,6 +50,7 @@ public class OAuth2Config
private Optional<String> jwtType = Optional.empty();
private Optional<String> userMappingPattern = Optional.empty();
private Optional<File> userMappingFile = Optional.empty();
private Optional<String> accessTokenExtraCredentialName = Optional.empty();
private boolean enableRefreshTokens;
private boolean enableDiscovery = true;

Expand Down Expand Up @@ -218,6 +222,36 @@ public OAuth2Config setUserMappingFile(File userMappingFile)
return this;
}

public Optional<String> getAccessTokenExtraCredentialName()
{
return accessTokenExtraCredentialName;
}

@Config("http-server.authentication.oauth2.access-token-extra-credential-name")
@ConfigDescription("Extra credential name for storing the authenticated OAuth2 access token")
public OAuth2Config setAccessTokenExtraCredentialName(String accessTokenExtraCredentialName)
{
this.accessTokenExtraCredentialName = Optional.ofNullable(emptyToNull(accessTokenExtraCredentialName));
return this;
}

@AssertTrue(message = "OAuth2 access token extra credential name must not start with internal$")
public boolean isAccessTokenExtraCredentialNameNotInternal()
{
return accessTokenExtraCredentialName
.map(name -> !isInternalExtraCredential(name))
.orElse(true);
}

@AssertTrue(message = "OAuth2 access token extra credential name must not contain whitespace, comma, or equals")
public boolean isAccessTokenExtraCredentialNameValid()
{
return accessTokenExtraCredentialName
.map(name -> name.chars().noneMatch(character ->
Character.isWhitespace(character) || character == ',' || character == '='))
.orElse(true);
}

public boolean isEnableRefreshTokens()
{
return enableRefreshTokens;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,28 @@ private static void assertSessionContext(ProtocolHeaders protocolHeaders)
assertThat(context.getIdentity().getExtraCredentials()).isEqualTo(ImmutableMap.of("test.token.foo", "bar", "test.token.abc", "xyz"));
}

@Test
public void testAuthenticatedExtraCredentialsOverrideRequestExtraCredentials()
{
MultivaluedMap<String, String> headers = new GuavaMultivaluedMap<>(ImmutableListMultimap.<String, String>builder()
.put(TRINO_HEADERS.requestUser(), "testUser")
.put(TRINO_HEADERS.requestExtraCredential(), "token=request-token")
.put(TRINO_HEADERS.requestExtraCredential(), "request-only=request-value")
.build());

SessionContext context = sessionContextFactory(TRINO_HEADERS).createSessionContext(
headers,
Optional.of("testRemote"),
Optional.of(Identity.forUser("testUser")
.withExtraCredentials(ImmutableMap.of("token", "authenticated-token", "authenticated-only", "authenticated-value"))
.build()));

assertThat(context.getIdentity().getExtraCredentials()).isEqualTo(ImmutableMap.of(
"token", "authenticated-token",
"request-only", "request-value",
"authenticated-only", "authenticated-value"));
}

@Test
public void testMappedUser()
{
Expand Down Expand Up @@ -186,7 +208,7 @@ public void testInternalExtraCredentialName()
.build());

assertInvalidSession(TRINO_HEADERS, headers)
.hasMessage("Invalid extra credential name: internal$abc");
.hasMessage("Invalid extra credential name");
}

private static AbstractThrowableAssert<?, ? extends Throwable> assertInvalidSession(ProtocolHeaders protocolHeaders, MultivaluedMap<String, String> headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import io.trino.spi.connector.DynamicFilter;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.security.Identity;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeOperators;
Expand Down Expand Up @@ -139,6 +140,7 @@
import static io.trino.server.InternalHeaders.TRINO_MAX_WAIT;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_ERROR;
import static io.trino.spi.StandardErrorCode.REMOTE_TASK_MISMATCH;
import static io.trino.spi.security.ExtraCredentials.authenticatedExtraCredentialName;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE;
Expand Down Expand Up @@ -220,6 +222,38 @@ public void testRegular()
httpRemoteTaskFactory.stop();
}

@Test
@Timeout(30)
public void testInternalExtraCredentialsAreNotSentToWorkers()
throws Exception
{
AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime());
TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, FailureScenario.NO_FAILURE);
HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource);
Session session = Session.builder(TEST_SESSION)
.setIdentity(Identity.forUser(TEST_SESSION.getUser())
.withExtraCredentials(ImmutableMap.of(
"token", "access-token",
authenticatedExtraCredentialName("token"), "access-token"))
.build())
.build();
RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory, ImmutableSet.of(), session);

testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo());
remoteTask.start();
remoteTask.addSplits(ImmutableMultimap.of(TABLE_SCAN_NODE_ID, new Split(TEST_CATALOG_HANDLE, TestingSplit.createLocalSplit())));

poll(() -> testingTaskResource.getLatestExtraCredentials().containsKey("token"));
assertThat(testingTaskResource.getLatestExtraCredentials())
.containsEntry("token", "access-token")
.doesNotContainKey(authenticatedExtraCredentialName("token"));

remoteTask.cancel();
poll(() -> remoteTask.getTaskStatus().state().isDone());

httpRemoteTaskFactory.stop();
}

@Test
@Timeout(30)
public void testDynamicFilterFetcherFailure()
Expand Down Expand Up @@ -781,6 +815,7 @@ public static class TestingTaskResource
private TaskState taskState;
private long taskInstanceId = INITIAL_TASK_INSTANCE_ID;
private Map<DynamicFilterId, Domain> latestDynamicFilterFromCoordinator = ImmutableMap.of();
private Map<String, String> latestExtraCredentials = ImmutableMap.of();

private long statusFetchCounter;
private long createOrUpdateCounter;
Expand Down Expand Up @@ -830,6 +865,7 @@ public synchronized TaskInfo createOrUpdateTask(
dynamicFiltersSentCounter++;
latestDynamicFilterFromCoordinator = taskUpdateRequest.dynamicFilterDomains();
}
latestExtraCredentials = taskUpdateRequest.extraCredentials();
createOrUpdateCounter++;
lastActivityNanos.set(System.nanoTime());
return buildTaskInfo();
Expand Down Expand Up @@ -944,6 +980,11 @@ public synchronized long getCreateOrUpdateCounter()
return createOrUpdateCounter;
}

public synchronized Map<String, String> getLatestExtraCredentials()
{
return latestExtraCredentials;
}

public synchronized long getDynamicFiltersFetchCounter()
{
return dynamicFiltersFetchCounter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.trino.server.protocol.PreparedStatementEncoder;
import io.trino.server.protocol.spooling.QueryDataEncoder;
import io.trino.server.security.oauth2.ChallengeFailedException;
import io.trino.server.security.oauth2.OAuth2Authenticator;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.TokenPairSerializer;
import io.trino.server.security.oauth2.TokenPairSerializer.TokenPair;
Expand Down Expand Up @@ -106,6 +107,7 @@
import static io.trino.server.ui.OAuthWebUiCookie.OAUTH2_COOKIE;
import static io.trino.spi.security.AccessDeniedException.denyImpersonateUser;
import static io.trino.spi.security.AccessDeniedException.denyReadSystemInformationAccess;
import static io.trino.spi.security.ExtraCredentials.authenticatedExtraCredentialName;
import static jakarta.servlet.http.HttpServletResponse.SC_FORBIDDEN;
import static jakarta.servlet.http.HttpServletResponse.SC_OK;
import static jakarta.servlet.http.HttpServletResponse.SC_SEE_OTHER;
Expand Down Expand Up @@ -632,6 +634,57 @@ public void testOAuth2Authenticator()
verifyOAuth2Authenticator(false, true, Optional.empty());
}

@Test
public void testOAuth2AccessTokenExtraCredential()
throws Exception
{
assertOAuth2AccessTokenExtraCredential("token");
assertOAuth2AccessTokenExtraCredential("credential");
}

private void assertOAuth2AccessTokenExtraCredential(String credentialName)
throws Exception
{
CookieManager cookieManager = new CookieManager();
OkHttpClient client = this.client.newBuilder()
.cookieJar(new JavaNetCookieJar(cookieManager))
.build();

try (TokenServer tokenServer = new TokenServer(Optional.empty());
TestingTrinoServer server = TestingTrinoServer.builder()
.setProperties(ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("web-ui.enabled", "false")
.put("http-server.authentication.type", "oauth2")
.putAll(getOAuth2Properties(tokenServer))
.put("http-server.authentication.oauth2.access-token-extra-credential-name", credentialName)
.buildOrThrow())
.setAdditionalModule(oauth2Module(tokenServer))
.setSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION)
.build()) {
HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class));
URI baseUri = httpServerInfo.getHttpsUri();

OAuthBearer bearer = assertAuthenticateOAuth2Bearer(client, getAuthorizedUserLocation(baseUri), "http://example.com/authorize");
assertOk(
client,
uriBuilderFrom(baseUri)
.replacePath("/oauth2/callback/")
.addParameter("code", "TEST_CODE")
.addParameter("state", bearer.state())
.toString());

String oauthToken = getOauthToken(client, bearer.tokenServer());
List<Authenticator> authenticators = server.getInstance(new Key<>() {});
assertThat(authenticators).hasSize(1);
assertThat(authenticators.get(0)).isInstanceOf(OAuth2Authenticator.class);
Identity identity = ((OAuth2Authenticator) authenticators.get(0)).authenticate(null, oauthToken);
assertThat(identity.getExtraCredentials())
.containsEntry(credentialName, tokenServer.getAccessToken())
.containsEntry(authenticatedExtraCredentialName(credentialName), tokenServer.getAccessToken());
}
}

private void verifyOAuth2Authenticator(boolean webUiEnabled, boolean refreshTokensEnabled, Optional<String> principalField)
throws Exception
{
Expand Down
Loading
Loading