Skip to content

Commit

Permalink
Merge pull request #1 from tysam-code/cleaner-logging
Browse files Browse the repository at this point in the history
Updated the logging to better align with repo goals, README.md
  • Loading branch information
tysam-code authored Dec 29, 2022
2 parents d683ee9 + bd341d7 commit 132829f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 37 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Welcome to the hyperlightspeedbench CIFAR-10 (HLB-CIFAR10) repo.
`git clone https://github.com/tysam-code/hlb-CIFAR10 && cd hlb-CIFAR10 && python -m pip install -r requirements.txt && python main.py`


If you're curious, this code is generally Colab friendly and is built to appropriately reset state without having to reload the instance (in fact -- most of this was developed in Colab!)
If you're curious, this code is generally Colab friendly (in fact -- most of this was developed in Colab!). Just be sure to uncomment the reset block at the top of the code.


### Main
Expand All @@ -22,10 +22,10 @@ Goals:
* near world-record single-GPU training time (~<18.1 seconds on an A100) .
* <2 seconds training time in <2 years

This is a neural network implementation that recreates and reproduces from nearly the ground-up in a painstakingly accurate manner a hacking-friendly version of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/) -- 94% accuracy in ~<18.1 seconds on an A100 GPU. There is only one primary functional difference that I am aware of. The code has been rewritten practically from scratch in an annotated, hackable flat structure that for me has been extremely fast to prototype ideas in. This code took about 120-130 hours of work from start to finish, and about about 80-90+ of those hours were mind-numbingly tedious debugging of the minutia between my implementation and David's implementation. It turns out that there are so many little things to consider to actually achieve and hold the accuracy David achieved, I find it an interesting balance of tons of wiggle room in places and none at all in others.
This is a neural network implementation that painstakingly reproduces from nearly the ground-up a hacking-friendly version of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/) -- 94% accuracy in ~<18.1 seconds on an A100 GPU. There is only one primary functional difference that I am aware of. The intended structure of the code is a flat structure intended for quick hacking in practically _any_ (!!!) stage of the training pipeline. This code took about 120-130 hours of work from start to finish, and about about 80-90+ of those hours were mind-numbingly tedious debugging of performance differences between my work David's original work. It was somewhat surprising in places which things really mattered, and which did not. To that end, I found it very educational to write (and may do a writeup someday if enough people and I have enough interest in it).


I built this because I loved David's work but for for my personal experimentation, his nearly-purely-functional style made implementing radical idea sketches nearly impossible. As a complement to his work, this code is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. The upside is that since making this repository, I've already gone from idea-to-new-single-GPU-world-record in under 10 minutes for one idea, and maybe under an hourish for doing the same thing a second, different idea as well. I personally find this code a delight to use, and hope you do too! :D Please let me know, whichever way it ends up going for you. I hope to publish those updates in the future, but for now, this is a (relatively) accurate baseline.
I built this because I loved David's work but found it difficult for my quick-experiment-and-hacking usecases. As a complement to his work, this code is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. The upside for me in this repository is that I've already been able to explore a wide variety of ideas rapidly, some of which already improve over the baseline (hopefully more of that in future releases). I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, but for now, this is a (relatively) accurate baseline.


Your support helps a lot -- even if it's a dollar as month. I have several more projects I'm in various stages on, and you can help me have the money and time to get them to the finish line! If you like what I'm doing, or this project has brought you some value, please consider subscribing on my [Patreon](https://www.patreon.com/user/posts?u=83632131). There's not too many extra rewards besides better software more frequently. Alternatively, if you want me to work up to a part-time amount of hours with you, feel free to reach out to me at hire.tysam@gmail.com. I'd love to hear from you.
Expand All @@ -49,4 +49,4 @@ Currently, submissions to this codebase as a benchmark are closed as we figure o

#### Bugs & Etc.

