Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a sequential dispatcher for each request and response #170

Open
wants to merge 1 commit into
base: series/0.23
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0.
val ctx = servletRequest.startAsync()
ctx.setTimeout(asyncTimeoutMillis)
// Must be done on the container thread for Tomcat's sake when using async I/O.
val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _
val bodyWriter = servletIo.bodyWriter(servletResponse) _
val result = F
.attempt(
toRequest(servletRequest).fold(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BlockingHttp4sServlet[F[_]] private (
): Unit = {
val result = F
.defer {
val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _
val bodyWriter = servletIo.bodyWriter(servletResponse) _

val render = toRequest(servletRequest).fold(
onParseFailure(_, servletResponse, bodyWriter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ abstract class Http4sServlet[F[_]](
uri = uri,
httpVersion = version,
headers = toHeaders(req),
body = servletIo.requestBody(req, dispatcher),
body = servletIo.requestBody(req),
attributes = attributes,
)

Expand Down
208 changes: 115 additions & 93 deletions servlet/src/main/scala/org/http4s/servlet/ServletIo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,41 @@ sealed abstract class ServletIo[F[_]: Async] {
protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F]

@nowarn("cat=deprecation")
def requestBody(
servletRequest: HttpServletRequest
): Stream[F, Byte] =
reader(servletRequest)

@deprecated(
"Prefer requestBody(HttpServletRequest), which creates its own request-scoped dispatcher",
"0.23.14",
)
def requestBody(
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
): Stream[F, Byte] = {
val _ = dispatcher // unused
reader(servletRequest)
requestBody(servletRequest)
}

/** May install a listener on the servlet response. */
@deprecated("Prefer bodyWriter, which has access to a Dispatcher", "0.23.12")
protected[servlet] def initWriter(servletResponse: HttpServletResponse): BodyWriter[F]

@nowarn("cat=deprecation")
def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])(
def bodyWriter(servletResponse: HttpServletResponse)(
response: Response[F]
): F[Unit] = {
val _ = dispatcher
): F[Unit] =
initWriter(servletResponse)(response)
}

@deprecated(
"Prefer bodyWriter(HttpServletResponse), which creates its own response-scoped dispatcher",
"0.23.14",
)
def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])(
response: Response[F]
): F[Unit] =
bodyWriter(servletResponse)(response)
}

/** Use standard blocking reads and writes.
Expand Down Expand Up @@ -210,8 +226,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
* https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala
*/
override def requestBody(
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
servletRequest: HttpServletRequest
): Stream[F, Byte] = {
sealed trait Read
final case class Bytes(chunk: Chunk[Byte]) extends Read
Expand All @@ -220,55 +235,58 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl

Stream.eval(F.delay(servletRequest.getInputStream)).flatMap { in =>
Stream.eval(Queue.bounded[F, Read](4)).flatMap { q =>
val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener {
var buf: Array[Byte] = _
unsafeReplaceBuffer()

def unsafeReplaceBuffer() =
buf = new Array[Byte](chunkSize)

def onDataAvailable(): Unit = {
def loopIfReady =
F.delay(in.isReady()).flatMap {
case true => go
case false => F.unit
}
Stream.resource(Dispatcher.sequential[F]).flatMap { dispatcher =>
val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener {
var buf: Array[Byte] = _
unsafeReplaceBuffer()

def unsafeReplaceBuffer() =
buf = new Array[Byte](chunkSize)

def onDataAvailable(): Unit = {
def loopIfReady =
F.delay(in.isReady()).flatMap {
case true => go
case false => F.unit
}

def go: F[Unit] =
F.delay(in.read(buf)).flatMap {
case len if len == chunkSize =>
// We used the whole buffer. Replace it new before next read.
q.offer(Bytes(Chunk.array(buf))) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady
case len if len >= 0 =>
// Got a partial chunk. Copy it, and reuse the current buffer.
q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady
case _ =>
F.unit
}
def go: F[Unit] =
F.delay(in.read(buf)).flatMap {
case len if len == chunkSize =>
// We used the whole buffer. Replace it new before next read.
q.offer(Bytes(Chunk.array(buf))) >> F
.delay(unsafeReplaceBuffer()) >> loopIfReady
case len if len >= 0 =>
// Got a partial chunk. Copy it, and reuse the current buffer.
q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady
case _ =>
F.unit
}

unsafeRunAndForget(go)
}
unsafeRunAndForget(go)
}

def onAllDataRead(): Unit =
unsafeRunAndForget(q.offer(End))
def onAllDataRead(): Unit =
unsafeRunAndForget(q.offer(End))

def onError(t: Throwable): Unit =
unsafeRunAndForget(q.offer(Error(t)))
def onError(t: Throwable): Unit =
unsafeRunAndForget(q.offer(Error(t)))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError { case t => F.delay(logger.error(t)("Error in servlet read listener")) }
)
})))
def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError { case t => F.delay(logger.error(t)("Error in servlet read listener")) }
)
})))

