Skip to content

Commit 73295c2

Browse files
authored
serial training league demo (#12)
* feature(nyz): add naive 1v1 two player demo * feature(nyz): add 1v1 evaluator and 2 rule-based policy for evaluation * feature(nyz): modify game env and adjust hyper-param * feature(nyz): add naive league training multi player demo * feature(nyz): enable force snapshot to support init historical league player; finish league demo basic code * feature(nyz): modify selfplay demo and add two type game env * style(nyz): correct format style * polish(nyz): correct format style and adapt league demo main * feature(nyz): add league payoff viz and enable payoff update in league demo * feature(nyz): modify win rate calculation with draws * test(nyz): fix one vs one league test compatibility bug * test(nyz): add selfplay and league demo into unittest and algotest * style(nyz): correct format * hotfix(nyz): fix ppo continuous comatibility bug
1 parent dd4de1a commit 73295c2

21 files changed

+1285
-26
lines changed

ding/entry/tests/test_serial_entry.py

+19
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from dizoo.multiagent_particle.config import cooperative_navigation_coma_config, cooperative_navigation_coma_create_config # noqa
2929
from dizoo.multiagent_particle.config import cooperative_navigation_collaq_config, cooperative_navigation_collaq_create_config # noqa
3030
from dizoo.multiagent_particle.config import cooperative_navigation_atoc_config, cooperative_navigation_atoc_create_config # noqa
31+
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
32+
from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main
33+
from dizoo.league_demo.league_demo_ppo_main import main as league_main
3134

3235

3336
@pytest.mark.unittest
@@ -254,6 +257,22 @@ def test_sqn():
254257
os.popen('rm -rf log ckpt*')
255258

256259

260+
@pytest.mark.unittest
261+
def test_selfplay():
262+
try:
263+
selfplay_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1)
264+
except Exception:
265+
assert False, "pipeline fail"
266+
267+
268+
@pytest.mark.unittest
269+
def test_league():
270+
try:
271+
league_main(deepcopy(league_demo_ppo_config), seed=0, max_iterations=1)
272+
except Exception as e:
273+
assert False, "pipeline fail"
274+
275+
257276
@pytest.mark.unittest
258277
def test_acer():
259278
config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)]

ding/entry/tests/test_serial_entry_algo.py

+23
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from dizoo.multiagent_particle.config import cooperative_navigation_coma_config, cooperative_navigation_coma_create_config # noqa
2929
from dizoo.multiagent_particle.config import cooperative_navigation_collaq_config, cooperative_navigation_collaq_create_config # noqa
3030
from dizoo.multiagent_particle.config import cooperative_navigation_atoc_config, cooperative_navigation_atoc_create_config # noqa
31+
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
32+
from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main
33+
from dizoo.league_demo.league_demo_ppo_main import main as league_main
3134

3235
with open("./algo_record.log", "w+") as f:
3336
f.write("ALGO TEST STARTS\n")
@@ -274,3 +277,23 @@ def test_acer():
274277
assert False, "pipeline fail"
275278
with open("./algo_record.log", "a+") as f:
276279
f.write("22. acer\n")
280+
281+
282+
@pytest.mark.algotest
283+
def test_selfplay():
284+
try:
285+
selfplay_main(league_demo_ppo_config, seed=0)
286+
except Exception:
287+
assert False, "pipeline fail"
288+
with open("./algo_record.log", "a+") as f:
289+
f.write("23. selfplay\n")
290+
291+
292+
@pytest.mark.algotest
293+
def test_league():
294+
try:
295+
league_main(league_demo_ppo_config, seed=0)
296+
except Exception:
297+
assert False, "pipeline fail"
298+
with open("./algo_record.log", "a+") as f:
299+
f.write("24. league\n")

