Skip to content

Commit

Permalink
USearch indexing for Hoplite DB.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684845341
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 22, 2024
1 parent 1b654c1 commit 8844cf1
Show file tree
Hide file tree
Showing 6 changed files with 665 additions and 14 deletions.
18 changes: 13 additions & 5 deletions chirp/projects/agile2/1_embed_audio_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"from chirp.projects.agile2 import colab_utils\n",
"from chirp.projects.agile2 import embed\n",
"from chirp.projects.agile2 import source_info\n",
"from chirp.projects.hoplite import interface\n"
"from chirp.projects.hoplite import interface\n",
"from chirp.projects.hoplite import brutalism\n",
"from chirp.projects.hoplite import db_loader\n",
"from chirp.projects.hoplite import sqlite_usearch_impl\n"
]
},
{
Expand Down Expand Up @@ -55,7 +58,7 @@
"#@markdown like '/home/me/myproject/site_XYZ/audio_ABC.wav'\n",
"dataset_name = '' #@param {type:'string'}\n",
"dataset_base_path = '' #@param {type:'string'}\n",
"dataset_fileglob = '' #@param {type:'string'}\n",
"dataset_fileglob = '*.wav' #@param {type:'string'}\n",
"\n",
"#@markdown Choose a supported model: `perch_8` or `birdnet_v2.3` are most common\n",
"#@markdown for birds. Other choices include `surfperch` for coral reefs or\n",
Expand All @@ -82,7 +85,8 @@
"configs = colab_utils.load_configs(\n",
" source_info.AudioSources((audio_glob,)),\n",
" db_path,\n",
" model_config_key=model_choice)\n",
" model_config_key=model_choice,\n",
" db_key = 'sqlite_usearch')\n",
"configs"
]
},
Expand Down Expand Up @@ -164,10 +168,14 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hr_AUAfI7UG_"
"id": "ihBNRbwuuwal"
},
"outputs": [],
"source": []
"source": [
"q = db.get_embedding(444)\n",
"%time results, scores = brutalism.brute_search(worker.db, query_embedding=q, search_list_size=128, score_fn=np.dot)\n",
"print([r.embedding_id for r in results])"
]
}
],
"metadata": {
Expand Down
19 changes: 13 additions & 6 deletions chirp/projects/agile2/colab_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from chirp.projects.zoo import model_configs
from etils import epath
from ml_collections import config_dict
import numpy as np


@dataclasses.dataclass
Expand All @@ -49,6 +50,7 @@ def load_configs(
audio_sources: source_info.AudioSources,
db_path: str | None = None,
model_config_key: str = 'perch_8',
db_key: str = 'sqlite_usearch',
) -> AgileConfigs:
"""Load default configs for the notebook and return them as an AgileConfigs.
Expand All @@ -68,10 +70,7 @@ def load_configs(
'db_path must be specified when embedding multiple datasets.'
)
# Put the DB in the same directory as the audio.
db_path = (
epath.Path(next(iter(audio_sources.audio_globs)).base_path)
/ 'hoplite_db.sqlite'
)
db_path = epath.Path(next(iter(audio_sources.audio_globs)).base_path)

model_key, embedding_dim, model_config = (
model_configs.get_preset_model_config(model_config_key)
Expand All @@ -83,11 +82,19 @@ def load_configs(
)
db_config = config_dict.ConfigDict({
'db_path': db_path,
'embedding_dim': embedding_dim,
})
if db_key == 'sqlite_usearch':
# A sane default.
db_config.usearch_cfg = config_dict.ConfigDict({
'embedding_dim': embedding_dim,
'metric_name': 'IP',
'expansion_add': 256,
'expansion_search': 128,
'dtype': 'float16',
})

return AgileConfigs(
audio_sources_config=audio_sources,
db_config=db_loader.DBConfig('sqlite', db_config),
db_config=db_loader.DBConfig(db_key, db_config),
model_config=db_model_config,
)
3 changes: 3 additions & 0 deletions chirp/projects/hoplite/db_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from chirp.projects.hoplite import in_mem_impl
from chirp.projects.hoplite import interface
from chirp.projects.hoplite import sqlite_impl
from chirp.projects.hoplite import sqlite_usearch_impl
from ml_collections import config_dict
import numpy as np
import tqdm
Expand All @@ -40,6 +41,8 @@ def load_db(self) -> interface.GraphSearchDBInterface:
"""Load the database from the specified path."""
if self.db_key == 'sqlite':
return sqlite_impl.SQLiteGraphSearchDB.create(**self.db_config)
elif self.db_key == 'sqlite_usearch':
return sqlite_usearch_impl.SQLiteUsearchDB.create(**self.db_config)
elif self.db_key == 'in_mem':
return in_mem_impl.InMemoryGraphSearchDB.create(**self.db_config)
else:
Expand Down
Loading

0 comments on commit 8844cf1

Please sign in to comment.