Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,5 @@
from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp
from .train_unizero_with_loss_landscape import train_unizero_with_loss_landscape

# from .utils import (
# symlog,
# inv_symlog,
# initialize_zeros_batch,
# freeze_non_lora_parameters,
# compute_task_weights,
# TemperatureScheduler,
# tasks_per_stage,
# compute_unizero_mt_normalized_stats,
# allocate_batch_size,
# is_ddp_enabled,
# ddp_synchronize,
# ddp_all_reduce_sum,
# calculate_update_per_collect,
# initialize_pad_batch,
# random_collect,
# convert_to_batch_for_unizero,
# create_unizero_loss_metrics,
# UniZeroDataLoader,
# log_module_trainable_status,
# log_param_statistics,
# log_buffer_memory_usage,
# log_buffer_run_time,
# )
from .train_unizero_multitask import train_unizero_multitask
from .train_unizero_multitask_ddp import train_unizero_multitask_ddp
357 changes: 357 additions & 0 deletions lzero/entry/train_unizero_multitask.py

Large diffs are not rendered by default.

449 changes: 449 additions & 0 deletions lzero/entry/train_unizero_multitask_ddp.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,8 @@ def safe_eval(
evaluator: Evaluator,
learner: BaseLearner,
collector: Collector,
rank: int,
world_size: int,
rank: int = 0,
world_size: int = 1,
timeout: int = EVALUATION_TIMEOUT
) -> Tuple[Optional[bool], Optional[Any]]:
"""
Expand Down
5 changes: 4 additions & 1 deletion lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
# image obs input, e.g. atari environments
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1])
else:
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])
if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
self.zero_obs_shape = config.model.observation_shape
elif len(config.model.observation_shape) == 3:
self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

self.obs_segment = []
self.action_segment = []
Expand Down
19 changes: 12 additions & 7 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,20 @@ def __init__(self,
torch.distributed.barrier()
if get_rank() != 0:
self.pretrained_model = AutoModel.from_pretrained(model_path)

if get_rank() != 0:
logging.info(f"Worker process is loading model from cache: {model_path}")
self.model = AutoModel.from_pretrained(model_path)
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)

if tokenizer is not None:
if tokenizer is None:
# Only rank 0 downloads the tokenizer, and then other processes load it from cache.
if get_rank() == 0:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
if get_world_size() > 1:
torch.distributed.barrier()
if get_rank() != 0:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
else:
self.tokenizer = tokenizer

for p in self.pretrained_model.parameters():
p.requires_grad_(False)

self.embedding_size = embedding_size
self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size)
Expand Down
6 changes: 3 additions & 3 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def __init__(
else:
raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}")

self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer,
with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option'])
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network,
with_lpips=False, obs_type=world_model_cfg.obs_type)
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)

# --- Log parameter counts for analysis ---
self._log_model_parameters(obs_type)
self._log_model_parameters(world_model_cfg.obs_type)

logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
logging.info('==' * 20)
Expand Down
19 changes: 18 additions & 1 deletion lzero/model/unizero_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from easydict import EasyDict

from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
HFLanguageRepresentationNetwork
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model_multitask import WorldModelMT
from .vit import ViT, ViTConfig
Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
self._init_image_components(world_model_cfg, observation_shape, num_res_blocks, num_channels, obs_act_embed_dim)
elif obs_type == 'image_memory':
self._init_image_memory_components(world_model_cfg)
elif obs_type == 'text':
self._init_text_components(world_model_cfg, encoder_url=kwargs['encoder_url'])
else:
raise ValueError(f"Unsupported observation type: {obs_type}")

Expand All @@ -93,6 +96,20 @@ def __init__(
# --- Log parameter counts for analysis ---
self._log_model_parameters(obs_type)

def _init_text_components(self, world_model_cfg: EasyDict, encoder_url: str) -> None:
"""Initializes components for 'text' observation type."""
self.representation_network = HFLanguageRepresentationNetwork(
model_path=encoder_url,
embedding_size=world_model_cfg.embed_dim,
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder
)
self.tokenizer = Tokenizer(
encoder=self.representation_network,
decoder=None,
with_lpips=False,
obs_type=world_model_cfg.obs_type
)

def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim: int) -> None:
"""Initializes components for 'vector' observation type."""
self.representation_network = RepresentationNetworkMLP(
Expand Down
6 changes: 6 additions & 0 deletions lzero/model/unizero_world_models/world_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,13 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
dtype=batch['observations'].dtype)
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
dtype=batch['observations'].dtype)

elif self.obs_type == 'text':
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
dtype=torch.float32)
latent_recon_loss = self.latent_recon_loss


# Action tokens
if self.continuous_action_space:
act_tokens = batch['actions']
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
self._cfg.device,
pad_token_id=self.pad_token_id
)
self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)]
self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)]


# We must handle both single int and list of ints for env_id.
Expand Down Expand Up @@ -1707,7 +1707,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
)
logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape)

self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]
self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)]

# This logic handles the crucial end-of-episode cache clearing for evaluation.
# The evaluator calls `_policy.reset([env_id])` when an episode is done.
Expand Down
40 changes: 17 additions & 23 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self._logger, _ = build_logger(
path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False
)
self._tb_logger = None
self._tb_logger = tb_logger

self.policy_config = policy_config
self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy
Expand Down Expand Up @@ -126,7 +126,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
self._logger.debug(
f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))"
)
self._policy.reset()
self._policy.reset(task_id=self.task_id)

def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Expand Down Expand Up @@ -359,7 +359,7 @@ def collect(
chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)}

# Initialize game segments and observation stacks for each environment.
game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)]
game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) for _ in range(env_nums)]
observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)]
for env_id in range(env_nums):
for _ in range(self.policy_config.model.frame_stack_num):
Expand Down Expand Up @@ -525,7 +525,7 @@ def collect(
last_game_priorities[env_id] = priorities

# Start a new game segment.
game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config)
game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id)
game_segments[env_id].reset(observation_window_stack[env_id])

self._env_info[env_id]['step'] += 1
Expand Down Expand Up @@ -570,7 +570,7 @@ def collect(
chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance'])

# Reset game segment and observation stack.
game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config)
game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id)
observation_window_stack[env_id].clear()
for _ in range(self.policy_config.model.frame_stack_num):
observation_window_stack[env_id].append(init_obs[env_id]['observation'])
Expand All @@ -585,7 +585,7 @@ def collect(
completed_value_lst[env_id] = 0

# Reset policy and collector stats for the finished environment.
self._policy.reset([env_id])
self._policy.reset([env_id], task_id=self.task_id)
self._reset_stat(env_id)
ready_env_id.remove(env_id)

Expand Down Expand Up @@ -628,9 +628,6 @@ def _output_log(self, train_iter: int) -> None:
Arguments:
- train_iter (:obj:`int`): The current training iteration number, used as the logging step.
"""
if self._rank != 0:
return

