Skip to content

Conversation

@alexunderch
Copy link
Contributor

Hello @lanctot !

We @parvxh, me, and @harmanagrawal present our first parts of the AlphaZero refactor. We (with the major help of the guys) have rewritten the models, using flax.linen and flax.nnx (not full support yet, but we'll fulfil it in the nearest future). Moreover, we added a replay buffer compatible with the gpu execution.

Problems that we've faced with:

  • multiprocessing vs jax combo: as jax operates in the multithreaded mode, it doesn't work in the fork method, thus we had to overwrite it. However, there are still spontaneous idle sections in the execution, which may be connected with not that clear usage of synchronisation primitives.

With this pr, we want to know if we're on the right direction and want to contribute to the persistent problem and not hold the solution process behind the closed doors.

@google-cla
Copy link

google-cla bot commented Aug 17, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@lanctot
Copy link
Collaborator

lanctot commented Aug 18, 2025

Awesome, thanks!

@lanctot
Copy link
Collaborator

lanctot commented Aug 19, 2025

Hi @alexunderch @parvxh @harmanagrawal ,

Just a quick question: why is it labeled "the first major part"? Because flax.nnx isn't fully working?

What's missing? Does this currently work with flax.linen? I'd really like to see a few graphs in a small game convincingly learning. Would that be possible to show? How long would it take, on e.g. Tic-Tac-Toe?

@alexunderch
Copy link
Contributor Author

alexunderch commented Aug 19, 2025

Yes, exacltly.

  1. flax.nnx is not fully working, there are some minor fixes (e.g. conv dimension calculation and tests) left
  2. the implementation does work, however, there happen some problems due to multiprocessing (one of the process stales)

The lads have their priories of submitting the pr, that's why we did it, but we'll finish all the fixes in the next week or two. Same with benchmarking for tic-tac-toe.

Appreciate your patience.

@lanctot
Copy link
Collaborator

lanctot commented Aug 19, 2025

I see. Ok, I'll wait until it's fully working to take a deeper look. Feel free to use this PR branch to collaborate amongst yourselves (i.e. update as necessary) and then just let me know when you think it's ready. Thanks for doing this!

@alexunderch
Copy link
Contributor Author

Surely, we will notify you!

@alexunderch
Copy link
Contributor Author

alexunderch commented Aug 27, 2025

@lanctot, for both APIs: linen and nnx there are now tests, and they're passing. The only minor things left from the development side are model export and the changelog benchmarks.

I've run only one benchmark on cpu and a ridiculously small replay buffer/batch size of 256/2.
Figure_1

P.s. need to fix rendering as well xD.

The code runs fine, so if you have an ability to run a test or 2 on a gpu, will be fire. I will run them by the weekend, I hope.

@alexunderch
Copy link
Contributor Author

@lanctot , I ran for much longer time a TTT experiment
telegram-cloud-photo-size-2-5301266743655791235-y

Doesn't look good, does it? Can the picture tell you what I can look at to find bugs?

@lanctot
Copy link
Collaborator

lanctot commented Aug 29, 2025

I'm not sure. I will solicit the help of @tewalds : Timo, does this look like what you expect?

The top left one looks ok...? Should they move closer to 0, though, I'm not sure. Maybe not for a small MLP? So far I'm not too concerned about that one.

I'm not sure what the top two are showing yet (and what the differences are between 0 to 6), but I would expect accuracy to go up over time.

The top-right one: I would expect would go up over time, but it doesn't seem to..? (but the one from two days ago does -- maybe it's because it's learning to draw against MCTS-100 really fast, which is possible and would be good)

Can you explain what you did differently from the graph two days ago? The one from two days ago seemed like a very small (too small) replay buffer and batch size. Did you increase those in the second run? Also how many MCTS simulations are you using?

Also you said you let it run for longer, but I see roughyl the same number of steps on the x-axis.

First step would be to, every epoch or training step, play 10 games against random (as P0) and 10 games against random as P1, and dump those.. and let's inspect the games. We can also track the values versus random over time. If those don't go up then there's something very seriously wrong, but I roughly know what that graph should look like.

@alexunderch
Copy link
Contributor Author

alexunderch commented Aug 29, 2025

@lanctot
The main difference between graphs is in buffer/batch size: 2 ** 16 and 2 ** 10, that were default values for the model

I use default value of averaging of 100 games, 300 simulations each:

{
  "actors": 2,
  "checkpoint_freq": 20,
  "device": "cpu",
  "eval_levels": 7,
  "evaluation_window": 100,
  "evaluators": 1,
  "game": "tic_tac_toe",
  "learning_rate": 0.001,
  "max_simulations": 300,
  "max_steps": 0,
  "nn_api_version": "linen",
  "nn_depth": 10,
  "nn_model": "mlp",
  "nn_width": 128,
  "observation_shape": [
    3,
    3,
    3
  ],
  "output_size": 9,
  "path": "checkpoints",
  "policy_alpha": 1.0,
  "policy_epsilon": 0.25,
  "quiet": true,
  "replay_buffer_reuse": 3,
  "replay_buffer_size": 65536,
  "temperature": 1.0,
  "temperature_drop": 10,
  "train_batch_size": 1024,
  "uct_c": 2,
  "weight_decay": 0.0001
}

@alexunderch
Copy link
Contributor Author

will share some progress tomorrow, you may approve checks later <3

@alexunderch
Copy link
Contributor Author

I guess, we're making slight progress, do we not?

telegram-cloud-photo-size-2-5314378076219637199-y

give it a look @lanctot

@lanctot
Copy link
Collaborator

lanctot commented Jan 7, 2026

Wait a sec.. is one of the tests writing to disk? If so we should definitely disable it during the tests.

@alexunderch
Copy link
Contributor Author

Good catch. Try to look into it. Merry (orthodox) Christmas!

@alexunderch
Copy link
Contributor Author

@lanctot I can't find why the space runs out exactly on this platform: my linux PC can't load github.
I found on the internet that the particular image has a lot of things pre-installed like Android SDK, dotnet and haskell. Gemini and google propose to delete those, that's why I tried to implement a workaround in actions.yaml.

Sorry, it's my first time facing issues like these.

@lanctot
Copy link
Collaborator

lanctot commented Jan 7, 2026

Ok yeah I was wondering that.. I was even surprised to find that Linux arm wheels are supported and had no idea where they get used :)

