Skip to content

Commit b0f5477

Browse files
committed
more simplification rules
1 parent f7a0362 commit b0f5477

File tree

2 files changed

+123
-26
lines changed

2 files changed

+123
-26
lines changed

src/api/expressions.jl

+95-17
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ isonenum(root) = isnum(root) && isone(root)
8686
isdiv(ex) = ex in [/, div, pdiv, aq]
8787
isexpr(ex) = isa(ex, Expr)
8888
issym(ex) = isa(ex, Symbol)
89+
isexprsym(ex) = isexpr(ex) || issym(ex)
90+
isbinexpr(ex) = isexpr(ex) && length(ex.args) == 3
8991

9092
function evaluate(ex::Expr, psyms::Dict{Symbol,Int}, vals::T...)::T where {T}
9193
exprm = ex.args
@@ -136,24 +138,100 @@ function simplifybinary!(root)
136138
elseif (fn == (+)) && (op1 == op2) # x+x = 2x
137139
root.args[1] = (*)
138140
root.args[2] = 2
139-
elseif (fn == (+) || fn == (-))
140-
# n1+(n2+x) = n1+(x+n2) = (n2+x)+n1 = (x+n2)+n1 = x+n3, s.t. n3=n1+n2
141-
if (isexpr(op1) && isnum(op2)) || (isnum(op1) && isexpr(op2))
142-
# swap so op1 is expr
143-
if isnum(op1) && isexpr(op2)
144-
op1, op2 = op2, op1
141+
elseif (fn == (+) || fn == (-)) && (isbinexpr(op1) && isnum(op2))
142+
fn2, op11, op12 = op1.args
143+
if fn == (+) && fn2 == (-) && isnum(op12)
144+
# (x-m)+n = x+(n-m)
145+
root.args[2] = op11
146+
root.args[3] = op2-op12
147+
elseif fn == (-) && fn2 == (-) && isnum(op12)
148+
# (x-m)-n = x-(n+m)
149+
root.args[2] = op11
150+
root.args[3] = op12+op2
151+
elseif fn2 == (+) && (isnum(op11) || isnum(op12))
152+
# (m+x)±n = (x+m)±n = x+(n±m)
153+
var, n2 = isnum(op11) ? (op12, op11) : (op11, op12)
154+
n3 = fn(n2, op2)
155+
root.args[1] = fn2
156+
root.args[2] = var
157+
root.args[3] = n3
158+
elseif fn2 == (-) && isnum(op11)
159+
# (m-x)±n = (n±m)-x
160+
n3 = fn(op11, op2)
161+
root.args[1] = fn2
162+
root.args[2] = n3
163+
root.args[3] = op12
164+
end
165+
elseif fn == (+) && isnum(op1) && isbinexpr(op2)
166+
fn2, op21, op22 = op2.args
167+
if fn2 == (+) || fn2 == (-)
168+
if isnum(op21)
169+
# n+(m±x) = (n+m)±x
170+
root.args[1] = fn2
171+
root.args[2] = op1 + op21
172+
root.args[3] = op22
173+
elseif isnum(op22)
174+
# n+(x±m) = (n±m)+x
175+
root.args[1] = fn
176+
root.args[2] = fn2(op1, op22)
177+
root.args[3] = op21
178+
end
179+
end
180+
elseif fn == (-) && isnum(op1) && isbinexpr(op2)
181+
fn2, op21, op22 = op2.args
182+
if fn2 == (+) && (isnum(op21) || isnum(op22))
183+
# n-(x+m) = n-(m+x) = (n-m)-x
184+
var, n2 = isnum(op21) ? (op22, op21) : (op21, op22)
185+
root.args[2] = op1 - n2
186+
root.args[3] = var
187+
elseif fn2 == (-) && (isnum(op21) || isnum(op22))
188+
# n-(m-x) = (n-m)+x
189+
# n-(x-m) = (n+m)-x
190+
var, n2, f1p = isnum(op21) ? (op22, op21, true) : (op21, op22, false)
191+
root.args[1] = f1p ? (+) : (-)
192+
root.args[2] = f1p ? op1 - n2 : op1 + n2
193+
root.args[3] = var
194+
end
195+
elseif fn == (+) && (isexpr(op1) || isexpr(op2))
196+
if isbinexpr(op2)
197+
# x + (n - x) = n
198+
fn2, op21, op22 = op2.args
199+
if fn2 == (-) && op1 == op22
200+
root = op21
201+
end
202+
elseif isbinexpr(op1)
203+
# (n - x) + x = n
204+
fn2, op11, op12 = op1.args
205+
if fn2 == (-) && op12 == op2
206+
root = op11
207+
end
208+
end
209+
elseif fn == (-) && (isexpr(op1) || isexpr(op2))
210+
if isbinexpr(op2)
211+
fn2, op21, op22 = op2.args
212+
if op1 == op21
213+
# x - (x ± n) = -±n
214+
if fn2 == (-)
215+
root = op22
216+
else
217+
pop!(root.args)
218+
root.args[end] = op22
219+
end
220+
elseif op1 == op22 && fn2 == (+)
221+
# x - (n + x) = -n
222+
pop!(root.args)
223+
root.args[end] = op21
145224
end
146-
#println("ex1: $fn ($op1, $op2)"
147-
# op2 is binexpr
148-
if isexpr(op1) && length(op1) == 3
149-
fn2, op21, op22 = op1.args
150-
# some operand has to be num
151-
if (fn2 == (+) || fn2 == (-)) && (isnum(op21) || isnum(op22))
152-
var, n2 = isnum(op21) ? (op22, op21) : (op21, op22)
153-
n3 = fn(n2, op2)
225+
elseif isbinexpr(op1)
226+
# (x ± n) - x = ±n
227+
fn2, op11, op12 = op1.args
228+
if op11 == op2
229+
if fn2 == (+)
230+
root = op12
231+
else
154232
root.args[1] = fn2
155-
root.args[2] = var
156-
root.args[3] = n3
233+
root.args[2] = op12
234+
pop!(root.args)
157235
end
158236
end
159237
end
@@ -196,7 +274,7 @@ function infix(io::IO, root; digits=3)
196274
print(io, "(")
197275
infix(io, root.args[2])
198276
infix(io, root.args[1])
199-
infix(io, root.args[3])
277+
length(root.args)>2 && infix(io, root.args[3])
200278
print(io, ")")
201279
else
202280
infix(io, root.args[1])

test/gp.jl

+28-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
@testset for (func, arity) in t.functions
1515
@test arity == 2
1616
end
17-
@test_skip summary(t) == "TreeGP[P=10,Parameter[x,y],Function[*, +, /, -]]"
17+
show(IOBuffer(), summary(t))
1818

1919
# population initialization
2020
popexp = Evolutionary.initial_population(t, rng=rng);
@@ -36,16 +36,16 @@
3636
@test Evolutionary.nodes(ft) == 15
3737
@test Evolutionary.height(ft) == 3
3838
@test length(ft) == 15
39-
# @test Evolutionary.depth(ft, :x) == 3
40-
# ft[3] = :z
41-
# @test Evolutionary.depth(ft, :z) == 3
39+
@test Evolutionary.depth(ft, :x) == 3
40+
ft[3] = :z
41+
@test Evolutionary.depth(ft, :z) == 3
4242
@test Evolutionary.depth(ft, ft) == 0
4343
@test Evolutionary.depth(ft, ft[3]) > 0
4444
@test Evolutionary.depth(ft, :w) == -1
45-
@test Evolutionary.evaluate([1.0, 2.0], :y, [:y, :z]) == 1.0
45+
@test Evolutionary.evaluate(:y, Dict(:y=>1, :z=>2), 1.0, 2.0) == 1.0
4646
copyto!(ft, gt)
4747
@test ft == gt
48-
# @test Evolutionary.symbols(ft) |> sort == [:x, :y]
48+
@test Evolutionary.symbols(ft) |> sort == [:x, :y]
4949

