Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sealir-tutorials/ch03_egraph_program_rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from sealir import rvsdg
from sealir.eqsat import rvsdg_eqsat
from sealir.eqsat.rvsdg_eqsat import GraphRoot, Term, TermList
from utils import IN_NOTEBOOK
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this script is intended to be run as a module, consider using a relative import (from .utils import IN_NOTEBOOK) to avoid ImportError when the package is installed.

Suggested change
from utils import IN_NOTEBOOK
from .utils import IN_NOTEBOOK

Copilot uses AI. Check for mistakes.
from egglog_to_latex import visualize_ruleset_latex

# We'll be extending from chapter 2.
from ch02_egraph_basic import (
Expand Down Expand Up @@ -116,8 +118,11 @@ def ruleset_const_propagate(a: Term, ival: i64):
IsConstantFalse(a)
)

if IN_NOTEBOOK:
# Visualize the constant propagation ruleset
visualize_ruleset_latex(ruleset_const_propagate)

Comment on lines +121 to 124
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] These if IN_NOTEBOOK blocks are repeated—consider extracting to a helper function to DRY up the visualization calls.

Suggested change
if IN_NOTEBOOK:
# Visualize the constant propagation ruleset
visualize_ruleset_latex(ruleset_const_propagate)
def visualize_if_in_notebook(ruleset):
if IN_NOTEBOOK:
# Visualize the given ruleset
visualize_ruleset_latex(ruleset)
visualize_if_in_notebook(ruleset_const_propagate)

Copilot uses AI. Check for mistakes.
# Now, well test our newly defined ruleset. This complete ruleset combines a
# Now, we'll test our newly defined ruleset. This complete ruleset combines a
# few built-in RVSDG rules with our recently crafted simple constant-propagation
# rules.

Expand Down Expand Up @@ -173,6 +178,9 @@ def ruleset_const_fold_if_else(a: Term, b: Term, c: Term, operands: TermList):
IsConstantFalse(a),
)

if IN_NOTEBOOK:
# Visualize the if-else folding ruleset
visualize_ruleset_latex(ruleset_const_fold_if_else)

