Skip to content

Commit 0c76c95

Browse files
thowellcopybara-github
authored andcommitted
Modify solver settings for model derivative finite difference computations and warmstart with accelerations.
PiperOrigin-RevId: 733362111 Change-Id: I85cc7d7fae3178d8ccc4acfa8171810302d9a045
1 parent 14ddc28 commit 0c76c95

File tree

9 files changed

+74
-30
lines changed

9 files changed

+74
-30
lines changed

mjpc/agent.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ void Agent::Initialize(const mjModel* model) {
8787
}
8888

8989
// planner
90-
planner_ = GetNumberOrDefault(0, model, "agent_planner");
90+
planner_ =
91+
static_cast<PlannerType>(GetNumberOrDefault(0, model, "agent_planner"));
9192

9293
// estimator
9394
estimator_ =

mjpc/agent.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class Agent {
213213

214214
// planners
215215
std::vector<std::unique_ptr<mjpc::Planner>> planners_;
216-
int planner_;
216+
PlannerType planner_;
217217

218218
// estimators
219219
std::vector<std::unique_ptr<mjpc::Estimator>> estimators_;

mjpc/planners/gradient/planner.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ void GradientPlanner::OptimizePolicy(int horizon, ThreadPool& pool) {
203203
// compute model and sensor Jacobians
204204
model_derivative.Compute(
205205
model, data_, trajectory[0].states.data(), trajectory[0].actions.data(),
206-
trajectory[0].times.data(), dim_state, dim_state_derivative, dim_action,
207-
dim_sensor, horizon, settings.fd_tolerance, settings.fd_mode, pool,
208-
skip);
206+
trajectory[0].accelerations.data(), trajectory[0].times.data(),
207+
dim_state, dim_state_derivative, dim_action, dim_sensor, horizon,
208+
settings.fd_tolerance, settings.fd_mode, pool, skip);
209209

210210
// stop timer
211211
model_derivative_time += GetDuration(model_derivative_start);

mjpc/planners/ilqg/planner.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ void iLQGPlanner::Iteration(int horizon, ThreadPool& pool) {
395395
model_derivative.Compute(
396396
model, data_, candidate_policy[0].trajectory.states.data(),
397397
candidate_policy[0].trajectory.actions.data(),
398+
candidate_policy[0].trajectory.accelerations.data(),
398399
candidate_policy[0].trajectory.times.data(), dim_state,
399400
dim_state_derivative, dim_action, dim_sensor, horizon,
400401
settings.fd_tolerance, settings.fd_mode, pool, derivative_skip_);

mjpc/planners/model_derivatives.cc

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ void ModelDerivatives::Reset(int dim_state_derivative, int dim_action,
4242
}
4343

4444
// compute derivatives at all time steps
45-
void ModelDerivatives::Compute(const mjModel* m,
45+
void ModelDerivatives::Compute(mjModel* m,
4646
const std::vector<UniqueMjData>& data,
4747
const double* x, const double* u,
48-
const double* h, int dim_state,
48+
const double* a, const double* h, int dim_state,
4949
int dim_state_derivative, int dim_action,
5050
int dim_sensor, int T, double tol, int mode,
5151
ThreadPool& pool, int skip) {
@@ -71,10 +71,23 @@ void ModelDerivatives::Compute(const mjModel* m,
7171
}
7272
}
7373

74+
// warmstart
75+
int saved_flags = m->opt.disableflags;
76+
m->opt.disableflags &= ~mjDSBL_WARMSTART;
77+
78+
// solver settings
79+
int saved_iterations = m->opt.iterations;
80+
mjtNum saved_tolerance = m->opt.tolerance;
81+
if (m->opt.solver == mjSOL_NEWTON) {
82+
m->opt.iterations = 1;
83+
m->opt.tolerance = 0.0;
84+
}
85+
// TODO(taylorhowell): settings for CG and PGS
86+
7487
// evaluate derivatives
7588
int count_before = pool.GetCount();
7689
for (int t : evaluate_) {
77-
pool.Schedule([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &h,
90+
pool.Schedule([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &a, &h,
7891
dim_state, dim_state_derivative, dim_action, dim_sensor, tol,
7992
mode, t, T]() {
8093
mjData* d = data[ThreadPool::WorkerId()].get();
@@ -85,6 +98,9 @@ void ModelDerivatives::Compute(const mjModel* m,
8598
// set action
8699
mju_copy(d->ctrl, u + t * dim_action, dim_action);
87100

101+
// set acceleration
102+
mju_copy(d->qacc_warmstart, a + t * m->nv, m->nv);
103+
88104
// Jacobians
89105
if (t == T - 1) {
90106
// Jacobians
@@ -105,6 +121,11 @@ void ModelDerivatives::Compute(const mjModel* m,
105121
pool.WaitCount(count_before + evaluate_.size());
106122
pool.ResetCount();
107123

124+
// restore settings
125+
m->opt.tolerance = saved_tolerance;
126+
m->opt.iterations = saved_iterations;
127+
m->opt.disableflags = saved_flags;
128+
108129
// interpolate derivatives
109130
count_before = pool.GetCount();
110131
for (int t : interpolate_) {

mjpc/planners/model_derivatives.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ class ModelDerivatives {
4242
void Reset(int dim_state_derivative, int dim_action, int dim_sensor, int T);
4343

4444
// compute derivatives at all time steps
45-
void Compute(const mjModel* m, const std::vector<UniqueMjData>& data,
46-
const double* x, const double* u, const double* h, int dim_state,
47-
int dim_state_derivative, int dim_action, int dim_sensor, int T,
48-
double tol, int mode, ThreadPool& pool, int skip = 0);
45+
void Compute(mjModel* m, const std::vector<UniqueMjData>& data,
46+
const double* x, const double* u, const double* a,
47+
const double* h, int dim_state, int dim_state_derivative,
48+
int dim_action, int dim_sensor, int T, double tol, int mode,
49+
ThreadPool& pool, int skip = 0);
4950

5051
// Jacobians
5152
std::vector<double> A; // model Jacobians wrt state

mjpc/test/agent/agent_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "gtest/gtest.h"
2323
#include <mujoco/mujoco.h>
24+
#include "mjpc/planners/include.h"
2425
#include "mjpc/planners/ilqs/planner.h"
2526
#include "mjpc/planners/sampling/planner.h"
2627
#include "mjpc/task.h"
@@ -71,7 +72,7 @@ class AgentTest : public ::testing::Test {
7172
// test
7273
EXPECT_EQ(agent->integrator_, 0);
7374
EXPECT_NEAR(agent->timestep_, 0.1, 1.0e-5);
74-
EXPECT_EQ(agent->planner_, 0);
75+
EXPECT_EQ(agent->planner_, kSamplingPlanner);
7576
EXPECT_NEAR(agent->horizon_, 1, 1.0e-5);
7677
EXPECT_EQ(agent->steps_, 11);
7778
EXPECT_FALSE(agent->plan_enabled);
@@ -152,7 +153,7 @@ class AgentTest : public ::testing::Test {
152153
0.0, 1.0e-1);
153154

154155
// ----- switch to iLQG planner ----- //
155-
agent->planner_ = 2;
156+
agent->planner_ = kILQGPlanner;
156157
agent->Allocate();
157158
agent->Reset();
158159
exitrequest.store(false);
@@ -215,7 +216,7 @@ class AgentTest : public ::testing::Test {
215216
agent->plan_enabled = true;
216217

217218
bool success = false;
218-
agent->planner_ = 0; // sampling
219+
agent->planner_ = kSamplingPlanner;
219220
reinterpret_cast<SamplingPlanner*>(&agent->ActivePlanner())
220221
->num_trajectory_ = 128;
221222

@@ -278,7 +279,7 @@ class AgentTest : public ::testing::Test {
278279
agent->Reset();
279280
agent->plan_enabled = true;
280281

281-
agent->planner_ = 2; // iLQG
282+
agent->planner_ = kILQGPlanner;
282283

283284
agent->Reset();
284285
data->qpos[0] = 0;
@@ -328,7 +329,7 @@ class AgentTest : public ::testing::Test {
328329
agent->Reset();
329330
agent->plan_enabled = true;
330331

331-
agent->planner_ = 3; // iLQS
332+
agent->planner_ = kILQSPlanner;
332333
iLQSPlanner* planner =
333334
reinterpret_cast<iLQSPlanner*>(&agent->ActivePlanner());
334335

mjpc/trajectory.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ void Trajectory::Allocate(int T) {
5959

6060
// traces
6161
trace.resize(dim_trace * T);
62+
63+
// accelerations
64+
accelerations.resize(dim_state * T); // TODO(taylorhowell): allocate nv x T
6265
}
6366

6467
// reset memory to zeros
@@ -86,6 +89,9 @@ void Trajectory::Reset(int T, const double* initial_repeated_action) {
8689

8790
// traces
8891
std::fill(trace.begin(), trace.begin() + dim_trace * T, 0.0);
92+
93+
// accelerations
94+
std::fill(accelerations.begin(), accelerations.begin() + dim_state * T, 0.0);
8995
}
9096

9197
// simulate model forward in time with continuous-time indexed policy
@@ -177,6 +183,9 @@ void Trajectory::NoisyRollout(
177183
mju_copy(DataAt(states, (t + 1) * dim_state + nq), data->qvel, nv);
178184
mju_copy(DataAt(states, (t + 1) * dim_state + nq + nv), data->act, na);
179185
times[t + 1] = data->time;
186+
187+
// record acceleration
188+
mju_copy(DataAt(accelerations, t * nv), data->qacc, nv);
180189
}
181190

182191
// check for step warnings
@@ -205,6 +214,9 @@ void Trajectory::NoisyRollout(
205214
GetTraces(DataAt(trace, (horizon - 1) * 3 * task->num_trace), model, data,
206215
task->num_trace);
207216

217+
// final acceleration
218+
mju_copy(DataAt(accelerations, (horizon - 1) * nv), data->qacc, nv);
219+
208220
// compute return
209221
UpdateReturn(task);
210222
}
@@ -276,6 +288,9 @@ void Trajectory::RolloutDiscrete(
276288
mju_copy(DataAt(states, (t + 1) * dim_state + nq), data->qvel, nv);
277289
mju_copy(DataAt(states, (t + 1) * dim_state + nq + nv), data->act, na);
278290
times[t + 1] = data->time;
291+
292+
// record acceleration
293+
mju_copy(DataAt(accelerations, t * nv), data->qacc, nv);
279294
}
280295

281296
// check for step warnings
@@ -304,6 +319,9 @@ void Trajectory::RolloutDiscrete(
304319
GetTraces(DataAt(trace, (horizon - 1) * 3 * task->num_trace), model, data,
305320
task->num_trace);
306321

322+
// final acceleration
323+
mju_copy(DataAt(accelerations, (horizon - 1) * nv), data->qacc, nv);
324+
307325
// compute return
308326
UpdateReturn(task);
309327
}

mjpc/trajectory.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,20 @@ class Trajectory {
7171
double time, const double* mocap, const double* userdata, int steps);
7272

7373
// ----- members ----- //
74-
int horizon; // trajectory length
75-
int dim_state; // states dimension
76-
int dim_action; // actions dimension
77-
int dim_residual; // residual dimension
78-
int dim_trace; // traces dimension
79-
std::vector<double> states; // (horizon x nq + nv + na)
80-
std::vector<double> actions; // (horizon-1 x num_action)
81-
std::vector<double> times; // horizon
82-
std::vector<double> residual; // (horizon x num_residual)
83-
std::vector<double> costs; // horizon
84-
std::vector<double> trace; // (horizon x 3)
85-
double total_return; // (1)
86-
bool failure; // true if last rollout had a warning
74+
int horizon; // trajectory length
75+
int dim_state; // states dimension
76+
int dim_action; // actions dimension
77+
int dim_residual; // residual dimension
78+
int dim_trace; // traces dimension
79+
std::vector<double> states; // (horizon x nq + nv + na)
80+
std::vector<double> actions; // (horizon-1 x num_action)
81+
std::vector<double> times; // horizon
82+
std::vector<double> residual; // (horizon x num_residual)
83+
std::vector<double> costs; // horizon
84+
std::vector<double> trace; // (horizon x 3)
85+
double total_return; // (1)
86+
bool failure; // true if last rollout had a warning
87+
std::vector<double> accelerations; // (horizon x nv)
8788

8889
private:
8990
// calculates total_return and costs

0 commit comments

Comments
 (0)