Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for more generic output head format. #7

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions proto/net.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,36 @@ message EngineVersion {
optional uint32 patch = 3;
}

// Weights are a collection of layers that flow inputs to multiple outputs.
// The input size is defined by InputFormat enum.
// The final output sizes are defined by the OutputFormat enum.
// In between sizes are inplicit in the sizes of the data provided to the
// layers.
// For example: a fully connected with x inputs and x*y params, must have y
// outputs.
message Weights {
message Layer {
optional float min_val = 1;
optional float max_val = 2;
optional bytes params = 3;
}

// All layers except weights are optional. Can represent a convolution that
// has biases, or batch normalization (with or without gammas and betas) and
// optional relu activation.
message ConvBlock {
optional Layer weights = 1;
optional Layer biases = 2;
optional Layer bn_means = 3;
optional Layer bn_stddivs = 4;
optional Layer bn_gammas = 5;
optional Layer bn_betas = 6;
// If present specifies the filter size. If 0/missing it is assumed 3
// before the final heads, and 1 in the final heads.
optional int32 filter_size = 7;
// Allows for a convolution block with no relu applied - ignored in
// residual blocks.
optional bool disable_relu = 8;
}

message SEunit {
Expand All @@ -59,29 +75,46 @@ message Weights {
optional Layer b2 = 4;
}

// Two convolution blocks one with relu and second without summed and passed through relu.
// If se is present it applies its transform to the output of the second conv block.
message Residual {
optional ConvBlock conv1 = 1;
optional ConvBlock conv2 = 2;
optional SEunit se = 3;
}

// Applies weights as fully connected layer. Biases are optional.
message FCLayer {
optional Layer weights = 1;
optional Layer biases = 2;
optional bool use_relu = 3;
}

message Output {
repeated ConvBlock outputConvs = 1;
repeated FCLayer outputReductions = 2;
}

// Input convnet.
optional ConvBlock input = 1;

// Residual tower.
repeated Residual residual = 2;

// Policy head
// Legacy Policy head
optional ConvBlock policy = 3;
optional Layer ip_pol_w = 4;
optional Layer ip_pol_b = 5;

// Value head
// Legacy Value head
optional ConvBlock value = 6;
optional Layer ip1_val_w = 7;
optional Layer ip1_val_b = 8;
optional Layer ip2_val_w = 9;
optional Layer ip2_val_b = 10;

optional Output policy_head = 11;
optional Output value_head = 12;
}

message TrainingParams {
Expand Down Expand Up @@ -115,6 +148,7 @@ message NetworkFormat {
NETWORK_UNKNOWN = 0;
NETWORK_CLASSICAL = 1;
NETWORK_SE = 2;
NETWORK_SE_AND_EXTENSIBLE_OUTPUT_HEADS = 3;
}
optional NetworkStructure network = 3;
}
Expand Down