-
Notifications
You must be signed in to change notification settings - Fork 1k
Alpha zero refactor (testing and polishing) #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Alpha zero refactor (testing and polishing) #1362
Conversation
Migrated the alphaZero model from tensorflow to JAX, the migrated model is in model_jax.py
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Awesome, thanks! |
Hi @alexunderch @parvxh @harmanagrawal , Just a quick question: why is it labeled "the first major part"? Because flax.nnx isn't fully working? What's missing? Does this currently work with flax.linen? I'd really like to see a few graphs in a small game convincingly learning. Would that be possible to show? How long would it take, on e.g. Tic-Tac-Toe? |
Yes, exacltly.
The lads have their priories of submitting the pr, that's why we did it, but we'll finish all the fixes in the next week or two. Same with benchmarking for tic-tac-toe. Appreciate your patience. |
I see. Ok, I'll wait until it's fully working to take a deeper look. Feel free to use this PR branch to collaborate amongst yourselves (i.e. update as necessary) and then just let me know when you think it's ready. Thanks for doing this! |
Surely, we will notify you! |
@lanctot, for both APIs: I've run only one benchmark on P.s. need to fix rendering as well xD. The code runs fine, so if you have an ability to run a test or 2 on a gpu, will be fire. I will run them by the weekend, I hope. |
@lanctot , I ran for much longer time a TTT experiment Doesn't look good, does it? Can the picture tell you what I can look at to find bugs? |
I'm not sure. I will solicit the help of @tewalds : Timo, does this look like what you expect? The top left one looks ok...? Should they move closer to 0, though, I'm not sure. Maybe not for a small MLP? So far I'm not too concerned about that one. I'm not sure what the top two are showing yet (and what the differences are between 0 to 6), but I would expect accuracy to go up over time. The top-right one: I would expect would go up over time, but it doesn't seem to..? (but the one from two days ago does -- maybe it's because it's learning to draw against MCTS-100 really fast, which is possible and would be good) Can you explain what you did differently from the graph two days ago? The one from two days ago seemed like a very small (too small) replay buffer and batch size. Did you increase those in the second run? Also how many MCTS simulations are you using? Also you said you let it run for longer, but I see roughyl the same number of steps on the x-axis. First step would be to, every epoch or training step, play 10 games against random (as P0) and 10 games against random as P1, and dump those.. and let's inspect the games. We can also track the values versus random over time. If those don't go up then there's something very seriously wrong, but I roughly know what that graph should look like. |
@lanctot I use default value of averaging of 100 games, 300 simulations each:
|
will share some progress tomorrow, you may approve checks later <3 |
I guess, we're making slight progress, do we not? give it a look @lanctot |
Hey @alexunderch a bit swamped at the moment. I will need the help of @tewalds. I have emailed him and says he can take a look at some poi t but may require a call / catch up. I will be in touch. This will be slow on our side, sorry |
Hello @lanctot !
We @parvxh, me, and @harmanagrawal present our first parts of the AlphaZero refactor. We (with the major help of the guys) have rewritten the models, using
flax.linen
andflax.nnx
(not full support yet, but we'll fulfil it in the nearest future). Moreover, we added a replay buffer compatible with thegpu
execution.Problems that we've faced with:
multiprocessing
vsjax
combo: asjax
operates in the multithreaded mode, it doesn't work in thefork
method, thus we had to overwrite it. However, there are still spontaneous idle sections in the execution, which may be connected with not that clear usage of synchronisation primitives.With this pr, we want to know if we're on the right direction and want to contribute to the persistent problem and not hold the solution process behind the closed doors.