Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks and improve performance of ServletIo#requestBody #171

Open
wants to merge 10 commits into
base: series/0.23
Choose a base branch
from
171 changes: 171 additions & 0 deletions benchmarks/src/main/scala/org/http4s/servlet/ServletIoBenchmarks.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package org.http4s.servlet

import cats.effect.IO
import cats.effect.std.Dispatcher
import cats.effect.unsafe.implicits.global

import org.openjdk.jmh.annotations._
import org.http4s.servlet.NonBlockingServletIo

import java.io.ByteArrayInputStream
import java.util.concurrent.TimeUnit
import javax.servlet.{ServletInputStream, ReadListener}
import javax.servlet.http.HttpServletRequest
import scala.util.Random

/** To do comparative benchmarks between versions:
*
* benchmarks/run-benchmark AsyncBenchmark
*
* This will generate results in `benchmarks/results`.
*
* Or to run the benchmark from within sbt:
*
* Jmh / run -i 10 -wi 10 -f 2 -t 1 cats.effect.benchmarks.AsyncBenchmark
*
* Which means "10 iterations", "10 warm-up iterations", "2 forks", "1 thread". Please note that
* benchmarks should be usually executed at least in 10 iterations (as a rule of thumb), but
* more is better.
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
class ServletIoBenchmarks {

@Param(Array("100000"))
var size: Int = _

@Param(Array("1000"))
var iters: Int = _

def servletRequest: HttpServletRequest = new HttpServletRequestStub(
new TestServletInputStream(Random.nextBytes(size))
)

@Benchmark
def reader() = {
val req = servletRequest
val servletIo = NonBlockingServletIo[IO](4096)

def loop(i: Int): IO[Unit] =
if (i == iters) IO.unit else servletIo.reader(req).compile.drain >> loop(i + 1)

loop(0).unsafeRunSync()
}

@Benchmark
def requestBody() = {
val req = servletRequest
val servletIo = NonBlockingServletIo[IO](4096)

Dispatcher
.sequential[IO]
rossabaker marked this conversation as resolved.
Show resolved Hide resolved
.use { disp =>
def loop(i: Int): IO[Unit] =
if (i == iters) IO.unit else servletIo.requestBody(req, disp).compile.drain >> loop(i + 1)

loop(0)
}
.unsafeRunSync()
}

class TestServletInputStream(body: Array[Byte]) extends ServletInputStream {
private var readListener: ReadListener = null
private val in = new ByteArrayInputStream(body)

override def isReady: Boolean = true

override def isFinished: Boolean = in.available() == 0

override def setReadListener(readListener: ReadListener): Unit = {
this.readListener = readListener
readListener.onDataAvailable()
}

override def read(): Int = {
val result = in.read()
if (in.available() == 0)
readListener.onAllDataRead()
result
}
}

case class HttpServletRequestStub(
inputStream: ServletInputStream
) extends HttpServletRequest {
def getInputStream(): ServletInputStream = inputStream

def authenticate(x$1: javax.servlet.http.HttpServletResponse): Boolean = ???
def changeSessionId(): String = ???
def getAuthType(): String = ???
def getContextPath(): String = ???
def getCookies(): Array[javax.servlet.http.Cookie] = ???
def getDateHeader(x$1: String): Long = ???
def getHeader(x$1: String): String = ???
def getHeaderNames(): java.util.Enumeration[String] = ???
def getHeaders(x$1: String): java.util.Enumeration[String] = ???
def getIntHeader(x$1: String): Int = ???
def getMethod(): String = ???
def getPart(x$1: String): javax.servlet.http.Part = ???
def getParts(): java.util.Collection[javax.servlet.http.Part] = ???
def getPathInfo(): String = ???
def getPathTranslated(): String = ???
def getQueryString(): String = ???
def getRemoteUser(): String = ???
def getRequestURI(): String = ???
def getRequestURL(): StringBuffer = ???
def getRequestedSessionId(): String = ???
def getServletPath(): String = ???
def getSession(): javax.servlet.http.HttpSession = ???
def getSession(x$1: Boolean): javax.servlet.http.HttpSession = ???
def getUserPrincipal(): java.security.Principal = ???
def isRequestedSessionIdFromCookie(): Boolean = ???
def isRequestedSessionIdFromURL(): Boolean = ???
def isRequestedSessionIdFromUrl(): Boolean = ???
def isRequestedSessionIdValid(): Boolean = ???
def isUserInRole(x$1: String): Boolean = ???
def login(x$1: String, x$2: String): Unit = ???
def logout(): Unit = ???
def upgrade[T <: javax.servlet.http.HttpUpgradeHandler](x$1: Class[T]): T = ???
def getAsyncContext(): javax.servlet.AsyncContext = ???
def getAttribute(x$1: String): Object = ???
def getAttributeNames(): java.util.Enumeration[String] = ???
def getCharacterEncoding(): String = ???
def getContentLength(): Int = ???
def getContentLengthLong(): Long = ???
def getContentType(): String = ???
def getDispatcherType(): javax.servlet.DispatcherType = ???
def getLocalAddr(): String = ???
def getLocalName(): String = ???
def getLocalPort(): Int = ???
def getLocale(): java.util.Locale = ???
def getLocales(): java.util.Enumeration[java.util.Locale] = ???
def getParameter(x$1: String): String = ???
def getParameterMap(): java.util.Map[String, Array[String]] = ???
def getParameterNames(): java.util.Enumeration[String] = ???
def getParameterValues(x$1: String): Array[String] = ???
def getProtocol(): String = ???
def getReader(): java.io.BufferedReader = ???
def getRealPath(x$1: String): String = ???
def getRemoteAddr(): String = ???
def getRemoteHost(): String = ???
def getRemotePort(): Int = ???
def getRequestDispatcher(x$1: String): javax.servlet.RequestDispatcher = ???
def getScheme(): String = ???
def getServerName(): String = ???
def getServerPort(): Int = ???
def getServletContext(): javax.servlet.ServletContext = ???
def isAsyncStarted(): Boolean = ???
def isAsyncSupported(): Boolean = ???
def isSecure(): Boolean = ???
def removeAttribute(x$1: String): Unit = ???
def setAttribute(x$1: String, x$2: Object): Unit = ???
def setCharacterEncoding(x$1: String): Unit = ???
def startAsync(
x$1: javax.servlet.ServletRequest,
x$2: javax.servlet.ServletResponse,
): javax.servlet.AsyncContext = ???
def startAsync(): javax.servlet.AsyncContext = ???
}

}
14 changes: 14 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ lazy val servlet = project
"org.eclipse.jetty" % "jetty-servlet" % jettyVersion % Test,
"org.http4s" %% "http4s-dsl" % http4sVersion % Test,
"org.http4s" %% "http4s-server" % http4sVersion,
"org.typelevel" %% "cats-effect" % "3.4.5",
"org.typelevel" %% "munit-cats-effect-3" % munitCatsEffectVersion % Test,
),
)
Expand All @@ -64,3 +65,16 @@ lazy val examples = project
.dependsOn(servlet)

lazy val docs = project.in(file("site")).enablePlugins(TypelevelSitePlugin)

lazy val benchmarks = project
.in(file("benchmarks"))
.dependsOn(servlet)
.settings(
name := "servlet-benchmarks",
libraryDependencies ++= Seq(
"javax.servlet" % "javax.servlet-api" % servletApiVersion,
),
javaOptions ++= Seq(
"-Dcats.effect.tracing.mode=none",
"-Dcats.effect.tracing.exceptions.enhanced=false"))
.enablePlugins(NoPublishPlugin, JmhPlugin)
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
addSbtPlugin("com.earldouglas" % "xsbt-web-plugin" % "4.2.4")
addSbtPlugin("org.http4s" % "sbt-http4s-org" % "0.14.9")
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.3")
24 changes: 11 additions & 13 deletions servlet/src/main/scala/org/http4s/servlet/ServletIo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,11 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
): Stream[F, Byte] = {
sealed trait Read
final case class Bytes(chunk: Chunk[Byte]) extends Read
case object End extends Read
final case class Error(t: Throwable) extends Read
case object End

Stream.eval(F.delay(servletRequest.getInputStream)).flatMap { in =>
Stream.eval(Queue.bounded[F, Read](4)).flatMap { q =>
val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener {
Stream.eval(Queue.bounded[F, Any](4)).flatMap { q =>
val readBody = Stream.eval(F.delay(in.setReadListener(new ReadListener {
var buf: Array[Byte] = _
unsafeReplaceBuffer()

Expand All @@ -238,10 +235,11 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
F.delay(in.read(buf)).flatMap {
case len if len == chunkSize =>
// We used the whole buffer. Replace it new before next read.
q.offer(Bytes(Chunk.array(buf))) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady
case len if len >= 0 =>
q.offer(Chunk.array(buf)) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady
case len if len > 0 =>
// Got a partial chunk. Copy it, and reuse the current buffer.
q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady
q.offer(Chunk.array(Arrays.copyOf(buf, len))) >> loopIfReady
case len if len == 0 => loopIfReady
case _ =>
F.unit
}
Expand All @@ -253,7 +251,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
unsafeRunAndForget(q.offer(End))

def onError(t: Throwable): Unit =
unsafeRunAndForget(q.offer(Error(t)))
unsafeRunAndForget(q.offer(t))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
Expand All @@ -263,12 +261,12 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl

def pullBody: Pull[F, Byte, Unit] =
Pull.eval(q.take).flatMap {
case Bytes(chunk) => Pull.output(chunk) >> pullBody
case chunk: Chunk[Byte] @ unchecked => Pull.output(chunk) >> pullBody
case End => Pull.done
case Error(t) => Pull.raiseError[F](t)
case t: Throwable => Pull.raiseError[F](t)
}

pullBody.stream.concurrently(readBody)
readBody.flatMap(_ => pullBody.stream)
}
}
}
Expand Down