From b71877828497e96c5f4d466406e6be819942f4a3 Mon Sep 17 00:00:00 2001 From: ilyajob05 Date: Wed, 30 Nov 2022 16:54:38 +0300 Subject: [PATCH 1/3] add func clear_mem() --- hnswlib/hnswalg.h | 266 ++++++++++++++++++++++++++++++----- python_bindings/bindings.cpp | 5 + 2 files changed, 232 insertions(+), 39 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e95e0b52..e3a283a9 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -53,7 +53,7 @@ namespace hnswlib { cur_element_count = 0; - visited_list_pool_ = new VisitedListPool(1, max_elements); + visited_list_pool_ = new VisitedListPool(1, max_elements_); //initializations for special treatment of the first node enterpoint_node_ = -1; @@ -137,12 +137,20 @@ namespace hnswlib { memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); } + inline labeltype *getExternalLabeLp(tableint internal_id, char *data_mem) const { + return (labeltype *) (data_mem + internal_id * size_data_per_element_ + label_offset_); + } + inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + return getExternalLabeLp(internal_id, data_level0_memory_); + } + + inline char *getDataByInternalId(tableint internal_id, char *data_mem) const { + return (data_mem + internal_id * size_data_per_element_ + offsetData_); } inline char *getDataByInternalId(tableint internal_id) const { - return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + return getDataByInternalId(internal_id, data_level0_memory_); } int getRandomLevel(double reverse_size) { @@ -151,10 +159,14 @@ namespace hnswlib { return (int) r; } + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int level) { + return searchBaseLayer(ep_id, data_point, level, data_level0_memory_, linkLists_, visited_list_pool_, link_list_locks_); + } std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayer(tableint ep_id, const void *data_point, int layer) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + searchBaseLayer(tableint ep_id, const void *data_point, int level, char* data_mem, char** linkLists, VisitedListPool *vlp, std::vector &link_list_locks) { + VisitedList *vl = vlp->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -182,13 +194,16 @@ namespace hnswlib { tableint curNodeNum = curr_el_pair.second; - std::unique_lock lock(link_list_locks_[curNodeNum]); + // todo: check optimisation + if(!link_list_locks.empty()) { + std::unique_lock lock(link_list_locks[curNodeNum]); + } int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); + if (level == 0) { + data = (int*)get_linklist0(curNodeNum, data_mem); } else { - data = (int*)get_linklist(curNodeNum, layer); + data = (int*)get_linklist(curNodeNum, level, linkLists); // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); } size_t size = getListCount((linklistsizeint*)data); @@ -196,8 +211,8 @@ namespace hnswlib { #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal, data_mem), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1), data_mem), _MM_HINT_T0); #endif for (size_t j = 0; j < size; j++) { @@ -205,11 +220,11 @@ namespace hnswlib { // if (candidate_id == 0) continue; #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1), data_mem), _MM_HINT_T0); #endif if (visited_array[candidate_id] == visited_array_tag) continue; visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); + char *currObj1 = (getDataByInternalId(candidate_id, data_mem)); dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { @@ -218,7 +233,7 @@ namespace hnswlib { _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); #endif - if (!isMarkedDeleted(candidate_id)) + if (!isMarkedDeleted(candidate_id, data_mem)) top_candidates.emplace(dist1, candidate_id); if (top_candidates.size() > ef_construction_) @@ -229,7 +244,7 @@ namespace hnswlib { } } } - visited_list_pool_->releaseVisitedList(vl); + vlp->releaseVisitedList(vl); return top_candidates; } @@ -326,10 +341,16 @@ namespace hnswlib { } void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { - if (top_candidates.size() < M) { - return; + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + getNeighborsByHeuristic2(top_candidates, M, data_level0_memory_); + } + + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M, char* data_mem) { + if (top_candidates.size() < M) { + return; } std::priority_queue> queue_closest; @@ -349,8 +370,8 @@ namespace hnswlib { for (std::pair second_pair : return_list) { dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), + fstdistfunc_(getDataByInternalId(second_pair.second, data_mem), + getDataByInternalId(curent_pair.second, data_mem), dist_func_param_);; if (curdist < dist_to_query) { good = false; @@ -372,23 +393,37 @@ namespace hnswlib { return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); }; - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + linklistsizeint *get_linklist0(tableint internal_id, char *data_mem) const { + return (linklistsizeint *) (data_mem + internal_id * size_data_per_element_ + offsetLevel0_); }; linklistsizeint *get_linklist(tableint internal_id, int level) const { - return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + return get_linklist(internal_id, level, linkLists_); }; + linklistsizeint *get_linklist(tableint internal_id, int level, char **link_lists) const { + return (linklistsizeint *) (link_lists[internal_id] + (level - 1) * size_links_per_element_); + }; + +// linklistsizeint *get_linklist(tableint internal_id, int level) const { +// return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); +// }; + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); }; tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - int level, bool isUpdate) { + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, bool isUpdate) { + return mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, isUpdate, data_level0_memory_, linkLists_, element_levels_, link_list_locks_); + }; + + tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, bool isUpdate, char *data_mem, char **link_lists, std::vector &element_levels, std::vector &link_list_locks) { size_t Mcurmax = level ? maxM_ : maxM0_; - getNeighborsByHeuristic2(top_candidates, M_); + getNeighborsByHeuristic2(top_candidates, M_, data_mem); if (top_candidates.size() > M_) throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); @@ -404,9 +439,9 @@ namespace hnswlib { { linklistsizeint *ll_cur; if (level == 0) - ll_cur = get_linklist0(cur_c); + ll_cur = get_linklist0(cur_c, data_mem); else - ll_cur = get_linklist(cur_c, level); + ll_cur = get_linklist(cur_c, level, link_lists); if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); @@ -416,7 +451,7 @@ namespace hnswlib { for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { if (data[idx] && !isUpdate) throw std::runtime_error("Possible memory corruption"); - if (level > element_levels_[selectedNeighbors[idx]]) + if (level > element_levels[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); data[idx] = selectedNeighbors[idx]; @@ -426,13 +461,15 @@ namespace hnswlib { for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + if(!link_list_locks.empty()) { + std::unique_lock lock(link_list_locks[selectedNeighbors[idx]]); + } linklistsizeint *ll_other; if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); + ll_other = get_linklist0(selectedNeighbors[idx], data_mem); else - ll_other = get_linklist(selectedNeighbors[idx], level); + ll_other = get_linklist(selectedNeighbors[idx], level, link_lists); size_t sz_link_list_other = getListCount(ll_other); @@ -440,7 +477,7 @@ namespace hnswlib { throw std::runtime_error("Bad value of sz_link_list_other"); if (selectedNeighbors[idx] == cur_c) throw std::runtime_error("Trying to connect an element to itself"); - if (level > element_levels_[selectedNeighbors[idx]]) + if (level > element_levels[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); tableint *data = (tableint *) (ll_other + 1); @@ -462,7 +499,7 @@ namespace hnswlib { setListCount(ll_other, sz_link_list_other + 1); } else { // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c, data_mem), getDataByInternalId(selectedNeighbors[idx], data_mem), dist_func_param_); // Heuristic: std::priority_queue, std::vector>, CompareByFirst> candidates; @@ -470,11 +507,11 @@ namespace hnswlib { for (size_t j = 0; j < sz_link_list_other; j++) { candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + fstdistfunc_(getDataByInternalId(data[j], data_mem), getDataByInternalId(selectedNeighbors[idx], data_mem), dist_func_param_), data[j]); } - getNeighborsByHeuristic2(candidates, Mcurmax); + getNeighborsByHeuristic2(candidates, Mcurmax, data_mem); int indx = 0; while (candidates.size() > 0) { @@ -746,6 +783,152 @@ namespace hnswlib { return data; } + void clear_mem(int level_param=-1) + { + std::unordered_map label_lookup_tmp; + char *data_level0_memory_tmp; + char **linkLists_tmp; + std::vector element_levels_tmp(max_elements_); + size_t cur_element_count_tmp = 0; + //initializations for special treatment of the first node + tableint enterpoint_node_tmp = -1; + int maxlevel_tmp = -1; + size_t offsetLevel0_tmp = 0; + + data_level0_memory_tmp = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_tmp == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + + linkLists_tmp = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_tmp == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + + VisitedListPool *visited_list_pool_tmp = new VisitedListPool(1, max_elements_); + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)){ + continue; + } + // get data from src + const auto label = getExternalLabel(i); + const auto data_point = getDataByInternalId(i); + + auto cur_c_tmp = cur_element_count_tmp; + cur_element_count_tmp++; + label_lookup_tmp[label] = cur_c_tmp; + + int curlevel = getRandomLevel(mult_); + if (level_param > 0){ + curlevel = level_param; + } + + element_levels_tmp[cur_c_tmp] = curlevel; + + int maxlevelcopy = maxlevel_tmp; + tableint currObj = enterpoint_node_tmp; + tableint enterpoint_copy = enterpoint_node_tmp; + + memset(data_level0_memory_tmp + cur_c_tmp * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c_tmp, data_level0_memory_tmp), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c_tmp, data_level0_memory_tmp), data_point, data_size_); + + if (curlevel) { + linkLists_tmp[cur_c_tmp] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_tmp[cur_c_tmp] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_tmp[cur_c_tmp], 0, size_links_per_element_ * curlevel + 1); + } + + if ((signed)currObj != -1) { + if (curlevel < maxlevelcopy) { + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj, data_level0_memory_tmp), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; +// std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj, level, linkLists_tmp); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand, data_level0_memory_tmp), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + bool epDeleted = isMarkedDeleted(enterpoint_copy, data_level0_memory_tmp); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::vector link_list_locks_tmp; // todo: test multithread + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level, data_level0_memory_tmp, linkLists_tmp, visited_list_pool_tmp, link_list_locks_tmp); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy, data_level0_memory_tmp), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); + } + currObj = mutuallyConnectNewElement(data_point, cur_c_tmp, top_candidates, level, false, data_level0_memory_tmp, linkLists_tmp, element_levels_tmp, link_list_locks_tmp); + } + + } else { + // Do nothing for the first element + enterpoint_node_tmp = 0; + maxlevel_tmp = curlevel; + + } + + //Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_tmp = cur_c_tmp; + maxlevel_tmp = curlevel; + } + + } + + // swap and free mem + { + std::unique_lock templock(global); + + free(data_level0_memory_); + data_level0_memory_ = data_level0_memory_tmp; + + delete visited_list_pool_; + visited_list_pool_ = visited_list_pool_tmp; + + cur_element_count = cur_element_count_tmp; + + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + + linkLists_ = linkLists_tmp; + + element_levels_.clear(); + element_levels_ = std::move(element_levels_tmp); + + enterpoint_node_ = enterpoint_node_tmp; + maxlevel_ = maxlevel_tmp; + } + } + static const unsigned char DELETE_MARK = 0x01; // static const unsigned char REUSE_MARK = 0x10; /** @@ -823,6 +1006,11 @@ namespace hnswlib { return *ll_cur & DELETE_MARK; } + bool isMarkedDeleted(tableint internalId, char *data_mem) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId, data_mem))+2; + return *ll_cur & DELETE_MARK; + } + unsigned short int getListCount(linklistsizeint * ptr) const { return *((unsigned short int *)ptr); } @@ -832,7 +1020,7 @@ namespace hnswlib { } void addPoint(const void *data_point, labeltype label) { - addPoint(data_point, label,-1); + addPoint(data_point, label, -1); } void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { @@ -1003,7 +1191,7 @@ namespace hnswlib { unmarkDeletedInternal(existingInternalId); } updatePoint(data_point, existingInternalId, 1.0); - + return existingInternalId; } diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 12f38e2e..2a106482 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -639,6 +639,10 @@ class Index { appr_alg->unmarkDelete(label); } + void clearDeleted(int level){ + appr_alg->clear_mem(level); + } + void resizeIndex(size_t new_size) { appr_alg->resizeIndex(new_size); } @@ -854,6 +858,7 @@ PYBIND11_PLUGIN(hnswlib) { .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) .def("mark_deleted", &Index::markDeleted, py::arg("label")) .def("unmark_deleted", &Index::unmarkDeleted, py::arg("label")) + .def("clear_deleted", &Index::clearDeleted, py::arg("level")=-1) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) .def("get_max_elements", &Index::getMaxElements) .def("get_current_count", &Index::getCurrentCount) From a09d8c65fafaaaa8db8a2b69f9c8ae694b69cda3 Mon Sep 17 00:00:00 2001 From: ilyajob05 Date: Wed, 30 Nov 2022 17:25:17 +0300 Subject: [PATCH 2/3] added bindings_test_clear.py --- python_bindings/tests/bindings_test_clear.py | 71 ++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 python_bindings/tests/bindings_test_clear.py diff --git a/python_bindings/tests/bindings_test_clear.py b/python_bindings/tests/bindings_test_clear.py new file mode 100644 index 00000000..d3231152 --- /dev/null +++ b/python_bindings/tests/bindings_test_clear.py @@ -0,0 +1,71 @@ +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + for idx in range(16): + print("\n**** Index resize test ****\n") + + np.random.seed(idx) + dim = 16 + num_elements = 10000 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + + # Initiating index + # max_elements - the maximum number of elements, should be known beforehand + # (probably will be made optional in the future) + # + # ef_construction - controls index search speed/build speed tradeoff + # M - is tightly connected with internal dimensionality of the data + # strongly affects the memory consumption + + p.init_index(max_elements=num_elements//2, ef_construction=100, M=16) + + # Controlling the recall by setting ef: + # higher ef leads to better accuracy, but slower search + p.set_ef(20) + + p.set_num_threads(idx % 8) # by default using all available cores + + # We split the data in two batches: + data1 = data[:num_elements // 2] + data2 = data[num_elements // 2:] + + print("Adding first batch of %d elements" % (len(data1))) + p.add_items(data1) + + # Query the elements for themselves and measure recall: + labels, distances = p.knn_query(data1, k=1) + + items = p.get_items(list(range(len(data1)))) + + # Check the recall: + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3) + + # Check that the returned element data is correct: + diff_with_gt_labels = np.max(np.abs(data1-items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4) + + print("Resizing the index") + p.resize_index(num_elements) + + print("Remove index") + p.mark_deleted(1) + p.mark_deleted(2) + p.clear_deleted() + + print("Adding the second batch of %d elements" % (len(data2))) + p.add_items(data2) + + # Check that the elements are removed correctly: + self.assertAlmostEqual(p.get_current_count() == num_elements - 2) + From 9d5e063effd319c04acfa2dc86dba6bd21b4e4b7 Mon Sep 17 00:00:00 2001 From: ilyajob05 Date: Thu, 12 Jan 2023 14:33:28 +0300 Subject: [PATCH 3/3] Update hnswalg.h --- hnswlib/hnswalg.h | 1 + 1 file changed, 1 insertion(+) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e3a283a9..d2b527df 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -1392,3 +1392,4 @@ namespace hnswlib { }; } +