Skip to content

Commit b27a7e0

Browse files
committed
feature(pu): add priorzero test version
1 parent b1efa60 commit b27a7e0

File tree

6 files changed

+1410
-0
lines changed

6 files changed

+1410
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# game_segment_priorzero.py
2+
from lzero.mcts.buffer.game_segment import GameSegment as OriginalGameSegment
3+
import numpy as np
4+
5+
class GameSegment(OriginalGameSegment):
6+
"""
7+
[PRIORZERO-MODIFIED]
8+
继承自原始 GameSegment 并添加了存储 MCTS 策略的功能。
9+
"""
10+
def __init__(self, action_space, game_segment_length=200, config=None, task_id=None):
11+
super().__init__(action_space, game_segment_length, config, task_id)
12+
# [PRIORZERO-NEW] 新增 mcts_policy_segment 用于存储 RFT 的目标
13+
self.mcts_policy_segment = []
14+
15+
def append(self, action, obs, reward, action_mask, to_play, timestep):
16+
super().append(action, obs, reward, action_mask, to_play, timestep)
17+
# 在 append 时,我们还没有 MCTS 策略,所以先用一个占位符
18+
self.mcts_policy_segment.append(None)
19+
20+
def store_search_stats(self, root_visit_dist, value, *args, **kwargs):
21+
"""
22+
[PRIORZERO-MODIFIED]
23+
在存储搜索统计信息时,将 MCTS 访问计数分布也存起来。
24+
"""
25+
super().store_search_stats(root_visit_dist, value, *args, **kwargs)
26+
# 最后一个被 append 的状态对应的 MCTS 策略
27+
# root_visit_dist 是一个 list, 我们需要它是一个 numpy array
28+
policy_array = np.array(root_visit_dist, dtype=np.float32)
29+
# 归一化
30+
if policy_array.sum() > 0:
31+
policy_array /= policy_array.sum()
32+
else: # 如果没有访问,则为均匀分布
33+
policy_array = np.ones_like(policy_array) / len(policy_array)
34+
35+
# 存储到最后一个位置
36+
self.mcts_policy_segment[-1] = policy_array
37+
38+
def game_segment_to_array(self):
39+
"""
40+
[PRIORZERO-MODIFIED]
41+
将 mcts_policy_segment 也转换为 numpy 数组。
42+
"""
43+
super().game_segment_to_array()
44+
self.mcts_policy_segment = np.array(self.mcts_policy_segment, dtype=object)

0 commit comments

Comments
 (0)