Skip to content

Commit 1066b79

Browse files
committed
Allow safe_cfunction to take regular pointer arguments
1 parent bdf5c80 commit 1066b79

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

examples/functions.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,20 @@ JLCXX_MODULE init_test_module(jlcxx::Module& mod)
353353
{
354354
b = !b;
355355
});
356+
357+
mod.method("test_safe_cfunction_uint", [](jlcxx::SafeCFunction f_data)
358+
{
359+
auto f = jlcxx::make_function_pointer<int(unsigned int*,int)>(f_data);
360+
unsigned int buffer[] = {1,2,3};
361+
return f(buffer, 3);
362+
});
363+
364+
mod.method("test_safe_cfunction_uint64", [](jlcxx::SafeCFunction f_data)
365+
{
366+
auto f = jlcxx::make_function_pointer<uint64_t*(int)>(f_data);
367+
uint64_t* buffer = f(3);
368+
return buffer[0] + buffer[1] + buffer[2];
369+
});
356370
}
357371

358372
}

include/jlcxx/functions.hpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,29 @@ struct ConvertToCpp<SafeCFunction>
137137

138138
namespace detail
139139
{
140+
// Allow fundamental pointer types to be passed as e.g. Ptr{Int32} instead of CxxPtr{Int32}
141+
template<typename T>
142+
struct FundamentalPtrT
143+
{
144+
static jl_datatype_t* value()
145+
{
146+
return julia_type<T>();
147+
}
148+
};
149+
150+
template<typename T>
151+
struct FundamentalPtrT<T*>
152+
{
153+
static jl_datatype_t* value()
154+
{
155+
if constexpr (std::is_fundamental_v<T>)
156+
{
157+
return (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)julia_type<T>());
158+
}
159+
return julia_type<T*>();
160+
}
161+
};
162+
140163
template<typename SignatureT>
141164
struct SplitSignature;
142165

@@ -146,11 +169,28 @@ namespace detail
146169
typedef R return_type;
147170
typedef R(*fptr_t)(ArgsT...);
148171

149-
std::vector<jl_datatype_t*> operator()()
172+
jl_datatype_t* expected_return_type()
150173
{
174+
create_if_not_exists<R>();
175+
return julia_type<return_type>();
176+
}
177+
178+
jl_datatype_t* fundamental_ptr_return_type()
179+
{
180+
return FundamentalPtrT<return_type>::value();
181+
}
182+
183+
std::vector<jl_datatype_t*> arg_types()
184+
{
185+
(create_if_not_exists<ArgsT>(), ...);
151186
return std::vector<jl_datatype_t*>({julia_type<ArgsT>()...});
152187
}
153188

189+
std::vector<jl_datatype_t*> fundamental_ptr_types()
190+
{
191+
return std::vector<jl_datatype_t*>({FundamentalPtrT<ArgsT>::value()...});
192+
}
193+
154194
fptr_t cast_ptr(void* ptr)
155195
{
156196
return reinterpret_cast<fptr_t>(ptr);
@@ -166,15 +206,16 @@ typename detail::SplitSignature<SignatureT>::fptr_t make_function_pointer(SafeCF
166206
JL_GC_PUSH3(&data.fptr, &data.return_type, &data.argtypes);
167207

168208
// Check return type
169-
jl_datatype_t* expected_rt = julia_type<typename SplitterT::return_type>();
170-
if(expected_rt != data.return_type)
209+
jl_datatype_t* expected_rt = SplitterT().expected_return_type();
210+
if(expected_rt != data.return_type && SplitterT().fundamental_ptr_return_type() != data.return_type)
171211
{
172212
JL_GC_POP();
173213
throw std::runtime_error("Incorrect datatype for cfunction return type, expected " + julia_type_name(expected_rt) + " but got " + julia_type_name(data.return_type));
174214
}
175215

176216
// Check arguments
177-
const std::vector<jl_datatype_t*> expected_argstypes = SplitterT()();
217+
const std::vector<jl_datatype_t*> expected_argstypes = SplitterT().arg_types();
218+
const std::vector<jl_datatype_t*> fundamental_ptr_argstypes = SplitterT().fundamental_ptr_types();
178219
ArrayRef<jl_value_t*> argtypes(data.argtypes);
179220
const int nb_args = expected_argstypes.size();
180221
if(nb_args != static_cast<int>(argtypes.size()))
@@ -187,7 +228,7 @@ typename detail::SplitSignature<SignatureT>::fptr_t make_function_pointer(SafeCF
187228
for(int i = 0; i != nb_args; ++i)
188229
{
189230
jl_datatype_t* argt = (jl_datatype_t*)argtypes[i];
190-
if(argt != expected_argstypes[i])
231+
if(argt != expected_argstypes[i] && argt != fundamental_ptr_argstypes[i])
191232
{
192233
std::stringstream err_sstr;
193234
err_sstr << "Incorrect argument type for cfunction at position " << i+1 << ", expected: " << julia_type_name(expected_argstypes[i]) << ", obtained: " << julia_type_name(argt);

include/jlcxx/jlcxx_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#define JLCXX_VERSION_MAJOR 0
1818
#define JLCXX_VERSION_MINOR 14
19-
#define JLCXX_VERSION_PATCH 7
19+
#define JLCXX_VERSION_PATCH 8
2020

2121
// From https://stackoverflow.com/questions/5459868/concatenate-int-to-string-using-c-preprocessor
2222
#define __JLCXX_STR_HELPER(x) #x

0 commit comments

Comments
 (0)