Skip to content

Merge social nav transformer semantic place add pi zero eval #2140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: merge_social_nav_transformer_semantic_place
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,10 @@ class HabitatBaselinesConfig(HabitatBaselinesBaseConfig):
# Function signature: fn(save_file_path: str) -> None
# If not specified, there is no callback.
on_save_ckpt_callback: Optional[HydraCallbackConfig] = None

# If we want to load third party
load_third_party_ckpt: bool = False
third_party_config_path_dir: str = "third_party/config"
third_party_ckpt_root_folder: str = "third_party/model"

@dataclass
class HabitatBaselinesRLConfig(HabitatBaselinesConfig):
Expand Down
142 changes: 142 additions & 0 deletions habitat-baselines/habitat_baselines/config/social_nav/oracle_nav.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# @package _global_

defaults:
- /benchmark/multi_agent: hssd_spot_human_social_nav
- /habitat_baselines: habitat_baselines_rl_config_base
- /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor
- /habitat_baselines/rl/policy/obs_transforms@habitat_baselines.rl.policy.agent_0.obs_transforms.add_virtual_keys:
- add_virtual_keys_base
- /habitat_baselines/rl/policy/obs_transforms@habitat_baselines.rl.policy.agent_1.obs_transforms.add_virtual_keys:
- add_virtual_keys_base
- /habitat_baselines/rl/policy@habitat_baselines.rl.policy.agent_0: hl_fixed
- /habitat_baselines/rl/policy/hierarchical_policy/defined_skills@habitat_baselines.rl.policy.agent_0.hierarchical_policy.defined_skills: oracle_skills_human_multi_agent
- /habitat_baselines/rl/policy@habitat_baselines.rl.policy.agent_1: hl_fixed
- /habitat_baselines/rl/policy/hierarchical_policy/defined_skills@habitat_baselines.rl.policy.agent_1.hierarchical_policy.defined_skills: oracle_skills_human_multi_agent
- _self_

hydra:
job:
name: 'social_nav'

habitat_baselines:
verbose: False
trainer_name: "ddppo"
updater_name: "HRLPPO"
distrib_updater_name: "HRLDDPPO"
torch_gpu_id: 0
video_fps: 30
eval_ckpt_path_dir: "data/checkpoints"
num_environments: 18
num_updates: -1
total_num_steps: 5.0e7
log_interval: 10
num_checkpoints: 100
force_torch_single_threaded: True
eval_keys_to_include_in_name: ['pddl_success']
load_resume_state_config: False
rollout_storage_name: "HrlRolloutStorage"

evaluate: False

eval:
extra_sim_sensors:
third_rgb_sensor:
height: 224
width: 171

should_load_ckpt: True
video_option: ["disk"]

rl:
agent:
type: "MultiAgentAccessMgr"
num_agent_types: 2
num_active_agents_per_type: [1, 1]
num_pool_agents_per_type: [1, 1]
agent_sample_interval: 20
force_partner_sample_idx: -1
policy:
# Motify the action distribution
agent_0:
hierarchical_policy:
high_level_policy:
add_arm_rest: False
policy_input_keys:
- "head_depth"
- "is_holding"
- "obj_start_gps_compass"
- "obj_goal_gps_compass"
- "other_agent_gps"
- "obj_start_sensor"
- "obj_goal_sensor"
allowed_actions:
- nav_to_goal
- nav_to_obj
- pick
- place
- nav_to_receptacle_by_name
# Override to use the oracle navigation skill (which will actually execute navigation).
defined_skills:
nav_to_randcoord:
skill_name: "OracleNavCoordPolicy"
obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"]
max_skill_steps: 1500
ignore_grip: True
agent_1:
hierarchical_policy:
high_level_policy:
add_arm_rest: False
policy_input_keys:
- "head_depth"
- "is_holding"
- "obj_start_gps_compass"
- "obj_goal_gps_compass"
- "other_agent_gps"
- "obj_start_sensor"
- "obj_goal_sensor"
allowed_actions:
- nav_to_goal
- nav_to_obj
- pick
- place
- nav_to_receptacle_by_name
# Override to use the oracle navigation skill (which will actually execute navigation).
defined_skills:
nav_to_randcoord:
skill_name: "OracleNavCoordPolicy"
obs_skill_inputs: ["obj_start_sensor", "abs_obj_start_sensor", "obj_goal_sensor", "abs_obj_goal_sensor"]
max_skill_steps: 1500
ignore_grip: True
ppo:
# ppo params
clip_param: 0.2
ppo_epoch: 1
num_mini_batch: 2
value_loss_coef: 0.5
entropy_coef: 0.0001
lr: 2.5e-4
eps: 1e-5
max_grad_norm: 0.2
num_steps: 128
use_gae: True
gamma: 0.99
tau: 0.95

ddppo:
sync_frac: 0.6
# The PyTorch distributed backend to use
distrib_backend: NCCL
# Visual encoder backbone
pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth
# Initialize with pretrained weights
pretrained: False
# Initialize just the visual encoder backbone with pretrained weights
pretrained_encoder: False
# Whether the visual encoder backbone will be trained.
train_encoder: True
# Whether to reset the critic linear layer
reset_critic: False
# Model parameters
backbone: resnet18
rnn_type: LSTM
num_recurrent_layers: 2
Loading