Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blogpost for the PyTorch blog #174

Merged
merged 11 commits into from
Sep 22, 2023
Merged
252 changes: 252 additions & 0 deletions blogpost/post.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# Compiling NumPy into C++ or CUDA via `torch.compile`
lezcano marked this conversation as resolved.
Show resolved Hide resolved

Tracing through NumPy code via `torch.compile` is now possible in PyTorch 2.1.
This feature leverages PyTorch's compiler to generate efficient fused
vectorized code without having to modify your original code. Even more, it
also allows for executing NumPy functions on CUDA just by running them through
`torch.compile` under `torch.device("cuda")`!

In this post, we go over how to use this feature and give a few tips and tricks
to make the most of it.


## Compiling NumPy into Parallel C++

We will take as our running example the iteration step in a K-Means algorithm
presented in this [NumPy book](https://realpython.com/numpy-array-programming/#clustering-algorithms)

```python
import numpy as np

def get_labels(X, means):
return np.argmin(np.linalg.norm(X - means[:, None], axis=2), axis=0)
```

We create a synthetic dataset with 10M random 2-D points. We can see that,
given that the means are chosen appropriately, the function returns the correct
cluster for all of them

```python
npts = 10_000_000
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape) # 2 distinct "blobs"
means = np.array([[5, 5], [10, 10]])
pred = get_labels(X, means)
```

Benchmarking this function gives us a baseline of **1.26s** on an AMD 3970X.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

Compiling this function is now as easy as wrapping it with `torch.compile` and
executing it with the example inputs

```python
compiled_fn = torch.compile(get_labels)
new_pred = compiled_fn(X, means)
assert np.allclose(prediction, new_pred)
```

The compiled function yields a 9x speed-up when running it on 1 core. Even
better, since the compiled code also runs on multiple cores, we get a **57x speed-up**
when running it on 32 cores. Note that vanilla NumPy always runs on
just one core.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you compared the performance to other options for compiling numpy. The get_labels example contains functions which are supported by numba, that might be an interesting datapoint to see how the torch.compile speedup compares to numba.jit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, as this would just make our announcement post way too long. I will expect other people to put up posts comparing this and other approaches (julia / numba / torch.jit / mojo), but that's beyond the scope of this post IMO

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to write a follow-up, based on #168 or https://github.com/Quansight-Labs/numpy_pytorch_interop/tree/main/e2e/smoke, comparing to numba.jit may fit there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd be super cool!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weirdly enough, throwing @numba.njit in place of torch.compile in this example runs into several rough edges of numba:

  • np.linalg.norm(..., axis=2) does not compile, chokes on axis.
  • Replacing np.linalg.norm with
def norm(a, axis):
    s = (a.conj() * a).real
    return np.sqrt(s.sum(axis=axis))

and njit-ting both norm and get_labels gives a slowdown of about 60%.

  • trying @njit(parallel=True) on get_labels crashes the compiler, again on the axis argument (TypeError: Failed in nopython mode pipeline (step: Preprocessing for parfors) got an unexpected keyword argument 'axis')

  • trying @njit(parallel=True) on the norm only and @njit on get_labels yields a slowdown w.r.t. numpy of 'only' 20%.

Got to admit I never had much luck with numba in non-toy situations.


We may inspect the generated C++ code by running the script with
`TORCH_LOGS=output_code`, and we can see that `torch.compile` was able to
compile the broadcasting, together with the two reductions into just one
for-loop, and it parallelizes it using OpenMP
```c++
extern "C" void kernel(const double* in_ptr0, const long* in_ptr1, long* out_ptr0) {
#pragma omp parallel num_threads(32)
#pragma omp for
for(long i0=0L; i0<20000000L; i0+=1L) {
auto tmp0 = in_ptr0[2L*i0];
auto tmp1 = in_ptr1[0L];
auto tmp5 = in_ptr0[1L + (2L*i0)];
auto tmp6 = in_ptr1[1L];
...
lezcano marked this conversation as resolved.
Show resolved Hide resolved
```

## Compiling NumPy into CUDA

Compiling our code so that it runs on CUDA is as simple as setting locally the
default dtype to be the CUDA
lezcano marked this conversation as resolved.
Show resolved Hide resolved

```python
with torch.device("cuda"):
cuda_pred = compiled_fn(X, means)
assert np.allclose(prediction, cuda_pred)
```

By inspecting the generated code via `TORCH_LOGS=output_code`, we see that,
rather than generating CUDA code directly, `torch.compile` generates rather
readable [triton](https://triton-lang.org/main/index.html) code

```python
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 20000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (2*x0), xmask)
tmp1 = tl.load(in_ptr1 + (0))
...
```

Running this small snippet on an RTX 2060 gives an **8x speed-up** over the
original NumPy code. This is something, but it is not particularly impressive,
given the speed-ups we have seen on CPU. Let's have a look into how to get some
proper speed-ups on CUDA via a couple minor changes.

**`float64` vs `float32`**. Many GPUs, in particular consumer-grade ones, are
rather sluggish when running operations on `float64`. For this reason, changing
the data generation to `float32`, the original NumPy code just gets a bit
faster, about a 9%, but our CUDA code gets **40% faster**, yielding a **11x
speed-up** over the plain NumPy code.

`torch.compile`, by default, respects the NumPy semantics, and as such, it uses
`np.float64` as its default dtype for all its creation ops. As discussed, this
can hinder performance, so it is possible to change this default by setting

```python
from torch._dynamo import config
config.numpy_default_float = "float32"
```

**CPU <> CUDA copies**. An 11x speed-up is good, but it is not even close to
the CPU numbers. This is caused by a small transformation that `torch.compile`
does behind the scenes. The code above takes NumPy arrays and returns NumPy
arrays. All of these arrays are on CPU, but the computations are performed on
the GPU. This means that every time the function is called, `torch.compie` has
lezcano marked this conversation as resolved.
Show resolved Hide resolved
to copy all these arrays from CPU to the GPU, and then copy the result from
CUDA back to CPU to preserve the original semantics. There is no native
solution to this issue in NumPy, as NumPy does not have the notion of a
`device`. That being said, we can work around it by creating a wrapper to this
function so that it accepts PyTorch tensors and returns PyTorch tensors.

```python
@torch.compile
def tensor_fn(X, means):
X, means = X.numpy(), means.numpy()
ret = get_labels(X, means)
return torch.from_numpy(ret)

def cuda_fn(X, means):
with torch.device("cuda"):
return tensor_fn(X, means)
```

This function now takes tensors in CUDA memory and returns tensors in CUDA
lezcano marked this conversation as resolved.
Show resolved Hide resolved
memory, but the function itself is written in NumPy! When we keep the tensors
in CUDA and perform the computations in `float32`, we see a **200x speed-up**
over the initial NumPy implementation on `float32` arrays.

**Mixing NumPy and PyTorch**. In this example, we had to write a small adaptor
to move the data from CPU to CUDA and back. In programs that mix PyTorch and
NumPy this is already done by calling `x.detach().cpu().numpy()` (or simply
lezcano marked this conversation as resolved.
Show resolved Hide resolved
`x.numpy(force=True)`). Since when running under `torch.compile` we can run
NumPy code in CUDA, we can simply modify this code to call `x.numpy()` and when
running it under `device("cuda")`, as we did above, it will generate efficient
CUDA code from original NumPy calls without copying the data from CUDA to CPU
at all.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

## Further Speed-up tricks

**General advice**. The CUDA code we have shown is already quite efficient, but
it is true that this is a rather tiny program. When dealing with larger
programs, we may need to tweak parts of it to make it more efficient. A good
place to start is the [`torch.compile` troubleshooting
page](https://pytorch.org/docs/stable/dynamo/troubleshooting.html#performance-profiling).
This showcases a number of ways to inspect the tracing process, and how to
identify problematic code that may cause slow downs.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

**Advice when compiling NumPy code**. NumPy, even if it is rather similar to
PyTorch, it is often used very differently. It is rather common to perform
lezcano marked this conversation as resolved.
Show resolved Hide resolved
computations in NumPy and then do an if/else depending on the value of the
array, or perform operations in-place, perhaps via boolean masks. These
constructions, while supported by `torch.compile`, hamper its performance.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
Changes like moving from in-place indexing to using `np.where`, writing the
lezcano marked this conversation as resolved.
Show resolved Hide resolved
code in a branchless way, or avoid using in-place ops in favor of out-of-place
ops can go a long way.

To write fast NumPy code, it is best to avoid loops, but sometimes they are
unavoidable. When tracing through a loop, `torch.compile` will try to fully
unroll it. This is sometimes desirable, but sometimes it may not even be
possible, like when we have a dynamic stopping condition (like a while loop).
In these cases, it may be best to just compile the body of the loop, perhaps
compiling a few iterations at a time (loop unrolling).

**Debugging NumPy code**. Debugging is rather tricky when a compiler is
involved. To figure out whether an error you are hitting is a `torch.compile`
error, or an error from the program, you can execute your NumPy program without
`torch.compile` by replacing the NumPy import by `import torch._numpy as np`.
ev-br marked this conversation as resolved.
Show resolved Hide resolved
This is should just be used for **debugging purposes** and is in no way a
replacement for the PyTorch API, as it is **much slower**.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

## Differences between NumPy and `torch.compile`d NumPy
Copy link
Collaborator

@ev-br ev-br Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to link to #5 (or move it the list in that issue somewhere more appropriate)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's on my TODO list for next week!


**NumPy scalars**. NumPy returns NumPy scalars in almost any case where PyTorch
would return a 0-D tensor (e.g. from `np.sum`). Under `torch.compile`, NumPy
scalars are treated as 0-D arrays. This is just fine in most cases. The only
case when their behavior diverges is when NumPy scalars are implicitly used as
Python scalars. For example,
```python
>>> np.asarray(2) * [1, 2, 3] # 0-D array is an array-like
array([2, 4, 6])
>>> u = np.int32(2)
>>> u * [1, 2, 3] # scalar decays into a Python int
[1, 2, 3, 1, 2, 3]
>>> torch.compile(lambda: u * [1, 2, 3])()
array([2, 4, 6]) # acts as a 0-D array, not as a scalar ?!?!
```

If we compile the first two lines, we see that `torch.compile` treats `u` as a
0-D array. To recover the eager semantics, we just need to make the casting
explicit
```python
>>> torch.compile(lambda: int(u) * [1, 2, 3])()
[1, 2, 3, 1, 2, 3]
```

**Type promotion and versioning**. NumPy's type promotion rules may be, at
times, a bit surprising
```python
>>> np.asarray([1], dtype=np.int8) + 127
array([128], dtype=int8)
lezcano marked this conversation as resolved.
Show resolved Hide resolved
>>> np.asarray([1], dtype=np.int8) + 128
array([129], dtype=int16)
```
These rules are changing to follow a set of rules that is closer to that of
PyTorch in NumPy 2.0. The relevant technical document is [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html).
`torch.compile` went ahead and implemented NEP 50 rather than the about-to-be-deprecated rules.

In general, `torch.compile` will match the semantics of the last NumPy release.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

## Beyond NumPy: SciPy and scikit-learn

In parallel to this effort, other Quansight engineers have designed, proposed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "other" here is a bit of a surprise perhaps, since Quansight hasn't been merged. Or will it be in the author attribution at the top of the post?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first part of the sentence, maybe be more explicit? E.g. "In parallel to this effort of making torch.compile understand NumPy code, ..."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re other. I agree. I want to start mentioning Quansight here, but it seems a bit out of the blue. Perhaps also mention it at in the first paragraph of the blogpost?

and got merged a way to support PyTorch arrays within SciPy and scikit-learn.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
This was encountered with a big enthusiasm by the other maintainers from these
libraries, as it was shown that using PyTorch as a backend would often yield
considerable speed-ups.

This can of course be combined with `torch.compile` to be able to compile
programs that rely on these other libraries.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not quiter for SciPy, since it's also encompassing compiled code without matching functionality in PyTorch.


Note that the initial support is just restricted to a few algorithms in
scikit-learn and to `scipy.cluster` in SciPy.
lezcano marked this conversation as resolved.
Show resolved Hide resolved

If you want to learn more about this effort, how to use it, or how to help
moving it forward, see this post. [TODO link post]

## Conclusion [TODO Make sure Greg approves this wording] PyTorch has committed
since its inception to be a framework compatible with the rest of the Python
ecosystem. Enabling compiling NumPy programs, and establishing the tools
necessary to do the same for other prominent libraries are two more steps in
this direction. Quansight and Meta continue working in this direction,
improving the compatibility between PyTorch and the rest of the ecosystem.

From Quansight, we would like to thank Meta for funding this project and all
lezcano marked this conversation as resolved.
Show resolved Hide resolved
the previous work that lead to it, like improving the NumPy compatibility
within PyTorch, and developing the [python Array API](https://data-apis.org/array-api/latest/).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't involve any Meta funding actually. Quansight and Intel have been the largest funders, the others are listed on https://data-apis.org/. Meta did fund the work on adding PyTorch support to https://github.com/data-apis/array-api-compat, but that work and funding acknowledgement are covered in Thomas' blog post on the scikit-learn work.

Without this consistent support, this would not have been possible.
Loading