diff --git a/CMakeLists.txt b/CMakeLists.txt index 79ab30b3..ebee6e6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ include_directories("${PROJECT_BINARY_DIR}") -set(SOURCE_EXE main.cpp) +set(SOURCE_EXE main.cpp) set(SOURCE_LIB sift_1b.cpp) @@ -13,5 +13,14 @@ add_library(sift_test STATIC ${SOURCE_LIB}) add_executable(main ${SOURCE_EXE}) +if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + SET( CMAKE_CXX_FLAGS "-Ofast -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -ftree-vectorize") +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" ) +endif() + +add_executable(test_updates examples/updates_test.cpp) + target_link_libraries(main sift_test) diff --git a/README.md b/README.md index c79e24c1..559c5dfd 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,13 @@ # Hnswlib - fast approximate nearest neighbor search -Header-only C++ HNSW implementation with python bindings. Paper code for the HNSW 200M SIFT experiment +Header-only C++ HNSW implementation with python bindings. Paper's code for the HNSW 200M SIFT experiment **NEWS:** -**Thanks to Louis Abraham ([@louisabraham](https://github.com/louisabraham)) hnswlib is now can be installed via pip!** +* **Thanks to Apoorv Sharma [@apoorv-sharma](https://github.com/apoorv-sharma), hnswlib now supports true element updates (the interface remained the same, but when you the perfromance/memory should not degrade as you update the element embeddinds).** + +* **Thanks to Dmitry [@2ooom](https://github.com/2ooom), hnswlib got a boost in performance for vector dimensions that are not mutiple of 4** + +* **Thanks to Louis Abraham ([@louisabraham](https://github.com/louisabraham)) hnswlib can now be installed via pip!** Highlights: 1) Lightweight, header-only, no dependencies other than C++ 11. @@ -23,10 +27,10 @@ Description of the algorithm parameters can be found in [ALGO_PARAMS.md](ALGO_PA | Distance | parameter | Equation | | ------------- |:---------------:| -----------------------:| |Squared L2 |'l2' | d = sum((Ai-Bi)^2) | -|Inner product |'ip' | d = 1.0 - sum(Ai\*Bi)) | +|Inner product |'ip' | d = 1.0 - sum(Ai\*Bi) | |Cosine similarity |'cosine' | d = 1.0 - sum(Ai\*Bi) / sqrt(sum(Ai\*Ai) * sum(Bi\*Bi))| -Note that inner product is not an actual metric. An element can be closer to some other element than to itself. +Note that inner product is not an actual metric. An element can be closer to some other element than to itself. That allows some speedup if you remove all elements that are not the closest to themselves from the index. For other spaces use the nmslib library https://github.com/nmslib/nmslib. @@ -42,6 +46,7 @@ Index methods: * `add_items(data, data_labels, num_threads = -1)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. * `labels` is an optional N-size numpy array of integer labels for all elements in `data`. * `num_threads` sets the number of cpu threads to use (-1 means use default). + * `data_labels` specifies the labels for the data. If index already has the elements with the same labels, their features will be updated. Note that update procedure is slower than insertion of a new element, but more memory- and query-efficient. * Thread-safe with other `add_items` calls, but not with `knn_query`. * `mark_deleted(data_label)` - marks the element as deleted, so it will be ommited from search results. @@ -223,6 +228,29 @@ To run the test on 200M SIFT subset: The size of the bigann subset (in millions) is controlled by the variable **subset_size_milllions** hardcoded in **sift_1b.cpp**. +### Updates test +To generate testing data (from root directory): +```bash +cd examples +python update_gen_data.py +``` +To compile (from root directory): +```bash +mkdir build +cd build +cmake .. +make +``` +To run test **without** updates (from `build` directory) +```bash +./test_updates +``` + +To run test **with** updates (from `build` directory) +```bash +./test_updates update +``` + ### HNSW example demos - Visual search engine for 1M amazon products (MXNet + HNSW): [website](https://thomasdelteil.github.io/VisualSearch_MXNet/), [code](https://github.com/ThomasDelteil/VisualSearch_MXNet), demo by [@ThomasDelteil](https://github.com/ThomasDelteil) diff --git a/examples/update_gen_data.py b/examples/update_gen_data.py new file mode 100644 index 00000000..6f51bbbe --- /dev/null +++ b/examples/update_gen_data.py @@ -0,0 +1,37 @@ +import numpy as np +import os + +def normalized(a, axis=-1, order=2): + l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) + l2[l2==0] = 1 + return a / np.expand_dims(l2, axis) + +N=100000 +dummy_data_multiplier=3 +N_queries = 1000 +d=8 +K=5 + +np.random.seed(1) + +print("Generating data...") +batches_dummy= [ normalized(np.float32(np.random.random( (N,d)))) for _ in range(dummy_data_multiplier)] +batch_final = normalized (np.float32(np.random.random( (N,d)))) +queries = normalized(np.float32(np.random.random( (N_queries,d)))) +print("Computing distances...") +dist=np.dot(queries,batch_final.T) +topk=np.argsort(-dist)[:,:K] +print("Saving...") + +try: + os.mkdir("data") +except OSError as e: + pass + +for idx, batch_dummy in enumerate(batches_dummy): + batch_dummy.tofile('data/batch_dummy_%02d.bin' % idx) +batch_final.tofile('data/batch_final.bin') +queries.tofile('data/queries.bin') +np.int32(topk).tofile('data/gt.bin') +with open("data/config.txt", "w") as file: + file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K)) \ No newline at end of file diff --git a/examples/updates_test.cpp b/examples/updates_test.cpp new file mode 100644 index 00000000..c8775877 --- /dev/null +++ b/examples/updates_test.cpp @@ -0,0 +1,298 @@ +#include "../hnswlib/hnswlib.h" +#include +class StopW +{ + std::chrono::steady_clock::time_point time_begin; + +public: + StopW() + { + time_begin = std::chrono::steady_clock::now(); + } + + float getElapsedTimeMicro() + { + std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); + return (std::chrono::duration_cast(time_end - time_begin).count()); + } + + void reset() + { + time_begin = std::chrono::steady_clock::now(); + } +}; + +/* + * replacement for the openmp '#pragma omp parallel for' directive + * only handles a subset of functionality (no reductions etc) + * Process ids from start (inclusive) to end (EXCLUSIVE) + * + * The method is borrowed from nmslib + */ +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if ((id >= end)) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } + + +} + + +template +std::vector load_batch(std::string path, int size) +{ + std::cout << "Loading " << path << "..."; + // float or int32 (python) + assert(sizeof(datatype) == 4); + + std::ifstream file; + file.open(path); + if (!file.is_open()) + { + std::cout << "Cannot open " << path << "\n"; + exit(1); + } + std::vector batch(size); + + file.read((char *)batch.data(), size * sizeof(float)); + std::cout << " DONE\n"; + return batch; +} + +template +static float +test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, + std::vector> &answers, size_t K) +{ + size_t correct = 0; + size_t total = 0; + //uncomment to test in parallel mode: + + + for (int i = 0; i < qsize; i++) + { + + std::priority_queue> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K); + total += K; + while (result.size()) + { + if (answers[i].find(result.top().second) != answers[i].end()) + { + correct++; + } + else + { + } + result.pop(); + } + } + return 1.0f * correct / total; +} + +static void +test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, + std::vector> &answers, size_t k) +{ + std::vector efs = {1}; + for (int i = k; i < 30; i++) + { + efs.push_back(i); + } + for (int i = 30; i < 400; i+=10) + { + efs.push_back(i); + } + for (int i = 1000; i < 100000; i += 5000) + { + efs.push_back(i); + } + std::cout << "ef\trecall\ttime\thops\tdistcomp\n"; + for (size_t ef : efs) + { + appr_alg.setEf(ef); + + appr_alg.metric_hops=0; + appr_alg.metric_distance_computations=0; + StopW stopw = StopW(); + + float recall = test_approx(queries, qsize, appr_alg, vecdim, answers, k); + float time_us_per_query = stopw.getElapsedTimeMicro() / qsize; + float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize); + float hops_per_query = appr_alg.metric_hops / (1.0f * qsize); + + std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"< 0.99) + { + std::cout << "Recall is over 0.99! "<2){ + std::cout<<"Usage ./test_updates [update]\n"; + exit(1); + } + + std::string path = "../examples/data/"; + + + int N; + int dummy_data_multiplier; + int N_queries; + int d; + int K; + { + std::ifstream configfile; + configfile.open(path + "/config.txt"); + if (!configfile.is_open()) + { + std::cout << "Cannot open config.txt\n"; + return 1; + } + configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K; + + printf("Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n", N, dummy_data_multiplier, N_queries, d, K); + } + + hnswlib::L2Space l2space(d); + hnswlib::HierarchicalNSW appr_alg(&l2space, N + 1, M, efConstruction); + + std::vector dummy_batch = load_batch(path + "batch_dummy_00.bin", N * d); + + // Adding enterpoint: + + appr_alg.addPoint((void *)dummy_batch.data(), (size_t)0); + + StopW stopw = StopW(); + + if (update) + { + std::cout << "Update iteration 0\n"; + + + ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); + }); + appr_alg.checkIntegrity(); + + ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); + }); + appr_alg.checkIntegrity(); + + for (int b = 1; b < dummy_data_multiplier; b++) + { + std::cout << "Update iteration " << b << "\n"; + char cpath[1024]; + sprintf(cpath, "batch_dummy_%02d.bin", b); + std::vector dummy_batchb = load_batch(path + cpath, N * d); + + ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); + }); + appr_alg.checkIntegrity(); + } + } + + std::cout << "Inserting final elements\n"; + std::vector final_batch = load_batch(path + "batch_final.bin", N * d); + + stopw.reset(); + ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(final_batch.data() + i * d), i); + }); + std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n"; + std::cout << "Running tests\n"; + std::vector queries_batch = load_batch(path + "queries.bin", N_queries * d); + + std::vector gt = load_batch(path + "gt.bin", N_queries * K); + + std::vector> answers(N_queries); + for (int i = 0; i < N_queries; i++) + { + for (int j = 0; j < K; j++) + { + answers[i].insert(gt[i * K + j]); + } + } + + for (int i = 0; i < 3; i++) + { + std::cout << "Test iteration " << i << "\n"; + test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K); + } + + return 0; +}; \ No newline at end of file diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index afc1222d..97bdcd18 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -2,6 +2,7 @@ #include "visited_list_pool.h" #include "hnswlib.h" +#include #include #include #include @@ -15,7 +16,7 @@ namespace hnswlib { template class HierarchicalNSW : public AlgorithmInterface { public: - + static const tableint max_update_element_locks = 65536; HierarchicalNSW(SpaceInterface *s) { } @@ -25,7 +26,7 @@ namespace hnswlib { } HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), element_levels_(max_elements) { + link_list_locks_(max_elements), element_levels_(max_elements), link_list_update_locks_(max_update_element_locks) { max_elements_ = max_elements; has_deletions_=false; @@ -39,6 +40,7 @@ namespace hnswlib { ef_ = 10; level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); @@ -104,6 +106,10 @@ namespace hnswlib { std::mutex cur_element_count_guard_; std::vector link_list_locks_; + + // Locks to prevent race condition during update/insert of an element at same time. + // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. + std::vector link_list_update_locks_; tableint enterpoint_node_; @@ -126,6 +132,7 @@ namespace hnswlib { std::unordered_map label_lookup_; std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; inline labeltype getExternalLabel(tableint internal_id) const { labeltype return_label; @@ -151,6 +158,7 @@ namespace hnswlib { return (int) r; } + std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void *data_point, int layer) { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); @@ -233,7 +241,10 @@ namespace hnswlib { return top_candidates; } - template + mutable std::atomic metric_distance_computations; + mutable std::atomic metric_hops; + + template std::priority_queue, std::vector>, CompareByFirst> searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); @@ -269,6 +280,10 @@ namespace hnswlib { int *data = (int *) get_linklist0(current_node_id); size_t size = getListCount((linklistsizeint*)data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); + if(collect_metrics){ + metric_hops++; + metric_distance_computations+=size; + } #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); @@ -319,10 +334,11 @@ namespace hnswlib { void getNeighborsByHeuristic2( std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { + const size_t M) { if (top_candidates.size() < M) { return; } + std::priority_queue> queue_closest; std::vector> return_list; while (top_candidates.size() > 0) { @@ -337,6 +353,7 @@ namespace hnswlib { dist_t dist_to_query = -curent_pair.first; queue_closest.pop(); bool good = true; + for (std::pair second_pair : return_list) { dist_t curdist = fstdistfunc_(getDataByInternalId(second_pair.second), @@ -350,12 +367,9 @@ namespace hnswlib { if (good) { return_list.push_back(curent_pair); } - - } for (std::pair curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); } } @@ -373,10 +387,13 @@ namespace hnswlib { return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); }; - void mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> top_candidates, - int level) { + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + }; + tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, bool isUpdate) { size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) @@ -389,6 +406,8 @@ namespace hnswlib { top_candidates.pop(); } + tableint next_closest_entry_point = selectedNeighbors[0]; + { linklistsizeint *ll_cur; if (level == 0) @@ -396,15 +415,13 @@ namespace hnswlib { else ll_cur = get_linklist(cur_c, level); - if (*ll_cur) { + if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); } setListCount(ll_cur,selectedNeighbors.size()); tableint *data = (tableint *) (ll_cur + 1); - - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - if (data[idx]) + if (data[idx] && !isUpdate) throw std::runtime_error("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); @@ -413,11 +430,11 @@ namespace hnswlib { } } + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - linklistsizeint *ll_other; if (level == 0) ll_other = get_linklist0(selectedNeighbors[idx]); @@ -434,47 +451,63 @@ namespace hnswlib { throw std::runtime_error("Trying to make a link on a non-existent level"); tableint *data = (tableint *) (ll_other + 1); - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue, std::vector>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); + bool is_cur_c_present = false; + if (isUpdate) { for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_), data[j]); + if (data[j] == cur_c) { + is_cur_c_present = true; + break; + } } + } - getNeighborsByHeuristic2(candidates, Mcurmax); + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; - } - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } } + if (indx >= 0) { + data[indx] = cur_c; + } */ } - if (indx >= 0) { - data[indx] = cur_c; - } */ } - } + + return next_closest_entry_point; } std::mutex global; @@ -516,15 +549,15 @@ namespace hnswlib { if (has_deletions_) { std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); + ef_); top_candidates.swap(top_candidates1); } else{ std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); + ef_); top_candidates.swap(top_candidates1); } - + while (top_candidates.size() > k) { top_candidates.pop(); } @@ -545,7 +578,6 @@ namespace hnswlib { std::vector(new_max_elements).swap(link_list_locks_); - // Reallocate base layer char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); if (data_level0_memory_new == nullptr) @@ -636,8 +668,8 @@ namespace hnswlib { dist_func_param_ = s->get_dist_func_param(); auto pos=input.tellg(); - - + + /// Optional - check if index is ok: input.seekg(cur_element_count * size_data_per_element_,input.cur); @@ -669,7 +701,7 @@ namespace hnswlib { throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -677,6 +709,7 @@ namespace hnswlib { size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); + std::vector(max_update_element_locks).swap(link_list_update_locks_); visited_list_pool_ = new VisitedListPool(1, max_elements); @@ -711,7 +744,7 @@ namespace hnswlib { if(isMarkedDeleted(i)) has_deletions_=true; } - + input.close(); return; @@ -795,26 +828,185 @@ namespace hnswlib { addPoint(data_point, label,-1); } + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); + + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; + + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set sCand; + std::unordered_set sNeigh; + std::vector listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; + + sCand.insert(internalId); + + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); + + if (distribution(update_probability_generator_) > updateNeighborProbability) + continue; + + sNeigh.insert(elOneHop); + + std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); + } + } + + for (auto&& neigh : sNeigh) { +// if (neigh == internalId) +// continue; + + std::priority_queue, std::vector>, CompareByFirst> candidates; + int size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; + int elementsToKeep = std::min(int(ef_construction_), size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; + + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); + candidates.emplace(distance, cand); + } + } + } + + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); + + { + std::unique_lock lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + int candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); + } + } + } + } + + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + }; + + void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); +#endif + for (int i = 0; i < size; i++) { +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); +#endif + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); + + std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); + + topCandidates.pop(); + } + + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); + } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + } + } + } + + std::vector getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll,size * sizeof(tableint)); + return result; + }; + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; { - std::unique_lock lock(cur_element_count_guard_); + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock templock_curr(cur_element_count_guard_); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + + templock_curr.unlock(); + + std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); + updatePoint(data_point, existingInternalId, 1.0); + return existingInternalId; + } + if (cur_element_count >= max_elements_) { throw std::runtime_error("The number of elements exceeds the specified limit"); }; cur_c = cur_element_count; cur_element_count++; - - auto search = label_lookup_.find(label); - if (search != label_lookup_.end()) { - std::unique_lock lock_el(link_list_locks_[search->second]); - has_deletions_ = true; - markDeletedInternal(search->second); - } label_lookup_[label] = cur_c; } + // Take update lock to prevent race conditions on an element with insertion/update at the same time. + std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); std::unique_lock lock_el(link_list_locks_[cur_c]); int curlevel = getRandomLevel(mult_); if (level > 0) @@ -889,9 +1081,7 @@ namespace hnswlib { if (top_candidates.size() > ef_construction_) top_candidates.pop(); } - mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); - - currObj = top_candidates.top().second; + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } @@ -926,6 +1116,9 @@ namespace hnswlib { data = (unsigned int *) get_linklist(currObj, level); int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + tableint *datal = (tableint *) (data + 1); for (int i = 0; i < size; i++) { tableint cand = datal[i]; @@ -943,16 +1136,15 @@ namespace hnswlib { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (has_deletions_) { - std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + if (has_deletions_) { + top_candidates=searchBaseLayerST( currObj, query_data, std::max(ef_, k)); - top_candidates.swap(top_candidates1); } else{ - std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + top_candidates=searchBaseLayerST( currObj, query_data, std::max(ef_, k)); - top_candidates.swap(top_candidates1); } + while (top_candidates.size() > k) { top_candidates.pop(); } @@ -982,6 +1174,40 @@ namespace hnswlib { return result; } + void checkIntegrity(){ + int connections_checked=0; + std::vector inbound_connections_num(cur_element_count,0); + for(int i = 0;i < cur_element_count; i++){ + for(int l = 0;l <= element_levels_[i]; l++){ + linklistsizeint *ll_cur = get_linklist_at_level(i,l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set s; + for (int j=0; j 0); + assert(data[j] < cur_element_count); + assert (data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; + + } + assert(s.size() == size); + } + } + if(cur_element_count > 1){ + int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; + for(int i=0; i < cur_element_count; i++){ + assert(inbound_connections_num[i] > 0); + min1=std::min(inbound_connections_num[i],min1); + max1=std::max(inbound_connections_num[i],max1); + } + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + + } + }; } diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index dbfb1656..c26f80b5 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -25,7 +25,7 @@ #include #include - +#include #include namespace hnswlib { diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index e9467473..d0497ff7 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -211,6 +211,36 @@ namespace hnswlib { #endif +#if defined(USE_SSE) || defined(USE_AVX) + static float + InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + return res + res_tail - 1.0f; + } + + static float + InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + + return res + res_tail - 1.0f; + } +#endif + class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; @@ -220,11 +250,15 @@ namespace hnswlib { InnerProductSpace(size_t dim) { fstdistfunc_ = InnerProduct; #if defined(USE_AVX) || defined(USE_SSE) - if (dim % 4 == 0) - fstdistfunc_ = InnerProductSIMD4Ext; if (dim % 16 == 0) fstdistfunc_ = InnerProductSIMD16Ext; -#endif + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductSIMD4ExtResiduals; + #endif dim_ = dim; data_size_ = dim * sizeof(float); } diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 4d3ac69a..bc00af72 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -4,16 +4,19 @@ namespace hnswlib { static float - L2Sqr(const void *pVect1, const void *pVect2, const void *qty_ptr) { - //return *((float *)pVect2); + L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); + float res = 0; - for (unsigned i = 0; i < qty; i++) { - float t = ((float *) pVect1)[i] - ((float *) pVect2)[i]; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; res += t * t; } return (res); - } #if defined(USE_AVX) @@ -49,10 +52,8 @@ namespace hnswlib { } _mm256_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; - - return (res); -} + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + } #elif defined(USE_SSE) @@ -62,12 +63,9 @@ namespace hnswlib { float *pVect2 = (float *) pVect2v; size_t qty = *((size_t *) qty_ptr); float PORTABLE_ALIGN32 TmpRes[8]; - // size_t qty4 = qty >> 2; size_t qty16 = qty >> 4; const float *pEnd1 = pVect1 + (qty16 << 4); - // const float* pEnd2 = pVect1 + (qty4 << 2); - // const float* pEnd3 = pVect1 + qty; __m128 diff, v1, v2; __m128 sum = _mm_set1_ps(0); @@ -102,10 +100,24 @@ namespace hnswlib { diff = _mm_sub_ps(v1, v2); sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + _mm_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + } +#endif - return (res); +#if defined(USE_SSE) || defined(USE_AVX) + static float + L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); } #endif @@ -119,10 +131,9 @@ namespace hnswlib { size_t qty = *((size_t *) qty_ptr); - // size_t qty4 = qty >> 2; - size_t qty16 = qty >> 2; + size_t qty4 = qty >> 2; - const float *pEnd1 = pVect1 + (qty16 << 2); + const float *pEnd1 = pVect1 + (qty4 << 2); __m128 diff, v1, v2; __m128 sum = _mm_set1_ps(0); @@ -136,9 +147,22 @@ namespace hnswlib { sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } _mm_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + } - return (res); + static float + L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; + + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; + + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + + return (res + res_tail); } #endif @@ -151,13 +175,14 @@ namespace hnswlib { L2Space(size_t dim) { fstdistfunc_ = L2Sqr; #if defined(USE_SSE) || defined(USE_AVX) - if (dim % 4 == 0) - fstdistfunc_ = L2SqrSIMD4Ext; if (dim % 16 == 0) fstdistfunc_ = L2SqrSIMD16Ext; - /*else{ - throw runtime_error("Data type not supported!"); - }*/ + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; #endif dim_ = dim; data_size_ = dim * sizeof(float); @@ -185,10 +210,6 @@ namespace hnswlib { int res = 0; unsigned char *a = (unsigned char *) pVect1; unsigned char *b = (unsigned char *) pVect2; - /*for (int i = 0; i < qty; i++) { - int t = int((a)[i]) - int((b)[i]); - res += t*t; - }*/ qty = qty >> 2; for (size_t i = 0; i < qty; i++) { @@ -241,4 +262,4 @@ namespace hnswlib { }; -} +} \ No newline at end of file diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index ef1dc1d6..1b88ca23 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -294,7 +294,7 @@ class Index { (void *) items.data(row), k); if (result.size() != k) throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is to small"); + "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); for (int i = k - 1; i >= 0; i--) { auto &result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; @@ -316,7 +316,7 @@ class Index { (void *) (norm_array.data()+start_idx), k); if (result.size() != k) throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is to small"); + "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); for (int i = k - 1; i >= 0; i--) { auto &result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; diff --git a/python_bindings/setup.py b/python_bindings/setup.py index 2e863c87..a6dfb81b 100644 --- a/python_bindings/setup.py +++ b/python_bindings/setup.py @@ -4,7 +4,7 @@ import sys import setuptools -__version__ = '0.3.4' +__version__ = '0.4.0' source_files = ['bindings.cpp']