From 1397a32a45b94223aa27d931f8f111f1f5b27217 Mon Sep 17 00:00:00 2001 From: Axlgrep Date: Thu, 11 Jul 2024 20:26:06 +0800 Subject: [PATCH] support DataLevel0BlocksMemory data struct to solve the realloc huge memory doubling problem --- hnswlib/hnswalg.h | 306 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 264 insertions(+), 42 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e269ae69..24114d7f 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -14,6 +14,141 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; + +class DataLevel0BlocksMemory { + public: + DataLevel0BlocksMemory(size_t size_data_per_element) + : size_data_per_element_(size_data_per_element) { + + assert(size_data_per_element_ > 0); + if (size_data_per_element_ >= MIN_MEMORY_BLOCK_SIZE) { + element_count_per_block_ = 1; + } else { + element_count_per_block_ = (MIN_MEMORY_BLOCK_SIZE + (size_data_per_element_ - 1)) / size_data_per_element_; + } + } + + ~DataLevel0BlocksMemory() { + for (size_t i = 0; i < memory_blocks_.size(); i++) { + assert(memory_blocks_[i] != nullptr); + free(memory_blocks_[i]); + memory_blocks_[i] = nullptr; + } + std::vector().swap(memory_blocks_); + } + + size_t Capacity() { + return capacity_; + } + + size_t ElementCountPerBlock() { + return element_count_per_block_; + } + + void Malloc(size_t max_elements) { + assert(memory_blocks_.empty()); + if (max_elements == 0) { + return; + } + + if (max_elements < element_count_per_block_) { + AppendNonstandardBlock(max_elements); + } else { + size_t added_blocks = max_elements / element_count_per_block_; + AppendStandardBlocks(added_blocks); + + size_t last_block_elements = max_elements % element_count_per_block_; + if (last_block_elements != 0) { + AppendNonstandardBlock(last_block_elements); + } + } + capacity_ = max_elements; + } + + void Realloc(size_t max_elements) { + if (max_elements <= capacity_) { + return; + } + + size_t full_block_count = max_elements / element_count_per_block_; + if (capacity_ % element_count_per_block_ != 0) { + if (full_block_count < memory_blocks_.size()) { + ReallocLastBlocks(max_elements % element_count_per_block_); + capacity_ = max_elements; + return; + } else { + ReallocLastBlocks(element_count_per_block_); + } + } + + assert(full_block_count >= memory_blocks_.size()); + size_t added_blocks = full_block_count - memory_blocks_.size(); + AppendStandardBlocks(added_blocks); + + size_t last_block_elements = max_elements % element_count_per_block_; + if (last_block_elements != 0) { + AppendNonstandardBlock(last_block_elements); + } + capacity_ = max_elements; + } + + char* GetElementPtr(tableint internal_id) { + assert(internal_id < capacity_); + size_t index = internal_id / element_count_per_block_; + size_t elements_offset_in_block = internal_id % element_count_per_block_; + assert(index < memory_blocks_.size()); + return memory_blocks_[index] + elements_offset_in_block * size_data_per_element_; + } + + char* GetMemoryBlockPtr(size_t index) { + assert(index < memory_blocks_.size()); + return memory_blocks_[index]; + } + + private: + DataLevel0BlocksMemory(const DataLevel0BlocksMemory&) = delete; + DataLevel0BlocksMemory& operator=(const DataLevel0BlocksMemory&) = delete; + + void AppendStandardBlocks(size_t count) { + for (size_t i = 0; i < count; i++) { + char* ptr = (char *) malloc(element_count_per_block_ * size_data_per_element_); + if (ptr == nullptr) { + throw std::runtime_error("Not enough memory"); + } else { + memory_blocks_.emplace_back(ptr); + } + } + } + + void AppendNonstandardBlock(size_t elements) { + char* ptr = (char *) malloc(elements * size_data_per_element_); + if (ptr == nullptr) { + throw std::runtime_error("Not enough memory"); + } else { + memory_blocks_.emplace_back(ptr); + } + } + + void ReallocLastBlocks(size_t elements) { + assert(!memory_blocks_.empty()); + assert(capacity_ % element_count_per_block_ < elements && elements <= element_count_per_block_); + size_t last_block_index = memory_blocks_.size() - 1; + char* ptr = (char *) realloc(memory_blocks_[last_block_index], elements * size_data_per_element_); + if (ptr == nullptr) { + throw std::runtime_error("Not enough memory"); + } else { + memory_blocks_[last_block_index] = ptr; + } + } + + size_t capacity_{0}; + size_t size_data_per_element_{0}; + size_t element_count_per_block_{0}; + std::vector memory_blocks_; + + static const size_t MIN_MEMORY_BLOCK_SIZE = 128 * 1024 * 1024; +}; + template class HierarchicalNSW : public AlgorithmInterface { public: @@ -47,7 +182,10 @@ class HierarchicalNSW : public AlgorithmInterface { size_t size_links_level0_{0}; size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; + const bool use_blocks_memory_{true}; char *data_level0_memory_{nullptr}; + DataLevel0BlocksMemory *data_level0_blocks_memory_{nullptr}; + char **linkLists_{nullptr}; std::vector element_levels_; // keeps level of each element @@ -80,8 +218,10 @@ class HierarchicalNSW : public AlgorithmInterface { const std::string &location, bool nmslib = false, size_t max_elements = 0, - bool allow_replace_deleted = false) - : allow_replace_deleted_(allow_replace_deleted) { + bool allow_replace_deleted = false, + bool use_small_blocks_memory = true) + : allow_replace_deleted_(allow_replace_deleted), + use_blocks_memory_(use_small_blocks_memory) { loadIndex(location, s, max_elements); } @@ -92,11 +232,13 @@ class HierarchicalNSW : public AlgorithmInterface { size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100, - bool allow_replace_deleted = false) + bool allow_replace_deleted = false, + bool use_small_blocks_memory = true) : label_op_locks_(MAX_LABEL_OPERATION_LOCKS), link_list_locks_(max_elements), element_levels_(max_elements), - allow_replace_deleted_(allow_replace_deleted) { + allow_replace_deleted_(allow_replace_deleted), + use_blocks_memory_(use_small_blocks_memory) { max_elements_ = max_elements; num_deleted_ = 0; data_size_ = s->get_data_size(); @@ -123,9 +265,15 @@ class HierarchicalNSW : public AlgorithmInterface { label_offset_ = size_links_level0_ + data_size_; offsetLevel0_ = 0; - data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); + if (use_blocks_memory_) { + data_level0_blocks_memory_ = new DataLevel0BlocksMemory(size_data_per_element_); + data_level0_blocks_memory_->Malloc(max_elements_); + } else { + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + } + cur_element_count = 0; @@ -149,8 +297,12 @@ class HierarchicalNSW : public AlgorithmInterface { } void clear() { - free(data_level0_memory_); - data_level0_memory_ = nullptr; + if (use_blocks_memory_) { + delete data_level0_blocks_memory_; + } else { + free(data_level0_memory_); + data_level0_memory_ = nullptr; + } for (tableint i = 0; i < cur_element_count; i++) { if (element_levels_[i] > 0) free(linkLists_[i]); @@ -184,23 +336,39 @@ class HierarchicalNSW : public AlgorithmInterface { inline labeltype getExternalLabel(tableint internal_id) const { labeltype return_label; - memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + if (use_blocks_memory_) { + memcpy(&return_label, (data_level0_blocks_memory_->GetElementPtr(internal_id) + label_offset_), sizeof(labeltype)); + } else { + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + } return return_label; } inline void setExternalLabel(tableint internal_id, labeltype label) const { - memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + if (use_blocks_memory_) { + memcpy((data_level0_blocks_memory_->GetElementPtr(internal_id) + label_offset_), &label, sizeof(labeltype)); + } else { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } } inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + if (use_blocks_memory_) { + return (labeltype *) (data_level0_blocks_memory_->GetElementPtr(internal_id) + label_offset_); + } else { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } } inline char *getDataByInternalId(tableint internal_id) const { - return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + if (use_blocks_memory_) { + return (data_level0_blocks_memory_->GetElementPtr(internal_id) + offsetData_); + } else { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } } @@ -266,8 +434,10 @@ class HierarchicalNSW : public AlgorithmInterface { #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + if (1 < size) { + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + } #endif for (size_t j = 0; j < size; j++) { @@ -275,7 +445,9 @@ class HierarchicalNSW : public AlgorithmInterface { // if (candidate_id == 0) continue; #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + if (j + 1 < size) { + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + } #endif if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; @@ -304,7 +476,6 @@ class HierarchicalNSW : public AlgorithmInterface { return top_candidates; } - // bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance template std::priority_queue, std::vector>, CompareByFirst> @@ -370,7 +541,13 @@ class HierarchicalNSW : public AlgorithmInterface { #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + if (use_blocks_memory_) { + if (1 < size) { + _mm_prefetch(data_level0_blocks_memory_->GetElementPtr(*(data + 1)) + offsetData_, _MM_HINT_T0); + } + } else { + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + } _mm_prefetch((char *) (data + 2), _MM_HINT_T0); #endif @@ -379,8 +556,14 @@ class HierarchicalNSW : public AlgorithmInterface { // if (candidate_id == 0) continue; #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, - _MM_HINT_T0); //////////// + if (use_blocks_memory_) { + if (j + 1 <= size) { + _mm_prefetch(data_level0_blocks_memory_->GetElementPtr(*(data + j + 1)) + offsetData_, _MM_HINT_T0); + } + } else { + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0); //////////// + } #endif if (!(visited_array[candidate_id] == visited_array_tag)) { visited_array[candidate_id] = visited_array_tag; @@ -398,9 +581,13 @@ class HierarchicalNSW : public AlgorithmInterface { if (flag_consider_candidate) { candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + - offsetLevel0_, /////////// - _MM_HINT_T0); //////////////////////// + if (use_blocks_memory_) { + _mm_prefetch(data_level0_blocks_memory_->GetElementPtr(candidate_set.top().second) + offsetLevel0_, _MM_HINT_T0); + } else { + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, /////////// + _MM_HINT_T0); //////////////////////// + } #endif if (bare_bone_search || @@ -439,7 +626,6 @@ class HierarchicalNSW : public AlgorithmInterface { return top_candidates; } - void getNeighborsByHeuristic2( std::priority_queue, std::vector>, CompareByFirst> &top_candidates, const size_t M) { @@ -484,12 +670,11 @@ class HierarchicalNSW : public AlgorithmInterface { linklistsizeint *get_linklist0(tableint internal_id) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - } - - - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + if (use_blocks_memory_) { + return (linklistsizeint *) (data_level0_blocks_memory_->GetElementPtr(internal_id) + offsetLevel0_); + } else { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } } @@ -641,10 +826,14 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector(new_max_elements).swap(link_list_locks_); // Reallocate base layer - char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); - if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); - data_level0_memory_ = data_level0_memory_new; + if (use_blocks_memory_) { + data_level0_blocks_memory_->Realloc(new_max_elements); + } else { + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; + } // Reallocate all other layers char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); @@ -701,7 +890,19 @@ class HierarchicalNSW : public AlgorithmInterface { writeBinaryPOD(output, mult_); writeBinaryPOD(output, ef_construction_); - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + if (use_blocks_memory_) { + size_t block_index = 0; + size_t left_element = cur_element_count; + assert(max_elements_ == data_level0_blocks_memory_->Capacity()); + while (left_element > 0) { + size_t write_element_count = std::min(left_element, data_level0_blocks_memory_->ElementCountPerBlock()); + output.write(data_level0_blocks_memory_->GetMemoryBlockPtr(block_index), write_element_count * size_data_per_element_); + left_element -= write_element_count; + block_index++; + } + } else { + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); + } for (size_t i = 0; i < cur_element_count; i++) { unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; @@ -774,10 +975,23 @@ class HierarchicalNSW : public AlgorithmInterface { input.seekg(pos, input.beg); - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + if (use_blocks_memory_) { + size_t block_index = 0; + size_t left_element = cur_element_count; + data_level0_blocks_memory_ = new DataLevel0BlocksMemory(size_data_per_element_); + data_level0_blocks_memory_->Malloc(max_elements); + while (left_element > 0) { + size_t read_element_count = std::min(left_element, data_level0_blocks_memory_->ElementCountPerBlock()); + input.read(data_level0_blocks_memory_->GetMemoryBlockPtr(block_index), read_element_count * size_data_per_element_); + left_element -= read_element_count; + block_index++; + } + } else { + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + } size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -1090,11 +1304,15 @@ class HierarchicalNSW : public AlgorithmInterface { int size = getListCount(data); tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + if (0 < size) { + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + } #endif for (int i = 0; i < size; i++) { #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); + if (i + 1 < size) { + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); + } #endif tableint cand = datal[i]; dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); @@ -1197,7 +1415,11 @@ class HierarchicalNSW : public AlgorithmInterface { tableint currObj = enterpoint_node_; tableint enterpoint_copy = enterpoint_node_; - memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + if (use_blocks_memory_) { + memset(data_level0_blocks_memory_->GetElementPtr(cur_c) + offsetLevel0_, 0, size_data_per_element_); + } else { + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + } // Initialisation of the data and label memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));