Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for additional static inputs #184

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tf/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
LC0_MINOR_WITH_INPUT_TYPE_3 = 25
LC0_MINOR_WITH_INPUT_TYPE_4 = 26
LC0_MINOR_WITH_INPUT_TYPE_5 = 27
LC0_MINOR_WITH_STATIC_SQUARES = 29
LC0_PATCH = 0
WEIGHTS_MAGIC = 0x1c0

Expand Down Expand Up @@ -78,6 +79,11 @@ def set_input(self, input_format):
elif input_format != pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE:
self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_3

def set_input_static(self, input_static_format):
self.pb.format.network_format.input_static = input_static_format
if input_static_format != pb.NetworkFormat.INPUT_STATIC_NONE:
self.pb.min_version.minor = LC0_MINOR_WITH_STATIC_SQUARES

def get_weight_amounts(self):
value_weights = 8
policy_weights = 6
Expand Down
6 changes: 5 additions & 1 deletion tf/net_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
'--ignore-errors',
action='store_true',
help='Ignore missing and wrong sized values.')
argparser.add_argument('-a',
'--attempt_adapt',
action='store_true',
help='If -e, then try to reshape wrong sized rather than drop.')
args = argparser.parse_args()
cfg = yaml.safe_load(args.cfg.read())
print(yaml.dump(cfg, default_flow_style=False))
START_FROM = args.start

tfp = tfprocess.TFProcess(cfg)
tfp.init_net()
tfp.replace_weights(args.net, args.ignore_errors)
tfp.replace_weights(args.net, args.ignore_errors, args.attempt_adapt)
tfp.global_step.assign(START_FROM)

root_dir = os.path.join(cfg['training']['path'], cfg['name'])
Expand Down
88 changes: 83 additions & 5 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def call(self, inputs):
return tf.matmul(h_conv_pol_flat,
tf.cast(self.fc1, h_conv_pol_flat.dtype))


class Metric:
def __init__(self, short_name, long_name, suffix='', **kwargs):
self.short_name = short_name
Expand Down Expand Up @@ -123,11 +122,13 @@ def __init__(self, cfg):
value_head = self.cfg['model'].get('value', 'wdl')
moves_left_head = self.cfg['model'].get('moves_left', 'v1')
input_mode = self.cfg['model'].get('input_type', 'classic')
input_static_mode = self.cfg['model'].get('input_static_type', 'none')

self.POLICY_HEAD = None
self.VALUE_HEAD = None
self.MOVES_LEFT_HEAD = None
self.INPUT_MODE = None
self.INPUT_STATIC_MODE = None

if policy_head == "classical":
self.POLICY_HEAD = pb.NetworkFormat.POLICY_CLASSICAL
Expand Down Expand Up @@ -183,6 +184,16 @@ def __init__(self, cfg):

self.net.set_input(self.INPUT_MODE)

if input_static_mode == "none":
self.INPUT_STATIC_MODE = pb.NetworkFormat.INPUT_STATIC_NONE
elif input_static_mode == "squares":
self.INPUT_STATIC_MODE = pb.NetworkFormat.INPUT_STATIC_SQUARES
else:
raise ValueError(
"Unknown input static mode format: {}".format(input_static_mode))

self.net.set_input_static(self.INPUT_STATIC_MODE)

self.swa_enabled = self.cfg['training'].get('swa', False)

# Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1)
Expand Down Expand Up @@ -474,7 +485,21 @@ def accuracy(target, output):
keep_checkpoint_every_n_hours=24,
checkpoint_name=self.cfg['name'])

def replace_weights(self, proto_filename, ignore_errors=False):
def expand_trim(self, new_weight, weight):
# Trim first as the expand logic depends upon new_weight having no dimensions larger than weight.
for i in range(len(new_weight.shape)):
if new_weight.shape[i] > weight.shape[i]:
print('Triming index {} from {} to {}'.format(i, new_weight.shape[i], weight.shape[i]))
new_weight = new_weight[(slice(None), ) * i + (slice(weight.shape[i]), ) + (slice(None), ) * (len(new_weight.shape) - i - 1)]
for i in range(len(new_weight.shape)):
if new_weight.shape[i] < weight.shape[i]:
print('Expanding index {} from {} to {}'.format(i, new_weight.shape[i], weight.shape[i]))
extra = weight.shape[i] - new_weight.shape[i]
out_slice = sum([(slice(extra),) if j==i else (slice(new_weight.shape[j]), ) for j in range(len(new_weight.shape))], ())
new_weight = tf.concat([new_weight, weight[out_slice]], axis=i)
return new_weight