if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
self._last_train_iter = train_iter
episode_count = len(self._episode_info)
Expand Down Expand Up @@ -663,21 +660,18 @@ def _output_log(self, train_iter: int) -> None:

self._episode_info.clear()

# Log to console
self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()])))

self._logger.info(f"Collector log (rank {self._rank}, task_id {self.task_id}):\n" + '\n'.join([f'{k}: {v}' for k, v in info.items()]))
# Log to TensorBoard and WandB
for k, v in info.items():
if k in ['each_reward']:
continue
if self.task_id is None:
tb_prefix_iter = f'{self._instance_name}_iter/'
tb_prefix_step = f'{self._instance_name}_step/'
# Log for single-task setting
self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter)
if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']:
self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, self._total_envstep_count)
else:
tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/'
tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/'

self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter)
self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count)

if self.policy_config.use_wandb:
wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()}
wandb.log(wandb_log_data, step=self._total_envstep_count)
# Log for multi-task setting
self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter)
if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']:
self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, self._total_envstep_count)
6 changes: 4 additions & 2 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
),
),
),
accumulation_steps=1, # TODO: Accumulated gradient steps (currently default)
accumulation_steps=1,
model=dict(
observation_shape=512,
action_space_size=action_space_size,
Expand All @@ -135,7 +135,9 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
obs_type="text",
env_num=max(collector_env_num, evaluator_env_num),
decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None.
latent_recon_loss_weight=0.1
latent_recon_loss_weight=0.1,
game_segment_length=50,
use_priority=False,
),
),
update_per_collect=int(collector_env_num*max_steps*replay_ratio ), # Important for DDP
Expand Down
14 changes: 6 additions & 8 deletions zoo/jericho/configs/jericho_unizero_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
else:
raise ValueError(f"Unsupported encoder option: {encoder_option}")

# TODO
# batch_size = batch_size * 2

# ------------------------------------------------------------------
# Base environment parameters (Note: these values might be adjusted for different env_id)
# ------------------------------------------------------------------
Expand Down Expand Up @@ -115,7 +112,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
),
),
),
accumulation_steps=accumulation_steps, # TODO: Accumulated gradient steps (currently default)
accumulation_steps=accumulation_steps,
model=dict(
observation_shape=512,
action_space_size=action_space_size,
Expand All @@ -138,13 +135,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
num_layers=num_layers,
num_heads=24,
embed_dim=embed_dim,
obs_type="text", # TODO: Modify as needed.
obs_type="text",
env_num=max(collector_env_num, evaluator_env_num),
decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None.
latent_recon_loss_weight=0.1 # TODO: decoder loss weight
latent_recon_loss_weight=0, # TODO: decoder loss weight
game_segment_length=50,
use_priority=False,
),
),
# TODO
update_per_collect=int(collector_env_num*max_steps*replay_ratio*accumulation_steps), # Important for DDP
action_type="varied_action_space",
model_path=None,
Expand All @@ -159,7 +157,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e

num_simulations=num_simulations,
n_episode=n_episode,
train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed.
train_start_after_envsteps=0,
replay_buffer_size=int(5e5),
eval_freq=int(3e4),
collector_env_num=collector_env_num,
Expand Down
Loading