From f460b317a9a671db08ac59439ca8810c69a8510d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Apr 2025 17:57:55 +0100 Subject: [PATCH 1/5] feat: handle anonymous function edgecases --- src/utils.jl | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 9b7a48b..c9b5c96 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -259,7 +259,34 @@ isshortdef(ex) = (@capture(ex, (fcall_ = body_)) && function longdef1(ex) if @capture(ex, (arg_ -> body_)) - Expr(:function, arg isa Symbol ? :($arg,) : arg, body) + + if isexpr(arg, :tuple) && length(arg.args) == 1 && isexpr(arg.args[1], :parameters) + # Special case (; kws...) -> + fcall = Expr(:call, :tuple, arg.args[1]) + + Expr(:function, fcall, body) + elseif isexpr(arg, :block) && any(a -> isexpr(a, :...) || isexpr(a, :(=)) || isexpr(a, :kw), arg.args) + # Has keywords in a block + pos_args = [] + kw_args = [] + for a in arg.args + if !(a isa LineNumberNode) + if isexpr(a, :...) || isexpr(a, :(=)) || isexpr(a, :kw) + push!(kw_args, a) + else + push!(pos_args, a) + end + end + end + fcall = Expr(:call, :tuple, Expr(:parameters, kw_args...), pos_args...) + + Expr(:function, fcall, body) + elseif isexpr(arg, :...) + # Special case for a varargs argument + Expr(:function, Expr(:tuple, arg), body) + else + Expr(:function, arg isa Symbol ? :($arg,) : arg, body) + end elseif isshortdef(ex) @assert @capture(ex, (fcall_ = body_)) Expr(:function, fcall, body) From 20e9d3abc409bb0c3d555a697cbdafd682b7617d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Apr 2025 17:58:06 +0100 Subject: [PATCH 2/5] test: anonymous function edgecases --- test/split.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/split.jl b/test/split.jl index 36bf768..b28b879 100644 --- a/test/split.jl +++ b/test/split.jl @@ -83,6 +83,38 @@ let @test (@splitcombine function (x::T, y::Vector{U}) where T <: U where U (T, U) end)(1, Number[2.0]) == (Int, Number) + + # Test for lambda expressions with keyword arguments + @test (@splitcombine (a::Int; b=2) -> a + b)(1) === 3 + @test (@splitcombine (a::Int; b::Float64=2.0) -> Float64(a) + b)(1) === 3.0 + @test (@splitcombine (a::Int, x; b=2, c=3) -> a + b + c + x)(1, 4) === 10 + @test (@splitcombine (a::Int, x=2) -> a + x)(1) === 3 + @test (@splitcombine (a::Int, x=2; y) -> a + x + y)(1; y=3) === 6 + @test (@splitcombine (a, x::Int=2; y) -> a + x + y)(1; y=3) === 6 + @test (@splitcombine (a::Int, x::Int=2; y) -> a + x + y)(1; y=3) === 6 + + # With tuple unpacking + @test (@splitcombine (((a, b)::Tuple{Int, Int}, c; d=1) -> a + b + c + d))((1, 2), 3; d=4) === 10 + @test (@splitcombine ((c, (a, b); d=1) -> a + b + c + d))(3, (1, 2); d=4) === 10 + @test (@splitcombine ((c, (a, b); d) -> a + b + c + d))(3, (1, 2); d=4) === 10 + + # Test for single varargs argument in lambda + @test splitdef(Meta.parse("(args...) -> 0"))[:args] == [:(args...)] + @test (@splitcombine (args...) -> sum(args))(1, 2, 3) == 6 + @test (@splitcombine (args::Int...) -> sum(args))(1, 2, 3) == 6 + @test (@splitcombine (args::Int...; y=2) -> sum(args) + y)(1, 2, 3) == 8 + @test (@splitcombine (arg, args::Int...; y=2) -> arg + sum(args) + y)(1, 2, 3) == 8 + @test (@splitcombine (::Int...) -> 1)(1, 2, 3) === 1 + + # Splatted keyword arguments + @test (@splitcombine (a::Int; kws...) -> a + sum(values(kws)))(1; b=2, c=3) == 6 + @test (@splitcombine (; kws...) -> sum(values(kws)))(b=2, c=3) == 5 + @test (@splitcombine (a::Int; b, kws...) -> a + b + sum(values(kws)))(1; b=2, c=3) == 6 + @test (@splitcombine (a::Int; b=2, kws...) -> a + b + sum(values(kws)))(1; c=3) == 6 + + # Both splatted positional and keyword arguments + @test (@splitcombine (a::Int, args::Int...; kws...) -> a + sum(args) + sum(values(kws)))(1, 2, 3; b=4, c=5) == 15 + @test (@splitcombine (a, ::Int...; b, kws...) -> a + sum(values(kws)))(1, 2, 3; b=4, c=5) == 1 + 5 end @testset "combinestructdef, splitstructdef" begin From 9c62605d70e81a60f7222c8917b686d49c9da5e8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Apr 2025 18:26:06 +0100 Subject: [PATCH 3/5] fix: misnaming function as `tuple` --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index c9b5c96..45998db 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -262,7 +262,7 @@ function longdef1(ex) if isexpr(arg, :tuple) && length(arg.args) == 1 && isexpr(arg.args[1], :parameters) # Special case (; kws...) -> - fcall = Expr(:call, :tuple, arg.args[1]) + fcall = Expr(:tuple, arg.args[1]) Expr(:function, fcall, body) elseif isexpr(arg, :block) && any(a -> isexpr(a, :...) || isexpr(a, :(=)) || isexpr(a, :kw), arg.args) @@ -278,7 +278,7 @@ function longdef1(ex) end end end - fcall = Expr(:call, :tuple, Expr(:parameters, kw_args...), pos_args...) + fcall = Expr(:tuple, Expr(:parameters, kw_args...), pos_args...) Expr(:function, fcall, body) elseif isexpr(arg, :...) From 0786002c6b2d085e6a03503ecc333df2a5e918ab Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 13 Apr 2025 18:43:26 +0100 Subject: [PATCH 4/5] fix: no-arg but kwarg splat --- src/utils.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 45998db..9668a06 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -351,8 +351,13 @@ function splitdef(fdef) (func_(args__)) | (func_(args__)::rtype_))) elseif isexpr(fcall_nowhere, :tuple) - if length(fcall_nowhere.args) > 1 && isexpr(fcall_nowhere.args[1], :parameters) - args = fcall_nowhere.args[2:end] + if length(fcall_nowhere.args) > 0 && isexpr(fcall_nowhere.args[1], :parameters) + # Handle both cases: parameters with args and parameters only + if length(fcall_nowhere.args) > 1 + args = fcall_nowhere.args[2:end] + else + args = [] + end kwargs = fcall_nowhere.args[1].args else args = fcall_nowhere.args From b3e148c90229ec8ac82c220bd98fa9cdffd01084 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 14 Apr 2025 17:43:32 +0100 Subject: [PATCH 5/5] fix: issue in generating `:kw` for anonymous functions --- src/utils.jl | 7 ++++++- test/split.jl | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 9668a06..cdce3ac 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -271,7 +271,12 @@ function longdef1(ex) kw_args = [] for a in arg.args if !(a isa LineNumberNode) - if isexpr(a, :...) || isexpr(a, :(=)) || isexpr(a, :kw) + if isexpr(a, :...) + push!(kw_args, a) + elseif isexpr(a, :(=)) + # Transform = to :kw for keyword arguments + push!(kw_args, Expr(:kw, a.args[1], a.args[2])) + elseif isexpr(a, :kw) push!(kw_args, a) else push!(pos_args, a) diff --git a/test/split.jl b/test/split.jl index b28b879..d00500b 100644 --- a/test/split.jl +++ b/test/split.jl @@ -115,6 +115,14 @@ let # Both splatted positional and keyword arguments @test (@splitcombine (a::Int, args::Int...; kws...) -> a + sum(args) + sum(values(kws)))(1, 2, 3; b=4, c=5) == 15 @test (@splitcombine (a, ::Int...; b, kws...) -> a + sum(values(kws)))(1, 2, 3; b=4, c=5) == 1 + 5 + + # Issue with longdef + ex = longdef(:((a::Int; b=2) -> a + b)) + any_kw(ex) = ex isa Expr ? (any_kw(ex.head) || any(any_kw, ex.args)) : ex == :kw + @test any_kw(ex) + ## ^Ensure we get a :kw expression in the output AST + @test eval(ex) isa Function + ## Shouldn't have issues evaluating end @testset "combinestructdef, splitstructdef" begin