Skip to content

Commit 1df9fb6

Browse files
Merge pull request #1877 from ChrisRackauckas-Claude/fix-ctarget-ssqrt
Map solver-safe functions (ssqrt/scbrt/slog) to math.h names in CTarget
2 parents abd8b82 + 36a9300 commit 1df9fb6

2 files changed

Lines changed: 24 additions & 0 deletions

File tree

src/build_function.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,11 @@ numbered_expr(c,args...;kwargs...) = c
675675
numbered_expr(c::Num,args...;kwargs...) = error("Num found")
676676

677677

678+
# The solver introduces "safe" variants of some math functions that handle
679+
# negative/complex inputs in Julia. They have no C equivalent, so map them back
680+
# to the standard math.h names when generating C code.
681+
const c_safe_function_renames = Dict(:ssqrt => :sqrt, :scbrt => :cbrt, :slog => :log)
682+
678683
# Replace certain multiplication and power expressions so they form valid C code
679684
# Extra factors of 1 are hopefully eliminated by the C compiler
680685
function coperators(expr)
@@ -684,6 +689,9 @@ function coperators(expr)
684689
coperators(e)
685690
end
686691
end
692+
if expr.head == :call && expr.args[1] isa Symbol
693+
expr.args[1] = get(c_safe_function_renames, expr.args[1], expr.args[1])
694+
end
687695
for i in eachindex(expr.args)
688696
if expr.args[i] isa Rational
689697
expr.args[i] = float(expr.args[i]) # Evaluate rational numbers to floating-point

test/build_targets.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ let
7070
end
7171

7272

73+
# Safe solver functions (ssqrt/scbrt/slog) must map to their math.h names in C
74+
# (https://github.com/JuliaSymbolics/Symbolics.jl/issues/1873)
75+
let
76+
@variables y
77+
expression = (1//2) * Symbolics.term(Symbolics.ssqrt, 4y) +
78+
Symbolics.term(Symbolics.scbrt, y) + Symbolics.term(Symbolics.slog, y)
79+
cfunc = build_function([expression], y; target = Symbolics.CTarget(), expression = Val{true})
80+
81+
@test occursin("sqrt(", cfunc)
82+
@test occursin("cbrt(", cfunc)
83+
@test occursin("log(", cfunc)
84+
@test !occursin("ssqrt", cfunc)
85+
@test !occursin("scbrt", cfunc)
86+
@test !occursin("slog", cfunc)
87+
end
88+
7389
# Matrix StanTarget test
7490
let
7591
@variables x[1:4] y[1:4] z[1:4]

0 commit comments

Comments
 (0)