33import copy
44
55import numpy as np
6-
7- from worldmodel_agents .base import AgentConfig , BaseAgent
86from worldmodel_models .registry import create_world_model
97from worldmodel_planners .mcts import MCTSPlanner
108
9+ from worldmodel_agents .base import AgentConfig , BaseAgent
10+
1111
1212class SearchMCTSAgent (BaseAgent ):
1313 """Minimal MuZero-style skeleton: learned model + MCTS planning."""
1414
1515 def __init__ (self , config : AgentConfig | None = None ):
1616 super ().__init__ (config = config )
1717 self .world_model = create_world_model ("deterministic" )
18- self .planner = MCTSPlanner (action_space_n = self .config .action_space_n , num_simulations = 56 , max_depth = 14 )
18+ self .planner = MCTSPlanner (
19+ action_space_n = self .config .action_space_n , num_simulations = 56 , max_depth = 14
20+ )
1921 self .latent = self .world_model .init_state (batch_size = 1 )
2022 self .buffer : list [dict ] = []
2123 self .rng = np .random .default_rng (0 )
@@ -32,7 +34,9 @@ def act(self, obs, info: dict) -> int:
3234 self .latent = self .world_model .observe (self .latent , obs )
3335
3436 def transition_fn (state , action ):
35- next_state , _pred_obs , pred_reward , pred_done , _aux = self .world_model .predict (state , int (action ))
37+ next_state , _pred_obs , pred_reward , pred_done , _aux = self .world_model .predict (
38+ state , int (action )
39+ )
3640 return next_state , float (pred_reward ), bool (pred_done )
3741
3842 result = self .planner .plan (
0 commit comments