From 3eda703fa9cb38cfcf08ad943f7bf311665138aa Mon Sep 17 00:00:00 2001 From: Yan Zhang Date: Sun, 10 Feb 2019 02:50:07 +0000 Subject: [PATCH] Implement V4TrainingData (#722) * Include Q evaluation in training data * Flip Q to match winner * Set probabilities of illegal moves to -1 * Include best move info * Remove root_q for now * Revert "Remove root_q for now" This reverts commit 318b6cb4141247bfe5de3692441eb6d29007587b. * clang-format * Comment about illegal moves * Store draw probability of root and best move * Add missing . --- src/mcts/node.cc | 22 ++++++++++++++++------ src/mcts/node.h | 5 +++-- src/neural/writer.cc | 2 +- src/neural/writer.h | 10 +++++++--- src/selfplay/game.cc | 4 ++-- src/selfplay/game.h | 2 +- 6 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/mcts/node.cc b/src/mcts/node.cc index d968377be2..b3ebc5b653 100644 --- a/src/mcts/node.cc +++ b/src/mcts/node.cc @@ -297,19 +297,21 @@ uint64_t ReverseBitsInBytes(uint64_t v) { } } // namespace -V3TrainingData Node::GetV3TrainingData( - GameResult game_result, const PositionHistory& history, - FillEmptyHistory fill_empty_history) const { - V3TrainingData result; +V4TrainingData Node::GetV4TrainingData(GameResult game_result, + const PositionHistory& history, + FillEmptyHistory fill_empty_history, + float best_eval) const { + V4TrainingData result; // Set version. - result.version = 3; + result.version = 4; // Populate probabilities. float total_n = static_cast(GetChildrenVisits()); // Prevent garbage/invalid training data from being uploaded to server. if (total_n <= 0.0f) throw Exception("Search generated invalid data!"); - std::memset(result.probabilities, 0, sizeof(result.probabilities)); + // Set illegal moves to have -1 probability. + std::memset(result.probabilities, -1, sizeof(result.probabilities)); for (const auto& child : Edges()) { result.probabilities[child.edge()->GetMove().as_nn_index()] = child.GetN() / total_n; @@ -343,6 +345,14 @@ V3TrainingData Node::GetV3TrainingData( result.result = 0; } + // Aggregate evaluation Q. + result.root_q = -GetQ(); + result.best_q = best_eval; + + // Draw probability of WDL head. + result.root_d = 0.0; + result.best_d = 0.0; + return result; } diff --git a/src/mcts/node.h b/src/mcts/node.h index 26d020b2ef..c18fff10b3 100644 --- a/src/mcts/node.h +++ b/src/mcts/node.h @@ -204,9 +204,10 @@ class Node { // in depth parameter, and returns true if it was indeed updated. bool UpdateFullDepth(uint16_t* depth); - V3TrainingData GetV3TrainingData(GameResult result, + V4TrainingData GetV4TrainingData(GameResult result, const PositionHistory& history, - FillEmptyHistory fill_empty_history) const; + FillEmptyHistory fill_empty_history, + float best_eval) const; // Returns range for iterating over edges. ConstIterator Edges() const; diff --git a/src/neural/writer.cc b/src/neural/writer.cc index 7ddde501ac..b762d3d319 100644 --- a/src/neural/writer.cc +++ b/src/neural/writer.cc @@ -51,7 +51,7 @@ TrainingDataWriter::TrainingDataWriter(int game_id) { if (!fout_) throw Exception("Cannot create gzip file " + filename_); } -void TrainingDataWriter::WriteChunk(const V3TrainingData& data) { +void TrainingDataWriter::WriteChunk(const V4TrainingData& data) { auto bytes_written = gzwrite(fout_, reinterpret_cast(&data), sizeof(data)); if (bytes_written != sizeof(data)) { diff --git a/src/neural/writer.h b/src/neural/writer.h index a642c5f4f8..a373621267 100644 --- a/src/neural/writer.h +++ b/src/neural/writer.h @@ -35,7 +35,7 @@ namespace lczero { #pragma pack(push, 1) -struct V3TrainingData { +struct V4TrainingData { uint32_t version; float probabilities[1858]; uint64_t planes[104]; @@ -47,8 +47,12 @@ struct V3TrainingData { uint8_t rule50_count; uint8_t move_count; int8_t result; + float root_q; + float best_q; + float root_d; + float best_d; } PACKED_STRUCT; -static_assert(sizeof(V3TrainingData) == 8276, "Wrong struct size"); +static_assert(sizeof(V4TrainingData) == 8292, "Wrong struct size"); #pragma pack(pop) @@ -63,7 +67,7 @@ class TrainingDataWriter { } // Writes a chunk. - void WriteChunk(const V3TrainingData& data); + void WriteChunk(const V4TrainingData& data); // Flushes file and closes it. void Finalize(); diff --git a/src/selfplay/game.cc b/src/selfplay/game.cc index ae199ec023..276f846a9e 100644 --- a/src/selfplay/game.cc +++ b/src/selfplay/game.cc @@ -100,9 +100,9 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training, if (training) { // Append training data. The GameResult is later overwritten. - training_data_.push_back(tree_[idx]->GetCurrentHead()->GetV3TrainingData( + training_data_.push_back(tree_[idx]->GetCurrentHead()->GetV4TrainingData( GameResult::UNDECIDED, tree_[idx]->GetPositionHistory(), - search_->GetParams().GetHistoryFill())); + search_->GetParams().GetHistoryFill(), search_->GetBestEval())); } float eval = search_->GetBestEval(); diff --git a/src/selfplay/game.h b/src/selfplay/game.h index 1f8a729e1a..462d0583be 100644 --- a/src/selfplay/game.h +++ b/src/selfplay/game.h @@ -102,7 +102,7 @@ class SelfPlayGame { std::mutex mutex_; // Training data to send. - std::vector training_data_; + std::vector training_data_; }; } // namespace lczero