diff --git a/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala index 32cd3a97..1840c0fa 100644 --- a/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala @@ -61,7 +61,7 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0. val ctx = servletRequest.startAsync() ctx.setTimeout(asyncTimeoutMillis) // Must be done on the container thread for Tomcat's sake when using async I/O. - val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _ + val bodyWriter = servletIo.bodyWriter(servletResponse) _ val result = F .attempt( toRequest(servletRequest).fold( diff --git a/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala index 266d4ee0..a6f21155 100644 --- a/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala @@ -58,7 +58,7 @@ class BlockingHttp4sServlet[F[_]] private ( ): Unit = { val result = F .defer { - val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _ + val bodyWriter = servletIo.bodyWriter(servletResponse) _ val render = toRequest(servletRequest).fold( onParseFailure(_, servletResponse, bodyWriter), diff --git a/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala index c463c63c..737d3ae0 100644 --- a/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala @@ -161,7 +161,7 @@ abstract class Http4sServlet[F[_]]( uri = uri, httpVersion = version, headers = toHeaders(req), - body = servletIo.requestBody(req, dispatcher), + body = servletIo.requestBody(req), attributes = attributes, ) diff --git a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala index 6d179e35..8ee602c9 100644 --- a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala +++ b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala @@ -43,12 +43,21 @@ sealed abstract class ServletIo[F[_]: Async] { protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F] @nowarn("cat=deprecation") + def requestBody( + servletRequest: HttpServletRequest + ): Stream[F, Byte] = + reader(servletRequest) + + @deprecated( + "Prefer requestBody(HttpServletRequest), which creates its own request-scoped dispatcher", + "0.23.14", + ) def requestBody( servletRequest: HttpServletRequest, dispatcher: Dispatcher[F], ): Stream[F, Byte] = { val _ = dispatcher // unused - reader(servletRequest) + requestBody(servletRequest) } /** May install a listener on the servlet response. */ @@ -56,12 +65,19 @@ sealed abstract class ServletIo[F[_]: Async] { protected[servlet] def initWriter(servletResponse: HttpServletResponse): BodyWriter[F] @nowarn("cat=deprecation") - def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])( + def bodyWriter(servletResponse: HttpServletResponse)( response: Response[F] - ): F[Unit] = { - val _ = dispatcher + ): F[Unit] = initWriter(servletResponse)(response) - } + + @deprecated( + "Prefer bodyWriter(HttpServletResponse), which creates its own response-scoped dispatcher", + "0.23.14", + ) + def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])( + response: Response[F] + ): F[Unit] = + bodyWriter(servletResponse)(response) } /** Use standard blocking reads and writes. @@ -210,8 +226,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl * https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala */ override def requestBody( - servletRequest: HttpServletRequest, - dispatcher: Dispatcher[F], + servletRequest: HttpServletRequest ): Stream[F, Byte] = { sealed trait Read final case class Bytes(chunk: Chunk[Byte]) extends Read @@ -220,55 +235,58 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl Stream.eval(F.delay(servletRequest.getInputStream)).flatMap { in => Stream.eval(Queue.bounded[F, Read](4)).flatMap { q => - val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener { - var buf: Array[Byte] = _ - unsafeReplaceBuffer() - - def unsafeReplaceBuffer() = - buf = new Array[Byte](chunkSize) - - def onDataAvailable(): Unit = { - def loopIfReady = - F.delay(in.isReady()).flatMap { - case true => go - case false => F.unit - } + Stream.resource(Dispatcher.sequential[F]).flatMap { dispatcher => + val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener { + var buf: Array[Byte] = _ + unsafeReplaceBuffer() + + def unsafeReplaceBuffer() = + buf = new Array[Byte](chunkSize) + + def onDataAvailable(): Unit = { + def loopIfReady = + F.delay(in.isReady()).flatMap { + case true => go + case false => F.unit + } - def go: F[Unit] = - F.delay(in.read(buf)).flatMap { - case len if len == chunkSize => - // We used the whole buffer. Replace it new before next read. - q.offer(Bytes(Chunk.array(buf))) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady - case len if len >= 0 => - // Got a partial chunk. Copy it, and reuse the current buffer. - q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady - case _ => - F.unit - } + def go: F[Unit] = + F.delay(in.read(buf)).flatMap { + case len if len == chunkSize => + // We used the whole buffer. Replace it new before next read. + q.offer(Bytes(Chunk.array(buf))) >> F + .delay(unsafeReplaceBuffer()) >> loopIfReady + case len if len >= 0 => + // Got a partial chunk. Copy it, and reuse the current buffer. + q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady + case _ => + F.unit + } - unsafeRunAndForget(go) - } + unsafeRunAndForget(go) + } - def onAllDataRead(): Unit = - unsafeRunAndForget(q.offer(End)) + def onAllDataRead(): Unit = + unsafeRunAndForget(q.offer(End)) - def onError(t: Throwable): Unit = - unsafeRunAndForget(q.offer(Error(t))) + def onError(t: Throwable): Unit = + unsafeRunAndForget(q.offer(Error(t))) - def unsafeRunAndForget[A](fa: F[A]): Unit = - dispatcher.unsafeRunAndForget( - fa.onError { case t => F.delay(logger.error(t)("Error in servlet read listener")) } - ) - }))) + def unsafeRunAndForget[A](fa: F[A]): Unit = + dispatcher.unsafeRunAndForget( + fa.onError { case t => F.delay(logger.error(t)("Error in servlet read listener")) } + ) + }))) - def pullBody: Pull[F, Byte, Unit] = - Pull.eval(q.take).flatMap { - case Bytes(chunk) => Pull.output(chunk) >> pullBody - case End => Pull.done - case Error(t) => Pull.raiseError[F](t) - } + def pullBody: Pull[F, Byte, Unit] = + Pull.eval(q.take).flatMap { + case Bytes(chunk) => Pull.output(chunk) >> pullBody + case End => Pull.done + case Error(t) => Pull.raiseError[F](t) + } - pullBody.stream.concurrently(readBody) + pullBody.stream.concurrently(readBody) + } } } } @@ -372,8 +390,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl * https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala */ override def bodyWriter( - servletResponse: HttpServletResponse, - dispatcher: Dispatcher[F], + servletResponse: HttpServletResponse )(response: Response[F]): F[Unit] = { sealed trait Write final case class Bytes(chunk: Chunk[Byte]) extends Write @@ -385,54 +402,59 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl F.delay(servletResponse.getOutputStream).flatMap { out => Queue.bounded[F, Write](4).flatMap { q => Deferred[F, Either[Throwable, Unit]].flatMap { done => - val writeBody = F.delay(out.setWriteListener(new WriteListener { - def onWritePossible(): Unit = { - def loopIfReady = F.delay(out.isReady()).flatMap { - case true => go - case false => F.unit - } + Dispatcher.sequential[F].use { dispatcher => + val writeBody = F.delay(out.setWriteListener(new WriteListener { + def onWritePossible(): Unit = { + def loopIfReady = F.delay(out.isReady()).flatMap { + case true => go + case false => F.unit + } - def flush = - if (autoFlush) { - F.delay(out.isReady()).flatMap { - case true => F.delay(out.flush()) >> loopIfReady - case false => F.unit + def flush = + if (autoFlush) { + F.delay(out.isReady()).flatMap { + case true => F.delay(out.flush()) >> loopIfReady + case false => F.unit + } + } else + loopIfReady + + def go: F[Unit] = + q.take.flatMap { + case Bytes(slice: Chunk.ArraySlice[_]) => + F.delay( + out + .write(slice.values.asInstanceOf[Array[Byte]], slice.offset, slice.length) + ) >> flush + case Bytes(chunk) => + F.delay(out.write(chunk.toArray)) >> flush + case End => + F.delay(out.flush()) >> done.complete(Either.unit).attempt.void + case Init => + if (autoFlush) flush else go } - } else - loopIfReady - - def go: F[Unit] = - q.take.flatMap { - case Bytes(slice: Chunk.ArraySlice[_]) => - F.delay( - out.write(slice.values.asInstanceOf[Array[Byte]], slice.offset, slice.length) - ) >> flush - case Bytes(chunk) => - F.delay(out.write(chunk.toArray)) >> flush - case End => - F.delay(out.flush()) >> done.complete(Either.unit).attempt.void - case Init => - if (autoFlush) flush else go - } - unsafeRunAndForget(go) - } - def onError(t: Throwable): Unit = - unsafeRunAndForget(done.complete(Left(t))) + unsafeRunAndForget(go) + } + def onError(t: Throwable): Unit = + unsafeRunAndForget(done.complete(Left(t))) - def unsafeRunAndForget[A](fa: F[A]): Unit = - dispatcher.unsafeRunAndForget( - fa.onError { case t => F.delay(logger.error(t)("Error in servlet write listener")) } - ) - })) + def unsafeRunAndForget[A](fa: F[A]): Unit = + dispatcher.unsafeRunAndForget( + fa.onError { case t => + F.delay(logger.error(t)("Error in servlet write listener")) + } + ) + })) - val writes = Stream.emit(Init) ++ response.body.chunks.map(Bytes(_)) ++ Stream.emit(End) + val writes = Stream.emit(Init) ++ response.body.chunks.map(Bytes(_)) ++ Stream.emit(End) - Stream - .eval(writeBody >> done.get.rethrow) - .mergeHaltL(writes.foreach(q.offer)) - .compile - .drain + Stream + .eval(writeBody >> done.get.rethrow) + .mergeHaltL(writes.foreach(q.offer)) + .compile + .drain + } } } }