Skip to content

Commit

Permalink
Merge pull request #64 from rolobio/feature/remove-db-kind-strings
Browse files Browse the repository at this point in the history
Removing db kind strings in favor of enum.
  • Loading branch information
rolobio authored Sep 17, 2019
2 parents 072eb32 + 8a06b2c commit f7757d0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
38 changes: 20 additions & 18 deletions dictorm/dictorm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""What if you could insert a Python dictionary into the database? DictORM allows you to select/insert/update rows of a database as if they were Python Dictionaries."""
import enum
from typing import Union, Optional, List

__version__ = '4.1.2'
__version__ = '4.1.3'

from contextlib import contextmanager
from itertools import chain
Expand Down Expand Up @@ -62,6 +63,11 @@ class NoCache(Exception):
pass


class DBKind(enum.Enum):
postgres = enum.auto()
sqlite3 = enum.auto()


class Dict(dict):
"""
This is a representation of a database row that behaves exactly like a
Expand Down Expand Up @@ -305,7 +311,7 @@ def __execute_once(self):

def __len__(self) -> int:
self.__execute_once()
if self.db_kind == 'sqlite3':
if self.db_kind == DBKind.sqlite3:
# sqlite3's cursor.rowcount doesn't support select statements
# returns a 0 because this method is called when a ResultsGenerator
# is converted into a list()
Expand Down Expand Up @@ -472,11 +478,11 @@ def _refresh_pks(self):
"""
Get a list of Primary Keys set for this table in the DB.
"""
if self.db.kind == 'sqlite3':
if self.db.kind == DBKind.sqlite3:
self.curs.execute('pragma table_info(%s)' % self.name)
self.pks = [i['name'] for i in self.curs.fetchall() if i['pk']]

elif self.db.kind == POSTGRES_KIND:
elif self.db.kind == DBKind.postgres:
self.curs.execute('''SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
Expand Down Expand Up @@ -539,7 +545,7 @@ def get_where(self, *a, **kw) -> ResultsGenerator:
"""
# When column names are quoted in an SQLite statement and the column doesn't exist, SQLite doesn't raise
# an exception. We'll raise an exception if any columns don't exist.
if self.db.kind == 'sqlite3':
if self.db.kind == DBKind.sqlite3:
bad_columns = set(kw.keys()).difference(self.column_names)
if bad_columns:
raise sqlite3.OperationalError(f'no such column: {bad_columns.pop()}')
Expand Down Expand Up @@ -599,7 +605,7 @@ def columns(self) -> List[str]:
"""
Get a list of columns of a table.
"""
if self.db.kind == 'sqlite3':
if self.db.kind == DBKind.sqlite3:
key = 'name'
else:
key = 'column_name'
Expand All @@ -614,7 +620,7 @@ def columns_info(self) -> List[dict]:
if self.cached_columns_info:
return self.cached_columns_info

if self.db.kind == 'sqlite3':
if self.db.kind == DBKind.sqlite3:
sql = "PRAGMA TABLE_INFO(" + str(self.name) + ")"
self.curs.execute(sql)
self.cached_columns_info = [dict(i) for i in self.curs.fetchall()]
Expand All @@ -627,7 +633,7 @@ def columns_info(self) -> List[dict]:
@property
def column_names(self) -> set:
if not self.cached_column_names:
if self.db.kind == 'sqlite3':
if self.db.kind == DBKind.sqlite3:
self.cached_column_names = set(i['name'] for i in
self.columns_info)
else:
Expand Down Expand Up @@ -677,10 +683,6 @@ def __contains__(self, item: Dict):
raise ValueError('Cannot check if item is in this Table because it is not a Dict.')


SQLITE_KIND = 'sqlite3'
POSTGRES_KIND = 'postgresql'


class DictDB(dict):
"""
Get all the tables from the provided Psycopg2/Sqlite3 connection. Create a
Expand All @@ -702,12 +704,12 @@ def __init__(self, db_conn: db_conn_type):
self._real_getitem = super().__getitem__
self.conn = db_conn
if 'sqlite3' in modules and isinstance(db_conn, sqlite3.Connection):
self.kind = SQLITE_KIND
self.kind = DBKind.sqlite3
self.insert = SqliteInsert
self.update = SqliteUpdate
self.column = SqliteColumn
else:
self.kind = POSTGRES_KIND
self.kind = DBKind.postgres
self.insert = Insert
self.update = Update
self.column = Column
Expand All @@ -730,7 +732,7 @@ def table_factory(cls) -> Table:
return Table

def __list_tables(self):
if self.kind == SQLITE_KIND:
if self.kind == DBKind.sqlite3:
self.curs.execute('SELECT name FROM sqlite_master WHERE type ='
'"table"')
else:
Expand All @@ -744,11 +746,11 @@ def get_cursor(self) -> CursorHint:
Returns a cursor from the provided database connection that DictORM
objects expect.
"""
if self.kind == SQLITE_KIND:
if self.kind == DBKind.sqlite3:
self.conn.row_factory = sqlite3.Row
curs = self.conn.cursor()
return curs
elif self.kind == POSTGRES_KIND:
elif self.kind == DBKind.postgres:
curs = self.conn.cursor(cursor_factory=DictCursor)
return curs

Expand All @@ -760,7 +762,7 @@ def refresh_tables(self):
# Reset this DictDB because it contains old tables
super(DictDB, self).__init__()
table_cls = self.table_factory()
name_key = 'name' if self.kind == SQLITE_KIND else 'table_name'
name_key = 'name' if self.kind == DBKind.sqlite3 else 'table_name'
for table in self.__list_tables():
name = table[name_key]
self[name] = table_cls(name, self)
Expand Down
2 changes: 1 addition & 1 deletion dictorm/test/test_dictorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_delete(self):
self.assertEqual(list(Person.get_where()), [bob, alice])

# get_where accepts a tuple of ids, and returns those rows
if self.db.kind != 'sqlite3':
if self.db.kind != dictorm.DBKind.sqlite3:
self.assertEqual(list(Person.get_where(Person['id'].In([1, 3]))),
[bob, alice])

Expand Down

0 comments on commit f7757d0

Please sign in to comment.