Skip to content

Commit

Permalink
Merge pull request #17 from formbay/min_points_error
Browse files Browse the repository at this point in the history
Fixed error when n of points is less than self.min_points
  • Loading branch information
carsonfarmer authored May 18, 2021
2 parents b8b4d30 + e6fb05f commit d3170fd
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 21 deletions.
17 changes: 9 additions & 8 deletions fastpair/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

class attrdict(dict):
"""Simple dict with support for accessing elements as attributes."""

def __init__(self, *args, **kwargs):
super(attrdict, self).__init__(*args, **kwargs)
self.__dict__ = self
Expand All @@ -53,6 +54,7 @@ def __init__(self, *args, **kwargs):
class FastPair(object):
"""FastPair 'sketch' class.
"""

def __init__(self, min_points=10, dist=dist.euclidean):
"""Initialize an empty FastPair data-structure.
Expand Down Expand Up @@ -145,11 +147,11 @@ def build(self, points=None):
np = len(self)

# Go through and find all neighbors, placing then in a conga line
for i in range(np-1):
for i in range(np - 1):
# Find neighbor to p[0] to start
nbr = i + 1
nbd = float("inf")
for j in range(i+1, np):
for j in range(i + 1, np):
d = self.dist(self.points[i], self.points[j])
if d < nbd:
nbr = j
Expand All @@ -161,8 +163,8 @@ def build(self, points=None):
self.points[i + 1] = self.neighbors[self.points[i]].neigh
# No more neighbors, terminate conga line.
# Last person on the line has no neigbors :(
self.neighbors[self.points[np-1]].neigh = self.points[np-1]
self.neighbors[self.points[np-1]].dist = float("inf")
self.neighbors[self.points[np - 1]].neigh = self.points[np - 1]
self.neighbors[self.points[np - 1]].dist = float("inf")
self.initialized = True
return self

Expand All @@ -187,7 +189,7 @@ def closest_pair(self):

def closest_pair_brute_force(self):
"""Find closest pair using brute-force algorithm."""
return _closest_pair_brute_force(self.points)
return _closest_pair_brute_force(self.points, self.dist)

def sdist(self, p):
"""Compute distances from input to all other points in data-structure.
Expand All @@ -204,8 +206,7 @@ def sdist(self, p):
>>> fp = FastPair().build(points)
>>> min(fp.sdist(point), key=itemgetter(0))
"""
return ((self.dist(a, b), b) for a, b in
zip(cycle([p]), self.points) if b != a)
return ((self.dist(a, b), b) for a, b in zip(cycle([p]), self.points) if b != a)

def _find_neighbor(self, p):
"""Find and update nearest neighbor of a given point."""
Expand All @@ -221,7 +222,7 @@ def _find_neighbor(self, p):
self.neighbors[p].neigh = self.points[first_nbr]
self.neighbors[p].dist = self.dist(p, self.neighbors[p].neigh)
# Now test whether each other point is closer
for q in self.points[first_nbr+1:]:
for q in self.points[first_nbr + 1 :]:
if p != q:
d = self.dist(p, q)
if d < self.neighbors[p].dist:
Expand Down
115 changes: 102 additions & 13 deletions fastpair/test/test_fastpair.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,38 @@
from math import isinf, isnan

from scipy import mean, array, unique
import numpy as np


def normalized_distance(_a, _b):
b = _b.astype(int)
a = _a.astype(int)
norm_diff = np.linalg.norm(b - a)
norm1 = np.linalg.norm(b)
norm2 = np.linalg.norm(a)
return norm_diff / (norm1 + norm2)


def image_distance(image1, image2):
(sig1, _) = image1
(sig2, _) = image2
sig1 = np.frombuffer(sig1, np.int8)
sig2 = np.frombuffer(sig2, np.int8)
return normalized_distance(sig1, sig2)


# Setup fixtures
@pytest.fixture(scope="module")
def image_array():
return [
(b"\x00\x00\x07\x20\x00\x00\x03\x21\x08\x02\x00\x00\x00", "0"),
(b"\x00\x50\x07\x60\x00\x00\x03\x21\x06\x02\x00\x00\x00", "1"),
(b"\x00\x00\x07\x20\x00\x00\x03\x21\x08\x02\x00\x08\x00", "2"),
(b"\x00\x50\x07\x60\x00\x00\x03\x21\x06\x02\x00\x60\x00", "3"),
(b"\x00\x00\x07\x20\x00\x00\x03\x21\x08\x02\x00\x30\x01", "4"),
(b"\x00\x50\x07\x60\x00\x00\x03\x21\x06\x02\x00\x00\x10", "5"),
]


def contains_same(s, t):
s, t = set(s), set(t)
Expand All @@ -29,9 +61,11 @@ def contains_same(s, t):

def all_close(s, t, tol=1e-8):
# Ignores inf and nan values...
return all(abs(a - b) < tol for a, b in zip(s, t)
if not isinf(a) and not isinf(b) and
not isnan(a) and not isnan(b))
return all(
abs(a - b) < tol
for a, b in zip(s, t)
if not isinf(a) and not isinf(b) and not isnan(a) and not isnan(b)
)


def rand_tuple(dim=2):
Expand Down Expand Up @@ -89,7 +123,7 @@ def test_sub(self, PointSet):
assert end["neigh"] != ps[-1]
# This is risky, because it might legitimately be the same...?
assert start["dist"] != end["dist"]
assert len(fp) == len(ps)-1
assert len(fp) == len(ps) - 1
with pytest.raises(ValueError):
fp -= rand_tuple(len(ps[0]))

Expand Down Expand Up @@ -124,7 +158,10 @@ def test_all_closest_pairs(self, PointSet):
# dc = fp.closest_pair_divide_conquer() # Maybe different ordering
assert abs(cp[0] - bf[0]) < 1e-8
assert cp[1] == bf[1] # Tuple comparison
test = min([(fp.dist(a, b), (a, b)) for a, b in combinations(ps, r=2)], key=itemgetter(0))
test = min(
[(fp.dist(a, b), (a, b)) for a, b in combinations(ps, r=2)],
key=itemgetter(0),
)
assert abs(cp[0] - test[0]) < 1e-8
assert sorted(cp[1]) == sorted(test[1]) # Tuple comparison
# assert abs(dc[0] - cp[0]) < 1e-8 # Compare distance
Expand All @@ -137,7 +174,7 @@ def test_find_neighbor_and_sdist(self, PointSet):
rando = rand_tuple(len(ps[0]))
neigh = fp._find_neighbor(rando) # Abusing find_neighbor!
dist = fp.dist(rando, neigh["neigh"])
assert abs(dist - neigh["dist"]) < 1e-8
assert abs(dist - neigh["dist"]) < 1e-8
assert len(fp) == len(ps) # Make sure we didn't add a point...
l = [(fp.dist(a, b), b) for a, b in zip(cycle([rando]), ps)]
res = min(l, key=itemgetter(0))
Expand All @@ -150,16 +187,18 @@ def test_find_neighbor_and_sdist(self, PointSet):
def test_cluster(self, PointSet):
ps = PointSet
fp = FastPair().build(ps)
for i in range(len(fp)-1):
for i in range(len(fp) - 1):
# Version one
dist, (a, b) = fp.closest_pair()
c = interact(a, b)
fp -= b # Drop b
fp -= a
fp += c
# Order gets reversed here...
d, (e, f) = min([(fp.dist(i, j), (i, j)) for i, j in
combinations(ps, r=2)], key=itemgetter(0))
d, (e, f) = min(
[(fp.dist(i, j), (i, j)) for i, j in combinations(ps, r=2)],
key=itemgetter(0),
)
g = interact(e, f)
assert abs(d - dist) < 1e-8
assert (a == e or b == e) and (b == f or a == f)
Expand All @@ -184,8 +223,8 @@ def test_update_point(self, PointSet):
l = [(fp.dist(a, b), b) for a, b in zip(cycle([new]), ps)]
res = min(l, key=itemgetter(0))
neigh = fp.neighbors[new]
#assert abs(res[0] - neigh["dist"]) < 1e-8
#assert res[1] == neigh["neigh"]
# assert abs(res[0] - neigh["dist"]) < 1e-8
# assert res[1] == neigh["neigh"]

def test_merge_closest(self):
# This needs to be 'fleshed' out more... lots of things to test here
Expand All @@ -201,10 +240,60 @@ def test_merge_closest(self):
fp._update_point(a, new)
n -= 1
assert len(fp) == 1 == n
points = [(0.69903599809571437, 0.52457534006594131,
0.7614753848101149, 0.37011695654655385)]
points = [
(
0.69903599809571437,
0.52457534006594131,
0.7614753848101149,
0.37011695654655385,
)
]
assert all_close(fp.points[0], points[0])
# Should have < 2 points now...
with pytest.raises(ValueError):
fp.closest_pair()
# fp2.closest_pair()

def test_call_and_closest_pair_min_points(self, image_array):
ps = image_array
fp = FastPair(dist=image_distance)
for p in ps:
fp += p
assert fp.initialized is False
assert len(fp) == 6
cp = fp.closest_pair()
bf = fp.closest_pair_brute_force()
assert fp() == cp
assert abs(cp[0] - bf[0]) < 1e-8
assert cp[1] == bf[1]

def test_iter(self, PointSet):
ps = PointSet
fp = FastPair().build(ps)
assert fp.min_points == 10
assert isinstance(fp.dist, FunctionType)
my_iter = iter(fp)
assert next(my_iter) in set(ps)
assert fp[ps[0]].neigh in set(ps)

try:
myitem = fp[(2, 3, 4)]
except KeyError as err:
print(err)

fp[ps[0]] = fp[ps[0]].neigh
try:
fp[(2, 3, 4)] = fp[ps[0]].neigh
except KeyError as err:
print(err)

def test_update_point_less_points(self, PointSet):
ps = PointSet
fp = FastPair()
for p in ps[:9]:
fp += p
assert fp.initialized is False
old = ps[0] # Just grab the first point...
new = rand_tuple(len(ps[0]))
res = fp._update_point(old, new)
assert len(fp) == 1

0 comments on commit d3170fd

Please sign in to comment.