Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: add code example of multi-GPU processing #6186

Closed
NielsRogge opened this issue Aug 28, 2023 · 18 comments · Fixed by #6415
Closed

Feature request: add code example of multi-GPU processing #6186

NielsRogge opened this issue Aug 28, 2023 · 18 comments · Fixed by #6415
Labels
documentation Improvements or additions to documentation enhancement New feature or request

Comments

@NielsRogge
Copy link
Contributor

NielsRogge commented Aug 28, 2023

Feature request

Would be great to add a code example of how to do multi-GPU processing with 🤗 Datasets in the documentation. cc @stevhliu

Currently the docs has a small section on this saying "your big GPU call goes here", however it didn't work for me out-of-the-box.

Let's say you have a PyTorch model that can do translation, and you have multiple GPUs. In that case, you'd like to duplicate the model on each GPU, each processing (translating) a chunk of the data in parallel.

Here's how I tried to do that:

from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from multiprocess import set_start_method
import torch
import os

dataset = load_dataset("mlfoundations/datacomp_small")

tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")

# put model on each available GPU
# also, should I do it like this or use nn.DataParallel?
model.to("cuda:0")
model.to("cuda:1")

set_start_method("spawn")

def translate_captions(batch, rank):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % torch.cuda.device_count())
    
    texts = batch["text"]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(model.device)

    translated_tokens = model.generate(
        **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30
    )
    translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)

    batch["translated_text"] = translated_texts
    
    return batch

updated_dataset = dataset.map(translate_captions, with_rank=True, num_proc=2, batched=True, batch_size=256)

I've personally tried running this script on a machine with 2 A100 GPUs.

Error 1