ding/league/base_league.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22
import copy
3+
import os
34
from abc import abstractmethod
45
from easydict import EasyDict
56
import os.path as osp
@@ -35,6 +36,9 @@ def __init__(self, cfg: EasyDict) -> None:
3536
"""
3637
self.cfg = cfg
3738
self.path_policy = cfg.path_policy
39+
if not osp.exists(self.path_policy):
40+
os.mkdir(self.path_policy)
41+
3842
self.league_uid = str(uuid.uuid1())
3943
self.active_players = []
4044
self.historical_players = []
@@ -52,13 +56,11 @@ def _init_players(self) -> None:
5256
for cate in self.cfg.player_category: # Player's category (Depends on the env)
5357
for k, n in self.cfg.active_players.items(): # Active player's type
5458
for i in range(n): # This type's active player number
55-
name = '{}_{}_{}_{}'.format(k, cate, i, self.league_uid)
56-
ckpt_path = '{}_ckpt.pth'.format(name)
59+
name = '{}_{}_{}'.format(k, cate, i)
60+
ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name))
5761
player = create_player(self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0)
5862
if self.cfg.use_pretrain:
59-
self.save_checkpoint(
60-
self.cfg.pretrain_checkpoint_path[cate], osp.join(self.path_policy, player.checkpoint_path)
61-
)
63+
self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path)
6264
self.active_players.append(player)
6365
self.payoff.add_player(player)
6466

@@ -68,7 +70,7 @@ def _init_players(self) -> None:
6870
main_player_name = [k for k in self.cfg.keys() if 'main_player' in k]
6971
assert len(main_player_name) == 1, main_player_name
7072
main_player_name = main_player_name[0]
71-
name = '{}_{}_0_pretrain'.format(main_player_name, cate)
73+
name = '{}_{}_0_pretrain_historical'.format(main_player_name, cate)
7274
parent_name = '{}_{}_0'.format(main_player_name, cate)
7375
hp = HistoricalPlayer(
7476
self.cfg.get(main_player_name),
@@ -122,7 +124,7 @@ def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
122124
"""
123125
raise NotImplementedError
124126

125-
def judge_snapshot(self, player_id: str) -> bool:
127+
def judge_snapshot(self, player_id: str, force: bool = False) -> bool:
126128
"""
127129
Overview:
128130
Judge whether a player is trained enough for snapshot. If yes, call player's ``snapshot``, create a
@@ -136,12 +138,10 @@ def judge_snapshot(self, player_id: str) -> bool:
136138
with self._active_players_lock:
137139
idx = self.active_players_ids.index(player_id)
138140
player = self.active_players[idx]
139-
if player.is_trained_enough():
141+
if force or player.is_trained_enough():
140142
# Snapshot
141143
hp = player.snapshot()
142-
self.save_checkpoint(
143-
osp.join(self.path_policy, player.checkpoint_path), osp.join(self.path_policy, hp.checkpoint_path)
144-
)
144+
self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path)
145145
self.historical_players.append(hp)
146146
self.payoff.add_player(hp)
147147
# Mutate

ding/league/player.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def snapshot(self) -> HistoricalPlayer:
185185
self.category,
186186
self.payoff,
187187
path,
188-
self.player_id + '_{}'.format(int(self._total_agent_step)),
188+
self.player_id + '_{}_historical'.format(int(self._total_agent_step)),
189189
self._total_agent_step,
190190
parent_id=self.player_id
191191
)

ding/league/shared_payoff.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import defaultdict
33
from typing import Tuple, Optional
44
from easydict import EasyDict
5+
from tabulate import tabulate
56
import numpy as np
67

78
from ding.utils import LockContext, LockContextType
@@ -76,6 +77,23 @@ def __init__(self, cfg: EasyDict):
7677
# Thread lock.
7778
self._lock = LockContext(type_=LockContextType.THREAD_LOCK)
7879

80+
def __repr__(self) -> str:
81+
headers = ["Home Player", "Away Player", "Wins", "Draws", "Losses", "Naive Win Rate"]
82+
data = []
83+
for k, v in self._data.items():
84+
k1 = k.split('-')
85+
# k is the format of '{}-{}'.format(name1, name2), and each HistoricalPlayer has `historical` suffix
86+
if 'historical' in k1[0]:
87+
# reverse representation
88+
naive_win_rate = (v['losses'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8)
89+
data.append([k1[1], k1[0], v['losses'], v['draws'], v['wins'], naive_win_rate])
90+
else:
91+
naive_win_rate = (v['wins'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8)
92+
data.append([k1[0], k1[1], v['wins'], v['draws'], v['losses'], naive_win_rate])
93+
data = sorted(data, key=lambda x: x[0])
94+
s = tabulate(data, headers=headers, tablefmt='grid')
95+
return s
96+
7997
def __getitem__(self, players: tuple) -> np.ndarray:
8098
"""
8199
Overview:
@@ -172,18 +190,21 @@ def _win_loss_reverse(result_: str, reverse_: bool) -> str:
172190

