Skip to content

Commit 38b4062

Browse files
authored
feat: linear-size Ord instance (#10270)
This PR adds an alternative implementation of `Deriving Ord` based on comparing `.ctorIdx` and using a dedicated matcher for comparing same constructors (added in #10152). The new option `deriving.ord.linear_construction_threshold` sets the constructor count threshold (10 by default) for using the new construction. It also (unconditionally) changes the implementation for enumeration types to simply compare the `ctorIdx`.
1 parent ae8dc41 commit 38b4062

File tree

3 files changed

+90
-9
lines changed

3 files changed

+90
-9
lines changed

src/Lean/DocString/Types.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ prelude
1010

1111
public import Init.Data.Repr
1212
public import Init.Data.Ord
13+
import Init.Data.Nat.Compare
1314

1415
set_option linter.missingDocs true
1516

src/Lean/Elab/Deriving/Ord.lean

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,21 @@ Authors: Dany Fabian
66
module
77

88
prelude
9-
public import Lean.Meta.Transform
10-
public import Lean.Elab.Deriving.Basic
11-
public import Lean.Elab.Deriving.Util
9+
public import Lean.Data.Options
10+
import Lean.Meta.Transform
11+
import Lean.Elab.Deriving.Basic
12+
import Lean.Elab.Deriving.Util
13+
import Lean.Meta.Constructions.CtorIdx
14+
import Lean.Meta.Constructions.CasesOnSameCtor
1215
import Lean.Meta.SameCtorUtils
1316

14-
public section
17+
register_builtin_option deriving.ord.linear_construction_threshold : Nat := {
18+
defValue := 10
19+
descr := "If the inductive data type has this many or more constructors, use a different \
20+
implementation for implementing `Ord` that avoids the quadratic code size produced by the \
21+
default implementation.\n\n\
22+
The alternative construction compiles to less efficient code in some cases, so by default \
23+
it is only used for inductive types with 10 or more constructors." }
1524

1625
namespace Lean.Elab.Deriving.Ord
1726
open Lean.Parser.Term
@@ -20,7 +29,7 @@ open Meta
2029
def mkOrdHeader (indVal : InductiveVal) : TermElabM Header := do
2130
mkHeader `Ord 2 indVal
2231

23-
def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
32+
def mkMatchOld (header : Header) (indVal : InductiveVal) : TermElabM Term := do
2433
let discrs ← mkDiscrs header indVal
2534
let alts ← mkAlts
2635
`(match $[$discrs],* with $alts:matchAlt*)
@@ -74,6 +83,59 @@ where
7483
alts := alts ++ (alt : Array (TSyntax ``matchAlt))
7584
return alts.pop.pop
7685

86+
def mkMatchNew (header : Header) (indVal : InductiveVal) : TermElabM Term := do
87+
assert! header.targetNames.size == 2
88+
89+
let x1 := mkIdent header.targetNames[0]!
90+
let x2 := mkIdent header.targetNames[1]!
91+
let ctorIdxName := mkCtorIdxName indVal.name
92+
-- NB: the getMatcherInfo? assumes all mathcers are called `match_`
93+
let casesOnSameCtorName ← mkFreshUserName (indVal.name ++ `match_on_same_ctor)
94+
mkCasesOnSameCtor casesOnSameCtorName indVal.name
95+
let alts ← Array.ofFnM (n := indVal.numCtors) fun ⟨ctorIdx, _⟩ => do
96+
let ctorName := indVal.ctors[ctorIdx]!
97+
let ctorInfo ← getConstInfoCtor ctorName
98+
forallTelescopeReducing ctorInfo.type fun xs type => do
99+
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
100+
let mut ctorArgs1 : Array Term := #[]
101+
let mut ctorArgs2 : Array Term := #[]
102+
103+
let mut rhsCont : Term → TermElabM Term := fun rhs => pure rhs
104+
for i in *...ctorInfo.numFields do
105+
let x := xs[indVal.numParams + i]!
106+
if occursOrInType (← getLCtx) x type then
107+
-- If resulting type depends on this field, we don't need to compare
108+
-- and the casesOnSameCtor only has a parameter for it once
109+
ctorArgs1 := ctorArgs1.push (← `(_))
110+
else
111+
let userName ← x.fvarId!.getUserName
112+
let a := mkIdent (← mkFreshUserName userName)
113+
let b := mkIdent (← mkFreshUserName (userName.appendAfter "'"))
114+
ctorArgs1 := ctorArgs1.push a
115+
ctorArgs2 := ctorArgs2.push b
116+
let xType ← inferType x
117+
if (← isProp xType) then
118+
continue
119+
else
120+
rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont
121+
let rhs ← rhsCont (← `(Ordering.eq))
122+
`(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term)
123+
if indVal.numCtors == 1 then
124+
`( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* )
125+
else
126+
`( match h : compare ($(mkCIdent ctorIdxName) $x1:ident) ($(mkCIdent ctorIdxName) $x2:ident) with
127+
| Ordering.lt => Ordering.lt
128+
| Ordering.gt => Ordering.gt
129+
| Ordering.eq =>
130+
$(mkCIdent casesOnSameCtorName) $x1:term $x2:term (Nat.compare_eq_eq.mp h) $alts:term*
131+
)
132+
133+
def mkMatch (header : Header) (indVal : InductiveVal) : TermElabM Term := do
134+
if indVal.numCtors ≥ deriving.ord.linear_construction_threshold.get (← getOptions) then
135+
mkMatchNew header indVal
136+
else
137+
mkMatchOld header indVal
138+
77139
def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
78140
let auxFunName := ctx.auxFunNames[i]!
79141
let indVal := ctx.typeInfos[i]!
@@ -105,13 +167,31 @@ private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
105167
trace[Elab.Deriving.ord] "\n{cmds}"
106168
return cmds
107169

170+
private def mkOrdEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
171+
let auxFunName := ctx.auxFunNames[0]!
172+
`(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Ordering := compare x.ctorIdx y.ctorIdx)
173+
174+
private def mkOrdEnumCmd (name : Name): TermElabM (Array Syntax) := do
175+
let ctx ← mkContext ``Ord "ord" name
176+
let cmds := #[← mkOrdEnumFun ctx name] ++ (← mkInstanceCmds ctx `Ord #[name])
177+
trace[Elab.Deriving.ord] "\n{cmds}"
178+
return cmds
179+
108180
open Command
109181

182+
def mkOrdInstance (declName : Name) : CommandElabM Unit := do
183+
withoutExposeFromCtors declName do
184+
let cmds ← liftTermElabM <|
185+
if (← isEnumType declName) then
186+
mkOrdEnumCmd declName
187+
else
188+
mkOrdInstanceCmds declName
189+
cmds.forM elabCommand
190+
110191
def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
111192
if (← declNames.allM isInductive) then
112193
for declName in declNames do
113-
let cmds ← withoutExposeFromCtors declName <| liftTermElabM <| mkOrdInstanceCmds declName
114-
cmds.forM elabCommand
194+
mkOrdInstance declName
115195
return true
116196
else
117197
return false

tests/lean/run/Ord.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ inductive ManyConstructors | A | B | C | D | E | F | G | H | I | J | K | L
1919
| M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
2020
deriving Ord
2121

22-
structure Person :=
22+
structure Person where
2323
firstName : String
2424
lastName : String
2525
age : Nat
2626
deriving Ord
2727

2828
example : compare { firstName := "A", lastName := "B", age := 10 : Person } ⟨"B", "A", 9⟩ = Ordering.lt := rfl
2929

30-
structure Company :=
30+
structure Company where
3131
name : String
3232
ceo : Person
3333
numberOfEmployees : Nat

0 commit comments

Comments
 (0)