Running the code snippet above from the terminal (python script.py) resulted in the following error:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/spawn.py", line 125, in _main
    prepare(preparation_data)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/spawn.py", line 236, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/spawn.py", line 287, in _fixup_main_from_path
    main_content = runpy.run_path(main_path,
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/niels/python_projects/datacomp/datasets_multi_gpu.py", line 16, in <module>
    set_start_method("spawn")
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/context.py", line 247, in set_start_method
    raise RuntimeError('context has already been set')
RuntimeError: context has already been set

Error 2

Then, based on this Stackoverflow answer, I put the set_start_method("spawn") section in a try: catch block. This resulted in the following error:

File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/datasets/dataset_dict.py", line 817, in <dictcomp>
    k: dataset.map(
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 2926, in map
    with Pool(nb_of_missing_shards, initargs=initargs, initializer=initializer) as pool:
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/context.py", line 119, in Pool
    return Pool(processes, initializer, initargs, maxtasksperchild,
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/pool.py", line 215, in __init__
    self._repopulate_pool()
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/pool.py", line 306, in _repopulate_pool
    return self._repopulate_pool_static(self._ctx, self.Process,
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/pool.py", line 329, in _repopulate_pool_static
    w.start()
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/popen_spawn_posix.py", line 42, in _launch
    prep_data = spawn.get_preparation_data(process_obj._name)
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/spawn.py", line 154, in get_preparation_data
    _check_not_importing_main()
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/multiprocess/spawn.py", line 134, in _check_not_importing_main
    raise RuntimeError('''
RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

So then I put the last line under a if __name__ == '__main__': block. Then the code snippet seemed to work, but it seemed that it's only leveraging a single GPU (based on monitoring nvidia-smi):

Mon Aug 28 12:19:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   55C    P0    76W / 275W |   8747MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
| N/A   67C    P0   274W / 275W |  59835MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |

Both GPUs should have equal GPU usage, but I've always noticed that the last GPU has way more usage than the other ones. This made me think that os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % torch.cuda.device_count()) might not work inside a Python script, especially if done after importing PyTorch?

Motivation

Would be great to clarify how to do multi-GPU data processing.

Your contribution

If my code snippet can be fixed, I can contribute it to the docs :)

@NielsRogge NielsRogge added the enhancement New feature or request label Aug 28, 2023
@mariosasko mariosasko added the documentation Improvements or additions to documentation label Aug 29, 2023
@stevhliu
Copy link
Member

That'd be a great idea! @mariosasko or @lhoestq, would it be possible to fix the code snippet or do you have another suggested way for doing this?

@lhoestq
Copy link
Member

lhoestq commented Aug 30, 2023

Indeed if __name__ == "__main__" is important in this case.

Not sure about the imbalanced GPU usage though, but maybe you can try using the torch.cuda.device context manager ?

also, should I do it like this or use nn.DataParallel?

In this case you wouldn't need a multiprocessed map no ? Since nn.DataParallel would take care of parallelism

@NielsRogge
Copy link
Contributor Author

Adding this Tweet for reference: https://twitter.com/jxmnop/status/1716834517909119019.

@lhoestq
Copy link
Member

lhoestq commented Nov 14, 2023

I think the issue is that we set CUDA_VISIBLE_DEVICES after pytorch is imported ?

We should use torch.cuda.set_device(...) instead

@kopyl
Copy link

kopyl commented Jan 27, 2024

@lhoestq

In this case you wouldn't need a multiprocessed map no ?

Yes. But how to load a model to 2 GPU simultaneously without something like accelerate?

@forrestbao
Copy link

@lhoestq

In this case you wouldn't need a multiprocessed map no ?

Yes. But how to load a model to 2 GPU simultaneously without something like accelerate?

Take a look at this fix #6550 . Basically, you move the model to each GPU inside of the function to be mapped.

@forrestbao
Copy link

forrestbao commented Jan 31, 2024

In case someone also runs into this issue, I wrote a blog post with a complete working example by compiling information from several PRs and issues here. Hope it can help. This issue cost me a few hours. I hope my blog post can save you time before the official document gets fixed.

@lhoestq
Copy link
Member

lhoestq commented Jan 31, 2024

Thanks ! I updated the docs in #6550

@StephennFernandes
Copy link

StephennFernandes commented Feb 19, 2024

hey @forrestbao , i was too struggling with the same issue for weeks hence i checked out your blog. great work on the blog.
however i wanted to ask you could we scale up the process by reinitializing the same model on the same GPU multiple times for even more speedups ?

i mean to say given that on a multi GPU setup where GPU vram is above 40GB each, after intializing the translation model which is barely 1-2GB in VRAM size, the rest of VRAM sits idle, how could i keep creating multiple instances of the same model on the same GPU for all GPUs to maxmize flops ?

@lhoestq
Copy link
Member

lhoestq commented Feb 20, 2024

You can use one single instance on your GPU and increase the batch size until you fill the VRAM

@StephennFernandes
Copy link

StephennFernandes commented Feb 22, 2024

@lhoestq i tried that, but i noticed that after a certain number of batch_size, using a larger batch_size makes the overall process really slow than using a lower batch_size.

@mizhazha
Copy link

Hi @lhoestq , could you help with my two questions:

  1. You mentioned if __name__ == "__main__", why is that? I tried with a toy dataset and didn't put this line, my two GPU usage looks balanced.
  2. Is there any difference between
    from multiprocess import set_start_method and from multiprocessing import set_start_method? The latter is Python's built-in library. In the official doc, it uses from multiprocess import set_start_method, but it gives me error like
[jobuser@f6e2419a0a63d45638da-n0-0 ~]$ python test.py
Traceback (most recent call last):
  File "/home/jobuser/test.py", line 33, in <module>
    updated_dataset = dataset.map(
  File "/home/jobuser/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 593, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 558, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3189, in map
    with Pool(len(kwargs_per_job)) as pool:
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/context.py", line 119, in Pool
    return Pool(processes, initializer, initargs, maxtasksperchild,
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/pool.py", line 191, in __init__
    self._setup_queues()
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/pool.py", line 343, in _setup_queues
    self._inqueue = self._ctx.SimpleQueue()
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/context.py", line 113, in SimpleQueue
    return SimpleQueue(ctx=self.get_context())
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/queues.py", line 339, in __init__
    self._rlock = ctx.Lock()
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/context.py", line 68, in Lock
    return Lock(ctx=self.get_context())
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/synchronize.py", line 168, in __init__
    SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/synchronize.py", line 86, in __init__
    register(self._semlock.name, "semaphore")
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/resource_tracker.py", line 150, in register
    self._send('REGISTER', name, rtype)
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/resource_tracker.py", line 157, in _send
    self.ensure_running()
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/resource_tracker.py", line 124, in ensure_running
    pid = util.spawnv_passfds(exe, args, fds_to_pass)
  File "/home/jobuser/.local/lib/python3.10/site-packages/multiprocess/util.py", line 452, in spawnv_passfds
    return _posixsubprocess.fork_exec(
TypeError: fork_exec() takes exactly 21 arguments (17 given)

which seems caused by python version. I am using Python 3.10.2.

@lhoestq
Copy link
Member

lhoestq commented Mar 20, 2024

Hi !

You mentioned if name == "main", why is that? I tried with a toy dataset and didn't put this line, my two GPU usage looks balanced.

It's a good practice when doing multiprocessing in python. Depending on the multiprocessing method and your python version, python could re-run the code in your main.py in subprocesses that you don't want to re-run (e.g. recursively spawning processes and failing). Though some multiprocessing methods don't re-run main.py and it appears to be your case ;)

Is there any difference between
from multiprocess import set_start_method and from multiprocessing import set_start_method? The latter is Python's built-in library. In the official doc, it uses from multiprocess import set_start_method, but it gives me error like

Yes, datasets uses multiprocess which is a separate library from the built-in multiprocessing.

multiprocess is an extended version of multiprocessing which allows e.g. to pass lambda functions to subprocesses

@mizhazha
Copy link

Thanks @lhoestq for explanation. Is it okay we use multiprocessing for set_start_method given the above-mentioned issue for multiprocess? From my run with toy example, it's fine. Just want to check if you foresee any problems.

@lhoestq
Copy link
Member

lhoestq commented Mar 21, 2024

Not sure whether multiprocessing.set_start_method has any effect actually since we use dill for multiprocessed map()

@guynich
Copy link

guynich commented Mar 28, 2024

I'm running the code example of multi-GPU processing on a Linux 8x A100 instance. The entire python code run time is 30 seconds faster if I add one line to set torch number of threads immediately after the import torch statement. It loads faster to the eight GPUs (however the map() progress bars take similar amount of time without/with this additional line).

import torch
torch.set_num_threads(1)  # I added this line.

from multiprocess import set_start_method

FWIW: my instance has these versions.

CUDA 12.2 driver 535.161.08
Python 3.10.12
torch '2.2.2'
multiprocess '0.70.16'
transformers '4.39.2'
datasets '2.18.0'

@scopello
Copy link

scopello commented Oct 4, 2024

@lhoestq Thanks for the updated GPU multiprocessing documentation! When I tried to add updated_dataset.save_to_disk() after the map function with multiple GPUs, I get an error during saving:

Saving the dataset (0/20 shards):   0%|     | 78000/84761821 [01:07<20:22:25, 1154.59 examples/s]Exception ignored in: <generator object Dataset._save_to_disk_single at 0x7f2498f15070>
Traceback (most recent call last):
  File "/home/ubuntu/lib/python3.10/site-packages/datasets/utils/py_utils.py", line 679, in _write_generator_to_queue
    queue.put(result)
RuntimeError: generator ignored GeneratorExit

Do you have any thoughts?

@lhoestq
Copy link
Member

lhoestq commented Oct 7, 2024

Hmm first time I see this, and it's even more surprising given there is no generator in _write_generator_to_queue. Could you open a new issue ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants