diff --git a/tf/net.py b/tf/net.py index de779ca0..b052e139 100755 --- a/tf/net.py +++ b/tf/net.py @@ -11,6 +11,7 @@ LC0_MINOR_WITH_INPUT_TYPE_3 = 25 LC0_MINOR_WITH_INPUT_TYPE_4 = 26 LC0_MINOR_WITH_INPUT_TYPE_5 = 27 +LC0_MINOR_WITH_STATIC_SQUARES = 29 LC0_PATCH = 0 WEIGHTS_MAGIC = 0x1c0 @@ -78,6 +79,11 @@ def set_input(self, input_format): elif input_format != pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE: self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_3 + def set_input_static(self, input_static_format): + self.pb.format.network_format.input_static = input_static_format + if input_static_format != pb.NetworkFormat.INPUT_STATIC_NONE: + self.pb.min_version.minor = LC0_MINOR_WITH_STATIC_SQUARES + def get_weight_amounts(self): value_weights = 8 policy_weights = 6 diff --git a/tf/net_to_model.py b/tf/net_to_model.py index 02772f82..390ecb53 100755 --- a/tf/net_to_model.py +++ b/tf/net_to_model.py @@ -19,6 +19,10 @@ '--ignore-errors', action='store_true', help='Ignore missing and wrong sized values.') +argparser.add_argument('-a', + '--attempt_adapt', + action='store_true', + help='If -e, then try to reshape wrong sized rather than drop.') args = argparser.parse_args() cfg = yaml.safe_load(args.cfg.read()) print(yaml.dump(cfg, default_flow_style=False)) @@ -26,7 +30,7 @@ tfp = tfprocess.TFProcess(cfg) tfp.init_net() -tfp.replace_weights(args.net, args.ignore_errors) +tfp.replace_weights(args.net, args.ignore_errors, args.attempt_adapt) tfp.global_step.assign(START_FROM) root_dir = os.path.join(cfg['training']['path'], cfg['name']) diff --git a/tf/tfprocess.py b/tf/tfprocess.py index bcfc1c0d..57f6e682 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -57,7 +57,6 @@ def call(self, inputs): return tf.matmul(h_conv_pol_flat, tf.cast(self.fc1, h_conv_pol_flat.dtype)) - class Metric: def __init__(self, short_name, long_name, suffix='', **kwargs): self.short_name = short_name @@ -123,11 +122,13 @@ def __init__(self, cfg): value_head = self.cfg['model'].get('value', 'wdl') moves_left_head = self.cfg['model'].get('moves_left', 'v1') input_mode = self.cfg['model'].get('input_type', 'classic') + input_static_mode = self.cfg['model'].get('input_static_type', 'none') self.POLICY_HEAD = None self.VALUE_HEAD = None self.MOVES_LEFT_HEAD = None self.INPUT_MODE = None + self.INPUT_STATIC_MODE = None if policy_head == "classical": self.POLICY_HEAD = pb.NetworkFormat.POLICY_CLASSICAL @@ -183,6 +184,16 @@ def __init__(self, cfg): self.net.set_input(self.INPUT_MODE) + if input_static_mode == "none": + self.INPUT_STATIC_MODE = pb.NetworkFormat.INPUT_STATIC_NONE + elif input_static_mode == "squares": + self.INPUT_STATIC_MODE = pb.NetworkFormat.INPUT_STATIC_SQUARES + else: + raise ValueError( + "Unknown input static mode format: {}".format(input_static_mode)) + + self.net.set_input_static(self.INPUT_STATIC_MODE) + self.swa_enabled = self.cfg['training'].get('swa', False) # Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1) @@ -474,7 +485,21 @@ def accuracy(target, output): keep_checkpoint_every_n_hours=24, checkpoint_name=self.cfg['name']) - def replace_weights(self, proto_filename, ignore_errors=False): + def expand_trim(self, new_weight, weight): + # Trim first as the expand logic depends upon new_weight having no dimensions larger than weight. + for i in range(len(new_weight.shape)): + if new_weight.shape[i] > weight.shape[i]: + print('Triming index {} from {} to {}'.format(i, new_weight.shape[i], weight.shape[i])) + new_weight = new_weight[(slice(None), ) * i + (slice(weight.shape[i]), ) + (slice(None), ) * (len(new_weight.shape) - i - 1)] + for i in range(len(new_weight.shape)): + if new_weight.shape[i] < weight.shape[i]: + print('Expanding index {} from {} to {}'.format(i, new_weight.shape[i], weight.shape[i])) + extra = weight.shape[i] - new_weight.shape[i] + out_slice = sum([(slice(extra),) if j==i else (slice(new_weight.shape[j]), ) for j in range(len(new_weight.shape))], ()) + new_weight = tf.concat([new_weight, weight[out_slice]], axis=i) + return new_weight + + def replace_weights(self, proto_filename, ignore_errors=False, adapt_length=False): self.net.parse_proto(proto_filename) filters, blocks = self.net.filters(), self.net.blocks() @@ -510,11 +535,17 @@ def replace_weights(self, proto_filename, ignore_errors=False): else: raise KeyError(error_string) + length_wrong = False + if reduce(operator.mul, weight.shape.as_list(), 1) != len(new_weight): error_string = 'Tensor {} has wrong length. Tensorflow shape {}, size in protobuf {}'.format( weight.name, weight.shape.as_list(), len(new_weight)) - if ignore_errors: + if adapt_length and ignore_errors: + print(error_string) + print('Will attempt to reshape') + length_wrong = True + elif ignore_errors: print(error_string) continue else: @@ -530,6 +561,7 @@ def replace_weights(self, proto_filename, ignore_errors=False): if (i % (num_inputs * 9)) // 9 == rule50_input: new_weight[i] = new_weight[i] * 99 + # Convolution weights need a transpose # # TF (kYXInputOutput) @@ -539,8 +571,27 @@ def replace_weights(self, proto_filename, ignore_errors=False): # [output, input, filter_size, filter_size] s = weight.shape.as_list() shape = [s[i] for i in [3, 2, 0, 1]] + if length_wrong and adapt_length: + if shape[0] == self.RESIDUAL_FILTERS: + shape[0] = filters + # Input is a special case + if weight.name == 'input/conv2d/kernel:0': + if shape[1] == 176 and shape[0]*shape[1]*shape[2]*shape[3] > len(new_weight): + shape[1] = 112 + elif shape[1] == 112 and shape[0]*shape[1]*shape[2]*shape[3] < len(new_weight): + shape[1] = 176 + else: + if shape[1] == self.RESIDUAL_FILTERS: + shape[1] = filters + if shape[0]*shape[1]*shape[2]*shape[3] != len(new_weight): + print('Failed to reshape') + continue + new_weight = tf.constant(new_weight, shape=shape) - weight.assign(tf.transpose(a=new_weight, perm=[2, 3, 1, 0])) + new_weight = tf.transpose(a=new_weight, perm=[2, 3, 1, 0]) + if length_wrong: + new_weight = self.expand_trim(new_weight, weight) + weight.assign(new_weight) elif weight.shape.ndims == 2: # Fully connected layers are [in, out] in TF # @@ -548,11 +599,30 @@ def replace_weights(self, proto_filename, ignore_errors=False): # s = weight.shape.as_list() shape = [s[i] for i in [1, 0]] + if length_wrong and adapt_length: + if shape[0] == self.RESIDUAL_FILTERS: + shape[0] = filters + if shape[1] == self.RESIDUAL_FILTERS: + shape[1] = filters + if shape[0]*shape[1] != len(new_weight): + print('Failed to reshape') + continue new_weight = tf.constant(new_weight, shape=shape) - weight.assign(tf.transpose(a=new_weight, perm=[1, 0])) + new_weight = tf.transpose(a=new_weight, perm=[1, 0]) + if length_wrong: + new_weight = self.expand_trim(new_weight, weight) + weight.assign(new_weight) else: # Biases, batchnorm etc new_weight = tf.constant(new_weight, shape=weight.shape) + if length_wrong and adapt_length: + if shape[0] == self.RESIDUAL_FILTERS: + shape[0] = filters + if shape[0] != len(new_weight): + print('Failed to reshape') + continue + if length_wrong: + new_weight = self.expand_trim(new_weight, weight) weight.assign(new_weight) # Replace the SWA weights as well, ensuring swa accumulation is reset. if self.swa_enabled: @@ -1099,6 +1169,14 @@ def residual_block(self, inputs, channels, name): [inputs, out2])) def construct_net(self, inputs): + if self.INPUT_STATIC_MODE == pb.NetworkFormat.INPUT_STATIC_SQUARES: + inputs = tf.concat([ + inputs, + tf.reshape(tf.eye(64, batch_shape=[tf.shape(inputs)[0]]), + [-1, 64, 8, 8]) + ], + axis=-3) + flow = self.conv_block(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS,