Skip to content
This repository was archived by the owner on Jul 16, 2024. It is now read-only.

Commit 9d34ec6

Browse files
committed
setting up support for card type in lut
1 parent 9a2b93f commit 9d34ec6

File tree

4 files changed

+44
-15
lines changed

4 files changed

+44
-15
lines changed

poker_ai/ai/runner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def resume(server_config_path: str):
205205
)
206206
@click.option(
207207
"--lut_path",
208-
default="./card_info_lut.joblib",
208+
default=".",
209209
help=(
210210
"The path to the files for clustering the infosets."
211211
),
@@ -275,6 +275,8 @@ def start(
275275
simple_search(
276276
config=config,
277277
save_path=save_path,
278+
lut_path=lut_path,
279+
pickle_dir=pickle_dir,
278280
strategy_interval=strategy_interval,
279281
n_iterations=n_iterations,
280282
lcfr_threshold=lcfr_threshold,

poker_ai/ai/singleprocess/train.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def print_strategy(strategy: Dict[str, Dict[str, int]]):
3939
def simple_search(
4040
config: Dict[str, int],
4141
save_path: Path,
42+
lut_path: Union[str, Path],
43+
pickle_dir: bool,
4244
strategy_interval: int,
4345
n_iterations: int,
4446
lcfr_threshold: int,
@@ -81,14 +83,19 @@ def simple_search(
8183
"""
8284
utils.random.seed(42)
8385
agent = Agent(use_manager=False)
84-
info_set_lut = {}
86+
card_info_lut = {}
8587
for t in trange(1, n_iterations + 1, desc="train iter"):
8688
if t == 2:
8789
logging.disable(logging.DEBUG)
8890
for i in range(n_players): # fixed position i
8991
# Create a new state.
90-
state: ShortDeckPokerState = new_game(n_players, info_set_lut)
91-
info_set_lut = state.info_set_lut
92+
state: ShortDeckPokerState = new_game(
93+
n_players,
94+
card_info_lut,
95+
lut_path=lut_path,
96+
pickle_dir=pickle_dir
97+
)
98+
card_info_lut = state.card_info_lut
9299
if t > update_threshold and t % strategy_interval == 0:
93100
ai.update_strategy(agent=agent, state=state, i=i, t=t)
94101
if t > prune_threshold:

poker_ai/clustering/card_info_lut_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def compute(
6464
"""
6565
log.info("Starting computation of clusters.")
6666
start = time.time()
67-
if "preflop" not in self.card_info_lut:
68-
self.card_info_lut["preflop"] = compute_preflop_lossless_abstraction(
67+
if "pre_flop" not in self.card_info_lut:
68+
self.card_info_lut["pre_flop"] = compute_preflop_lossless_abstraction(
6969
builder=self
7070
)
7171
joblib.dump(self.card_info_lut, self.card_info_lut_path)

poker_ai/games/short_deck/state.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,18 @@ def new_game(
4848
]
4949
if card_info_lut:
5050
# Don't reload massive files, it takes ages.
51-
state = ShortDeckPokerState(players=players, load_card_lut=False, **kwargs)
51+
state = ShortDeckPokerState(
52+
players=players,
53+
load_card_lut=False,
54+
**kwargs
55+
)
5256
state.card_info_lut = card_info_lut
5357
else:
5458
# Load massive files.
55-
state = ShortDeckPokerState(players=players, **kwargs)
59+
state = ShortDeckPokerState(
60+
players=players,
61+
**kwargs
62+
)
5663
return state
5764

5865

@@ -79,8 +86,9 @@ def __init__(
7986
f"At least 2 players must be provided but only {n_players} "
8087
f"were provided."
8188
)
89+
self._pickle_dir = pickle_dir
8290
if load_card_lut:
83-
self.card_info_lut = self.load_card_lut(lut_path, pickle_dir)
91+
self.card_info_lut = self.load_card_lut(lut_path, self._pickle_dir)
8492
else:
8593
self.card_info_lut = {}
8694
# Get a reference of the pot from the first player.
@@ -227,7 +235,10 @@ def apply_action(self, action_str: Optional[str]) -> ShortDeckPokerState:
227235
return new_state
228236

229237
@staticmethod
230-
def load_card_lut(lut_path: str = ".", pickle_dir: bool = False) -> Dict[str, Dict[Tuple[int, ...], str]]:
238+
def load_card_lut(
239+
lut_path: str = ".",
240+
pickle_dir: bool = False
241+
) -> Dict[str, Dict[Tuple[int, ...], str]]:
231242
"""
232243
Load card information lookup table.
233244
@@ -267,7 +278,7 @@ def load_card_lut(lut_path: str = ".", pickle_dir: bool = False) -> Dict[str, Di
267278
card_info_lut[betting_stage] = joblib.load(fp)
268279
elif lut_path:
269280
logger.info(f"Loading card from single file at path: {lut_path}")
270-
card_info_lut = joblib.load(lut_path)
281+
card_info_lut = joblib.load(lut_path + '/card_info_lut.joblib')
271282
else:
272283
card_info_lut = {}
273284
return card_info_lut
@@ -373,20 +384,29 @@ def betting_round(self) -> int:
373384
@property
374385
def info_set(self) -> str:
375386
"""Get the information set for the current player."""
387+
if self._pickle_dir:
388+
key = operator.attrgetter("eval_card")
389+
else:
390+
key = None
376391
cards = sorted(
377392
self.current_player.cards,
378-
key=operator.attrgetter("eval_card"),
393+
key=key,
379394
reverse=True,
380395
)
381396
cards += sorted(
382397
self._table.community_cards,
383-
key=operator.attrgetter("eval_card"),
398+
key=key,
384399
reverse=True,
385400
)
386-
eval_cards = tuple([card.eval_card for card in cards])
401+
if self._pickle_dir:
402+
lookup_cards = tuple([card.eval_card for card in cards])
403+
else:
404+
lookup_cards = tuple(cards)
387405
try:
388-
cards_cluster = self.card_info_lut[self._betting_stage][eval_cards]
406+
cards_cluster = self.card_info_lut[self._betting_stage][lookup_cards]
389407
except KeyError:
408+
import ipdb;
409+
ipdb.set_trace()
390410
return "default info set, please ensure you load it correctly"
391411
# Convert history from a dict of lists to a list of dicts as I'm
392412
# paranoid about JSON's lack of care with insertion order.

0 commit comments

Comments
 (0)