diff --git a/src/main/java/net/qihoo/xlearning/AM/ApplicationContainerListener.java b/src/main/java/net/qihoo/xlearning/AM/ApplicationContainerListener.java index dcb3c37..1598591 100644 --- a/src/main/java/net/qihoo/xlearning/AM/ApplicationContainerListener.java +++ b/src/main/java/net/qihoo/xlearning/AM/ApplicationContainerListener.java @@ -6,6 +6,7 @@ import net.qihoo.xlearning.common.*; import net.qihoo.xlearning.conf.XLearningConfiguration; import net.qihoo.xlearning.container.XLearningContainerId; +import net.qihoo.xlearning.security.XTokenSecretManager; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -14,6 +15,7 @@ import org.apache.hadoop.ipc.ProtocolSignature; import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.ipc.Server; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.service.AbstractService; import org.apache.hadoop.yarn.util.Clock; import org.apache.hadoop.yarn.util.SystemClock; @@ -120,11 +122,21 @@ public ApplicationContainerListener(ApplicationContext applicationContext, Confi @Override public void start() { LOG.info("Starting application containers handler server"); - RPC.Builder builder = new RPC.Builder(getConfig()); + + Configuration conf = getConfig(); + conf.setBoolean("hadoop.security.authorization",false); + + RPC.Builder builder = new RPC.Builder(conf); builder.setProtocol(ApplicationContainerProtocol.class); builder.setInstance(this); builder.setBindAddress("0.0.0.0"); builder.setPort(0); + + if(UserGroupInformation.isSecurityEnabled()) { + XTokenSecretManager secretManager = new XTokenSecretManager(); + builder.setSecretManager(secretManager); + } + try { server = builder.build(); } catch (Exception e) { diff --git a/src/main/java/net/qihoo/xlearning/AM/ApplicationMaster.java b/src/main/java/net/qihoo/xlearning/AM/ApplicationMaster.java index 0342123..11453fb 100644 --- a/src/main/java/net/qihoo/xlearning/AM/ApplicationMaster.java +++ b/src/main/java/net/qihoo/xlearning/AM/ApplicationMaster.java @@ -15,7 +15,11 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.*; import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.mapred.*; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.service.CompositeService; import org.apache.hadoop.yarn.api.ApplicationConstants; @@ -24,10 +28,12 @@ import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; import org.apache.hadoop.yarn.client.api.async.NMClientAsync; import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.security.AMRMTokenIdentifier; import org.apache.hadoop.yarn.util.ConverterUtils; import org.apache.hadoop.yarn.util.Records; import java.io.BufferedReader; +import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import java.lang.reflect.InvocationTargetException; @@ -36,6 +42,7 @@ import java.net.InetSocketAddress; import java.net.Socket; import java.net.URI; +import java.nio.ByteBuffer; import java.security.NoSuchAlgorithmException; import java.text.DecimalFormat; import java.util.*; @@ -106,6 +113,10 @@ public class ApplicationMaster extends CompositeService { private Thread cleanApplication; + + private ByteBuffer allTokens; + + /** * Constructor, connect to Resource Manager * @@ -808,7 +819,7 @@ private void launchContainer(Map containerLocalResource, containerEnv.put(XLearningConstants.Environment.XLEARNING_TF_INDEX.toString(), String.valueOf(index)); ContainerLaunchContext ctx = ContainerLaunchContext.newInstance( - containerLocalResource, containerEnv, containerLaunchcommands, null, null, null); + containerLocalResource, containerEnv, containerLaunchcommands, null,allTokens==null?null:allTokens.duplicate(), null); try { nmAsync.startContainerAsync(container, ctx); @@ -851,6 +862,26 @@ public Configuration getConf() { private boolean run() throws IOException, NoSuchAlgorithmException { LOG.info("ApplicationMaster Starting ..."); + if(UserGroupInformation.isSecurityEnabled()) { + // Note: Credentials, Token, UserGroupInformation, DataOutputBuffer class + // are marked as LimitedPrivate + Credentials credentials = + UserGroupInformation.getCurrentUser().getCredentials(); + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + // Now remove the AM->RM token so that containers cannot access it. + Iterator> iter = credentials.getAllTokens().iterator(); + LOG.info("Executing with tokens:"); + while (iter.hasNext()) { + Token token = iter.next(); + LOG.info(token); + if (token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) { + iter.remove(); + } + } + allTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + } + registerApplicationMaster(); if (conf.get(XLearningConfiguration.XLEARNING_INPUT_STRATEGY, XLearningConfiguration.DEFAULT_XLEARNING_INPUT_STRATEGY).equals("STREAM")) { buildInputStreamFileStatus(); diff --git a/src/main/java/net/qihoo/xlearning/AM/ApplicationMessageService.java b/src/main/java/net/qihoo/xlearning/AM/ApplicationMessageService.java index 15837b4..359244c 100644 --- a/src/main/java/net/qihoo/xlearning/AM/ApplicationMessageService.java +++ b/src/main/java/net/qihoo/xlearning/AM/ApplicationMessageService.java @@ -3,6 +3,7 @@ import net.qihoo.xlearning.api.ApplicationContext; import net.qihoo.xlearning.api.ApplicationMessageProtocol; import net.qihoo.xlearning.common.Message; +import net.qihoo.xlearning.security.XTokenSecretManager; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -10,6 +11,7 @@ import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.ipc.Server; import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.service.AbstractService; import java.io.IOException; @@ -35,11 +37,20 @@ public ApplicationMessageService(ApplicationContext applicationContext, Configur @Override public void start() { LOG.info("Starting application message server"); - RPC.Builder builder = new RPC.Builder(getConfig()); + Configuration conf = getConfig(); + conf.setBoolean("hadoop.security.authorization",false); + + RPC.Builder builder = new RPC.Builder(conf); builder.setProtocol(ApplicationMessageProtocol.class); builder.setInstance(this); builder.setBindAddress("0.0.0.0"); builder.setPort(0); + + if(UserGroupInformation.isSecurityEnabled()) { + XTokenSecretManager secretManager = new XTokenSecretManager(); + builder.setSecretManager(secretManager); + } + Server server; try { server = builder.build(); diff --git a/src/main/java/net/qihoo/xlearning/api/ApplicationContainerProtocol.java b/src/main/java/net/qihoo/xlearning/api/ApplicationContainerProtocol.java index d311cc0..3cec7da 100644 --- a/src/main/java/net/qihoo/xlearning/api/ApplicationContainerProtocol.java +++ b/src/main/java/net/qihoo/xlearning/api/ApplicationContainerProtocol.java @@ -2,9 +2,19 @@ import net.qihoo.xlearning.common.*; import net.qihoo.xlearning.container.XLearningContainerId; +import net.qihoo.xlearning.security.Utils; +import net.qihoo.xlearning.security.XTokenSelector; +import org.apache.hadoop.ipc.ProtocolInfo; import org.apache.hadoop.ipc.VersionedProtocol; import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.security.KerberosInfo; +import org.apache.hadoop.security.token.TokenInfo; + +@KerberosInfo(serverPrincipal = Utils.SERVER_PRINCIPAL_KEY) +@TokenInfo(XTokenSelector.class) +@ProtocolInfo(protocolName = "ApplicationContainerProtocol", + protocolVersion = 1) public interface ApplicationContainerProtocol extends VersionedProtocol { public static final long versionID = 1L; diff --git a/src/main/java/net/qihoo/xlearning/api/ApplicationMessageProtocol.java b/src/main/java/net/qihoo/xlearning/api/ApplicationMessageProtocol.java index 0d6590f..de4417f 100644 --- a/src/main/java/net/qihoo/xlearning/api/ApplicationMessageProtocol.java +++ b/src/main/java/net/qihoo/xlearning/api/ApplicationMessageProtocol.java @@ -1,11 +1,21 @@ package net.qihoo.xlearning.api; +import net.qihoo.xlearning.security.Utils; +import net.qihoo.xlearning.security.XTokenSelector; +import org.apache.hadoop.ipc.ProtocolInfo; import org.apache.hadoop.ipc.VersionedProtocol; import net.qihoo.xlearning.common.Message; +import org.apache.hadoop.security.KerberosInfo; +import org.apache.hadoop.security.token.TokenInfo; /** * The Protocal between clients and ApplicationMaster to fetch Application Messages. */ + +@KerberosInfo(serverPrincipal = Utils.SERVER_PRINCIPAL_KEY) +@TokenInfo(XTokenSelector.class) +@ProtocolInfo(protocolName = "ApplicationMessageProtocol", + protocolVersion = 1) public interface ApplicationMessageProtocol extends VersionedProtocol { public static final long versionID = 1L; diff --git a/src/main/java/net/qihoo/xlearning/api/ContainerListener.java b/src/main/java/net/qihoo/xlearning/api/ContainerListener.java index 696d41a..4e2d170 100644 --- a/src/main/java/net/qihoo/xlearning/api/ContainerListener.java +++ b/src/main/java/net/qihoo/xlearning/api/ContainerListener.java @@ -2,6 +2,8 @@ import net.qihoo.xlearning.container.XLearningContainerId; + + public interface ContainerListener { void registerContainer(XLearningContainerId xlearningContainerId, String role); diff --git a/src/main/java/net/qihoo/xlearning/client/Client.java b/src/main/java/net/qihoo/xlearning/client/Client.java index 991f5cf..db4d193 100644 --- a/src/main/java/net/qihoo/xlearning/client/Client.java +++ b/src/main/java/net/qihoo/xlearning/client/Client.java @@ -7,6 +7,8 @@ import net.qihoo.xlearning.common.Message; import net.qihoo.xlearning.common.exceptions.RequestOverLimitException; import net.qihoo.xlearning.conf.XLearningConfiguration; +import net.qihoo.xlearning.security.XTokenIdentifier; +import net.qihoo.xlearning.security.XTokenSecretManager; import net.qihoo.xlearning.util.Utilities; import org.apache.commons.cli.ParseException; import org.apache.commons.lang.StringUtils; @@ -16,10 +18,15 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.Text; import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.mapred.InputFormat; import org.apache.hadoop.mapred.OutputFormat; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; import org.apache.hadoop.yarn.api.ApplicationConstants; import org.apache.hadoop.yarn.api.protocolrecords.*; import org.apache.hadoop.yarn.api.records.*; @@ -33,6 +40,7 @@ import java.lang.reflect.UndeclaredThrowableException; import java.net.InetSocketAddress; import java.net.URI; +import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; @@ -216,6 +224,15 @@ private static ApplicationMessageProtocol getAppMessageHandler( ApplicationMessageProtocol appMessageHandler = null; if (!StringUtils.isBlank(appMasterAddress) && !appMasterAddress.equalsIgnoreCase("N/A")) { InetSocketAddress addr = new InetSocketAddress(appMasterAddress, appMasterPort); + + if(UserGroupInformation.isSecurityEnabled()) { + XTokenSecretManager secretManager = new XTokenSecretManager(); + UserGroupInformation current = UserGroupInformation.getCurrentUser(); + XTokenIdentifier tokenID = new XTokenIdentifier(new Text(current.getUserName())); + Token token = new Token<>(tokenID, secretManager); + SecurityUtil.setTokenService(token, addr); + current.addToken(token); + } appMessageHandler = RPC.getProxy(ApplicationMessageProtocol.class, ApplicationMessageProtocol.versionID, addr, conf); } return appMessageHandler; @@ -539,7 +556,11 @@ private boolean submitAndMonitor() throws IOException, YarnException { if (clientArguments.xlearningCacheArchives != null && !clientArguments.xlearningCacheArchives.equals("")) { appMasterEnv.put(XLearningConstants.Environment.XLEARNING_CACHE_ARCHIVE_LOCATION.toString(), clientArguments.xlearningCacheArchives); - if (clientArguments.appType.equals("MXNET") && !conf.getBoolean(XLearningConfiguration.XLEARNING_MXNET_MODE_SINGLE, XLearningConfiguration.DEFAULT_XLEARNING_MXNET_MODE_SINGLE)) { + if (clientArguments.appType.equals("MXNET") && !conf.getBoolean(XLearningConfiguration.XLEARNING_MXNET_MODE_SINGLE, XLearningConfiguration.DEFAULT_XLEARNING_MXNET_MODE_SINGLE) + || (clientArguments.appType.equals("DISTXGBOOST")) + + ) { + URI defaultUri = new Path(conf.get("fs.defaultFS")).toUri(); String appCacheArchivesRemoteLocation = appMasterEnv.get(XLearningConstants.Environment.XLEARNING_CACHE_ARCHIVE_LOCATION.toString()); String[] cacheArchives = StringUtils.split(appCacheArchivesRemoteLocation, ","); @@ -632,6 +653,30 @@ private boolean submitAndMonitor() throws IOException, YarnException { ContainerLaunchContext amContainer = ContainerLaunchContext.newInstance( localResources, appMasterEnv, appMasterLaunchcommands, null, null, null); + + if (UserGroupInformation.isSecurityEnabled()) { + // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce + Credentials credentials = new Credentials(); + String tokenRenewer = conf.get(YarnConfiguration.RM_PRINCIPAL); + if (tokenRenewer == null || tokenRenewer.length() == 0) { + throw new IOException( + "Can't get Master Kerberos principal for the RM to use as renewer"); + } + + // For now, only getting tokens for the default file-system. + final Token tokens[] = + dfs.addDelegationTokens(tokenRenewer, credentials); + if (tokens != null) { + for (Token token : tokens) { + LOG.info("Got dt for " + dfs.getUri() + "; " + token); + } + } + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + ByteBuffer fsTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + amContainer.setTokens(fsTokens); + } + applicationContext.setAMContainerSpec(amContainer); Priority priority = Records.newRecord(Priority.class); diff --git a/src/main/java/net/qihoo/xlearning/container/ContainerReporter.java b/src/main/java/net/qihoo/xlearning/container/ContainerReporter.java index 573b523..6f48448 100644 --- a/src/main/java/net/qihoo/xlearning/container/ContainerReporter.java +++ b/src/main/java/net/qihoo/xlearning/container/ContainerReporter.java @@ -145,7 +145,7 @@ private void produceCpuMetrics(String xlearningCmdProcessId) throws IOException } catch (IOException e) { e.printStackTrace(); } - LOG.info("containerProcessId is:" + this.containerProcessId); + LOG.info("ps is:" + this.containerProcessId); ProcessTreeInfo processTreeInfo = new ProcessTreeInfo(this.containerId.getContainerId(), null, null, 0, 0, 0); diff --git a/src/main/java/net/qihoo/xlearning/container/XLearningContainer.java b/src/main/java/net/qihoo/xlearning/container/XLearningContainer.java index 39a469c..7579035 100644 --- a/src/main/java/net/qihoo/xlearning/container/XLearningContainer.java +++ b/src/main/java/net/qihoo/xlearning/container/XLearningContainer.java @@ -10,6 +10,8 @@ import net.qihoo.xlearning.common.XLearningContainerStatus; import net.qihoo.xlearning.common.TextMultiOutputFormat; import net.qihoo.xlearning.conf.XLearningConfiguration; +import net.qihoo.xlearning.security.XTokenIdentifier; +import net.qihoo.xlearning.security.XTokenSecretManager; import net.qihoo.xlearning.util.Utilities; import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; @@ -19,6 +21,9 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.security.SecurityUtil; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; import org.apache.hadoop.yarn.api.ApplicationConstants; import org.apache.hadoop.yarn.util.ConverterUtils; import org.apache.hadoop.io.Text; @@ -124,11 +129,22 @@ private XLearningContainer() { reservedSocket = new Socket(); } - private void init() { + private void init() throws IOException { LOG.info("XLearningContainer initializing"); String appMasterHost = System.getenv(XLearningConstants.Environment.APPMASTER_HOST.toString()); int appMasterPort = Integer.valueOf(System.getenv(XLearningConstants.Environment.APPMASTER_PORT.toString())); InetSocketAddress addr = new InetSocketAddress(appMasterHost, appMasterPort); + + if(UserGroupInformation.isSecurityEnabled()) { + //add for kerberos + conf.setBoolean("hadoop.security.authorization", false); + XTokenSecretManager secretManager = new XTokenSecretManager(); + UserGroupInformation current = UserGroupInformation.getCurrentUser(); + XTokenIdentifier tokenID = new XTokenIdentifier(new Text(current.getUserName())); + Token token = new Token<>(tokenID, secretManager); + SecurityUtil.setTokenService(token, addr); + current.addToken(token); + } try { this.amClient = RPC.getProxy(ApplicationContainerProtocol.class, ApplicationContainerProtocol.versionID, addr, conf); diff --git a/src/main/java/net/qihoo/xlearning/security/Utils.java b/src/main/java/net/qihoo/xlearning/security/Utils.java new file mode 100644 index 0000000..eba35f4 --- /dev/null +++ b/src/main/java/net/qihoo/xlearning/security/Utils.java @@ -0,0 +1,6 @@ +package net.qihoo.xlearning.security; + +public class Utils { + public static final String SERVER_PRINCIPAL_KEY="xlearning.ipc.server.principal"; + +} diff --git a/src/main/java/net/qihoo/xlearning/security/XTokenIdentifier.java b/src/main/java/net/qihoo/xlearning/security/XTokenIdentifier.java new file mode 100644 index 0000000..806d066 --- /dev/null +++ b/src/main/java/net/qihoo/xlearning/security/XTokenIdentifier.java @@ -0,0 +1,59 @@ +package net.qihoo.xlearning.security; + +import org.apache.hadoop.io.Text; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.TokenIdentifier; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +public class XTokenIdentifier extends TokenIdentifier { + private Text tokenid; + private Text realUser; + final static Text KIND_NAME = new Text("xlearning.token"); + + + public XTokenIdentifier(){ + this(new Text(),new Text()); + } + + public XTokenIdentifier(Text tokenid){ + this(tokenid,new Text()); + } + + public XTokenIdentifier(Text tokenid, Text realUser){ + this.tokenid = tokenid == null ? new Text() : tokenid; + this.realUser = realUser == null ? new Text() : realUser; + } + + + @Override + public Text getKind() { + return KIND_NAME; + } + + @Override + public UserGroupInformation getUser() { + if (realUser.toString().isEmpty()) { + return UserGroupInformation.createRemoteUser(tokenid.toString()); + } else { + UserGroupInformation realUgi = UserGroupInformation + .createRemoteUser(realUser.toString()); + return UserGroupInformation + .createProxyUser(tokenid.toString(), realUgi); + } + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + tokenid.write(dataOutput); + realUser.write(dataOutput); + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + tokenid.readFields(dataInput); + realUser.readFields(dataInput); + } +} diff --git a/src/main/java/net/qihoo/xlearning/security/XTokenSecretManager.java b/src/main/java/net/qihoo/xlearning/security/XTokenSecretManager.java new file mode 100644 index 0000000..0097f24 --- /dev/null +++ b/src/main/java/net/qihoo/xlearning/security/XTokenSecretManager.java @@ -0,0 +1,21 @@ +package net.qihoo.xlearning.security; + +import org.apache.hadoop.security.token.SecretManager; + +public class XTokenSecretManager extends + SecretManager { + @Override + protected byte[] createPassword(XTokenIdentifier xTokenIdentifier) { + return xTokenIdentifier.getBytes(); + } + + @Override + public byte[] retrievePassword(XTokenIdentifier xTokenIdentifier) throws InvalidToken { + return xTokenIdentifier.getBytes(); + } + + @Override + public XTokenIdentifier createIdentifier() { + return new XTokenIdentifier(); + } +} diff --git a/src/main/java/net/qihoo/xlearning/security/XTokenSelector.java b/src/main/java/net/qihoo/xlearning/security/XTokenSelector.java new file mode 100644 index 0000000..474433f --- /dev/null +++ b/src/main/java/net/qihoo/xlearning/security/XTokenSelector.java @@ -0,0 +1,28 @@ +package net.qihoo.xlearning.security; + +import org.apache.hadoop.io.Text; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hadoop.security.token.TokenSelector; + +import java.util.Collection; + +public class XTokenSelector implements + TokenSelector { + + @Override + public Token selectToken(Text service, + Collection> tokens) { + + if (service == null) { + return null; + } + for (Token token : tokens) { + if (XTokenIdentifier.KIND_NAME.equals(token.getKind()) + && service.equals(token.getService())) { + return (Token) token; + } + } + return null; + } +}