-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grad Norm and SAFE encoding Misunderstanding #49
Comments
Hi @Anri-Lombard a few questions:
What happens when you use the same small model on this dataset ? |
Hi @maclandrol, thank you for the quick response. To address your last question first, no: I currently have access to 1 80GB A100, so am not training with multiple GPUs (thus not trying to replicate your larger model with the 1.1B molecule dataset). Thank you for clarifying the gradient norm; that's how I intuited it, although was unsure. It explains the results of more iterations I tried 🙏. To plot the validation loss I need to alter the safe library since the code currently plots perplexity only at the end (and there is no flag for intermediate recording of validation unless I'm mistaken?). The dataset mentioned is in smiles format, but I passed in the is_tokenized False flag; I did this for the small model when training on MOSES as well. I'm retraining the small model on this dataset, that is a great suggestion. My intuition on tokenizers could be stronger, but I suspected since this is a 20M molecule zinc dataset and you trained the original tokenizer on 1.1B molecules, of which a large subset was zinc, retraining the tokenizer for a smaller zinc subset won't change the results? Would you mind keeping the issue open for the time being? I can come back to record my findings for others if they happen upon the same situation once training is done. (For context, the batch size I used was 64 with 2 steps gradient accumulation - just to address your second point) |
I think you can potentially increase your batch size a bit. Also if your sequences are not very long after tokenization, try to reduce the model I would really suggest some light hyperparameter tuning here. You normally should be able to plot the validation loss if you use wandb. Just make sure that your dataset is a datasetdict and you have a key called
You are right that the tokenizer should work for both SAFE and SMILES strings and a good model should just ignore any tokens that is not in your training data. There is an argument that can be made about this being wasteful, as you can likely reduce the vocab and by that also reduce the model size and thus training time. Also, just to be sure, if you train of the 20M SMILES dataset without converting to SAFE first, you will get a SMILES model, as we don't automatically convert to allow reusing the same training code for any molecular line representation the user wants to use. |
I suspect that not getting I have fixed the default behaviour in #53, please let me know if you get the eval_loss with your setup now. |
When training a model on a different dataset, in this case (https://huggingface.co/datasets/sagawa/ZINC-canonicalized - somewhat larger than MOSES and quite a bit smaller than SAFE-GPT), the perplexity ends up very bad:
Looking into it further I discovered the grad_norm is very large despite explicitly setting max_grad_norm:
The model then does not generate any valid molecules and seems to overfit:
I tried adjusting the library myself and realised transformers set grad_norm to 1.0 by default, which made sense when I replicated your small model results since it stayed between 0 and 1 throughout and gave good results at the end.
Do you have a solution in mind? It might be that the default is ignored when doing warmup steps but I haven't found any evidence for this reading through the Trainer code (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py).
For more context this is a 50M parameter model with a learning rate of 1e-4 and the dataset is ~20M Zinc molecules. Do you have intuition what the problem might be?
The text was updated successfully, but these errors were encountered: