diff --git a/lib/nexrad/db.py b/lib/nexrad/db.py index b4fe427..2af9515 100644 --- a/lib/nexrad/db.py +++ b/lib/nexrad/db.py @@ -1,186 +1,11 @@ -import enum import sqlite3 -from xenu_nntp.config import Config - -class DatabaseOrder(enum.Enum): - DEFAULT = 0 - ASC = 1 - DESC = 2 - -class DatabaseTable(): - __slots__ = '__dirty__', '__dirty_columns__', - - def __init__(self): - object.__setattr__(self, '__dirty__', False) - object.__setattr__(self, '__dirty_columns__', dict()) - - def __reset__(self): - object.__setattr__(self, '__dirty__', False) - object.__setattr__(self, '__dirty_columns__', {k: 0 for k in self.columns}) - - def __setattr__(self, k, v): - object.__setattr__(self, k, v) - - values = object.__getattribute__(self, '__dirty_columns__') - - if k in values: - object.__setattr__(self, '__dirty__', True) - values[k] += 1 - -class DatabaseTableCursor(): - __slots__ = 'cr', 'table', - - def __init__(self, table, cr): - self.cr = cr - self.table = table - - def __getattr__(self, name): - return getattr(self.cr, name) - - def __map__(self, row): - fn = getattr(self.table, '__from_row__', None) - - if fn is not None: - return fn(row) - - obj = self.table() - - for name in self.table.columns: - try: - setattr(obj, name, getattr(row, name)) - except IndexError: - setattr(obj, name, None) - - obj.__reset__() - - return obj - - def fetchone(self): - row = self.cr.fetchone() - - return self.__map__(row) if row is not None else None - - def fetchall(self): - return map(self.__map__, self.cr.fetchall()) - - def each(self): - while True: - obj = self.fetchone() - - if obj is None: - break - - yield obj - class Database(): - __slots__ = 'db', - - def __init__(self, db): - self.db = db - - def __getattr__(self, name): - return getattr(self.db, name) - - @staticmethod def connect(path: str): - db = sqlite3.connect(path, row_factory=sqlite3.Row) + db = sqlite3.connect(path) + db.enable_load_extension(True) - return Database(db) + db.execute('select load_extension("mod_spatialite")') + db.execute('select InitSpatialMetadata(1)') - def add(self, obj): - table = type(obj) - sql = f"insert into {table.name} (" - sql += ", ".join([c for c in table.columns if c != table.key]) - sql += ') values (' - sql += ", ".join(['?' for c in table.columns if c != table.key]) - sql += f") returning {table.key}" - - fn = getattr(obj, '__values__', None) - - if fn is not None: - values = fn() - else: - values = list() - - for column in table.columns: - if column != table.key: - values.append(getattr(obj, column, None)) - - cr = self.db.execute(sql, values) - - setattr(obj, table.key, cr.fetchone()[0]) - - def update(self, obj): - if not obj.__dirty__: - return - - dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0] - table = type(obj) - sql = f"update {table.name} set " - sql += ", ".join([f"{k} = ?" for k in dirty]) - sql += f" where {table.key} = ?" - - values = [getattr(obj, k) for k in dirty] - values.append(getattr(obj, table.key)) - - self.db.execute(sql, values) - - def query_sql(self, table, sql, values=list()): - cr = DatabaseTableCursor(table, self.db.cursor()) - cr.execute(sql, values) - - return cr - - def query(self, table, values=dict(), order_by=list()): - sql = "select %s from %s" % ( - ', '.join(table.columns), - table.name - ) - - if len(values) > 0: - sql += " where " - sql += " and ".join([f"{table.name}.{k} = ?" for k in values]) - - if len(order_by) > 0: - sql += " order by" - - first = True - - for column, order in order_by: - if first: - first = False - else: - sql += ", " - - if order is None or order is DatabaseOrder.DEFAULT: - sql += f" {column}" - elif order is DatabaseOrder.ASC: - sql += f" {column} asc" - elif order is DatabaseOrder.DESC: - sql += f" {column} desc" - - return self.query_sql(table, sql, list(values.values())) - - def get(self, table, values: dict=dict()): - return self.query(table, values).fetchone() - - def _call(self, table, fn: str, column: str, values: dict=dict()) -> int: - sql = f"select {fn}({column}) as ret from {table.name}" - - if len(values) > 0: - sql += " where " - sql += " and ".join([f"{k} = ?" for k in values]) - - row = self.db.execute(sql, list(values.values())).fetchone() - - return row[0] if row is not None else None - - def min(self, table, column: str, values: dict=dict()) -> int: - return self._call(table, 'min', column, values) - - def max(self, table, column: str, values: dict=dict()) -> int: - return self._call(table, 'max', column, values) - - def count(self, table, values: dict=dict()) -> int: - return self._call(table, 'count', table.key, values) + return db