Skip to content

Commit

Permalink
e2e: kmeans: remove a workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Jul 21, 2023
1 parent 12be7fb commit 58df095
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions e2e/kmeans/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@
# https://realpython.com/numpy-array-programming/#clustering-algorithms
import numpy as np
import torch
torch.set_default_device("cpu")
torch.set_default_device("cuda")
import torch._dynamo.config as cfg
cfg.numpy_ndarray_as_tensor = True


# np.linalg.norm replacement (2-norm only), https://github.com/pytorch/pytorch/issues/105269
def norm(a, axis):
s = (a.conj() * a).real
return np.sqrt(s.sum(axis=axis))


#@torch.compile
# this will be compiled
def get_labels(X, centroids) -> np.ndarray:
return np.argmin(norm(X - centroids[:, None], axis=2),
return np.argmin(np.linalg.norm(X - centroids[:, None, :], ord=2, axis=2),
axis=0)


Expand All @@ -31,7 +25,7 @@ def init(npts):
import time

# ### numpy ###
npts = int(2e7)
npts = int(1e8)
X, centroids = init(npts)

start_time = time.time()
Expand All @@ -53,6 +47,8 @@ def init(npts):
start_time = time.time()
labels = get_labels_c(X, centroids)
end_time = time.time()
torch.cuda.synchronize()
compiled_time = end_time - start_time
print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time)


0 comments on commit 58df095

Please sign in to comment.