if __name__ == "__main__":
my_ruleset = (
Expand Down
213 changes: 213 additions & 0 deletions sealir-tutorials/egglog_to_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from typing import List, Union
from egglog import EGraph


def tokenize(egglog_str: str) -> List[str]:
"""
Splits an Egglog S-expression string into a flat list of tokens.
Tokens are either "(" or ")", or atoms (any sequence of non-whitespace, non-parenthesis chars).
"""
tokens = []
i = 0
while i < len(egglog_str):
c = egglog_str[i]
if c.isspace():
i += 1
continue
if c in ("(", ")"):
tokens.append(c)
i += 1
else:
j = i
while j < len(egglog_str) and not egglog_str[j].isspace() and egglog_str[j] not in ("(", ")"):
j += 1
tokens.append(egglog_str[i:j])
i = j
return tokens


def parse_sexps(tokens: List[str]) -> List[Union[str, list]]:
"""
Parses a flat list of tokens into a nested list of S-expression forms.
Each form is either an atom (string) or a list whose first element is the head.
Returns a flat list of top-level S-expressions (each itself a nested list).
"""
stack: List[List] = []
current: List[Union[str, list]] = []
for tok in tokens:
if tok == "(":
stack.append(current)
current = []
elif tok == ")":
completed = current
current = stack.pop()
current.append(completed)
else:
current.append(tok)
# If the entire parse wrapped everything in a single list, unwrap it:
if len(current) == 1 and isinstance(current[0], list):
return current[0]
return current


def sexp_to_string(sexp):
"""Convert parsed S-expression back to original string format"""
if isinstance(sexp, str):
return sexp
elif isinstance(sexp, list):
inner = ' '.join(sexp_to_string(item) for item in sexp)
return f"({inner})"
else:
return str(sexp)


LATEX_ESCAPE = str.maketrans({
"_": r"\_",
"#": r"\#",
"$": r"\$",
"%": r"\%",
"&": r"\&",
"{": r"\{",
"}": r"\}",
"~": r"\textasciitilde{}",
"^": r"\^{}",
"\\": r"\\",
})


def _atom_tex(a: str) -> str:
try:
float(a) # leave numerics bare
return a
except ValueError:
return r"\text{" + a.translate(LATEX_ESCAPE) + "}"


INFIX_OPS = {"=", "!=", "<", "<=", ">", ">=", "+", "-", "*", "/", "%", "**"}


def _sexp_tex(x) -> str:
if isinstance(x, str):
return _atom_tex(x)

head, *args = x

# infix pretty-printing for common binary ops
if head in INFIX_OPS and len(args) == 2:
return f"{_sexp_tex(args[0])} {head} {_sexp_tex(args[1])}"

return (
r"\text{" + head.translate(LATEX_ESCAPE) + "}"
+ "(" + ", ".join(_sexp_tex(a) for a in args) + ")"
)


def _is_set_expr(x):
# Detects if x is a set-like S-expression: ['set', lhs, rhs]
return isinstance(x, list) and len(x) == 3 and x[0] == "set"


def _set_tex(x):
# Renders set(lhs, rhs) as lhs \to rhs
return f"{_sexp_tex(x[1])} \\to {_sexp_tex(x[2])}"


def to_latex(sexp):
"""
Render (rewrite …) or (rule …) as KaTeX-safe LaTeX.
"""
if not (isinstance(sexp, list) and sexp):
return None

tag = sexp[0]

# ─────────────── REWRITE ────────────────
if tag == "rewrite" and len(sexp) >= 3:
lhs, rhs = sexp[1], sexp[2]

# harvest optional :when clause (list of conditions)
when_conds = []
i = 3
while i < len(sexp):
if sexp[i] == ":when" and i + 1 < len(sexp):
when_conds = sexp[i + 1] # list of cond S-exps
break
i += 1 # <- step only ONE token

lhs_tex = _sexp_tex(lhs)
rhs_tex = _sexp_tex(rhs)

cond_tex = ""
if when_conds:
joined = r",\; ".join(_sexp_tex(c) for c in when_conds)
cond_tex = rf",\; {joined}"

num = rf"\text{{expr}} = {lhs_tex}{cond_tex}"
den = rf"\text{{expr}} \to {rhs_tex}"

return rf"\frac{{{num}}}{{{den}}}"

# ──────────────── RULE ─────────────────
if tag == "rule" and len(sexp) >= 3:
premises, conclusions = sexp[1], sexp[2]

def render_stack(exprs):
lines = []
for e in exprs:
if _is_set_expr(e):
lines.append(_set_tex(e))
else:
lines.append(_sexp_tex(e))
return r"\\ ".join(lines)

prem_tex = render_stack(premises)
concl_tex = render_stack(conclusions)

num = rf"\begin{{array}}{{c}}{prem_tex}\end{{array}}"
den = rf"\begin{{array}}{{c}}{concl_tex}\end{{array}}"

return rf"\frac{{{num}}}{{{den}}}"

return None


def visualize_ruleset_latex(ruleset, verbose=True):
"""
Visualize an egglog ruleset by converting it to LaTeX representation.
Only works in notebook environments.

Args:
ruleset: The egglog ruleset to visualize
verbose: If True, prints the original S-expression before LaTeX display

Returns:
None, but displays LaTeX representation if in notebook environment
"""
try:
shell = get_ipython().__class__.__name__
is_notebook = shell == "ZMQInteractiveShell"
except NameError:
is_notebook = False

if not is_notebook:
return

# Create demo egraph and run ruleset
demo_egraph = EGraph(save_egglog_string=True)
demo_egraph.run(ruleset)
egglog_str = demo_egraph.as_egglog_string

# Parse into S-expressions
tokens = tokenize(egglog_str)
sexps = parse_sexps(tokens)

from IPython.display import display, Math
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Move the IPython.display import to the top of the module to avoid runtime overhead when calling visualize_ruleset_latex multiple times.

Copilot uses AI. Check for mistakes.

for sexp in sexps:
tex = to_latex(sexp)
if tex:
if verbose:
print(sexp_to_string(sexp))
display(Math(tex))
if verbose:
print()