diff --git a/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/Metarpheus.scala b/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/Metarpheus.scala index 11d5b43db..086831204 100644 --- a/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/Metarpheus.scala +++ b/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/Metarpheus.scala @@ -9,13 +9,17 @@ import scala.meta.internal.io.PlatformFileIO object Metarpheus { def run(paths: List[String], config: Config): intermediate.API = { - val files = paths - .flatMap(path => PlatformFileIO.listAllFilesRecursively(AbsolutePath(path))) - .filter(_.toString.endsWith(".scala")) - val parsed = files.map(File(_).parse[Source].get) + val parsed = parseFiles(paths) extractors .extractFullAPI(parsed = parsed) .stripUnusedModels(config.modelsForciblyInUse, config.discardRouteErrorModels) } + def parseFiles(paths: List[String]) : List[Source] = { + val files = paths + .flatMap(path => PlatformFileIO.listAllFilesRecursively(AbsolutePath(path))) + .filter(_.toString.endsWith(".scala")) + files.map(File(_).parse[Source].get) +} + } diff --git a/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/extractors/package.scala b/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/extractors/package.scala index a167ea95e..b74ca037d 100644 --- a/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/extractors/package.scala +++ b/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/extractors/package.scala @@ -19,6 +19,11 @@ package object extractors { intermediate.API(models, routes) } + def extractImports(source: scala.meta.Source): List[scala.meta.Import] = + source.collect { + case imp: scala.meta.Import => imp + } + /** * Extract all terms from a sequence of applications of an infix operator * (which translates to nested `ApplyInfix`es). diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala index 9039b7cd3..2f8dc32d3 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/AkkaHttpMeta.scala @@ -7,7 +7,7 @@ import scala.meta._ object AkkaHttpMeta { val `class` = ( `package`: Term.Ref, - imports: Set[Term.Ref], + imports: Set[Import], controllerName: Type.Name, tapirEndpointsName: Term.Name, authTokenName: Type.Name, @@ -21,7 +21,7 @@ object AkkaHttpMeta { val tapirEndpoints = q"val endpoints = $tapirEndpointsName.create[$authTokenName](statusCodes)" q""" package ${`package`} { - ..${imports.toList.sortWith(_.toString < _.toString).map(i => q"import $i._")} + ..${imports.toList.sortWith(_.toString < _.toString)} import akka.http.scaladsl.server._ import akka.http.scaladsl.server.Directives._ import io.circe.{ Decoder, Encoder } diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala index 715a6fa8a..0cd298f1a 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/Http4sMeta.scala @@ -7,7 +7,7 @@ import scala.meta._ object Http4sMeta { val `class` = ( `package`: Term.Ref, - imports: Set[Term.Ref], + imports: Set[Import], controllerName: Type.Name, tapirEndpointsName: Term.Name, authTokenName: Type.Name, @@ -21,7 +21,7 @@ object Http4sMeta { val tapirEndpoints = q"val endpoints = $tapirEndpointsName.create[$authTokenName](statusCodes)" q""" package ${`package`} { - ..${imports.toList.sortWith(_.toString < _.toString).map(i => q"import $i._")} + ..${imports.toList.sortWith(_.toString < _.toString)} import cats.effect._ import cats.implicits._ import cats.data.NonEmptyList diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala index ce7513c88..a8e15dfce 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/Meta.scala @@ -8,6 +8,29 @@ import scala.meta.contrib._ import cats.data.NonEmptyList object Meta { + + val controllerImports =(routes: List[TapiroRoute],imports: Set[Importer],outputPackage:String) => { + val types = routes.flatMap(tr => tr.route.params.map(r=> r.tpe) ++ (if (tr.route.error.isDefined) List(tr.route.error.get,tr.route.returns) else List(tr.route.returns))) + val typeNameList = types.flatMap(typeNameExtractor(_)).toSet + val (wildcardImportList,importList) = imports.collect{ + case i : Importer if i.syntax.endsWith("._") => List(i) + case Importer(ref,importees) => importees.filter(p=> typeNameList.contains(p.syntax)).map(p=> Importer(ref,List(p))) + }.flatten.partition(_.syntax.endsWith("._")) + val controllerPackageStrings = routes.map(r=> s"import ${r.route.controllerPackage.mkString(".")}._") + val controllerPackage : List[Import] =controllerPackageStrings.flatMap(_.parse[Source].getOrElse(Source(List())).tree.children).collect{case i: Import => i} + val importers = (if (importList.flatMap(i=> typeNameList.diff(i.importees.map(_.syntax).toSet)).isEmpty) importList else { + (wildcardImportList ++ importList) + }).map(i=>Import(List(i))) + val result= if (controllerPackageStrings.filter(_.endsWith(outputPackage+"._")).isEmpty) controllerPackage ++ importers else importers + deduplicate(result.toList).toSet + } : Set[Import] + + def typeNameExtractor (tpe : MetarpheusType) : Set[String]= + tpe match { + case MetarpheusType.Name(name) => Set(name) + case MetarpheusType.Apply(head,args) => (head :: args.map(typeNameExtractor(_)).flatten.toList).toSet + } + val codecsImplicits = (routes: List[TapiroRoute], authTokenName: String) => { val notUnit = (t: MetarpheusType) => t != MetarpheusType.Name("Unit") val toDecoder = (t: Type) => t"Decoder[${extractListType(t)}]" @@ -44,10 +67,10 @@ object Meta { case _ => t } - private[this] val deduplicate: List[Type] => List[Type] = (ts: List[Type]) => + private[this] def deduplicate[A<:Tree](ts: List[A]): List[A] = ts match { case Nil => Nil - case head :: tail => head :: deduplicate(tail.filter(!_.isEqual(head))) + case head :: tail => head :: deduplicate(tail.filter(!_.syntax.equals(head.syntax))) } private[this] val isAuthToken = (t: MetarpheusType, authTokenName: String) => t == MetarpheusType.Name(authTokenName) diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala index ee0395a78..af07426d6 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/TapirMeta.scala @@ -14,7 +14,7 @@ object TapirMeta { val `class` = ( `package`: Term.Ref, - imports: Set[Term.Ref], + imports: Set[Import], tapirEndpointsName: Term.Name, authTokenName: Type.Name, implicits: List[Term.Param], @@ -26,7 +26,7 @@ object TapirMeta { Type.Param(List(), authTokenName, List(), Type.Bounds(None, None), List(), List()) q""" package ${`package`} { - ..${imports.toList.sortWith(_.toString < _.toString).map(i => q"import $i._")} + ..${imports.toList.sortWith(_.toString < _.toString)} import io.circe.{ Decoder, Encoder } import io.circe.generic.semiauto.{ deriveDecoder, deriveEncoder } import sttp.tapir._ diff --git a/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala b/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala index 1748d553a..4acc85fab 100644 --- a/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala +++ b/tapiro/core/src/main/scala/io/buildo/tapiro/Util.scala @@ -1,9 +1,7 @@ package io.buildo.tapiro - +import io.buildo.metarpheus.core.extractors._ import io.buildo.metarpheus.core.{Config, Metarpheus} import io.buildo.metarpheus.core.intermediate.{ - CaseClass, - CaseEnum, Route, TaggedUnion, Type => MetarpheusType, @@ -55,7 +53,7 @@ class Util() { NonEmptyList.fromList(`package`) match { case Some(nonEmptyPackage) => val config = Config(Set.empty) - val models = Metarpheus.run(modelsPaths, config).models + val models = Metarpheus.run(modelsPaths, config).models //this is needed because Metarpheus removes auth from params when authentication type is "Auth" //see https://github.com/buildo/retro/blob/dfe62fa54d4f34c1861d694ac0cd8fa82f0a8703/metarpheus/core/src/main/scala/io.buildo.metarpheus/core/extractors/controller.scala#L35 val routesWithAuthParams: List[Route] = Metarpheus @@ -74,19 +72,13 @@ class Util() { ) else r, }) - val routes: List[TapiroRoute] = + val routes: List[TapiroRoute] = routesWithAuthParams.map(toTapiroRoute(models)) + val imports = Metarpheus.parseFiles(routesPaths).map(extractImports).flatten.map(_.importers).flatten.toSet val controllersRoutes = routes.groupBy( route => (route.route.controllerType, route.route.pathName), ) - val modelsPackages = models.map { - case c: CaseClass => c.`package` - case c: CaseEnum => c.`package` - case t: TaggedUnion => t.`package` - }.collect { - case head :: tail => NonEmptyList(head, tail) - } controllersRoutes.foreach { case ((controllerType, pathName), routes) => val controllerName = typeNameString(controllerType) @@ -94,21 +86,16 @@ class Util() { val pathNameOrController = pathName.getOrElse(controllerName) val tapirEndpointsName = s"${pathNameOrController}TapirEndpoints".capitalize val httpEndpointsName = s"${pathNameOrController}HttpEndpoints".capitalize + val controllerImports = Meta.controllerImports(routes,imports,nonEmptyPackage.toList.mkString(".")) val tapirEndpoints = createTapirEndpoints( tapirEndpointsName, authTypeString, routes, nonEmptyPackage, - modelsPackages, + controllerImports, ) writeToFile(outputPath, tapirEndpoints, tapirEndpointsName) - - val routesPackages = routes - .map(_.route.controllerPackage) - .collect { - case head :: tail => NonEmptyList(head, tail) - } server match { case Server.Http4s => val http4sEndpoints = @@ -119,7 +106,7 @@ class Util() { tapirEndpointsName, authTypeString, httpEndpointsName, - modelsPackages ++ routesPackages, + controllerImports, routes, ) http4sEndpoints.foreach(writeToFile(outputPath, _, httpEndpointsName)) @@ -132,7 +119,7 @@ class Util() { tapirEndpointsName, authTypeString, httpEndpointsName, - modelsPackages ++ routesPackages, + controllerImports, routes, ) akkaHttpEndpoints.foreach( @@ -150,12 +137,12 @@ class Util() { authTokenName: String, routes: List[TapiroRoute], `package`: NonEmptyList[String], - requiredPackages: List[NonEmptyList[String]], + requiredPackages: Set[Import], ): String = { format( TapirMeta.`class`( Meta.packageFromList(`package`), - requiredPackages.toSet.map(Meta.packageFromList), + requiredPackages, Term.Name(tapirEndpointsName), Type.Name(authTokenName), Meta.codecsImplicits(routes, authTokenName), @@ -173,7 +160,7 @@ class Util() { tapirEndpointsName: String, authTokenName: String, httpEndpointsName: String, - requiredPackages: List[NonEmptyList[String]], + requiredPackages: Set[Import], tapiroRoutes: List[TapiroRoute], ): Option[String] = { val routes = tapiroRoutes.map(_.route) @@ -184,7 +171,7 @@ class Util() { format( Http4sMeta.`class`( Meta.packageFromList(`package`), - requiredPackages.toSet.map(Meta.packageFromList), + requiredPackages, Type.Name(controllerName), Term.Name(tapirEndpointsName), Type.Name(authTokenName), @@ -206,7 +193,7 @@ class Util() { tapirEndpointsName: String, authTokenName: String, httpEndpointsName: String, - requiredPackages: List[NonEmptyList[String]], + requiredPackages: Set[Import], tapiroRoutes: List[TapiroRoute], ): Option[String] = { val routes = tapiroRoutes.map(_.route) @@ -217,7 +204,7 @@ class Util() { format( AkkaHttpMeta.`class`( Meta.packageFromList(`package`), - requiredPackages.toSet.map(Meta.packageFromList), + requiredPackages, Type.Name(controllerName), Term.Name(tapirEndpointsName), Type.Name(authTokenName),