From e746f3d9b34dd5648b5c1dd36234922e4fc394b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Art=C5=ABras=20=C5=A0lajus?= Date: Mon, 15 Jan 2024 18:36:43 +0200 Subject: [PATCH] Provide a custom way to handle SSL exceptions Same thing that was provided for ember (https://github.com/http4s/http4s/pull/7093). --- .gitignore | 8 ++ .../blaze/server/BlazeServerBuilder.scala | 84 +++++++++++++++---- .../blaze/pipeline/stages/SSLStage.scala | 38 +++++++-- 3 files changed, 106 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 275fcc490..8e3c61b6f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,14 @@ target/ *~ /.bsp/ +# Ignore Scala Metals configuration and logs. +/.metals/ +# Ignore the Bloop build server directories. +.bloop/ +# Ignore the auto-generated metals configuration file +**/metals.sbt +# Ignore settings for Visual Studio Code +/.vscode /.ensime /.idea/ /.idea_modules/ diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala index 11f7d9831..8defc37ff 100644 --- a/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala @@ -31,7 +31,7 @@ import com.comcast.ip4s.SocketAddress import org.http4s.blaze.channel._ import org.http4s.blaze.channel.nio1.NIO1SocketServerGroup import org.http4s.blaze.pipeline.LeafBuilder -import org.http4s.blaze.pipeline.stages.SSLStage +import org.http4s.blaze.pipeline.stages.{SSLStage, SSLStageDefaults} import org.http4s.blaze.server.BlazeServerBuilder._ import org.http4s.blaze.util.TickWheelExecutor import org.http4s.blaze.{BuildInfo => BlazeBuildInfo} @@ -194,15 +194,37 @@ class BlazeServerBuilder[F[_]] private ( ): Self = copy(sslConfig = new ContextWithClientAuth[F](sslContext, clientAuth)) + /** Configures the server with TLS, using the provided `SSLContext` and its + * default `SSLParameters` + * + * @param sslErrorHandler function that runs if an error occurs during the TLS handshake. Default behavior is to log the error. + */ + def withSslContext( + sslContext: SSLContext, + sslErrorHandler: PartialFunction[Throwable, Unit] = PartialFunction.empty, + ): Self = + copy(sslConfig = new ContextOnly[F](sslContext, sslErrorHandler)) + /** Configures the server with TLS, using the provided `SSLContext` and its * default `SSLParameters` */ def withSslContext(sslContext: SSLContext): Self = - copy(sslConfig = new ContextOnly[F](sslContext)) + withSslContext(sslContext, PartialFunction.empty) + + /** Configures the server with TLS, using the provided `SSLContext` and `SSLParameters`. + * + * @param sslErrorHandler function that runs if an error occurs during the TLS handshake. Default behavior is to log the error. + */ + def withSslContextAndParameters( + sslContext: SSLContext, + sslParameters: SSLParameters, + sslErrorHandler: PartialFunction[Throwable, Unit] = PartialFunction.empty, + ): Self = + copy(sslConfig = new ContextWithParameters[F](sslContext, sslParameters, sslErrorHandler)) /** Configures the server with TLS, using the provided `SSLContext` and `SSLParameters`. */ def withSslContextAndParameters(sslContext: SSLContext, sslParameters: SSLParameters): Self = - copy(sslConfig = new ContextWithParameters[F](sslContext, sslParameters)) + withSslContextAndParameters(sslContext, sslParameters, PartialFunction.empty) def withoutSsl: Self = copy(sslConfig = new NoSsl[F]()) @@ -277,7 +299,7 @@ class BlazeServerBuilder[F[_]] private ( private def pipelineFactory( scheduler: TickWheelExecutor, - engineConfig: Option[(SSLContext, SSLEngine => Unit)], + engineConfig: Option[(SSLContextWithExtras, SSLEngine => Unit)], dispatcher: Dispatcher[F], )(conn: SocketConnection): Future[LeafBuilder[ByteBuffer]] = { def requestAttributes(secure: Boolean, optionalSslEngine: Option[SSLEngine]): () => Vault = @@ -365,7 +387,7 @@ class BlazeServerBuilder[F[_]] private ( executionContextConfig.getExecutionContext[F].flatMap { executionContext => engineConfig match { case Some((ctx, configure)) => - val engine = ctx.createSSLEngine() + val engine = ctx.context.createSSLEngine() engine.setUseClientMode(false) configure(engine) @@ -373,7 +395,8 @@ class BlazeServerBuilder[F[_]] private ( if (isHttp2Enabled) http2Stage(executionContext, engine).map(LeafBuilder(_)) else http1Stage(executionContext, secure = true, engine.some).map(LeafBuilder(_)) - leafBuilder.map(_.prepend(new SSLStage(engine))) + leafBuilder + .map(_.prepend(new SSLStage(engine, SSLStageDefaults.MaxWrite, ctx.errorHandler))) case None => if (isHttp2Enabled) @@ -497,8 +520,17 @@ object BlazeServerBuilder { private def defaultThreadSelectorFactory: ThreadFactory = threadFactory(name = n => s"blaze-selector-${n}", daemon = false) + private case class SSLContextWithExtras( + context: SSLContext, + errorHandler: PartialFunction[Throwable, Unit], + ) + private object SSLContextWithExtras { + def onlyContext(context: SSLContext): SSLContextWithExtras = + apply(context, PartialFunction.empty) + } + private sealed trait SslConfig[F[_]] { - def makeContext: F[Option[SSLContext]] + def makeContext: F[Option[SSLContextWithExtras]] def configureEngine(sslEngine: SSLEngine): Unit def isSecure: Boolean } @@ -511,7 +543,7 @@ object BlazeServerBuilder { clientAuth: SSLClientAuthMode, )(implicit F: Sync[F]) extends SslConfig[F] { - def makeContext: F[Option[SSLContext]] = + def makeContext: F[Option[SSLContextWithExtras]] = F.delay { val ksStream = new FileInputStream(keyStore.path) val ks = KeyStore.getInstance("JKS") @@ -540,24 +572,43 @@ object BlazeServerBuilder { val context = SSLContext.getInstance(protocol) context.init(kmf.getKeyManagers, tmf.orNull, null) - context.some + SSLContextWithExtras(context, PartialFunction.empty).some } def configureEngine(engine: SSLEngine): Unit = configureEngineFromSslClientAuthMode(engine, clientAuth) def isSecure: Boolean = true } - private class ContextOnly[F[_]](sslContext: SSLContext)(implicit F: Applicative[F]) + private class ContextOnly[F[_]]( + sslContext: SSLContext, + sslErrorHandler: PartialFunction[Throwable, Unit], + )(implicit F: Applicative[F]) extends SslConfig[F] { - def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some) + + /** Constructor for backwards compatibility */ + def this(sslContext: SSLContext)(implicit F: Applicative[F]) = + this(sslContext, PartialFunction.empty) + + def makeContext: F[Option[SSLContextWithExtras]] = + F.pure(SSLContextWithExtras(sslContext, sslErrorHandler).some) def configureEngine(engine: SSLEngine): Unit = () def isSecure: Boolean = true } - private class ContextWithParameters[F[_]](sslContext: SSLContext, sslParameters: SSLParameters)( - implicit F: Applicative[F] + private class ContextWithParameters[F[_]]( + sslContext: SSLContext, + sslParameters: SSLParameters, + sslErrorHandler: PartialFunction[Throwable, Unit], + )(implicit + F: Applicative[F] ) extends SslConfig[F] { - def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some) + + /** Constructor for backwards compatibility */ + def this(sslContext: SSLContext, sslParameters: SSLParameters)(implicit F: Applicative[F]) = + this(sslContext, sslParameters, PartialFunction.empty) + + def makeContext: F[Option[SSLContextWithExtras]] = + F.pure(SSLContextWithExtras(sslContext, sslErrorHandler).some) def configureEngine(engine: SSLEngine): Unit = engine.setSSLParameters(sslParameters) def isSecure: Boolean = true } @@ -565,14 +616,15 @@ object BlazeServerBuilder { private class ContextWithClientAuth[F[_]](sslContext: SSLContext, clientAuth: SSLClientAuthMode)( implicit F: Applicative[F] ) extends SslConfig[F] { - def makeContext: F[Option[SSLContext]] = F.pure(sslContext.some) + def makeContext: F[Option[SSLContextWithExtras]] = + F.pure(SSLContextWithExtras.onlyContext(sslContext).some) def configureEngine(engine: SSLEngine): Unit = configureEngineFromSslClientAuthMode(engine, clientAuth) def isSecure: Boolean = true } private class NoSsl[F[_]]()(implicit F: Applicative[F]) extends SslConfig[F] { - def makeContext: F[Option[SSLContext]] = F.pure(None) + def makeContext: F[Option[SSLContextWithExtras]] = F.pure(None) def configureEngine(engine: SSLEngine): Unit = () def isSecure: Boolean = false } diff --git a/core/src/main/scala/org/http4s/blaze/pipeline/stages/SSLStage.scala b/core/src/main/scala/org/http4s/blaze/pipeline/stages/SSLStage.scala index 1acf851b6..1ff927da4 100644 --- a/core/src/main/scala/org/http4s/blaze/pipeline/stages/SSLStage.scala +++ b/core/src/main/scala/org/http4s/blaze/pipeline/stages/SSLStage.scala @@ -54,10 +54,30 @@ private object SSLStage { private case class SSLFailure(t: Throwable) extends SSLResult } -final class SSLStage(engine: SSLEngine, maxWrite: Int = 1024 * 1024) - extends MidStage[ByteBuffer, ByteBuffer] { +/** Default values for the [[SSLStage]] constructors. + * + * A separate `object` because the [[SSLStage]] `object` is `private`. + */ +object SSLStageDefaults { + + /** Default value for [[SSLStage.maxWrite]] */ + final val MaxWrite = 1024 * 1024 +} + +/** @param maxWrite + * \@see [[SSLStageDefaults.MaxWrite]]. + */ +final class SSLStage( + engine: SSLEngine, + maxWrite: Int, + sslHandshakeExceptionHandler: PartialFunction[Throwable, Unit] +) extends MidStage[ByteBuffer, ByteBuffer] { import SSLStage._ + /** Constructor to keep backwards compatibility with old versions. */ + def this(engine: SSLEngine, maxWrite: Int = SSLStageDefaults.MaxWrite) = + this(engine, maxWrite, PartialFunction.empty) + def name: String = "SSLStage" // We use a serial executor to ensure single threaded behavior. This makes @@ -248,13 +268,15 @@ final class SSLStage(engine: SSLEngine, maxWrite: Int = 1024 * 1024) val start = System.nanoTime try sslHandshakeLoop(data, r) catch { - case t: SSLException => - logger.warn(t)("SSLException in SSL handshake") - handshakeFailure(t) - case NonFatal(t) => - logger.error(t)("Error in SSL handshake") - handshakeFailure(t) + try + if (sslHandshakeExceptionHandler.isDefinedAt(t)) sslHandshakeExceptionHandler(t) + else + t match { + case t: SSLException => logger.warn(t)("SSLException in SSL handshake") + case _ => logger.error(t)("Error in SSL handshake") + } + finally handshakeFailure(t) } logger.trace(s"${engine.##}: sslHandshake completed in ${System.nanoTime - start}ns") }