Skip to content

Commit

Permalink
Merge pull request #431 from Jeadie/main
Browse files Browse the repository at this point in the history
Documentation, Typing and Refactors
  • Loading branch information
erikbern authored Jul 17, 2023
2 parents 8799a82 + 55f3dd3 commit 61378f4
Show file tree
Hide file tree
Showing 11 changed files with 1,032 additions and 454 deletions.
75 changes: 59 additions & 16 deletions ann_benchmarks/algorithms/base/module.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,79 @@
from multiprocessing.pool import ThreadPool

from typing import Any, Dict, Optional
import psutil

import numpy

class BaseANN(object):
def done(self):
"""Base class/interface for Approximate Nearest Neighbors (ANN) algorithms used in benchmarking."""

def done(self) -> None:
"""Clean up BaseANN once it is finished being used."""
pass

def get_memory_usage(self):
"""Return the current memory usage of this algorithm instance
(in kilobytes), or None if this information is not available."""
# return in kB for backwards compatibility
def get_memory_usage(self) -> Optional[float]:
"""Returns the current memory usage of this ANN algorithm instance in kilobytes.
Returns:
float: The current memory usage in kilobytes (for backwards compatibility), or None if
this information is not available.
"""

return psutil.Process().memory_info().rss / 1024

def fit(self, X):
def fit(self, X: numpy.array) -> None:
"""Fits the ANN algorithm to the provided data.
Note: This is a placeholder method to be implemented by subclasses.
Args:
X (numpy.array): The data to fit the algorithm to.
"""
pass

def query(self, q, n):
def query(self, q: numpy.array, n: int) -> numpy.array:
"""Performs a query on the algorithm to find the nearest neighbors.
Note: This is a placeholder method to be implemented by subclasses.
Args:
q (numpy.array): The vector to find the nearest neighbors of.
n (int): The number of nearest neighbors to return.
Returns:
numpy.array: An array of indices representing the nearest neighbors.
"""
return [] # array of candidate indices

def batch_query(self, X, n):
"""Provide all queries at once and let algorithm figure out
how to handle it. Default implementation uses a ThreadPool
to parallelize query processing."""
def batch_query(self, X: numpy.array, n: int) -> None:
"""Performs multiple queries at once and lets the algorithm figure out how to handle it.
The default implementation uses a ThreadPool to parallelize query processing.
Args:
X (numpy.array): An array of vectors to find the nearest neighbors of.
n (int): The number of nearest neighbors to return for each query.
Returns:
None: self.get_batch_results() is responsible for retrieving batch result
"""
pool = ThreadPool()
self.res = pool.map(lambda q: self.query(q, n), X)

def get_batch_results(self):
def get_batch_results(self) -> numpy.array:
"""Retrieves the results of a batch query (from .batch_query()).
Returns:
numpy.array: An array of nearest neighbor results for each query in the batch.
"""
return self.res

def get_additional(self):
def get_additional(self) -> Dict[str, Any]:
"""Returns additional attributes to be stored with the result.
Returns:
dict: A dictionary of additional attributes.
"""
return {}

def __str__(self):
return self.name
def __str__(self) -> str:
return self.name
8 changes: 4 additions & 4 deletions ann_benchmarks/algorithms/bruteforce/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy
import sklearn.neighbors

from ...distance import metrics as pd
from ...distance import compute_distance, metrics as pd
from ..base.module import BaseANN

class BruteForce(BaseANN):
Expand Down Expand Up @@ -87,17 +87,17 @@ def query_with_distances(self, v, n):
# Just compute hamming distance using euclidean distance
dists = self.lengths - 2 * numpy.dot(self.index, v)
elif self._metric == "jaccard":
dists = [pd[self._metric]["distance"](v, e) for e in self.index]
dists = [pd[self._metric].distance(v, e) for e in self.index]
else:
# shouldn't get past the constructor!
assert False, "invalid metric"
# partition-sort by distance, get `n` closest
nearest_indices = numpy.argpartition(dists, n)[:n]
indices = [idx for idx in nearest_indices if pd[self._metric]["distance_valid"](dists[idx])]
indices = [idx for idx in nearest_indices if pd[self._metric].distance_valid(dists[idx])]

def fix(index):
ep = self.index[index]
ev = v
return (index, pd[self._metric]["distance"](ep, ev))
return (index, pd[self._metric].distance(ep, ev))

return map(fix, indices)
2 changes: 1 addition & 1 deletion ann_benchmarks/algorithms/qdrant/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ float:
[ 8, 16, 24, 32, 40, 48, 64, 72 ], #m
[ 64, 128, 256, 512 ], #ef_construct
]
query-args: [
query_args: [
[null, 8, 16, 32, 64, 128, 256, 512, 768], #hnsw_ef
[True, False], # re-score
]
32 changes: 8 additions & 24 deletions ann_benchmarks/data.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,25 @@
import numpy
from typing import List, Union, Set, FrozenSet


def float_parse_entry(line):
def float_parse_entry(line: str) -> List[float]:
return [float(x) for x in line.strip().split()]


def float_unparse_entry(entry):
def float_unparse_entry(entry: List[float]) -> str:
return " ".join(map(str, entry))


def int_parse_entry(line):
def int_parse_entry(line: str) -> FrozenSet[int]:
return frozenset([int(x) for x in line.strip().split()])


def int_unparse_entry(entry):
def int_unparse_entry(entry: Union[Set[int], FrozenSet[int]]) -> str:
return " ".join(map(str, map(int, entry)))


def bit_parse_entry(line):
def bit_parse_entry(line: str) -> List[bool]:
return [bool(int(x)) for x in list(line.strip().replace(" ", "").replace("\t", ""))]


def bit_unparse_entry(entry):
return " ".join(map(lambda el: "1" if el else "0", entry))


type_info = {
"float": {
"type": numpy.float,
"parse_entry": float_parse_entry,
"unparse_entry": float_unparse_entry,
"finish_entries": numpy.vstack,
},
"bit": {"type": numpy.bool_, "parse_entry": bit_parse_entry, "unparse_entry": bit_unparse_entry},
"int": {
"type": numpy.object,
"parse_entry": int_parse_entry,
"unparse_entry": int_unparse_entry,
},
}
def bit_unparse_entry(entry: List[bool]) -> str:
return " ".join(map(lambda el: "1" if el else "0", entry))
Loading

0 comments on commit 61378f4

Please sign in to comment.