@@ -153,6 +153,19 @@ void Agent::Initialize(const mjModel* model) {
153153 planner_threads_ =
154154 std::max (1 , NumAvailableHardwareThreads () - 3 - 2 * estimator_threads_);
155155
156+ // differentiable planning model
157+ // by default gradient-based planners use a differentiable model
158+ int gradient_planner = false ;
159+ if (planner_ == kGradientPlanner || planner_ == kILQGPlanner ||
160+ planner_ == kILQSPlanner ) {
161+ gradient_planner = true ;
162+ }
163+ differentiable_ =
164+ GetNumberOrDefault (gradient_planner, model, " agent_differentiable" );
165+ jnt_solimp_.resize (model->njnt );
166+ geom_solimp_.resize (model->ngeom );
167+ pair_solimp_.resize (model->npair );
168+
156169 // delete the previous model after all the planners have been updated to use
157170 // the new one.
158171 if (old_model) {
@@ -279,6 +292,22 @@ void Agent::PlanIteration(ThreadPool* pool) {
279292 steps_ =
280293 mju_max (mju_min (horizon_ / timestep_ + 1 , kMaxTrajectoryHorizon ), 1 );
281294
295+ // make model differentiable
296+ int differentiable = differentiable_;
297+ if (differentiable) {
298+ // cache solimp defaults
299+ for (int i = 0 ; i < model_->njnt ; i++) {
300+ jnt_solimp_[i] = model_->jnt_solimp [mjNIMP * i];
301+ }
302+ for (int i = 0 ; i < model_->ngeom ; i++) {
303+ geom_solimp_[i] = model_->geom_solimp [mjNIMP * i];
304+ }
305+ for (int i = 0 ; i < model_->npair ; i++) {
306+ pair_solimp_[i] = model_->pair_solimp [mjNIMP * i];
307+ }
308+ MakeDifferentiable (model_);
309+ }
310+
282311 // plan
283312 if (!allocate_enabled) {
284313 // set state
@@ -312,6 +341,19 @@ void Agent::PlanIteration(ThreadPool* pool) {
312341 // release the planning residual function
313342 residual_fn_.reset ();
314343 }
344+
345+ // restore solimp defaults
346+ if (differentiable) {
347+ for (int i = 0 ; i < model_->njnt ; i++) {
348+ model_->jnt_solimp [mjNIMP * i] = jnt_solimp_[i];
349+ }
350+ for (int i = 0 ; i < model_->ngeom ; i++) {
351+ model_->geom_solimp [mjNIMP * i] = geom_solimp_[i];
352+ }
353+ for (int i = 0 ; i < model_->npair ; i++) {
354+ model_->pair_solimp [mjNIMP * i] = pair_solimp_[i];
355+ }
356+ }
315357}
316358
317359// call planner to update nominal policy
@@ -644,21 +686,23 @@ void Agent::GUI(mjUI& ui) {
644686 }
645687
646688 // ----- agent ----- //
647- mjuiDef defAgent[] = {{mjITEM_SECTION, " Agent" , 1 , nullptr , " AP" },
648- {mjITEM_BUTTON, " Reset" , 2 , nullptr , " #459" },
649- {mjITEM_SELECT, " Planner" , 2 , &planner_, " " },
650- {mjITEM_SELECT, " Estimator" , 2 , &estimator_, " " },
651- {mjITEM_CHECKINT, " Plan" , 2 , &plan_enabled, " " },
652- {mjITEM_CHECKINT, " Action" , 2 , &action_enabled, " " },
653- {mjITEM_CHECKINT, " Plots" , 2 , &plot_enabled, " " },
654- {mjITEM_CHECKINT, " Traces" , 2 , &visualize_enabled, " " },
655- {mjITEM_SEPARATOR, " Agent Settings" , 1 },
656- {mjITEM_SLIDERNUM, " Horizon" , 2 , &horizon_, " 0 1" },
657- {mjITEM_SLIDERNUM, " Timestep" , 2 , ×tep_, " 0 1" },
658- {mjITEM_SELECT, " Integrator" , 2 , &integrator_,
659- " Euler\n RK4\n Implicit\n ImplicitFast" },
660- {mjITEM_SEPARATOR, " Planner Settings" , 1 },
661- {mjITEM_END}};
689+ mjuiDef defAgent[] = {
690+ {mjITEM_SECTION, " Agent" , 1 , nullptr , " AP" },
691+ {mjITEM_BUTTON, " Reset" , 2 , nullptr , " #459" },
692+ {mjITEM_SELECT, " Planner" , 2 , &planner_, " " },
693+ {mjITEM_SELECT, " Estimator" , 2 , &estimator_, " " },
694+ {mjITEM_CHECKINT, " Plan" , 2 , &plan_enabled, " " },
695+ {mjITEM_CHECKINT, " Action" , 2 , &action_enabled, " " },
696+ {mjITEM_CHECKINT, " Plots" , 2 , &plot_enabled, " " },
697+ {mjITEM_CHECKINT, " Traces" , 2 , &visualize_enabled, " " },
698+ {mjITEM_SEPARATOR, " Agent Settings" , 1 },
699+ {mjITEM_SLIDERNUM, " Horizon" , 2 , &horizon_, " 0 1" },
700+ {mjITEM_SLIDERNUM, " Timestep" , 2 , ×tep_, " 0 1" },
701+ {mjITEM_SELECT, " Integrator" , 2 , &integrator_,
702+ " Euler\n RK4\n Implicit\n ImplicitFast" },
703+ {mjITEM_CHECKINT, " Differentiable" , 2 , &differentiable_, " " },
704+ {mjITEM_SEPARATOR, " Planner Settings" , 1 },
705+ {mjITEM_END}};
662706
663707 // planner names
664708 mju::strcpy_arr (defAgent[2 ].other , planner_names_);
@@ -730,6 +774,14 @@ void Agent::AgentEvent(mjuiItem* it, mjData* data,
730774 this ->PlotInitialize ();
731775 this ->PlotReset ();
732776
777+ // by default gradient-based planners use a differentiable model
778+ if (planner_ == kGradientPlanner || planner_ == kILQGPlanner ||
779+ planner_ == kILQSPlanner ) {
780+ differentiable_ = true ;
781+ } else {
782+ differentiable_ = false ;
783+ }
784+
733785 // reset agent
734786 uiloadrequest.fetch_sub (1 );
735787 }
0 commit comments