Skip to content

Commit

Permalink
Add multithread search to BF index (#425)
Browse files Browse the repository at this point in the history
* Add multithread search for BF index
  • Loading branch information
dyashuni authored May 12, 2023
1 parent dccd4f9 commit 6aac477
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 7 deletions.
45 changes: 38 additions & 7 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ class BFIndex {
int dim;
bool index_inited;
bool normalize;
int num_threads_default;

hnswlib::labeltype cur_l;
hnswlib::BruteforceSearch<dist_t>* alg;
Expand All @@ -739,6 +740,8 @@ class BFIndex {
}
alg = NULL;
index_inited = false;

num_threads_default = std::thread::hardware_concurrency();
}


Expand All @@ -749,6 +752,21 @@ class BFIndex {
}


size_t getMaxElements() const {
return alg->maxelements_;
}


size_t getCurrentCount() const {
return alg->cur_element_count;
}


void set_num_threads(int num_threads) {
this->num_threads_default = num_threads;
}


void init_new_index(const size_t maxElements) {
if (alg) {
throw std::runtime_error("The index is already initiated.");
Expand Down Expand Up @@ -820,15 +838,19 @@ class BFIndex {
py::object knnQuery_return_numpy(
py::object input,
size_t k = 1,
int num_threads = -1,
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype *data_numpy_l;
dist_t *data_numpy_d;
size_t rows, features;

if (num_threads <= 0)
num_threads = num_threads_default;

{
py::gil_scoped_release l;

get_input_array_shapes(buffer, &rows, &features);

data_numpy_l = new hnswlib::labeltype[rows * k];
Expand All @@ -837,16 +859,16 @@ class BFIndex {
CustomFilterFunctor idFilter(filter);
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;

for (size_t row = 0; row < rows; row++) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void *) items.data(row), k, p_idFilter);
(void*)items.data(row), k, p_idFilter);
for (int i = k - 1; i >= 0; i--) {
auto &result_tuple = result.top();
auto& result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
data_numpy_l[row * k + i] = result_tuple.second;
result.pop();
}
}
});
}

py::capsule free_when_done_l(data_numpy_l, [](void *f) {
Expand Down Expand Up @@ -957,13 +979,22 @@ PYBIND11_PLUGIN(hnswlib) {
py::class_<BFIndex<float>>(m, "BFIndex")
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none())
.def("knn_query",
&BFIndex<float>::knnQuery_return_numpy,
py::arg("data"),
py::arg("k") = 1,
py::arg("num_threads") = -1,
py::arg("filter") = py::none())
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
.def("set_num_threads", &BFIndex<float>::set_num_threads, py::arg("num_threads"))
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))
.def("load_index", &BFIndex<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0)
.def("__repr__", [](const BFIndex<float> &a) {
return "<hnswlib.BFIndex(space='" + a.space_name + "', dim="+std::to_string(a.dim)+")>";
});
})
.def("get_max_elements", &BFIndex<float>::getMaxElements)
.def("get_current_count", &BFIndex<float>::getCurrentCount)
.def_readwrite("num_threads", &BFIndex<float>::num_threads_default);
return m.ptr();
}
49 changes: 49 additions & 0 deletions python_bindings/tests/bindings_test_bf_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import numpy as np

import hnswlib


class RandomSelfTestCase(unittest.TestCase):
def testBFIndex(self):

dim = 16
num_elements = 10000
num_queries = 1000
k = 20

# Generating sample data
data = np.float32(np.random.random((num_elements, dim)))

# Declaring index
bf_index = hnswlib.BFIndex(space='l2', dim=dim) # possible options are l2, cosine or ip
bf_index.init_index(max_elements=num_elements)

num_threads = 8
bf_index.set_num_threads(num_threads) # by default using all available cores

print(f"Adding all elements {num_elements}")
bf_index.add_items(data)

self.assertEqual(bf_index.num_threads, num_threads)
self.assertEqual(bf_index.get_max_elements(), num_elements)
self.assertEqual(bf_index.get_current_count(), num_elements)

queries = np.float32(np.random.random((num_queries, dim)))
print("Searching nearest neighbours")
labels, distances = bf_index.knn_query(queries, k=k)

print("Checking results")
for i in range(num_queries):
query = queries[i]
sq_dists = (data - query)**2
dists = np.sum(sq_dists, axis=1)
labels_gt = np.argsort(dists)[:k]
dists_gt = dists[labels_gt]
dists_bf = distances[i]
# we can compare labels but because of numeric errors in distance calculation in C++ and numpy
# sometimes we get different order of labels, therefore we compare distances
max_diff_with_gt = np.max(np.abs(dists_gt - dists_bf))

self.assertTrue(max_diff_with_gt < 1e-5)

0 comments on commit 6aac477

Please sign in to comment.