Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge 0.7.0 into master #436

Merged
merged 57 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
48bbed6
add throw statement to errors in BruteforceSearch
yoshoku Mar 21, 2022
5f074d5
remove unnecessary new operators
yoshoku Mar 21, 2022
eae971a
Merge pull request #375 from yoshoku/fix_errors
yurymalkov Mar 22, 2022
d51c312
chore(ALGO_PARAMS.md): Fix typo
PLNech May 2, 2022
492e15e
Highlight code
korzhenevski Jun 2, 2022
fb3a699
fix global linkage
MasterAler Jun 8, 2022
fb1885b
Merge pull request #380 from korzhenevski/patch-1
yurymalkov Jun 14, 2022
1431632
Merge pull request #383 from MasterAler/develop
yurymalkov Jun 16, 2022
632da8f
add missing quote
jlmelville Jul 30, 2022
d3197c5
initialize fields in constructor
jlmelville Jul 30, 2022
296b687
Merge pull request #395 from jlmelville/doc/py-quote-readme
yurymalkov Aug 2, 2022
25c7383
direct member initialize fields
jlmelville Aug 4, 2022
fdb1632
Merge pull request #396 from jlmelville/bug/ctor-init
yurymalkov Aug 5, 2022
406731d
Add rust implementation
jianshu93 Aug 9, 2022
765c4ab
Filter elements with an optional filtering function.
kishorenc Aug 15, 2022
ad3440c
Filter function should be sent the label and not the internal ID.
kishorenc Aug 19, 2022
4f6dcc3
Ensure that results are not empty when reading from top results.
kishorenc Aug 19, 2022
0fc42b0
Merge pull request #401 from jianshu93/patch-1
yurymalkov Aug 21, 2022
c5be3f5
Update port git_tester.py on Windows
dyashuni Aug 21, 2022
1c833a7
Make allowAllIds static.
kishorenc Aug 25, 2022
aaee13a
Use functor for filtering.
kishorenc Aug 26, 2022
b87f623
Explicitly check for filter functor being default.
kishorenc Aug 27, 2022
e8da5a0
Refactoring
dyashuni Aug 27, 2022
74bf4a3
Add cpp tests to CI
dyashuni Aug 27, 2022
e97b37c
Merge pull request #406 from dyashuni/add_tests_to_ci
yurymalkov Aug 28, 2022
bdd0220
Use shutil
dyashuni Aug 28, 2022
f0dedf3
Remove duplicate assignment.
kishorenc Aug 28, 2022
de22860
Merge branch 'develop' into filter-elements
kishorenc Aug 28, 2022
e4705fd
Add search with filter test to CI.
kishorenc Aug 28, 2022
7f419ea
Remove constexpr for functor in test.
kishorenc Aug 28, 2022
f7d3366
USE_SSE with msvc compilers
alxvth Aug 29, 2022
23f5351
Remove inclusion of cpu_x86.h
alxvth Aug 29, 2022
e8b3e44
Add cpp tests for Windows in CI
dyashuni Aug 28, 2022
6fa8cd0
Merge pull request #404 from dyashuni/windows_tester
dyashuni Sep 2, 2022
a3e399a
Merge pull request #409 from dyashuni/cpp_windows_tests
dyashuni Sep 2, 2022
1fe7baf
Add check for is_filter_disabled.
kishorenc Sep 6, 2022
c9897b0
Add assert header.
kishorenc Sep 6, 2022
5c14e05
Merge pull request #402 from typesense/filter-elements
dyashuni Sep 6, 2022
c481fd4
Merge branch 'develop' into fix/cpu_x86_include
dyashuni Sep 6, 2022
1f49ffe
Merge pull request #408 from alxvth/fix/cpu_x86_include
dyashuni Sep 6, 2022
6d28ec0
Refactoring (#410)
dyashuni Sep 18, 2022
4ab1d61
Remove some code duplication in bindings (#416)
dyashuni Sep 20, 2022
687ca85
Update python recall test (#415)
dyashuni Sep 20, 2022
983cea9
Python: filter elements with an optional filtering function (#417)
gtsoukas Nov 9, 2022
3e006ea
Replace deleted elements at addition (#418)
dyashuni Jan 12, 2023
28681fc
Getters for max elements, element count and num deleted. (#431)
kishorenc Jan 14, 2023
978f713
Fix insufficient results during filtering. (#430)
kishorenc Jan 14, 2023
d86f8f9
Refactoring of project structure (#432)
dyashuni Jan 15, 2023
225b519
Add warning that python filter works slow in multi-threaded mode
Jan 14, 2023
32f4b02
Add comments with warnings that filter works slow in python in multit…
Jan 15, 2023
2175362
Merge pull request #433 from dyashuni/filter_warning
yurymalkov Jan 15, 2023
2c6f244
Merge pull request #379 from PLNech/patch-1
yurymalkov Jan 15, 2023
dd266bc
preliminary release notes
Jan 15, 2023
d35f428
Add construction speed logging
Jan 17, 2023
68a3387
fix a misprint
yurymalkov Jan 18, 2023
488ab52
Add cpp examples (#435)
dyashuni Jan 30, 2023
dd1bdb7
Merge pull request #434 from nmslib/v07release
yurymalkov Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Filter elements with an optional filtering function.
  • Loading branch information
kishorenc committed Aug 15, 2022
commit 765c4ab4ba00e2e8a54b349c5df1f028b08953ed
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)
target_link_libraries(searchKnnCloserFirst_test hnswlib)

add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp)
target_link_libraries(searchKnnWithFilter_test hnswlib)

add_executable(main main.cpp sift_1b.cpp)
target_link_libraries(main hnswlib)
endif()
95 changes: 95 additions & 0 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// This is a test file for testing the filtering feature

#include "../hnswlib/hnswlib.h"

#include <assert.h>

#include <vector>
#include <iostream>

namespace
{

using idx_t = hnswlib::labeltype;

bool pickIdsDivisibleByThree(unsigned int ep_id) {
return ep_id % 3 == 0;
}

bool pickIdsDivisibleBySeven(unsigned int ep_id) {
return ep_id % 7 == 0;
}

template<typename filter_func_t>
void test(filter_func_t filter_func, size_t div_num) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; ++i) {
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}


hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float,hnswlib::FILTERFUNC>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float,hnswlib::FILTERFUNC>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
alg_brute->addPoint(data.data() + d * i, i);
alg_hnsw->addPoint(data.data() + d * i, i);
}

// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
assert((gd.top().second % div_num) == 0);
gd.pop();
}
}

// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
assert((gd.top().second % div_num) == 0);
gd.pop();
}
}

delete alg_brute;
delete alg_hnsw;
}

} // namespace

