@@ -6,12 +6,21 @@ Authors: Dany Fabian
66module
77
88prelude
9- public import Lean.Meta.Transform
10- public import Lean.Elab.Deriving.Basic
11- public import Lean.Elab.Deriving.Util
9+ public import Lean.Data.Options
10+ import Lean.Meta.Transform
11+ import Lean.Elab.Deriving.Basic
12+ import Lean.Elab.Deriving.Util
13+ import Lean.Meta.Constructions.CtorIdx
14+ import Lean.Meta.Constructions.CasesOnSameCtor
1215import Lean.Meta.SameCtorUtils
1316
14- public section
17+ register_builtin_option deriving.ord.linear_construction_threshold : Nat := {
18+ defValue := 10
19+ descr := "If the inductive data type has this many or more constructors, use a different \
20+ implementation for implementing `Ord` that avoids the quadratic code size produced by the \
21+ default implementation.\n\n \
22+ The alternative construction compiles to less efficient code in some cases, so by default \
23+ it is only used for inductive types with 10 or more constructors." }
1524
1625namespace Lean.Elab.Deriving.Ord
1726open Lean.Parser.Term
@@ -20,7 +29,7 @@ open Meta
2029def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do
2130 mkHeader `Ord 2 indVal
2231
23- def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
32+ def mkMatchOld (header : Header) (indVal : InductiveVal) : TermElabM Term := do
2433 let discrs ← mkDiscrs header indVal
2534 let alts ← mkAlts
2635 `(match $[$discrs],* with $alts:matchAlt*)
7483 alts := alts ++ (alt : Array (TSyntax ``matchAlt))
7584 return alts.pop.pop
7685
86+ def mkMatchNew (header : Header) (indVal : InductiveVal) : TermElabM Term := do
87+ assert! header.targetNames.size == 2
88+
89+ let x1 := mkIdent header.targetNames[0 ]!
90+ let x2 := mkIdent header.targetNames[1 ]!
91+ let ctorIdxName := mkCtorIdxName indVal.name
92+ -- NB: the getMatcherInfo? assumes all mathcers are called `match_`
93+ let casesOnSameCtorName ← mkFreshUserName (indVal.name ++ `match_on_same_ctor)
94+ mkCasesOnSameCtor casesOnSameCtorName indVal.name
95+ let alts ← Array.ofFnM (n := indVal.numCtors) fun ⟨ctorIdx, _⟩ => do
96+ let ctorName := indVal.ctors[ctorIdx]!
97+ let ctorInfo ← getConstInfoCtor ctorName
98+ forallTelescopeReducing ctorInfo.type fun xs type => do
99+ let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
100+ let mut ctorArgs1 : Array Term := #[]
101+ let mut ctorArgs2 : Array Term := #[]
102+
103+ let mut rhsCont : Term → TermElabM Term := fun rhs => pure rhs
104+ for i in *...ctorInfo.numFields do
105+ let x := xs[indVal.numParams + i]!
106+ if occursOrInType (← getLCtx) x type then
107+ -- If resulting type depends on this field, we don't need to compare
108+ -- and the casesOnSameCtor only has a parameter for it once
109+ ctorArgs1 := ctorArgs1.push (← `(_))
110+ else
111+ let userName ← x.fvarId!.getUserName
112+ let a := mkIdent (← mkFreshUserName userName)
113+ let b := mkIdent (← mkFreshUserName (userName.appendAfter "'" ))
114+ ctorArgs1 := ctorArgs1.push a
115+ ctorArgs2 := ctorArgs2.push b
116+ let xType ← inferType x
117+ if (← isProp xType) then
118+ continue
119+ else
120+ rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont
121+ let rhs ← rhsCont (← `(Ordering.eq))
122+ `(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term)
123+ if indVal.numCtors == 1 then
124+ `( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* )
125+ else
126+ `( match h : compare ($(mkCIdent ctorIdxName) $x1:ident) ($(mkCIdent ctorIdxName) $x2:ident) with
127+ | Ordering.lt => Ordering.lt
128+ | Ordering.gt => Ordering.gt
129+ | Ordering.eq =>
130+ $(mkCIdent casesOnSameCtorName) $x1:term $x2:term (Nat.compare_eq_eq.mp h) $alts:term*
131+ )
132+
133+ def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
134+ if indVal.numCtors ≥ deriving.ord.linear_construction_threshold.get (← getOptions) then
135+ mkMatchNew header indVal
136+ else
137+ mkMatchOld header indVal
138+
77139def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
78140 let auxFunName := ctx.auxFunNames[i]!
79141 let indVal := ctx.typeInfos[i]!
@@ -105,13 +167,31 @@ private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
105167 trace[Elab.Deriving.ord] "\n {cmds}"
106168 return cmds
107169
170+ private def mkOrdEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
171+ let auxFunName := ctx.auxFunNames[0 ]!
172+ `(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Ordering := compare x.ctorIdx y.ctorIdx)
173+
174+ private def mkOrdEnumCmd (name : Name): TermElabM (Array Syntax) := do
175+ let ctx ← mkContext ``Ord "ord" name
176+ let cmds := #[← mkOrdEnumFun ctx name] ++ (← mkInstanceCmds ctx `Ord #[name])
177+ trace[Elab.Deriving.ord] "\n {cmds}"
178+ return cmds
179+
108180open Command
109181
182+ def mkOrdInstance (declName : Name) : CommandElabM Unit := do
183+ withoutExposeFromCtors declName do
184+ let cmds ← liftTermElabM <|
185+ if (← isEnumType declName) then
186+ mkOrdEnumCmd declName
187+ else
188+ mkOrdInstanceCmds declName
189+ cmds.forM elabCommand
190+
110191def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
111192 if (← declNames.allM isInductive) then
112193 for declName in declNames do
113- let cmds ← withoutExposeFromCtors declName <| liftTermElabM <| mkOrdInstanceCmds declName
114- cmds.forM elabCommand
194+ mkOrdInstance declName
115195 return true
116196 else
117197 return false
0 commit comments