Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Sep 18, 2023
1 parent 1172ac0 commit f2b1d61
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions blogpost/post.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Quansight engineers have implemented support for tracing through NumPy code via
`torch.compile` in PyTorch 2.1.
This feature leverages PyTorch's compiler to generate efficient fused
vectorized code without having to modify your original code. Even more, it
vectorized code without having to modify your original NumPy code. Even more, it
also allows for executing NumPy functions on CUDA just by running them through
`torch.compile` under `torch.device("cuda")`!

Expand Down Expand Up @@ -54,8 +54,8 @@ so this is the default behavior you get when using `torch.compile`.

We may inspect the generated C++ code by running the script with the
environment variable `TORCH_LOGS=output_code`. When doing so, we can see that
`torch.compile` was able to compile the broadcasting, together with the two
reductions into just one for-loop, and parallelizes it using OpenMP
`torch.compile` was able to compile the broadcasting and the two
reductions into just one for-loop, and parallelize it using OpenMP
```c++
extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
#pragma omp parallel num_threads(32)
Expand All @@ -71,7 +71,7 @@ extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr
## Compiling NumPy code into CUDA
Compiling our code so that it runs on CUDA is as simple as setting the
default dtype to be CUDA
default device to be CUDA
```python
with torch.device("cuda"):
Expand Down Expand Up @@ -139,19 +139,23 @@ def cuda_fn(X, means):
```

This function now takes tensors in CUDA memory and returns tensors in CUDA
memory, but the function itself is written in NumPy! When we keep the tensors
in CUDA and perform the computations in `float32`, we see a **200x speed-up**
over the initial NumPy implementation on `float32` arrays.
memory, but the function itself is written in NumPy! `torch.compile` uses the
`numpy()` and the `from_numpy()` calls as hints, and optimizes them away, and
internally it simply works with PyTorch tensors without moving moving the
memory at all. When we keep the tensors in CUDA and perform the computations in
`float32`, we see a **200x speed-up** over the initial NumPy implementation on
`float32` arrays.

**Mixing NumPy and PyTorch**. In this example, we had to write a small adaptor
to convert tensors to ndarrays and then back to tensors. In programs that mix PyTorch and
NumPy this is already done by calling `x.detach().cpu().numpy()` (or simply
`x.numpy(force=True)`). Since when running under `torch.compile` we can run
NumPy code in CUDA, we can simply modify this code to call `x.numpy()` and when
running it under `device("cuda")`, as we did above, it will generate efficient
CUDA code from original NumPy calls without copying the data from CUDA to CPU
at all. Note that the resulting code would not run without `torch.compile`. For
it to run in eager mode one would need to rollback to `x.numpy(force=True)`.
to convert tensors to ndarrays and then back to tensors. In programs that mix
PyTorch and NumPy converting a tensor into an ndarray is often implemented as
`x.detach().cpu().numpy()`, or simply `x.numpy(force=True)`. Since when running
under `torch.compile` we can run NumPy code in CUDA, we can implement this
conversion pattern as call to `x.numpy()`, as we did above. Doing so and
running the resulting code under `device("cuda")` will generate efficient CUDA
code from original NumPy calls without copying the data from CUDA to CPU at
all. Note that the resulting code does not run without `torch.compile`. For it
to run in eager mode one would need to rollback to `x.numpy(force=True)`.

## Further Speed-up tricks

Expand Down Expand Up @@ -215,8 +219,8 @@ explicit
**Type promotion and versioning**. NumPy's type promotion rules may be, at
times, a bit surprising
```python
>>> np.asarray([1], dtype=np.int8) + 127
array([128], dtype=int8)
>>> np.asarray([1], dtype=np.int8) + 126
array([127], dtype=int8)
>>> np.asarray([1], dtype=np.int8) + 128
array([129], dtype=int16)
```
Expand Down Expand Up @@ -254,8 +258,10 @@ are two more steps in this direction. Quansight and Meta continue working hand o
hand, improving the compatibility between PyTorch and the rest of the
ecosystem.

From Quansight, we would like to thank Meta for funding this project as well as
previous work on improving NumPy compatibility within PyTorch, and the project
that led to supporting PyTorch within scikit-learn and SciPy. These are giant leaps
towards consolidating PyTorch as the framework of choice within the open source
Python data ecosystem.
From Quansight, we would like to thank Mengwei, Voz, and Ed for their
invaluable help in integrating our work with `torch.compile`. We would also
like to thank Meta for funding this project as well as previous work on
improving NumPy compatibility within PyTorch, and the project that led to
supporting PyTorch within scikit-learn and SciPy. These are giant leaps towards
consolidating PyTorch as the framework of choice within the open source Python
data ecosystem.

0 comments on commit f2b1d61

Please sign in to comment.