-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LSTM Timeseries prediction example #1532
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1532 +/- ##
==========================================
- Coverage 86.38% 86.31% -0.07%
==========================================
Files 693 683 -10
Lines 80473 78091 -2382
==========================================
- Hits 69519 67408 -2111
+ Misses 10954 10683 -271 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice example, ping me when you think it's going to be ready for a review.
@nathanielsimard Hello, I am always interested in the implementation of lstm in burn. I still think lstm is buggy right now. If a linear layer is added after the lstm, the parameters of the lstm and all layers before it will not be updated during training. I've been stuck on this problem for a long time. The example of using lstm in this PR further confirms that lstm does have problems. I add some code in let pjr = PrettyJsonFileRecorder::<FullPrecisionSettings>::new();
model.input_layer.clone().save_file("./input-before.json", &pjr).unwrap();
model.lstm.clone().save_file("./lstm-before.json", &pjr).unwrap();
model.output_layer.clone().save_file("./output-before.json", &pjr).unwrap();
// ......
model_trained.input_layer.clone().save_file("./input-after.json", &pjr).unwrap();
model_trained.lstm.clone().save_file("./lstm-after.json", &pjr).unwrap();
model_trained.output_layer.clone().save_file("./output-after.json", &pjr).unwrap(); After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit. |
I was hoping I could spark the development of the LSTM implementation a bit with an example. I would love to use Burn for this purpose as well.
Happy to incorporate your suggestions! Feel free to create a PR that makes changes to this branch. |
@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the We already want to implement a client/server architecture in |
@nathanielsimard do we have a separate ticket of "planned fix"? It would go to track and link it here. |
@nathanielsimard Kinda off topic but it would be cool to have a generic |
@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you. |
Thanks, I created a separate issue to discuss the details 👍 |
This PR has been marked as stale because it has not been updated for over a month |
@NicoZweifel Hey 👋 I'm going through opened issues/PRs right now, looks like there hasn't been a lot of activity here for a while. I'll close the draft PR but if you want to take it up eventually and need a review feel free to reopen and ping us 🙏 |
@laggui Thanks will do 👍. I've made some progress in the meantime on my local fork but I need to finish a separate feature PR, as well as update/maintain this one before re-opening since it also contains some changes to I'd love to reopen and finish this eventually. |
Checklist
run-checks all
script has been executed.Changes
LSTM
that was added in Feat/lstm #370, using a Partial Dataset from Huggingface.SqliteDataset
.I have not narrowed it down yet as I am using custom Datasets on my other burn project and they work fine (
InMemory
with data from alphavantage). I might need to spend some more time on it to figure it out but since it doesn't block me in my other goals and the example seems to work with 10000 entries I though I could publish this as a draft for now.Testing
cargo run --example lstm --features tch-cpu