173191
with self._lock:
174192
home_id, away_id = job_info['player_id']
193+
job_info_result = job_info['result']
194+
# for compatibility of one-layer list
195+
if not isinstance(job_info_result[0], list):
196+
job_info_result = [job_info_result]
175197
try:
176-
assert home_id in self._players_ids
177-
assert away_id in self._players_ids
198+
assert home_id in self._players_ids, "home_id error"
199+
assert away_id in self._players_ids, "away_id error"
178200
# Assert all results are in ['wins', 'losses', 'draws']
179-
assert all([i in BattleRecordDict.data_keys[:3] for j in job_info['result'] for i in j])
201+
assert all([i in BattleRecordDict.data_keys[:3] for j in job_info_result for i in j]), "results error"
180202
except Exception as e:
181-
print("[ERROR] invalid job_info: {}".format(job_info))
182-
print(e)
203+
print("[ERROR] invalid job_info: {}\n\tError reason is: {}".format(job_info, e))
183204
return False
184205
key, reverse = self.get_key(home_id, away_id)
185206
# Update with decay
186-
for j in job_info['result']:
207+
for j in job_info_result:
187208
for i in j:
188209
# All categories should decay
189210
self._data[key] *= self._decay

ding/league/tests/test_one_vs_one_league.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def test_naive(self):
6666
active_player_ckpt = league.active_players[0].checkpoint_path
6767
tmp = torch.tensor([1, 2, 3])
6868
path_policy = one_vs_one_league_default_config.league.path_policy
69-
os.makedirs(path_policy)
70-
torch.save(tmp, os.path.join(path_policy, active_player_ckpt))
69+
torch.save(tmp, active_player_ckpt)
7170

7271
# judge_snapshot & update_active_player
7372
assert not league.judge_snapshot(active_player_id)

ding/worker/collector/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
to_tensor_transitions
44
from .sample_serial_collector import SampleCollector
55
from .episode_serial_collector import EpisodeCollector
6+
from .episode_one_vs_one_serial_collector import Episode1v1Collector
67
from .base_serial_evaluator import BaseSerialEvaluator
8+
from .one_vs_one_serial_evaluator import OnevOneEvaluator
79
# parallel
810
from .base_parallel_collector import BaseCollector, create_parallel_collector, get_parallel_collector_cls
911
from .zergling_collector import ZerglingCollector

ding/worker/collector/base_serial_collector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __init__(self, maxlen: int, *args, **kwargs) -> None:
177177
Overview:
178178
Initialization trajBuffer.
179179
Arguments:
180-
- maxlen (:obj:`int`): the max len of trajBuffer
180+
- maxlen (:obj:`int`): The maximum length of trajectory buffer.
181181
"""
182182
self._maxlen = maxlen
183183
super().__init__(*args, **kwargs)

ding/worker/collector/comm/flask_fs_collector.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def get_policy_update_info(self, path: str) -> dict:
161161
"""
162162
if self._collector_close_flag:
163163
return
164-
path = os.path.join(self._path_policy, path)
164+
if self._path_policy not in path:
165+
path = os.path.join(self._path_policy, path)
165166
return read_file(path, use_lock=True)
166167

167168
# override

0 commit comments

Comments
 (0)