-
Notifications
You must be signed in to change notification settings - Fork 0
/
p3o_trainer.py
1458 lines (1256 loc) · 62 KB
/
p3o_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import os
import time
import typing
import warnings
from contextlib import nullcontext
from typing import Callable, List, Optional, Union
import datasets
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, is_deepspeed_available
from datasets import Dataset
from huggingface_hub import whoami
from packaging import version
from torch.optim import Adam
from transformers import (
DataCollatorForLanguageModeling,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
from ..core import (
WANDB_PADDING,
PPODecorators,
convert_to_scalar,
entropy_from_logits,
flatten_dict,
logprobs_from_logits,
masked_mean,
set_seed,
stack_dicts,
stats_to_np,
)
from ..import_utils import is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, P3OConfig, RunningMoments
if is_deepspeed_available():
import deepspeed
MODEL_CARD_TEMPLATE = """---
license: apache-2.0
tags:
- trl
- transformers
- reinforcement-learning
---
# {model_name}
This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
## Usage
To use this model for inference, first install the TRL library:
```bash
python -m pip install trl
```
You can then generate text as follows:
```python
from transformers import pipeline
generator = pipeline("text-generation", model="{model_id}")
outputs = generator("Hello, my llama is cute")
```
If you want to use the model for training or to obtain the outputs from the value head, load the model as follows:
```python
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
tokenizer = AutoTokenizer.from_pretrained("{model_id}")
model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}")
inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
```
"""
class P3OTrainer(BaseTrainer):
"""
The PPOTrainer uses Proximal Policy Optimization to optimise language models.
Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
https://github.com/openai/summarize-from-feedback
Attributes:
**config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
details.
**model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
Check the documentation of `PreTrainedModelWrapper` for more details.
**ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
for more details. If no reference model is provided, the trainer will create a reference model with the same
architecture as the model to be optimized with shared layers.
**tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
data. Check the documentation of `transformers.PreTrainedTokenizer` and
`transformers.PreTrainedTokenizerFast` for more details.
**dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
created outside the trainer users needs to design their own dataloader and make sure the batch
size that is used is the same as the one specified in the configuration object.
**optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
object.
**data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
passed along the dataloader
**num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
model, if no reference model is passed. If no number is provided, all the layers will be shared.
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
"""
def __init__(
self,
config: P3OConfig = None,
model: PreTrainedModelWrapper = None,
ref_model: Optional[PreTrainedModelWrapper] = None,
tokenizer: PreTrainedTokenizerBase = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator: Optional[typing.Callable] = None,
num_shared_layers: Optional[int] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
"""
Initialize PPOTrainer.
Args:
config (`PPOConfig`):
Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a value head.
ref_model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a casual language modelling head. Used for KL penalty
tokenizer (`transformers.PreTrainedTokenizerBase`):
Hugging Face tokenizer
dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
will be preprocessed by removing the columns that are not used by the model. If none is passed,
a warning will be raised in a multi-GPU setting.
optimizer (Optional[`torch.optim.Optimizer`]):
Optimizer used for training. If `None`, the `Adam` is used as default.
data_collator (Optional[function]):
Data collator function.
num_shared_layers (Optional[int]):
Number of shared layers between the model and the reference model. If `None`, all layers are shared.
used only if `ref_model` is `None`.
lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
Learning rate scheduler used for training.
"""
super().__init__(config)
# initial seed for reproducible experiments
set_seed(config.seed)
# Step 0: check positional arguments validity
if not isinstance(config, P3OConfig):
raise ValueError(f"config must be a P3OConfig, got {type(config)}")
if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
raise ValueError(
f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
)
if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
raise ValueError(
f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
)
# Step 1: Initialize Accelerator
self.accelerator = Accelerator(
log_with=config.log_with,
gradient_accumulation_steps=config.gradient_accumulation_steps,
project_config=ProjectConfiguration(**config.project_kwargs),
**config.accelerator_kwargs,
)
# Step 1.1 Runtime variables filled by the accelerator
config.world_size = self.accelerator.num_processes
config.global_backward_batch_size = config.backward_batch_size * config.world_size
config.global_batch_size = config.batch_size * config.world_size
self.model = model
self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
self.is_peft_model = getattr(self.model, "is_peft_model", False)
config.is_encoder_decoder = self.is_encoder_decoder
config.is_peft_model = self.is_peft_model
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
self.accelerator.init_trackers(
config.tracker_project_name,
config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
init_kwargs=config.tracker_kwargs,
)
self.is_using_text_environment = getattr(config, "use_text_environment", False)
if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
self.ref_model = ref_model
if num_shared_layers is not None:
warnings.warn(
"num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
"model and the reference model and no layers are shared.",
UserWarning,
)
elif ref_model is None and not self.is_peft_model:
self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
elif self.is_peft_model:
self.ref_model = None
else:
raise ValueError(
f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
f"architectures are: {SUPPORTED_ARCHITECTURES} "
)
self.optional_peft_ctx = (
self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
if self.is_peft_model
else nullcontext
)
if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
raise ValueError(
"tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
)
self.tokenizer = tokenizer
if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
elif dataset is None:
warnings.warn(
"No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
UserWarning,
)
self.dataset = dataset
self._signature_columns = None
if self.dataset is not None:
self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
elif self.dataset is None and self.accelerator.num_processes > 1:
warnings.warn(
"No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
" prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
" and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
" refer to the documentation for more details.",
UserWarning,
)
self.dataloader = None
else:
self.dataloader = None
# Step 3: Initialize optimizer and data collator
self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
if optimizer is None:
self.optimizer = Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=self.config.learning_rate,
)
else:
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if self.lr_scheduler is not None:
lr_scheduler_class = (
torch.optim.lr_scheduler._LRScheduler
if not is_torch_greater_2_0()
else torch.optim.lr_scheduler.LRScheduler
)
if not isinstance(self.lr_scheduler, lr_scheduler_class):
raise ValueError(
"lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
)
if self.config.adap_kl_ctrl:
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
else:
self.kl_ctl = FixedKLController(self.config.init_kl_coef)
# Safety checkers for DS integration
is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
(
self.model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
)
if is_deepspeed_used:
# Quantized models are already set on the correct device
if not self.is_peft_model and not (
getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare(self.ref_model)
# In a distributed setup, only logging needs to be performed on the main process
# check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
# or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
self.is_distributed = self.accelerator.distributed_type == "MULTI_GPU"
# init the current step
self.current_step = 0
# init variables for pushing model to hub
if config.push_to_hub_if_best_kwargs:
if "repo_id" not in config.push_to_hub_if_best_kwargs:
raise ValueError("You have to specify repo_id in order to push the model to the hub!")
self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
self.compare_step = 0
self.highest_reward = torch.tensor(-float("inf"))
# post process for PP
if not getattr(self.model, "is_sequential_parallel", False):
self.current_device = self.accelerator.device
else:
if is_xpu_available():
self.current_device = torch.device("xpu:0")
else:
self.current_device = torch.device("cuda:0")
PPODecorators.optimize_device_cache = self.config.optimize_device_cache
self.running = RunningMoments(self.accelerator)
def _filter_kwargs(self, kwargs, target_func):
"""
filter the keyword arguments that are supported by the target function.
Args:
kwargs (dict):
Keyword arguments
target_func (function):
Target function
"""
return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
"""
Prepare the dataloader for training.
Args:
dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
will be preprocessed by removing the columns that are not used by the model.
data_collator (Optional[function]):
Data collator function.
Returns:
`torch.utils.data.DataLoader`: PyTorch dataloader
"""
if isinstance(dataset, Dataset):
dataset = self._remove_unused_columns(dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=self.config.batch_size,
collate_fn=data_collator,
shuffle=True,
drop_last=True,
)
return dataloader
# Adapted from transformers.Trainer._set_signature_columns_if_needed
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# label => sentiment | we need query and response for logging purpose
self._signature_columns += ["label", "query", "response"]
# Adapted from transformers.Trainer._remove_unused_columns
def _remove_unused_columns(self, dataset: "Dataset"):
if not self.config.remove_unused_columns:
return dataset
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
columns = [k for k in signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
type=dataset.format["type"],
columns=columns,
format_kwargs=dataset.format["format_kwargs"],
)
return dataset
else:
return dataset.remove_columns(ignored_columns)
def generate(
self,
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
length_sampler: Callable = None,
batch_size: int = 4,
return_prompt: bool = True,
generate_ref_response: bool = False,
responses_per_query: int = 2,
**generation_kwargs,
):
"""
Generate response with the model given the query tensor.
call the `generate` method of the model. Automatically generate two responses for each query
Args:
query_tensor (`torch.LongTensor`):
A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
generation_kwargs (dict[str, Any]):
Keyword arguments for generation.
length_sampler (`Callable`, *optional*):
Callable that returns the number of newly generated tokens.
batch_size (`int`, *optional):
Batch size used for generation, defaults to `4`.
return_prompt (`bool`, *optional*):
If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
generate_ref_response (`bool`, *optional*):
If set to `True` the reference response is also generated, defaults to `False`.
Returns:
`torch.LongTensor`: A tensor of shape (`2 * batch_size`, `gen_len`) containing response tokens.
"""
if responses_per_query != 2:
raise NotImplementedError("Only 2 responses per query are supported by P3O at the moment")
if generate_ref_response:
ref_model = self.model if self.is_peft_model else self.ref_model
if isinstance(query_tensor, List):
query_tensor_repeated = query_tensor * responses_per_query
response = self._generate_batched(
self.model,
query_tensor_repeated,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = self._generate_batched(
ref_model,
query_tensor_repeated,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
else:
if len(query_tensor.shape) == 2:
raise ValueError(
"query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
)
query_tensor_repeated = torch.vstack([query_tensor] * responses_per_query)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
response = self.accelerator.unwrap_model(self.model).generate(
input_ids=query_tensor_repeated.unsqueeze(dim=0), **generation_kwargs
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = ref_model.generate(
input_ids=query_tensor_repeated.unsqueeze(dim=0), **generation_kwargs
)
if not return_prompt and not self.is_encoder_decoder:
response = response[:, query_tensor.shape[0] :]
if generate_ref_response:
ref_response = ref_response[:, query_tensor.shape[0] :]
if generate_ref_response:
return response, ref_response
return response
def _generate_batched(
self,
model: PreTrainedModelWrapper,
query_tensors: List[torch.Tensor],
length_sampler: Callable = None,
batch_size: int = 4,
return_prompt: bool = True,
pad_to_multiple_of: int = None,
remove_padding: bool = True,
**generation_kwargs,
) -> List[torch.Tensor]:
outputs = []
padding_side_default = self.tokenizer.padding_side
if not self.is_encoder_decoder:
self.tokenizer.padding_side = "left"
# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for i in range(0, len(query_tensors), batch_size):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
if not self.is_encoder_decoder:
output = generation[(1 - mask).sum() :] # remove padding
else:
output = generation
if not return_prompt and not self.is_encoder_decoder:
output = output[(mask).sum() :] # remove prompt
if remove_padding and self.tokenizer.eos_token_id in output:
pad_mask = output == self.tokenizer.eos_token_id
pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
output = output[: pad_start + 1] # keep the eos token at the end
outputs.append(output)
self.tokenizer.padding_side = padding_side_default
return outputs
def _step_safety_checker(
self,
batch_size: int,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
masks: Optional[List[torch.LongTensor]] = None,
):
"""s
Check if the input data is valid for training.
Args:
batch_size (int):
Batch size from the config file.
queries (List[`torch.LongTensor`]):
List of tensors containing the encoded queries of shape (`query_length`)
responses (List[`torch.LongTensor`]):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
masks (List[`torch.LongTensor`], *optional*):
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
Returns:
`tuple`: The input processed data.
"""
num_per_query = {"queries": 1, "responses": 2, "scores": 2}
for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
if not isinstance(tensor_list, list):
raise ValueError(f"queries must be a list of tensors - got {type(tensor_list)}")
if not isinstance(tensor_list[0], torch.Tensor):
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
if batch_size is not None and len(tensor_list) != num_per_query[name] * batch_size:
raise ValueError(
f"Batch size ({num_per_query[name] * batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
)
# add queries, scores and responses on the correct device, duplicate queries to match the number of responses
queries = [tensor.to(self.current_device) for tensor in queries]
double_queries = queries + queries
responses = [tensor.to(self.current_device) for tensor in responses]
scores = [tensor.to(self.current_device) for tensor in scores]
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
# squeeze scores if needed
for i, score in enumerate(scores):
if score.dim() > 1:
raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
elif score.dim() == 1:
scores[i] = score.squeeze()
return double_queries, responses, scores, masks
@PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
response_masks: Optional[List[torch.LongTensor]] = None,
):
"""
Run a PPO optimisation step given a list of queries, model responses, and rewards.
Args:
queries (List[`torch.LongTensor`]):
List of tensors containing the encoded queries of shape (`query_length`)
responses (List[`torch.LongTensor`]):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
response_masks (List[`torch.FloatTensor`], *optional*)):
List of tensors containing masks of the response tokens.
Returns:
`dict[str, Any]`: A summary of the training statistics
"""
bs = self.config.batch_size
queries_repeated, responses, scores, response_masks = self._step_safety_checker(
bs, queries, responses, scores, response_masks
) # want it to return tensor with the same first dim, queries will be duplicated
scores = torch.tensor(scores, device=self.current_device)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
if self.config.use_score_norm:
scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
else:
scores /= score_scaling_factor
if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
# push model to hub
self.push_to_hub(**self.push_to_hub_kwargs)
self.compare_step += 1
timing = dict()
t0 = time.time()
t = time.time()
model_inputs = self.prepare_model_inputs(queries_repeated, responses)
if self.is_distributed:
pad_first = self.tokenizer.padding_side == "left"
model_inputs["input_ids"] = self.accelerator.pad_across_processes(
model_inputs["input_ids"],
dim=1,
pad_index=self.tokenizer.pad_token_id,
pad_first=pad_first,
)
model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
)
if self.is_encoder_decoder:
model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
model_inputs["decoder_input_ids"],
dim=1,
pad_index=self.tokenizer.pad_token_id,
pad_first=pad_first,
)
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["decoder_attention_mask"],
dim=1,
pad_index=0,
pad_first=pad_first,
)
model_inputs_names = list(model_inputs.keys())
full_kl_penalty = self.config.kl_penalty == "full"
with torch.no_grad():
all_logprobs, logits_or_none, masks = self.batched_forward_pass(
self.model,
queries_repeated,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _ = self.batched_forward_pass(
self.model if self.is_peft_model else self.ref_model,
queries_repeated,
responses,
model_inputs,
return_logits=full_kl_penalty,
)
timing["time/p3o/forward_pass"] = time.time() - t
with torch.no_grad():
t = time.time()
if full_kl_penalty:
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
rewards, non_score_reward = self.compute_rewards(
scores, active_full_logprobs, ref_full_logprobs, masks
)
else:
rewards, non_score_reward = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
timing["time/p3o/compute_rewards"] = time.time() - t
# upcast to float32 to avoid dataset issues
batch_dict = {
"queries": queries_repeated,
"responses": responses,
"logprobs": all_logprobs.to(torch.float32),
"masks": masks,
"rewards": rewards,
}
batch_dict.update(model_inputs)
t = time.time()
all_stats = []
early_stop = False
for _ in range(self.config.p3o_epochs):
if early_stop:
break
b_inds = np.random.permutation(bs)
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
backward_batch_end = backward_batch_start + self.config.backward_batch_size
backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
mini_batch_end = mini_batch_start + self.config.mini_batch_size
mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
mini_batch_dict = {
"logprobs": batch_dict["logprobs"][mini_batch_inds],
"masks": batch_dict["masks"][mini_batch_inds],
# hacks: the queries and responses are ragged.
"queries": [batch_dict["queries"][i] for i in mini_batch_inds],
"responses": [batch_dict["responses"][i] for i in mini_batch_inds],
"rewards": batch_dict["rewards"][mini_batch_inds],
}
for k in model_inputs_names:
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
with self.accelerator.accumulate(self.model):
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
logprobs, logits, _ = self.batched_forward_pass(
self.model,
mini_batch_dict["queries"],
mini_batch_dict["responses"],
model_inputs,
return_logits=True,
)
train_stats = self.train_minibatch(
logprobs,
mini_batch_dict["logprobs"],
mini_batch_dict["rewards"],
logits,
mini_batch_dict["masks"],
)
all_stats.append(train_stats)
# typically, early stopping is done at the epoch level
if self.config.early_stopping:
policykl = train_stats["policy/policykl"]
early_stop = self._early_stop(policykl)
if early_stop:
break
timing["time/p3o/optimize_step"] = time.time() - t
t = time.time()
train_stats = stack_dicts(all_stats)
stats = self.record_step_stats(
scores=scores,
logprobs=all_logprobs,
ref_logprobs=ref_logprobs,
non_score_reward=non_score_reward,
train_stats=train_stats,
kl_coef=self.kl_ctl.value,
masks=masks,
queries=queries_repeated,
responses=responses,
)
# Gather/Reduce stats from all processes
if self.is_distributed:
stats = self.gather_stats(stats)
stats = stats_to_np(stats)
timing["time/p3o/calc_stats"] = time.time() - t
stats["p3o/learning_rate"] = self.optimizer.param_groups[0]["lr"]
# Update the KL control - multiply the batch_size by the number of processes
self.kl_ctl.update(
stats["objective/kl"],
self.config.batch_size * self.accelerator.num_processes,
)
# Log the total ppo time
timing["time/p3o/total"] = time.time() - t0
stats.update(timing)
# post-process stats for tensorboard and other loggers
if self.config.log_with != "wandb":
stats = convert_to_scalar(stats)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return stats
def _early_stop(self, policykl):
r"""
Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
the optimization step is skipped.
This also handles the multi-gpu case where the policy KL is averaged across all processes.
Args:
policy_kl (torch.Tensor):
the policy KL
Returns:
`bool`: whether to early stop or not
"""
early_stop = False
if not self.config.early_stopping:
return early_stop
if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
elif self.is_distributed:
import torch.distributed as dist
# Wait for all processes to finish
dist.barrier()
# all gather the policykl
dist.all_reduce(policykl, dist.ReduceOp.SUM)
policykl /= self.accelerator.num_processes
if policykl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
return early_stop
def gather_stats(self, stats):
"""
Gather stats from all processes. Useful in the context of distributed training.
Args:
stats (dict[str, Any]):
a dictionary of stats to be gathered. The stats should contain torch tensors.
Returns:
`dict[str, Any]`: A dictionary of stats with the tensors gathered.
"""
import torch.distributed as dist
# Wait for all processes to finish
dist.barrier()
for k, v in stats.items():
if isinstance(v, torch.Tensor):
dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
v /= self.accelerator.num_processes
stats[k] = v
return stats
def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
if self.is_encoder_decoder:
input_data = self.data_collator(
[{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
).to(self.current_device)
decoder_inputs = self.data_collator(
[{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
).to(self.current_device)
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
else:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator(
[{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
).to(self.current_device)
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
def prepare_loss_input(self, *data):
bs = data[0].shape[0]
if bs % 2 != 0:
raise ValueError(
"Input to prepare_loss_input must have an even number of examples, since every query has two responses."
)
half_bs = bs // 2
return_data = []
for d in data:
return_data.append([d[:half_bs], d[half_bs:]])
return tuple(return_data)
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: PreTrainedModelWrapper,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
"""
Calculate model outputs in multiple batches.
Args:
queries (`torch.LongTensor`):
List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
responses (`torch.LongTensor`):
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
return_logits (`bool`, *optional*, defaults to `False`):
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
Returns:
(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
# all_values = []
model.eval()
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
logits, *_ = model(**input_kwargs) # TODO: need to add position_ids
if self.is_encoder_decoder:
input_ids = input_kwargs["decoder_input_ids"]
attention_mask = input_kwargs["decoder_attention_mask"]
else:
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
if self.is_encoder_decoder:
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
start = 1
end = attention_mask[j, :].sum() - 1
else:
start = len(query_batch[j]) - 1 # logprobs starts from the second query token