5050
# simplification
5151
using Evolutionary: simplify!
@@ -64,8 +64,27 @@
6464
@test Expr(:call, log, Expr(:call, exp, 1)) |> simplify! == 1
6565
@test Expr(:call, -, Expr(:call, +, :x, 1), 2) |> simplify! == Expr(:call, +, :x, -1)
6666
@test Expr(:call, -, Expr(:call, +, 1, :x), 2) |> simplify! == Expr(:call, +, :x, -1)
67-
@test Expr(:call, +, 2, Expr(:call, +, 1, :x)) |> simplify! == Expr(:call, +, :x, 3)
68-
@test Expr(:call, +, 2, Expr(:call, +, :x, 1)) |> simplify! == Expr(:call, +, :x, 3)
67+
@test Expr(:call, +, Expr(:call, +, :x, 1), 2) |> simplify! == Expr(:call, +, :x, 3)
68+
@test Expr(:call, +, Expr(:call, +, 1, :x), 2) |> simplify! == Expr(:call, +, :x, 3)
69+
@test Expr(:call, +, Expr(:call, -, 1, :x), 2) |> simplify! == Expr(:call, -, 3, :x)
70+
@test Expr(:call, -, Expr(:call, -, 1, :x), 2) |> simplify! == Expr(:call, -, -1, :x)
71+
@test Expr(:call, +, Expr(:call, -, :x, 1), 2) |> simplify! == Expr(:call, +, :x, 1)
72+
@test Expr(:call, -, Expr(:call, -, :x, 1), 2) |> simplify! == Expr(:call, -, :x, 3)
73+
@test Expr(:call, +, :x, Expr(:call, -, 1, :x)) |> simplify! == 1
74+
@test Expr(:call, +, Expr(:call, -, 2, :x), :x) |> simplify! == 2
75+
@test Expr(:call, -, :x, Expr(:call, +, :x, :y)) |> simplify! == Expr(:call, -, :y)
76+
@test Expr(:call, -, :x, Expr(:call, -, :x, :y)) |> simplify! == :y
77+
@test Expr(:call, -, :x, Expr(:call, +, :y, :x)) |> simplify! == Expr(:call, -, :y)
78+
@test Expr(:call, -, Expr(:call, -, :x, :y), :x) |> simplify! == Expr(:call, -, :y)
79+
@test Expr(:call, -, Expr(:call, +, :x, :y), :x) |> simplify! == :y
80+
@test Expr(:call, +, 2, Expr(:call, +, 1, :x)) |> simplify! == Expr(:call, +, 3, :x)
81+
@test Expr(:call, +, 2, Expr(:call, -, 1, :x)) |> simplify! == Expr(:call, -, 3, :x)
82+
@test Expr(:call, +, 2, Expr(:call, +, :x, 1)) |> simplify! == Expr(:call, +, 3, :x)
83+
@test Expr(:call, +, 2, Expr(:call, -, :x, 1)) |> simplify! == Expr(:call, +, 1, :x)
84+
@test Expr(:call, -, 2, Expr(:call, +, 1, :x)) |> simplify! == Expr(:call, -, 1, :x)
85+
@test Expr(:call, -, 2, Expr(:call, +, :x, 1)) |> simplify! == Expr(:call, -, 1, :x)
86+
@test Expr(:call, -, 1, Expr(:call, -, 2, :x)) |> simplify! == Expr(:call, +, -1, :x)
87+
@test Expr(:call, -, 2, Expr(:call, -, :x, 1)) |> simplify! == Expr(:call, -, 3, :x)
6988

7089
# evaluation
7190
ex = Expr(:call, +, 1, :x) |> Evolutionary.Expression
@@ -100,7 +119,7 @@
100119
ε = 0.1
101120
),
102121
),
103-
Evolutionary.Options(show_trace=true, rng=rng, iterations=50)
122+
Evolutionary.Options(show_trace=false, rng=rng, iterations=50)
104123
)
105124
@test minimum(res) < 1.1
106125

0 commit comments

Comments
 (0)