Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi gpu map example #6415

Merged
merged 3 commits into from
Nov 22, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,36 @@ You can also use [`~Dataset.map`] with indices if you set `with_indices=True`. T
]
```

### Multiprocessing

Multiprocessing significantly speeds up processing by parallelizing processes on the CPU. Set the `num_proc` parameter in [`~Dataset.map`] to set the number of processes to use:

```py
>>> updated_dataset = dataset.map(lambda example, idx: {"sentence2": f"{idx}: " + example["sentence2"]}, num_proc=4)
```

The [`~Dataset.map`] also works with the rank of the process if you set `with_rank=True`. This is analogous to the `with_indices` parameter. The `with_rank` parameter in the mapped function goes after the `index` one if it is already present.

```py
>>> from multiprocess import set_start_method
>>> import torch
>>> import os
>>>
>>> set_start_method("spawn")
>>> for i in range(torch.cuda.device_count()): # send model to every GPU
... model.to(torch.cuda.device(i))
Copy link
Contributor

@NielsRogge NielsRogge Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives me the following error:

Traceback (most recent call last):
  File "/home/niels/python_projects/datacomp/datasets_multi_gpu.py", line 14, in <module>
    model.to(torch.cuda.device(i))
  File "/home/niels/anaconda3/envs/datacomp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 968, in to
    device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
TypeError: to() received an invalid combination of arguments - got (device), but expected one of:
 * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (Tensor tensor, bool non_blocking, bool copy, *, torch.memory_format memory_format)

I used this instead:

for i in range(torch.cuda.device_count()):  # send model to every GPU
    model.to(f"cuda:{i}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it, thanks

>>>
>>> def gpu_computation(example, rank):
>>> os.environ["CUDA_VISIBLE_DEVICES"] = str(rank % torch.cuda.device_count())
>>> # Your big GPU call goes here
>>> return examples
... device = torch.cuda.device(rank % torch.cuda.device_count())
... torch.cuda.set_device(device) # use one GPU
... # Your big GPU call goes here, for example
... inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
... outputs = model.generate(**inputs)
... example["generated_text"] = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
... return example
>>>
>>> updated_dataset = dataset.map(gpu_computation, with_rank=True)
>>> if __name__ == "__main__":
... set_start_method("spawn")
... updated_dataset = dataset.map(gpu_computation, with_rank=True, num_proc=torch.cuda.device_count())
```

The main use-case for rank is to parallelize computation across several GPUs. This requires setting `multiprocess.set_start_method("spawn")`. If you don't you'll receive the following CUDA error:
Expand All @@ -363,14 +378,6 @@ The main use-case for rank is to parallelize computation across several GPUs. Th
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method.
```

### Multiprocessing

Multiprocessing significantly speeds up processing by parallelizing processes on the CPU. Set the `num_proc` parameter in [`~Dataset.map`] to set the number of processes to use:

```py
>>> updated_dataset = dataset.map(lambda example, idx: {"sentence2": f"{idx}: " + example["sentence2"]}, num_proc=4)
```

### Batch processing

The [`~Dataset.map`] function supports working with batches of examples. Operate on batches by setting `batched=True`. The default batch size is 1000, but you can adjust it with the `batch_size` parameter. Batch processing enables interesting applications such as splitting long sentences into shorter chunks and data augmentation.
Expand Down
Loading