@@ -137,6 +137,29 @@ struct ConvertToCpp<SafeCFunction>
137137
138138namespace detail
139139{
140+ // Allow fundamental pointer types to be passed as e.g. Ptr{Int32} instead of CxxPtr{Int32}
141+ template <typename T>
142+ struct FundamentalPtrT
143+ {
144+ static jl_datatype_t * value ()
145+ {
146+ return julia_type<T>();
147+ }
148+ };
149+
150+ template <typename T>
151+ struct FundamentalPtrT <T*>
152+ {
153+ static jl_datatype_t * value ()
154+ {
155+ if constexpr (std::is_fundamental_v<T>)
156+ {
157+ return (jl_datatype_t *)jl_apply_type1 ((jl_value_t *)jl_pointer_type, (jl_value_t *)julia_type<T>());
158+ }
159+ return julia_type<T*>();
160+ }
161+ };
162+
140163 template <typename SignatureT>
141164 struct SplitSignature ;
142165
@@ -146,11 +169,28 @@ namespace detail
146169 typedef R return_type;
147170 typedef R (*fptr_t )(ArgsT...);
148171
149- std::vector< jl_datatype_t *> operator () ()
172+ jl_datatype_t * expected_return_type ()
150173 {
174+ create_if_not_exists<R>();
175+ return julia_type<return_type>();
176+ }
177+
178+ jl_datatype_t * fundamental_ptr_return_type ()
179+ {
180+ return FundamentalPtrT<return_type>::value ();
181+ }
182+
183+ std::vector<jl_datatype_t *> arg_types ()
184+ {
185+ (create_if_not_exists<ArgsT>(), ...);
151186 return std::vector<jl_datatype_t *>({julia_type<ArgsT>()...});
152187 }
153188
189+ std::vector<jl_datatype_t *> fundamental_ptr_types ()
190+ {
191+ return std::vector<jl_datatype_t *>({FundamentalPtrT<ArgsT>::value ()...});
192+ }
193+
154194 fptr_t cast_ptr (void * ptr)
155195 {
156196 return reinterpret_cast <fptr_t >(ptr);
@@ -166,15 +206,16 @@ typename detail::SplitSignature<SignatureT>::fptr_t make_function_pointer(SafeCF
166206 JL_GC_PUSH3 (&data.fptr , &data.return_type , &data.argtypes );
167207
168208 // Check return type
169- jl_datatype_t * expected_rt = julia_type< typename SplitterT::return_type> ();
170- if (expected_rt != data.return_type )
209+ jl_datatype_t * expected_rt = SplitterT (). expected_return_type ();
210+ if (expected_rt != data.return_type && SplitterT (). fundamental_ptr_return_type () != data. return_type )
171211 {
172212 JL_GC_POP ();
173213 throw std::runtime_error (" Incorrect datatype for cfunction return type, expected " + julia_type_name (expected_rt) + " but got " + julia_type_name (data.return_type ));
174214 }
175215
176216 // Check arguments
177- const std::vector<jl_datatype_t *> expected_argstypes = SplitterT ()();
217+ const std::vector<jl_datatype_t *> expected_argstypes = SplitterT ().arg_types ();
218+ const std::vector<jl_datatype_t *> fundamental_ptr_argstypes = SplitterT ().fundamental_ptr_types ();
178219 ArrayRef<jl_value_t *> argtypes (data.argtypes );
179220 const int nb_args = expected_argstypes.size ();
180221 if (nb_args != static_cast <int >(argtypes.size ()))
@@ -187,7 +228,7 @@ typename detail::SplitSignature<SignatureT>::fptr_t make_function_pointer(SafeCF
187228 for (int i = 0 ; i != nb_args; ++i)
188229 {
189230 jl_datatype_t * argt = (jl_datatype_t *)argtypes[i];
190- if (argt != expected_argstypes[i])
231+ if (argt != expected_argstypes[i] && argt != fundamental_ptr_argstypes[i] )
191232 {
192233 std::stringstream err_sstr;
193234 err_sstr << " Incorrect argument type for cfunction at position " << i+1 << " , expected: " << julia_type_name (expected_argstypes[i]) << " , obtained: " << julia_type_name (argt);
0 commit comments