See https://discourse.julialang.org/t/input-convex-neural-network-is-not-convex-at-origin-in-lux-jl/133565/2 Jax also uses the logaddexp formulation https://github.com/jax-ml/jax/blob/ff99faf218b29d65a2aeeee8b0c2bda09f75b6f4/jax/_src/nn/functions.py#L139C14-L139C23