Skip to content

Commit ddc9b76

Browse files
authored
Merge pull request #96 from Herb-AI/dev-deepcoder
Add DeepCoder benchmark
2 parents 498af09 + 4a06c17 commit ddc9b76

File tree

10 files changed

+6154
-11
lines changed

10 files changed

+6154
-11
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
module DeepCoder_2016
2+
3+
using HerbCore
4+
using HerbSpecification
5+
using HerbGrammar
6+
7+
using JSON
8+
9+
include("data.jl")
10+
include("base_grammar.jl")
11+
include("grammars.jl")
12+
13+
include("list_functions.jl")
14+
15+
export
16+
parse_deepcoder_problem_and_grammar
17+
base_grammar_deepcoder
18+
19+
"""
20+
parse_deepcoder_problem(filename::AbstractString, base_grammar::AbstractGrammar)::Problem
21+
Parses a DeepCoder problem from a file given a base grammar.
22+
"""
23+
function parse_deepcoder_problem_and_grammar(filename::AbstractString,
24+
base_grammar::AbstractGrammar)
25+
raw = JSON.parsefile(filename)
26+
27+
examples = IOExample[]
28+
for ex in raw["examples"]
29+
args = split_inputs(ex["input"])
30+
out = normalize_value(ex["output"])
31+
push!(examples, IOExample(args, out))
32+
end
33+
34+
number = match(r"\d+", raw["name"])
35+
number === nothing && error("Could not extract problem number from: $filename")
36+
problem_name = "problem_" * lpad(number.match, 3, '0')
37+
problem = Problem(problem_name, examples)
38+
39+
# infer from first example (DeepCoder tasks are consistent)
40+
sig = infer_signature(examples[1].in)
41+
start_nt = infer_output_nt(examples[1].out)
42+
43+
# combine base + extras
44+
g = deepcopy(base_grammar)
45+
add_extras!(g, sig, start_nt)
46+
47+
return problem, g
48+
end
49+
50+
function split_inputs(raw_in)::Dict{Symbol,Any}
51+
@assert raw_in isa Vector "DeepCoder 'input' must be an array"
52+
n = length(raw_in)
53+
@assert 1 <= n <= 2 "Expected 1 or 2 inputs, got $n"
54+
55+
tojl(v) = v isa Vector ? map(Int, v) : Int(v)
56+
57+
args = Dict{Symbol,Any}()
58+
args[:_arg_1] = tojl(raw_in[1])
59+
if n == 2
60+
args[:_arg_2] = tojl(raw_in[2])
61+
end
62+
return args
63+
end
64+
65+
function infer_signature(args::Dict{Symbol,Any})::Dict{Symbol,Symbol}
66+
sig = Dict{Symbol,Symbol}()
67+
for (k, v) in args
68+
if v isa AbstractVector{<:Integer}
69+
sig[k] = :ExprArr
70+
elseif v isa Integer
71+
sig[k] = :ExprNum
72+
else
73+
error("Unsupported input type for $(k): $(typeof(v))")
74+
end
75+
end
76+
sig
77+
end
78+
79+
function add_extras!(g::AbstractGrammar, sig::Dict{Symbol,Symbol}, start_nt::String)
80+
add_rule!(g, make_sym_rule(:Start, start_nt))
81+
for (arg, nt) in sig
82+
add_rule!(g, make_sym_rule(nt, arg))
83+
end
84+
g
85+
end
86+
87+
infer_output_nt(out)::String = out isa AbstractVector{<:Any} ? "ExprArr" :
88+
out isa Integer ? "ExprNum" :
89+
error("Unsupported output type: $(typeof(out)): $out")
90+
91+
normalize_value(x) = x isa Vector ? map(v -> Int(v), x) : Int(x)
92+
93+
make_sym_rule(lhs::Symbol, rhs::Symbol)::Expr = Expr(:(=), lhs, rhs)
94+
95+
end # module DeepCoder_2016

