diff --git a/src/utils.jl b/src/utils.jl index 9b7a48b..cdce3ac 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -259,7 +259,39 @@ isshortdef(ex) = (@capture(ex, (fcall_ = body_)) && function longdef1(ex) if @capture(ex, (arg_ -> body_)) - Expr(:function, arg isa Symbol ? :($arg,) : arg, body) + + if isexpr(arg, :tuple) && length(arg.args) == 1 && isexpr(arg.args[1], :parameters) + # Special case (; kws...) -> + fcall = Expr(:tuple, arg.args[1]) + + Expr(:function, fcall, body) + elseif isexpr(arg, :block) && any(a -> isexpr(a, :...) || isexpr(a, :(=)) || isexpr(a, :kw), arg.args) + # Has keywords in a block + pos_args = [] + kw_args = [] + for a in arg.args + if !(a isa LineNumberNode) + if isexpr(a, :...) + push!(kw_args, a) + elseif isexpr(a, :(=)) + # Transform = to :kw for keyword arguments + push!(kw_args, Expr(:kw, a.args[1], a.args[2])) + elseif isexpr(a, :kw) + push!(kw_args, a) + else + push!(pos_args, a) + end + end + end + fcall = Expr(:tuple, Expr(:parameters, kw_args...), pos_args...) + + Expr(:function, fcall, body) + elseif isexpr(arg, :...) + # Special case for a varargs argument + Expr(:function, Expr(:tuple, arg), body) + else + Expr(:function, arg isa Symbol ? :($arg,) : arg, body) + end elseif isshortdef(ex) @assert @capture(ex, (fcall_ = body_)) Expr(:function, fcall, body) @@ -324,8 +356,13 @@ function splitdef(fdef) (func_(args__)) | (func_(args__)::rtype_))) elseif isexpr(fcall_nowhere, :tuple) - if length(fcall_nowhere.args) > 1 && isexpr(fcall_nowhere.args[1], :parameters) - args = fcall_nowhere.args[2:end] + if length(fcall_nowhere.args) > 0 && isexpr(fcall_nowhere.args[1], :parameters) + # Handle both cases: parameters with args and parameters only + if length(fcall_nowhere.args) > 1 + args = fcall_nowhere.args[2:end] + else + args = [] + end kwargs = fcall_nowhere.args[1].args else args = fcall_nowhere.args diff --git a/test/split.jl b/test/split.jl index 36bf768..d00500b 100644 --- a/test/split.jl +++ b/test/split.jl @@ -83,6 +83,46 @@ let @test (@splitcombine function (x::T, y::Vector{U}) where T <: U where U (T, U) end)(1, Number[2.0]) == (Int, Number) + + # Test for lambda expressions with keyword arguments + @test (@splitcombine (a::Int; b=2) -> a + b)(1) === 3 + @test (@splitcombine (a::Int; b::Float64=2.0) -> Float64(a) + b)(1) === 3.0 + @test (@splitcombine (a::Int, x; b=2, c=3) -> a + b + c + x)(1, 4) === 10 + @test (@splitcombine (a::Int, x=2) -> a + x)(1) === 3 + @test (@splitcombine (a::Int, x=2; y) -> a + x + y)(1; y=3) === 6 + @test (@splitcombine (a, x::Int=2; y) -> a + x + y)(1; y=3) === 6 + @test (@splitcombine (a::Int, x::Int=2; y) -> a + x + y)(1; y=3) === 6 + + # With tuple unpacking + @test (@splitcombine (((a, b)::Tuple{Int, Int}, c; d=1) -> a + b + c + d))((1, 2), 3; d=4) === 10 + @test (@splitcombine ((c, (a, b); d=1) -> a + b + c + d))(3, (1, 2); d=4) === 10 + @test (@splitcombine ((c, (a, b); d) -> a + b + c + d))(3, (1, 2); d=4) === 10 + + # Test for single varargs argument in lambda + @test splitdef(Meta.parse("(args...) -> 0"))[:args] == [:(args...)] + @test (@splitcombine (args...) -> sum(args))(1, 2, 3) == 6 + @test (@splitcombine (args::Int...) -> sum(args))(1, 2, 3) == 6 + @test (@splitcombine (args::Int...; y=2) -> sum(args) + y)(1, 2, 3) == 8 + @test (@splitcombine (arg, args::Int...; y=2) -> arg + sum(args) + y)(1, 2, 3) == 8 + @test (@splitcombine (::Int...) -> 1)(1, 2, 3) === 1 + + # Splatted keyword arguments + @test (@splitcombine (a::Int; kws...) -> a + sum(values(kws)))(1; b=2, c=3) == 6 + @test (@splitcombine (; kws...) -> sum(values(kws)))(b=2, c=3) == 5 + @test (@splitcombine (a::Int; b, kws...) -> a + b + sum(values(kws)))(1; b=2, c=3) == 6 + @test (@splitcombine (a::Int; b=2, kws...) -> a + b + sum(values(kws)))(1; c=3) == 6 + + # Both splatted positional and keyword arguments + @test (@splitcombine (a::Int, args::Int...; kws...) -> a + sum(args) + sum(values(kws)))(1, 2, 3; b=4, c=5) == 15 + @test (@splitcombine (a, ::Int...; b, kws...) -> a + sum(values(kws)))(1, 2, 3; b=4, c=5) == 1 + 5 + + # Issue with longdef + ex = longdef(:((a::Int; b=2) -> a + b)) + any_kw(ex) = ex isa Expr ? (any_kw(ex.head) || any(any_kw, ex.args)) : ex == :kw + @test any_kw(ex) + ## ^Ensure we get a :kw expression in the output AST + @test eval(ex) isa Function + ## Shouldn't have issues evaluating end @testset "combinestructdef, splitstructdef" begin