Skip to content

Commit

Permalink
Support layer parallelism in transformer application (#2420)
Browse files Browse the repository at this point in the history
This PR adds the capability to support layer parallelism in transformers, variable-length version of The Pile pretokenized dataset, updates to the LBANN graph visualizer script, and some minor tweaks to weights layer.
  • Loading branch information
tbennun authored Feb 22, 2024
1 parent 6011d03 commit f3172ac
Show file tree
Hide file tree
Showing 18 changed files with 660 additions and 128 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from tqdm import trange
from multiprocessing import Pool
import numpy as np
import pickle


class Processor:

def __init__(self, total_threads: int):
self.threads = total_threads

def __call__(self, tid: int):
import thepile as dataset
num_samples = dataset.num_val_samples()
filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val.bin'
len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val-seqlen.bin'

with open(filename, 'ab') as fp:
with open(len_filename, 'ab') as slfp:
for i in trange(num_samples):
text = dataset.dataset_val[i]['text']
tokenized = dataset.tokenize(text)
sample = np.array(tokenized, dtype=np.uint16)
sample_len = np.array([len(sample)], dtype=np.uint32)
sample.tofile(fp)
sample_len.tofile(slfp)

print('Done')


if __name__ == '__main__':
threads = 1
with Pool(threads) as pool:
pool.map(Processor(threads), range(threads))
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from tqdm import trange
from multiprocessing import Pool
import numpy as np
import os
import argparse
from pathlib import Path


class Processor:

def __init__(self, total_threads: int):
self.threads = total_threads

def __call__(self, tid: int):
import thepile as dataset
num_samples = dataset.num_train_samples()
np.random.seed(20231023)
indices = np.random.permutation(num_samples)
local_samples = num_samples // self.threads
offset = tid * local_samples
# Add remainder
if tid == self.threads - 1:
local_samples += num_samples % self.threads
section = indices[offset:offset + local_samples]
filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-pretokenized-{tid:02d}-of-{self.threads}.bin'
len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-seqlen-{tid:02d}-of-{self.threads}.bin'

# Create file
if not os.path.isfile(filename):
Path(filename).touch()
if not os.path.isfile(len_filename):
Path(len_filename).touch()

sz = os.path.getsize(len_filename)
assert sz % 4 == 0
sequences_processed = sz // 4
print(tid, ': Size in bytes:', sz, '. Sequences processed:',
sequences_processed)

with open(filename, 'ab') as fp:
with open(len_filename, 'ab') as slfp:
for i in trange(sequences_processed,
section.shape[0],
desc=f'Thread {tid}'):
text = dataset.dataset_train[int(section[i])]['text']
sample = dataset.tokenize(text)
sample = np.array(sample, dtype=np.uint16)
sample.tofile(fp)
sample_len = np.array([len(sample)], dtype=np.uint32)
sample_len.tofile(slfp)


if __name__ == '__main__':
parser = argparse.ArgumentParser()

parser.add_argument('-j',
action='store',
default=0,
type=int,
help='Threads (default 0 = number of cores)')
parser.add_argument('-t',
action='store',
default=0,
type=int,
help='Total Chunks (default 0 = number of threads)')
parser.add_argument('-o',
action='store',
default=0,
type=int,
help='Chunk offset (default 0)')
args = parser.parse_args()

threads = args.j or os.cpu_count()
total_chunks = args.t or threads
offset = args.o
assert offset + threads <= total_chunks
with Pool(threads) as pool:
pool.map(Processor(total_chunks), range(offset, offset + threads))
11 changes: 10 additions & 1 deletion applications/nlp/transformer/datasets/thepile.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_train_sample(index):

def get_val_sample(index):
"""Token indices for a data sample from the validation set."""
text = dataset_train[index]['text']
text = dataset_val[index]['text']
tokenized = tokenize(text)

# Trim long sequences, left-pad short sequences
Expand Down Expand Up @@ -120,3 +120,12 @@ def sample_dims():

def vocab_size():
return tokenizer.get_vocab_size()


if __name__ == '__main__':
print('Training samples:', num_train_samples())
print('Validation samples:', num_val_samples())
print('Training sample 101:')
print(tokenizer.decode(get_train_sample(101)))
print('Validation sample 233:')
print(tokenizer.decode(get_val_sample(233)))
6 changes: 4 additions & 2 deletions applications/nlp/transformer/datasets/thepile_pretokenized.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The Pile dataset, stored as pre-tokenized binary files for optimized processing.
The Pile dataset, stored as pre-tokenized, pre-packed binary files for optimized processing.
"""
import os
import os.path
Expand All @@ -10,7 +10,9 @@
# Options
# ----------------------------------------------

sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))
# Sequence length is hardcoded to 512 in the pre-packed binary dataset.
# To use other sequence lengths, see ``thepile_pretokenized_varlen.py``
sequence_length = 512

# ----------------------------------------------
# Setup
Expand Down
105 changes: 105 additions & 0 deletions applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
The Pile dataset, stored as pre-tokenized binary files for optimized processing.
"""
import os
import os.path

import numpy as np
# ----------------------------------------------
# Options
# ----------------------------------------------

sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))