src/data/DeepCoder_2016/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# DeepCoder Benchmark
2+
3+
The DeepCoder specializes in functional programs that manipulate lists.
4+
Each problem is written as a set of input-output examples.
5+
6+
The DeepCoder benchmark is derived from Balog et al. (2016) using the setup from Neo (Feng et al., 2018), as the evaluation benchmarks are not publicly available.
7+
Neo thus generated 100 benchmarks following this workflow:
8+
9+
> We enumerate DSL programs with
10+
> at least 5 components and randomly generate inputs and the
11+
> corresponding output. This procedure is repeated for a fixed
12+
> number of times until we either obtain 5 valid input-output
13+
> examples or no examples have been found within the iter-
14+
> ation limit. In the latter case, we restart this process and
15+
> randomly search for a different program.
16+
17+
See
18+
> Balog, M., Gaunt, A. L., Brockschmidt, M., Nowozin, S., & Tarlow, D. (2016). Deepcoder: Learning to write programs. arXiv preprint arXiv:1611.01989.
19+
and
20+
> Feng, Y., Martins, R., Bastani, O., & Dillig, I. (2018). Program synthesis using conflict-driven learning. ACM SIGPLAN Notices, 53(4), 420-435.
21+
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
base_grammar_deepcoder = @csgrammar begin
2+
Int = |(-3:3)
3+
4+
ExprNum = Int
5+
6+
ExprNum = maximum(ExprArr) := (length(x1) > 1, maximum(y) == maximum(x1), minimum(y) > minimum(x1))
7+
ExprNum = minimum(ExprArr) := (length(x1) > 1, maximum(y) < maximum(x1), minimum(y) == minimum(x1))
8+
ExprNum = sum(ExprArr) := (length(x1) > 1)
9+
ExprNum = first(ExprArr) := (length(x1) > 1, maximum(y) <= maximum(x1), minimum(y) >= minimum(x1), first(y) == first(x1), last(y) == first(x1))
10+
ExprNum = last(ExprArr) := (length(x1) > 1, maximum(y) <= maximum(x1), minimum(y) >= minimum(x1), first(y) == last(x1), last(y) == last(x1))
11+
ExprNum = getindex(ExprArr, ExprNum) := (length(x1) > 1, maximum(y) <= maximum(x1), minimum(y) >= minimum(x1), first(x2) > 0, length(x1) > first(x2))
12+
13+
ExprNum = countSt(ExprArr, Int) := (length(x1) > 1, last(y) <= length(x1), last(y) >= 0)
14+
ExprNum = countGt(ExprArr, Int) := (length(x1) > 1, last(y) <= length(x1), last(y) >= 0)
15+
ExprNum = countEq(ExprArr, Int) := (length(x1) > 1, last(y) <= length(x1), last(y) >= 0)
16+
ExprNum = countNeq(ExprArr, Int) := (length(x1) > 1, last(y) <= length(x1), last(y) >= 0)
17+
ExprNum = countMod(ExprArr, Int) := (length(x1) > 1, last(y) <= length(x1), last(y) >= 0)
18+
ExprNum = countNmod(ExprArr, Int) := (length(x1) > 1, last(y) <= length(x1), last(y) >= 0)
19+
20+
ExprArr = drop(ExprArr, ExprNum) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1), last(y) == last(x1), first(x2) > 0, length(x1) > maximum(x2))
21+
ExprArr = take(ExprArr, ExprNum) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1), first(y) == first(x1), first(x2) > 0, length(x1) > maximum(x2))
22+
ExprArr = sort(ExprArr) := (length(y) == length(x1), maximum(y) == maximum(x1), minimum(y) == minimum(x1), first(y) == minimum(x1), last(y) == maximum(x1))
23+
ExprArr = reverse(ExprArr) := (length(y) == length(x1), maximum(y) == maximum(x1), minimum(y) == minimum(x1), first(y) == last(x1), last(y) == first(x1))
24+
25+
ExprArr = filterSt(ExprArr, Int) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1))
26+
ExprArr = filterGt(ExprArr, Int) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1))
27+
ExprArr = filterEq(ExprArr, Int) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1))
28+
ExprArr = filterNeq(ExprArr, Int) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1))
29+
ExprArr = filterMod(ExprArr, Int) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1))
30+
ExprArr = filterNmod(ExprArr, Int) := (length(y) < length(x1), maximum(y) <= maximum(x1), minimum(y) >= minimum(x1))
31+
32+
ExprArr = mapPlus(ExprArr, Int) := (length(y) == length(x1))
33+
ExprArr = mapMult(ExprArr, Int) := (length(y) == length(x1))
34+
ExprArr = mapDiv(ExprArr, Int) := (length(y) == length(x1))
35+
ExprArr = mapPow(ExprArr, Int) := (length(y) == length(x1))
36+
37+
ExprArr = zipwithMax(ExprArr, ExprArr) := (length(y) == length(x1), length(y) == length(x2))
38+
ExprArr = zipwithMin(ExprArr, ExprArr) := (length(y) == length(x1), length(y) == length(x2))
39+
ExprArr = zipwithPlus(ExprArr, ExprArr) := (length(y) == length(x1), length(y) == length(x2))
40+
ExprArr = zipwithMinus(ExprArr, ExprArr) := (length(y) == length(x1), length(y) == length(x2))
41+
ExprArr = zipwithMult(ExprArr, ExprArr) := (length(y) == length(x1), length(y) == length(x2))
42+
43+
ExprArr = scanl1Plus(ExprArr) := (length(y) == length(x1), first(y) == first(x1))
44+
ExprArr = scanl1Minus(ExprArr) := (length(y) == length(x1), first(y) == first(x1))
45+
ExprArr = scanl1Mult(ExprArr) := (length(y) == length(x1), first(y) == first(x1))
46+
ExprArr = scanl1Max(ExprArr) := (length(y) == length(x1), first(y) == first(x1), maximum(y) == maximum(x1), minimum(y) >= minimum(x1), last(y) == maximum(x1))
47+
ExprArr = scanl1Min(ExprArr) := (length(y) == length(x1), first(y) == first(x1), maximum(y) <= maximum(x1), minimum(y) == minimum(x1), last(y) == minimum(x1))
48+
end
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@article{article,
2+
author = {Balog, Matej and Gaunt, Alexander and Brockschmidt, Marc and Nowozin, Sebastian and Tarlow, Daniel},
3+
year = {2016},
4+
month = {11},
5+
pages = {},
6+
title = {DeepCoder: Learning to Write Programs},
7+
doi = {10.48550/arXiv.1611.01989}
8+
}

0 commit comments

Comments
 (0)