Skip to content

Commit

Permalink
Repair graph method nmslib#515
Browse files Browse the repository at this point in the history
  • Loading branch information
kishorenc committed Oct 23, 2023
1 parent 5aba40d commit 37921cf
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@ jobs:
./multiThread_replace_test
./test_updates
./test_updates update
./repair_test
shell: bash
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)

add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
target_link_libraries(main hnswlib)

add_executable(repair_test tests/cpp/repair_test.cpp)
target_link_libraries(repair_test hnswlib)
endif()
155 changes: 124 additions & 31 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::mutex deleted_elements_lock; // lock for deleted_elements
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements

std::mutex repair_lock; // locks graph repair


HierarchicalNSW(SpaceInterface<dist_t> *s) {
}
Expand Down Expand Up @@ -190,9 +192,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


int getRandomLevel(double reverse_size) {
int getRandomLevel(double ml) {
std::uniform_real_distribution<double> distribution(0.0, 1.0);
double r = -log(distribution(level_generator_)) * reverse_size;
double r = -log(distribution(level_generator_)) * ml;
return (int) r;
}

Expand Down Expand Up @@ -240,14 +242,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);

int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
if (layer == 0) {
data = (int*)get_linklist0(curNodeNum);
} else {
data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
}
size_t size = getListCount((linklistsizeint*)data);
linklistsizeint *data = get_linklist_at_level(curNodeNum, layer);
size_t size = getListCount(data);
tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
Expand Down Expand Up @@ -325,8 +321,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
candidate_set.pop();

tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
linklistsizeint *data = get_linklist0(current_node_id);
size_t size = getListCount(data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
if (collect_metrics) {
metric_hops++;
Expand Down Expand Up @@ -471,11 +467,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
if (isUpdate) {
lock.lock();
}
linklistsizeint *ll_cur;
if (level == 0)
ll_cur = get_linklist0(cur_c);
else
ll_cur = get_linklist(cur_c, level);
linklistsizeint *ll_cur = get_linklist_at_level(cur_c, level);

if (*ll_cur && !isUpdate) {
throw std::runtime_error("The newly inserted element should have blank link list");
Expand All @@ -495,12 +487,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);

linklistsizeint *ll_other;
if (level == 0)
ll_other = get_linklist0(selectedNeighbors[idx]);
else
ll_other = get_linklist(selectedNeighbors[idx], level);

linklistsizeint *ll_other = get_linklist_at_level(selectedNeighbors[idx], level);
size_t sz_link_list_other = getListCount(ll_other);

if (sz_link_list_other > Mcurmax)
Expand Down Expand Up @@ -969,8 +956,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

{
std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
linklistsizeint *ll_cur;
ll_cur = get_linklist_at_level(neigh, layer);
linklistsizeint *ll_cur = get_linklist_at_level(neigh, layer);
size_t candSize = candidates.size();
setListCount(ll_cur, candSize);
tableint *data = (tableint *) (ll_cur + 1);
Expand Down Expand Up @@ -999,7 +985,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
linklistsizeint *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist_at_level(currObj, level);
int size = getListCount(data);
Expand Down Expand Up @@ -1057,7 +1043,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
unsigned int *data = get_linklist_at_level(internalId, level);
linklistsizeint *data = get_linklist_at_level(internalId, level);
int size = getListCount(data);
std::vector<tableint> result(size);
tableint *ll = (tableint *) (data + 1);
Expand Down Expand Up @@ -1095,6 +1081,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}

cur_c = cur_element_count;
// use the element level as a flag to show that an element is not added yet
// the element count is increased but no lock is aquired
// so someone can start using the new element
element_levels_[cur_c] = -1;
cur_element_count++;
label_lookup_[label] = cur_c;
}
Expand Down Expand Up @@ -1134,7 +1124,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
linklistsizeint *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist(currObj, level);
int size = getListCount(data);
Expand Down Expand Up @@ -1196,9 +1186,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;

data = (unsigned int *) get_linklist(currObj, level);
linklistsizeint *data = get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;
Expand Down Expand Up @@ -1271,5 +1259,110 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
}


void repair_zero_indegree() {
// only one repair is allowed to be in progress at a time
std::unique_lock <std::mutex> lock_repair(repair_lock);

int maxlevel_copy = maxlevel_;
size_t element_count_copy = cur_element_count;
std::vector<size_t> indegree(element_count_copy);

for (int level = maxlevel_copy; level >=0 ; level--) {
std::fill(indegree.begin(), indegree.end(), 0);

size_t m_max = level ? maxM_ : maxM0_;
int num_elements = 0;
// calculate in-degree
for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) {
// lock until addition is finished
std::unique_lock <std::mutex> lock_el(link_list_locks_[internal_id]);
// skip elements that are not in the current level
// Note: if the element was not added to the graph before the lock
// then element_level = -1 and we skip it as well
int element_level = element_levels_[internal_id];
if (element_level < level) {
continue;
}

linklistsizeint *ll = get_linklist_at_level(internal_id, level);
int size = getListCount(ll);
tableint *datal = (tableint *) (ll + 1);
for (int i = 0; i < size; i++) {
tableint nei_id = datal[i];
// skip newly added elements
if (nei_id >= element_count_copy) {
continue;
}
indegree[nei_id] += 1;
}
num_elements += 1;
}

// skip levels with 1 element
if (num_elements <= 1) {
continue;
}

// fix elements with 0 in-degree
for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) {
int element_level = element_levels_[internal_id];
if (element_level < level || indegree[internal_id] > 0) {
continue;
}

char* data_point = getDataByInternalId(internal_id);
tableint currObj = enterpoint_node_;

dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
for (int level_above = maxlevel_copy; level_above > level; level_above--) {
bool changed = true;
while (changed) {
changed = false;
linklistsizeint *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist_at_level(currObj, level_above);
int size = getListCount(data);

tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates = searchBaseLayer(
currObj, data_point, level);

while (candidates.size() > 0) {
tableint cand_id = candidates.top().second;
// skip same element
if (cand_id == internal_id) {
candidates.pop();
continue;
}

// try to connect candidate to the element
// add an edge if there is space
std::unique_lock <std::mutex> lock(link_list_locks_[cand_id]);
linklistsizeint *ll_cand = get_linklist_at_level(cand_id, level);
tableint *data_cand = (tableint *) (ll_cand + 1);
size_t size = getListCount(ll_cand);
if (size < m_max) {
data_cand[size] = internal_id;
setListCount(ll_cand, size + 1);
}
candidates.pop();
}
}
}
}
};
} // namespace hnswlib

0 comments on commit 37921cf

Please sign in to comment.