@lanctot
Copy link
Collaborator

lanctot commented Jan 7, 2026

Sorry, it's my first time facing issues like these.

No worries, it's new for me too! We just started supporting Linux ofr arm64 recently based on a community contribution.

@alexunderch
Copy link
Contributor Author

We spoke with Gemini, and it adviced to upgrade linux distiributions because I found that actions used a different docker image that you've been asking during the build: Error response from daemon: No such image: quay.io/pypa/manylinux2014_x86_64:2025.10.10-1

It's all looks more like coping and I think I should just try to reproduce the problem locally, and not go left and right with github actions

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 8, 2026

We've got a new error: OOM but for RAM, which is something new. Googling says to turn off multi-threading in make.

It says that the runner has like 7GB of RAM, don't know if that applies to your configuration.

@lanctot
Copy link
Collaborator

lanctot commented Jan 8, 2026

We can't do that. It'll make the build and tests take forever.

I think the best thing we can do right now is exclude AlphaZero from the tests. I've not seen any of these issues for other PRs (many of them recent), so I don't know why it would only show up for AlphaZero. Something is not making sense, but it's awkward to keep iterating on it in this PR which is already quite large.

We could simply exclude AlphaZero tests for now so that we can do the import, merge the code, and then make a follow-up PR where you can try to enable them again and debug why the CI tests are failing for the AlphaZero code separately. Wdyt?

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 8, 2026

Yes, let's do that. Blindly debuggin doesn't look successful at all. I will remove AZ tests and revert the last changes. Let's see how it'd be going. Sorry for this mess.

Looking through logs, I don't see that the tests are failing, they don't even get to run - the space is being run out even before all Linux deps are installed.

Also, an interesting detail: while for manual installation you use clang++, for linux wheels, g++ is used. Doesn't it create any sort of disperancy?

@alexunderch
Copy link
Contributor Author

I think I might have found the problem, and it may backfire for other jax refactors -- it's checkpointing with orbax: it preallocates some substantial amount of disk space to ensure the throughput. But I want to test it separately.

@lanctot
Copy link
Collaborator

lanctot commented Jan 9, 2026

Ah this makes sense. But can maybe just disable checpointing during the test by setting a flag only in the test?

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 12, 2026

@lanctot, can you run it once with turned off AZ tests? I want to make sure that current workflows and requirements are good to go.

@lanctot
Copy link
Collaborator

lanctot commented Jan 12, 2026

@lanctot, can you run it once with turned off AZ tests? I want to make sure that current workflows and requirements are good to go.

Yeah - I will do it when I import it. We'll definitely have to find a way to disable checkpointing once we enable tests because I won't be able to test it internally for similar reasons. But we can put that off to later. I'll make sure it runs first before we merge it.

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 18, 2026

i think eta is the end of the next week, because I've got some things to do
everything should be fine when I merge the updates from master branch

apologise for the delay

@alexunderch
Copy link
Contributor Author

alexunderch commented Jan 24, 2026

@parvxh may you sign CLA, please, as we finally prepare the PR for the merge...

I will rebase the code otherwise...

@lanctot
Copy link
Collaborator

lanctot commented Jan 26, 2026

@alexunderch can you sign the CLA again? (maybe it expired?)

I can't import it if the CLA isn't valid

@alexunderch
Copy link
Contributor Author

@parvxh didn't sign, mine is fine. I imported (erroneosly, some of his changes from master..)

@lanctot
Copy link
Collaborator

lanctot commented Jan 26, 2026

Ok thanks, I will email him.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants