Skip to content

Commit

Permalink
get rid of the poorly architected list_distance option
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo Dell'Amico committed Jan 29, 2018
1 parent 687759b commit 5125346
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 55 deletions.
42 changes: 8 additions & 34 deletions flexible_clustering/fishdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class FISHDBC:
"""Flexible Incremental Scalable Hierarchical Density-Based Clustering."""

def __init__(self, d, min_samples=5, m=5, ef=200, m0=None, level_mult=None,
heuristic=True, balanced_add=True, list_distance=False):
heuristic=True, balanced_add=True):
"""Setup the algorithm. The only mandatory parameter is d, the
dissimilarity function. min_samples is passed to hdbscan, and
the other parameters are all passed to HNSW."""
Expand Down Expand Up @@ -92,41 +92,15 @@ def __init__(self, d, min_samples=5, m=5, ef=200, m0=None, level_mult=None,
self._distance_cache = distance_cache = {}

# decorated_d will cache the computed distances in distance_cache.
if not list_distance: # d is defined to work on scalars
def decorated_d(i, j):
# assert i == len(data) - 1 # 1st argument is the new item
if j in distance_cache:
return distance_cache[j]
distance_cache[j] = dist = d(data[i], data[j])
return dist
if list_distance: # d is defined to work on a scalar and a list
def decorated_d(i, js):
assert i == len(data) - 1 # 1st argument is the new item
known = []
unknown_j, unknown_items = [], []
for pos, j in enumerate(js):
k = j in distance_cache
known.append(k)
if not k:
unknown_j.append(j)
unknown_items.append(k)
new_d = d(data[i], unknown_items)
for j, dist in zip(unknown_j, new_d):
distance_cache[j] = dist
old_d = (distance_cache[j] for j, k in zip(js, known) if k)
new_d = iter(new_d)

res = []
for k in known:
if k:
res.append(next(old_d))
else:
res.append(next(new_d))
return res
def decorated_d(i, j):
# assert i == len(data) - 1 # 1st argument is the new item
if j in distance_cache:
return distance_cache[j]
distance_cache[j] = dist = d(data[i], data[j])
return dist

# We create the HNSW
the_hnsw = hnsw.HNSW(decorated_d, m, ef, m0, level_mult, heuristic,
list_distance)
the_hnsw = hnsw.HNSW(decorated_d, m, ef, m0, level_mult, heuristic)
self._hnsw_add = (the_hnsw.balanced_add if balanced_add
else the_hnsw.add)

Expand Down
35 changes: 14 additions & 21 deletions flexible_clustering/hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,14 @@ class HNSW(object):
# where j is a neighbor of i and dist is distance

def __init__(self, d, m=5, ef=200, m0=None, level_mult=None,
heuristic=True, list_distance=False):
heuristic=True):
"""d the dissimilarity function
See other parameters in http://arxiv.org/pdf/1603.09320v2.pdf"""

self.data = []

if list_distance:
def d1(x, y):
return d(x, [y])[0]
self.distance = d1
self.list_distance = d
else:
self.distance = d
def vd(x, ys):
return [d(x, y) for y in ys]
self.list_distance = vd
self.distance = d

self._m = m
self._ef = ef
Expand Down Expand Up @@ -219,7 +210,7 @@ def search(self, q, k=None, ef=None):
def _search_graph_ef1(self, q, entry, dist, g):
"""Equivalent to _search_graph when ef=1."""

ld = self.list_distance
d = self.distance
data = self.data

best = entry
Expand All @@ -231,10 +222,11 @@ def _search_graph_ef1(self, q, entry, dist, g):
dist, c = heappop(candidates)
if dist > best_dist:
break
edges = [e for e in g[c] if e not in visited]
visited.update(edges)
dists = ld(q, [data[e] for e in edges])
for e, dist in zip(edges, dists):
for e in g[c]:
if e in visited:
continue
visited.add(e)
dist = d(q, data[e])
if dist < best_dist:
best = e
best_dist = dist
Expand All @@ -245,7 +237,7 @@ def _search_graph_ef1(self, q, entry, dist, g):

def _search_graph(self, q, ep, g, ef):

ld = self.list_distance
d = self.distance
data = self.data

candidates = [(-mdist, p) for mdist, p in ep]
Expand All @@ -259,10 +251,11 @@ def _search_graph(self, q, ep, g, ef):
if dist > ref:
break

edges = [e for e in g[c] if e not in visited]
visited.update(edges)
dists = ld(q, [data[e] for e in edges])
for e, dist in zip(edges, dists):
for e in g[c]:
if e in visited:
continue
visited.add(e)
dist = d(q, data[e])
mdist = -dist
if len(ep) < ef:
heappush(candidates, (dist, e))
Expand Down

0 comments on commit 5125346

Please sign in to comment.