Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ import org.apache.commons.io.IOUtils
import org.apache.commons.io.input.BoundedInputStream
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.util.VersionInfo
import org.apache.http.{HttpHeaders, HttpHost, HttpStatus}
import org.apache.http.{HttpHeaders, HttpHost, HttpRequest, HttpStatus}
import org.apache.http.client.config.RequestConfig
import org.apache.http.client.methods.{HttpGet, HttpPost, HttpRequestBase}
import org.apache.http.client.protocol.HttpClientContext
import org.apache.http.conn.routing.HttpRoute
import org.apache.http.conn.ssl.{SSLConnectionSocketFactory, SSLContextBuilder, TrustSelfSignedStrategy}
import org.apache.http.entity.StringEntity
import org.apache.http.impl.client.{HttpClientBuilder, HttpClients}
import org.apache.http.impl.conn.{DefaultRoutePlanner, DefaultSchemePortResolver}
import org.apache.http.protocol.HttpContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

import io.delta.sharing.client.auth.{AuthConfig, AuthCredentialProviderFactory}
import io.delta.sharing.client.model._
import io.delta.sharing.client.util.{ConfUtils, JsonUtils, RetryUtils, UnexpectedHttpStatus}
import io.delta.sharing.client.util.ConfUtils.ProxyConfig
import io.delta.sharing.spark.MissingEndStreamActionException

