-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlr_bot.py
More file actions
92 lines (64 loc) · 2.48 KB
/
lr_bot.py
File metadata and controls
92 lines (64 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Bots that carry out orders generated by a logistic regressor model."""
from abc import ABC
from dataclasses import dataclass
from pathlib import Path
from typing import List, Sequence
from baseline_models.model_code.predict import predict
from diplomacy.utils import strings as diplomacy_strings
from diplomacy.utils.constants import SuggestionType
from chiron_utils.bots.baseline_bot import BaselineBot, BotType
from chiron_utils.utils import return_logger
logger = return_logger(__name__)
MODEL_PATH = Path() / "lr_model"
@dataclass
class LrBot(BaselineBot, ABC):
"""Baseline logistic regressor model.
`MODEL_PATH` should point to folder containing model .pkl files.
Each model corresponds to a (unit, location, phase) combination.
Unit types are 'A', 'F'.
Location types include all possible locations on the board, such as 'BRE', 'LON', 'STP_SC', etc.
Phase types are 'SM', 'FM', 'WA, 'SR', 'FR', 'CD'
"""
player_type = diplomacy_strings.NO_PRESS_BOT
def __post_init__(self) -> None:
"""Verify that model path exists when instantiated."""
super().__post_init__()
if not MODEL_PATH.is_dir():
raise NotADirectoryError(
f"Model directory {str(MODEL_PATH)!r} does not exist or is not a directory."
)
def get_orders(self) -> List[str]:
"""Get order predictions from model.
Returns:
List of predicted orders.
"""
state = self.game.get_state()
orders: List[str] = predict(MODEL_PATH, state, self.power_name)
logger.info("Orders to suggest: %s", orders)
return orders
async def gen_orders(self) -> List[str]:
"""Generate orders for a turn.
Returns:
List of orders to carry out.
"""
orders = self.get_orders()
if self.bot_type == BotType.ADVISOR:
await self.suggest_orders(orders)
elif self.bot_type == BotType.PLAYER:
await self.send_orders(orders, wait=True)
return orders
async def do_messaging_round(self, orders: Sequence[str]) -> List[str]:
"""Carry out one round of messaging.
Returns:
List of orders to carry out.
"""
return list(orders)
@dataclass
class LrAdvisor(LrBot):
"""Advisor form of `LrBot`."""
bot_type = BotType.ADVISOR
default_suggestion_type = SuggestionType.MOVE
@dataclass
class LrPlayer(LrBot):
"""Player form of `LrBot`."""
bot_type = BotType.PLAYER