Skip to content

Commit 7bb0eae

Browse files
committed
Support any compounded parameters
1 parent da457b0 commit 7bb0eae

File tree

4 files changed

+111
-11
lines changed

4 files changed

+111
-11
lines changed

docs/API.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,11 @@ With `tinyopt`, you can directly optimize several types of parameters `x`, namel
121121
* `std::vector` of a scalars or another type
122122
* `Eigen::Vector` of fixed or dynamic size
123123
* `Eigen::Matrix` of fixed or dynamic size
124-
* Your custorm class or struct, see below
124+
* Your custom class or struct, see [User defined parameters](#user-defined-parameters)
125125

126-
You can also use a one level nesting of types as long as the dimensions of the nested type are known at compile time,
127-
e.g. `std::array<Vec2f, 2>` or `std::vector<Vec3>`.
126+
You can also use any levels of nesting for known types, e.g. `std::array<Vec2f, 2>` or `std::pair<std::vector<Vec3>, VecX>`.
128127

129-
`tinyopt` will detect whether the size is known at compile time and use optimized data structs to make the optimization faster.
128+
`tinyopt` will detect whether the size is known at compile time and use this information to make the optimization faster.
130129

131130
Residuals of the following types can also be returned
132131

include/tinyopt/traits.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ template <typename T, typename = void>
9191
struct has_cast : std::false_type {};
9292

9393
template <typename T>
94-
struct has_cast<T, std::void_t<decltype(std::declval<const T>().template cast<int>())>> : std::true_type {};
94+
struct has_cast<T, std::void_t<decltype(std::declval<const T>().template cast<int>())>>
95+
: std::true_type {};
9596

9697
// Helper variable template for easier usage
9798
template <typename T>
@@ -246,8 +247,10 @@ struct params_trait<std::vector<_Scalar>> {
246247
// Cast to a new type, only needed when using automatic differentiation
247248
template <typename T2>
248249
static auto cast(const T& v) {
249-
std::vector<T2> o(v.size());
250-
for (std::size_t i = 0; i < v.size(); ++i) o[i] = params_trait<Scalar>::template cast<T2>(v[i]);
250+
using Scalar2 = params_trait<Scalar>::Scalar;
251+
std::vector<Scalar2> o;
252+
o.reserve(v.size());
253+
for (auto& x : v) o.emplace_back(params_trait<Scalar>::template cast<T2>(x));
251254
return o;
252255
}
253256
// Define update / manifold
@@ -291,7 +294,8 @@ struct params_trait<std::array<_Scalar, N>> {
291294
// Cast to a new type, only needed when using automatic differentiation
292295
template <typename T2>
293296
static auto cast(const T& v) {
294-
std::array<T2, N> o;
297+
using Scalar2 = params_trait<Scalar>::Scalar;
298+
std::array<Scalar2, N> o;
295299
for (std::size_t i = 0; i < N; ++i) o[i] = params_trait<Scalar>::template cast<T2>(v[i]);
296300
return o;
297301
}
@@ -327,9 +331,10 @@ struct params_trait<std::pair<T1, T2>> {
327331
// Cast to a new type, only needed when using automatic differentiation
328332
template <typename T3>
329333
static auto cast(const T& v) {
330-
std::pair<T1, T2> o;
331-
o.first = params_trait<T1>::template cast<T3>(v.first);
332-
o.second = params_trait<T2>::template cast<T3>(v.second);
334+
using Scalar1 = params_trait<T1>::Scalar;
335+
using Scalar2 = params_trait<T2>::Scalar;
336+
std::pair<Scalar1, Scalar2> o{params_trait<T1>::template cast<T3>(v.first),
337+
params_trait<T2>::template cast<T3>(v.second)};
333338
return o;
334339
}
335340
// Define update / manifold

tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11

2+
add_executable(tinyopt_test_traits traits.cpp)
3+
target_link_libraries(tinyopt_test_traits PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)
4+
add_test_target(tinyopt_test_traits)
25

36
add_executable(tinyopt_test_diff diff.cpp)
47
target_link_libraries(tinyopt_test_diff PRIVATE ${THIRDPARTY_TEST_LIBS} tinyopt)

tests/traits.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (C) 2025 Julien Michot. All Rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if CATCH2_VERSION == 2
16+
#include <catch2/catch.hpp>
17+
#else
18+
#include <catch2/catch_approx.hpp>
19+
#include <catch2/catch_test_macros.hpp>
20+
#endif
21+
22+
#include <tinyopt/tinyopt.h>
23+
#include <tinyopt/traits.h>
24+
25+
using namespace tinyopt;
26+
using namespace tinyopt::nlls;
27+
28+
TEST_CASE("tinyopt_params_traits_stl") {
29+
SECTION("std::array<float, 5>") {
30+
using Params = std::array<float, 5>;
31+
Params x;
32+
REQUIRE(traits::params_trait<Params>::Dims == 5);
33+
REQUIRE(traits::params_trait<Params>::dims(x) == 5);
34+
}
35+
SECTION("std::vector<float>") {
36+
using Params = std::vector<double>;
37+
Params x{{0.4, 0.5}};
38+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
39+
REQUIRE(traits::params_trait<Params>::dims(x) == 2);
40+
}
41+
SECTION("std::array<Vec4, 5>") {
42+
using Params = std::array<Vec4, 5>;
43+
Params x;
44+
REQUIRE(traits::params_trait<Params>::Dims == 4 * 5);
45+
REQUIRE(traits::params_trait<Params>::dims(x) == 4 * 5);
46+
}
47+
SECTION("std::array<VecX, 2>") {
48+
using Params = std::array<VecX, 2>;
49+
Params x;
50+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
51+
REQUIRE(traits::params_trait<Params>::dims(x) == 0);
52+
}
53+
SECTION("std::vector<Vec2>") {
54+
using Params = std::vector<Vec2>;
55+
Params x{{Vec2::Zero(), Vec2::Zero(), Vec2::Zero()}};
56+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
57+
REQUIRE(traits::params_trait<Params>::dims(x) == 6);
58+
}
59+
SECTION("std::pair<Vec2, Vec3>") {
60+
using Params = std::pair<Vec2, Vec3>;
61+
Params x;
62+
REQUIRE(traits::params_trait<Params>::Dims == 2 + 3);
63+
REQUIRE(traits::params_trait<Params>::dims(x) == 2 + 3);
64+
}
65+
SECTION("std::pair<Vec2, VecX>") {
66+
using Params = std::pair<Vec2, VecX>;
67+
Params x = std::make_pair(Vec2::Zero(), VecX::Random(4));
68+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
69+
REQUIRE(traits::params_trait<Params>::dims(x) == 2 + 4);
70+
}
71+
SECTION("std::pair<vector<float>, Vec3>") {
72+
using Params = std::pair<std::vector<float>, Vec3>;
73+
Params x = std::make_pair(std::vector<float>{{1, 2, 3, 4}}, Vec3::Zero());
74+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
75+
REQUIRE(traits::params_trait<Params>::dims(x) == 4 + 3);
76+
}
77+
SECTION("std::pair<std::vector<Vec3>, std::array<VecX, 4>>") {
78+
using Params = std::pair<std::vector<Vec3>, std::array<VecX, 4>>;
79+
std::vector<Vec3> a{Vec3::Zero(), Vec3::Zero()};
80+
std::array<VecX, 4> b{{VecX::Random(5), VecX::Random(2), VecX::Random(0), VecX::Random(0)}};
81+
Params x = std::make_pair(a, b);
82+
REQUIRE(traits::params_trait<Params>::Dims == Dynamic);
83+
REQUIRE(traits::params_trait<Params>::dims(x) == 6 + 7);
84+
}
85+
SECTION("std::pair<std::array<Vec3>, std::array<Vec2, 4>>") {
86+
using Params = std::pair<std::array<Vec3, 5>, std::array<Vec2, 4>>;
87+
std::array<Vec3, 5> a;
88+
std::array<Vec2, 4> b;
89+
Params x = std::make_pair(a, b);
90+
REQUIRE(traits::params_trait<Params>::Dims == 15 + 8);
91+
REQUIRE(traits::params_trait<Params>::dims(x) == 15 + 8);
92+
}
93+
}

0 commit comments

Comments
 (0)