From bf19c9cac473425041aa1144bc9dea37ed1c3943 Mon Sep 17 00:00:00 2001 From: Carson Farmer Date: Sun, 3 Jul 2016 17:54:11 -0600 Subject: [PATCH] Simplify and move API around MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No substantial changes hereā€¦ --- fastpair/__init__.py | 2 +- fastpair/base.py | 85 ++++++++++++++++------------------ fastpair/test/test_fastpair.py | 46 +++++++----------- 3 files changed, 59 insertions(+), 74 deletions(-) diff --git a/fastpair/__init__.py b/fastpair/__init__.py index 9b5d511..fb2ddc3 100644 --- a/fastpair/__init__.py +++ b/fastpair/__init__.py @@ -10,4 +10,4 @@ # Copyright (c) 2002-2015, David Eppstein # Licensed under the MIT Licence (http://opensource.org/licenses/MIT). -from .base import FastPair, interact +from .base import FastPair diff --git a/fastpair/base.py b/fastpair/base.py index 1bc7a37..f5c8749 100644 --- a/fastpair/base.py +++ b/fastpair/base.py @@ -38,29 +38,22 @@ from operator import itemgetter from collections import defaultdict import scipy.spatial.distance as dist -from scipy import mean as _mean, array as _array -__all__ = ["interact", "FastPair", "dist", "default_dist"] -default_dist = dist.euclidean +__all__ = ["FastPair", "dist"] -def interact(u, v): - """Compute element-wise mean(s) from two arrays.""" - return tuple(_mean(_array([u, v]), axis=0)) - - -class _adict(dict): +class attrdict(dict): """Simple dict with support for accessing elements as attributes.""" def __init__(self, *args, **kwargs): - super(_adict, self).__init__(*args, **kwargs) + super(attrdict, self).__init__(*args, **kwargs) self.__dict__ = self class FastPair(object): """FastPair 'sketch' class. """ - def __init__(self, min_points=10, dist=default_dist, merge=interact): + def __init__(self, min_points=10, dist=dist.euclidean): """Initialize an empty FastPair data-structure. Parameters @@ -75,19 +68,11 @@ def __init__(self, min_points=10, dist=default_dist, merge=interact): from `scipy.spatial.distance` will do the trick. By default, the Euclidean distance function is used. This function should play nicely with the `merge` function. - merge : function, default=scipy.mean - Can be any Python function that returns a single 'point' from two - input 'points'. By default, the element-wise mean(s) from two input - point arrays is used. If a user has a 'special' point class; for - example, one that represents cluster centroids, then the user can - specify a function that returns valid clusters. This function - should play nicely with the `dist` function. """ self.min_points = min_points self.dist = dist - self.merge = merge self.initialized = False # Has the data-structure been initialized? - self.neighbors = defaultdict(_adict) # Dict of neighbor points and dists + self.neighbors = defaultdict(attrdict) # Dict of neighbor points and dists self.points = list() # Internal point set; entries may be non-unique def __add__(self, p): @@ -129,6 +114,16 @@ def __contains__(self, p): def __iter__(self): return iter(self.points) + def __getitem__(self, item): + if not item in self: + raise KeyError("{} not found".format(item)) + return self.neighbors[item] + + def __setitem__(self, item, value): + if not item in self: + raise KeyError("{} not found".format(item)) + self._update_point(item, value) + def build(self, points=None): """Build a FastPairs data-structure from a set of (new) points. @@ -179,7 +174,7 @@ def closest_pair(self): """ if len(self) < 2: raise ValueError("Must have `npoints >= 2` to form a pair.") - elif len(self) < self.min_points: + elif not self.initialized: return self.closest_pair_brute_force() a = self.points[0] # Start with first point d = self.neighbors[a].dist @@ -194,6 +189,24 @@ def closest_pair_brute_force(self): """Find closest pair using brute-force algorithm.""" return _closest_pair_brute_force(self.points) + def sdist(self, p): + """Compute distances from input to all other points in data-structure. + + This returns an iterator over all other points and their distance + from the input point `p`. The resulting iterator returns tuples with + the first item giving the distance, and the second item giving in + neighbor point. The `min` of this iterator is essentially a brute- + force 'nearest-neighbor' calculation. To do this, supply `itemgetter` + (or a lambda version) as the `key` argument to `min`. + + Examples + -------- + >>> fp = FastPair().build(points) + >>> min(fp.sdist(point), key=itemgetter(0)) + """ + return ((self.dist(a, b), b) for a, b in + zip(cycle([p]), self.points) if b != a) + def _find_neighbor(self, p): """Find and update nearest neighbor of a given point.""" # If no neighbors available, set flag for `_update_point` to find @@ -216,12 +229,6 @@ def _find_neighbor(self, p): self.neighbors[p].neigh = q return dict(self.neighbors[p]) # Return plain ol' dict - def merge_closest(self): - dist, (a, b) = self.closest_pair() - c = self.merge(a, b) - self -= b - return self._update_point(a, c) - def _update_point(self, old, new): """Update point location, neighbors, and distances. @@ -255,26 +262,14 @@ def _update_point(self, old, new): self.neighbors[q].dist = d return dict(self.neighbors[new]) - def sdist(self, p): - """Compute distances from input to all other points in data-structure. - - This returns an iterator over all other points and their distance - from the input point `p`. The resulting iterator returns tuples with - the first item giving the distance, and the second item giving in - neighbor point. The `min` of this iterator is essentially a brute- - force 'nearest-neighbor' calculation. To do this, supply `itemgetter` - (or a lambda version) as the `key` argument to `min`. - - Examples - -------- - >>> fp = FastPair().build(points) - >>> min(fp.sdist(point), key=itemgetter(0)) - """ - return ((self.dist(a, b), b) for a, b in - zip(cycle([p]), self.points) if b != a) + # def merge_closest(self): + # dist, (a, b) = self.closest_pair() + # c = self.merge(a, b) + # self -= b + # return self._update_point(a, c) -def _closest_pair_brute_force(pts, dst=default_dist): +def _closest_pair_brute_force(pts, dst=dist.euclidean): """Compute closest pair of points using brute-force algorithm. Notes diff --git a/fastpair/test/test_fastpair.py b/fastpair/test/test_fastpair.py index 9a05cd9..e8b063a 100644 --- a/fastpair/test/test_fastpair.py +++ b/fastpair/test/test_fastpair.py @@ -17,7 +17,7 @@ from itertools import cycle, combinations, groupby import random import pytest -from fastpair import FastPair, interact +from fastpair import FastPair from math import isinf, isnan from scipy import mean, array, unique @@ -38,13 +38,9 @@ def rand_tuple(dim=2): return tuple([random.random() for _ in range(dim)]) -def to_codebook(X, part): - """Calculates centroids according to flat cluster assignment.""" - codebook = [] - X = array(X) - for i in unique(part): - codebook.append(tuple(X[part == i].mean(0))) - return codebook +def interact(u, v): + """Compute element-wise mean(s) from two arrays.""" + return tuple(mean(array([u, v]), axis=0)) # Setup fixtures @@ -192,29 +188,23 @@ def test_update_point(self): assert res[1] == neigh["neigh"] def test_merge_closest(self): - # Still failing sometimes... - ps = PointSet() - fp1 = FastPair().build(ps) - fp2 = FastPair().build(ps) + # This needs to be 'fleshed' out more... lots of things to test here + random.seed(1234) + ps = PointSet(d=4) + fp = FastPair().build(ps) + # fp2 = FastPair().build(ps) n = len(ps) while n >= 2: - dist, (a, b) = fp1.closest_pair() + dist, (a, b) = fp.closest_pair() new = interact(a, b) - fp1 -= b # Drop b - fp1._update_point(a, new) - fp2.merge_closest() + fp -= b # Drop b + fp._update_point(a, new) n -= 1 - assert len(fp1) == len(fp2) == 1 # == len(fp2) - assert fp1.points == fp2.points # == fp2.points - # Compare points - assert contains_same(list(fp1.neighbors.keys()), list(fp2.neighbors.keys())) - # Compare neighbors - assert contains_same([n["neigh"] for n in fp1.neighbors.values()], - [n["neigh"] for n in fp2.neighbors.values()]) - # Compare dists - assert all_close([n["dist"] for n in fp1.neighbors.values()], - [n["dist"] for n in fp2.neighbors.values()]) + assert len(fp) == 1 == n + points = [(0.69903599809571437, 0.52457534006594131, + 0.7614753848101149, 0.37011695654655385)] + assert all_close(fp.points[0], points[0]) # Should have < 2 points now... with pytest.raises(ValueError): - fp1.closest_pair() - fp2.closest_pair() + fp.closest_pair() + # fp2.closest_pair()