Skip to content

Commit

Permalink
Remove `setup' method from hoplite interface, and add a 'vacuuum' met…
Browse files Browse the repository at this point in the history
…hod for the sqlite db's.

PiperOrigin-RevId: 691887853
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 31, 2024
1 parent 982e225 commit 2b2bcc5
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 98 deletions.
1 change: 0 additions & 1 deletion chirp/projects/agile2/1_embed_audio_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
" os.unlink(configs.db_config.db_config.db_path)\n",
" print('\\n Deleted previous db at: ', configs.db_config.db_config.db_path)\n",
" db = configs.db_config.load_db()\n",
" db.setup()\n",
"\n",
"drop_existing_db = True #@param[True, False]\n",
"\n",
Expand Down
1 change: 0 additions & 1 deletion chirp/projects/agile2/convert_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def convert_tfrecords(
)
else:
raise ValueError(f'Unknown db type: {db_type}')
db.setup()

# Convert embedding config to new format and insert into the DB.
legacy_config = embed_lib.load_embedding_config(embeddings_path)
Expand Down
1 change: 0 additions & 1 deletion chirp/projects/agile2/ingest_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def embed_annotated_dataset(
)
)
db = db_config.load_db()
db.setup()
print('Initialized DB located at ', db_filepath)
worker = embed.EmbedWorker(
audio_sources=audio_srcs_config, db=db, model_config=db_model_config
Expand Down
1 change: 0 additions & 1 deletion chirp/projects/agile2/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def make_db(
)
else:
raise ValueError(f'Unknown db type: {db_type}')
db.setup()
# Insert a few embeddings...
graph_utils.insert_random_embeddings(db, embedding_dim, num_embeddings, rng)

Expand Down
1 change: 0 additions & 1 deletion chirp/projects/hoplite/db_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def duplicate_db(
):
"""Create a new DB and copy all data in source_db into it."""
target_db = DBConfig(target_db_key, target_db_config).load_db()
target_db.setup()
target_db.commit()

# Check that the target_db is empty. If not, we'll have to do something more
Expand Down
17 changes: 5 additions & 12 deletions chirp/projects/hoplite/in_mem_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,14 @@ def create(cls, **kwargs):
db.drop_all_edges()
return db

embeddings = np.zeros([
1,
])
embeddings = np.zeros(
[kwargs['max_size'], kwargs['embedding_dim']],
dtype=kwargs.get('embedding_dtype', np.float16),
)
db = cls(embeddings=embeddings, **kwargs)
db.setup()
db.drop_all_edges()
return db

def setup(self):
"""Initialize an empty database."""
self.embeddings = np.zeros(
(self.max_size, self.embedding_dim), dtype=self.embedding_dtype
)
# Dropping all edges initializes the edge table.
self.drop_all_edges()

@functools.cached_property
def empty_edges(self):
return -1 * np.ones((self.degree_bound,), dtype=np.int64)
Expand Down
4 changes: 0 additions & 4 deletions chirp/projects/hoplite/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,6 @@ class GraphSearchDBInterface(abc.ABC):
def create(cls, **kwargs):
"""Connect to and, if needed, initialize the database."""

@abc.abstractmethod
def setup(self):
"""Initialize an empty database."""

@abc.abstractmethod
def commit(self) -> None:
"""Commit any pending transactions to the database."""
Expand Down
160 changes: 88 additions & 72 deletions chirp/projects/hoplite/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import collections
from collections.abc import Sequence
import dataclasses
import functools
import json
import sqlite3
from typing import Any
Expand Down Expand Up @@ -51,7 +50,9 @@ def create(
db = sqlite3.connect(db_path)
cursor = db.cursor()
cursor.execute('PRAGMA journal_mode=WAL;') # Enable WAL mode
_setup_sqlite_tables(cursor)
db.commit()

if embedding_dim is None:
# Get an embedding from the DB to check its dimension.
cursor = db.cursor()
Expand All @@ -64,7 +65,6 @@ def create(
) from exc
embedding = deserialize_embedding(embedding, embedding_dtype)
embedding_dim = embedding.shape[-1]

return SQLiteGraphSearchDB(db, db_path, embedding_dim, embedding_dtype)

def thread_split(self):
Expand All @@ -85,79 +85,18 @@ def deserialize_edges(self, serialized_edges: bytes) -> np.ndarray:
dtype=np.dtype(np.int64).newbyteorder('<'),
)

def setup(self, index=True):
cursor = self._get_cursor()
# Create embedding sources table
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_sources (
id INTEGER PRIMARY KEY AUTOINCREMENT,
dataset STRING NOT NULL,
source STRING NOT NULL
);
""")

# Create embeddings table
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
embedding BLOB NOT NULL,
source_idx INTEGER NOT NULL,
offsets BLOB NOT NULL,
FOREIGN KEY (source_idx) REFERENCES hoplite_sources(id)
);
""")

cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_metadata (
key STRING PRIMARY KEY,
data STRING NOT NULL
);
""")

# Create hoplite_edges table.
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_edges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_embedding_id INTEGER NOT NULL,
target_embedding_ids BLOB NOT NULL,
FOREIGN KEY (source_embedding_id) REFERENCES embeddings(id)
);
""")

# Create hoplite_labels table.
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_labels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
embedding_id INTEGER NOT NULL,
label STRING NOT NULL,
type INT NOT NULL,
provenance STRING NOT NULL,
FOREIGN KEY (embedding_id) REFERENCES embeddings(id)
)""")

if index:
# Create indices for efficient lookups.
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS
idx_embedding ON hoplite_embeddings (id);
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS
source_pairs ON hoplite_sources (dataset, source);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS embedding_source ON hoplite_embeddings (source_idx);
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS idx_source_embedding ON hoplite_edges (source_embedding_id);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_label ON hoplite_labels (embedding_id, label);
""")
def commit(self) -> None:
self.db.commit()

def commit(self):
def vacuum_db(self) -> None:
"""Clears out the WAL log and defragments data."""
cursor = self._get_cursor()
cursor.execute('VACUUM;')
self.db.commit()
cursor.close()
self._cursor = None
self.db.close()
self.db = sqlite3.connect(self.db_path)

def get_embedding_ids(self) -> np.ndarray:
cursor = self._get_cursor()
Expand Down Expand Up @@ -508,6 +447,83 @@ def print_table_values(self, table_name):
print(', '.join(str(value) for value in row))


def _setup_sqlite_tables(cursor: sqlite3.Cursor) -> None:
""" "Create all needed tables in the SQLite database."""
cursor.execute("""
SELECT name FROM sqlite_master
WHERE type='table' AND name='hoplite_labels';
""")
if cursor.fetchone() is not None:
return

# Create embedding sources table
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_sources (
id INTEGER PRIMARY KEY AUTOINCREMENT,
dataset STRING NOT NULL,
source STRING NOT NULL
);
""")

# Create embeddings table
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
embedding BLOB NOT NULL,
source_idx INTEGER NOT NULL,
offsets BLOB NOT NULL,
FOREIGN KEY (source_idx) REFERENCES hoplite_sources(id)
);
""")

cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_metadata (
key STRING PRIMARY KEY,
data STRING NOT NULL
);
""")

# Create hoplite_edges table.
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_edges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_embedding_id INTEGER NOT NULL,
target_embedding_ids BLOB NOT NULL,
FOREIGN KEY (source_embedding_id) REFERENCES embeddings(id)
);
""")

# Create hoplite_labels table.
cursor.execute("""
CREATE TABLE IF NOT EXISTS hoplite_labels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
embedding_id INTEGER NOT NULL,
label STRING NOT NULL,
type INT NOT NULL,
provenance STRING NOT NULL,
FOREIGN KEY (embedding_id) REFERENCES embeddings(id)
)""")

# Create indices for efficient lookups.
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS
idx_embedding ON hoplite_embeddings (id);
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS
source_pairs ON hoplite_sources (dataset, source);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS embedding_source ON hoplite_embeddings (source_idx);
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS idx_source_embedding ON hoplite_edges (source_embedding_id);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_label ON hoplite_labels (embedding_id, label);
""")


def serialize_embedding(
embedding: np.ndarray, embedding_dtype: type[Any]
) -> bytes:
Expand Down
15 changes: 11 additions & 4 deletions chirp/projects/hoplite/sqlite_usearch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ def deserialize_edges(self, serialized_edges: bytes) -> np.ndarray:
dtype=np.dtype(np.int64).newbyteorder('<'),
)

def setup(self):
pass

def commit(self):
def commit(self) -> None:
self.db.commit()
if self._cursor is not None:
self._cursor.close()
Expand All @@ -180,6 +177,16 @@ def commit(self):
# This check is sufficient because the index is strictly additive.
self.ui.save(self._usearch_filepath.as_posix())

def vacuum_db(self) -> None:
"""Clears out the WAL log and defragments data."""
cursor = self._get_cursor()
cursor.execute('VACUUM;')
self.db.commit()
cursor.close()
self._cursor = None
self.db.close()
self.db = sqlite3.connect(self.db_path)

def get_embedding_ids(self) -> np.ndarray:
# Note that USearch can also create a list of all keys, but it seems
# quite slow.
Expand Down
1 change: 0 additions & 1 deletion chirp/projects/hoplite/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def make_db(
)
else:
raise ValueError(f'Unknown db type: {db_type}')
db.setup()
# Insert a few embeddings...
graph_utils.insert_random_embeddings(db, embedding_dim, num_embeddings, rng)
config = config_dict.ConfigDict()
Expand Down

0 comments on commit 2b2bcc5

Please sign in to comment.