diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 2e1a98cc..0fc9179b 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -145,6 +145,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc..70d023fb 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -442,7 +442,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step(