nexrad-archive/lib/nexrad/db.py

197 lines
5.4 KiB
Python
Raw Normal View History

import enum
2025-02-10 20:05:00 -05:00
import sqlite3
class DatabaseOrder(enum.Enum):
DEFAULT = 0
ASC = 1
DESC = 2
class DatabaseTable():
__slots__ = '__dirty__', '__dirty_columns__',
def __init__(self):
object.__setattr__(self, '__dirty__', False)
object.__setattr__(self, '__dirty_columns__', dict())
def __reset__(self):
object.__setattr__(self, '__dirty__', False)
object.__setattr__(self, '__dirty_columns__', {k: 0 for k in self.columns})
def __setattr__(self, k, v):
object.__setattr__(self, k, v)
values = object.__getattribute__(self, '__dirty_columns__')
if k in values:
object.__setattr__(self, '__dirty__', True)
values[k] += 1
class DatabaseTableCursor():
__slots__ = 'cr', 'table',
def __init__(self, table, cr):
self.cr = cr
self.table = table
def __getattr__(self, name):
return getattr(self.cr, name)
def __map__(self, row):
fn = getattr(self.table, '__from_row__', None)
if fn is not None:
return fn(row)
obj = self.table()
for name in self.table.__columns__:
try:
setattr(obj, name, getattr(row, name))
except IndexError:
setattr(obj, name, None)
obj.__reset__()
return obj
def fetchone(self):
row = self.cr.fetchone()
return self.__map__(row) if row is not None else None
def fetchall(self):
return map(self.__map__, self.cr.fetchall())
def each(self):
while True:
obj = self.fetchone()
if obj is None:
break
yield obj
2025-02-10 20:05:00 -05:00
class Database():
__slots__ = 'db',
def __init__(self, db):
self.db = db
def __getattr__(self, name):
return getattr(self.db, name)
@staticmethod
2025-02-10 20:05:00 -05:00
def connect(path: str):
2025-02-11 21:12:27 -05:00
db = sqlite3.connect(path)
2025-02-13 14:42:04 -05:00
db.row_factory = sqlite3.Row
2025-02-11 21:12:27 -05:00
db.enable_load_extension(True)
2025-02-10 20:05:00 -05:00
2025-02-11 22:18:35 -05:00
db.execute('select load_extension("mod_spatialite.so.8")')
2025-02-11 21:12:27 -05:00
db.execute('select InitSpatialMetadata(1)')
2025-02-10 20:05:00 -05:00
return Database(db)
def add(self, obj):
table = type(obj)
sql = f"insert into {table.__table__} ("
sql += ", ".join([c for c in table.__columns__ if c != table.__key__])
sql += ') values ('
sql += ", ".join(['?' for c in table.__columns__ if c != table.__key__])
sql += f") returning {table.__key__}"
fn = getattr(obj, '__values__', None)
if fn is not None:
values = fn()
else:
values = list()
for column in table.__columns__:
if column != table.__key__:
values.append(getattr(obj, column, None))
cr = self.db.execute(sql, values)
setattr(obj, table.__key__, cr.fetchone()[0])
def update(self, obj):
if not obj.__dirty__:
return
dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0]
table = type(obj)
sql = f"update {table.__table__} set "
sql += ", ".join([f"{k} = ?" for k in dirty])
sql += f" where {table.__key__} = ?"
values = [getattr(obj, k) for k in dirty]
values.append(getattr(obj, table.__key__))
self.db.execute(sql, values)
def query_sql(self, table, sql, values=list()):
cr = DatabaseTableCursor(table, self.db.cursor())
cr.execute(sql, values)
return cr
def query(self, table, values=dict(), order_by=list()):
selectors = getattr(table, '__selectors__')
if selectors is None:
columns = table.__columns__
else:
columns = [selectors[c] if c in selectors else c for c in table.__columns__]
sql = "select %s from %s" % (
', '.join(columns),
table.__table__
)
if len(values) > 0:
sql += " where "
sql += " and ".join([f"{table.__table__}.{k} = ?" for k in values])
if len(order_by) > 0:
sql += " order by"
first = True
for column, order in order_by:
if first:
first = False
else:
sql += ", "
if order is None or order is DatabaseOrder.DEFAULT:
sql += f" {column}"
elif order is DatabaseOrder.ASC:
sql += f" {column} asc"
elif order is DatabaseOrder.DESC:
sql += f" {column} desc"
return self.query_sql(table, sql, list(values.values()))
def get(self, table, values: dict=dict()):
return self.query(table, values).fetchone()
def _call(self, table, fn: str, column: str, values: dict=dict()) -> int:
sql = f"select {fn}({column}) as ret from {table.__table__}"
if len(values) > 0:
sql += " where "
sql += " and ".join([f"{k} = ?" for k in values])
row = self.db.execute(sql, list(values.values())).fetchone()
return row[0] if row is not None else None
def min(self, table, column: str, values: dict=dict()) -> int:
return self._call(table, 'min', column, values)
def max(self, table, column: str, values: dict=dict()) -> int:
return self._call(table, 'max', column, values)
def count(self, table, values: dict=dict()) -> int:
return self._call(table, 'count', table.__key__, values)