int main() {
std::cout << "Testing ..." << std::endl;
test(pickIdsDivisibleByThree, 3);
test(pickIdsDivisibleBySeven, 7);
std::cout << "Test ok" << std::endl;

return 0;
}
18 changes: 11 additions & 7 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include <algorithm>

namespace hnswlib {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename filter_func_t=FILTERFUNC>
class BruteforceSearch : public AlgorithmInterface<dist_t,filter_func_t> {
public:
BruteforceSearch(SpaceInterface <dist_t> *s) : data_(nullptr), maxelements_(0),
cur_element_count(0), size_per_element_(0), data_size_(0),
Expand Down Expand Up @@ -92,20 +92,24 @@ namespace hnswlib {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
dist_t lastdist = topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
topResults.pop();
lastdist = topResults.top().first;
Expand Down
16 changes: 8 additions & 8 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename filter_func_t=FILTERFUNC>
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> {
public:
static const tableint max_update_element_locks = 65536;
HierarchicalNSW(SpaceInterface<dist_t> *s) {
Expand Down Expand Up @@ -238,7 +238,7 @@ namespace hnswlib {

template <bool has_deletions, bool collect_metrics=false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand All @@ -247,7 +247,7 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
if (!has_deletions || !isMarkedDeleted(ep_id)) {
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
Expand Down Expand Up @@ -307,7 +307,7 @@ namespace hnswlib {
_MM_HINT_T0);////////////////////////
#endif

if (!has_deletions || !isMarkedDeleted(candidate_id))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id))
top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
Expand Down Expand Up @@ -1111,7 +1111,7 @@ namespace hnswlib {
};

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down Expand Up @@ -1148,11 +1148,11 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (num_deleted_) {
top_candidates=searchBaseLayerST<true,true>(
currObj, query_data, std::max(ef_, k));
currObj, query_data, std::max(ef_, k), isIdAllowed);
}
else{
top_candidates=searchBaseLayerST<false,true>(
currObj, query_data, std::max(ef_, k));
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

while (top_candidates.size() > k) {
Expand Down
20 changes: 14 additions & 6 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ static bool AVX512Capable() {
namespace hnswlib {
typedef size_t labeltype;

bool allowAllIds(unsigned int ep_id) {
return true;
}

template <typename T>
class pairGreater {
public:
Expand All @@ -137,6 +141,7 @@ namespace hnswlib {
template<typename MTYPE>
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);

using FILTERFUNC = bool(*)(unsigned int);

template<typename MTYPE>
class SpaceInterface {
Expand All @@ -151,28 +156,31 @@ namespace hnswlib {
virtual ~SpaceInterface() {}
};

template<typename dist_t>
template<typename dist_t, typename filter_func_t=FILTERFUNC>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k) const;
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const;

virtual void saveIndex(const std::string &location)=0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t>
template<typename dist_t, typename filter_func_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
filter_func_t isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k);
auto ret = searchKnn(query_data, k, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
Expand Down