@@ -928,6 +928,57 @@ def createLocalPreDiscrTree
928928def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do
929929 keys.foldlM (init := t) (·.dropKey ·)
930930
931+ /-- Collect all values from a subtree recursively and clear them. -/
932+ partial def collectSubtreeAux (next : TrieIndex) : MatchM α (Array α) :=
933+ if next = 0 then
934+ pure #[]
935+ else do
936+ let (values, star, children) ← evalNode next
937+ -- Collect from star subtrie
938+ let starVals ← collectSubtreeAux star
939+ -- Collect from all children
940+ let mut childVals : Array α := #[]
941+ for (_, childIdx) in children do
942+ childVals := childVals ++ (← collectSubtreeAux childIdx)
943+ -- Clear this node (keep structure but remove values)
944+ modify (·.set! next {values := #[], star, children})
945+ return values ++ starVals ++ childVals
946+
947+ /-- Navigate to a key path and return all values in that subtree, then drop them. -/
948+ def extractKeyAux (next : TrieIndex) (rest : List Key) :
949+ MatchM α (Array α) :=
950+ if next = 0 then
951+ pure #[]
952+ else do
953+ let (_, star, children) ← evalNode next
954+ match rest with
955+ | [] =>
956+ -- At the target node: collect ALL values from entire subtree
957+ collectSubtreeAux next
958+ | k :: r => do
959+ let next := if k == .star then star else children.getD k 0
960+ extractKeyAux next r
961+
962+ /-- Extract and drop entries at a specific key, returning the dropped entries. -/
963+ def extractKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) :
964+ MetaM (Array α × LazyDiscrTree α) :=
965+ match path with
966+ | [] => pure (#[], t)
967+ | rootKey :: rest => do
968+ let idx := t.roots.getD rootKey 0
969+ runMatch t (extractKeyAux idx rest)
970+
971+ /-- Extract entries at the given keys and also drop them from the tree. -/
972+ def extractKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) :
973+ MetaM (Array α × LazyDiscrTree α) := do
974+ let mut allExtracted : Array α := #[]
975+ let mut tree := t
976+ for path in keys do
977+ let (extracted, newTree) ← extractKey tree path
978+ allExtracted := allExtracted ++ extracted
979+ tree := newTree
980+ return (allExtracted, tree)
981+
931982def logImportFailure [Monad m] [MonadLog m] [AddMessageContext m] [MonadOptions m] (f : ImportFailure) : m Unit :=
932983 logError m!"Processing failure with {f.const} in {f.module}:\n {f.exception.toMessageData}"
933984
@@ -979,6 +1030,7 @@ def findImportMatches
9791030 (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α)))
9801031 (droppedKeys : List (List LazyDiscrTree.Key) := [])
9811032 (constantsPerTask : Nat := 1000 )
1033+ (droppedEntriesRef : Option (IO.Ref (Option (Array α))) := none)
9821034 (ty : Expr) : MetaM (MatchResult α) := do
9831035 let cctx ← (read : CoreM Core.Context)
9841036 let ngen ← getNGen
@@ -990,7 +1042,13 @@ def findImportMatches
9901042 profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do
9911043 let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry
9921044 (constantsPerTask := constantsPerTask)
993- dropKeys t droppedKeys
1045+ -- If a reference is provided, extract and store dropped entries
1046+ if let some droppedRef := droppedEntriesRef then
1047+ let (extracted, t) ← extractKeys t droppedKeys
1048+ droppedRef.set (some extracted)
1049+ pure t
1050+ else
1051+ dropKeys t droppedKeys
9941052 let (importCandidates, importTree) ← importTree.getMatch ty
9951053 ref.set (some importTree)
9961054 pure importCandidates
@@ -1064,10 +1122,11 @@ def findMatchesExt
10641122 (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α)))
10651123 (droppedKeys : List (List LazyDiscrTree.Key) := [])
10661124 (constantsPerTask : Nat := 1000 )
1125+ (droppedEntriesRef : Option (IO.Ref (Option (Array α))) := none)
10671126 (adjustResult : Nat → α → β)
10681127 (ty : Expr) : MetaM (Array β) := do
10691128 let moduleMatches ← findModuleMatches moduleTreeRef ty
1070- let importMatches ← findImportMatches ext addEntry droppedKeys constantsPerTask ty
1129+ let importMatches ← findImportMatches ext addEntry droppedKeys constantsPerTask droppedEntriesRef ty
10711130 return Array.mkEmpty (moduleMatches.size + importMatches.size)
10721131 |> moduleMatches.appendResultsAux (f := adjustResult)
10731132 |> importMatches.appendResultsAux (f := adjustResult)
@@ -1080,13 +1139,15 @@ def findMatchesExt
10801139* `addEntry` is the function for creating discriminator tree entries from constants.
10811140* `droppedKeys` contains keys we do not want to consider when searching for matches.
10821141 It is used for dropping very general keys.
1142+ * `droppedEntriesRef` optionally stores entries dropped from the tree for later use.
10831143 -/
10841144def findMatches (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
10851145 (addEntry : Name → ConstantInfo → MetaM (Array (InitEntry α)))
10861146 (droppedKeys : List (List LazyDiscrTree.Key) := [])
10871147 (constantsPerTask : Nat := 1000 )
1148+ (droppedEntriesRef : Option (IO.Ref (Option (Array α))) := none)
10881149 (ty : Expr) : MetaM (Array α) := do
10891150
10901151 let moduleTreeRef ← createModuleTreeRef addEntry droppedKeys
10911152 let incPrio _ v := v
1092- findMatchesExt moduleTreeRef ext addEntry droppedKeys constantsPerTask incPrio ty
1153+ findMatchesExt moduleTreeRef ext addEntry droppedKeys constantsPerTask droppedEntriesRef incPrio ty
0 commit comments