Skip to content

Commit 6264afd

Browse files
KDr2yebaipenelopeysm
authored
static_parameter expr (#175)
* comment about TypedSlot * find an Instruction dynamically * translate static_parameter expr * translate literal variables * test case * test case for static parameter * drop compat for Julia < 1.10; mark as breaking release * Restrict CI to 1.10 * eval static parameters in args * Test whether Julia 1.7 still works * eval static parameter when binding vars * [skip ci] fix typos --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Penelope Yong <[email protected]>
1 parent d176f19 commit 6264afd

File tree

6 files changed

+42
-6
lines changed

6 files changed

+42
-6
lines changed

Diff for: .github/workflows/Testing.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ jobs:
1212
matrix:
1313
version:
1414
- '1.7'
15+
- '1.10'
1516
- '1'
1617
- 'nightly'
1718
os:

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.8.7"
6+
version = "0.8.8"
77

88
[deps]
99
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"

Diff for: perf/benchmark.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ println("======= breakdown benchmark =======")
114114
x = rand(100000)
115115
tf = Libtask.TapedFunction(ackley, x, nothing)
116116
tf(x, nothing);
117-
ins = tf.tape[45]
117+
idx = findlast((x)->isa(x, Libtask.Instruction), tf.tape)
118+
ins = tf.tape[idx]
118119
b = ins.input[1]
119120

120121
@show ins.input |> length

Diff for: src/Libtask.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ export TArray, tzeros, tfill, TRef # legacy types back compat
99

1010

1111
@static if isdefined(Core, :TypedSlot) || isdefined(Core.Compiler, :TypedSlot)
12-
# Julia v1.10 removed Core.TypedSlot
12+
# Julia v1.10 moved Core.TypedSlot to Core.Compiler
13+
# Julia v1.11 removed Core.Compiler.TypedSlot
1314
const TypedSlot = @static if isdefined(Core, :TypedSlot)
1415
Core.TypedSlot
1516
else

Diff for: src/tapedfunction.jl

+15-3
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,10 @@ end
269269

270270
const IRVar = Union{Core.SSAValue, Core.SlotNumber}
271271

272-
function bind_var!(var_literal, bindings::Bindings, ir::Core.CodeInfo)
273-
# for literal constants
274-
push!(bindings, var_literal)
272+
function bind_var!(var, bindings::Bindings, ir::Core.CodeInfo)
273+
# for literal constants, and static parameters
274+
var = Meta.isexpr(var, :static_parameter) ? ir.parent.sparam_vals[var.args[1]] : var
275+
push!(bindings, var)
275276
idx = length(bindings)
276277
return idx
277278
end
@@ -368,6 +369,14 @@ function translate!!(var::IRVar, line::Core.SlotNumber,
368369
return Instruction(func, input, output)
369370
end
370371

372+
function translate!!(var::IRVar, line::Number, # literal vars
373+
bindings::Bindings, isconst::Bool, ir)
374+
func = identity
375+
input = (bind_var!(line, bindings, ir),)
376+
output = bind_var!(var, bindings, ir)
377+
return Instruction(func, input, output)
378+
end
379+
371380
function translate!!(var::IRVar, line::NTuple{N, Symbol},
372381
bindings::Bindings, isconst::Bool, ir) where {N}
373382
# for syntax (; x, y, z), see Turing.jl#1873
@@ -439,6 +448,9 @@ function translate!!(var::IRVar, line::Expr,
439448
end
440449
return Instruction(identity, (_bind_fn(rhs),), _bind_fn(lhs))
441450
end
451+
elseif head === :static_parameter
452+
v = ir.parent.sparam_vals[line.args[1]]
453+
return Instruction(identity, (_bind_fn(v),), _bind_fn(var))
442454
else
443455
@error "Unknown Expression: " typeof(var) var typeof(line) line
444456
throw(ErrorException("Unknown Expression"))

Diff for: test/issues.jl

+21
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,25 @@
5959
r = tf(1, 2)
6060
@test r == (c=3, x=1, y=2)
6161
end
62+
63+
@testset "Issue-Libtask-174, SSAValue=Int and static parameter" begin
64+
# SSAValue = Int
65+
function f()
66+
# this line generates: %1 = 1::Core.Const(1)
67+
r = (a = 1)
68+
return nothing
69+
end
70+
tf = Libtask.TapedFunction(f)
71+
r = tf()
72+
@test r == nothing
73+
74+
# static parameter
75+
function g(::Type{T}) where {T}
76+
a = zeros(T, 10)
77+
end
78+
tf = Libtask.TapedFunction(g, Float64)
79+
r = tf(Float64)
80+
@test r == zeros(Float64, 10)
81+
end
82+
6283
end

0 commit comments

Comments
 (0)