Skip to content

Commit ccd17c0

Browse files
committed
Add compressor, implement compression in HttpClient
1 parent b2a981e commit ccd17c0

File tree

9 files changed

+362
-27
lines changed

9 files changed

+362
-27
lines changed

build.sbt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ val pekkoStreams = "org.apache.pekko" %% "pekko-stream" % pekkoStreamVersion
153153
val scalaTest = libraryDependencies ++= Seq("freespec", "funsuite", "flatspec", "wordspec", "shouldmatchers").map(m =>
154154
"org.scalatest" %%% s"scalatest-$m" % "3.2.19" % Test
155155
)
156+
val scalaTestPlusScalaCheck = libraryDependencies += "org.scalatestplus" %% "scalacheck-1-18" % "3.2.19.0" % Test
156157

157158
val zio1Version = "1.0.18"
158159
val zio2Version = "2.1.14"
@@ -318,7 +319,8 @@ lazy val core = (projectMatrix in file("core"))
318319
"com.softwaremill.sttp.shared" %%% "core" % sttpSharedVersion,
319320
"com.softwaremill.sttp.shared" %%% "ws" % sttpSharedVersion
320321
),
321-
scalaTest
322+
scalaTest,
323+
scalaTestPlusScalaCheck
322324
)
323325
.settings(testServerSettings)
324326
.jvmPlatform(
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package sttp.client4.internal.compression
2+
3+
import sttp.client4._
4+
import sttp.model.Encodings
5+
6+
import Compressor._
7+
import java.io.FileInputStream
8+
import java.nio.ByteBuffer
9+
import java.util.zip.DeflaterInputStream
10+
import java.util.zip.Deflater
11+
import java.io.ByteArrayOutputStream
12+
13+
private[client4] trait Compressor {
14+
def encoding: String
15+
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R]
16+
}
17+
18+
private[client4] object GZipDefaultCompressor extends Compressor {
19+
val encoding: String = Encodings.Gzip
20+
21+
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
22+
body match {
23+
case NoBody => NoBody
24+
case StringBody(s, encoding, defaultContentType) =>
25+
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
26+
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
27+
case ByteBufferBody(b, defaultContentType) =>
28+
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
29+
case InputStreamBody(b, defaultContentType) =>
30+
InputStreamBody(GZIPCompressingInputStream(b), defaultContentType)
31+
case StreamBody(b) => streamsNotSupported
32+
case FileBody(f, defaultContentType) =>
33+
InputStreamBody(GZIPCompressingInputStream(new FileInputStream(f.toFile)), defaultContentType)
34+
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
35+
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
36+
}
37+
38+
private def byteArray(bytes: Array[Byte]): Array[Byte] = {
39+
val bos = new java.io.ByteArrayOutputStream()
40+
val gzip = new java.util.zip.GZIPOutputStream(bos)
41+
gzip.write(bytes)
42+
gzip.close()
43+
bos.toByteArray()
44+
}
45+
}
46+
47+
private[client4] object DeflateDefaultCompressor extends Compressor {
48+
val encoding: String = Encodings.Deflate
49+
50+
def apply[R](body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
51+
body match {
52+
case NoBody => NoBody
53+
case StringBody(s, encoding, defaultContentType) =>
54+
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
55+
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
56+
case ByteBufferBody(b, defaultContentType) =>
57+
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
58+
case InputStreamBody(b, defaultContentType) =>
59+
InputStreamBody(DeflaterInputStream(b), defaultContentType)
60+
case StreamBody(b) => streamsNotSupported
61+
case FileBody(f, defaultContentType) =>
62+
InputStreamBody(DeflaterInputStream(new FileInputStream(f.toFile)), defaultContentType)
63+
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
64+
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
65+
}
66+
67+
private def byteArray(bytes: Array[Byte]): Array[Byte] = {
68+
val deflater = new Deflater()
69+
try {
70+
deflater.setInput(bytes)
71+
deflater.finish()
72+
val byteArrayOutputStream = new ByteArrayOutputStream()
73+
val readBuffer = new Array[Byte](1024)
74+
75+
while (!deflater.finished()) {
76+
val readCount = deflater.deflate(readBuffer)
77+
if (readCount > 0) {
78+
byteArrayOutputStream.write(readBuffer, 0, readCount)
79+
}
80+
}
81+
82+
byteArrayOutputStream.toByteArray
83+
} finally deflater.end()
84+
}
85+
}
86+
87+
private[client4] object Compressor {
88+
def compressIfNeeded[T, R](
89+
request: GenericRequest[T, R],
90+
compressors: List[Compressor]
91+
): (GenericRequestBody[R], Option[Long]) =
92+
request.options.compressRequestBody match {
93+
case Some(encoding) =>
94+
val compressedBody = compressors.find(_.encoding.equalsIgnoreCase(encoding)) match {
95+
case Some(compressor) => compressor(request.body, encoding)
96+
case None => throw new IllegalArgumentException(s"Unsupported encoding: $encoding")
97+
}
98+
99+
val contentLength = calculateContentLength(compressedBody)
100+
(compressedBody, contentLength)
101+
102+
case None => (request.body, request.contentLength)
103+
}
104+
105+
private def calculateContentLength[R](body: GenericRequestBody[R]): Option[Long] = body match {
106+
case NoBody => None
107+
case StringBody(b, e, _) => Some(b.getBytes(e).length.toLong)
108+
case ByteArrayBody(b, _) => Some(b.length.toLong)
109+
case ByteBufferBody(b, _) => None
110+
case InputStreamBody(b, _) => None
111+
case FileBody(f, _) => Some(f.toFile.length())
112+
case StreamBody(_) => None
113+
case MultipartStreamBody(parts) => None
114+
case BasicMultipartBody(parts) => None
115+
}
116+
117+
private[compression] def compressingMultipartBodiesNotSupported: Nothing =
118+
throw new IllegalArgumentException("Multipart bodies cannot be compressed")
119+
120+
private[compression] def streamsNotSupported: Nothing =
121+
throw new IllegalArgumentException("Streams are not supported")
122+
123+
private[compression] def byteBufferToArray(inputBuffer: ByteBuffer): Array[Byte] =
124+
if (inputBuffer.hasArray()) {
125+
inputBuffer.array()
126+
} else {
127+
val inputBytes = new Array[Byte](inputBuffer.remaining())
128+
inputBuffer.get(inputBytes)
129+
inputBytes
130+
}
131+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package sttp.client4.internal.compression
2+
3+
import java.io.{ByteArrayInputStream, IOException, InputStream}
4+
import java.util.zip.{CRC32, Deflater}
5+
6+
// based on:
7+
// https://github.com/http4k/http4k/blob/master/core/core/src/main/kotlin/org/http4k/filter/Gzip.kt#L124
8+
// https://stackoverflow.com/questions/11036280/compress-an-inputstream-with-gzip
9+
private[client4] class GZIPCompressingInputStream(
10+
source: InputStream,
11+
compressionLevel: Int = java.util.zip.Deflater.DEFAULT_COMPRESSION
12+
) extends InputStream {
13+
14+
private object State extends Enumeration {
15+
type State = Value
16+
val HEADER, DATA, FINALISE, TRAILER, DONE = Value
17+
}
18+
19+
import State._
20+
21+
private val GZIP_MAGIC = 0x8b1f
22+
private val HEADER_DATA: Array[Byte] = Array(
23+
GZIP_MAGIC.toByte,
24+
(GZIP_MAGIC >> 8).toByte,
25+
Deflater.DEFLATED.toByte,
26+
0,
27+
0,
28+
0,
29+
0,
30+
0,
31+
0,
32+
0
33+
)
34+
private val INITIAL_BUFFER_SIZE = 8192
35+
36+
private val deflater = new Deflater(Deflater.DEFLATED, true)
37+
deflater.setLevel(compressionLevel)
38+
39+
private val crc = new CRC32()
40+
private var trailer: ByteArrayInputStream = _
41+
private val header = new ByteArrayInputStream(HEADER_DATA)
42+
43+
private var deflationBuffer: Array[Byte] = new Array[Byte](INITIAL_BUFFER_SIZE)
44+
private var stage: State = HEADER
45+
46+
override def read(): Int = {
47+
val readBytes = new Array[Byte](1)
48+
var bytesRead = 0
49+
while (bytesRead == 0)
50+
bytesRead = read(readBytes, 0, 1)
51+
if (bytesRead != -1) readBytes(0) & 0xff else -1
52+
}
53+
54+
@throws[IOException]
55+
override def read(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = stage match {
56+
case HEADER =>
57+
val bytesRead = header.read(readBuffer, readOffset, readLength)
58+
if (header.available() == 0) stage = DATA
59+
bytesRead
60+
61+
case DATA =>
62+
if (!deflater.needsInput) {
63+
deflatePendingInput(readBuffer, readOffset, readLength)
64+
} else {
65+
if (deflationBuffer.length < readLength) {
66+
deflationBuffer = new Array[Byte](readLength)
67+
}
68+
69+
val bytesRead = source.read(deflationBuffer, 0, readLength)
70+
if (bytesRead <= 0) {
71+
stage = FINALISE
72+
deflater.finish()
73+
0
74+
} else {
75+
crc.update(deflationBuffer, 0, bytesRead)
76+
deflater.setInput(deflationBuffer, 0, bytesRead)
77+
deflatePendingInput(readBuffer, readOffset, readLength)
78+
}
79+
}
80+
81+
case FINALISE =>
82+
if (deflater.finished()) {
83+
stage = TRAILER
84+
val crcValue = crc.getValue.toInt
85+
val totalIn = deflater.getTotalIn
86+
trailer = createTrailer(crcValue, totalIn)
87+
0
88+
} else {
89+
deflater.deflate(readBuffer, readOffset, readLength, Deflater.FULL_FLUSH)
90+
}
91+
92+
case TRAILER =>
93+
val bytesRead = trailer.read(readBuffer, readOffset, readLength)
94+
if (trailer.available() == 0) stage = DONE
95+
bytesRead
96+
97+
case DONE => -1
98+
}
99+
100+
private def deflatePendingInput(readBuffer: Array[Byte], readOffset: Int, readLength: Int): Int = {
101+
var bytesCompressed = 0
102+
while (!deflater.needsInput && readLength - bytesCompressed > 0)
103+
bytesCompressed += deflater.deflate(
104+
readBuffer,
105+
readOffset + bytesCompressed,
106+
readLength - bytesCompressed,
107+
Deflater.FULL_FLUSH
108+
)
109+
bytesCompressed
110+
}
111+
112+
private def createTrailer(crcValue: Int, totalIn: Int): ByteArrayInputStream =
113+
new ByteArrayInputStream(
114+
Array(
115+
(crcValue >> 0).toByte,
116+
(crcValue >> 8).toByte,
117+
(crcValue >> 16).toByte,
118+
(crcValue >> 24).toByte,
119+
(totalIn >> 0).toByte,
120+
(totalIn >> 8).toByte,
121+
(totalIn >> 16).toByte,
122+
(totalIn >> 24).toByte
123+
)
124+
)
125+
126+
override def available(): Int = if (stage == DONE) 0 else 1
127+
128+
@throws[IOException]
129+
override def close(): Unit = {
130+
source.close()
131+
deflater.end()
132+
if (trailer != null) trailer.close()
133+
header.close()
134+
}
135+
136+
crc.reset()
137+
}

core/src/main/scala/sttp/client4/request.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import sttp.attributes.AttributeMap
2525
* ability to send and receive streaming bodies) or [[sttp.capabilities.WebSockets]] (the ability to handle websocket
2626
* requests).
2727
*/
28-
trait GenericRequest[+T, -R] extends RequestBuilder[GenericRequest[T, R]] with RequestMetadata {
28+
sealed trait GenericRequest[+T, -R] extends RequestBuilder[GenericRequest[T, R]] with RequestMetadata {
2929
def body: GenericRequestBody[R]
3030
def response: ResponseAsDelegate[T, R]
3131

core/src/main/scalajvm/sttp/client4/httpclient/HttpClientBackend.scala

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
package sttp.client4.httpclient
22

3-
import sttp.capabilities.{Effect, Streams}
3+
import sttp.capabilities.Effect
4+
import sttp.capabilities.Streams
5+
import sttp.client4.Backend
6+
import sttp.client4.BackendOptions
47
import sttp.client4.BackendOptions.Proxy
8+
import sttp.client4.GenericBackend
9+
import sttp.client4.GenericRequest
10+
import sttp.client4.MultipartBody
11+
import sttp.client4.Response
12+
import sttp.client4.SttpClientException
513
import sttp.client4.httpclient.HttpClientBackend.EncodingHandler
614
import sttp.client4.internal.SttpToJavaConverters.toJavaFunction
7-
import sttp.client4.internal.httpclient.{BodyFromHttpClient, BodyToHttpClient, Sequencer}
8-
import sttp.client4.internal.ws.SimpleQueue
9-
import sttp.client4.{
10-
Backend,
11-
BackendOptions,
12-
GenericBackend,
13-
GenericRequest,
14-
MultipartBody,
15-
Response,
16-
SttpClientException
17-
}
18-
import sttp.model.HttpVersion.{HTTP_1_1, HTTP_2}
15+
import sttp.client4.internal.httpclient.BodyFromHttpClient
16+
import sttp.client4.internal.httpclient.BodyToHttpClient
1917
import sttp.model._
18+
import sttp.model.HttpVersion.HTTP_1_1
19+
import sttp.model.HttpVersion.HTTP_2
2020
import sttp.monad.MonadError
2121
import sttp.monad.syntax._
2222
import sttp.ws.WebSocket
2323

24+
import java.net.Authenticator
2425
import java.net.Authenticator.RequestorType
25-
import java.net.http.{HttpClient, HttpRequest, HttpResponse, WebSocket => JWebSocket}
26-
import java.net.{Authenticator, PasswordAuthentication}
26+
import java.net.PasswordAuthentication
27+
import java.net.http.HttpClient
28+
import java.net.http.HttpRequest
29+
import java.net.http.HttpResponse
30+
import java.net.http.{WebSocket => JWebSocket}
2731
import java.time.{Duration => JDuration}
28-
import java.util.concurrent.{Executor, ThreadPoolExecutor}
32+
import java.util.concurrent.Executor
33+
import java.util.concurrent.ThreadPoolExecutor
2934
import java.util.function
3035
import scala.collection.JavaConverters._
3136

@@ -117,7 +122,7 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B](
117122
resBody.left
118123
.map { is =>
119124
encoding
120-
.filterNot(e => code.equals(StatusCode.NoContent) || request.autoDecompressionDisabled || e.isEmpty)
125+
.filterNot(e => code.equals(StatusCode.NoContent) || !request.autoDecompressionEnabled || e.isEmpty)
121126
.map(e => customEncodingHandler.applyOrElse((is, e), standardEncoding.tupled))
122127
.getOrElse(is)
123128
}
@@ -166,16 +171,16 @@ abstract class HttpClientBackend[F[_], S <: Streams[S], P, B](
166171
}
167172
override def close(): F[Unit] =
168173
if (closeClient) {
169-
monad.eval(
170-
client
174+
monad.eval {
175+
val _ = client
171176
.executor()
172177
.map[Unit](new function.Function[Executor, Unit] {
173178
override def apply(t: Executor): Unit = t match {
174179
case tpe: ThreadPoolExecutor => tpe.shutdown()
175180
case _ => ()
176181
}
177182
})
178-
)
183+
}
179184
} else {
180185
monad.unit(())
181186
}

core/src/main/scalajvm/sttp/client4/httpclient/HttpClientSyncBackend.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ class HttpClientSyncBackend private (
6363
val isOpen: AtomicBoolean = new AtomicBoolean(false)
6464
val responseCell = new ArrayBlockingQueue[Either[Throwable, () => Response[T]]](1)
6565

66-
def fillCellError(t: Throwable): Unit = responseCell.add(Left(t)): Unit
67-
def fillCell(wr: () => Response[T]): Unit = responseCell.add(Right(wr)): Unit
66+
def fillCellError(t: Throwable): Unit = { val _ = responseCell.add(Left(t)) }
67+
def fillCell(wr: () => Response[T]): Unit = { val _ = responseCell.add(Right(wr)) }
6868

6969
val listener = new DelegatingWebSocketListener(
7070
new AddToQueueListener(queue, isOpen),
7171
ws => {
72-
val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, _.get(): Unit)
72+
val webSocket = new WebSocketImpl[Identity](ws, queue, isOpen, sequencer, monad, cf => { val _ = cf.get() })
7373
val baseResponse = Response((), StatusCode.SwitchingProtocols, "", Nil, Nil, request.onlyMetadata)
7474
val body = () => bodyFromHttpClient(Right(webSocket), request.response, baseResponse)
7575
fillCell(() => baseResponse.copy(body = body()))

0 commit comments

Comments
 (0)