1+ # game_segment_priorzero.py
2+ from lzero .mcts .buffer .game_segment import GameSegment as OriginalGameSegment
3+ import numpy as np
4+
5+ class GameSegment (OriginalGameSegment ):
6+ """
7+ [PRIORZERO-MODIFIED]
8+ 继承自原始 GameSegment 并添加了存储 MCTS 策略的功能。
9+ """
10+ def __init__ (self , action_space , game_segment_length = 200 , config = None , task_id = None ):
11+ super ().__init__ (action_space , game_segment_length , config , task_id )
12+ # [PRIORZERO-NEW] 新增 mcts_policy_segment 用于存储 RFT 的目标
13+ self .mcts_policy_segment = []
14+
15+ def append (self , action , obs , reward , action_mask , to_play , timestep ):
16+ super ().append (action , obs , reward , action_mask , to_play , timestep )
17+ # 在 append 时,我们还没有 MCTS 策略,所以先用一个占位符
18+ self .mcts_policy_segment .append (None )
19+
20+ def store_search_stats (self , root_visit_dist , value , * args , ** kwargs ):
21+ """
22+ [PRIORZERO-MODIFIED]
23+ 在存储搜索统计信息时,将 MCTS 访问计数分布也存起来。
24+ """
25+ super ().store_search_stats (root_visit_dist , value , * args , ** kwargs )
26+ # 最后一个被 append 的状态对应的 MCTS 策略
27+ # root_visit_dist 是一个 list, 我们需要它是一个 numpy array
28+ policy_array = np .array (root_visit_dist , dtype = np .float32 )
29+ # 归一化
30+ if policy_array .sum () > 0 :
31+ policy_array /= policy_array .sum ()
32+ else : # 如果没有访问,则为均匀分布
33+ policy_array = np .ones_like (policy_array ) / len (policy_array )
34+
35+ # 存储到最后一个位置
36+ self .mcts_policy_segment [- 1 ] = policy_array
37+
38+ def game_segment_to_array (self ):
39+ """
40+ [PRIORZERO-MODIFIED]
41+ 将 mcts_policy_segment 也转换为 numpy 数组。
42+ """
43+ super ().game_segment_to_array ()
44+ self .mcts_policy_segment = np .array (self .mcts_policy_segment , dtype = object )
0 commit comments