Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate restart broken with Nanoset? #233

Open
Pclanglais opened this issue Sep 25, 2024 · 5 comments
Open

Learning rate restart broken with Nanoset? #233

Pclanglais opened this issue Sep 25, 2024 · 5 comments

Comments

@Pclanglais
Copy link

Retraining on checkpoint works perfectly with the tokenization on the fly, but breaks while using nanoset: training restart with a different lr, which is not the same as lr_schedule.pt

We also have two additional issues that are likely connected:

  • Loading a different staging dataset results in several anomalous messages (same number of tokens as the previous one, 0 steps remaining). It's not clear if it is properly loaded at all based on these information.
  • Training continues even when there are no tokens remaining (probably looping on the past tokens?).

Training tested with this configuration:

checkpoints:
  checkpoint_interval: 1500  # Adjusted to save checkpoints less frequently
  checkpoints_path: checkpoints_marianne_300m_pretrain_mini
  checkpoints_path_is_shared_file_system: false
  resume_checkpoint_path: checkpoints_marianne_300m_pretrain_mini
  save_initial_state: false
data_stages:
- data:
    dataset:
      dataset_folder: /lustre/fsn1/projects/rech/fmr/uft12cr/mini_corpus_base_300m_tokenized
    num_loading_workers: 96
    seed: 42
  name: Base corpus
  start_training_step: 1
- data:
    dataset:
      dataset_folder: /lustre/fsn1/projects/rech/fmr/uft12cr/mini_corpus_annealing_300m_tokenized
    num_loading_workers: 96
    seed: 42
  name: Annealing corpus
  start_training_step: 4501
general:
  benchmark_csv_path: null
  consumed_train_samples: null
  ignore_sanity_checks: true
  project: pretrain
  run: marianne_3b_pretrain_%date_%jobid
  seed: 42
  step: null
lighteval: null
logging:
  iteration_step_info_interval: 1
  log_level: info
  log_level_replica: info
model:
  ddp_bucket_cap_mb: 25
  dtype: bfloat16
  init_method:
    std: 0.025
  make_vocab_size_divisible_by: 1
  model_config:
    bos_token_id: 1
    eos_token_id: 2
    hidden_act: silu
    hidden_size: 960
    initializer_range: 0.02
    intermediate_size: 2560
    is_llama_config: true
    max_position_embeddings: 4096
    num_attention_heads: 15
    num_hidden_layers: 32
    num_key_value_heads: 5
    pad_token_id: null
    pretraining_tp: 1
    rms_norm_eps: 1.0e-05
    rope_scaling: null
    tie_word_embeddings: true
    use_cache: true
    vocab_size: 65536
    rope_theta: 500000 
optimizer:
  accumulate_grad_in_fp32: true
  clip_grad: 1.0
  learning_rate_scheduler:
    learning_rate: 0.003
    lr_decay_starting_step: 1501  # Start decay after warmup
    lr_decay_steps: 6000  # Decay over the remaining 80% of training
    lr_decay_style: cosine
    lr_warmup_steps: 1500  # 20% warmup
    lr_warmup_style: linear
    min_decay_lr: 0.0001
  optimizer_factory:
    adam_beta1: 0.9
    adam_beta2: 0.95
    adam_eps: 1.0e-08
    name: adamW
    torch_adam_is_fused: true
  weight_decay: 0.01
  zero_stage: 1
parallelism:
  dp: 32
  expert_parallel_size: 1
  pp: 1
  pp_engine: 1f1b
  tp: 1
  tp_linear_async_communication: false
  tp_mode: ALL_REDUCE
profiler: null
tokenizer:
  tokenizer_max_length: null
  tokenizer_name_or_path: /lustre/fswork/projects/rech/fmr/uft12cr/tokenizer/pleias_300m_65k_tokenizer
  tokenizer_revision: null
tokens:
  batch_accumulation_per_replica: 2
  limit_test_batches: 0
  limit_val_batches: 0
  micro_batch_size: 8
  sequence_length: 4096
  train_steps: 7500
  val_check_interval: -1```
@eliebak
Copy link
Contributor

eliebak commented Sep 25, 2024

Hey thanks for opening the issue, can you add the error message that you get and the log ?

@Pclanglais
Copy link
Author

Pclanglais commented Sep 25, 2024

Here they are:

  • start_run.out: 300 initial steps interrupted.
  • restart_run.out: 100 more steps with the wrong lr.

(switched to .txt due to github constraints)

start_run.txt

restart_run.txt

lr_scheduler.txt

@pchizhov
Copy link

Hi! About the "0 steps remaining" in this issue: here

if metadata.last_train_step > stage.start_training_step:
# NOTE: if the last_train_step is larger than the start_training_step of the current stage,
# it means that the training has already passed this stage
# so there is no remaining steps
return 0
there seems to be a bug. It returns that 0 steps are remaining when the current training step is larger than the first step of the stage. However, the stage can be not finished yet: in our case, the total number of steps in the stage is 100,000 and we are trying to restart from step 42501.

@eliebak
Copy link
Contributor

eliebak commented Sep 26, 2024

cc @zzhhjjj maybe if you can take a look at this (i screenshot the part that show two different lr despite having the same lr_schedule in the config and resuming from ckpt)
Screenshot 2024-09-26 at 15 46 06
Screenshot 2024-09-26 at 15 45 58

@zzhhjjj
Copy link
Collaborator

zzhhjjj commented Sep 26, 2024

I think you are correct. I'll take a look. I remember seeing the same issue before. A temporary bypass would be to modify the metafile by hand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants