Skip to content

Policy temp decay #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ install:
- cmd: IF NOT EXIST c:\cache\protobuf\ cmake -G "Visual Studio 15 2017 Win64" -Dprotobuf_BUILD_SHARED_LIBS=NO -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_INSTALL_PREFIX=c:/cache/protobuf ../cmake
- cmd: IF NOT EXIST c:\cache\protobuf\ msbuild INSTALL.vcxproj /p:Configuration=Release /p:Platform=x64 /m
- cmd: set PATH=c:\cache\protobuf\bin;%PATH%
- cmd: IF NOT EXIST c:\cache\testnet appveyor DownloadFile http://training.lczero.org/get_network?sha=7170f639ba1cdc407283b8e52377283e36845b954788c6ada8897937637ef032 -Filename c:\cache\testnet
- cmd: IF NOT EXIST c:\cache\testnet appveyor DownloadFile https://training.lczero.org/get_network?sha=47e3f899519dc1bc95496a457b77730fce7b0b89b6187af5c01ecbbd02e88398 -Filename c:\cache\testnet
- cmd: IF %GTEST%==true IF NOT EXIST C:\cache\syzygy mkdir C:\cache\syzygy
- cmd: IF %GTEST%==true cd C:\cache\syzygy
- cmd: IF %GTEST%==true IF NOT EXIST KQvK.rtbz curl --remote-name-all https://tablebase.lichess.ovh/tables/standard/3-4-5/K{P,N,R,B,Q}vK.rtb{w,z}
Expand Down
39 changes: 38 additions & 1 deletion src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*/

#include "mcts/node.h"
#include "utils/fastmath.h"

#include <algorithm>
#include <bitset>
Expand Down Expand Up @@ -303,7 +304,8 @@ void Node::CancelScoreUpdate(int multivisit) {
best_child_cached_ = nullptr;
}

void Node::FinalizeScoreUpdate(float v, float d, int multivisit) {
void Node::FinalizeScoreUpdate(float v, float d, int multivisit,
float policy_temperature, float policy_temp_decay, float intermediate[]) {
// Recompute Q.
q_ += multivisit * (v - q_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
Expand All @@ -314,6 +316,41 @@ void Node::FinalizeScoreUpdate(float v, float d, int multivisit) {
}
// Increment N.
n_ += multivisit;


// Update the policies of children _if_ the difference is nontrivial
// Take power only when the difference is nontrivial to minimize error
// On a test with P = 0.2 and 0.5, visits = 2m, limiting pow calls with
// distance from 1 > 0.0001
// brings us to 4 decimal places accuracy compared to the correct answer
// (correct answer = one pow call with the proper scaling exponent).


float old_policy_temp = policy_temperature -
policy_temp_decay * FastLog2(1 + n_last_temp_);
float exponent = old_policy_temp / (policy_temperature -
policy_temp_decay * FastLog2(1 + n_));

if (abs(exponent - 1.0f) > 0.005f) {
float total = 0.0f;
int counter = 0;
for (auto& child : Edges()) {
float new_p = FastPow(child.GetP(), exponent);
intermediate[counter++] = new_p;
total += new_p;
}

counter = 0;
if (total > 0.0f) {
for (auto& child : Edges()) {
child.SetP(intermediate[counter++] / total);
}
}

n_last_temp_ = n_;
}


// Decrement virtual loss.
n_in_flight_ -= multivisit;
// Best child is potentially no longer valid.
Expand Down
9 changes: 7 additions & 2 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ class Node {
// * Q (weighted average of all V in a subtree)
// * N (+=1)
// * N-in-flight (-=1)
void FinalizeScoreUpdate(float v, float d, int multivisit);
void FinalizeScoreUpdate(float v, float d, int multivisit,
float policy_temperature, float policy_temp_decay, float intermediate[]);
// When search decides to treat one visit as several (in case of collisions
// or visiting terminal nodes several times), it amplifies the visit by
// incrementing n_in_flight.
Expand Down Expand Up @@ -421,6 +422,8 @@ class Node {
float visited_policy_ = 0.0f;
// How many completed visits this node had.
uint32_t n_ = 0;
// What the number of visits was at the last time we applied dynamic policy temp.
uint32_t n_last_temp_ = 0;
// (AKA virtual loss.) How many threads currently process this node (started
// but not finished). This value is added to n during selection which node
// to pick in MCTS, and also when selecting the best move.
Expand Down Expand Up @@ -451,7 +454,7 @@ class Node {

// A basic sanity check. This must be adjusted when Node members are adjusted.
#if defined(__i386__) || (defined(__arm__) && !defined(__aarch64__))
static_assert(sizeof(Node) == 52, "Unexpected size of Node for 32bit compile");
static_assert(sizeof(Node) == 56, "Unexpected size of Node for 32bit compile");
#else
static_assert(sizeof(Node) == 80, "Unexpected size of Node");
#endif
Expand Down Expand Up @@ -495,6 +498,8 @@ class EdgeAndNode {

// Edge related getters.
float GetP() const { return edge_->GetP(); }
void SetP(float val) { edge_->SetP(val); }

Move GetMove(bool flip = false) const {
return edge_ ? edge_->GetMove(flip) : Move();
}
Expand Down
9 changes: 8 additions & 1 deletion src/mcts/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ const OptionId SearchParams::kPolicySoftmaxTempId{
"policy-softmax-temp", "PolicyTemperature",
"Policy softmax temperature. Higher values make priors of move candidates "
"closer to each other, widening the search."};
const OptionId SearchParams::kPolicyTempDecayId{
"policy-temp-decay", "PolicyTemperatureDecay",
"Policy softmax temperature decay. Positive values make temperature "
"smaller as visits increase, negative values make it larger. Bigger "
"magnitude means there is a stronger effect."};
const OptionId SearchParams::kMaxCollisionVisitsId{
"max-collision-visits", "MaxCollisionVisits",
"Total allowed node collision visits, per batch."};
Expand Down Expand Up @@ -226,7 +231,8 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<ChoiceOption>(kFpuStrategyAtRootId, fpu_strategy) = "absolute";
options->Add<FloatOption>(kFpuValueAtRootId, -100.0f, 100.0f) = 1.0f;
options->Add<IntOption>(kCacheHistoryLengthId, 0, 7) = 0;
options->Add<FloatOption>(kPolicySoftmaxTempId, 0.1f, 10.0f) = 1.607f;
options->Add<FloatOption>(kPolicySoftmaxTempId, 0.1f, 10.0f) = 1.5427f;
options->Add<FloatOption>(kPolicyTempDecayId, -10.0f, 10.0f) = -0.2702f;
options->Add<IntOption>(kMaxCollisionEventsId, 1, 1024) = 64;
options->Add<IntOption>(kMaxCollisionVisitsId, 1, 1000000) = 9999;
options->Add<BoolOption>(kOutOfOrderEvalId) = true;
Expand Down Expand Up @@ -267,6 +273,7 @@ SearchParams::SearchParams(const OptionsDict& options)
: options.Get<float>(kFpuValueAtRootId.GetId())),
kCacheHistoryLength(options.Get<int>(kCacheHistoryLengthId.GetId())),
kPolicySoftmaxTemp(options.Get<float>(kPolicySoftmaxTempId.GetId())),
kPolicyTempDecay(options.Get<float>(kPolicyTempDecayId.GetId())),
kMaxCollisionEvents(options.Get<int>(kMaxCollisionEventsId.GetId())),
kMaxCollisionVisits(options.Get<int>(kMaxCollisionVisitsId.GetId())),
kOutOfOrderEval(options.Get<bool>(kOutOfOrderEvalId.GetId())),
Expand Down
3 changes: 3 additions & 0 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class SearchParams {
float GetFpuValue(bool at_root) const { return at_root ? kFpuValueAtRoot : kFpuValue; }
int GetCacheHistoryLength() const { return kCacheHistoryLength; }
float GetPolicySoftmaxTemp() const { return kPolicySoftmaxTemp; }
float GetPolicyTempDecay() const { return kPolicyTempDecay; }
int GetMaxCollisionEvents() const { return kMaxCollisionEvents; }
int GetMaxCollisionVisitsId() const { return kMaxCollisionVisits; }
bool GetOutOfOrderEval() const { return kOutOfOrderEval; }
Expand Down Expand Up @@ -127,6 +128,7 @@ class SearchParams {
static const OptionId kFpuValueAtRootId;
static const OptionId kCacheHistoryLengthId;
static const OptionId kPolicySoftmaxTempId;
static const OptionId kPolicyTempDecayId;
static const OptionId kMaxCollisionEventsId;
static const OptionId kMaxCollisionVisitsId;
static const OptionId kOutOfOrderEvalId;
Expand Down Expand Up @@ -161,6 +163,7 @@ class SearchParams {
const float kFpuValueAtRoot;
const int kCacheHistoryLength;
const float kPolicySoftmaxTemp;
const float kPolicyTempDecay;
const int kMaxCollisionEvents;
const int kMaxCollisionVisits;
const bool kOutOfOrderEval;
Expand Down
4 changes: 3 additions & 1 deletion src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,9 @@ void SearchWorker::DoBackupUpdateSingleNode(
if (n->GetOwnEdge()->IsLBounded() && v < 0.0f) v = 0.00f;
}

n->FinalizeScoreUpdate(v, d, node_to_process.multivisit);
n->FinalizeScoreUpdate(v, d, node_to_process.multivisit,
params_.GetPolicySoftmaxTemp(),
params_.GetPolicyTempDecay(), intermediate_);

// Certainty propagation: adjust Qs along the path as if all visits already
// had propagated the certain result.
Expand Down
7 changes: 7 additions & 0 deletions src/mcts/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,13 @@ class SearchWorker {
int number_out_of_order_ = 0;
const SearchParams& params_;
std::unique_ptr<Node> precached_node_;

// Intermediate array to store values when processing policy temperature decay.
// According to a lichess developer post:
// https://lichess.org/blog/Wqa7GiAAAOIpBLoY/
// developer-update-275-improved-game-compression
// there are never more than 256 valid legal moves in any legal position.
float intermediate_[256];
};

} // namespace lczero
5 changes: 4 additions & 1 deletion src/utils/fastmath.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,8 @@ inline float FastPow2(const float a) {
inline float FastLog(const float a) {
return 0.6931471805599453f * FastLog2(a);
}


// Fast approximate pow(a, b).
inline float FastPow(const float a, const float b) { return FastPow2(b * FastLog2(a)); }

} // namespace lczero