/** An interface to fetch Delta metadata from remote server. */
Expand Down Expand Up @@ -198,7 +202,8 @@ class DeltaSharingRestClient(
asyncQueryMaxDuration: Long = 600000L,
tokenExchangeMaxRetries: Int = 5,
tokenExchangeMaxRetryDurationInSeconds: Int = 60,
tokenRenewalThresholdInSeconds: Int = 600
tokenRenewalThresholdInSeconds: Int = 600,
proxyConfigOpt: Option[ProxyConfig] = None
) extends DeltaSharingClient with Logging {

logInfo(s"DeltaSharingRestClient with endStreamActionEnabled: $endStreamActionEnabled, " +
Expand All @@ -211,7 +216,7 @@ class DeltaSharingRestClient(
// Convert the responseFormat to a Seq to be used later.
private val responseFormatSet = responseFormat.split(",").toSet

private lazy val client = {
private[sharing] lazy val client = {
val clientBuilder: HttpClientBuilder = if (sslTrustAll) {
val sslBuilder = new SSLContextBuilder()
.loadTrustMaterial(null, new TrustSelfSignedStrategy())
Expand All @@ -227,6 +232,31 @@ class DeltaSharingRestClient(
.setConnectTimeout(timeoutInSeconds * 1000)
.setConnectionRequestTimeout(timeoutInSeconds * 1000)
.setSocketTimeout(timeoutInSeconds * 1000).build()
proxyConfigOpt.foreach { proxyConfig =>
if (sslTrustAll) {
throw new IllegalStateException(
"Proxy configuration is not supported when sslTrustAll is enabled.")
}
val proxy = new HttpHost(proxyConfig.host, proxyConfig.port)
clientBuilder.setProxy(proxy)

if (proxyConfig.noProxyHosts.nonEmpty) {
val routePlanner = new DefaultRoutePlanner(DefaultSchemePortResolver.INSTANCE) {
override def determineRoute(target: HttpHost,
request: HttpRequest,
context: HttpContext): HttpRoute = {
if (proxyConfig.noProxyHosts.contains(target.getHostName)) {
// Direct route (no proxy)
new HttpRoute(target)
} else {
// Route via proxy
new HttpRoute(target, proxy)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor always constructs an insecure (http) route through the proxy. This does not work if one is using an http proxy to make an https call to the server. Instead, consider using an alternate constructor and setting the secure flag appropriately based on the target:

Suggested change
new HttpRoute(target, proxy)
val isSecure = target.getSchemeName == "https"
new HttpRoute(target, null, proxy, isSecure)

}
}
}
clientBuilder.setRoutePlanner(routePlanner)
}
}
val client = clientBuilder
// Disable the default retry behavior because we have our own retry logic.
// See `RetryUtils.runWithExponentialBackoff`.
Expand Down Expand Up @@ -1401,6 +1431,7 @@ object DeltaSharingRestClient extends Logging {
val endStreamActionEnabled = ConfUtils.includeEndStreamAction(sqlConf)
val asyncQueryMaxDurationMillis = ConfUtils.asyncQueryTimeout(sqlConf)
val asyncQueryPollDurationMillis = ConfUtils.asyncQueryPollIntervalMillis(sqlConf)
val proxyConfig = ConfUtils.getClientProxyConfig(sqlConf)

val tokenExchangeMaxRetries = ConfUtils.tokenExchangeMaxRetries(sqlConf)
val tokenExchangeMaxRetryDurationInSeconds =
Expand All @@ -1427,7 +1458,8 @@ object DeltaSharingRestClient extends Logging {
classOf[Long],
classOf[Int],
classOf[Int],
classOf[Int]
classOf[Int],
classOf[Option[ProxyConfig]]
).newInstance(profileProvider,
java.lang.Integer.valueOf(timeoutInSeconds),
java.lang.Integer.valueOf(numRetries),
Expand All @@ -1445,7 +1477,8 @@ object DeltaSharingRestClient extends Logging {
java.lang.Long.valueOf(asyncQueryMaxDurationMillis),
java.lang.Integer.valueOf(tokenExchangeMaxRetries),
java.lang.Integer.valueOf(tokenExchangeMaxRetryDurationInSeconds),
java.lang.Integer.valueOf(tokenRenewalThresholdInSeconds)
java.lang.Integer.valueOf(tokenRenewalThresholdInSeconds),
proxyConfig
).asInstanceOf[DeltaSharingClient]
}
}
21 changes: 21 additions & 0 deletions client/src/main/scala/io/delta/sharing/client/util/ConfUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ object ConfUtils {
val PROXY_PORT = "spark.delta.sharing.network.proxyPort"
val NO_PROXY_HOSTS = "spark.delta.sharing.network.noProxyHosts"

val CLIENT_PROXY_HOST = "spark.delta.sharing.client.network.proxyHost"
val CLIENT_PROXY_PORT = "spark.delta.sharing.client.network.proxyPort"
val CLIENT_NO_PROXY_HOSTS = "spark.delta.sharing.client.network.noProxyHosts"

val OAUTH_RETRIES_CONF = "spark.delta.sharing.oauth.tokenExchangeMaxRetries"
val OAUTH_RETRIES_DEFAULT = 5

Expand Down Expand Up @@ -118,6 +122,23 @@ object ConfUtils {
Some(ProxyConfig(proxyHost, proxyPort, noProxyHosts = noProxyList))
}

def getClientProxyConfig(conf: SQLConf): Option[ProxyConfig] = {
val proxyHost = conf.getConfString(CLIENT_PROXY_HOST, null)
val proxyPortAsString = conf.getConfString(CLIENT_PROXY_PORT, null)

if (proxyHost == null && proxyPortAsString == null) {
return None
}

validateNonEmpty(proxyHost, CLIENT_PROXY_HOST)
validateNonEmpty(proxyPortAsString, CLIENT_PROXY_PORT)
val proxyPort = proxyPortAsString.toInt
validatePortNumber(proxyPort, CLIENT_PROXY_PORT)

val noProxyList = conf.getConfString(CLIENT_NO_PROXY_HOSTS, "").split(",").map(_.trim).toSeq
Some(ProxyConfig(proxyHost, proxyPort, noProxyHosts = noProxyList))
}

def getNeverUseHttps(conf: Configuration): Boolean = {
conf.getBoolean(NEVER_USE_HTTPS, NEVER_USE_HTTPS_DEFAULT.toBoolean)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,17 @@
package io.delta.sharing.client

import java.sql.Timestamp
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import org.apache.http.HttpHeaders
import org.apache.http.client.methods.{HttpGet, HttpRequestBase}
import org.apache.http.util.EntityUtils
import org.sparkproject.jetty.server.Server
import org.sparkproject.jetty.servlet.{ServletHandler, ServletHolder}

import io.delta.sharing.client.model.{
AddCDCFile,
AddFile,
AddFileForCDF,
DeltaTableFiles,
EndStreamAction,
Format,
Metadata,
Protocol,
RemoveFile,
Table
}
import io.delta.sharing.client.util.JsonUtils
import io.delta.sharing.client.util.UnexpectedHttpStatus
import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, EndStreamAction, Format, Metadata, Protocol, RemoveFile, Table}
import io.delta.sharing.client.util.{JsonUtils, ProxyServer, UnexpectedHttpStatus}
import io.delta.sharing.client.util.ConfUtils.ProxyConfig
import io.delta.sharing.spark.MissingEndStreamActionException

// scalastyle:off maxLineLength
Expand Down Expand Up @@ -1330,4 +1323,178 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest {
}
checkErrorMessage(e, s"and 0 lines, and last line as [Empty_Seq_in_checkEndStreamAction].")
}

integrationTest("traffic goes through a proxy when a proxy configured") {
// Create a local HTTP server.
val server = new Server(0)
val handler = new ServletHandler()
server.setHandler(handler)
handler.addServletWithMapping(new ServletHolder(new HttpServlet {
override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
resp.setContentType("text/plain")
resp.setStatus(HttpServletResponse.SC_OK)

// scalastyle:off println
resp.getWriter.println("Hello, World!")
// scalastyle:on println
}
}), "/*")
server.start()
do {
Thread.sleep(100)
} while (!server.isStarted())

// Create a local HTTP proxy server.
val proxyServer = new ProxyServer(0)
proxyServer.initialize()

try {
val dsClient = new DeltaSharingRestClient(
testProfileProvider,
sslTrustAll = false,
proxyConfigOpt = Some(
ProxyConfig(
host = proxyServer.getHost(),
port = proxyServer.getPort(),
noProxyHosts = Seq(server.getURI.getHost)
)
)
)

// Send a request to the local server through the httpClient.
val response = dsClient.client.execute(new HttpGet(server.getURI.toString))

// Assert that the request is successful.
assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK)
val content = EntityUtils.toString(response.getEntity)
assert(content.trim == "Hello, World!")

// Assert that the request is passed through proxy.
assert(proxyServer.getCapturedRequests().size == 1)
} finally {
server.stop()
proxyServer.stop()
}
}

integrationTest("traffic skips the proxy when a noProxyHosts configured") {
// Create a local HTTP server.
val server = new Server(0)
val handler = new ServletHandler()
server.setHandler(handler)
handler.addServletWithMapping(new ServletHolder(new HttpServlet {
override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
resp.setContentType("text/plain")
resp.setStatus(HttpServletResponse.SC_OK)

// scalastyle:off println
resp.getWriter.println("Hello, World!")
// scalastyle:on println
}
}), "/*")
server.start()
do {
Thread.sleep(100)
} while (!server.isStarted())

// Create a local HTTP proxy server.
val proxyServer = new ProxyServer(0)
proxyServer.initialize()
try {
val dsClient = new DeltaSharingRestClient(
testProfileProvider,
sslTrustAll = false,
proxyConfigOpt = Some(
ProxyConfig(
host = proxyServer.getHost(),
port = proxyServer.getPort(),
noProxyHosts = Seq(server.getURI.getHost)
)
)
)

// Send a request to the local server through the httpClient.
val response = dsClient.client.execute(new HttpGet(server.getURI.toString))

// Assert that the request is successful.
assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK)
val content = EntityUtils.toString(response.getEntity)
assert(content.trim == "Hello, World!")

// Assert that the request is not passed through proxy.
assert(proxyServer.getCapturedRequests().isEmpty)
} finally {
server.stop()
proxyServer.stop()
}
}

integrationTest("traffic goes through the proxy when noProxyHosts does not include destination") {
// Create a local HTTP server.
val server = new Server(0)
val handler = new ServletHandler()
server.setHandler(handler)
handler.addServletWithMapping(new ServletHolder(new HttpServlet {
override def doGet(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
resp.setContentType("text/plain")
resp.setStatus(HttpServletResponse.SC_OK)

// scalastyle:off println
resp.getWriter.println("Hello, World!")
// scalastyle:on println
}
}), "/*")
server.start()
do {
Thread.sleep(100)
} while (!server.isStarted())

// Create a local HTTP proxy server.
val proxyServer = new ProxyServer(0)
proxyServer.initialize()
try {
val dsClient = new DeltaSharingRestClient(
testProfileProvider,
sslTrustAll = false,
proxyConfigOpt = Some(
ProxyConfig(
host = proxyServer.getHost(),
port = proxyServer.getPort(),
noProxyHosts = Seq(server.getURI.getHost)
)
)
)

// Send a request to the local server through the httpClient.
val response = dsClient.client.execute(new HttpGet(server.getURI.toString))

// Assert that the request is successful.
assert(response.getStatusLine.getStatusCode == HttpServletResponse.SC_OK)
val content = EntityUtils.toString(response.getEntity)
assert(content.trim == "Hello, World!")

// Assert that the request is not passed through proxy.
assert(proxyServer.getCapturedRequests().size == 1)
} finally {
server.stop()
proxyServer.stop()
}
}

integrationTest("sslTrustAll cannot be true if proxy configured") {
val e = intercept[IllegalStateException] {
new DeltaSharingRestClient(
testProfileProvider,
sslTrustAll = true,
proxyConfigOpt = Some(
ProxyConfig(
host = "localhost",
port = 8080,
noProxyHosts = Seq()
)
)
).client
}
assert(e.getMessage.contains("Proxy configuration is not supported when sslTrustAll is enabled."))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,9 @@

package io.delta.sharing.spark

import io.delta.sharing.client.{
DeltaSharingClient,
DeltaSharingProfile,
DeltaSharingProfileProvider
}
import io.delta.sharing.client.model.{
AddCDCFile,
AddFile,
AddFileForCDF,
DeltaTableFiles,
DeltaTableMetadata,
Metadata,
Protocol,
RemoveFile,
SingleAction,
Table
}
import io.delta.sharing.client.{DeltaSharingClient, DeltaSharingProfile, DeltaSharingProfileProvider}
import io.delta.sharing.client.model.{AddCDCFile, AddFile, AddFileForCDF, DeltaTableFiles, DeltaTableMetadata, Metadata, Protocol, RemoveFile, SingleAction, Table}
import io.delta.sharing.client.util.ConfUtils.ProxyConfig
import io.delta.sharing.client.util.JsonUtils
import io.delta.sharing.spark.TestDeltaSharingClient.TESTING_TIMESTAMP

Expand All @@ -54,7 +40,8 @@ class TestDeltaSharingClient(
asyncQueryMaxDuration: Long = Long.MaxValue,
tokenExchangeMaxRetries: Int = 5,
tokenExchangeMaxRetryDurationInSeconds: Int = 60,
tokenRenewalThresholdInSeconds: Int = 600
tokenRenewalThresholdInSeconds: Int = 600,
proxyConfigOpt: Option[ProxyConfig] = None
) extends DeltaSharingClient {

import DeltaSharingOptions.RESPONSE_FORMAT_PARQUET
Expand Down