Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: e2e: compile Rougier's mandelbrot implementation #165

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ compiles, the performance is on par or slightly worse than the original NumPy.
## Mandelbrot fractal

Results strongly depend on an implementation: a straighforward NumPy implementation
uses a data-dependent loop, which does not compile.
uses complex-valued arrays which are not supported by triton.
Working around this and several other dynamo issues, leads to speedups of about x3 to x5.

The implementation based on the [Mojo benchmark](https://shashankprasanna.com/benchmarking-modular-mojo-and-pytorch-torch.compile-on-mandelbrot-function/index.html#benchmarking-pytorch-cpu-with-torchcompile) allows to compile the inner loop. The performance
increase relative to numpy is substantial and strongly data size and machine
Expand Down
14 changes: 5 additions & 9 deletions e2e/kmeans/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,9 @@
cfg.numpy_ndarray_as_tensor = True


# np.linalg.norm replacement (2-norm only), https://github.com/pytorch/pytorch/issues/105269
def norm(a, axis):
s = (a.conj() * a).real
return np.sqrt(s.sum(axis=axis))


#@torch.compile
# this will be compiled
def get_labels(X, centroids) -> np.ndarray:
return np.argmin(norm(X - centroids[:, None], axis=2),
return np.argmin(np.linalg.norm(X - centroids[:, None, :], ord=2, axis=2),
axis=0)


Expand All @@ -31,7 +25,7 @@ def init(npts):
import time

# ### numpy ###
npts = int(2e7)
npts = int(1e8)
X, centroids = init(npts)

start_time = time.time()
Expand All @@ -53,6 +47,8 @@ def init(npts):
start_time = time.time()
labels = get_labels_c(X, centroids)
end_time = time.time()
torch.cuda.synchronize()
compiled_time = end_time - start_time
print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time)


Binary file modified e2e/mandelbrot/mandelbrot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
125 changes: 99 additions & 26 deletions e2e/mandelbrot/mandelbrot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
# Copyright (2017) Nicolas P. Rougier - BSD license
# More information at https://github.com/rougier/numpy-book
# -----------------------------------------------------------------------------
#import numpy as np
import torch_np as np
import math
import numpy as np
import time

# need to import before torch
from matplotlib import colors
import matplotlib.pyplot as plt

# To run on CUDA, change "cpu" to "cuda" below.
import torch
torch.set_default_device("cpu")
import torch._dynamo.config as cfg
cfg.numpy_ndarray_as_tensor = True


# from mandelbrot_numpy_1 import mandelbrot # copy-paste below
# ### Original NumPy version. ###

def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0):
# Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/...
Expand All @@ -30,45 +35,113 @@ def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0):
return Z, N


if __name__ == '__main__':
from matplotlib import colors
import matplotlib.pyplot as plt
## from timeit import timeit

# Benchmark
xmin, xmax, xn = -2.25, +0.75, int(3000/3)
ymin, ymax, yn = -1.25, +1.25, int(2500/3)
maxiter = 200
## timeit("mandelbrot_1(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())
## timeit("mandelbrot_2(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())
## timeit("mandelbrot_3(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())
# ### Compiled analog. ###

# Visualization
xmin, xmax, xn = -2.25, +0.75, int(3000/2)
ymin, ymax, yn = -1.25, +1.25, int(2500/2)
maxiter = 200
horizon = 2.0 ** 40
log_horizon = np.log(np.log(horizon))/np.log(2)
Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon)
# For torch.Dynamo, need to work around
# 1. Complex numbers: add a trailing length-2 dimension for Re and Im parts.
# 2. Avoid fancy indexing: use with np.where instead to avoid data dependency
#
# Also:
# 1. Only compile the inner loop, to keep compile time and memory consumption
# under control (otherwise, can run into OOM while compiling)

def abs2(a):
r"""abs(a) replacement."""
return a[..., 0]**2 + a[..., 1]**2


def sq2(a):
"""a**2 replacement."""
z = np.empty_like(a)
z[..., 0] = a[..., 0]**2 - a[..., 1]**2
z[..., 1] = 2 * a[..., 0] * a[..., 1]
return z


@torch.compile(dynamic=True)
def step(n0, c, Z, N, horizon, chunksize):
for j in range(chunksize):
n = n0 + j
I = abs2(Z) < horizon**2
N = np.where(I, n, N) # N[I] = n
Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I]
return Z, N


def mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=2**10, maxiter=5):
x = np.linspace(xmin, xmax, xn, dtype='float32')
y = np.linspace(ymin, ymax, yn, dtype='float32')
c = np.stack(np.broadcast_arrays(x[None, :], y[:, None]), axis=-1)

N = np.zeros(c.shape[:-1], dtype='int')
Z = np.zeros_like(c, dtype='float32')

