Skip to content

Commit 8359b2a

Browse files
authored
Feature/support non wrapped abstracts (#11)
* Add support for forwarding abstrat methods not wrapped by the algebra's kind * Detect when abstract return types are parameterised by the algebra's kind parameter and abort with a nice message. * Simplify extraction of supported annottees.
1 parent ce75227 commit 8359b2a

File tree

3 files changed

+140
-106
lines changed

3 files changed

+140
-106
lines changed

core/src/main/scala/diesel/internal/KTransImpl.scala

Lines changed: 123 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -89,35 +89,10 @@ object KTransImpl {
8989

9090
def build(): Defn.Def = {
9191
ensureSoundness()
92-
val forwardedAbstracts = abstracts.flatMap {
93-
case Decl.Val(mods, pats, Type.Apply(_, declTpeParams)) => {
94-
val newdeclTpe = Type.Apply(targetKType, declTpeParams)
95-
pats.map { pat =>
96-
val access = q"""$natTransArg.apply($currentTraitHandle.${pat.name})"""
97-
Defn.Val(mods, Seq(pat), Some(newdeclTpe), access)
98-
}
99-
}
100-
case origDef: Decl.Def => {
101-
/*
102-
First, "bump" type parameters by adding a TransK-related suffix.
103-
Bumping all of them across the board helps us avoid having to deal with
104-
possible collisions with our new, transformed Kind
105-
*/
106-
val defWithTransKedTParams = bumpTypeParamsToTransKed(origDef)
107-
val mods = defWithTransKedTParams.mods
108-
val name = defWithTransKedTParams.name
109-
val tparams = defWithTransKedTParams.tparams
110-
val paramss = defWithTransKedTParams.paramss
111-
val declTpe = defWithTransKedTParams.decltpe
112-
val tparamTypes = tparams.map(tp => Type.Name(tp.name.value))
113-
val paramNames = paramss.map(_.map(tp => Term.Name(tp.name.value)))
114-
val body =
115-
if (tparamTypes.nonEmpty)
116-
q"""$natTransArg.apply($currentTraitHandle.$name[..$tparamTypes](...$paramNames))"""
117-
else
118-
q"""$natTransArg($currentTraitHandle.$name(...$paramNames))"""
119-
Seq(Defn.Def(mods, name, tparams, paramss, Some(declTpe), body))
120-
}
92+
val forwardedAbstracts = forwardableAbstracts.flatMap {
93+
case decl: Decl.Val => toForwardedDefnVals(decl)
94+
case decl: Decl.Def => Seq(toForwardedDefnDef(decl))
95+
case _ => abort("Oh noes! You found a bug in the macro! Please file an issue :)")
12196
}
12297

12398
q"""
@@ -145,7 +120,7 @@ object KTransImpl {
145120
Nil
146121
}
147122

148-
val dslMembersSet = (abstracts: List[Stat]).toSet
123+
val dslMembersSet = (forwardableAbstracts: List[Stat]).toSet
149124
val concreteMembersSet = (concretes: List[Stat]).toSet
150125
// The spaces in multiline strings are significant
151126
val statsWithErrors = findErrors(
@@ -157,15 +132,15 @@ object KTransImpl {
157132
privateMembersPf(concreteMembersSet)),
158133
("Return types must be explicitly stated.", noReturnTypePf(concreteMembersSet)),
159134
("Abstract type members are not supported.", abstractType),
160-
(s"""The return type of this method is not wrapped in $tparamName[_].""".stripMargin,
135+
(s"""The return type of this method references $tparamName[_] as a type argument.""".stripMargin,
161136
nonMatchingKindPf(dslMembersSet ++ concreteMembersSet)),
162137
("Vars are not allowed.", varsPf(Set.empty)),
163138
("Vals that are not assignments are not allowed at the moment.",
164139
patternMatchingVals(concreteMembersSet)),
165140
(s"Type member shadows the algebra's kind $tparamName[_] (same name or otherwise points to it).",
166141
typeMemberPointsToKind),
167142
(s"""This method has a type parameter that shadows the $tparamName[_] used to annotate the trait.
168-
| Besides being confusing for readers of your code, this is not currently supported by diesel.""".stripMargin,
143+
| Besides being confusing for readers of your code, this is not currently supported.""".stripMargin,
169144
methodsShadowingTParamPF)
170145
)
171146
)
@@ -233,22 +208,22 @@ object KTransImpl {
233208
case v: Defn => v
234209
}.toList
235210

236-
private val abstracts: List[Decl] =
211+
private val forwardableAbstracts: List[Decl] =
237212
templateStatements.collect {
238-
case d @ Decl.Def(_, _, _, _, Type.Apply(retName: Type.Select, _))
239-
if retName.name.value == tparamName =>
240-
d
241-
case v @ Decl.Val(_, _, Type.Apply(retName: Type.Select, _))
242-
if retName.name.value == tparamName =>
243-
v
244-
case d @ Decl.Def(_, _, _, _, Type.Apply(retName: Type.Name, _))
245-
if retName.value == tparamName =>
246-
d
247-
case v @ Decl.Val(_, _, Type.Apply(retName: Type.Name, _))
248-
if retName.value == tparamName =>
249-
v
213+
case d: Decl.Def if algKindWrapped(d.decltpe) || !typeRefsAlgKind(d.decltpe) => d
214+
case v: Decl.Val if algKindWrapped(v.decltpe) || !typeRefsAlgKind(v.decltpe) => v
250215
}.toList
251216

217+
private def algKindWrapped(t: Type): Boolean = t match {
218+
case Type.Apply(retName: Type.Select, tArgs)
219+
if retName.name.value == tparamName && !tArgs.map(typeRefsAlgKind).exists(identity) =>
220+
true
221+
case Type.Apply(retName: Type.Name, tArgs)
222+
if retName.value == tparamName && !tArgs.map(typeRefsAlgKind).exists(identity) =>
223+
true
224+
case _ => false
225+
}
226+
252227
private def findErrors(
253228
msgsToPfs: Seq[(String, PartialFunction[Stat, Stat])]): Seq[(Stat, Seq[String])] = {
254229
templateStatements.foldLeft(Seq.empty[(Stat, Seq[String])]) {
@@ -340,24 +315,71 @@ object KTransImpl {
340315
v
341316
}
342317

343-
private def bumpTypeParamsToTransKed(meth: Decl.Def): Decl.Def = {
318+
private def toForwardedDefnVals(abstVal: Decl.Val): Seq[Defn.Val] = {
319+
val declTypeWrappedByAlgKind = algKindWrapped(abstVal.decltpe)
320+
val newdeclTpe =
321+
if (declTypeWrappedByAlgKind)
322+
suffixTypeNames(Set(tparamName))(abstVal.decltpe)
323+
else
324+
abstVal.decltpe
325+
abstVal.pats.map { pat =>
326+
val forwardingCall = q"""$currentTraitHandle.${pat.name}"""
327+
val body =
328+
if (declTypeWrappedByAlgKind)
329+
q"""$natTransArg.apply($forwardingCall)"""
330+
else
331+
forwardingCall
332+
Defn.Val(abstVal.mods, Seq(pat), Some(newdeclTpe), body)
333+
}
334+
}
335+
336+
private def toForwardedDefnDef(abstrMeth: Decl.Def): Defn.Def = {
337+
val declTypeWrappedByAlgKind = algKindWrapped(abstrMeth.decltpe)
338+
339+
val defWithTransKedTParams = addSuffixToTypeParams(abstrMeth)
340+
val mods = defWithTransKedTParams.mods
341+
val name = defWithTransKedTParams.name
342+
val tparams = defWithTransKedTParams.tparams
343+
val paramss = defWithTransKedTParams.paramss
344+
345+
// Do not use KTrans wrapped return type if the original type was not wrapped in the
346+
// algebra kind.
347+
val declTpe = defWithTransKedTParams.decltpe
348+
val tparamTypes = tparams.map(tp => Type.Name(tp.name.value))
349+
val paramNames = paramss.map(_.map(tp => Term.Name(tp.name.value)))
350+
val forwardingCall =
351+
if (tparamTypes.nonEmpty)
352+
q"""$currentTraitHandle.$name[..$tparamTypes](...$paramNames)"""
353+
else
354+
q"""$currentTraitHandle.$name(...$paramNames)"""
355+
val body =
356+
if (declTypeWrappedByAlgKind)
357+
q"""$natTransArg.apply($forwardingCall)"""
358+
else
359+
forwardingCall
360+
Defn.Def(mods, name, tparams, paramss, Some(declTpe), body)
361+
}
362+
363+
private def addSuffixToTypeParams(meth: Decl.Def): Decl.Def = {
344364
// Add on the Kind param of the original algebra because we want it to be properly suffixed
345365
// when referenced to in the methods of our new algebra implementation.
346-
val tParamsToBump = meth.tparams.map(_.name.value).toSet + tparamName
366+
val tParamsToBump = meth.tparams.map(_.name.value).toSet
347367

348368
def bumpTParam(tparam: Type.Param): Type.Param = {
369+
val nameStr: String = tparam.name.value
349370
val bumpedTparamName: Type.Param.Name =
350-
if (tParamsToBump.contains(tparam.name.value))
351-
Type.Name(s"${tparam.name.value}$transKTypeSuffix")
371+
if (tParamsToBump.contains(nameStr))
372+
transKSuffixed(Type.Name(nameStr))
352373
else
353374
tparam.name
354375
val bumpedTParamTParams = tparam.tparams.map(tp => bumpTParam(tp))
355376

356-
val bumpedCBounds = tparam.cbounds.map(cb => bumpType(cb))
357-
val bumpedVBounds = tparam.vbounds.map(vb => bumpType(vb))
377+
val bumpedCBounds = tparam.cbounds.map(suffixTypeNames(tParamsToBump))
378+
val bumpedVBounds = tparam.vbounds.map(suffixTypeNames(tParamsToBump))
358379
val bumpedTBounds = {
359380
val Type.Bounds(lo, hi) = tparam.tbounds
360-
Type.Bounds(lo.map(l => bumpType(l)), hi.map(h => bumpType(h)))
381+
Type.Bounds(lo.map(suffixTypeNames(tParamsToBump)),
382+
hi.map(suffixTypeNames(tParamsToBump)))
361383
}
362384
tparam.copy(name = bumpedTparamName,
363385
tparams = bumpedTParamTParams,
@@ -366,50 +388,18 @@ object KTransImpl {
366388
cbounds = bumpedCBounds)
367389
}
368390

369-
def bumpType(tpe: Type): Type = tpe match {
370-
case tName @ Type.Name(v) if tParamsToBump.contains(v) => transKSuffixed(tName)
371-
case tApply @ Type.Apply(tpeInner, args) =>
372-
tApply.copy(tpe = bumpType(tpeInner), args = args.map(a => bumpType(a)))
373-
case tApplyInfix @ Type.ApplyInfix(lhs, opTName @ Type.Name(op), rhs) => {
374-
val opBumped =
375-
if (tParamsToBump.contains(op))
376-
transKSuffixed(opTName)
377-
else
378-
Type.Name(op)
379-
tApplyInfix.copy(lhs = bumpType(lhs), op = opBumped, rhs = bumpType(rhs))
380-
}
381-
case tWith @ Type.With(lhs, rhs) =>
382-
tWith.copy(lhs = bumpType(lhs), rhs = bumpType(rhs))
383-
case Type.Placeholder(Type.Bounds(lo, hi)) =>
384-
Type.Placeholder(
385-
Type.Bounds(lo = lo.map(l => bumpType(l)), hi = hi.map(h => bumpType(h))))
386-
case Type.And(lhs, rhs) => Type.And(lhs = bumpType(lhs), rhs = bumpType(rhs))
387-
case Type.Or(lhs, rhs) => Type.Or(lhs = bumpType(lhs), rhs = bumpType(rhs))
388-
case typeAnnotate @ Type.Annotate(t, _) =>
389-
typeAnnotate.copy(tpe = bumpType(t))
390-
case typeExist @ Type.Existential(t, _) => typeExist.copy(tpe = bumpType(t))
391-
case typeFunc @ Type.Function(params, res) => {
392-
val bumpedParams = params.map(transformTArgType(bumpType))
393-
val bumpedRes = bumpType(res)
394-
typeFunc.copy(params = bumpedParams, res = bumpedRes)
395-
}
396-
case typeRefine @ Type.Refine(maybeTpe, _) =>
397-
typeRefine.copy(tpe = maybeTpe.map(t => bumpType(t)))
398-
case Type.Tuple(tpes) => Type.Tuple(tpes.map(t => bumpType(t)))
399-
case Type.Project(q, tName @ Type.Name(v)) if tParamsToBump.contains(v) =>
400-
Type.Project(qual = bumpType(q), name = transKSuffixed(tName))
401-
case Type.Select(r, tName @ Type.Name(v)) if tParamsToBump.contains(v) =>
402-
Type.Select(r, transKSuffixed(tName))
403-
case other => other // Singleton, I believe ... which can't point to method type params ?
404-
}
405-
406391
val newtParams = meth.tparams.map { tparam =>
407392
bumpTParam(tparam)
408393
}
409-
val newDeclTpe = bumpType(meth.decltpe)
394+
// If the algebra kind wraps the return type, then suffix the kind too (G[_] -> GTransK[_])
395+
val newDeclTpe =
396+
if (algKindWrapped(meth.decltpe))
397+
suffixTypeNames(tParamsToBump + tparamName)(meth.decltpe)
398+
else
399+
suffixTypeNames(tParamsToBump)(meth.decltpe)
410400
val newParamss = meth.paramss.map { params =>
411401
params.map { param =>
412-
val bumpedTArg = param.decltpe.map(transformTArgType(bumpType))
402+
val bumpedTArg = param.decltpe.map(transformTArgType(suffixTypeNames(tParamsToBump)))
413403
param.copy(decltpe = bumpedTArg)
414404
}
415405
}
@@ -453,6 +443,48 @@ object KTransImpl {
453443
case Type.Arg.ByName(tpe) => Type.Arg.ByName(f(tpe))
454444
case t: Type => f(t)
455445
}
446+
447+
// Adds a suffix to a type if it matches any of the names given
448+
def suffixTypeNames(tNamesToSuffix: Set[String])(tpe: Type): Type = {
449+
def suffixTypeInner(tpe: Type): Type = tpe match {
450+
case tName @ Type.Name(v) if tNamesToSuffix.contains(v) => transKSuffixed(tName)
451+
case tApply @ Type.Apply(tpeInner, args) =>
452+
tApply.copy(tpe = suffixTypeInner(tpeInner), args = args.map(a => suffixTypeInner(a)))
453+
case tApplyInfix @ Type.ApplyInfix(lhs, opTName @ Type.Name(op), rhs) => {
454+
val opBumped =
455+
if (tNamesToSuffix.contains(op))
456+
transKSuffixed(opTName)
457+
else
458+
Type.Name(op)
459+
tApplyInfix.copy(lhs = suffixTypeInner(lhs), op = opBumped, rhs = suffixTypeInner(rhs))
460+
}
461+
case tWith @ Type.With(lhs, rhs) =>
462+
tWith.copy(lhs = suffixTypeInner(lhs), rhs = suffixTypeInner(rhs))
463+
case Type.Placeholder(Type.Bounds(lo, hi)) =>
464+
Type.Placeholder(
465+
Type.Bounds(lo = lo.map(l => suffixTypeInner(l)),
466+
hi = hi.map(h => suffixTypeInner(h))))
467+
case Type.And(lhs, rhs) => Type.And(lhs = suffixTypeInner(lhs), rhs = suffixTypeInner(rhs))
468+
case Type.Or(lhs, rhs) => Type.Or(lhs = suffixTypeInner(lhs), rhs = suffixTypeInner(rhs))
469+
case typeAnnotate @ Type.Annotate(t, _) =>
470+
typeAnnotate.copy(tpe = suffixTypeInner(t))
471+
case typeExist @ Type.Existential(t, _) => typeExist.copy(tpe = suffixTypeInner(t))
472+
case typeFunc @ Type.Function(params, res) => {
473+
val bumpedParams = params.map(transformTArgType(suffixTypeInner))
474+
val bumpedRes = suffixTypeInner(res)
475+
typeFunc.copy(params = bumpedParams, res = bumpedRes)
476+
}
477+
case typeRefine @ Type.Refine(maybeTpe, _) =>
478+
typeRefine.copy(tpe = maybeTpe.map(t => suffixTypeInner(t)))
479+
case Type.Tuple(tpes) => Type.Tuple(tpes.map(t => suffixTypeInner(t)))
480+
case Type.Project(q, tName @ Type.Name(v)) if tNamesToSuffix.contains(v) =>
481+
Type.Project(qual = suffixTypeInner(q), name = transKSuffixed(tName))
482+
case Type.Select(r, tName @ Type.Name(v)) if tNamesToSuffix.contains(v) =>
483+
Type.Select(r, transKSuffixed(tName))
484+
case other => other // Singleton, I believe ... which can't point to method type params ?
485+
}
486+
suffixTypeInner(tpe)
487+
}
456488
}
457489

458490
}

core/src/main/scala/diesel/internal/SupportedAnnottee.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,10 @@ trait SupportedAnnottee {
2626
case class TraitAnnottee(mods: Seq[Mod],
2727
tname: Type.Name,
2828
tparams: Seq[Type.Param],
29-
template: Template)
29+
template: Template,
30+
underlying: Stat)
3031
extends SupportedAnnottee {
3132

32-
def underlying: Stat = {
33-
q"..$mods trait $tname[..$tparams] extends $template"
34-
}
35-
3633
def appendStat(stat: Stat): Stat = {
3734
val newTempl =
3835
template.copy(stats = template.stats.map(s => s :+ stat).orElse(Some(Seq(stat))))
@@ -50,11 +47,9 @@ case class ClassAnnottee(mods: Seq[Mod],
5047
tname: Type.Name,
5148
tparams: Seq[Type.Param],
5249
ctor: Ctor.Primary,
53-
template: Template)
50+
template: Template,
51+
underlying: Stat)
5452
extends SupportedAnnottee {
55-
def underlying: Stat = {
56-
Defn.Class(mods, tname, tparams, ctor, template)
57-
}
5853

5954
def appendStat(stat: Stat): Stat = {
6055
val newTempl =
@@ -74,10 +69,10 @@ case class ClassAnnottee(mods: Seq[Mod],
7469
object SupportedAnnottee {
7570

7671
def unapply(tree: Tree): Option[SupportedAnnottee] = tree match {
77-
case q"..$mods trait $tname[..$tparams] extends $template" =>
78-
Some(TraitAnnottee(mods, tname, tparams, template))
79-
case Defn.Class(mods, tname, tparams, ctor, template) =>
80-
Some(ClassAnnottee(mods, tname, tparams, ctor, template))
72+
case stat @ Defn.Trait(mods, tname, tparams, _, template) =>
73+
Some(TraitAnnottee(mods, tname, tparams, template, stat))
74+
case stat @ Defn.Class(mods, tname, tparams, ctor, template) =>
75+
Some(ClassAnnottee(mods, tname, tparams, ctor, template, stat))
8176
case _ => None
8277
}
8378

core/src/test/scala/diesel/KtransAnnotationCompile.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package diesel
33
import cats._
44
import cats.kernel.Monoid
55

6+
import scala.util.Try
7+
68
object KtransAnnotationCompileTests {
79

810
@ktrans
@@ -11,11 +13,14 @@ object KtransAnnotationCompileTests {
1113
type Hey = Int
1214
val eh: Int = 3
1315

16+
def simpleAbstDefMeth(yo: Int): Option[Either[Boolean, Try[Seq[Double]]]]
17+
val simpleAbstValMeth: Option[Either[Boolean, Seq[Try[Double]]]]
18+
1419
def implementedDef(a: Int): Option[Boolean] = None
1520

1621
val valThing, valThing2: F[Int]
1722
protected val protValThing: F[Option[Int]]
18-
protected[diesel] val packProtValThing: F[Option[Int]]
23+
protected[diesel] val packProtValThing: F[Option[Try[Either[String,Int]]]]
1924
private[diesel] val packPrivValThing: F[Option[Int]]
2025

2126
def noArg: F[Int]
@@ -34,8 +39,10 @@ object KtransAnnotationCompileTests {
3439

3540
// The following should fail compilation if uncommented
3641
// private def wut: F[Byte]
37-
// def hmm(yo: F[Int]): Int
3842
// var eh = 9
43+
// def hmm(yo: F[Int]): Int
44+
// def hmmSuperNestedKParam(yo: Option[Either[Boolean, Try[Seq[F[Double]]]]]): Int
45+
// def hmmSuperNestedKRet(yo: Option[Either[Boolean, Try[Seq[Double]]]]): Option[Either[Boolean, Try[Seq[F[Double]]]]]
3946
// type Yo[A] = F[A]
4047
// protected def shadowGames[F[_]]: F[Int]
4148

0 commit comments

Comments
 (0)