Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save only the best model files #1787

Open
ricardo-ervilha opened this issue Jun 25, 2024 · 2 comments
Open

Save only the best model files #1787

ricardo-ervilha opened this issue Jun 25, 2024 · 2 comments

Comments

@ricardo-ervilha
Copy link

Hey everyone,

I'm looking for a way to save only the best model during training, discarding the others as updates occur. Currently, I’m doing it like this:

checker = dde.callbacks.ModelCheckpoint(
    "model/model.ckpt", save_better_only=True, period=1000, verbose=1
)

losshistory, trainstate = model.train(
    iterations=ITERATIONS,
    batch_size=BATCH_SIZE,
    callbacks=[checker]
)

model.restore("model/model.ckpt-" + str(trainstate.best_step) + ".ckpt", verbose=1)

But this approach saves multiple files like:

model.ckpt-1000.ckpt.data-00000-of-00001,
model.ckpt-2000.ckpt.data-00000-of-00001,
...

These extra files are unnecessary for me since I only need the weights of the best model at the end. Is there a way to always overwrite just one file that stores the weights for the best loss?

@praksharma
Copy link
Contributor

I think you have to save the weights of model when the training loss in the current iteration is less than the previous iteration. You have to modify the source code.

If you are using adam, you need to modify the _train_sgd() defined in the Model class. Here you can find the for-loop of the training. You can create a new model name bestPINN and copy the weights using copy.deepcopy from model.net(state_dict()) where model.net() is the network you pass to dde.Model().

Now you can easily wrap this weight copy technique in a if-condition where you can compare the training loss in current and last iteration. You can compute the training loss by summing the entries in model.train_state.loss_train, which is a list.

Few years ago, I did the exact same thing with PINNs but in PyTorch (not deepxde). Here, you can find the relevant code in section named main.

@ricardo-ervilha
Copy link
Author

Thank you very much for the response, your implementation in PyTorch helped me a lot !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants