pecos.ann.hnsw
is a PECOS Approximated Nearest Neighbor (ANN) search module that implements the Hierarchical Navigable Small World Graphs (HNSW) algorithm (Malkov et al., TPAMI 2018).
- Supports both sparse and dense input features
- SIMD optimization for both dense/sparse distance computation
- Supports thread-safe graph construction in parallel on multi-core shared memory machines
- Supports thread-safe Searchers to do inference in parallel, which reduces inference overhead
Basic training (building HNSW index) and predicting (HNSW inference):
python3 -m pecos.ann.hnsw.train -x ${X_path} -m ${model_folder}
python3 -m pecos.ann.hnsw.predict -x ${Xt_path} -m ${model_folder} -o ${Yp_path}
where
X_path
andXt_path
are the paths to the CSR npz or Row-majored npy files of the training/test feature matrices with shape(N,d)
and(Nt,d)
model_folder
is the path to the model folder where the trained model will be saved to, will be created if not existYp_path
is the path to save the prediction label matrix with shape(Nt, N)
Most commonly-used training parameters are
--metric-type
: we support two distance metrics, namelyip
andl2
, for now.--max-edge-per-node
: maximum number of edges per node for layer l=1,...,L. For base layer l=0, it becomes 2M (default 32).--efConstruction
: size of the priority queue when performing best first search during construction (default 100).
For more details, please refer to
python3 -m pecos.ann.hnsw.train --help
Most commonly-used training parameters are
--efSearch
: size of the priority queue when performing best first search during inference (Default 100).--only-topk
: maximum number of candidates (sorted by distances, nearest first) to be returned (Default 10).
Remark For metric_type=ip
, we define its distance to be 1 - <q,x>
.
For more details, please refer to
python3 -m pecos.ann.hnsw.predict --help
First, let's create the database matrix X_trn
and query matrix X_tst
. We will use dense numpy matrices in this illustration. But keep in mind that we also support sparse input features!
import numpy as np
X_trn = np.random.randn(10000, 100).astype(np.float32)
X_tst = np.random.randn(1000, 100).astype(np.float32)
Note that the data type needed to be np.float32
.
Train the HNSW model (i.e., building the graph-based indexing data structure) with maximum number of threads available on your machine (threads=-1
):
from pecos.ann.hnsw import HNSW
train_params = HNSW.TrainParams(M=32, efC=300, metric_type="ip", threads=-1)
model = HNSW.train(X_trn, train_params=train_params, pred_params=None)
Users are also welcome to train the default parameters via
model = HNSW.train(X_trn)
After training, we can save the model to file and re-load
model_folder = "./tmp-hsnw-model"
model.save(model_folder)
del model
model = HNSW.load(model_folder)
Next, we initialize multiple searchers for the inference stage. The searchers will pre-allocate some intermediate variables later to be used by HNSW graph search (e.g., which nodes being visited, priority queues storing the candidates, etc).
# here we would like to FOUR threads to do parallel inference
searchers = model.searchers_create(num_searcher=4)
Finally, we conduct ANN inference by inputing searchers to the HNSW model.
pred_params = HNSW.PredParams(efS=100, topk=10)
Yt_pred = model.predict(X_tst, pred_params=pred_params, searchers=searchers)
where Yt_pred
is a scipy.sparse.csr_matrix
whose column indices for each row are sorted by its distances ascendingly.
Alternatively, it is also feasible to do inference without pre-allocating searchers, which may have larger overhead since it will re-allocate intermediate graph-searhing variables for each query matrix X_tst
.
pred_params.threads = 2
indices, distances = model.predict(X_tst, pred_params=pred_params, ret_csr=False)
When ret_csr=False
, the prediction function will return the indices and distances numpy array.