Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some micro-optimizations for tf2. #94

Open
wants to merge 41 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6c63629
Silence some warnings in 1.14 at the expense of making 1.14 the minim…
Tilps Dec 7, 2019
be869dd
Merge remote-tracking branch 'upstream/master' into migrate_1.14
Tilps Dec 7, 2019
a9f7129
Update requirements.
Tilps Dec 7, 2019
7ad15ce
Minimal changes to get training running on 2.0
Tilps Dec 7, 2019
d170e20
Update requirements.
Tilps Dec 7, 2019
3fcbe2b
Remove uses of get_variable since its initializer don't seem to be ca…
Tilps Dec 7, 2019
162fa40
Fix typo.
Tilps Dec 7, 2019
ff2e256
Migration in progress some more.
Tilps Dec 8, 2019
f202c78
Add grad norm clipping back in.
Tilps Dec 8, 2019
1398d2a
Add swa support and test result reporting.
Tilps Dec 8, 2019
2d97b80
Actually sum the gradients...
Tilps Dec 8, 2019
cd3c69d
Fix bad bug with residual blocks, also include the stub for the incom…
Tilps Dec 8, 2019
d3f95eb
Constant policy map shouldn't be a variable, or it gets saved.
Tilps Dec 8, 2019
c90a9c4
Add untested network saving
Tilps Dec 9, 2019
61158e2
Fix permutation which didn't take into account the inserted gammas.
Tilps Dec 9, 2019
8681df6
GlobalAveragePooling2D needs to know the data format.
Tilps Dec 9, 2019
215445d
Use tf.function on the inner loop of training for massive speed up
Tilps Dec 9, 2019
bdec903
Extract test summary inner loop to tf.function for a bit more perform…
Tilps Dec 9, 2019
0b17356
Basic tensorboard summary data writting and some other cleanup relate…
Tilps Dec 11, 2019
43cdaf2
Remove some v1 code which I've done a second pass checking conversion.
Tilps Dec 11, 2019
53c6057
Add a missing _v2.
Tilps Dec 11, 2019
6892b1c
Fix bug in saving swa_weights that I hadn't tested before committing...
Tilps Dec 11, 2019
c610936
Add mixed precision support to tf2 version.
Tilps Dec 11, 2019
5a2a95b
Renorm also doesn't actually support fused.
Tilps Dec 11, 2019
6a37e9f
Re-add accuracy reporting.
Tilps Dec 11, 2019
f2f32a2
Some more cleanup of pre-v2 code that is converted or close enough to…
Tilps Dec 11, 2019
7cbe775
Readd basic update ratios.
Tilps Dec 11, 2019
47208df
Small performance optimization to offset the cost of update ratios.
Tilps Dec 11, 2019
d22b9b8
Optimize compute update ratio since adding update_ratio_log10 had a n…
Tilps Dec 11, 2019
ef06365
Some more cleanup, add weight histograms back.
Tilps Dec 12, 2019
753b918
Fix net saving for renorm mode.
Tilps Dec 12, 2019
72f7cb3
Add net_to_model and update_steps support. Remove more unsupported co…
Tilps Dec 12, 2019
9f7defc
Some more cleanup.
Tilps Dec 12, 2019
7823f0d
More cleanup.
Tilps Dec 12, 2019
5f4d8c4
Small fix needed to make script work in ubuntu.
Tilps Dec 12, 2019
da94d76
explicitly add the training flags to model calls - as apparently needed.
Tilps Dec 12, 2019
20fe10a
More cleanup.
Tilps Dec 12, 2019
3ac88fa
Pull iterator advancement out of the tf.function loops, its not suppo…
Tilps Dec 13, 2019
681ab74
Add experimental dataset loader.
Tilps Dec 13, 2019
121545f
Some micro-optimizations.
Tilps Dec 13, 2019
b016a64
Merge branch 'master' into migrate_2.0_part2
Tilps Dec 13, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def process_inner_loop(self, x, y, z, q):
value_loss = self.value_loss_fn(self.qMix(z, q), value)
return policy_loss, value_loss, mse_loss, reg_term, tape.gradient(total_loss, self.model.trainable_weights)

@tf.function()
def add_lists(self, x, y):
return [tf.math.add(a, b) for (a, b) in zip(x, y)]

def process_v2(self, batch_size, test_batches, batch_splits=1):
if not self.time_start:
self.time_start = time.time()
Expand Down Expand Up @@ -410,7 +414,7 @@ def process_v2(self, batch_size, test_batches, batch_splits=1):
if not grads:
grads = new_grads
else:
grads = [tf.math.add(a, b) for (a, b) in zip(grads, new_grads)]
grads = self.add_lists(grads, new_grads)
# Keep running averages
# Google's paper scales MSE by 1/4 to a [0, 1] range, so do the same to
# get comparable values.
Expand Down Expand Up @@ -490,16 +494,25 @@ def process_v2(self, batch_size, test_batches, batch_splits=1):
self.save_swa_weights_v2(swa_path)
print("SWA Weights saved in file: {}".format(swa_path))

def calculate_swa_summaries_v2(self, test_batches, steps):
@tf.function()
def switch_to_swa(self):
backup = self.read_weights()
for (swa, w) in zip(self.swa_weights, self.model.weights):
w.assign(swa.read_value())
return backup

@tf.function()
def restore_weights(self, backup):
for (old, w) in zip(backup, self.model.weights):
w.assign(old)

def calculate_swa_summaries_v2(self, test_batches, steps):
backup = self.switch_to_swa()
true_test_writer, self.test_writer = self.test_writer, self.swa_writer
print('swa', end=' ')
self.calculate_test_summaries_v2(test_batches, steps)
self.test_writer = true_test_writer
for (old, w) in zip(backup, self.model.weights):
w.assign(old)
self.restore_weights(backup)

@tf.function()
def calculate_test_summaries_inner_loop(self, x, y, z, q):
Expand Down Expand Up @@ -576,19 +589,17 @@ def compute_update_ratio_v2(self, before_weights, after_weights, steps):
ratios = [tf.cond(r > 0, lambda: tf.math.log(r) / 2.30258509299, lambda: 200.) for (_, r) in ratios]
tf.summary.histogram('update_ratios_log10', tf.stack(ratios), buckets=1000, step=steps)

@tf.function()
def update_swa_v2(self):
num = self.swa_count.read_value()
for (w, swa) in zip(self.model.weights, self.swa_weights):
swa.assign(swa.read_value() * (num / (num + 1.)) + w.read_value() * (1. / (num + 1.)))
self.swa_count.assign(min(num + 1., self.swa_max_n))
self.swa_count.assign(tf.math.minimum(num + 1., self.swa_max_n))

def save_swa_weights_v2(self, filename):
backup = self.read_weights()
for (swa, w) in zip(self.swa_weights, self.model.weights):
w.assign(swa.read_value())
backup = self.switch_to_swa()
self.save_leelaz_weights_v2(filename)
for (old, w) in zip(backup, self.model.weights):
w.assign(old)
self.restore_weights(backup)

def save_leelaz_weights_v2(self, filename):
all_tensors = []
Expand Down