|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Cameron Zwarich |
| 5 | +-/ |
| 6 | +prelude |
| 7 | +import Lean.Compiler.LCNF.Basic |
| 8 | +import Lean.Compiler.LCNF.InferType |
| 9 | +import Lean.Compiler.LCNF.MonoTypes |
| 10 | +import Lean.Compiler.LCNF.PassManager |
| 11 | +import Lean.Compiler.LCNF.PrettyPrinter |
| 12 | + |
| 13 | +namespace Lean.Compiler.LCNF |
| 14 | +namespace StructProjCases |
| 15 | + |
| 16 | +def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do |
| 17 | + let .inductInfo info ← getConstInfo typeName | return none |
| 18 | + let [ctorName] := info.ctors | return none |
| 19 | + let some (.ctorInfo ctorInfo) := (← getEnv).find? ctorName | return none |
| 20 | + return ctorInfo |
| 21 | + |
| 22 | +def mkFieldParamsForCtorType (e : Expr) (numParams : Nat): CompilerM (Array Param) := do |
| 23 | + let rec loop (params : Array Param) (e : Expr) (numParams : Nat): CompilerM (Array Param) := do |
| 24 | + match e with |
| 25 | + | .forallE name type body _ => |
| 26 | + if numParams == 0 then |
| 27 | + let param ← mkParam name (← toMonoType type) false |
| 28 | + loop (params.push param) body numParams |
| 29 | + else |
| 30 | + loop params body (numParams - 1) |
| 31 | + | _ => return params |
| 32 | + loop #[] e numParams |
| 33 | + |
| 34 | +structure StructProjState where |
| 35 | + projMap : Std.HashMap FVarId (Array FVarId) := {} |
| 36 | + fvarMap : Std.HashMap FVarId FVarId := {} |
| 37 | + |
| 38 | +abbrev M := StateRefT StructProjState CompilerM |
| 39 | + |
| 40 | +def M.run (x : M α) : CompilerM α := do |
| 41 | + x.run' {} |
| 42 | + |
| 43 | +def remapFVar (fvarId : FVarId) : M FVarId := do |
| 44 | + match (← get).fvarMap[fvarId]? with |
| 45 | + | some newFvarId => return newFvarId |
| 46 | + | none => return fvarId |
| 47 | + |
| 48 | +mutual |
| 49 | + |
| 50 | +partial def visitCode (code : Code) : M Code := do |
| 51 | + match code with |
| 52 | + | .let decl k => |
| 53 | + match decl.value with |
| 54 | + | .proj typeName i base => |
| 55 | + eraseLetDecl decl |
| 56 | + let base ← remapFVar base |
| 57 | + if let some projVars := (← get).projMap[base]? then |
| 58 | + modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId projVars[i]! } |
| 59 | + visitCode k |
| 60 | + else |
| 61 | + let some ctorInfo ← findStructCtorInfo? typeName | panic! "expected struct constructor" |
| 62 | + let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams |
| 63 | + assert! params.size == ctorInfo.numFields |
| 64 | + let fvars := params.map (·.fvarId) |
| 65 | + modify fun s => { s with projMap := s.projMap.insert base fvars, |
| 66 | + fvarMap := s.fvarMap.insert decl.fvarId fvars[i]! } |
| 67 | + let k ← visitCode k |
| 68 | + modify fun s => { s with projMap := s.projMap.erase base } |
| 69 | + let resultType ← toMonoType (← k.inferType) |
| 70 | + let alts := #[.alt ctorInfo.name params k] |
| 71 | + return .cases { typeName, resultType, discr := base, alts } |
| 72 | + | _ => return code.updateLet! (← decl.updateValue (← visitLetValue decl.value)) (← visitCode k) |
| 73 | + | .fun decl k => |
| 74 | + let decl ← decl.updateValue (← visitCode decl.value) |
| 75 | + return code.updateFun! decl (← visitCode k) |
| 76 | + | .jp decl k => |
| 77 | + let decl ← decl.updateValue (← visitCode decl.value) |
| 78 | + return code.updateFun! decl (← visitCode k) |
| 79 | + | .jmp fvarId args => |
| 80 | + return code.updateJmp! (← remapFVar fvarId) (← args.mapM visitArg) |
| 81 | + | .cases cases => |
| 82 | + let discr ← remapFVar cases.discr |
| 83 | + if let #[.alt ctorName params k] := cases.alts then |
| 84 | + if let some projVars := (← get).projMap[discr]? then |
| 85 | + assert! projVars.size == params.size |
| 86 | + for param in params, projVar in projVars do |
| 87 | + modify fun s => { s with fvarMap := s.fvarMap.insert param.fvarId projVar } |
| 88 | + eraseParam param |
| 89 | + visitCode k |
| 90 | + else |
| 91 | + let fvars := params.map (·.fvarId) |
| 92 | + modify fun s => { s with projMap := s.projMap.insert discr fvars } |
| 93 | + let k ← visitCode k |
| 94 | + modify fun s => { s with projMap := s.projMap.erase discr } |
| 95 | + -- TODO: This should preserve the .alt allocation, but binding it to |
| 96 | + -- a variable above while also destructuring an array doesn't work. |
| 97 | + return code.updateCases! cases.resultType discr #[.alt ctorName params k] |
| 98 | + else |
| 99 | + let alts ← cases.alts.mapM (visitAlt ·) |
| 100 | + return code.updateCases! cases.resultType discr alts |
| 101 | + | .return fvarId => return code.updateReturn! (← remapFVar fvarId) |
| 102 | + | .unreach .. => return code |
| 103 | + |
| 104 | +partial def visitLetValue (v : LetValue) : M LetValue := do |
| 105 | + match v with |
| 106 | + | .const _ _ args => |
| 107 | + return v.updateArgs! (← args.mapM visitArg) |
| 108 | + | .fvar fvarId args => |
| 109 | + return v.updateFVar! (← remapFVar fvarId) (← args.mapM visitArg) |
| 110 | + | .value _ | .erased => return v |
| 111 | + -- Projections should be handled directly by `visitCode`. |
| 112 | + | .proj .. => unreachable! |
| 113 | + |
| 114 | +partial def visitAlt (alt : Alt) : M Alt := do |
| 115 | + return alt.updateCode (← visitCode alt.getCode) |
| 116 | + |
| 117 | +partial def visitArg (arg : Arg) : M Arg := |
| 118 | + match arg with |
| 119 | + | .fvar fvarId => return arg.updateFVar! (← remapFVar fvarId) |
| 120 | + | .type _ | .erased => return arg |
| 121 | + |
| 122 | +end |
| 123 | + |
| 124 | +def visitDecl (decl : Decl) : M Decl := do |
| 125 | + let value ← decl.value.mapCodeM (visitCode ·) |
| 126 | + return { decl with value } |
| 127 | + |
| 128 | +end StructProjCases |
| 129 | + |
| 130 | +def structProjCases : Pass := |
| 131 | + .mkPerDeclaration `structProjCases (StructProjCases.visitDecl · |>.run) .mono |
| 132 | + |
| 133 | +end Lean.Compiler.LCNF |
0 commit comments