From 3bb104ce16d169e14783c406a3ac9b5a1e5f6678 Mon Sep 17 00:00:00 2001 From: TySam& B Date: Mon, 17 Apr 2023 00:06:54 -0400 Subject: [PATCH] Update for v0.6.0 (see patch notes for more details!) --- README.md | 7 ++++--- main.py | 41 +++++++++++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index a064b94..fbf63aa 100644 --- a/README.md +++ b/README.md @@ -21,16 +21,17 @@ Goals: * torch- and python-idiomatic * hackable * few external dependencies (currently only torch and torchvision) -* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!) seconds on an A100, down from ~18.1 seconds originally). +* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!!) seconds on an A100, down from ~18.1 seconds originally). * <2 seconds training time in <2 years (yep!) -This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, though it will likely get to be _much_ harder for us to continue to improve our speeds significantly from here on out. +This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, for example. What we've added: -* squeeze and excite layers +* custom architecture that is somehow even faster * way too much hyperparameter tuning * miscellaneous architecture trimmings (see the patch notes) * memory format changes (and more!) to better use tensor cores/etc +* dirac initializations on non-depth-transitional layers (information passthrough on init) * and more! This code, in comparison to David's original code, is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to be user-friendly. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support is encouraged. Share this repo with someone you know that might like it! diff --git a/main.py b/main.py index 7a55a92..1e5bb12 100644 --- a/main.py +++ b/main.py @@ -43,14 +43,14 @@ default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False} batchsize = 1024 -bias_scaler = 48 -# To replicate the ~95.80%-accuracy-in-120-seconds runs, you can change the base_depth from 64->128, num_epochs from 10->90, ['ema'] epochs 9->78, and cutmix_size 0->9 +bias_scaler = 56 +# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75 hyp = { 'opt': { 'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :')))) 'non_bias_lr': 1.64 / 512, - 'bias_decay': 1.05 * 6.45e-4 * batchsize/bias_scaler, - 'non_bias_decay': 1.05 * 6.45e-4 * batchsize, + 'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler, + 'non_bias_decay': 1.08 * 6.45e-4 * batchsize, 'scaling_factor': 1./9, 'percent_start': .23, 'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :) @@ -61,8 +61,9 @@ 'num_examples': 50000, }, 'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( ) + 'conv_norm_pow': 2.6, 'cutmix_size': 3, - 'cutmix_epochs': 5, + 'cutmix_epochs': 6, 'pad_amount': 2, 'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly }, @@ -70,10 +71,10 @@ 'ema': { 'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too 'decay_base': .95, - 'decay_pow': 3, + 'decay_pow': 3., 'every_n_steps': 5, }, - 'train_epochs': 12.6, + 'train_epochs': 12.1, 'device': 'cuda', 'data_location': 'data.pt', } @@ -189,14 +190,14 @@ def forward(self, x): # can hack any changes to each residual group that you want directly in here class ConvGroup(nn.Module): - def __init__(self, channels_in, channels_out): + def __init__(self, channels_in, channels_out, norm): super().__init__() self.channels_in = channels_in self.channels_out = channels_out self.pool1 = nn.MaxPool2d(2) - self.conv1 = Conv(channels_in, channels_out) - self.conv2 = Conv(channels_out, channels_out) + self.conv1 = Conv(channels_in, channels_out, norm=norm) + self.conv2 = Conv(channels_out, channels_out, norm=norm) self.norm1 = BatchNorm(channels_out) self.norm2 = BatchNorm(channels_out) @@ -337,9 +338,9 @@ def make_net(): 'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step 'activation': nn.GELU(), }), - 'residual1': ConvGroup(depths['init'], depths['block1']), - 'residual2': ConvGroup(depths['block1'], depths['block2']), - 'residual3': ConvGroup(depths['block2'], depths['block3']), + 'residual1': ConvGroup(depths['init'], depths['block1'], hyp['net']['conv_norm_pow']), + 'residual2': ConvGroup(depths['block1'], depths['block2'], hyp['net']['conv_norm_pow']), + 'residual3': ConvGroup(depths['block2'], depths['block3'], hyp['net']['conv_norm_pow']), 'pooling': FastGlobalMaxPooling(), 'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.), 'temperature': TemperatureScaler(hyp['opt']['scaling_factor']) @@ -364,6 +365,18 @@ def make_net(): ## the index lookup in the dataloader may give you some trouble depending ## upon exactly how memory-limited you are + ## We initialize the projections layer to return exactly the spatial inputs, this way we start + ## at a nice clean place (the whitened image in feature space, directly) and can iterate directly from there. + torch.nn.init.dirac_(net.net_dict['initial_block']['project'].weight) + + for layer_name in net.net_dict.keys(): + if 'residual' in layer_name: + ## We do the same for the second layer in each residual block, since this only + ## adds a simple multiplier to the inputs instead of the noise of a randomly-initialized + ## convolution. This can be easily scaled down by the network, and the weights can more easily + ## pivot in whichever direction they need to go now. + torch.nn.init.dirac_(net.net_dict[layer_name].conv2.weight) + return net ############################################# @@ -547,7 +560,7 @@ def main(): ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 -- ## This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training) initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D - final_lr_ratio = .05 # Actually pretty important, apparently! + final_lr_ratio = .07 # Actually pretty important, apparently! lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False) lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)