Skip to content

Commit

Permalink
Implement V4TrainingData (#722)
Browse files Browse the repository at this point in the history
* Include Q evaluation in training data

* Flip Q to match winner

* Set probabilities of illegal moves to -1

* Include best move info

* Remove root_q for now

* Revert "Remove root_q for now"

This reverts commit 318b6cb.

* clang-format

* Comment about illegal moves

* Store draw probability of root and best move

* Add missing .
  • Loading branch information
Cyanogenoid authored and Tilps committed Feb 10, 2019
1 parent dc4364d commit 3eda703
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 15 deletions.
22 changes: 16 additions & 6 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,19 +297,21 @@ uint64_t ReverseBitsInBytes(uint64_t v) {
}
} // namespace

V3TrainingData Node::GetV3TrainingData(
GameResult game_result, const PositionHistory& history,
FillEmptyHistory fill_empty_history) const {
V3TrainingData result;
V4TrainingData Node::GetV4TrainingData(GameResult game_result,
const PositionHistory& history,
FillEmptyHistory fill_empty_history,
float best_eval) const {
V4TrainingData result;

// Set version.
result.version = 3;
result.version = 4;

// Populate probabilities.
float total_n = static_cast<float>(GetChildrenVisits());
// Prevent garbage/invalid training data from being uploaded to server.
if (total_n <= 0.0f) throw Exception("Search generated invalid data!");
std::memset(result.probabilities, 0, sizeof(result.probabilities));
// Set illegal moves to have -1 probability.
std::memset(result.probabilities, -1, sizeof(result.probabilities));
for (const auto& child : Edges()) {
result.probabilities[child.edge()->GetMove().as_nn_index()] =
child.GetN() / total_n;
Expand Down Expand Up @@ -343,6 +345,14 @@ V3TrainingData Node::GetV3TrainingData(
result.result = 0;
}

// Aggregate evaluation Q.
result.root_q = -GetQ();
result.best_q = best_eval;

// Draw probability of WDL head.
result.root_d = 0.0;
result.best_d = 0.0;

return result;
}

Expand Down
5 changes: 3 additions & 2 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ class Node {
// in depth parameter, and returns true if it was indeed updated.
bool UpdateFullDepth(uint16_t* depth);

V3TrainingData GetV3TrainingData(GameResult result,
V4TrainingData GetV4TrainingData(GameResult result,
const PositionHistory& history,
FillEmptyHistory fill_empty_history) const;
FillEmptyHistory fill_empty_history,
float best_eval) const;

// Returns range for iterating over edges.
ConstIterator Edges() const;
Expand Down
2 changes: 1 addition & 1 deletion src/neural/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ TrainingDataWriter::TrainingDataWriter(int game_id) {
if (!fout_) throw Exception("Cannot create gzip file " + filename_);
}

void TrainingDataWriter::WriteChunk(const V3TrainingData& data) {
void TrainingDataWriter::WriteChunk(const V4TrainingData& data) {
auto bytes_written =
gzwrite(fout_, reinterpret_cast<const char*>(&data), sizeof(data));
if (bytes_written != sizeof(data)) {
Expand Down
10 changes: 7 additions & 3 deletions src/neural/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace lczero {

#pragma pack(push, 1)

struct V3TrainingData {
struct V4TrainingData {
uint32_t version;
float probabilities[1858];
uint64_t planes[104];
Expand All @@ -47,8 +47,12 @@ struct V3TrainingData {
uint8_t rule50_count;
uint8_t move_count;
int8_t result;
float root_q;
float best_q;
float root_d;
float best_d;
} PACKED_STRUCT;
static_assert(sizeof(V3TrainingData) == 8276, "Wrong struct size");
static_assert(sizeof(V4TrainingData) == 8292, "Wrong struct size");

#pragma pack(pop)

Expand All @@ -63,7 +67,7 @@ class TrainingDataWriter {
}

// Writes a chunk.
void WriteChunk(const V3TrainingData& data);
void WriteChunk(const V4TrainingData& data);

// Flushes file and closes it.
void Finalize();
Expand Down
4 changes: 2 additions & 2 deletions src/selfplay/game.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ void SelfPlayGame::Play(int white_threads, int black_threads, bool training,

if (training) {
// Append training data. The GameResult is later overwritten.
training_data_.push_back(tree_[idx]->GetCurrentHead()->GetV3TrainingData(
training_data_.push_back(tree_[idx]->GetCurrentHead()->GetV4TrainingData(
GameResult::UNDECIDED, tree_[idx]->GetPositionHistory(),
search_->GetParams().GetHistoryFill()));
search_->GetParams().GetHistoryFill(), search_->GetBestEval()));
}

float eval = search_->GetBestEval();
Expand Down
2 changes: 1 addition & 1 deletion src/selfplay/game.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class SelfPlayGame {
std::mutex mutex_;

// Training data to send.
std::vector<V3TrainingData> training_data_;
std::vector<V4TrainingData> training_data_;
};

} // namespace lczero

0 comments on commit 3eda703

Please sign in to comment.