def replace_weights(self, proto_filename, ignore_errors=False, adapt_length=False):
self.net.parse_proto(proto_filename)

filters, blocks = self.net.filters(), self.net.blocks()
Expand Down Expand Up @@ -510,11 +535,17 @@ def replace_weights(self, proto_filename, ignore_errors=False):
else:
raise KeyError(error_string)

length_wrong = False

if reduce(operator.mul, weight.shape.as_list(),
1) != len(new_weight):
error_string = 'Tensor {} has wrong length. Tensorflow shape {}, size in protobuf {}'.format(
weight.name, weight.shape.as_list(), len(new_weight))
if ignore_errors:
if adapt_length and ignore_errors:
print(error_string)
print('Will attempt to reshape')
length_wrong = True
elif ignore_errors:
print(error_string)
continue
else:
Expand All @@ -530,6 +561,7 @@ def replace_weights(self, proto_filename, ignore_errors=False):
if (i % (num_inputs * 9)) // 9 == rule50_input:
new_weight[i] = new_weight[i] * 99


# Convolution weights need a transpose
#
# TF (kYXInputOutput)
Expand All @@ -539,20 +571,58 @@ def replace_weights(self, proto_filename, ignore_errors=False):
# [output, input, filter_size, filter_size]
s = weight.shape.as_list()
shape = [s[i] for i in [3, 2, 0, 1]]
if length_wrong and adapt_length:
if shape[0] == self.RESIDUAL_FILTERS:
shape[0] = filters
# Input is a special case
if weight.name == 'input/conv2d/kernel:0':
if shape[1] == 176 and shape[0]*shape[1]*shape[2]*shape[3] > len(new_weight):
shape[1] = 112
elif shape[1] == 112 and shape[0]*shape[1]*shape[2]*shape[3] < len(new_weight):
shape[1] = 176
else:
if shape[1] == self.RESIDUAL_FILTERS:
shape[1] = filters
if shape[0]*shape[1]*shape[2]*shape[3] != len(new_weight):
print('Failed to reshape')
continue

new_weight = tf.constant(new_weight, shape=shape)
weight.assign(tf.transpose(a=new_weight, perm=[2, 3, 1, 0]))
new_weight = tf.transpose(a=new_weight, perm=[2, 3, 1, 0])
if length_wrong:
new_weight = self.expand_trim(new_weight, weight)
weight.assign(new_weight)
elif weight.shape.ndims == 2:
# Fully connected layers are [in, out] in TF
#
# [out, in] in Leela
#
s = weight.shape.as_list()
shape = [s[i] for i in [1, 0]]
if length_wrong and adapt_length:
if shape[0] == self.RESIDUAL_FILTERS:
shape[0] = filters
if shape[1] == self.RESIDUAL_FILTERS:
shape[1] = filters
if shape[0]*shape[1] != len(new_weight):
print('Failed to reshape')
continue
new_weight = tf.constant(new_weight, shape=shape)
weight.assign(tf.transpose(a=new_weight, perm=[1, 0]))
new_weight = tf.transpose(a=new_weight, perm=[1, 0])
if length_wrong:
new_weight = self.expand_trim(new_weight, weight)
weight.assign(new_weight)
else:
# Biases, batchnorm etc
new_weight = tf.constant(new_weight, shape=weight.shape)
if length_wrong and adapt_length:
if shape[0] == self.RESIDUAL_FILTERS:
shape[0] = filters
if shape[0] != len(new_weight):
print('Failed to reshape')
continue
if length_wrong:
new_weight = self.expand_trim(new_weight, weight)
weight.assign(new_weight)
# Replace the SWA weights as well, ensuring swa accumulation is reset.
if self.swa_enabled:
Expand Down Expand Up @@ -1099,6 +1169,14 @@ def residual_block(self, inputs, channels, name):
[inputs, out2]))

def construct_net(self, inputs):
if self.INPUT_STATIC_MODE == pb.NetworkFormat.INPUT_STATIC_SQUARES:
inputs = tf.concat([
inputs,
tf.reshape(tf.eye(64, batch_shape=[tf.shape(inputs)[0]]),
[-1, 64, 8, 8])
],
axis=-3)

flow = self.conv_block(inputs,
filter_size=3,
output_channels=self.RESIDUAL_FILTERS,
Expand Down