Skip to content

Commit

Permalink
import MegatronDataModule
Browse files Browse the repository at this point in the history
  • Loading branch information
sichu2023 committed Oct 25, 2024
1 parent 740f776 commit 714b6ef
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from bionemo.esm2.data import dataset, tokenizer
from bionemo.llm.data import collate
from bionemo.llm.data.datamodule import MegatronDatamodule
from bionemo.llm.data.datamodule import MegatronDataModule
from bionemo.llm.utils.datamodule_utils import infer_num_samples


Expand Down Expand Up @@ -180,6 +180,7 @@ def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader
Args:
dataset: The dataset to create the dataloader for.
mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train).
**kwargs: Additional arguments to pass to the dataloader.
"""
self.update_init_global_step()
assert self._tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import collate
from bionemo.llm.data.datamodule import MegatronDatamodule
from bionemo.llm.data.datamodule import MegatronDataModule
from bionemo.llm.utils.datamodule_utils import infer_num_samples


Expand Down Expand Up @@ -180,6 +180,13 @@ def test_dataloader(self) -> EVAL_DATALOADERS: # noqa: D102
return self._create_dataloader(self._test_ds)

def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader:
"""Create dataloader for train, validation, and test stages.
Args:
dataset: The dataset to create the dataloader for.
mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train).
**kwargs: Additional arguments to pass to the dataloader.
"""
self.update_init_global_step()
return WrappedDataLoader(
mode=mode,
Expand Down

0 comments on commit 714b6ef

Please sign in to comment.