Skip to content

Commit

Permalink
Update for v0.6.0 (see patch notes for more details!)
Browse files Browse the repository at this point in the history
  • Loading branch information
tysam-code committed Apr 17, 2023
1 parent 50eab05 commit 3bb104c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ Goals:
* torch- and python-idiomatic
* hackable
* few external dependencies (currently only torch and torchvision)
* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!) seconds on an A100, down from ~18.1 seconds originally).
* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!!) seconds on an A100, down from ~18.1 seconds originally).
* <2 seconds training time in <2 years (yep!)

This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, though it will likely get to be _much_ harder for us to continue to improve our speeds significantly from here on out.
This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, for example.

What we've added:
* squeeze and excite layers
* custom architecture that is somehow even faster
* way too much hyperparameter tuning
* miscellaneous architecture trimmings (see the patch notes)
* memory format changes (and more!) to better use tensor cores/etc
* dirac initializations on non-depth-transitional layers (information passthrough on init)
* and more!

This code, in comparison to David's original code, is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to be user-friendly. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support is encouraged. Share this repo with someone you know that might like it!
Expand Down
41 changes: 27 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False}

batchsize = 1024
bias_scaler = 48
# To replicate the ~95.80%-accuracy-in-120-seconds runs, you can change the base_depth from 64->128, num_epochs from 10->90, ['ema'] epochs 9->78, and cutmix_size 0->9
bias_scaler = 56
# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75
hyp = {
'opt': {
'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
'non_bias_lr': 1.64 / 512,
'bias_decay': 1.05 * 6.45e-4 * batchsize/bias_scaler,
'non_bias_decay': 1.05 * 6.45e-4 * batchsize,
'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler,
'non_bias_decay': 1.08 * 6.45e-4 * batchsize,
'scaling_factor': 1./9,
'percent_start': .23,
'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
Expand All @@ -61,19 +61,20 @@
'num_examples': 50000,
},
'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
'conv_norm_pow': 2.6,
'cutmix_size': 3,
'cutmix_epochs': 5,
'cutmix_epochs': 6,
'pad_amount': 2,
'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
},
'misc': {
'ema': {
'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
'decay_base': .95,
'decay_pow': 3,
'decay_pow': 3.,
'every_n_steps': 5,
},
'train_epochs': 12.6,
'train_epochs': 12.1,
'device': 'cuda',
'data_location': 'data.pt',
}
Expand Down Expand Up @@ -189,14 +190,14 @@ def forward(self, x):

# can hack any changes to each residual group that you want directly in here
class ConvGroup(nn.Module):
def __init__(self, channels_in, channels_out):
def __init__(self, channels_in, channels_out, norm):
super().__init__()
self.channels_in = channels_in
self.channels_out = channels_out

self.pool1 = nn.MaxPool2d(2)
self.conv1 = Conv(channels_in, channels_out)
self.conv2 = Conv(channels_out, channels_out)
self.conv1 = Conv(channels_in, channels_out, norm=norm)
self.conv2 = Conv(channels_out, channels_out, norm=norm)

self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
Expand Down Expand Up @@ -337,9 +338,9 @@ def make_net():
'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step
'activation': nn.GELU(),
}),
'residual1': ConvGroup(depths['init'], depths['block1']),
'residual2': ConvGroup(depths['block1'], depths['block2']),
'residual3': ConvGroup(depths['block2'], depths['block3']),
'residual1': ConvGroup(depths['init'], depths['block1'], hyp['net']['conv_norm_pow']),
'residual2': ConvGroup(depths['block1'], depths['block2'], hyp['net']['conv_norm_pow']),
'residual3': ConvGroup(depths['block2'], depths['block3'], hyp['net']['conv_norm_pow']),
'pooling': FastGlobalMaxPooling(),
'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.),
'temperature': TemperatureScaler(hyp['opt']['scaling_factor'])
Expand All @@ -364,6 +365,18 @@ def make_net():
## the index lookup in the dataloader may give you some trouble depending
## upon exactly how memory-limited you are

## We initialize the projections layer to return exactly the spatial inputs, this way we start
## at a nice clean place (the whitened image in feature space, directly) and can iterate directly from there.
torch.nn.init.dirac_(net.net_dict['initial_block']['project'].weight)

for layer_name in net.net_dict.keys():
if 'residual' in layer_name:
## We do the same for the second layer in each residual block, since this only
## adds a simple multiplier to the inputs instead of the noise of a randomly-initialized
## convolution. This can be easily scaled down by the network, and the weights can more easily
## pivot in whichever direction they need to go now.
torch.nn.init.dirac_(net.net_dict[layer_name].conv2.weight)

return net

#############################################
Expand Down Expand Up @@ -547,7 +560,7 @@ def main():
## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 --
## This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training)
initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
final_lr_ratio = .05 # Actually pretty important, apparently!
final_lr_ratio = .07 # Actually pretty important, apparently!
lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)
lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps, anneal_strategy='linear', cycle_momentum=False)

Expand Down

0 comments on commit 3bb104c

Please sign in to comment.