chunksize=10
n_chunks = maxiter // chunksize

for i_chunk in range(n_chunks):
n0 = i_chunk*chunksize
Z, N = step(n0, c, Z, N, horizon, chunksize)

N = np.where(N == maxiter-1, 0, N) # N[N == maxiter-1] = 0
return Z, N

# Normalized recount as explained in:
# http://linas.org/art-gallery/escape/smooth.html


# plot a nice figure
def visualize(Z, N, horizon, xn, yn):
log_horizon = math.log(horizon, 2)
M = np.nan_to_num(N + 1 - np.log(np.log(abs(Z)))/np.log(2) + log_horizon)

dpi = 72
width = 10
height = 10*yn/xn

fig = plt.figure(figsize=(width, height), dpi=dpi)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False, aspect=1)

light = colors.LightSource(azdeg=315, altdeg=10)

plt.imshow(light.shade(M.tensor.cpu().numpy(), cmap=plt.cm.hot, vert_exag=1.5,
plt.imshow(light.shade(M, cmap=plt.cm.hot, vert_exag=1.5,
norm = colors.PowerNorm(0.3), blend_mode='hsv'),
extent=[xmin, xmax, ymin, ymax], interpolation="bicubic")
ax.set_xticks([])
ax.set_yticks([])
plt.savefig("mandelbrot.png")
plt.show()



if __name__ == '__main__':
# start up
xmax, xmin, xn = -2.25, 0.75, 3000 // 2
ymax, ymin, yn = -1.25, 1.25, 2500 // 2

maxiter = 200
horizon = 2**10

# time numpy
start_time = time.time()
Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter)
end_time = time.time()
numpy_time = end_time - start_time
print("\n\nnumpy: elapsed=", numpy_time)


# compile, warm up, time
for _ in range(3):
mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter)

# measure
start_time = time.time()
Z, N = mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter)
end_time = time.time()
compiled_time = end_time - start_time
print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time)

# Visualization
Z = Z[..., 0] + 1j*Z[..., 1]
visualize(Z, N, horizon, xn, yn)


74 changes: 74 additions & 0 deletions e2e/mandelbrot/mandelbrot_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -----------------------------------------------------------------------------
# From Numpy to Python
# Copyright (2017) Nicolas P. Rougier - BSD license
# More information at https://github.com/rougier/numpy-book
# -----------------------------------------------------------------------------
#import numpy as np
import torch_np as np


# To run on CUDA, change "cpu" to "cuda" below.
import torch
torch.set_default_device("cpu")


# from mandelbrot_numpy_1 import mandelbrot # copy-paste below

def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0):
# Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/...
# .../entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=en
X = np.linspace(xmin, xmax, xn, dtype=np.float32)
Y = np.linspace(ymin, ymax, yn, dtype=np.float32)
C = X + Y[:,None]*1j
N = np.zeros(C.shape, dtype=int)
Z = np.zeros(C.shape, np.complex64)
for n in range(maxiter):
I = np.less(abs(Z), horizon)
N[I] = n
Z[I] = Z[I]**2 + C[I]
N[N == maxiter-1] = 0
return Z, N


if __name__ == '__main__':
from matplotlib import colors
import matplotlib.pyplot as plt
## from timeit import timeit

# Benchmark
xmin, xmax, xn = -2.25, +0.75, int(3000/3)
ymin, ymax, yn = -1.25, +1.25, int(2500/3)
maxiter = 200
## timeit("mandelbrot_1(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())
## timeit("mandelbrot_2(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())
## timeit("mandelbrot_3(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals())

# Visualization
xmin, xmax, xn = -2.25, +0.75, int(3000/2)
ymin, ymax, yn = -1.25, +1.25, int(2500/2)
maxiter = 200
horizon = 2.0 ** 40
log_horizon = np.log(np.log(horizon))/np.log(2)
Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon)

# Normalized recount as explained in:
# http://linas.org/art-gallery/escape/smooth.html
M = np.nan_to_num(N + 1 - np.log(np.log(abs(Z)))/np.log(2) + log_horizon)

dpi = 72
width = 10
height = 10*yn/xn

fig = plt.figure(figsize=(width, height), dpi=dpi)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False, aspect=1)

light = colors.LightSource(azdeg=315, altdeg=10)

plt.imshow(light.shade(M.tensor.cpu().numpy(), cmap=plt.cm.hot, vert_exag=1.5,
norm = colors.PowerNorm(0.3), blend_mode='hsv'),
extent=[xmin, xmax, ymin, ymax], interpolation="bicubic")
ax.set_xticks([])
ax.set_yticks([])
plt.savefig("mandelbrot.png")
plt.show()

Loading