diff --git a/fastpair/base.py b/fastpair/base.py index f5c8749..9f73321 100644 --- a/fastpair/base.py +++ b/fastpair/base.py @@ -45,6 +45,7 @@ class attrdict(dict): """Simple dict with support for accessing elements as attributes.""" + def __init__(self, *args, **kwargs): super(attrdict, self).__init__(*args, **kwargs) self.__dict__ = self @@ -53,6 +54,7 @@ def __init__(self, *args, **kwargs): class FastPair(object): """FastPair 'sketch' class. """ + def __init__(self, min_points=10, dist=dist.euclidean): """Initialize an empty FastPair data-structure. @@ -145,11 +147,11 @@ def build(self, points=None): np = len(self) # Go through and find all neighbors, placing then in a conga line - for i in range(np-1): + for i in range(np - 1): # Find neighbor to p[0] to start nbr = i + 1 nbd = float("inf") - for j in range(i+1, np): + for j in range(i + 1, np): d = self.dist(self.points[i], self.points[j]) if d < nbd: nbr = j @@ -161,8 +163,8 @@ def build(self, points=None): self.points[i + 1] = self.neighbors[self.points[i]].neigh # No more neighbors, terminate conga line. # Last person on the line has no neigbors :( - self.neighbors[self.points[np-1]].neigh = self.points[np-1] - self.neighbors[self.points[np-1]].dist = float("inf") + self.neighbors[self.points[np - 1]].neigh = self.points[np - 1] + self.neighbors[self.points[np - 1]].dist = float("inf") self.initialized = True return self @@ -187,7 +189,7 @@ def closest_pair(self): def closest_pair_brute_force(self): """Find closest pair using brute-force algorithm.""" - return _closest_pair_brute_force(self.points) + return _closest_pair_brute_force(self.points, self.dist) def sdist(self, p): """Compute distances from input to all other points in data-structure. @@ -204,8 +206,7 @@ def sdist(self, p): >>> fp = FastPair().build(points) >>> min(fp.sdist(point), key=itemgetter(0)) """ - return ((self.dist(a, b), b) for a, b in - zip(cycle([p]), self.points) if b != a) + return ((self.dist(a, b), b) for a, b in zip(cycle([p]), self.points) if b != a) def _find_neighbor(self, p): """Find and update nearest neighbor of a given point.""" @@ -221,7 +222,7 @@ def _find_neighbor(self, p): self.neighbors[p].neigh = self.points[first_nbr] self.neighbors[p].dist = self.dist(p, self.neighbors[p].neigh) # Now test whether each other point is closer - for q in self.points[first_nbr+1:]: + for q in self.points[first_nbr + 1 :]: if p != q: d = self.dist(p, q) if d < self.neighbors[p].dist: diff --git a/fastpair/test/test_fastpair.py b/fastpair/test/test_fastpair.py index 7331118..6967ff4 100644 --- a/fastpair/test/test_fastpair.py +++ b/fastpair/test/test_fastpair.py @@ -21,6 +21,38 @@ from math import isinf, isnan from scipy import mean, array, unique +import numpy as np + + +def normalized_distance(_a, _b): + b = _b.astype(int) + a = _a.astype(int) + norm_diff = np.linalg.norm(b - a) + norm1 = np.linalg.norm(b) + norm2 = np.linalg.norm(a) + return norm_diff / (norm1 + norm2) + + +def image_distance(image1, image2): + (sig1, _) = image1 + (sig2, _) = image2 + sig1 = np.frombuffer(sig1, np.int8) + sig2 = np.frombuffer(sig2, np.int8) + return normalized_distance(sig1, sig2) + + +# Setup fixtures +@pytest.fixture(scope="module") +def image_array(): + return [ + (b"\x00\x00\x07\x20\x00\x00\x03\x21\x08\x02\x00\x00\x00", "0"), + (b"\x00\x50\x07\x60\x00\x00\x03\x21\x06\x02\x00\x00\x00", "1"), + (b"\x00\x00\x07\x20\x00\x00\x03\x21\x08\x02\x00\x08\x00", "2"), + (b"\x00\x50\x07\x60\x00\x00\x03\x21\x06\x02\x00\x60\x00", "3"), + (b"\x00\x00\x07\x20\x00\x00\x03\x21\x08\x02\x00\x30\x01", "4"), + (b"\x00\x50\x07\x60\x00\x00\x03\x21\x06\x02\x00\x00\x10", "5"), + ] + def contains_same(s, t): s, t = set(s), set(t) @@ -29,9 +61,11 @@ def contains_same(s, t): def all_close(s, t, tol=1e-8): # Ignores inf and nan values... - return all(abs(a - b) < tol for a, b in zip(s, t) - if not isinf(a) and not isinf(b) and - not isnan(a) and not isnan(b)) + return all( + abs(a - b) < tol + for a, b in zip(s, t) + if not isinf(a) and not isinf(b) and not isnan(a) and not isnan(b) + ) def rand_tuple(dim=2): @@ -89,7 +123,7 @@ def test_sub(self, PointSet): assert end["neigh"] != ps[-1] # This is risky, because it might legitimately be the same...? assert start["dist"] != end["dist"] - assert len(fp) == len(ps)-1 + assert len(fp) == len(ps) - 1 with pytest.raises(ValueError): fp -= rand_tuple(len(ps[0])) @@ -124,7 +158,10 @@ def test_all_closest_pairs(self, PointSet): # dc = fp.closest_pair_divide_conquer() # Maybe different ordering assert abs(cp[0] - bf[0]) < 1e-8 assert cp[1] == bf[1] # Tuple comparison - test = min([(fp.dist(a, b), (a, b)) for a, b in combinations(ps, r=2)], key=itemgetter(0)) + test = min( + [(fp.dist(a, b), (a, b)) for a, b in combinations(ps, r=2)], + key=itemgetter(0), + ) assert abs(cp[0] - test[0]) < 1e-8 assert sorted(cp[1]) == sorted(test[1]) # Tuple comparison # assert abs(dc[0] - cp[0]) < 1e-8 # Compare distance @@ -137,7 +174,7 @@ def test_find_neighbor_and_sdist(self, PointSet): rando = rand_tuple(len(ps[0])) neigh = fp._find_neighbor(rando) # Abusing find_neighbor! dist = fp.dist(rando, neigh["neigh"]) - assert abs(dist - neigh["dist"]) < 1e-8 + assert abs(dist - neigh["dist"]) < 1e-8 assert len(fp) == len(ps) # Make sure we didn't add a point... l = [(fp.dist(a, b), b) for a, b in zip(cycle([rando]), ps)] res = min(l, key=itemgetter(0)) @@ -150,7 +187,7 @@ def test_find_neighbor_and_sdist(self, PointSet): def test_cluster(self, PointSet): ps = PointSet fp = FastPair().build(ps) - for i in range(len(fp)-1): + for i in range(len(fp) - 1): # Version one dist, (a, b) = fp.closest_pair() c = interact(a, b) @@ -158,8 +195,10 @@ def test_cluster(self, PointSet): fp -= a fp += c # Order gets reversed here... - d, (e, f) = min([(fp.dist(i, j), (i, j)) for i, j in - combinations(ps, r=2)], key=itemgetter(0)) + d, (e, f) = min( + [(fp.dist(i, j), (i, j)) for i, j in combinations(ps, r=2)], + key=itemgetter(0), + ) g = interact(e, f) assert abs(d - dist) < 1e-8 assert (a == e or b == e) and (b == f or a == f) @@ -184,8 +223,8 @@ def test_update_point(self, PointSet): l = [(fp.dist(a, b), b) for a, b in zip(cycle([new]), ps)] res = min(l, key=itemgetter(0)) neigh = fp.neighbors[new] - #assert abs(res[0] - neigh["dist"]) < 1e-8 - #assert res[1] == neigh["neigh"] + # assert abs(res[0] - neigh["dist"]) < 1e-8 + # assert res[1] == neigh["neigh"] def test_merge_closest(self): # This needs to be 'fleshed' out more... lots of things to test here @@ -201,10 +240,60 @@ def test_merge_closest(self): fp._update_point(a, new) n -= 1 assert len(fp) == 1 == n - points = [(0.69903599809571437, 0.52457534006594131, - 0.7614753848101149, 0.37011695654655385)] + points = [ + ( + 0.69903599809571437, + 0.52457534006594131, + 0.7614753848101149, + 0.37011695654655385, + ) + ] assert all_close(fp.points[0], points[0]) # Should have < 2 points now... with pytest.raises(ValueError): fp.closest_pair() # fp2.closest_pair() + + def test_call_and_closest_pair_min_points(self, image_array): + ps = image_array + fp = FastPair(dist=image_distance) + for p in ps: + fp += p + assert fp.initialized is False + assert len(fp) == 6 + cp = fp.closest_pair() + bf = fp.closest_pair_brute_force() + assert fp() == cp + assert abs(cp[0] - bf[0]) < 1e-8 + assert cp[1] == bf[1] + + def test_iter(self, PointSet): + ps = PointSet + fp = FastPair().build(ps) + assert fp.min_points == 10 + assert isinstance(fp.dist, FunctionType) + my_iter = iter(fp) + assert next(my_iter) in set(ps) + assert fp[ps[0]].neigh in set(ps) + + try: + myitem = fp[(2, 3, 4)] + except KeyError as err: + print(err) + + fp[ps[0]] = fp[ps[0]].neigh + try: + fp[(2, 3, 4)] = fp[ps[0]].neigh + except KeyError as err: + print(err) + + def test_update_point_less_points(self, PointSet): + ps = PointSet + fp = FastPair() + for p in ps[:9]: + fp += p + assert fp.initialized is False + old = ps[0] # Just grab the first point... + new = rand_tuple(len(ps[0])) + res = fp._update_point(old, new) + assert len(fp) == 1