Skip to content

Commit 2cef9fc

Browse files
Merge pull request #47 from julien-michot/feature/params-traits-stl-compound
Support any compounded parameters
2 parents da457b0 + 85fd10c commit 2cef9fc

File tree

4 files changed

+162
-34
lines changed

4 files changed

+162
-34
lines changed

docs/API.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,11 @@ With `tinyopt`, you can directly optimize several types of parameters `x`, namel
121121
* `std::vector` of a scalars or another type
122122
* `Eigen::Vector` of fixed or dynamic size
123123
* `Eigen::Matrix` of fixed or dynamic size
124-
* Your custorm class or struct, see below
124+
* Your custom class or struct, see [User defined parameters](#user-defined-parameters)
125125

126-
You can also use a one level nesting of types as long as the dimensions of the nested type are known at compile time,
127-
e.g. `std::array<Vec2f, 2>` or `std::vector<Vec3>`.
126+
You can also use any levels of nesting for known types, e.g. `std::array<Vec2f, 2>` or `std::pair<std::vector<Vec3>, VecX>`.
128127

129-
`tinyopt` will detect whether the size is known at compile time and use optimized data structs to make the optimization faster.
128+
`tinyopt` will detect whether the size is known at compile time and use this information to make the optimization faster.
130129

131130
Residuals of the following types can also be returned
132131

include/tinyopt/traits.h

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ template <typename T, typename = void>
9191
struct has_cast : std::false_type {};
9292

9393
template <typename T>
94-
struct has_cast<T, std::void_t<decltype(std::declval<const T>().template cast<int>())>> : std::true_type {};
94+
struct has_cast<T, std::void_t<decltype(std::declval<const T>().template cast<int>())>>
95+
: std::true_type {};
9596

9697
// Helper variable template for easier usage
9798
template <typename T>
@@ -228,16 +229,17 @@ struct params_trait<T, std::enable_if_t<is_sparse_matrix_v<T>>> {
228229
template <typename _Scalar>
229230
struct params_trait<std::vector<_Scalar>> {
230231
using T = typename std::vector<_Scalar>;
231-
using Scalar = _Scalar; // The scalar type
232+
using Scalar = _Scalar; // The scalar type
233+
using ScalarParamsTraits = params_trait<Scalar>;
232234
static constexpr Index Dims = Dynamic; // Compile-time parameters dimensions
233235
// Execution-time parameters dimensions
234236
static Index dims(const T& v) {
235-
constexpr int ScalarDims = params_trait<Scalar>::Dims;
237+
constexpr int ScalarDims = ScalarParamsTraits::Dims;
236238
if constexpr (std::is_scalar_v<Scalar> || ScalarDims == 1) {
237239
return static_cast<int>(v.size());
238240
} else if constexpr (ScalarDims == Dynamic) {
239241
int d = 0;
240-
for (std::size_t i = 0; i < v.size(); ++i) d += params_trait<Scalar>::dims(v[i]);
242+
for (std::size_t i = 0; i < v.size(); ++i) d += ScalarParamsTraits::dims(v[i]);
241243
return d;
242244
} else {
243245
return static_cast<int>(v.size()) * ScalarDims;
@@ -246,20 +248,23 @@ struct params_trait<std::vector<_Scalar>> {
246248
// Cast to a new type, only needed when using automatic differentiation
247249
template <typename T2>
248250
static auto cast(const T& v) {
249-
std::vector<T2> o(v.size());
250-
for (std::size_t i = 0; i < v.size(); ++i) o[i] = params_trait<Scalar>::template cast<T2>(v[i]);
251+
using Scalar2 =
252+
std::decay_t<decltype(ScalarParamsTraits::template cast<T2>(std::declval<Scalar>()))>;
253+
std::vector<Scalar2> o;
254+
o.reserve(v.size());
255+
for (auto& x : v) o.emplace_back(ScalarParamsTraits::template cast<T2>(x));
251256
return o;
252257
}
253258
// Define update / manifold
254259
static void PlusEq(T& v, const auto& delta) {
255260
for (std::size_t i = 0; i < v.size(); ++i) {
256-
if constexpr (std::is_scalar_v<Scalar> || params_trait<Scalar>::Dims == 1)
261+
if constexpr (std::is_scalar_v<Scalar> || ScalarParamsTraits::Dims == 1)
257262
v[i] += delta[i];
258-
else if constexpr (params_trait<Scalar>::Dims != Dynamic) {
259-
params_trait<Scalar>::PlusEq(v[i], delta.template segment<params_trait<Scalar>::Dims>(
260-
i * params_trait<Scalar>::Dims));
263+
else if constexpr (ScalarParamsTraits::Dims != Dynamic) {
264+
ScalarParamsTraits::PlusEq(
265+
v[i], delta.template segment<ScalarParamsTraits::Dims>(i * ScalarParamsTraits::Dims));
261266
} else {
262-
params_trait<Scalar>::PlusEq(v[i], delta.segment(i, i * params_trait<Scalar>::dims(v[i])));
267+
ScalarParamsTraits::PlusEq(v[i], delta.segment(i, i * ScalarParamsTraits::dims(v[i])));
263268
}
264269
}
265270
}
@@ -270,18 +275,19 @@ template <typename _Scalar, std::size_t N>
270275
struct params_trait<std::array<_Scalar, N>> {
271276
using T = typename std::array<_Scalar, N>;
272277
using Scalar = _Scalar; // The scalar type
278+
using ScalarParamsTraits = params_trait<Scalar>;
273279
static constexpr Index Dims =
274-
params_trait<Scalar>::Dims == Dynamic
280+
ScalarParamsTraits::Dims == Dynamic
275281
? Dynamic
276-
: N * params_trait<Scalar>::Dims; // Compile-time parameters dimensions
282+
: N * ScalarParamsTraits::Dims; // Compile-time parameters dimensions
277283
// Execution-time parameters dimensions
278284
static Index dims(const T& v) {
279-
constexpr int ScalarDims = params_trait<Scalar>::Dims;
285+
constexpr int ScalarDims = ScalarParamsTraits::Dims;
280286
if constexpr (std::is_scalar_v<Scalar> || ScalarDims == 1) {
281287
return N;
282288
} else if constexpr (ScalarDims == Dynamic) {
283289
int d = 0;
284-
for (std::size_t i = 0; i < N; ++i) d += params_trait<Scalar>::dims(v[i]);
290+
for (std::size_t i = 0; i < N; ++i) d += ScalarParamsTraits::dims(v[i]);
285291
return d;
286292
} else {
287293
return static_cast<Index>(v.size()) * ScalarDims;
@@ -291,20 +297,22 @@ struct params_trait<std::array<_Scalar, N>> {
291297
// Cast to a new type, only needed when using automatic differentiation
292298
template <typename T2>
293299
static auto cast(const T& v) {
294-
std::array<T2, N> o;
295-
for (std::size_t i = 0; i < N; ++i) o[i] = params_trait<Scalar>::template cast<T2>(v[i]);
300+
using Scalar2 =
301+
std::decay_t<decltype(ScalarParamsTraits::template cast<T2>(std::declval<Scalar>()))>;
302+
std::array<Scalar2, N> o;
303+
for (std::size_t i = 0; i < N; ++i) o[i] = ScalarParamsTraits::template cast<T2>(v[i]);
296304
return o;
297305
}
298306
// Define update / manifold
299307
static void PlusEq(T& v, const auto& delta) {
300308
for (std::size_t i = 0; i < N; ++i) {
301-
if constexpr (std::is_scalar_v<Scalar> || params_trait<Scalar>::Dims == 1)
309+
if constexpr (std::is_scalar_v<Scalar> || ScalarParamsTraits::Dims == 1)
302310
v[i] += delta[i];
303-
else if constexpr (params_trait<Scalar>::Dims != Dynamic) {
304-
params_trait<Scalar>::PlusEq(v[i], delta.template segment<params_trait<Scalar>::Dims>(
305-
i * params_trait<Scalar>::Dims));
311+
else if constexpr (ScalarParamsTraits::Dims != Dynamic) {
312+
ScalarParamsTraits::PlusEq(
313+
v[i], delta.template segment<ScalarParamsTraits::Dims>(i * ScalarParamsTraits::Dims));
306314
} else {
307-
params_trait<Scalar>::PlusEq(v[i], delta.segment(i, i * params_trait<Scalar>::dims(v[i])));
315+
ScalarParamsTraits::PlusEq(v[i], delta.segment(i, i * ScalarParamsTraits::dims(v[i])));
308316
}
309317
}
310318
}
@@ -315,27 +323,39 @@ template <typename T1, typename T2>
315323
struct params_trait<std::pair<T1, T2>> {
316324
using T = std::pair<T1, T2>;
317325
using Scalar = typename params_trait<T1>::Scalar;
326+
using Scalar1ParamsTraits = params_trait<T1>;
327+
using Scalar2ParamsTraits = params_trait<T2>;
328+
// Compile-time parameters dimensions
318329
static constexpr Index Dims =
319-
(params_trait<T1>::Dims == Dynamic || params_trait<T2>::Dims == Dynamic)
330+
(Scalar1ParamsTraits::Dims == Dynamic || Scalar2ParamsTraits::Dims == Dynamic)
320331
? Dynamic
321-
: params_trait<T1>::Dims + params_trait<T2>::Dims; // Compile-time parameters dimensions
332+
: Scalar1ParamsTraits::Dims + Scalar2ParamsTraits::Dims;
322333

323334
// Execution-time parameters dimensions
324335
static Index dims(const T& v) {
325-
return params_trait<T1>::dims(v.first) + params_trait<T2>::dims(v.second);
336+
return Scalar1ParamsTraits::dims(v.first) + Scalar2ParamsTraits::dims(v.second);
326337
}
327338
// Cast to a new type, only needed when using automatic differentiation
328339
template <typename T3>
329340
static auto cast(const T& v) {
330-
std::pair<T1, T2> o;
331-
o.first = params_trait<T1>::template cast<T3>(v.first);
332-
o.second = params_trait<T2>::template cast<T3>(v.second);
341+
using Scalar1 =
342+
std::decay_t<decltype(Scalar1ParamsTraits::template cast<T3>(std::declval<T1>()))>;
343+
using Scalar2 =
344+
std::decay_t<decltype(Scalar2ParamsTraits::template cast<T3>(std::declval<T2>()))>;
345+
std::pair<Scalar1, Scalar2> o{Scalar1ParamsTraits::template cast<T3>(v.first),
346+
Scalar2ParamsTraits::template cast<T3>(v.second)};
333347
return o;
334348
}
335349
// Define update / manifold
336350
static void PlusEq(T& v, const auto& delta) {
337-
params_trait<T1>::PlusEq(v.first, delta.head(params_trait<T1>::dims(v.first)));
338-
params_trait<T2>::PlusEq(v.second, delta.tail(params_trait<T1>::dims(v.second)));
351+
if constexpr (Scalar1ParamsTraits::Dims == Dynamic)
352+
Scalar1ParamsTraits::PlusEq(v.first, delta.head(Scalar1ParamsTraits::dims(v.first)));
353+
else
354+
Scalar1ParamsTraits::PlusEq(v.first, delta.template head<Scalar1ParamsTraits::Dims>());
355+
if constexpr (Scalar2ParamsTraits::Dims == Dynamic)
356+
Scalar2ParamsTraits::PlusEq(v.second, delta.tail(Scalar2ParamsTraits::dims(v.second)));
357+
else
358+
Scalar2ParamsTraits::PlusEq(v.first, delta.template tail<Scalar2ParamsTraits::Dims>());
339359
}
340360
};
341361

tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11

2+
add_executable(tinyopt_test_traits traits.cpp)
3+
target_link_libraries(tinyopt_test_traits PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)
4+
add_test_target(tinyopt_test_traits)
25

36
add_executable(tinyopt_test_diff diff.cpp)
47
target_link_libraries(tinyopt_test_diff PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)

tests/traits.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (C) 2025 Julien Michot. All Rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <type_traits>
16+
#if CATCH2_VERSION == 2
17+
#include <catch2/catch.hpp>
18+
#else
19+
#include <catch2/catch_approx.hpp>
20+
#include <catch2/catch_test_macros.hpp>
21+
#endif
22+
23+
#include <tinyopt/tinyopt.h>
24+
#include <tinyopt/traits.h>
25+
26+
using namespace tinyopt;
27+
using namespace tinyopt::nlls;
28+
29+
TEST_CASE("tinyopt_params_traits_stl") {
30+
SECTION("std::array<float, 5>") {
31+
using Params = std::array<float, 5>;
32+
Params x;
33+
REQUIRE(traits::params_trait<Params>::Dims == 5);
34+
REQUIRE(traits::params_trait<Params>::dims(x) == 5);
35+
}
36+
SECTION("std::vector<float>") {
37+
using Params = std::vector<double>;
38+
Params x{{0.4, 0.5}};
39+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
40+
REQUIRE(traits::params_trait<Params>::dims(x) == 2);
41+
}
42+
SECTION("std::array<Vec4, 5>") {
43+
using Params = std::array<Vec4, 5>;
44+
Params x;
45+
REQUIRE(traits::params_trait<Params>::Dims == 4 * 5);
46+
REQUIRE(traits::params_trait<Params>::dims(x) == 4 * 5);
47+
}
48+
SECTION("std::array<VecX, 2>") {
49+
using Params = std::array<VecX, 2>;
50+
using ptraits = traits::params_trait<Params>;
51+
Params x;
52+
REQUIRE(ptraits::Dims == Dynamic);
53+
REQUIRE(ptraits::dims(x) == 0);
54+
static_assert(std::is_same_v<std::decay_t<decltype(ptraits::template cast<float>(x))>,
55+
std::array<VecXf, 2>>,
56+
"Wrong casting");
57+
}
58+
SECTION("std::vector<Vec2>") {
59+
using Params = std::vector<Vec2>;
60+
using ptraits = traits::params_trait<Params>;
61+
Params x{{Vec2::Zero(), Vec2::Zero(), Vec2::Zero()}};
62+
REQUIRE(ptraits::Dims == Dynamic);
63+
REQUIRE(ptraits::dims(x) == 6);
64+
static_assert(std::is_same_v<std::decay_t<decltype(ptraits::template cast<float>(x))>,
65+
std::vector<Vec2f>>,
66+
"Wrong casting");
67+
}
68+
SECTION("std::pair<Vec2, Vec3>") {
69+
using Params = std::pair<Vec2, Vec3>;
70+
using ptraits = traits::params_trait<Params>;
71+
Params x;
72+
REQUIRE(ptraits::Dims == 2 + 3);
73+
REQUIRE(ptraits::dims(x) == 2 + 3);
74+
static_assert(std::is_same_v<std::decay_t<decltype(ptraits::template cast<float>(x))>,
75+
std::pair<Vec2f, Vec3f>>,
76+
"Wrong casting");
77+
}
78+
SECTION("std::pair<Vec2, VecX>") {
79+
using Params = std::pair<Vec2, VecX>;
80+
Params x = std::make_pair(Vec2::Zero(), VecX::Random(4));
81+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
82+
REQUIRE(traits::params_trait<Params>::dims(x) == 2 + 4);
83+
}
84+
SECTION("std::pair<vector<float>, Vec3>") {
85+
using Params = std::pair<std::vector<float>, Vec3>;
86+
Params x = std::make_pair(std::vector<float>{{1, 2, 3, 4}}, Vec3::Zero());
87+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
88+
REQUIRE(traits::params_trait<Params>::dims(x) == 4 + 3);
89+
}
90+
SECTION("std::pair<std::vector<Vec3>, std::array<VecX, 4>>") {
91+
using Params = std::pair<std::vector<Vec3>, std::array<VecX, 4>>;
92+
std::vector<Vec3> a{Vec3::Zero(), Vec3::Zero()};
93+
std::array<VecX, 4> b{{VecX::Random(5), VecX::Random(2), VecX::Random(0), VecX::Random(0)}};
94+
Params x = std::make_pair(a, b);
95+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
96+
REQUIRE(traits::params_trait<Params>::dims(x) == 6 + 7);
97+
}
98+
SECTION("std::pair<std::array<Vec3>, std::array<Vec2, 4>>") {
99+
using Params = std::pair<std::array<Vec3, 5>, std::array<Vec2, 4>>;
100+
std::array<Vec3, 5> a;
101+
std::array<Vec2, 4> b;
102+
Params x = std::make_pair(a, b);
103+
REQUIRE(traits::params_trait<Params>::Dims == 15 + 8);
104+
REQUIRE(traits::params_trait<Params>::dims(x) == 15 + 8);
105+
}
106+
}

0 commit comments

Comments
 (0)