diff --git a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala index b56577b8c5..85e61c257b 100644 --- a/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala +++ b/server/akka-http-server/src/test/scala/sttp/tapir/server/akkahttp/AkkaHttpServerTest.scala @@ -206,7 +206,9 @@ class AkkaHttpServerTest extends TestSuite with EitherValues { def drainAkka(stream: AkkaStreams.BinaryStream): Future[Unit] = stream.runWith(Sink.ignore).map(_ => ()) - new AllServerTests(createServerTest, interpreter, backend).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ + new ServerMultipartTests(createServerTest, chunkingSupport = false) + .tests() ++ // chunking disabled, akka-http rejects content-length with transfer-encoding new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new ServerWebSocketTests( createServerTest, diff --git a/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala b/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala index 4181406515..0eada6fffc 100644 --- a/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala +++ b/server/armeria-server/cats/src/test/scala/sttp/tapir/server/armeria/cats/ArmeriaCatsServerTest.scala @@ -16,7 +16,10 @@ class ArmeriaCatsServerTest extends TestSuite { def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = stream.compile.drain.void - new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false) + .tests() ++ + new ServerMultipartTests(createServerTest, chunkingSupport = false) + .tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++ new ServerStreamingTests(createServerTest).tests(Fs2Streams[IO])(drainFs2) } diff --git a/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala b/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala index 3c7d99de62..f6c43cbfa5 100644 --- a/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala +++ b/server/armeria-server/src/test/scala/sttp/tapir/server/armeria/ArmeriaFutureServerTest.scala @@ -15,7 +15,10 @@ class ArmeriaFutureServerTest extends TestSuite { val interpreter = new ArmeriaTestFutureServerInterpreter() val createServerTest = new DefaultCreateServerTest(backend, interpreter) - new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false) + .tests() ++ + new ServerMultipartTests(createServerTest, chunkingSupport = false) + .tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++ new ServerStreamingTests(createServerTest, maxLengthSupported = false).tests(ArmeriaStreams)(_ => Future.unit) } diff --git a/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala b/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala index f18e3dd5b9..e2df96856a 100644 --- a/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala +++ b/server/armeria-server/zio/src/test/scala/sttp/tapir/server/armeria/zio/ArmeriaZioServerTest.scala @@ -19,7 +19,10 @@ class ArmeriaZioServerTest extends TestSuite { def drainZStream(zStream: ZioStreams.BinaryStream): Task[Unit] = zStream.run(ZSink.drain) - new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, basic = false, options = false, maxContentLength = false, multipart = false) + .tests() ++ + new ServerMultipartTests(createServerTest, chunkingSupport = false) + .tests() ++ // chunking disabled, Armeria rejects content-length with transfer-encoding new ServerBasicTests(createServerTest, interpreter, supportsUrlEncodedPathSegments = false, maxContentLength = false).tests() ++ new ServerStreamingTests(createServerTest).tests(ZioStreams)(drainZStream) } diff --git a/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala b/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala index f3e24bf61f..1af54cc60c 100644 --- a/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala +++ b/server/jdkhttp-server/src/test/scala/sttp/tapir/server/jdkhttp/JdkHttpServerTest.scala @@ -36,7 +36,9 @@ class JdkHttpServerTest extends TestSuite with EitherValues { val createServerTest = new DefaultCreateServerTest(backend, interpreter) new ServerBasicTests(createServerTest, interpreter, invulnerableToUnsanitizedHeaders = false).tests() ++ - new AllServerTests(createServerTest, interpreter, backend, basic = false).tests() + new ServerMultipartTests(createServerTest, chunkingSupport = false) + .tests() ++ // chunking disabled, backend rejects content-length with transfer-encoding + new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false).tests() }) } } diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala index 98b0742b74..e9d177bfff 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/NettyCatsRequestBody.scala @@ -1,18 +1,27 @@ package sttp.tapir.server.netty.cats.internal import cats.effect.Async +import cats.effect.kernel.{Resource, Sync} import cats.syntax.all._ import fs2.Chunk +import fs2.interop.reactivestreams.StreamSubscriber import fs2.io.file.{Files, Path} import io.netty.handler.codec.http.HttpContent +import io.netty.handler.codec.http.multipart.{DefaultHttpDataFactory, HttpData, HttpPostRequestDecoder} +import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher +import sttp.capabilities.StreamMaxLengthExceededException import sttp.capabilities.fs2.Fs2Streams +import sttp.model.Part import sttp.monad.MonadError -import sttp.tapir.TapirFile import sttp.tapir.integ.cats.effect.CatsMonadError import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.RawValue import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} -import sttp.capabilities.WebSockets +import sttp.tapir.{RawBodyType, RawPart, TapirFile} + +import java.io.File + private[cats] class NettyCatsRequestBody[F[_]: Async]( val createFile: ServerRequest => F[TapirFile], @@ -21,6 +30,63 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( override implicit val monad: MonadError[F] = new CatsMonadError() + def publisherToMultipart( + nettyRequest: StreamedHttpRequest, + serverRequest: ServerRequest, + m: RawBodyType.MultipartBody, + maxBytes: Option[Long] + ): F[RawValue[Seq[RawPart]]] = { + fs2.Stream + .resource( + Resource.make(Sync[F].delay(new HttpPostRequestDecoder(NettyCatsRequestBody.multiPartDataFactory, nettyRequest)))(d => + Sync[F].blocking(d.destroy()) // after the stream finishes or fails, decoder data has to be cleaned up + ) + ) + .flatMap { decoder => + fs2.Stream + .eval(StreamSubscriber[F, HttpContent](bufferSize = 1)) + .flatMap(s => s.sub.stream(Sync[F].delay(nettyRequest.subscribe(s)))) + .evalMapAccumulate({ + (decoder, 0L) + })({ case ((decoder, processedBytesNum), httpContent) => + monad + .blocking { + val newProcessedBytes = if (httpContent.content() != null) { + val processedBytesAndContentBytes = processedBytesNum + httpContent.content().readableBytes() + maxBytes.foreach { max => + if (max < processedBytesAndContentBytes) { + throw new StreamMaxLengthExceededException(max) + } + } + processedBytesAndContentBytes + } else processedBytesNum + + // this operation is the one that does potential I/O (writing files) + decoder.offer(httpContent) + val parts = Stream + .continually(if (decoder.hasNext) { + val next = decoder.next() + next + } else null) + .takeWhile(_ != null) + .toVector + + ( + (decoder, newProcessedBytes), + parts + ) + } + }) + .map(_._2) + .map(_.flatMap(p => m.partType(p.getName()).map((p, _)).toList)) + .evalMap(_.traverse { case (data, partType) => toRawPart(serverRequest, data, partType).map(_.asInstanceOf[Part[Any]]) }) + } + .compile + .toVector + .map(_.flatten) + .map(RawValue.fromParts(_)) + } + override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] = streamCompatible.fromPublisher(publisher, maxBytes).compile.to(Chunk).map(_.toArray[Byte]) @@ -32,4 +98,13 @@ private[cats] class NettyCatsRequestBody[F[_]: Async]( ) .compile .drain + + override def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit] = + fs2.Stream.emits(bytes).through(Files.forAsync[F].writeAll(Path.fromNioPath(file.toPath))).compile.drain + +} + +private[cats] object NettyCatsRequestBody { + val multiPartDataFactory = + new DefaultHttpDataFactory() // writes to memory, then switches to disk if exceeds MINSIZE (16kB), check other constructors. } diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index d02cb96a3e..7eff7aa5bb 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -12,8 +12,9 @@ import sttp.tapir.tests.{Test, TestSuite} import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration +import org.scalatest.matchers.should.Matchers -class NettyCatsServerTest extends TestSuite with EitherValues { +class NettyCatsServerTest extends TestSuite with EitherValues with Matchers { override def tests: Resource[IO, List[Test]] = backendResource.flatMap { backend => @@ -41,6 +42,12 @@ class NettyCatsServerTest extends TestSuite with EitherValues { new ServerCancellationTests(createServerTest)(m, IO.asyncForIO).tests() ++ new NettyFs2StreamingCancellationTest(createServerTest).tests() ++ new ServerGracefulShutdownTests(createServerTest, ioSleeper).tests() ++ + new ServerMultipartTests( + createServerTest, + partContentTypeHeaderSupport = false, + partOtherHeaderSupport = false, + multipartResponsesSupport = false + ).tests() ++ new ServerWebSocketTests( createServerTest, Fs2Streams[IO], diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala index 2316217275..a6c2e69dd7 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyFutureRequestBody.scala @@ -5,16 +5,27 @@ import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher import sttp.capabilities import sttp.monad.{FutureMonad, MonadError} -import sttp.tapir.TapirFile import sttp.tapir.capabilities.NoStreams import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.RawValue import sttp.tapir.server.netty.internal.reactivestreams._ +import sttp.tapir.{RawBodyType, RawPart, TapirFile} +import java.io.File import scala.concurrent.{ExecutionContext, Future} private[netty] class NettyFutureRequestBody(val createFile: ServerRequest => Future[TapirFile])(implicit ec: ExecutionContext) extends NettyRequestBody[Future, NoStreams] { + override def publisherToMultipart( + nettyRequest: StreamedHttpRequest, + serverRequest: ServerRequest, + m: RawBodyType.MultipartBody, + maxBytes: Option[Long] + ): Future[RawValue[Seq[RawPart]]] = Future.failed(new UnsupportedOperationException("Multipart requests not supported.")) + + override def writeBytesToFile(bytes: Array[Byte], file: File): Future[Unit] = Future.failed(new UnsupportedOperationException) + override val streams: capabilities.Streams[NoStreams] = NoStreams override implicit val monad: MonadError[Future] = new FutureMonad() diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala index bcae1a9e39..0a19395c67 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyRequestBody.scala @@ -16,6 +16,15 @@ import sttp.tapir.{FileRange, InputStreamRange, RawBodyType, TapirFile} import java.io.InputStream import java.nio.ByteBuffer +import scala.collection.JavaConverters._ +import sttp.tapir.RawPart +import io.netty.handler.codec.http.multipart.InterfaceHttpData +import sttp.model.Part +import io.netty.handler.codec.http.multipart.HttpData +import io.netty.handler.codec.http.multipart.FileUpload +import java.io.ByteArrayInputStream +import java.io.File + /** Common logic for processing request body in all Netty backends. It requires particular backends to implement a few operations. */ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody[F, S] { @@ -37,6 +46,16 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody */ def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): F[Array[Byte]] + /** Reads the reactive stream emitting HttpData into a vector of parts. Implementation-specific, as file manipulations and stream + * processing logic can be different for different backends. + */ + def publisherToMultipart( + nettyRequest: StreamedHttpRequest, + serverRequest: ServerRequest, + m: RawBodyType.MultipartBody, + maxBytes: Option[Long] + ): F[RawValue[Seq[RawPart]]] + /** Backend-specific way to process all elements emitted by a Publisher[HttpContent] and write their bytes into a file. * * @param serverRequest @@ -50,6 +69,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody */ def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): F[Unit] + def writeBytesToFile(bytes: Array[Byte], file: File): F[Unit] + override def toRaw[RAW]( serverRequest: ServerRequest, bodyType: RawBodyType[RAW], @@ -70,8 +91,8 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody file <- createFile(serverRequest) _ <- writeToFile(serverRequest, file, maxBytes) } yield RawValue(FileRange(file), Seq(FileRange(file))) - case _: RawBodyType.MultipartBody => - monad.error(new UnsupportedOperationException) + case m: RawBodyType.MultipartBody => + publisherToMultipart(serverRequest.underlying.asInstanceOf[StreamedHttpRequest], serverRequest, m, maxBytes) } private def readAllBytes(serverRequest: ServerRequest, maxBytes: Option[Long]): F[Array[Byte]] = @@ -96,4 +117,72 @@ private[netty] trait NettyRequestBody[F[_], S <: Streams[S]] extends RequestBody throw new UnsupportedOperationException(s"Unexpected Netty request of type ${other.getClass.getName}") } } + + protected def toRawPart[R]( + serverRequest: ServerRequest, + data: InterfaceHttpData, + partType: RawBodyType[R] + ): F[Part[R]] = { + val partName = data.getName() + data match { + case httpData: HttpData => + // TODO filename* attribute is not used by netty. Non-ascii filenames like https://github.com/http4s/http4s/issues/5809 are unsupported. + toRawPartHttpData(partName, serverRequest, httpData, partType) + case unsupportedDataType => + monad.error(new UnsupportedOperationException(s"Unsupported multipart data type: $unsupportedDataType in part $partName")) + } + } + + private def toRawPartHttpData[R]( + partName: String, + serverRequest: ServerRequest, + httpData: HttpData, + partType: RawBodyType[R] + ): F[Part[R]] = { + val fileName = httpData match { + case fileUpload: FileUpload => Option(fileUpload.getFilename()) + case _ => None + } + partType match { + case RawBodyType.StringBody(defaultCharset) => + // TODO otherDispositionParams not supported. They are normally a part of the content-disposition part header, but this header is not directly accessible, they are extracted internally by the decoder. + val charset = if (httpData.getCharset() != null) httpData.getCharset() else defaultCharset + readHttpData(httpData, _.getString(charset)).map(body => Part(partName, body, fileName = fileName)) + case RawBodyType.ByteArrayBody => + readHttpData(httpData, _.get()).map(body => Part(partName, body, fileName = fileName)) + case RawBodyType.ByteBufferBody => + readHttpData(httpData, _.get()).map(body => Part(partName, ByteBuffer.wrap(body), fileName = fileName)) + case RawBodyType.InputStreamBody => + (if (httpData.isInMemory()) + monad.unit(new ByteArrayInputStream(httpData.get())) + else { + monad.blocking(java.nio.file.Files.newInputStream(httpData.getFile().toPath())) + }).map(body => Part(partName, body, fileName = fileName)) + case RawBodyType.InputStreamRangeBody => + val body = () => { + if (httpData.isInMemory()) + new ByteArrayInputStream(httpData.get()) + else + java.nio.file.Files.newInputStream(httpData.getFile().toPath()) + } + monad.unit(Part(partName, InputStreamRange(body), fileName = fileName)) + case RawBodyType.FileBody => + val fileF: F[File] = + if (httpData.isInMemory()) + (for { + file <- createFile(serverRequest) + _ <- writeBytesToFile(httpData.get(), file) + } yield file) + else monad.unit(httpData.getFile()) + fileF.map(file => Part(partName, FileRange(file), fileName = fileName)) + case _: RawBodyType.MultipartBody => + monad.error(new UnsupportedOperationException(s"Nested multipart not supported, part name = $partName")) + } + } + + private def readHttpData[T](httpData: HttpData, f: HttpData => T): F[T] = + if (httpData.isInMemory()) + monad.unit(f(httpData)) + else + monad.blocking(f(httpData)) } diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/NettySyncRequestBody.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/NettySyncRequestBody.scala index d6cc30a3f3..0f0d542ab9 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/NettySyncRequestBody.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/NettySyncRequestBody.scala @@ -11,6 +11,10 @@ import sttp.tapir.model.ServerRequest import sttp.tapir.server.netty.internal.NettyRequestBody import sttp.tapir.server.netty.internal.reactivestreams.{FileWriterSubscriber, SimpleSubscriber} import sttp.tapir.server.netty.sync.* +import sttp.tapir.RawBodyType +import sttp.tapir.server.interpreter.RawValue +import sttp.tapir.RawPart +import java.io.File private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirFile) extends NettyRequestBody[Identity, OxStreams]: @@ -20,6 +24,14 @@ private[sync] class NettySyncRequestBody(val createFile: ServerRequest => TapirF override def publisherToBytes(publisher: Publisher[HttpContent], contentLength: Option[Long], maxBytes: Option[Long]): Array[Byte] = SimpleSubscriber.processAllBlocking(publisher, contentLength, maxBytes) + override def publisherToMultipart( + nettyRequest: StreamedHttpRequest, + serverRequest: ServerRequest, + m: RawBodyType.MultipartBody, + maxBytes: Option[Long] + ): RawValue[Seq[RawPart]] = throw new UnsupportedOperationException("Multipart requests not supported.") + override def writeBytesToFile(bytes: Array[Byte], file: File) = throw new UnsupportedOperationException() + override def writeToFile(serverRequest: ServerRequest, file: TapirFile, maxBytes: Option[Long]): Unit = serverRequest.underlying match case r: StreamedHttpRequest => FileWriterSubscriber.processAllBlocking(r, file.toPath, maxBytes) diff --git a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala index 3cb9b9ab21..d32853f0f7 100644 --- a/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala +++ b/server/netty-server/zio/src/main/scala/sttp/tapir/server/netty/zio/internal/NettyZioRequestBody.scala @@ -1,15 +1,19 @@ package sttp.tapir.server.netty.zio.internal import io.netty.handler.codec.http.HttpContent +import org.playframework.netty.http.StreamedHttpRequest import org.reactivestreams.Publisher import sttp.capabilities.zio.ZioStreams import sttp.monad.MonadError -import sttp.tapir.TapirFile import sttp.tapir.model.ServerRequest +import sttp.tapir.server.interpreter.RawValue import sttp.tapir.server.netty.internal.{NettyStreamingRequestBody, StreamCompatible} import sttp.tapir.ztapir.RIOMonadError -import zio.RIO +import sttp.tapir.{RawBodyType, RawPart, TapirFile} import zio.stream._ +import zio.{RIO, ZIO} + +import java.io.File private[zio] class NettyZioRequestBody[Env]( val createFile: ServerRequest => RIO[Env, TapirFile], @@ -19,6 +23,14 @@ private[zio] class NettyZioRequestBody[Env]( override val streams: ZioStreams = ZioStreams override implicit val monad: MonadError[RIO[Env, *]] = new RIOMonadError[Env] + override def publisherToMultipart( + nettyRequest: StreamedHttpRequest, + serverRequest: ServerRequest, + m: RawBodyType.MultipartBody, + maxBytes: Option[Long] + ): RIO[Env, RawValue[Seq[RawPart]]] = ZIO.die(new UnsupportedOperationException("Multipart requests not supported.")) + + override def writeBytesToFile(bytes: Array[Byte], file: File): RIO[Env, Unit] = ZIO.die(new UnsupportedOperationException) override def publisherToBytes( publisher: Publisher[HttpContent], contentLength: Option[Long], diff --git a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala index a5f5b81886..8500485966 100644 --- a/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala +++ b/server/pekko-http-server/src/test/scala/sttp/tapir/server/pekkohttp/PekkoHttpServerTest.scala @@ -166,7 +166,8 @@ class PekkoHttpServerTest extends TestSuite with EitherValues { def drainPekko(stream: PekkoStreams.BinaryStream): Future[Unit] = stream.runWith(Sink.ignore).map(_ => ()) - new AllServerTests(createServerTest, interpreter, backend).tests() ++ + new AllServerTests(createServerTest, interpreter, backend, multipart = false).tests() ++ + new ServerMultipartTests(createServerTest, chunkingSupport = false).tests() ++ // chunking disabled, pekko-http rejects content-length with transfer-encoding new ServerStreamingTests(createServerTest).tests(PekkoStreams)(drainPekko) ++ new ServerWebSocketTests( createServerTest, diff --git a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index 5ee450b051..a4c6064152 100644 --- a/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -111,7 +111,8 @@ class PlayServerTest extends TestSuite { inputStreamSupport = false, invulnerableToUnsanitizedHeaders = false ).tests() ++ - new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++ + // chunking disabled, Play rejects content-length with transfer-encoding + new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false, chunkingSupport = false).tests() ++ new AllServerTests( createServerTest, interpreter, diff --git a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala index aa87e4eb22..2aa2b78c93 100644 --- a/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala +++ b/server/play29-server/src/test/scala/sttp/tapir/server/play/PlayServerTest.scala @@ -110,7 +110,8 @@ class PlayServerTest extends TestSuite { inputStreamSupport = false, invulnerableToUnsanitizedHeaders = false ).tests() ++ - new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false).tests() ++ + // chunking disabled, Play rejects content-length with transfer-encoding + new ServerMultipartTests(createServerTest, partOtherHeaderSupport = false, chunkingSupport = false).tests() ++ new AllServerTests(createServerTest, interpreter, backend, basic = false, multipart = false, options = false).tests() ++ new ServerStreamingTests(createServerTest).tests(AkkaStreams)(drainAkka) ++ new PlayServerWithContextTest(backend).tests() ++ diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala index 157b27cce3..f2a15133b8 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerMultipartTests.scala @@ -7,13 +7,7 @@ import sttp.model.{Part, StatusCode} import sttp.monad.MonadError import sttp.tapir._ import sttp.tapir.generic.auto._ -import sttp.tapir.tests.Multipart.{ - in_file_list_multipart_out_multipart, - in_file_multipart_out_multipart, - in_raw_multipart_out_string, - in_simple_multipart_out_multipart, - in_simple_multipart_out_string -} +import sttp.tapir.tests.Multipart._ import sttp.tapir.tests.TestUtil.{readFromFile, writeToFile} import sttp.tapir.tests.data.{DoubleFruit, FruitAmount, FruitData} import sttp.tapir.tests.{MultipleFileUpload, Test, data} @@ -21,18 +15,26 @@ import sttp.tapir.server.model.EndpointExtensions._ import scala.concurrent.Await import scala.concurrent.duration.DurationInt +import java.io.File +import fs2.io.file.Files +import cats.effect.IO +import fs2.io.file.Path class ServerMultipartTests[F[_], OPTIONS, ROUTE]( createServerTest: CreateServerTest[F, Any, OPTIONS, ROUTE], partContentTypeHeaderSupport: Boolean = true, partOtherHeaderSupport: Boolean = true, - maxContentLengthSupport: Boolean = true + maxContentLengthSupport: Boolean = true, + chunkingSupport: Boolean = true, + multipartResponsesSupport: Boolean = true )(implicit m: MonadError[F]) { import createServerTest._ def tests(): List[Test] = basicTests() ++ (if (partContentTypeHeaderSupport) contentTypeHeaderTests() else Nil) ++ - (if (maxContentLengthSupport) maxContentLengthTests() else Nil) + (if (maxContentLengthSupport) maxContentLengthTests() else Nil) ++ (if (multipartResponsesSupport) multipartResponsesTests() + else + Nil) ++ (if (chunkingSupport) chunkedMultipartTests() else Nil) def maxContentLengthTests(): List[Test] = List( testServer( @@ -61,18 +63,6 @@ class ServerMultipartTests[F[_], OPTIONS, ROUTE]( def basicTests(): List[Test] = { List( - testServer(in_simple_multipart_out_multipart)((fa: FruitAmount) => - pureResult(FruitAmount(fa.fruit + " apple", fa.amount * 2).asRight[Unit]) - ) { (backend, baseUri) => - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody(multipart("fruit", "pineapple"), multipart("amount", "120")) - .send(backend) - .map { r => - r.body should include regex "name=\"fruit\"[\\s\\S]*pineapple apple" - r.body should include regex "name=\"amount\"[\\s\\S]*240" - } - }, testServer(in_simple_multipart_out_string, "discard unknown parts")((fa: FruitAmount) => pureResult(fa.toString.asRight[Unit])) { (backend, baseUri) => basicStringRequest @@ -83,61 +73,6 @@ class ServerMultipartTests[F[_], OPTIONS, ROUTE]( r.body shouldBe "FruitAmount(pineapple,120)" } }, - testServer(in_file_multipart_out_multipart)((fd: FruitData) => - pureResult( - data - .FruitData( - Part("", writeToFile(Await.result(readFromFile(fd.data.body), 3.seconds).reverse), fd.data.otherDispositionParams, Nil) - .header("X-Auth", fd.data.headers.find(_.is("X-Auth")).map(_.value).toString) - ) - .asRight[Unit] - ) - ) { (backend, baseUri) => - val file = writeToFile("peach mario") - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody(multipartFile("data", file).fileName("fruit-data.txt").header("X-Auth", "12Aa")) - .send(backend) - .map { r => - r.code shouldBe StatusCode.Ok - if (partOtherHeaderSupport) r.body should include regex "((?i)X-Auth):[ ]?Some\\(12Aa\\)" - r.body should include regex "name=\"data\"[\\s\\S]*oiram hcaep" - } - }, - testServer(in_file_list_multipart_out_multipart) { (mfu: MultipleFileUpload) => - val files = mfu.files.map { part => - Part( - part.name, - writeToFile(Await.result(readFromFile(part.body), 3.seconds) + " result"), - part.otherDispositionParams, - Nil - ).header("X-Auth", part.headers.find(_.is("X-Auth")).map(_.value + "x").getOrElse("")) - } - pureResult(MultipleFileUpload(files).asRight[Unit]) - } { (backend, baseUri) => - val file1 = writeToFile("peach mario 1") - val file2 = writeToFile("peach mario 2") - val file3 = writeToFile("peach mario 3") - basicStringRequest - .post(uri"$baseUri/api/echo/multipart") - .multipartBody( - multipartFile("files", file1).fileName("fruit-data-1.txt").header("X-Auth", "12Aa"), - multipartFile("files", file2).fileName("fruit-data-2.txt").header("X-Auth", "12Ab"), - multipartFile("files", file3).fileName("fruit-data-3.txt").header("X-Auth", "12Ac") - ) - .send(backend) - .map { r => - r.code shouldBe StatusCode.Ok - if (partOtherHeaderSupport) { - r.body should include regex "((?i)X-Auth):[ ]?12Aax" - r.body should include regex "((?i)X-Auth):[ ]?12Abx" - r.body should include regex "((?i)X-Auth):[ ]?12Acx" - } - r.body should include("peach mario 1 result") - r.body should include("peach mario 2 result") - r.body should include("peach mario 3 result") - } - }, testServer(in_raw_multipart_out_string)((parts: Seq[Part[Array[Byte]]]) => pureResult( parts.map(part => s"${part.name}:${new String(part.body)}").mkString("\n").asRight[Unit] @@ -184,10 +119,157 @@ class ServerMultipartTests[F[_], OPTIONS, ROUTE]( r.code shouldBe StatusCode.Ok r.body should be("firstPart:BODYONE\r\n--AA\n__\nsecondPart:BODYTWO") } + }, + testServer(in_file_multipart_out_string, "simple file multipart body")((fd: FruitData) => { + val content = Await.result(readFromFile(fd.data.body), 3.seconds) + pureResult(content.reverse.asRight[Unit]) + }) { (backend, baseUri) => + val file = writeToFile("peach2 mario2") + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipartFile("data", file).fileName("fruit-data7.txt").header("X-Auth", "12Aa")) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + r.body shouldBe "2oiram 2hcaep" + } + }, + testServer(in_file_multipart_out_string, "large file multipart body")((fd: FruitData) => { + val fileSize = fd.data.body.length() // FIXME is 0, because decoder.destroy() removes the file + pureResult(fileSize.toString.asRight[Unit]) + }) { (backend, baseUri) => + val file = File.createTempFile("test", "tapir") + file.deleteOnExit() + fs2.Stream + .constant[IO, Byte]('x') + .take(5 * 1024 * 1024) + .through(Files.forAsync[IO].writeAll(Path.fromNioPath(file.toPath))) + .compile + .drain >> + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipartFile("data", file).fileName("fruit-data8.txt")) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + r.body shouldBe "5242880" + } + }, + testServer(in_file_multipart_out_string, "file from a multipart attribute")((fd: FruitData) => { + val content = Await.result(readFromFile(fd.data.body), 3.seconds) + pureResult(content.reverse.asRight[Unit]) + }) { (backend, baseUri) => + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipart("data", "peach3 mario3")) + .send(backend) + .map { r => + // r.code shouldBe StatusCode.Ok + r.body shouldBe "3oiram 3hcaep" + } } ) } + def chunkedMultipartTests() = List( + testServer(in_raw_multipart_out_string, "chunked multipart attribute")((parts: Seq[Part[Array[Byte]]]) => + pureResult( + parts.map(part => s"${part.name}:${new String(part.body)}").mkString("\n__\n").asRight[Unit] + ) + ) { (backend, baseUri) => + val testBody = "61\r\n--boundary123\r\n" + + "Content-Disposition: form-data; name=\"attr1\"\r\n" + + "Content-Type: text/plain\r\n" + + "\r\nValue1\r\n" + + "\r\n47\r\n--boundary123\r\n" + + "Content-Disposition: form-data; name=\"attr2\"\r\n" + + "\r\nPart1 of\r\n" + + "1E\r\n Attr2 Value\r\n" + + "--boundary123--\r\n\r\n" + + "0\r\n\r\n" + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .header("Content-Type", "multipart/form-data; boundary=boundary123") + .header("Transfer-Encoding", "chunked") + .body(testBody) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + r.body should be("attr1:Value1\n__\nattr2:Part1 of Attr2 Value") + } + } + ) + + def multipartResponsesTests() = List( + testServer(in_simple_multipart_out_multipart)((fa: FruitAmount) => + pureResult(FruitAmount(fa.fruit + " apple", fa.amount * 2).asRight[Unit]) + ) { (backend, baseUri) => + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipart("fruit", "pineapple"), multipart("amount", "120")) + .send(backend) + .map { r => + r.body should include regex "name=\"fruit\"[\\s\\S]*pineapple apple" + r.body should include regex "name=\"amount\"[\\s\\S]*240" + } + }, + testServer(in_file_multipart_out_multipart)((fd: FruitData) => + pureResult( + data + .FruitData( + Part("", writeToFile(Await.result(readFromFile(fd.data.body), 3.seconds).reverse), fd.data.otherDispositionParams, Nil) + .header("X-Auth", fd.data.headers.find(_.is("X-Auth")).map(_.value).toString) + ) + .asRight[Unit] + ) + ) { (backend, baseUri) => + val file = writeToFile("peach mario") + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody(multipartFile("data", file).fileName("fruit-data.txt").header("X-Auth", "12Aa")) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + if (partOtherHeaderSupport) r.body should include regex "((?i)X-Auth):[ ]?Some\\(12Aa\\)" + r.body should include regex "name=\"data\"[\\s\\S]*oiram hcaep" + } + }, + testServer(in_file_list_multipart_out_multipart) { (mfu: MultipleFileUpload) => + val files = mfu.files.map { part => + Part( + part.name, + writeToFile(Await.result(readFromFile(part.body), 3.seconds) + " result"), + part.otherDispositionParams, + Nil + ).header("X-Auth", part.headers.find(_.is("X-Auth")).map(_.value + "x").getOrElse("")) + } + pureResult(MultipleFileUpload(files).asRight[Unit]) + } { (backend, baseUri) => + val file1 = writeToFile("peach mario 1") + val file2 = writeToFile("peach mario 2") + val file3 = writeToFile("peach mario 3") + basicStringRequest + .post(uri"$baseUri/api/echo/multipart") + .multipartBody( + multipartFile("files", file1).fileName("fruit-data-1.txt").header("X-Auth", "12Aa"), + multipartFile("files", file2).fileName("fruit-data-2.txt").header("X-Auth", "12Ab"), + multipartFile("files", file3).fileName("fruit-data-3.txt").header("X-Auth", "12Ac") + ) + .send(backend) + .map { r => + r.code shouldBe StatusCode.Ok + if (partOtherHeaderSupport) { + r.body should include regex "((?i)X-Auth):[ ]?12Aax" + r.body should include regex "((?i)X-Auth):[ ]?12Abx" + r.body should include regex "((?i)X-Auth):[ ]?12Acx" + } + r.body should include("peach mario 1 result") + r.body should include("peach mario 2 result") + r.body should include("peach mario 3 result") + } + } + ) + def contentTypeHeaderTests(): List[Test] = List( testServer(in_file_multipart_out_multipart, "with part content type header")((fd: FruitData) => pureResult( diff --git a/tests/src/main/scala/sttp/tapir/tests/Multipart.scala b/tests/src/main/scala/sttp/tapir/tests/Multipart.scala index 7d41320cd2..38cd060379 100644 --- a/tests/src/main/scala/sttp/tapir/tests/Multipart.scala +++ b/tests/src/main/scala/sttp/tapir/tests/Multipart.scala @@ -23,6 +23,9 @@ object Multipart { val in_file_multipart_out_multipart: PublicEndpoint[FruitData, Unit, FruitData, Any] = endpoint.post.in("api" / "echo" / "multipart").in(multipartBody[FruitData]).out(multipartBody[FruitData]).name("echo file") + val in_file_multipart_out_string: PublicEndpoint[FruitData, Unit, String, Any] = + endpoint.post.in("api" / "echo" / "multipart").in(multipartBody[FruitData]).out(stringBody) + val in_file_list_multipart_out_multipart: PublicEndpoint[MultipleFileUpload, Unit, MultipleFileUpload, Any] = endpoint.post .in("api" / "echo" / "multipart")