# ----------------------------------------------
# Setup
# ----------------------------------------------

# Load the datasets
data_dir = os.getenv('THE_PILE_DATA_DIR',
'/p/vast1/data/datasets/the-pile-pretokenized')
dataset_train = np.memmap(os.path.join(data_dir, 'train.bin'),
dtype=np.uint16,
mode='r')
sample_lengths_train = np.fromfile(os.path.join(data_dir, 'train-seqlen.bin'),
dtype=np.uint32).astype(np.uint64)
sample_offsets_train = np.zeros_like(sample_lengths_train)
sample_offsets_train[1:] = np.cumsum(sample_lengths_train)[:-1]
dataset_val = np.memmap(os.path.join(data_dir, 'val.bin'),
dtype=np.uint16,
mode='r')
sample_lengths_val = np.fromfile(os.path.join(data_dir, 'val-seqlen.bin'),
dtype=np.uint32).astype(np.uint64)
sample_offsets_val = np.zeros_like(sample_lengths_val)
sample_offsets_val[1:] = np.cumsum(sample_lengths_val)[:-1]

# Uses the definition from the GPT-NeoX-20B tokenizer
pad_index = 1 # '<|padding|>'
_vocab_size = 50277

# ----------------------------------------------
# Sample access functions
# ----------------------------------------------


def trim_and_pad(sample, random: bool):
# Trim long sequences
if len(sample) > sequence_length:
if random:
pos = np.random.rand()
offset = (len(sample) - sequence_length + 1) * pos
offset = int(np.floor(offset))
sample = sample[offset:offset + sequence_length]
else:
sample = sample[0:sequence_length]

# Left-pad short sequences
if len(sample) < sequence_length:
sample_pad = np.full(sequence_length, pad_index, dtype=np.int32)
if len(sample) > 0:
sample_pad[-len(sample):] = sample
return sample_pad

return sample


def get_train_sample(index: int):
sample = np.copy(
dataset_train[sample_offsets_train[index]:sample_offsets_train[index] +
sample_lengths_train[index]]).astype(np.int32)
return trim_and_pad(sample, True)


def get_val_sample(index):
sample = np.copy(
dataset_val[sample_offsets_val[index]:sample_offsets_val[index] +
sample_lengths_val[index]]).astype(np.int32)
return trim_and_pad(sample, False)


def num_train_samples():
return sample_lengths_train.shape[0]


def num_val_samples():
return sample_lengths_val.shape[0]


def sample_dims():
return (sequence_length, )


def vocab_size():
return _vocab_size


if __name__ == '__main__':
print('Training samples:', num_train_samples())
print('Validation samples:', num_val_samples())
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(
os.path.join(data_dir, '20B_tokenizer.json'))
print('Training sample 101:')
print(tokenizer.decode(get_train_sample(101)))
print('Validation sample 233:')
print(tokenizer.decode(get_val_sample(233)))
4 changes: 4 additions & 0 deletions applications/nlp/transformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
transformer, args)
parallelism.apply_ffn_model_parallelism(transformer, args)
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
parallelism.apply_layer_parallelism(transformer, args)

# Run through transformer
result = transformer(encoder_input, decoder_input, sequence_length - 1)
Expand Down Expand Up @@ -124,6 +125,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
)

parallelism.apply_fsdp_allweights(result, args)
parallelism.apply_layer_parallelism_postamble(result, args)
return result


Expand Down Expand Up @@ -186,6 +188,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
transformer, args)
parallelism.apply_ffn_model_parallelism(transformer, args)
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
parallelism.apply_layer_parallelism(transformer, args)

# Run through transformer with the same sequence
result = transformer(decoder_input, decoder_input, sequence_length)
Expand Down Expand Up @@ -227,6 +230,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
)

parallelism.apply_fsdp_allweights(result, args)
parallelism.apply_layer_parallelism_postamble(result, args)
return result


Expand Down
Loading

0 comments on commit f3172ac

Please sign in to comment.