@@ -91,7 +91,8 @@ template <typename T, typename = void>
9191struct has_cast : std::false_type {};
9292
9393template <typename T>
94- struct has_cast <T, std::void_t <decltype (std::declval<const T>().template cast<int >())>> : std::true_type {};
94+ struct has_cast <T, std::void_t <decltype (std::declval<const T>().template cast<int >())>>
95+ : std::true_type {};
9596
9697// Helper variable template for easier usage
9798template <typename T>
@@ -228,16 +229,17 @@ struct params_trait<T, std::enable_if_t<is_sparse_matrix_v<T>>> {
228229template <typename _Scalar>
229230struct params_trait <std::vector<_Scalar>> {
230231 using T = typename std::vector<_Scalar>;
231- using Scalar = _Scalar; // The scalar type
232+ using Scalar = _Scalar; // The scalar type
233+ using ScalarParamsTraits = params_trait<Scalar>;
232234 static constexpr Index Dims = Dynamic; // Compile-time parameters dimensions
233235 // Execution-time parameters dimensions
234236 static Index dims (const T& v) {
235- constexpr int ScalarDims = params_trait<Scalar> ::Dims;
237+ constexpr int ScalarDims = ScalarParamsTraits ::Dims;
236238 if constexpr (std::is_scalar_v<Scalar> || ScalarDims == 1 ) {
237239 return static_cast <int >(v.size ());
238240 } else if constexpr (ScalarDims == Dynamic) {
239241 int d = 0 ;
240- for (std::size_t i = 0 ; i < v.size (); ++i) d += params_trait<Scalar> ::dims (v[i]);
242+ for (std::size_t i = 0 ; i < v.size (); ++i) d += ScalarParamsTraits ::dims (v[i]);
241243 return d;
242244 } else {
243245 return static_cast <int >(v.size ()) * ScalarDims;
@@ -246,20 +248,23 @@ struct params_trait<std::vector<_Scalar>> {
246248 // Cast to a new type, only needed when using automatic differentiation
247249 template <typename T2>
248250 static auto cast (const T& v) {
249- std::vector<T2> o (v.size ());
250- for (std::size_t i = 0 ; i < v.size (); ++i) o[i] = params_trait<Scalar>::template cast<T2>(v[i]);
251+ using Scalar2 =
252+ std::decay_t <decltype (ScalarParamsTraits::template cast<T2>(std::declval<Scalar>()))>;
253+ std::vector<Scalar2> o;
254+ o.reserve (v.size ());
255+ for (auto & x : v) o.emplace_back (ScalarParamsTraits::template cast<T2>(x));
251256 return o;
252257 }
253258 // Define update / manifold
254259 static void PlusEq (T& v, const auto & delta) {
255260 for (std::size_t i = 0 ; i < v.size (); ++i) {
256- if constexpr (std::is_scalar_v<Scalar> || params_trait<Scalar> ::Dims == 1 )
261+ if constexpr (std::is_scalar_v<Scalar> || ScalarParamsTraits ::Dims == 1 )
257262 v[i] += delta[i];
258- else if constexpr (params_trait<Scalar> ::Dims != Dynamic) {
259- params_trait<Scalar> ::PlusEq (v[i], delta. template segment <params_trait<Scalar>::Dims> (
260- i * params_trait<Scalar> ::Dims));
263+ else if constexpr (ScalarParamsTraits ::Dims != Dynamic) {
264+ ScalarParamsTraits ::PlusEq (
265+ v[i], delta. template segment <ScalarParamsTraits::Dims>( i * ScalarParamsTraits ::Dims));
261266 } else {
262- params_trait<Scalar> ::PlusEq (v[i], delta.segment (i, i * params_trait<Scalar> ::dims (v[i])));
267+ ScalarParamsTraits ::PlusEq (v[i], delta.segment (i, i * ScalarParamsTraits ::dims (v[i])));
263268 }
264269 }
265270 }
@@ -270,18 +275,19 @@ template <typename _Scalar, std::size_t N>
270275struct params_trait <std::array<_Scalar, N>> {
271276 using T = typename std::array<_Scalar, N>;
272277 using Scalar = _Scalar; // The scalar type
278+ using ScalarParamsTraits = params_trait<Scalar>;
273279 static constexpr Index Dims =
274- params_trait<Scalar> ::Dims == Dynamic
280+ ScalarParamsTraits ::Dims == Dynamic
275281 ? Dynamic
276- : N * params_trait<Scalar> ::Dims; // Compile-time parameters dimensions
282+ : N * ScalarParamsTraits ::Dims; // Compile-time parameters dimensions
277283 // Execution-time parameters dimensions
278284 static Index dims (const T& v) {
279- constexpr int ScalarDims = params_trait<Scalar> ::Dims;
285+ constexpr int ScalarDims = ScalarParamsTraits ::Dims;
280286 if constexpr (std::is_scalar_v<Scalar> || ScalarDims == 1 ) {
281287 return N;
282288 } else if constexpr (ScalarDims == Dynamic) {
283289 int d = 0 ;
284- for (std::size_t i = 0 ; i < N; ++i) d += params_trait<Scalar> ::dims (v[i]);
290+ for (std::size_t i = 0 ; i < N; ++i) d += ScalarParamsTraits ::dims (v[i]);
285291 return d;
286292 } else {
287293 return static_cast <Index>(v.size ()) * ScalarDims;
@@ -291,20 +297,22 @@ struct params_trait<std::array<_Scalar, N>> {
291297 // Cast to a new type, only needed when using automatic differentiation
292298 template <typename T2>
293299 static auto cast (const T& v) {
294- std::array<T2, N> o;
295- for (std::size_t i = 0 ; i < N; ++i) o[i] = params_trait<Scalar>::template cast<T2>(v[i]);
300+ using Scalar2 =
301+ std::decay_t <decltype (ScalarParamsTraits::template cast<T2>(std::declval<Scalar>()))>;
302+ std::array<Scalar2, N> o;
303+ for (std::size_t i = 0 ; i < N; ++i) o[i] = ScalarParamsTraits::template cast<T2>(v[i]);
296304 return o;
297305 }
298306 // Define update / manifold
299307 static void PlusEq (T& v, const auto & delta) {
300308 for (std::size_t i = 0 ; i < N; ++i) {
301- if constexpr (std::is_scalar_v<Scalar> || params_trait<Scalar> ::Dims == 1 )
309+ if constexpr (std::is_scalar_v<Scalar> || ScalarParamsTraits ::Dims == 1 )
302310 v[i] += delta[i];
303- else if constexpr (params_trait<Scalar> ::Dims != Dynamic) {
304- params_trait<Scalar> ::PlusEq (v[i], delta. template segment <params_trait<Scalar>::Dims> (
305- i * params_trait<Scalar> ::Dims));
311+ else if constexpr (ScalarParamsTraits ::Dims != Dynamic) {
312+ ScalarParamsTraits ::PlusEq (
313+ v[i], delta. template segment <ScalarParamsTraits::Dims>( i * ScalarParamsTraits ::Dims));
306314 } else {
307- params_trait<Scalar> ::PlusEq (v[i], delta.segment (i, i * params_trait<Scalar> ::dims (v[i])));
315+ ScalarParamsTraits ::PlusEq (v[i], delta.segment (i, i * ScalarParamsTraits ::dims (v[i])));
308316 }
309317 }
310318 }
@@ -315,27 +323,39 @@ template <typename T1, typename T2>
315323struct params_trait <std::pair<T1, T2>> {
316324 using T = std::pair<T1, T2>;
317325 using Scalar = typename params_trait<T1>::Scalar;
326+ using Scalar1ParamsTraits = params_trait<T1>;
327+ using Scalar2ParamsTraits = params_trait<T2>;
328+ // Compile-time parameters dimensions
318329 static constexpr Index Dims =
319- (params_trait<T1> ::Dims == Dynamic || params_trait<T2> ::Dims == Dynamic)
330+ (Scalar1ParamsTraits ::Dims == Dynamic || Scalar2ParamsTraits ::Dims == Dynamic)
320331 ? Dynamic
321- : params_trait<T1> ::Dims + params_trait<T2> ::Dims; // Compile-time parameters dimensions
332+ : Scalar1ParamsTraits ::Dims + Scalar2ParamsTraits ::Dims;
322333
323334 // Execution-time parameters dimensions
324335 static Index dims (const T& v) {
325- return params_trait<T1> ::dims (v.first ) + params_trait<T2> ::dims (v.second );
336+ return Scalar1ParamsTraits ::dims (v.first ) + Scalar2ParamsTraits ::dims (v.second );
326337 }
327338 // Cast to a new type, only needed when using automatic differentiation
328339 template <typename T3>
329340 static auto cast (const T& v) {
330- std::pair<T1, T2> o;
331- o.first = params_trait<T1>::template cast<T3>(v.first );
332- o.second = params_trait<T2>::template cast<T3>(v.second );
341+ using Scalar1 =
342+ std::decay_t <decltype (Scalar1ParamsTraits::template cast<T3>(std::declval<T1>()))>;
343+ using Scalar2 =
344+ std::decay_t <decltype (Scalar2ParamsTraits::template cast<T3>(std::declval<T2>()))>;
345+ std::pair<Scalar1, Scalar2> o{Scalar1ParamsTraits::template cast<T3>(v.first ),
346+ Scalar2ParamsTraits::template cast<T3>(v.second )};
333347 return o;
334348 }
335349 // Define update / manifold
336350 static void PlusEq (T& v, const auto & delta) {
337- params_trait<T1>::PlusEq (v.first , delta.head (params_trait<T1>::dims (v.first )));
338- params_trait<T2>::PlusEq (v.second , delta.tail (params_trait<T1>::dims (v.second )));
351+ if constexpr (Scalar1ParamsTraits::Dims == Dynamic)
352+ Scalar1ParamsTraits::PlusEq (v.first , delta.head (Scalar1ParamsTraits::dims (v.first )));
353+ else
354+ Scalar1ParamsTraits::PlusEq (v.first , delta.template head <Scalar1ParamsTraits::Dims>());
355+ if constexpr (Scalar2ParamsTraits::Dims == Dynamic)
356+ Scalar2ParamsTraits::PlusEq (v.second , delta.tail (Scalar2ParamsTraits::dims (v.second )));
357+ else
358+ Scalar2ParamsTraits::PlusEq (v.first , delta.template tail <Scalar2ParamsTraits::Dims>());
339359 }
340360};
341361
0 commit comments