-
Notifications
You must be signed in to change notification settings - Fork 758
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #431 from Jeadie/main
Documentation, Typing and Refactors
- Loading branch information
Showing
11 changed files
with
1,032 additions
and
454 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,79 @@ | ||
from multiprocessing.pool import ThreadPool | ||
|
||
from typing import Any, Dict, Optional | ||
import psutil | ||
|
||
import numpy | ||
|
||
class BaseANN(object): | ||
def done(self): | ||
"""Base class/interface for Approximate Nearest Neighbors (ANN) algorithms used in benchmarking.""" | ||
|
||
def done(self) -> None: | ||
"""Clean up BaseANN once it is finished being used.""" | ||
pass | ||
|
||
def get_memory_usage(self): | ||
"""Return the current memory usage of this algorithm instance | ||
(in kilobytes), or None if this information is not available.""" | ||
# return in kB for backwards compatibility | ||
def get_memory_usage(self) -> Optional[float]: | ||
"""Returns the current memory usage of this ANN algorithm instance in kilobytes. | ||
Returns: | ||
float: The current memory usage in kilobytes (for backwards compatibility), or None if | ||
this information is not available. | ||
""" | ||
|
||
return psutil.Process().memory_info().rss / 1024 | ||
|
||
def fit(self, X): | ||
def fit(self, X: numpy.array) -> None: | ||
"""Fits the ANN algorithm to the provided data. | ||
Note: This is a placeholder method to be implemented by subclasses. | ||
Args: | ||
X (numpy.array): The data to fit the algorithm to. | ||
""" | ||
pass | ||
|
||
def query(self, q, n): | ||
def query(self, q: numpy.array, n: int) -> numpy.array: | ||
"""Performs a query on the algorithm to find the nearest neighbors. | ||
Note: This is a placeholder method to be implemented by subclasses. | ||
Args: | ||
q (numpy.array): The vector to find the nearest neighbors of. | ||
n (int): The number of nearest neighbors to return. | ||
Returns: | ||
numpy.array: An array of indices representing the nearest neighbors. | ||
""" | ||
return [] # array of candidate indices | ||
|
||
def batch_query(self, X, n): | ||
"""Provide all queries at once and let algorithm figure out | ||
how to handle it. Default implementation uses a ThreadPool | ||
to parallelize query processing.""" | ||
def batch_query(self, X: numpy.array, n: int) -> None: | ||
"""Performs multiple queries at once and lets the algorithm figure out how to handle it. | ||
The default implementation uses a ThreadPool to parallelize query processing. | ||
Args: | ||
X (numpy.array): An array of vectors to find the nearest neighbors of. | ||
n (int): The number of nearest neighbors to return for each query. | ||
Returns: | ||
None: self.get_batch_results() is responsible for retrieving batch result | ||
""" | ||
pool = ThreadPool() | ||
self.res = pool.map(lambda q: self.query(q, n), X) | ||
|
||
def get_batch_results(self): | ||
def get_batch_results(self) -> numpy.array: | ||
"""Retrieves the results of a batch query (from .batch_query()). | ||
Returns: | ||
numpy.array: An array of nearest neighbor results for each query in the batch. | ||
""" | ||
return self.res | ||
|
||
def get_additional(self): | ||
def get_additional(self) -> Dict[str, Any]: | ||
"""Returns additional attributes to be stored with the result. | ||
Returns: | ||
dict: A dictionary of additional attributes. | ||
""" | ||
return {} | ||
|
||
def __str__(self): | ||
return self.name | ||
def __str__(self) -> str: | ||
return self.name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,25 @@ | ||
import numpy | ||
from typing import List, Union, Set, FrozenSet | ||
|
||
|
||
def float_parse_entry(line): | ||
def float_parse_entry(line: str) -> List[float]: | ||
return [float(x) for x in line.strip().split()] | ||
|
||
|
||
def float_unparse_entry(entry): | ||
def float_unparse_entry(entry: List[float]) -> str: | ||
return " ".join(map(str, entry)) | ||
|
||
|
||
def int_parse_entry(line): | ||
def int_parse_entry(line: str) -> FrozenSet[int]: | ||
return frozenset([int(x) for x in line.strip().split()]) | ||
|
||
|
||
def int_unparse_entry(entry): | ||
def int_unparse_entry(entry: Union[Set[int], FrozenSet[int]]) -> str: | ||
return " ".join(map(str, map(int, entry))) | ||
|
||
|
||
def bit_parse_entry(line): | ||
def bit_parse_entry(line: str) -> List[bool]: | ||
return [bool(int(x)) for x in list(line.strip().replace(" ", "").replace("\t", ""))] | ||
|
||
|
||
def bit_unparse_entry(entry): | ||
return " ".join(map(lambda el: "1" if el else "0", entry)) | ||
|
||
|
||
type_info = { | ||
"float": { | ||
"type": numpy.float, | ||
"parse_entry": float_parse_entry, | ||
"unparse_entry": float_unparse_entry, | ||
"finish_entries": numpy.vstack, | ||
}, | ||
"bit": {"type": numpy.bool_, "parse_entry": bit_parse_entry, "unparse_entry": bit_unparse_entry}, | ||
"int": { | ||
"type": numpy.object, | ||
"parse_entry": int_parse_entry, | ||
"unparse_entry": int_unparse_entry, | ||
}, | ||
} | ||
def bit_unparse_entry(entry: List[bool]) -> str: | ||
return " ".join(map(lambda el: "1" if el else "0", entry)) |
Oops, something went wrong.