Skip to content

Commit b881154

Browse files
implement the single argument version
1 parent ef258b4 commit b881154

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

solvers/solve.cc

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,27 @@ std::vector<MathematicalProgramResult> SolveInParallel(
111111
const SolverOptions* solver_options,
112112
const std::optional<SolverId>& solver_id, const Parallelism parallelism,
113113
const bool dynamic_schedule) {
114-
// Broadcast the option and id arguments into vectors (if given).
115-
std::optional<std::vector<const SolverOptions*>> broadcast_options;
116-
std::optional<std::vector<std::optional<SolverId>>> broadcast_ids;
117-
if (solver_options != nullptr) {
118-
broadcast_options.emplace(progs.size(), solver_options);
119-
}
120-
if (solver_id.has_value()) {
121-
broadcast_ids.emplace(progs.size(), solver_id);
122-
}
123-
// Delegate to the primary overload.
124-
return SolveInParallel(
125-
progs, initial_guesses,
126-
broadcast_options.has_value() ? &(*broadcast_options) : nullptr,
127-
broadcast_ids.has_value() ? &(*broadcast_ids) : nullptr, // BR
128-
parallelism, dynamic_schedule);
114+
// TODO(Alexandre.Amice) is there a way around this clone?
115+
auto prog_generator = [&progs](const int thread_num, const int64_t i) {
116+
unused(thread_num);
117+
return progs[i]->Clone();
118+
};
119+
120+
auto initial_guess_generator =
121+
[&initial_guesses](int64_t thread_num,
122+
int64_t i) -> std::optional<Eigen::VectorXd> {
123+
unused(thread_num);
124+
if (initial_guesses != nullptr && initial_guesses->at(i) != nullptr) {
125+
return *(initial_guesses->at(i));
126+
} else {
127+
return std::nullopt;
128+
}
129+
};
130+
131+
return SolveInParallel(prog_generator, initial_guess_generator,
132+
static_cast<int64_t>(0),
133+
static_cast<int64_t>(progs.size()), solver_options,
134+
solver_id, parallelism, dynamic_schedule);
129135
}
130136

131137
std::vector<MathematicalProgramResult> SolveInParallel(
@@ -257,9 +263,35 @@ std::vector<MathematicalProgramResult> SolveInParallel(
257263
solve_ith_serial(i);
258264
}
259265
}
260-
261266
return results;
262267
}
263268

269+
std::vector<MathematicalProgramResult> SolveInParallel(
270+
const std::function<std::unique_ptr<MathematicalProgram>(int64_t, int64_t)>&
271+
prog_generator,
272+
const std::function<std::optional<Eigen::VectorXd>(int64_t, int64_t)>&
273+
initial_guesses_generator,
274+
const int64_t range_start, const int64_t range_end,
275+
const SolverOptions* solver_options,
276+
const std::optional<SolverId>& solver_id, Parallelism parallelism,
277+
bool dynamic_schedule) {
278+
auto solver_options_generator =
279+
[&solver_options](int64_t thread_num,
280+
int64_t i) -> std::optional<SolverOptions> {
281+
unused(thread_num, i);
282+
return solver_options == nullptr
283+
? std::nullopt
284+
: std::optional<SolverOptions>{*solver_options};
285+
};
286+
auto solver_id_generator =
287+
[&solver_id](int64_t thread_num, int64_t i) -> std::optional<SolverId> {
288+
unused(thread_num, i);
289+
return solver_id;
290+
};
291+
return SolveInParallel(prog_generator, initial_guesses_generator,
292+
solver_options_generator, solver_id_generator,
293+
range_start, range_end, parallelism, dynamic_schedule);
294+
}
295+
264296
} // namespace solvers
265297
} // namespace drake

solvers/solve.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ std::vector<MathematicalProgramResult> SolveInParallel(
120120
* solver_ids can be nullptr or std::optional.
121121
*
122122
* @note Please ensure that all generators are thread safe.
123-
* @return
123+
* @return A vector of size range_end with range_start to range_end populated
124+
* with the results of solving prog_generator(*, range_start-range_end)
124125
*/
125126
std::vector<MathematicalProgramResult> SolveInParallel(
126127
const std::function<std::unique_ptr<MathematicalProgram>(int64_t, int64_t)>&
@@ -143,13 +144,14 @@ std::vector<MathematicalProgramResult> SolveInParallel(
143144
* The input to the generator is an integer i and the output is the ith program.
144145
*
145146
* The output of prog_generator cannot be a nullptr.
146-
* @return
147147
*/
148148
std::vector<MathematicalProgramResult> SolveInParallel(
149149
const std::function<std::unique_ptr<MathematicalProgram>(int64_t, int64_t)>&
150150
prog_generator,
151-
const const std::function<std::optional<Eigen::VectorXd>(int64_t, int64_t)>&
152-
initial_guesses_generatorconst SolverOptions* solver_options = nullptr,
151+
const std::function<std::optional<Eigen::VectorXd>(int64_t, int64_t)>&
152+
initial_guesses_generator,
153+
const int64_t range_start, const int64_t range_end,
154+
const SolverOptions* solver_options = nullptr,
153155
const std::optional<SolverId>& solver_id = std::nullopt,
154156
Parallelism parallelism = Parallelism::Max(),
155157
bool dynamic_schedule = false);

0 commit comments

Comments
 (0)