-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix training resumption #312
base: sichu/match-resumption-loss-curve
Are you sure you want to change the base?
Fix training resumption #312
Conversation
c8a7d1a
to
4575a61
Compare
sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_stop_and_go.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-testing/src/bionemo/testing/harnesses/stop_and_go.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some comments on stop and go, when addressed lgtm
68fa7cb
to
332a7b9
Compare
/build-ci |
Confirm that we can reproduce loss curve upon resumption despite a small jump on validation loss at the beginning. |
def create_dummy_parquet_train_val_inputs(tmp_path: Path) -> Tuple[Path, Path]: | ||
"""Create a mock protein train and val cluster parquet.""" | ||
train_cluster_path = tmp_path / "train_clusters.parquet" | ||
train_clusters = pd.DataFrame( | ||
{ | ||
"ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]], | ||
} | ||
) | ||
train_clusters.to_parquet(train_cluster_path) | ||
|
||
valid_cluster_path = tmp_path / "valid_clusters.parquet" | ||
valid_clusters = pd.DataFrame( | ||
{ | ||
"ur50_id": ["UniRef50_A", "UniRef50_B", "UniRef90_A", "UniRef90_B"], | ||
} | ||
) | ||
valid_clusters.to_parquet(valid_cluster_path) | ||
return train_cluster_path, valid_cluster_path | ||
|
||
|
||
def create_dummy_protein_dataset(tmp_path) -> Path: | ||
"""Create a mock protein dataset.""" | ||
if not isinstance(tmp_path, Path): | ||
tmp_path = Path(str(tmp_path)) | ||
|
||
db_file = tmp_path / "protein_dataset.db" | ||
conn = sqlite3.connect(str(db_file)) | ||
cursor = conn.cursor() | ||
|
||
cursor.execute( | ||
""" | ||
CREATE TABLE protein ( | ||
id TEXT PRIMARY KEY, | ||
sequence TEXT | ||
) | ||
""" | ||
) | ||
|
||
proteins = [ | ||
("UniRef90_A", "ACDEFGHIKLMNPQRSTVWY"), | ||
("UniRef90_B", "DEFGHIKLMNPQRSTVWYAC"), | ||
("UniRef90_C", "MGHIKLMNPQRSTVWYACDE"), | ||
("UniRef50_A", "MKTVRQERLKSIVRI"), | ||
("UniRef50_B", "MRILERSKEPVSGAQLA"), | ||
] | ||
cursor.executemany("INSERT INTO protein VALUES (?, ?)", proteins) | ||
|
||
conn.commit() | ||
conn.close() | ||
|
||
return db_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are testing-specific functions, let's not have them in the library itself
Broke down the PR into smaller pieces. Now we have a PR specifically tackles inconsistency stop-and-go v.s. continuous training curve. |
631a51e
to
4994705
Compare
/build-ci |
/build-ci |
/build-ci |
/build-ci |
/build-ci |
Summary
Loss curve from training resumption is inconsistent with a single uninterrupted loss curve.
Details
It has now been identified partly from incorrect datamodule behavior and is fixed with implementing
state_dict
for datamodule. However, there is still a very minor dip in validation loss curve that does not affect subsequent training curve.