@@ -108,6 +108,21 @@ def terminalReplacement (oldTacticName newTacticName : String) (oldTacticKind :
108108 }
109109
110110
111+ /-- Convert a term syntax to a grindParam syntax (wrapping in grindLemma).
112+ If the term is a simple identifier (like `pi_pos`), wrap it in an explicit application
113+ `(id pi_pos)` so grind treats it as a term rather than an e-matching theorem. -/
114+ private def termToGrindParam (t : Syntax) : Syntax :=
115+ -- grindLemma := ppGroup((Attr.grindMod ppSpace)? term)
116+ -- grindParam := grindErase <|> grindLemmaMin <|> grindLemma <|> anchor
117+ -- With no modifier, the first child is a null node
118+ -- If t is a simple identifier, wrap as `(id t)` to force term interpretation
119+ let t' : Syntax := if t.isIdent then
120+ -- Create `id t` application - this ensures grind sees it as a term, not an e-match candidate
121+ mkNode ``Lean.Parser.Term.app #[mkIdent `id, mkNullNode #[t]]
122+ else t
123+ let grindLemma := mkNode ``Lean.Parser.Tactic.grindLemma #[mkNullNode, t']
124+ mkNode ``Lean.Parser.Tactic.grindParam #[grindLemma]
125+
111126/--
112127Define a pass that tries replacing a specific tactic with `grind`.
113128
@@ -117,12 +132,46 @@ all produce the same message.
117132
118133`tacticKind` is the `SyntaxNodeKind` for the tactic's main parser,
119134for example `Mathlib.Tactic.linarith`.
135+
136+ If `extractArgs` is provided, it extracts term arguments from the original tactic
137+ (e.g., `linarith [X, Y]`) and passes them to grind (e.g., `grind [X, Y]`).
138+ Local hypotheses are filtered out since grind uses them automatically.
120139-/
121140def grindReplacementWith (tacticName : String) (tacticKind : SyntaxNodeKind)
141+ (extractArgs : Syntax → Option (Syntax.TSepArray `term "," ) := fun _ => none)
122142 (reportFailure : Bool := true ) (reportSuccess : Bool := false )
123143 (reportSlowdown : Bool := false ) (maxSlowdown : Float := 1 ) :
124144 TacticAnalysis.Config :=
125- terminalReplacement tacticName "grind" tacticKind (fun _ _ _ => `(tactic| grind))
145+ let newTactic : ContextInfo → TacticInfo → Syntax → CommandElabM (TSyntax `tactic) :=
146+ fun _ctxI tacI stx => do
147+ match extractArgs stx with
148+ | some args =>
149+ if args.getElems.isEmpty then
150+ return ← `(tactic| grind)
151+ -- Get local hypothesis names from the goal's local context
152+ let lctxNames : Std.HashSet Name :=
153+ match tacI.goalsBefore.head? with
154+ | some goal =>
155+ let goalDecl := tacI.mctxBefore.decls.find! goal
156+ goalDecl.lctx.foldl (init := {}) fun s decl =>
157+ if decl.isImplementationDetail then s else s.insert decl.userName
158+ | none => {}
159+ -- Filter out terms that are simple identifiers matching local hypotheses
160+ let filteredElems := args.getElems.filter fun term =>
161+ match term.raw with
162+ | .ident _ _ name _ => !lctxNames.contains name
163+ | _ => true -- Keep non-identifier terms (like `foo.bar x`)
164+ if filteredElems.isEmpty then
165+ return ← `(tactic| grind)
166+ -- Build comma-separated list from filtered elements
167+ let grindElemsAndSeps := filteredElems.foldl (init := #[]) fun acc elem =>
168+ if acc.isEmpty then #[termToGrindParam elem]
169+ else acc.push (mkAtom "," ) |>.push (termToGrindParam elem)
170+ let grindArgs : Syntax.TSepArray ``Lean.Parser.Tactic.grindParam "," :=
171+ ⟨grindElemsAndSeps⟩
172+ `(tactic| grind [$grindArgs,*])
173+ | none => `(tactic| grind)
174+ terminalReplacement tacticName "grind" tacticKind newTactic
126175 reportFailure reportSuccess reportSlowdown maxSlowdown
127176
128177end Mathlib.TacticAnalysis
@@ -136,6 +185,14 @@ register_option linter.tacticAnalysis.regressions.linarithToGrind : Bool := {
136185@[tacticAnalysis linter.tacticAnalysis.regressions.linarithToGrind,
137186 inherit_doc linter.tacticAnalysis.regressions.linarithToGrind]
138187def linarithToGrindRegressions := grindReplacementWith "linarith" `Mathlib.Tactic.linarith
188+ (extractArgs := fun stx => do
189+ -- linarith syntax: "linarith" "!"? linarithArgsRest
190+ -- linarithArgsRest := optConfig (&" only")? (" [" term,* "]")?
191+ -- The args are in the last child of linarithArgsRest (index 2), which is at index 2 of linarith
192+ let rest := stx[2 ] -- linarithArgsRest
193+ let argsGroup := rest[2 ] -- the optional bracket group
194+ guard (argsGroup.getNumArgs >= 2 ) -- has at least "[" and "]"
195+ return ⟨argsGroup[1 ].getArgs⟩) -- the term,* between brackets (getArgs gives the array)
139196
140197/-- Debug `grind` by identifying places where it does not yet supersede `ring`. -/
141198register_option linter.tacticAnalysis.regressions.ringToGrind : Bool := {
0 commit comments