Skip to content

add kerberos support #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import net.qihoo.xlearning.common.*;
import net.qihoo.xlearning.conf.XLearningConfiguration;
import net.qihoo.xlearning.container.XLearningContainerId;
import net.qihoo.xlearning.security.XTokenSecretManager;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
Expand All @@ -14,6 +15,7 @@
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.service.AbstractService;
import org.apache.hadoop.yarn.util.Clock;
import org.apache.hadoop.yarn.util.SystemClock;
Expand Down Expand Up @@ -120,11 +122,21 @@ public ApplicationContainerListener(ApplicationContext applicationContext, Confi
@Override
public void start() {
LOG.info("Starting application containers handler server");
RPC.Builder builder = new RPC.Builder(getConfig());

Configuration conf = getConfig();
conf.setBoolean("hadoop.security.authorization",false);

RPC.Builder builder = new RPC.Builder(conf);
builder.setProtocol(ApplicationContainerProtocol.class);
builder.setInstance(this);
builder.setBindAddress("0.0.0.0");
builder.setPort(0);

if(UserGroupInformation.isSecurityEnabled()) {
XTokenSecretManager secretManager = new XTokenSecretManager();
builder.setSecretManager(secretManager);
}

try {
server = builder.build();
} catch (Exception e) {
Expand Down
33 changes: 32 additions & 1 deletion src/main/java/net/qihoo/xlearning/AM/ApplicationMaster.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.*;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.mapred.*;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.service.CompositeService;
import org.apache.hadoop.yarn.api.ApplicationConstants;
Expand All @@ -24,10 +28,12 @@
import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -36,6 +42,7 @@
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.text.DecimalFormat;
import java.util.*;
Expand Down Expand Up @@ -106,6 +113,10 @@ public class ApplicationMaster extends CompositeService {

private Thread cleanApplication;


private ByteBuffer allTokens;


/**
* Constructor, connect to Resource Manager
*
Expand Down Expand Up @@ -808,7 +819,7 @@ private void launchContainer(Map<String, LocalResource> containerLocalResource,

containerEnv.put(XLearningConstants.Environment.XLEARNING_TF_INDEX.toString(), String.valueOf(index));
ContainerLaunchContext ctx = ContainerLaunchContext.newInstance(
containerLocalResource, containerEnv, containerLaunchcommands, null, null, null);
containerLocalResource, containerEnv, containerLaunchcommands, null,allTokens==null?null:allTokens.duplicate(), null);

try {
nmAsync.startContainerAsync(container, ctx);
Expand Down Expand Up @@ -851,6 +862,26 @@ public Configuration getConf() {
private boolean run() throws IOException, NoSuchAlgorithmException {
LOG.info("ApplicationMaster Starting ...");

if(UserGroupInformation.isSecurityEnabled()) {
// Note: Credentials, Token, UserGroupInformation, DataOutputBuffer class
// are marked as LimitedPrivate
Credentials credentials =
UserGroupInformation.getCurrentUser().getCredentials();
DataOutputBuffer dob = new DataOutputBuffer();
credentials.writeTokenStorageToStream(dob);
// Now remove the AM->RM token so that containers cannot access it.
Iterator<org.apache.hadoop.security.token.Token<?>> iter = credentials.getAllTokens().iterator();
LOG.info("Executing with tokens:");
while (iter.hasNext()) {
Token<?> token = iter.next();
LOG.info(token);
if (token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) {
iter.remove();
}
}
allTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength());
}

registerApplicationMaster();
if (conf.get(XLearningConfiguration.XLEARNING_INPUT_STRATEGY, XLearningConfiguration.DEFAULT_XLEARNING_INPUT_STRATEGY).equals("STREAM")) {
buildInputStreamFileStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import net.qihoo.xlearning.api.ApplicationContext;
import net.qihoo.xlearning.api.ApplicationMessageProtocol;
import net.qihoo.xlearning.common.Message;
import net.qihoo.xlearning.security.XTokenSecretManager;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.service.AbstractService;

import java.io.IOException;
Expand All @@ -35,11 +37,20 @@ public ApplicationMessageService(ApplicationContext applicationContext, Configur
@Override
public void start() {
LOG.info("Starting application message server");
RPC.Builder builder = new RPC.Builder(getConfig());
Configuration conf = getConfig();
conf.setBoolean("hadoop.security.authorization",false);

RPC.Builder builder = new RPC.Builder(conf);
builder.setProtocol(ApplicationMessageProtocol.class);
builder.setInstance(this);
builder.setBindAddress("0.0.0.0");
builder.setPort(0);

if(UserGroupInformation.isSecurityEnabled()) {
XTokenSecretManager secretManager = new XTokenSecretManager();
builder.setSecretManager(secretManager);
}

Server server;
try {
server = builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,19 @@

import net.qihoo.xlearning.common.*;
import net.qihoo.xlearning.container.XLearningContainerId;
import net.qihoo.xlearning.security.Utils;
import net.qihoo.xlearning.security.XTokenSelector;
import org.apache.hadoop.ipc.ProtocolInfo;
import org.apache.hadoop.ipc.VersionedProtocol;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.TokenInfo;


@KerberosInfo(serverPrincipal = Utils.SERVER_PRINCIPAL_KEY)
@TokenInfo(XTokenSelector.class)
@ProtocolInfo(protocolName = "ApplicationContainerProtocol",
protocolVersion = 1)
public interface ApplicationContainerProtocol extends VersionedProtocol {

public static final long versionID = 1L;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
package net.qihoo.xlearning.api;

import net.qihoo.xlearning.security.Utils;
import net.qihoo.xlearning.security.XTokenSelector;
import org.apache.hadoop.ipc.ProtocolInfo;
import org.apache.hadoop.ipc.VersionedProtocol;
import net.qihoo.xlearning.common.Message;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.token.TokenInfo;

/**
* The Protocal between clients and ApplicationMaster to fetch Application Messages.
*/

@KerberosInfo(serverPrincipal = Utils.SERVER_PRINCIPAL_KEY)
@TokenInfo(XTokenSelector.class)
@ProtocolInfo(protocolName = "ApplicationMessageProtocol",
protocolVersion = 1)
public interface ApplicationMessageProtocol extends VersionedProtocol {

public static final long versionID = 1L;
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/net/qihoo/xlearning/api/ContainerListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import net.qihoo.xlearning.container.XLearningContainerId;



public interface ContainerListener {

void registerContainer(XLearningContainerId xlearningContainerId, String role);
Expand Down
47 changes: 46 additions & 1 deletion src/main/java/net/qihoo/xlearning/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import net.qihoo.xlearning.common.Message;
import net.qihoo.xlearning.common.exceptions.RequestOverLimitException;
import net.qihoo.xlearning.conf.XLearningConfiguration;
import net.qihoo.xlearning.security.XTokenIdentifier;
import net.qihoo.xlearning.security.XTokenSecretManager;
import net.qihoo.xlearning.util.Utilities;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang.StringUtils;
Expand All @@ -16,10 +18,15 @@
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.mapred.InputFormat;
import org.apache.hadoop.mapred.OutputFormat;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.api.protocolrecords.*;
import org.apache.hadoop.yarn.api.records.*;
Expand All @@ -33,6 +40,7 @@
import java.lang.reflect.UndeclaredThrowableException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -216,6 +224,15 @@ private static ApplicationMessageProtocol getAppMessageHandler(
ApplicationMessageProtocol appMessageHandler = null;
if (!StringUtils.isBlank(appMasterAddress) && !appMasterAddress.equalsIgnoreCase("N/A")) {
InetSocketAddress addr = new InetSocketAddress(appMasterAddress, appMasterPort);

if(UserGroupInformation.isSecurityEnabled()) {
XTokenSecretManager secretManager = new XTokenSecretManager();
UserGroupInformation current = UserGroupInformation.getCurrentUser();
XTokenIdentifier tokenID = new XTokenIdentifier(new Text(current.getUserName()));
Token<XTokenIdentifier> token = new Token<>(tokenID, secretManager);
SecurityUtil.setTokenService(token, addr);
current.addToken(token);
}
appMessageHandler = RPC.getProxy(ApplicationMessageProtocol.class, ApplicationMessageProtocol.versionID, addr, conf);
}
return appMessageHandler;
Expand Down Expand Up @@ -539,7 +556,11 @@ private boolean submitAndMonitor() throws IOException, YarnException {

if (clientArguments.xlearningCacheArchives != null && !clientArguments.xlearningCacheArchives.equals("")) {
appMasterEnv.put(XLearningConstants.Environment.XLEARNING_CACHE_ARCHIVE_LOCATION.toString(), clientArguments.xlearningCacheArchives);
if (clientArguments.appType.equals("MXNET") && !conf.getBoolean(XLearningConfiguration.XLEARNING_MXNET_MODE_SINGLE, XLearningConfiguration.DEFAULT_XLEARNING_MXNET_MODE_SINGLE)) {
if (clientArguments.appType.equals("MXNET") && !conf.getBoolean(XLearningConfiguration.XLEARNING_MXNET_MODE_SINGLE, XLearningConfiguration.DEFAULT_XLEARNING_MXNET_MODE_SINGLE)
|| (clientArguments.appType.equals("DISTXGBOOST"))

) {

URI defaultUri = new Path(conf.get("fs.defaultFS")).toUri();
String appCacheArchivesRemoteLocation = appMasterEnv.get(XLearningConstants.Environment.XLEARNING_CACHE_ARCHIVE_LOCATION.toString());
String[] cacheArchives = StringUtils.split(appCacheArchivesRemoteLocation, ",");
Expand Down Expand Up @@ -632,6 +653,30 @@ private boolean submitAndMonitor() throws IOException, YarnException {
ContainerLaunchContext amContainer = ContainerLaunchContext.newInstance(
localResources, appMasterEnv, appMasterLaunchcommands, null, null, null);


if (UserGroupInformation.isSecurityEnabled()) {
// Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce
Credentials credentials = new Credentials();
String tokenRenewer = conf.get(YarnConfiguration.RM_PRINCIPAL);
if (tokenRenewer == null || tokenRenewer.length() == 0) {
throw new IOException(
"Can't get Master Kerberos principal for the RM to use as renewer");
}

// For now, only getting tokens for the default file-system.
final Token<?> tokens[] =
dfs.addDelegationTokens(tokenRenewer, credentials);
if (tokens != null) {
for (Token<?> token : tokens) {
LOG.info("Got dt for " + dfs.getUri() + "; " + token);
}
}
DataOutputBuffer dob = new DataOutputBuffer();
credentials.writeTokenStorageToStream(dob);
ByteBuffer fsTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength());
amContainer.setTokens(fsTokens);
}

applicationContext.setAMContainerSpec(amContainer);

Priority priority = Records.newRecord(Priority.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private void produceCpuMetrics(String xlearningCmdProcessId) throws IOException
} catch (IOException e) {
e.printStackTrace();
}
LOG.info("containerProcessId is:" + this.containerProcessId);
LOG.info("ps is:" + this.containerProcessId);
ProcessTreeInfo processTreeInfo =
new ProcessTreeInfo(this.containerId.getContainerId(),
null, null, 0, 0, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import net.qihoo.xlearning.common.XLearningContainerStatus;
import net.qihoo.xlearning.common.TextMultiOutputFormat;
import net.qihoo.xlearning.conf.XLearningConfiguration;
import net.qihoo.xlearning.security.XTokenIdentifier;
import net.qihoo.xlearning.security.XTokenSecretManager;
import net.qihoo.xlearning.util.Utilities;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
Expand All @@ -19,6 +21,9 @@
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.ApplicationConstants;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.io.Text;
Expand Down Expand Up @@ -124,11 +129,22 @@ private XLearningContainer() {
reservedSocket = new Socket();
}

private void init() {
private void init() throws IOException {
LOG.info("XLearningContainer initializing");
String appMasterHost = System.getenv(XLearningConstants.Environment.APPMASTER_HOST.toString());
int appMasterPort = Integer.valueOf(System.getenv(XLearningConstants.Environment.APPMASTER_PORT.toString()));
InetSocketAddress addr = new InetSocketAddress(appMasterHost, appMasterPort);

if(UserGroupInformation.isSecurityEnabled()) {
//add for kerberos
conf.setBoolean("hadoop.security.authorization", false);
XTokenSecretManager secretManager = new XTokenSecretManager();
UserGroupInformation current = UserGroupInformation.getCurrentUser();
XTokenIdentifier tokenID = new XTokenIdentifier(new Text(current.getUserName()));
Token<XTokenIdentifier> token = new Token<>(tokenID, secretManager);
SecurityUtil.setTokenService(token, addr);
current.addToken(token);
}
try {
this.amClient = RPC.getProxy(ApplicationContainerProtocol.class,
ApplicationContainerProtocol.versionID, addr, conf);
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/net/qihoo/xlearning/security/Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package net.qihoo.xlearning.security;

public class Utils {
public static final String SERVER_PRINCIPAL_KEY="xlearning.ipc.server.principal";

}
Loading