def pullBody: Pull[F, Byte, Unit] =
Pull.eval(q.take).flatMap {
case Bytes(chunk) => Pull.output(chunk) >> pullBody
case End => Pull.done
case Error(t) => Pull.raiseError[F](t)
}
def pullBody: Pull[F, Byte, Unit] =
Pull.eval(q.take).flatMap {
case Bytes(chunk) => Pull.output(chunk) >> pullBody
case End => Pull.done
case Error(t) => Pull.raiseError[F](t)
}

pullBody.stream.concurrently(readBody)
pullBody.stream.concurrently(readBody)
}
}
}
}
Expand Down Expand Up @@ -372,8 +390,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
* https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala
*/
override def bodyWriter(
servletResponse: HttpServletResponse,
dispatcher: Dispatcher[F],
servletResponse: HttpServletResponse
)(response: Response[F]): F[Unit] = {
sealed trait Write
final case class Bytes(chunk: Chunk[Byte]) extends Write
Expand All @@ -385,54 +402,59 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
F.delay(servletResponse.getOutputStream).flatMap { out =>
Queue.bounded[F, Write](4).flatMap { q =>
Deferred[F, Either[Throwable, Unit]].flatMap { done =>
val writeBody = F.delay(out.setWriteListener(new WriteListener {
def onWritePossible(): Unit = {
def loopIfReady = F.delay(out.isReady()).flatMap {
case true => go
case false => F.unit
}
Dispatcher.sequential[F].use { dispatcher =>
val writeBody = F.delay(out.setWriteListener(new WriteListener {
def onWritePossible(): Unit = {
def loopIfReady = F.delay(out.isReady()).flatMap {
case true => go
case false => F.unit
}

def flush =
if (autoFlush) {
F.delay(out.isReady()).flatMap {
case true => F.delay(out.flush()) >> loopIfReady
case false => F.unit
def flush =
if (autoFlush) {
F.delay(out.isReady()).flatMap {
case true => F.delay(out.flush()) >> loopIfReady
case false => F.unit
}
} else
loopIfReady

def go: F[Unit] =
q.take.flatMap {
case Bytes(slice: Chunk.ArraySlice[_]) =>
F.delay(
out
.write(slice.values.asInstanceOf[Array[Byte]], slice.offset, slice.length)
) >> flush
case Bytes(chunk) =>
F.delay(out.write(chunk.toArray)) >> flush
case End =>
F.delay(out.flush()) >> done.complete(Either.unit).attempt.void
case Init =>
if (autoFlush) flush else go
}
} else
loopIfReady

def go: F[Unit] =
q.take.flatMap {
case Bytes(slice: Chunk.ArraySlice[_]) =>
F.delay(
out.write(slice.values.asInstanceOf[Array[Byte]], slice.offset, slice.length)
) >> flush
case Bytes(chunk) =>
F.delay(out.write(chunk.toArray)) >> flush
case End =>
F.delay(out.flush()) >> done.complete(Either.unit).attempt.void
case Init =>
if (autoFlush) flush else go
}

unsafeRunAndForget(go)
}
def onError(t: Throwable): Unit =
unsafeRunAndForget(done.complete(Left(t)))
unsafeRunAndForget(go)
}
def onError(t: Throwable): Unit =
unsafeRunAndForget(done.complete(Left(t)))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError { case t => F.delay(logger.error(t)("Error in servlet write listener")) }
)
}))
def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError { case t =>
F.delay(logger.error(t)("Error in servlet write listener"))
}
)
}))

val writes = Stream.emit(Init) ++ response.body.chunks.map(Bytes(_)) ++ Stream.emit(End)
val writes = Stream.emit(Init) ++ response.body.chunks.map(Bytes(_)) ++ Stream.emit(End)

Stream
.eval(writeBody >> done.get.rethrow)
.mergeHaltL(writes.foreach(q.offer))
.compile
.drain
Stream
.eval(writeBody >> done.get.rethrow)
.mergeHaltL(writes.foreach(q.offer))
.compile
.drain
}
}
}
}
Expand Down