Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b6e80d8
Cleaned and added refactored cfr versions for torch and jax
alexunderch Dec 14, 2025
8bda550
Misinput
alexunderch Dec 14, 2025
730dd30
Annotations, improvement fixes
alexunderch Dec 15, 2025
4b45250
Annotations, improvement fixes
alexunderch Dec 15, 2025
bc554fc
Minor changes, added degugging
alexunderch Dec 15, 2025
2f98560
A typo
alexunderch Dec 15, 2025
a20799f
Performance improvements
alexunderch Dec 15, 2025
05d61ea
Fixed a mistake in the network
alexunderch Dec 16, 2025
bcb85e4
Speed improvements
alexunderch Dec 17, 2025
01775ec
Jax perf improvements
alexunderch Dec 17, 2025
c69a211
Clean code. Should pass the tests
alexunderch Dec 18, 2025
c8abc91
Decorator details, type hints
alexunderch Dec 19, 2025
8b25d84
Structured Loops, added numpy backend as well
alexunderch Dec 20, 2025
ed025d9
delete tqdm
alexunderch Dec 20, 2025
8e26def
Fixed failing tests due to the cornercase
alexunderch Dec 20, 2025
508192f
Final remarks
alexunderch Dec 21, 2025
192626f
Cleaning
alexunderch Dec 22, 2025
efdff3a
Fixes
alexunderch Dec 24, 2025
e54274c
Optimisation improvememnts, trimming
alexunderch Dec 30, 2025
7fb6cc9
Speedups
alexunderch Jan 11, 2026
8832ad9
Updated to the most recent reqs
alexunderch Jan 13, 2026
1d1b907
Typo
alexunderch Jan 13, 2026
018e15f
Another typo
alexunderch Jan 13, 2026
cc3a6b1
Force pytorch to cpu
alexunderch Jan 13, 2026
7a600fd
openblas attempt
alexunderch Jan 13, 2026
3e24895
Disabled ARM wheels
alexunderch Jan 13, 2026
26161df
manylinux change
alexunderch Jan 13, 2026
39c250a
Dont run tests. To be reverted
alexunderch Jan 13, 2026
024e6f9
Brought back linux workflows
alexunderch Jan 13, 2026
2b7103e
Re-added the terminating newline
alexunderch Jan 13, 2026
ce109c5
A merge conflict
alexunderch Jan 13, 2026
137172c
Merge branch 'master' into CI_fix
alexunderch Jan 13, 2026
83d847c
Update wheels.yml
lanctot Jan 13, 2026
a97de84
Added deep_cfr_jax_test
alexunderch Jan 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,26 @@ jobs:
strategy:
matrix:
include:
# Disabled temporarily while broken.
# See https://github.com/google-deepmind/open_spiel/issues/1441 for details.
# - os: ubuntu-24.04
# NAME: "Linux"
# OS_TYPE: "Linux"
# CI_PYBIN: python3.12
# OS_PYTHON_VERSION: 3.12
# OPEN_SPIEL_ABSL_VERSION: "20250814.1"
# CIBW_VERSION: 3.2.1
# CIBW_ENABLE: all
# CIBW_ENVIRONMENT: "CXX=$(which g++) OPEN_SPIEL_BUILDING_WHEEL='ON' OPEN_SPIEL_BUILD_WITH_ACPC='ON' OPEN_SPIEL_BUILD_WITH_HANABI='ON' OPEN_SPIEL_BUILD_WITH_ROSHAMBO='ON'"
# CIBW_BUILD: cp311-manylinux_x86_64 cp312-manylinux_x86_64 cp313-manylinux_x86_64 cp314-manylinux_x86_64
# - os: ubuntu-24.04-arm
# NAME: "Linux_arm64"
# OS_TYPE: "Linux"
# CI_PYBIN: python3.12
# OS_PYTHON_VERSION: 3.12
# OPEN_SPIEL_ABSL_VERSION: "20250814.1"
# CIBW_VERSION: 3.2.1
# CIBW_ENABLE: all
# CIBW_ENVIRONMENT: "CXX=$(which g++) OPEN_SPIEL_BUILDING_WHEEL='ON' OPEN_SPIEL_BUILD_WITH_ACPC='ON' OPEN_SPIEL_BUILD_WITH_HANABI='ON' OPEN_SPIEL_BUILD_WITH_ROSHAMBO='ON'"
# CIBW_BUILD: cp311-manylinux_aarch64 cp312-manylinux_aarch64 cp313-manylinux_aarch64 cp314-manylinux_aarch64
- os: ubuntu-24.04
NAME: "Linux"
OS_TYPE: "Linux"
CI_PYBIN: python3.12
OS_PYTHON_VERSION: 3.12
OPEN_SPIEL_ABSL_VERSION: "20250814.1"
CIBW_VERSION: 3.2.1
CIBW_ENABLE: all
CIBW_ENVIRONMENT: "CXX=$(which g++) OPEN_SPIEL_BUILDING_WHEEL='ON' OPEN_SPIEL_BUILD_WITH_ACPC='ON' OPEN_SPIEL_BUILD_WITH_HANABI='ON' OPEN_SPIEL_BUILD_WITH_ROSHAMBO='ON'"
CIBW_BUILD: cp311-manylinux_x86_64 cp312-manylinux_x86_64 cp313-manylinux_x86_64 cp314-manylinux_x86_64
- os: ubuntu-24.04-arm
NAME: "Linux_arm64"
OS_TYPE: "Linux"
CI_PYBIN: python3.12
OS_PYTHON_VERSION: 3.12
OPEN_SPIEL_ABSL_VERSION: "20250814.1"
CIBW_VERSION: 3.2.1
CIBW_ENABLE: all
CIBW_ENVIRONMENT: "CXX=$(which g++) OPEN_SPIEL_BUILDING_WHEEL='ON' OPEN_SPIEL_BUILD_WITH_ACPC='ON' OPEN_SPIEL_BUILD_WITH_HANABI='ON' OPEN_SPIEL_BUILD_WITH_ROSHAMBO='ON'"
CIBW_BUILD: cp311-manylinux_aarch64 cp312-manylinux_aarch64 cp313-manylinux_aarch64 cp314-manylinux_aarch64
- os: macos-14
OS_TYPE: "Darwin"
CI_PYBIN: python3.12
Expand Down Expand Up @@ -78,8 +76,8 @@ jobs:
CI_PYBIN: ${{ matrix.CI_PYBIN }}
CIBW_VERSION: ${{ matrix.CIBW_VERSION }}
CIBW_ENABLE: ${{ matrix.CIBW_ENABLE }}
CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014
CIBW_MANYLINUX_AARCH64_IMAGE: manylinux2014
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28
CIBW_MANYLINUX_AARCH64_IMAGE: manylinux_2_28
CIBW_BUILD: ${{ matrix.CIBW_BUILD }}
CIBW_BEFORE_TEST: python -m pip install --upgrade pip
CIBW_TEST_COMMAND: /bin/bash {project}/open_spiel/scripts/test_wheel.sh basic {project}
Expand Down Expand Up @@ -113,7 +111,7 @@ jobs:
python -m pip install --upgrade -r requirements.txt -q
source ./open_spiel/scripts/python_extra_deps.sh python
python -m pip install --no-cache-dir --upgrade $OPEN_SPIEL_PYTHON_JAX_DEPS
python -m pip install --no-cache-dir --upgrade $OPEN_SPIEL_PYTHON_PYTORCH_DEPS
python -m pip install --no-cache-dir --upgrade $OPEN_SPIEL_PYTHON_PYTORCH_DEPS --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pypi.org/simple
python -m pip install --no-cache-dir --upgrade $OPEN_SPIEL_PYTHON_MISC_DEPS
python -m pip install twine
python -m pip install cibuildwheel==${CIBW_VERSION}
Expand Down Expand Up @@ -146,3 +144,4 @@ jobs:
path: |
dist/*.tar.gz
./wheelhouse/*.whl

1 change: 1 addition & 0 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ set(PYTHON_TESTS ${PYTHON_TESTS}
# Add Jax tests if it is enabled.
if (OPEN_SPIEL_ENABLE_JAX)
set (PYTHON_TESTS ${PYTHON_TESTS}
jax/deep_cfr_jax_test.py
jax/dqn_jax_test.py
jax/nfsp_jax_test.py
jax/opponent_shaping_jax_test.py
Expand Down
72 changes: 48 additions & 24 deletions open_spiel/python/examples/deep_cfr_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,71 @@

FLAGS = flags.FLAGS

flags.DEFINE_integer("num_iterations", 100, "Number of iterations")
flags.DEFINE_integer("num_traversals", 1500, "Number of traversals/games")
flags.DEFINE_string("game_name", "leduc_poker", "Name of the game")
flags.DEFINE_integer("num_iterations", 101, "Number of iterations")
flags.DEFINE_integer("num_traversals", 375, "Number of traversals/games")
flags.DEFINE_string("game_name", "kuhn_poker", "Name of the game")

# Recommended parameters:
# For more, see https://github.com/aicenter/openspiel_reproductions/

# Parameter Value
# ---------------------------------------
# num_traversals 1500
# batch_size_advantage 2048
# batch_size_strategy 2048
# num_hidden 64
# num_layers 3
# reinitialize_advantage_networks True
# learning_rate 1e-3
# memory_capacity 1e6
# policy_network_train_steps 5000
# advantage_network_train_steps 750


def main(unused_argv):
logging.info("Loading %s", FLAGS.game_name)
logging.info(f"Loading {FLAGS.game_name}")

game = pyspiel.load_game(FLAGS.game_name)
deep_cfr_solver = deep_cfr.DeepCFRSolver(
game,
policy_network_layers=(64, 64, 64),
advantage_network_layers=(64, 64, 64),
num_iterations=FLAGS.num_iterations,
num_traversals=FLAGS.num_traversals,
learning_rate=1e-3,
batch_size_advantage=2048,
batch_size_strategy=2048,
memory_capacity=1e7,
policy_network_train_steps=5000,
advantage_network_train_steps=750,
reinitialize_advantage_networks=True)
game,
policy_network_layers=(64,),
advantage_network_layers=(64,),
num_iterations=FLAGS.num_iterations,
num_traversals=FLAGS.num_traversals,
reinitialize_advantage_networks=True,
learning_rate=1e-3,
batch_size_advantage=256,
batch_size_strategy=256,
memory_capacity=100000,
policy_network_train_steps=2500,
advantage_network_train_steps=375,
print_nash_convs=False # for debugging purposes
)

_, advantage_losses, policy_loss = deep_cfr_solver.solve()
for player, losses in advantage_losses.items():
logging.info("Advantage for player %d: %s", player,
losses[:2] + ["..."] + losses[-2:])
logging.info("Advantage Buffer Size for player %s: '%s'", player,
len(deep_cfr_solver.advantage_buffers[player]))
logging.info("Strategy Buffer Size: '%s'",
len(deep_cfr_solver.strategy_buffer))
logging.info("Final policy loss: '%s'", policy_loss)
logging.info(f"Advantage Buffer Size for player {player}: {len(deep_cfr_solver.advantage_buffers[player])}")
logging.info(f"Strategy Buffer Size: {len(deep_cfr_solver.strategy_buffer)}")
logging.info(f"Final policy loss: {policy_loss}")

average_policy = policy.tabular_policy_from_callable(
game, deep_cfr_solver.action_probabilities)

conv = exploitability.nash_conv(game, average_policy)
logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)
logging.info(f"Deep CFR in {FLAGS.game_name} - NashConv: {conv}")


average_policy_values = expected_game_score.policy_value(
game.new_initial_state(), [average_policy] * 2)
print("Computed player 0 value: {}".format(average_policy_values[0]))
print("Computed player 1 value: {}".format(average_policy_values[1]))
if FLAGS.game_name == "kuhn_poker":
# We know EVs
logging.info(f"Computed player 0 value: {average_policy_values[0]:.2f} (expected: {-1/18:.2f}).")
logging.info(f"Computed player 1 value: {average_policy_values[1]:.2f} (expected: {1/18:.2f}).")
else:
logging.info(f"Computed player 0 value: {average_policy_values[0]:.2f}")
logging.info(f"Computed player 1 value: {average_policy_values[1]:.2f}")


if __name__ == "__main__":
Expand Down
70 changes: 45 additions & 25 deletions open_spiel/python/examples/deep_cfr_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,73 @@
import pyspiel
from open_spiel.python.pytorch import deep_cfr

# Recommended parameters:
# For more, see https://github.com/aicenter/openspiel_reproductions/

# Parameter Value
# ---------------------------------------
# num_traversals 1500
# batch_size_advantage 2048
# batch_size_strategy 2048
# num_hidden 64
# num_layers 3
# reinitialize_advantage_networks True
# learning_rate 1e-3
# memory_capacity 1e6
# policy_network_train_steps 5000
# advantage_network_train_steps 750


FLAGS = flags.FLAGS

flags.DEFINE_integer("num_iterations", 400, "Number of iterations")
flags.DEFINE_integer("num_traversals", 160, "Number of traversals/games")
flags.DEFINE_integer("num_iterations", 101, "Number of iterations")
flags.DEFINE_integer("num_traversals", 375, "Number of traversals/games")
flags.DEFINE_string("game_name", "kuhn_poker", "Name of the game")


def main(unused_argv):
logging.info("Loading %s", FLAGS.game_name)
logging.info(f"Loading {FLAGS.game_name}")
game = pyspiel.load_game(FLAGS.game_name)

deep_cfr_solver = deep_cfr.DeepCFRSolver(
game,
policy_network_layers=(32, 32),
advantage_network_layers=(16, 16),
num_iterations=FLAGS.num_iterations,
num_traversals=FLAGS.num_traversals,
learning_rate=1e-3,
batch_size_advantage=512,
batch_size_strategy=None,
memory_capacity=int(1e7),
reinitialize_advantage_networks=False,
policy_network_train_steps=100
)
game,
policy_network_layers=(64,),
advantage_network_layers=(64,),
num_iterations=FLAGS.num_iterations,
num_traversals=FLAGS.num_traversals,
reinitialize_advantage_networks=True,
learning_rate=1e-3,
batch_size_advantage=256,
batch_size_strategy=256,
memory_capacity=100000,
policy_network_train_steps=2500,
advantage_network_train_steps=375,
print_nash_convs=False # for debugging purposes
)

_, advantage_losses, policy_loss = deep_cfr_solver.solve()
for player, losses in advantage_losses.items():
logging.info("Advantage for player %d: %s", player,
losses[:2] + ["..."] + losses[-2:])
logging.info("Advantage Buffer Size for player %s: '%s'", player,
len(deep_cfr_solver.advantage_buffers[player]))
logging.info("Strategy Buffer Size: '%s'",
len(deep_cfr_solver.strategy_buffer))
logging.info("Final policy loss: '%s'", policy_loss)
logging.info(f"Advantage Buffer Size for player {player}: {len(deep_cfr_solver.advantage_buffers[player])}")
logging.info(f"Strategy Buffer Size: {len(deep_cfr_solver.strategy_buffer)}")
logging.info(f"Final policy loss: {policy_loss}")

average_policy = policy.tabular_policy_from_callable(
game, deep_cfr_solver.action_probabilities)
pyspiel_policy = policy.python_policy_to_pyspiel_policy(average_policy)
conv = pyspiel.nash_conv(game, pyspiel_policy)
logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)
logging.info(f"Deep CFR in {FLAGS.game_name} - NashConv: {conv}")

average_policy_values = expected_game_score.policy_value(
game.new_initial_state(), [average_policy] * 2)
logging.info("Computed player 0 value: %.2f (expected: %.2f).",
average_policy_values[0], -1 / 18)
logging.info("Computed player 1 value: %.2f (expected: %.2f).",
average_policy_values[1], 1 / 18)
if FLAGS.game_name == "kuhn_poker":
# We know EVs
logging.info(f"Computed player 0 value: {average_policy_values[0]:.2f} (expected: {-1/18:.2f}).")
logging.info(f"Computed player 1 value: {average_policy_values[1]:.2f} (expected: {1/18:.2f}).")
else:
logging.info(f"Computed player 0 value: {average_policy_values[0]:.2f}")
logging.info(f"Computed player 1 value: {average_policy_values[1]:.2f}")


if __name__ == "__main__":
Expand Down
Loading
Loading