diff --git a/chirp/projects/agile2/1_embed_audio_v2.ipynb b/chirp/projects/agile2/1_embed_audio_v2.ipynb index 06badf1b..ecaaaece 100644 --- a/chirp/projects/agile2/1_embed_audio_v2.ipynb +++ b/chirp/projects/agile2/1_embed_audio_v2.ipynb @@ -110,7 +110,6 @@ " os.unlink(configs.db_config.db_config.db_path)\n", " print('\\n Deleted previous db at: ', configs.db_config.db_config.db_path)\n", " db = configs.db_config.load_db()\n", - " db.setup()\n", "\n", "drop_existing_db = True #@param[True, False]\n", "\n", diff --git a/chirp/projects/agile2/convert_legacy.py b/chirp/projects/agile2/convert_legacy.py index 91392c4b..a6d2bbe3 100644 --- a/chirp/projects/agile2/convert_legacy.py +++ b/chirp/projects/agile2/convert_legacy.py @@ -60,7 +60,6 @@ def convert_tfrecords( ) else: raise ValueError(f'Unknown db type: {db_type}') - db.setup() # Convert embedding config to new format and insert into the DB. legacy_config = embed_lib.load_embedding_config(embeddings_path) diff --git a/chirp/projects/agile2/ingest_annotations.py b/chirp/projects/agile2/ingest_annotations.py index 4222989d..47be6bae 100644 --- a/chirp/projects/agile2/ingest_annotations.py +++ b/chirp/projects/agile2/ingest_annotations.py @@ -204,7 +204,6 @@ def embed_annotated_dataset( ) ) db = db_config.load_db() - db.setup() print('Initialized DB located at ', db_filepath) worker = embed.EmbedWorker( audio_sources=audio_srcs_config, db=db, model_config=db_model_config diff --git a/chirp/projects/agile2/tests/test_utils.py b/chirp/projects/agile2/tests/test_utils.py index 0221e75e..e2063a5b 100644 --- a/chirp/projects/agile2/tests/test_utils.py +++ b/chirp/projects/agile2/tests/test_utils.py @@ -69,7 +69,6 @@ def make_db( ) else: raise ValueError(f'Unknown db type: {db_type}') - db.setup() # Insert a few embeddings... graph_utils.insert_random_embeddings(db, embedding_dim, num_embeddings, rng) diff --git a/chirp/projects/hoplite/db_loader.py b/chirp/projects/hoplite/db_loader.py index c2d28e74..31042fa0 100644 --- a/chirp/projects/hoplite/db_loader.py +++ b/chirp/projects/hoplite/db_loader.py @@ -57,7 +57,6 @@ def duplicate_db( ): """Create a new DB and copy all data in source_db into it.""" target_db = DBConfig(target_db_key, target_db_config).load_db() - target_db.setup() target_db.commit() # Check that the target_db is empty. If not, we'll have to do something more diff --git a/chirp/projects/hoplite/in_mem_impl.py b/chirp/projects/hoplite/in_mem_impl.py index 86f4d5ea..aa6a30f2 100644 --- a/chirp/projects/hoplite/in_mem_impl.py +++ b/chirp/projects/hoplite/in_mem_impl.py @@ -64,21 +64,14 @@ def create(cls, **kwargs): db.drop_all_edges() return db - embeddings = np.zeros([ - 1, - ]) + embeddings = np.zeros( + [kwargs['max_size'], kwargs['embedding_dim']], + dtype=kwargs.get('embedding_dtype', np.float16), + ) db = cls(embeddings=embeddings, **kwargs) - db.setup() + db.drop_all_edges() return db - def setup(self): - """Initialize an empty database.""" - self.embeddings = np.zeros( - (self.max_size, self.embedding_dim), dtype=self.embedding_dtype - ) - # Dropping all edges initializes the edge table. - self.drop_all_edges() - @functools.cached_property def empty_edges(self): return -1 * np.ones((self.degree_bound,), dtype=np.int64) diff --git a/chirp/projects/hoplite/interface.py b/chirp/projects/hoplite/interface.py index 3552c550..bffa956e 100644 --- a/chirp/projects/hoplite/interface.py +++ b/chirp/projects/hoplite/interface.py @@ -103,10 +103,6 @@ class GraphSearchDBInterface(abc.ABC): def create(cls, **kwargs): """Connect to and, if needed, initialize the database.""" - @abc.abstractmethod - def setup(self): - """Initialize an empty database.""" - @abc.abstractmethod def commit(self) -> None: """Commit any pending transactions to the database.""" diff --git a/chirp/projects/hoplite/sqlite_impl.py b/chirp/projects/hoplite/sqlite_impl.py index 00e7a1a1..01e90dd9 100644 --- a/chirp/projects/hoplite/sqlite_impl.py +++ b/chirp/projects/hoplite/sqlite_impl.py @@ -18,7 +18,6 @@ import collections from collections.abc import Sequence import dataclasses -import functools import json import sqlite3 from typing import Any @@ -51,7 +50,9 @@ def create( db = sqlite3.connect(db_path) cursor = db.cursor() cursor.execute('PRAGMA journal_mode=WAL;') # Enable WAL mode + _setup_sqlite_tables(cursor) db.commit() + if embedding_dim is None: # Get an embedding from the DB to check its dimension. cursor = db.cursor() @@ -64,7 +65,6 @@ def create( ) from exc embedding = deserialize_embedding(embedding, embedding_dtype) embedding_dim = embedding.shape[-1] - return SQLiteGraphSearchDB(db, db_path, embedding_dim, embedding_dtype) def thread_split(self): @@ -85,79 +85,18 @@ def deserialize_edges(self, serialized_edges: bytes) -> np.ndarray: dtype=np.dtype(np.int64).newbyteorder('<'), ) - def setup(self, index=True): - cursor = self._get_cursor() - # Create embedding sources table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS hoplite_sources ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - dataset STRING NOT NULL, - source STRING NOT NULL - ); - """) - - # Create embeddings table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS hoplite_embeddings ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - embedding BLOB NOT NULL, - source_idx INTEGER NOT NULL, - offsets BLOB NOT NULL, - FOREIGN KEY (source_idx) REFERENCES hoplite_sources(id) - ); - """) - - cursor.execute(""" - CREATE TABLE IF NOT EXISTS hoplite_metadata ( - key STRING PRIMARY KEY, - data STRING NOT NULL - ); - """) - - # Create hoplite_edges table. - cursor.execute(""" - CREATE TABLE IF NOT EXISTS hoplite_edges ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - source_embedding_id INTEGER NOT NULL, - target_embedding_ids BLOB NOT NULL, - FOREIGN KEY (source_embedding_id) REFERENCES embeddings(id) - ); - """) - - # Create hoplite_labels table. - cursor.execute(""" - CREATE TABLE IF NOT EXISTS hoplite_labels ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - embedding_id INTEGER NOT NULL, - label STRING NOT NULL, - type INT NOT NULL, - provenance STRING NOT NULL, - FOREIGN KEY (embedding_id) REFERENCES embeddings(id) - )""") - - if index: - # Create indices for efficient lookups. - cursor.execute(""" - CREATE UNIQUE INDEX IF NOT EXISTS - idx_embedding ON hoplite_embeddings (id); - """) - cursor.execute(""" - CREATE UNIQUE INDEX IF NOT EXISTS - source_pairs ON hoplite_sources (dataset, source); - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS embedding_source ON hoplite_embeddings (source_idx); - """) - cursor.execute(""" - CREATE UNIQUE INDEX IF NOT EXISTS idx_source_embedding ON hoplite_edges (source_embedding_id); - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_label ON hoplite_labels (embedding_id, label); - """) + def commit(self) -> None: self.db.commit() - def commit(self): + def vacuum_db(self) -> None: + """Clears out the WAL log and defragments data.""" + cursor = self._get_cursor() + cursor.execute('VACUUM;') self.db.commit() + cursor.close() + self._cursor = None + self.db.close() + self.db = sqlite3.connect(self.db_path) def get_embedding_ids(self) -> np.ndarray: cursor = self._get_cursor() @@ -508,6 +447,83 @@ def print_table_values(self, table_name): print(', '.join(str(value) for value in row)) +def _setup_sqlite_tables(cursor: sqlite3.Cursor) -> None: + """ "Create all needed tables in the SQLite database.""" + cursor.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name='hoplite_labels'; + """) + if cursor.fetchone() is not None: + return + + # Create embedding sources table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS hoplite_sources ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + dataset STRING NOT NULL, + source STRING NOT NULL + ); + """) + + # Create embeddings table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS hoplite_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embedding BLOB NOT NULL, + source_idx INTEGER NOT NULL, + offsets BLOB NOT NULL, + FOREIGN KEY (source_idx) REFERENCES hoplite_sources(id) + ); + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS hoplite_metadata ( + key STRING PRIMARY KEY, + data STRING NOT NULL + ); + """) + + # Create hoplite_edges table. + cursor.execute(""" + CREATE TABLE IF NOT EXISTS hoplite_edges ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source_embedding_id INTEGER NOT NULL, + target_embedding_ids BLOB NOT NULL, + FOREIGN KEY (source_embedding_id) REFERENCES embeddings(id) + ); + """) + + # Create hoplite_labels table. + cursor.execute(""" + CREATE TABLE IF NOT EXISTS hoplite_labels ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + embedding_id INTEGER NOT NULL, + label STRING NOT NULL, + type INT NOT NULL, + provenance STRING NOT NULL, + FOREIGN KEY (embedding_id) REFERENCES embeddings(id) + )""") + + # Create indices for efficient lookups. + cursor.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS + idx_embedding ON hoplite_embeddings (id); + """) + cursor.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS + source_pairs ON hoplite_sources (dataset, source); + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS embedding_source ON hoplite_embeddings (source_idx); + """) + cursor.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS idx_source_embedding ON hoplite_edges (source_embedding_id); + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_label ON hoplite_labels (embedding_id, label); + """) + + def serialize_embedding( embedding: np.ndarray, embedding_dtype: type[Any] ) -> bytes: diff --git a/chirp/projects/hoplite/sqlite_usearch_impl.py b/chirp/projects/hoplite/sqlite_usearch_impl.py index 0567ebca..ef23603f 100644 --- a/chirp/projects/hoplite/sqlite_usearch_impl.py +++ b/chirp/projects/hoplite/sqlite_usearch_impl.py @@ -167,10 +167,7 @@ def deserialize_edges(self, serialized_edges: bytes) -> np.ndarray: dtype=np.dtype(np.int64).newbyteorder('<'), ) - def setup(self): - pass - - def commit(self): + def commit(self) -> None: self.db.commit() if self._cursor is not None: self._cursor.close() @@ -180,6 +177,16 @@ def commit(self): # This check is sufficient because the index is strictly additive. self.ui.save(self._usearch_filepath.as_posix()) + def vacuum_db(self) -> None: + """Clears out the WAL log and defragments data.""" + cursor = self._get_cursor() + cursor.execute('VACUUM;') + self.db.commit() + cursor.close() + self._cursor = None + self.db.close() + self.db = sqlite3.connect(self.db_path) + def get_embedding_ids(self) -> np.ndarray: # Note that USearch can also create a list of all keys, but it seems # quite slow. diff --git a/chirp/projects/hoplite/tests/test_utils.py b/chirp/projects/hoplite/tests/test_utils.py index 7620cf7a..503c2030 100644 --- a/chirp/projects/hoplite/tests/test_utils.py +++ b/chirp/projects/hoplite/tests/test_utils.py @@ -56,7 +56,6 @@ def make_db( ) else: raise ValueError(f'Unknown db type: {db_type}') - db.setup() # Insert a few embeddings... graph_utils.insert_random_embeddings(db, embedding_dim, num_embeddings, rng) config = config_dict.ConfigDict()