From e27b5f88a620bf1e16ee9e8d87d66d205d907ea2 Mon Sep 17 00:00:00 2001 From: XANTRONIX Industrial Date: Thu, 13 Feb 2025 15:20:46 -0500 Subject: [PATCH] Bring significant object mapping enhancements to db.py --- lib/nexrad/db.py | 186 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 185 insertions(+), 1 deletion(-) diff --git a/lib/nexrad/db.py b/lib/nexrad/db.py index 691d345..f657c07 100644 --- a/lib/nexrad/db.py +++ b/lib/nexrad/db.py @@ -1,6 +1,86 @@ +import enum import sqlite3 +class DatabaseOrder(enum.Enum): + DEFAULT = 0 + ASC = 1 + DESC = 2 + +class DatabaseTable(): + __slots__ = '__dirty__', '__dirty_columns__', + + def __init__(self): + object.__setattr__(self, '__dirty__', False) + object.__setattr__(self, '__dirty_columns__', dict()) + + def __reset__(self): + object.__setattr__(self, '__dirty__', False) + object.__setattr__(self, '__dirty_columns__', {k: 0 for k in self.columns}) + + def __setattr__(self, k, v): + object.__setattr__(self, k, v) + + values = object.__getattribute__(self, '__dirty_columns__') + + if k in values: + object.__setattr__(self, '__dirty__', True) + values[k] += 1 + +class DatabaseTableCursor(): + __slots__ = 'cr', 'table', + + def __init__(self, table, cr): + self.cr = cr + self.table = table + + def __getattr__(self, name): + return getattr(self.cr, name) + + def __map__(self, row): + fn = getattr(self.table, '__from_row__', None) + + if fn is not None: + return fn(row) + + obj = self.table() + + for name in self.table.__columns__: + try: + setattr(obj, name, getattr(row, name)) + except IndexError: + setattr(obj, name, None) + + obj.__reset__() + + return obj + + def fetchone(self): + row = self.cr.fetchone() + + return self.__map__(row) if row is not None else None + + def fetchall(self): + return map(self.__map__, self.cr.fetchall()) + + def each(self): + while True: + obj = self.fetchone() + + if obj is None: + break + + yield obj + class Database(): + __slots__ = 'db', + + def __init__(self, db): + self.db = db + + def __getattr__(self, name): + return getattr(self.db, name) + + @staticmethod def connect(path: str): db = sqlite3.connect(path) db.row_factory = sqlite3.Row @@ -9,4 +89,108 @@ class Database(): db.execute('select load_extension("mod_spatialite.so.8")') db.execute('select InitSpatialMetadata(1)') - return db + return Database(db) + + def add(self, obj): + table = type(obj) + sql = f"insert into {table.__table__} (" + sql += ", ".join([c for c in table.__columns__ if c != table.__key__]) + sql += ') values (' + sql += ", ".join(['?' for c in table.__columns__ if c != table.__key__]) + sql += f") returning {table.__key__}" + + fn = getattr(obj, '__values__', None) + + if fn is not None: + values = fn() + else: + values = list() + + for column in table.__columns__: + if column != table.__key__: + values.append(getattr(obj, column, None)) + + cr = self.db.execute(sql, values) + + setattr(obj, table.__key__, cr.fetchone()[0]) + + def update(self, obj): + if not obj.__dirty__: + return + + dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0] + table = type(obj) + sql = f"update {table.__table__} set " + sql += ", ".join([f"{k} = ?" for k in dirty]) + sql += f" where {table.__key__} = ?" + + values = [getattr(obj, k) for k in dirty] + values.append(getattr(obj, table.__key__)) + + self.db.execute(sql, values) + + def query_sql(self, table, sql, values=list()): + cr = DatabaseTableCursor(table, self.db.cursor()) + cr.execute(sql, values) + + return cr + + def query(self, table, values=dict(), order_by=list()): + selectors = getattr(table, '__selectors__') + + if selectors is None: + columns = table.__columns__ + else: + columns = [selectors[c] if c in selectors else c for c in table.__columns__] + + sql = "select %s from %s" % ( + ', '.join(columns), + table.__table__ + ) + + if len(values) > 0: + sql += " where " + sql += " and ".join([f"{table.__table__}.{k} = ?" for k in values]) + + if len(order_by) > 0: + sql += " order by" + + first = True + + for column, order in order_by: + if first: + first = False + else: + sql += ", " + + if order is None or order is DatabaseOrder.DEFAULT: + sql += f" {column}" + elif order is DatabaseOrder.ASC: + sql += f" {column} asc" + elif order is DatabaseOrder.DESC: + sql += f" {column} desc" + + return self.query_sql(table, sql, list(values.values())) + + def get(self, table, values: dict=dict()): + return self.query(table, values).fetchone() + + def _call(self, table, fn: str, column: str, values: dict=dict()) -> int: + sql = f"select {fn}({column}) as ret from {table.__table__}" + + if len(values) > 0: + sql += " where " + sql += " and ".join([f"{k} = ?" for k in values]) + + row = self.db.execute(sql, list(values.values())).fetchone() + + return row[0] if row is not None else None + + def min(self, table, column: str, values: dict=dict()) -> int: + return self._call(table, 'min', column, values) + + def max(self, table, column: str, values: dict=dict()) -> int: + return self._call(table, 'max', column, values) + + def count(self, table, values: dict=dict()) -> int: + return self._call(table, 'count', table.__key__, values)