If you find a bug, open an issue! L:D If you have a success story, let me know! It helps me understand what works and doesn't more than you might expect -- if I know how this is specifically helping people, that can help me further improve as a developer, as I can keep that in mind when developing other software for people in the future. :D :)
If you find a bug, open an issue! L:D If you have a success story, let me know! It helps me understand what works and doesn't more than you might expect -- if I know how this is specifically helping people, that can help me further improve as a developer, as I can keep that in mind when developing other software for people in the future. :D :)
81 changes: 48 additions & 33 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Bug: Currently we're having some trouble with differentiating between a colab notebook and a .py file, so right now you'll just need to uncomment the below if you're wanting to properly clear your variable state in your colab notebook each time.
# BUG: Currently, I haven't found a good way to distinguish between python and ipython, so we're leaving the 'colab mode' to require a manual
# uncomment of this code to fully guard against state errors. This lets people run the one-line download-and-train command appropriately.
"""
# If we are in an ipython session or a notebook, clear the state to avoid bugs
try:
Expand All @@ -17,11 +18,6 @@
import torch.nn.functional as F
from torch import nn

import rich
from rich.live import Live
from rich.progress import Progress, Console, track
from rich.table import Table

import torchvision
from torchvision import transforms

Expand Down Expand Up @@ -457,11 +453,26 @@ def init_split_parameter_dictionaries(network):
## Hey look, it's the soft-targets/label-smoothed loss! Native to PyTorch. Now, _that_ is pretty cool, and simplifies things a lot, to boot! :D :)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2, reduction='none')

console = Console()
progress_bar = Progress()
progress_table = Table()
for column_name in ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'val_ema_acc', 'timing (s)']:
progress_table.add_column(column_name)
logging_columns_list = ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'ema_val_acc', 'total_time_seconds']
# define the printing function and print the column heads
def print_training_details(columns_list, separator_left='| ', separator_right=' ', final="|", column_heads_only=False, is_final_entry=False):
print_string = ""
if column_heads_only:
for column_head_name in columns_list:
print_string += separator_left + column_head_name + separator_right
print_string += final
print('-'*(len(print_string))) # print the top bar
print(print_string)
print('-'*(len(print_string))) # print the bottom bar
else:
for column_value in columns_list:
print_string += separator_left + column_value + separator_right
print_string += final
print(print_string)
if is_final_entry:
print('-'*(len(print_string))) # print the final output bar

print_training_details(logging_columns_list, column_heads_only=True) # print out the training column heads.

########################################
# Train and Eval #
Expand All @@ -471,13 +482,13 @@ def init_split_parameter_dictionaries(network):
def main():
# Initializing constants for the whole run.
net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)
## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)

total_time = 0.
total_time_seconds = 0.
current_steps = 0.

# TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize)....
num_steps_per_epoch = len(data['train']['images']) // batchsize # todo: a tad bit of cleanup here. :) :>
num_steps_per_epoch = len(data['train']['images']) // batchsize # todo: a bit of a tad of cleanup here. ::::))) :>>>
total_train_steps = num_steps_per_epoch * hyp['misc']['train_epochs']
ema_epoch_start = hyp['misc']['train_epochs'] - hyp['misc']['ema']['epochs']
num_low_lr_steps_for_ema = hyp['misc']['ema']['epochs'] * num_steps_per_epoch
Expand Down Expand Up @@ -506,18 +517,16 @@ def main():
lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)
lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)


## For accurately timing GPU code
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
## There's another repository that's mainly reorganized David's code while still maintaining some of the functional structure, and it
## has a timing feature too, but there's no synchronizes so I suspect the times reported are much faster than they may be in actuality
## due to some of the quirks of timing GPU operations.
torch.cuda.synchronize() ## clean up any pre-net setup operations

# Hack -- disabling the below until we figure out how to properly nest these tables during training.
#with Live(progress_table, refresh_per_second=1):
if True:
for epoch in track(range(hyp['misc']['train_epochs'])):

if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent
for epoch in range(hyp['misc']['train_epochs']):
#################
# Training Mode #
#################
Expand All @@ -539,8 +548,8 @@ def main():

# we only take the last-saved accs and losses from train
if epoch_step % 50 == 0:
accuracy_train = (outputs.detach().argmax(-1) == targets).float().mean().item()
loss_train = loss.cpu().item()/batchsize
train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item()
train_loss = loss.detach().cpu().item()/batchsize

loss.backward()

Expand All @@ -567,7 +576,7 @@ def main():

ender.record()
torch.cuda.synchronize()
total_time += 1e-3 * starter.elapsed_time(ender)
total_time_seconds += 1e-3 * starter.elapsed_time(ender)

####################
# Evaluation Mode #
Expand All @@ -579,7 +588,7 @@ def main():
loss_list_val, acc_list, acc_list_ema = [], [], []

with torch.no_grad():
# TODO: Copy is probably slow, we can def avoid this I think....
# TODO: Copy is probably slow, we can def avoid this somehow, I think....
for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize):
if epoch >= ema_epoch_start:
outputs = net_ema(inputs)
Expand All @@ -588,18 +597,24 @@ def main():
loss_list_val.append(loss_fn(outputs, targets).float().mean())
acc_list.append((outputs.argmax(-1) == targets).float().mean())

accuracy_non_ema = torch.stack(acc_list).mean().item()
accuracy_ema = None
val_acc = torch.stack(acc_list).mean().item()
ema_val_acc = None
# TODO: We can fuse these two operations (just above and below) all-together like :D :))))
if epoch >= ema_epoch_start:
accuracy_ema = torch.stack(acc_list_ema).mean().item()

loss_val = torch.stack(loss_list_val).mean().item()

format_for_table = lambda x: "{0:.4f}".format(x) if x is not None else None
progress_table.add_row(*list(map(format_for_table, [epoch, loss_train, loss_val, accuracy_train, accuracy_non_ema, accuracy_ema, total_time])))
#progress_bar.advance() ## Now that we've finish everything, we update our eta (and we do it last so the ETA is more accurate.)
console.print(progress_table)
ema_val_acc = torch.stack(acc_list_ema).mean().item()

val_loss = torch.stack(loss_list_val).mean().item()
# We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width.
## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even
## more heinous than this, unfortunately. So we switched to the "more simple" version of this!
format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \
if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \
if locals[x] is not None \
else " "*len(x)

# Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.
print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1))

if __name__ == "__main__":
for run_num in range(5):
Expand Down

0 comments on commit 132829f

Please sign in to comment.