diff --git a/e2e/mandelbrot/mandelbrot.png b/e2e/mandelbrot/mandelbrot.png index 840635f0..0a5aa621 100644 Binary files a/e2e/mandelbrot/mandelbrot.png and b/e2e/mandelbrot/mandelbrot.png differ diff --git a/e2e/mandelbrot/mandelbrot.py b/e2e/mandelbrot/mandelbrot.py index 874fdac8..6264ba2e 100644 --- a/e2e/mandelbrot/mandelbrot.py +++ b/e2e/mandelbrot/mandelbrot.py @@ -59,11 +59,13 @@ def sq2(a): return z -@torch.compile -def step(n, c, Z, N, horizon): - I = abs2(Z) < horizon**2 - N = np.where(I, n, N) # N[I] = n - Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I] +@torch.compile(dynamic=True) +def step(n0, c, Z, N, horizon, chunksize): + for j in range(chunksize): + n = n0 + j + I = abs2(Z) < horizon**2 + N = np.where(I, n, N) # N[I] = n + Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I] return Z, N @@ -75,8 +77,12 @@ def mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=2**10, maxiter=5): N = np.zeros(c.shape[:-1], dtype='int') Z = np.zeros_like(c, dtype='float32') - for n in range(maxiter): - Z, N = step(n, c, Z, N, horizon) + chunksize=10 + n_chunks = maxiter // chunksize + + for i_chunk in range(n_chunks): + n0 = i_chunk*chunksize + Z, N = step(n0, c, Z, N, horizon, chunksize) N = np.where(N == maxiter-1, 0, N) # N[N == maxiter-1] = 0 return Z, N