Skip to content

Commit 775f8fd

Browse files
committed
Fix memoization on high-arity functions and extension functions
1 parent 429ed2d commit 775f8fd

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

compiler-plugin/src/main/kotlin/Memoizer.kt

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,30 @@ class Memoizer(private val context: IrPluginContext) : IrElementTransformerVoid(
4242
Name.identifier("mutableMapOf")
4343
)).single { it.owner.valueParameters.isEmpty() }
4444

45+
private val listOf = context.referenceFunctions(CallableId(
46+
FqName("kotlin.collections"),
47+
null,
48+
Name.identifier("listOf")
49+
)).single { it.owner.valueParameters.singleOrNull()?.isVararg == true }
50+
4551
private val pair = context.referenceClass(ClassId(FqName("kotlin"), FqName("Pair"), false))!!
4652
private val triple = context.referenceClass(ClassId(FqName("kotlin"), FqName("Triple"), false))!!
4753

4854
private val memoizeAnnotation = FqName("com.sschr15.aoc.annotations.Memoize")
4955

50-
private fun IrPluginContext.keyFor(declaration: IrFunction): IrType = when (declaration.valueParameters.size) {
51-
1 -> declaration.valueParameters.single().type
52-
2 -> pair.typeWith(declaration.valueParameters.map { it.type })
53-
3 -> triple.typeWith(declaration.valueParameters.map { it.type })
54-
else -> irBuiltIns.arrayClass.typeWith(irBuiltIns.anyType)
56+
private fun IrPluginContext.keyFor(params: List<IrValueParameter>): IrType = when (params.size) {
57+
1 -> params.single().type
58+
2 -> pair.typeWith(params.map { it.type })
59+
3 -> triple.typeWith(params.map { it.type })
60+
else -> irBuiltIns.listClass.typeWith(irBuiltIns.anyType)
61+
}
62+
63+
fun IrPluginContext.keyFor(declaration: IrFunction): IrType {
64+
val params = declaration.valueParameters.toMutableList()
65+
if (declaration.extensionReceiverParameter != null) {
66+
params.add(0, declaration.extensionReceiverParameter!!)
67+
}
68+
return keyFor(params)
5569
}
5670

5771
override fun visitFunction(declaration: IrFunction): IrStatement {
@@ -137,20 +151,28 @@ class Memoizer(private val context: IrPluginContext) : IrElementTransformerVoid(
137151
return declaration
138152
}
139153

154+
fun IrBuilderWithScope.createKeyFor(declaration: IrFunction): IrExpression {
155+
val params = declaration.valueParameters.toMutableList()
156+
if (declaration.extensionReceiverParameter != null) {
157+
params.add(0, declaration.extensionReceiverParameter!!)
158+
}
159+
return createKeyFor(params)
160+
}
161+
140162
@OptIn(UnsafeDuringIrConstructionAPI::class)
141-
private fun IrBuilderWithScope.createKeyFor(function: IrFunction): IrExpression = when (function.valueParameters.size) {
142-
1 -> irGet(function.valueParameters.single())
163+
private fun IrBuilderWithScope.createKeyFor(params: List<IrValueParameter>): IrExpression = when (params.size) {
164+
1 -> irGet(params.single())
143165
2 -> irCall(pair.constructors.single()).apply {
144-
putValueArgument(0, irGet(function.valueParameters[0]))
145-
putValueArgument(1, irGet(function.valueParameters[1]))
166+
putValueArgument(0, irGet(params[0]))
167+
putValueArgument(1, irGet(params[1]))
146168
}
147169
3 -> irCall(triple.constructors.single()).apply {
148-
putValueArgument(0, irGet(function.valueParameters[0]))
149-
putValueArgument(1, irGet(function.valueParameters[1]))
150-
putValueArgument(2, irGet(function.valueParameters[2]))
170+
putValueArgument(0, irGet(params[0]))
171+
putValueArgument(1, irGet(params[1]))
172+
putValueArgument(2, irGet(params[2]))
151173
}
152-
else -> irCall(context.irBuiltIns.arrayOf).apply {
153-
function.valueParameters.forEachIndexed { index, parameter ->
174+
else -> irCall(listOf).apply {
175+
params.forEachIndexed { index, parameter ->
154176
putValueArgument(index, irGet(parameter))
155177
}
156178
}

0 commit comments

Comments
 (0)