Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ScaleModule and SumModule for DiagHessian. #317

Open
wants to merge 4 commits into
base: development
Choose a base branch
from

Conversation

hlzl
Copy link

@hlzl hlzl commented Sep 5, 2023

Partially fixes #316 . ScaleModule is also used for torch.nn.Identity.

Not sure if hessian_is_zero() should always return True for those two modules.
Same with accumulate_backpropagated_quantities() which concats dicts instead of tensors as for the DiagGGN.

@hlzl hlzl changed the title fix: ScaleModule for DiagHessian. fix: ScaleModule and SumModule for DiagHessian. Sep 5, 2023
@hlzl
Copy link
Author

hlzl commented Sep 5, 2023

Commit 74c4173 allows to compute the Hessian diagonal even if there is a batch norm in the network by simply not computing the Hessian elements for the batch norm layer.

Not sure if this is a reasonable approach, however, this can be used as a quick fix.

The other diagonal elements can then be extracted as following:

hessian_diag_wo_bn = torch.cat(
    [
        p.diag_h_batch.view(batch.shape[0], -1)
        for p in model.parameters()
        if "diag_h_batch" in p.__dict__.keys()
    ],
    dim=1,
)

@hlzl
Copy link
Author

hlzl commented Sep 5, 2023

Commit 48f03e9 tries to actually compute the diagonal elements of the Hessian for the batch norm layer.

If one of you could have a quick look at the commits to see if they make any sense, would really appreciate it. @f-dangel @fKunstner

Thank you!

@f-dangel
Copy link
Owner

Hi,

just wanted to let you know I read your message above. Please don't expect any reaction before the ICLR deadline (Sep 28)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants