@@ -42,10 +42,10 @@ void ModelDerivatives::Reset(int dim_state_derivative, int dim_action,
4242}
4343
4444// compute derivatives at all time steps
45- void ModelDerivatives::Compute (const mjModel* m,
45+ void ModelDerivatives::Compute (mjModel* m,
4646 const std::vector<UniqueMjData>& data,
4747 const double * x, const double * u,
48- const double * h, int dim_state,
48+ const double * a, const double * h, int dim_state,
4949 int dim_state_derivative, int dim_action,
5050 int dim_sensor, int T, double tol, int mode,
5151 ThreadPool& pool, int skip) {
@@ -71,10 +71,23 @@ void ModelDerivatives::Compute(const mjModel* m,
7171 }
7272 }
7373
74+ // warmstart
75+ int saved_flags = m->opt .disableflags ;
76+ m->opt .disableflags &= ~mjDSBL_WARMSTART;
77+
78+ // solver settings
79+ int saved_iterations = m->opt .iterations ;
80+ mjtNum saved_tolerance = m->opt .tolerance ;
81+ if (m->opt .solver == mjSOL_NEWTON) {
82+ m->opt .iterations = 1 ;
83+ m->opt .tolerance = 0.0 ;
84+ }
85+ // TODO(taylorhowell): settings for CG and PGS
86+
7487 // evaluate derivatives
7588 int count_before = pool.GetCount ();
7689 for (int t : evaluate_) {
77- pool.Schedule ([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &h,
90+ pool.Schedule ([&m, &data, &A = A, &B = B, &C = C, &D = D, &x, &u, &a, & h,
7891 dim_state, dim_state_derivative, dim_action, dim_sensor, tol,
7992 mode, t, T]() {
8093 mjData* d = data[ThreadPool::WorkerId ()].get ();
@@ -85,6 +98,9 @@ void ModelDerivatives::Compute(const mjModel* m,
8598 // set action
8699 mju_copy (d->ctrl , u + t * dim_action, dim_action);
87100
101+ // set acceleration
102+ mju_copy (d->qacc_warmstart , a + t * m->nv , m->nv );
103+
88104 // Jacobians
89105 if (t == T - 1 ) {
90106 // Jacobians
@@ -105,6 +121,11 @@ void ModelDerivatives::Compute(const mjModel* m,
105121 pool.WaitCount (count_before + evaluate_.size ());
106122 pool.ResetCount ();
107123
124+ // restore settings
125+ m->opt .tolerance = saved_tolerance;
126+ m->opt .iterations = saved_iterations;
127+ m->opt .disableflags = saved_flags;
128+
108129 // interpolate derivatives
109130 count_before = pool.GetCount ();
110131 for (int t : interpolate_) {
0 commit comments