Skip to content

Conversation

alexunderch
Copy link
Contributor

Hello @lanctot !

We @parvxh, me, and @harmanagrawal present our first parts of the AlphaZero refactor. We (with the major help of the guys) have rewritten the models, using flax.linen and flax.nnx (not full support yet, but we'll fulfil it in the nearest future). Moreover, we added a replay buffer compatible with the gpu execution.

Problems that we've faced with:

  • multiprocessing vs jax combo: as jax operates in the multithreaded mode, it doesn't work in the fork method, thus we had to overwrite it. However, there are still spontaneous idle sections in the execution, which may be connected with not that clear usage of synchronisation primitives.

With this pr, we want to know if we're on the right direction and want to contribute to the persistent problem and not hold the solution process behind the closed doors.

Copy link

google-cla bot commented Aug 17, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@lanctot
Copy link
Collaborator

lanctot commented Aug 18, 2025

Awesome, thanks!

@lanctot
Copy link
Collaborator

lanctot commented Aug 19, 2025

Hi @alexunderch @parvxh @harmanagrawal ,

Just a quick question: why is it labeled "the first major part"? Because flax.nnx isn't fully working?

What's missing? Does this currently work with flax.linen? I'd really like to see a few graphs in a small game convincingly learning. Would that be possible to show? How long would it take, on e.g. Tic-Tac-Toe?

@alexunderch
Copy link
Contributor Author

alexunderch commented Aug 19, 2025

Yes, exacltly.

  1. flax.nnx is not fully working, there are some minor fixes (e.g. conv dimension calculation and tests) left
  2. the implementation does work, however, there happen some problems due to multiprocessing (one of the process stales)

The lads have their priories of submitting the pr, that's why we did it, but we'll finish all the fixes in the next week or two. Same with benchmarking for tic-tac-toe.

Appreciate your patience.

@lanctot
Copy link
Collaborator

lanctot commented Aug 19, 2025

I see. Ok, I'll wait until it's fully working to take a deeper look. Feel free to use this PR branch to collaborate amongst yourselves (i.e. update as necessary) and then just let me know when you think it's ready. Thanks for doing this!

@alexunderch
Copy link
Contributor Author

Surely, we will notify you!

@alexunderch
Copy link
Contributor Author

alexunderch commented Aug 27, 2025

@lanctot, for both APIs: linen and nnx there are now tests, and they're passing. The only minor things left from the development side are model export and the changelog benchmarks.

I've run only one benchmark on cpu and a ridiculously small replay buffer/batch size of 256/2.
Figure_1

P.s. need to fix rendering as well xD.

The code runs fine, so if you have an ability to run a test or 2 on a gpu, will be fire. I will run them by the weekend, I hope.

@alexunderch
Copy link
Contributor Author

@lanctot , I ran for much longer time a TTT experiment
telegram-cloud-photo-size-2-5301266743655791235-y

Doesn't look good, does it? Can the picture tell you what I can look at to find bugs?

@lanctot
Copy link
Collaborator

lanctot commented Aug 29, 2025

I'm not sure. I will solicit the help of @tewalds : Timo, does this look like what you expect?

The top left one looks ok...? Should they move closer to 0, though, I'm not sure. Maybe not for a small MLP? So far I'm not too concerned about that one.

I'm not sure what the top two are showing yet (and what the differences are between 0 to 6), but I would expect accuracy to go up over time.

The top-right one: I would expect would go up over time, but it doesn't seem to..? (but the one from two days ago does -- maybe it's because it's learning to draw against MCTS-100 really fast, which is possible and would be good)

Can you explain what you did differently from the graph two days ago? The one from two days ago seemed like a very small (too small) replay buffer and batch size. Did you increase those in the second run? Also how many MCTS simulations are you using?

Also you said you let it run for longer, but I see roughyl the same number of steps on the x-axis.

First step would be to, every epoch or training step, play 10 games against random (as P0) and 10 games against random as P1, and dump those.. and let's inspect the games. We can also track the values versus random over time. If those don't go up then there's something very seriously wrong, but I roughly know what that graph should look like.

@alexunderch
Copy link
Contributor Author

alexunderch commented Aug 29, 2025

@lanctot
The main difference between graphs is in buffer/batch size: 2 ** 16 and 2 ** 10, that were default values for the model

I use default value of averaging of 100 games, 300 simulations each:

{
  "actors": 2,
  "checkpoint_freq": 20,
  "device": "cpu",
  "eval_levels": 7,
  "evaluation_window": 100,
  "evaluators": 1,
  "game": "tic_tac_toe",
  "learning_rate": 0.001,
  "max_simulations": 300,
  "max_steps": 0,
  "nn_api_version": "linen",
  "nn_depth": 10,
  "nn_model": "mlp",
  "nn_width": 128,
  "observation_shape": [
    3,
    3,
    3
  ],
  "output_size": 9,
  "path": "checkpoints",
  "policy_alpha": 1.0,
  "policy_epsilon": 0.25,
  "quiet": true,
  "replay_buffer_reuse": 3,
  "replay_buffer_size": 65536,
  "temperature": 1.0,
  "temperature_drop": 10,
  "train_batch_size": 1024,
  "uct_c": 2,
  "weight_decay": 0.0001
}

@alexunderch
Copy link
Contributor Author

will share some progress tomorrow, you may approve checks later <3

@alexunderch
Copy link
Contributor Author

I guess, we're making slight progress, do we not?

telegram-cloud-photo-size-2-5314378076219637199-y

give it a look @lanctot

@lanctot
Copy link
Collaborator

lanctot commented Sep 3, 2025

Hey @alexunderch a bit swamped at the moment. I will need the help of @tewalds. I have emailed him and says he can take a look at some poi t but may require a call / catch up. I will be in touch. This will be slow on our side, sorry

@alexunderch alexunderch changed the title Alpha zero refactor (the 1st major part) Alpha zero refactor (testing and polishing) Sep 10, 2025
@alexunderch
Copy link
Contributor Author

The latest plots (minor tweaks and fixes here and there). Maybe, using much more resources (I used a toy config), there is smth here:
image

@alexunderch
Copy link
Contributor Author

I found an example with hyperparameters for tic-tac-toe, and results look somewhat more intuitive (although, I had to reduce batch size fourfold due to the resource constraints)
Figure_linen
Figure_nnx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants