@@ -131,10 +131,11 @@ structure State where
131131 `transDeps[i]` is the (non-reflexive) transitive closure of `mods[i].imports`. More specifically,
132132 * `j ∈ transDeps[i].pub` if `i -(public import)->+ j`
133133 * `j ∈ transDeps[i].priv` if `i -(import ...)-> _ -(public import)->* j`
134- * `j ∈ transDeps[i].priv` if `i -(import all)->+ -(public import ...)-> _ -(public import)->* j`
135- * `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import ...)->* j`
136- * `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import ...)->* j`
137- * `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ -(public meta import ...)-> _ -(public (meta)? import ...)->* j`
134+ * `j ∈ transDeps[i].priv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].pub/priv`
135+ * `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import)->* j`
136+ * `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import)->* j`
137+ * `j ∈ transDeps[i].metaPriv` if `i -(import ...)-> i'` and `j ∈ transDeps[i'].metaPub`
138+ * `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].metaPub/metaPriv`
138139 -/
139140 transDeps : Array Needs := #[]
140141 /--
@@ -162,21 +163,24 @@ def addTransitiveImps (transImps : Needs) (imp : Import) (j : Nat) (impTransImps
162163 -- `j ∈ transDeps[i].priv` if `i -(import ...)-> _ -(public import)->* j`
163164 transImps := transImps.union .priv {j} |>.union .priv (impTransImps.get .pub)
164165 if imp.importAll then
165- -- `j ∈ transDeps[i].priv` if `i -(import all)->+ -(public import ...)-> _ -(public import)->* j `
166- transImps := transImps.union .priv (impTransImps.get .pub)
166+ -- `j ∈ transDeps[i].priv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].pub/priv `
167+ transImps := transImps.union .priv (impTransImps.get .pub ∪ impTransImps.get .priv )
167168
168- -- `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import ... )->* j`
169+ -- `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import)->* j`
169170 if imp.isExported then
170171 transImps := transImps.union .metaPub (impTransImps.get .metaPub)
171172 if imp.isMeta then
172173 transImps := transImps.union .metaPub {j} |>.union .metaPub (impTransImps.get .pub ∪ impTransImps.get .metaPub)
173174
174175 if !imp.isExported then
175176 if imp.isMeta then
176- -- `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import ... )->* j`
177+ -- `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import)->* j`
177178 transImps := transImps.union .metaPriv {j} |>.union .metaPriv (impTransImps.get .pub ∪ impTransImps.get .metaPub)
178179 if imp.importAll then
179- -- `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ -(public meta import ...)-> _ -(public (meta)? import ...)->* j`
180+ -- `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].metaPub/metaPriv`
181+ transImps := transImps.union .metaPriv (impTransImps.get .metaPub ∪ impTransImps.get .metaPriv)
182+ else
183+ -- `j ∈ transDeps[i].metaPriv` if `i -(import ...)-> i'` and `j ∈ transDeps[i'].metaPub`
180184 transImps := transImps.union .metaPriv (impTransImps.get .metaPub)
181185
182186 transImps
@@ -185,7 +189,8 @@ def addTransitiveImps (transImps : Needs) (imp : Import) (j : Nat) (impTransImps
185189def calcNeeds (env : Environment) (i : ModuleIdx) : Needs := Id.run do
186190 let mut needs := default
187191 for ci in env.header.moduleData[i]!.constants do
188- let pubCI? := env.setExporting true |>.find? ci.name
192+ -- Added guard for cases like `structure` that are still exported even if private
193+ let pubCI? := guard (!isPrivateName ci.name) *> (env.setExporting true ).find? ci.name
189194 let k := { isExported := pubCI?.isSome, isMeta := isMeta env ci.name }
190195 needs := visitExpr k ci.type needs
191196 if let some e := ci.value? (allowOpaque := true ) then
@@ -216,7 +221,8 @@ def getExplanations (env : Environment) (i : ModuleIdx) :
216221 Std.HashMap (ModuleIdx × NeedsKind) (Option (Name × Name)) := Id.run do
217222 let mut deps := default
218223 for ci in env.header.moduleData[i]!.constants do
219- let pubCI? := env.setExporting true |>.find? ci.name
224+ -- Added guard for cases like `structure` that are still exported even if private
225+ let pubCI? := guard (!isPrivateName ci.name) *> (env.setExporting true ).find? ci.name
220226 let k := { isExported := pubCI?.isSome, isMeta := isMeta env ci.name }
221227 deps := visitExpr k ci.name ci.type deps
222228 if let some e := ci.value? (allowOpaque := true ) then
@@ -286,16 +292,16 @@ and `endPos` is the position of the end of the header.
286292-/
287293def parseHeaderFromString (text path : String) :
288294 IO (System.FilePath × Parser.InputContext ×
289- TSyntaxArray ``Parser.Module.import × String.Pos) := do
295+ TSyntax ``Parser.Module.header × String.Pos.Raw ) := do
290296 let inputCtx := Parser.mkInputContext text path
291297 let (header, parserState, msgs) ← Parser.parseHeader inputCtx
292298 if !msgs.toList.isEmpty then -- skip this file if there are parse errors
293299 msgs.forM fun msg => msg.toString >>= IO.println
294300 throw <| .userError "parse errors in file"
295301 -- the insertion point for `add` is the first newline after the imports
296302 let insertion := header.raw.getTailPos?.getD parserState.pos
297- let insertion := text.findAux (· == '\n ' ) text.endPos insertion + ⟨ 1 ⟩
298- pure (path, inputCtx, .mk header.raw[ 2 ].getArgs , insertion)
303+ let insertion := text.findAux (· == '\n ' ) text.endPos insertion + ' \n '
304+ pure (path, inputCtx, header, insertion)
299305
300306/-- Parse a source file to extract the location of the import lines, for edits and error messages.
301307
@@ -304,13 +310,18 @@ and `endPos` is the position of the end of the header.
304310-/
305311def parseHeader (srcSearchPath : SearchPath) (mod : Name) :
306312 IO (System.FilePath × Parser.InputContext ×
307- TSyntaxArray ``Parser.Module.import × String.Pos) := do
313+ TSyntax ``Parser.Module.header × String.Pos.Raw ) := do
308314 -- Parse the input file
309315 let some path ← srcSearchPath.findModuleWithExt "lean" mod
310316 | throw <| .userError s! "error: failed to find source file for { mod} "
311317 let text ← IO.FS.readFile path
312318 parseHeaderFromString text path.toString
313319
320+ def decodeHeader : TSyntax ``Parser.Module.header → Option (TSyntax `module) × Option (TSyntax `prelude) × TSyntaxArray ``Parser.Module.import
321+ | `(Parser.Module.header| $[module%$moduleTk?]? $[prelude %$preludeTk?]? $imports*) =>
322+ (moduleTk?.map .mk, preludeTk?.map .mk, imports)
323+ | _ => unreachable!
324+
314325def decodeImport : TSyntax ``Parser.Module.import → Import
315326 | `(Parser.Module.import| $[public%$pubTk?]? $[meta%$metaTk?]? import $[all%$allTk?]? $id) =>
316327 { module := id.getId, isExported := pubTk?.isSome, isMeta := metaTk?.isSome, importAll := allTk?.isSome }
@@ -326,11 +337,20 @@ def decodeImport : TSyntax ``Parser.Module.import → Import
326337* `addOnly`: if true, only add missing imports, do not remove unused ones
327338 -/
328339def visitModule (srcSearchPath : SearchPath)
329- (i : Nat) (needs : Needs) (preserve : Needs) (edits : Edits)
340+ (i : Nat) (needs : Needs) (preserve : Needs) (edits : Edits) (headerStx : TSyntax ``Parser.Module.header)
330341 (addOnly := false ) (githubStyle := false ) (explain := false ) : StateT State IO Edits := do
331342 let s ← get
332343 -- Do transitive reduction of `needs` in `deps`.
333344 let mut deps := needs
345+ let (_, prelude ?, imports) := decodeHeader headerStx
346+ if prelude ?.isNone then
347+ deps := deps.union .pub {s.env.getModuleIdx? `Init |>.get!}
348+ for imp in imports do
349+ if addOnly || imp.raw.getTrailing?.any (·.toString.toSlice.contains "shake: keep" ) then
350+ let imp := decodeImport imp
351+ let j := s.env.getModuleIdx? imp.module |>.get!
352+ let k := NeedsKind.ofImport imp
353+ deps := deps.union k {j}
334354 for j in [0 :s.mods.size] do
335355 let transDeps := s.transDeps[j]!
336356 for k in NeedsKind.all do
@@ -354,7 +374,8 @@ def visitModule (srcSearchPath : SearchPath)
354374 newDeps := addTransitiveImps newDeps imp j s.transDeps[j]!
355375 else
356376 let k := NeedsKind.ofImport imp
357- if !addOnly && !deps.has k j && !deps.has { k with isExported := false } j then
377+ -- A private import should also be removed if the public version is needed
378+ if !deps.has k j || !k.isExported && deps.has { k with isExported := true } j then
358379 toRemove := toRemove.push imp
359380 else
360381 newDeps := addTransitiveImps newDeps imp j s.transDeps[j]!
@@ -385,7 +406,8 @@ def visitModule (srcSearchPath : SearchPath)
385406
386407 if githubStyle then
387408 try
388- let (path, inputCtx, imports, endHeader) ← parseHeader srcSearchPath s.modNames[i]!
409+ let (path, inputCtx, stx, endHeader) ← parseHeader srcSearchPath s.modNames[i]!
410+ let (_, _, imports) := decodeHeader stx
389411 for stx in imports do
390412 if toRemove.any fun imp => imp == decodeImport stx then
391413 let pos := inputCtx.fileMap.toPosition stx.raw.getPos?.get!
@@ -529,41 +551,51 @@ def main (args : List String) : IO UInt32 := do
529551 let needs := s.mods.mapIdx fun i _ =>
530552 Task.spawn fun _ => calcNeeds s.env i
531553
554+ -- Parse headers in parallel
555+ let headers ← s.mods.mapIdxM fun i _ =>
556+ BaseIO.asTask (parseHeader srcSearchPath s.modNames[i]! |>.toBaseIO)
557+
532558 if args.fix then
533559 println! "The following changes will be made automatically:"
534560
535561 -- Check all selected modules
536562 let mut edits : Edits := ∅
537563 let mut revNeeds : Needs := default
538- for i in [0 :s.mods.size], t in needs do
539- edits ← visitModule (addOnly := !pkg.isPrefixOf s.modNames[i]!) srcSearchPath i t.get revNeeds edits args.githubStyle args.explain
540- if isExtraRevModUse s.env i then
541- revNeeds := revNeeds.union .priv {i}
564+ for i in [0 :s.mods.size], t in needs, header in headers do
565+ match header.get with
566+ | .ok (_, _, stx, _) =>
567+ edits ← visitModule (addOnly := !pkg.isPrefixOf s.modNames[i]!)
568+ srcSearchPath i t.get revNeeds edits stx args.githubStyle args.explain
569+ if isExtraRevModUse s.env i then
570+ revNeeds := revNeeds.union .priv {i}
571+ | .error e =>
572+ println! e.toString
542573
543574 if !args.fix then
544575 -- return error if any issues were found
545576 return if edits.isEmpty then 0 else 1
546577
547578 -- Apply the edits to existing files
548- let count ← edits.foldM (init := 0 ) fun count mod (remove, add) => do
579+ let mut count := 0
580+ for mod in s.modNames, header? in headers do
581+ let some (remove, add) := edits[mod]? | continue
549582 let add : Array Import := add.qsortOrd
550583
551584 -- Parse the input file
552- let (path, inputCtx, imports, insertion) ←
553- try parseHeader srcSearchPath mod
554- catch e => println! e.toString; return count
585+ let .ok (path, inputCtx, stx, insertion) := header?.get | continue
586+ let (_, _, imports) := decodeHeader stx
555587 let text := inputCtx.fileMap.source
556588
557589 -- Calculate the edit result
558- let mut pos : String.Pos := 0
590+ let mut pos : String.Pos.Raw := 0
559591 let mut out : String := ""
560592 let mut seen : Std.HashSet Import := {}
561593 for stx in imports do
562594 let mod := decodeImport stx
563595 if remove.contains mod || seen.contains mod then
564596 out := out ++ text.extract pos stx.raw.getPos?.get!
565597 -- We use the end position of the syntax, but include whitespace up to the first newline
566- pos := text.findAux (· == '\n ' ) text.rawEndPos stx.raw.getTailPos?.get! + ⟨ 1 ⟩
598+ pos := text.findAux (· == '\n ' ) text.rawEndPos stx.raw.getTailPos?.get! + ' \n '
567599 seen := seen.insert mod
568600 out := out ++ text.extract pos insertion
569601 for mod in add do
@@ -573,7 +605,7 @@ def main (args : List String) : IO UInt32 := do
573605 out := out ++ text.extract insertion text.rawEndPos
574606
575607 IO.FS.writeFile path out
576- return count + 1
608+ count := count + 1
577609
578610 -- Since we throw an error upon encountering issues, we can be sure that everything worked
579611 -- if we reach this point of the script.
0 commit comments