Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About OOM issue related to 'create_graph=True' #7

Open
DOHA-HWANG opened this issue Jan 26, 2022 · 4 comments
Open

About OOM issue related to 'create_graph=True' #7

DOHA-HWANG opened this issue Jan 26, 2022 · 4 comments

Comments

@DOHA-HWANG
Copy link

Dear Author,

Above all, thank you for sharing nice codes.

BTW, about quant training on CIFAR10,
Have you ever faced with OOM issues by loss.backward(create_graph=True) in update_grad_scales?
When I tried it by below args, I was faced with the "RuntimeError: CUDA out of memory" issue.

python train_quant.py --gpu_id '0'
--weight_levels 8
--act_levels 8
--baseline False
--use_hessian True
--load_pretrain True
--pretrain_path '../results/ResNet20_CIFAR10/fp/checkpoint/last_checkpoint.pth'
--log_dir '../results/ResNet20_CIFAR10/ours(hess)/W8A8/

Do you have some idea to avoid this issue?

Thank you in advance.

@Joejwu
Copy link

Joejwu commented May 11, 2022

So have you solved the problem? I have the same problem

@kartikgupta-at-anu
Copy link

I run into same issue when training on ImageNet.

@DOHA-HWANG
Copy link
Author

I haven't used this code recently. That's why I can't remember clearly how to avoid this problem.
However, when I checked the last private code, I disabled "--use_hessian False". (And I remember I had success training this project before.)
As far as I know, using hessian was not crucial in this project according to their paper experiments.

@kartikgupta-at-anu
Copy link

kartikgupta-at-anu commented May 24, 2022

use_hessian is important if we want the scale factors in EWGS equation to be based on hessian.
I figured out a solution to this problem. The code in this part is not well-written and thus there are some references left dangling due to which it keeps on accumulating graphs. One possible solution/hack to avoid OOM issue is to do a loss.backward() in utils.py after line 116. This will release the graph after each batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants