import enum import sqlite3 class DatabaseOrder(enum.Enum): DEFAULT = 0 ASC = 1 DESC = 2 class DatabaseTable(): __slots__ = '__dirty__', '__dirty_columns__', def __init__(self): object.__setattr__(self, '__dirty__', False) object.__setattr__(self, '__dirty_columns__', dict()) def __reset__(self): object.__setattr__(self, '__dirty__', False) object.__setattr__(self, '__dirty_columns__', {k: 0 for k in self.__columns__}) def __setattr__(self, k, v): object.__setattr__(self, k, v) values = object.__getattribute__(self, '__dirty_columns__') if k in values: object.__setattr__(self, '__dirty__', True) values[k] += 1 def __format_columns_select__(self): csel = getattr(self, '__columns_read__') if csel is None: return self.__columns__ else: ret = list() for c in self.__columns__: if c in csel: ret.append(csel[c]) else: ret.append(c) return ret class DatabaseTableCursor(): __slots__ = 'cr', 'table', def __init__(self, table, cr): self.cr = cr self.table = table def __getattr__(self, name): return getattr(self.cr, name) def __map__(self, row): fn = getattr(self.table, '__from_row__', None) if fn is not None: return fn(row) obj = self.table() cr = getattr(self.table, '__values_read__', None) for name in self.table.__columns__: try: v = row[name] if cr is not None and name in cr: setattr(obj, name, cr[name](v)) else: setattr(obj, name, v) except IndexError: pass obj.__reset__() return obj def fetchone(self): row = self.cr.fetchone() return self.__map__(row) if row is not None else None def fetchall(self): return map(self.__map__, self.cr.fetchall()) def each(self): while True: obj = self.fetchone() if obj is None: break yield obj class Database(): __slots__ = 'db', def __init__(self, db): self.db = db def __getattr__(self, name): return getattr(self.db, name) @staticmethod def connect(path: str): db = sqlite3.connect(path) db.row_factory = sqlite3.Row db.enable_load_extension(True) db.execute("select load_extension('mod_spatialite.so.8')") db.execute("select InitSpatialMetadata(1)") return Database(db) def column_placeholders(self, table, obj) -> list: ret = list() for c in table.__columns__: v = getattr(obj, c) if v is not None: ret.append(c) return ret def value_placeholders(self, table, obj) -> list: ci = getattr(table, '__columns_write__', None) if ci is None: return [f':{c}' for c in table.__columns__] else: ret = list() for c in table.__columns__: v = getattr(obj, c, None) if v is None: continue if c in ci: ret.append(ci[c]) else: ret.append(f':{c}') return ret def row_values(self, table, obj) -> dict: ret = dict() vi = getattr(table, '__values_write__', None) if vi is None: for c in table.__columns__: v = getattr(obj, c) if v is not None: ret[c] = v else: for c in table.__columns__: v = getattr(obj, c) if v is not None: if c in vi: ret.update(vi[c](v)) else: ret[c] = v return ret def add(self, obj): fn = getattr(obj, '__insert__', None) if fn is not None: return fn(self.db) table = type(obj) sql = f"insert into {table.__table__} (" sql += ", ".join(self.column_placeholders(table, obj)) sql += ') values (' sql += ", ".join(self.value_placeholders(table, obj)) sql += f") returning {table.__key__}" values = self.row_values(table, obj) cr = self.db.execute(sql, values) setattr(obj, table.__key__, cr.fetchone()[0]) def update(self, obj): if not obj.__dirty__: return dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0] table = type(obj) sql = f"update {table.__table__} set " sql += ", ".join([f"{k} = :{k}" for k in dirty]) sql += f" where {table.__key__} = :{table.__key__}" values = { table.__key__: getattr(obj, table.__key__) } vi = getattr(table, '__values_write__', None) for k in dirty: if vi is not None and k in vi: values[k] = vi[k](getattr(obj, k)) else: values[k] = getattr(obj, k) self.db.execute(sql, values) def query_sql(self, table, sql, values=list()): cr = DatabaseTableCursor(table, self.db.cursor()) cr.execute(sql, values) return cr def query(self, table, values=dict(), order_by=list()): selectors = getattr(table, '__columns_read__', None) if selectors is None: columns = table.__columns__ else: columns = [selectors[c] if c in selectors else c for c in table.__columns__] sql = "select %s from %s" % ( ', '.join(columns), table.__table__ ) if len(values) > 0: sql += " where " sql += " and ".join([f"{table.__table__}.{k} = ?" for k in values]) if len(order_by) > 0: sql += " order by" first = True for column, order in order_by: if first: first = False else: sql += ", " if order is None or order is DatabaseOrder.DEFAULT: sql += f" {column}" elif order is DatabaseOrder.ASC: sql += f" {column} asc" elif order is DatabaseOrder.DESC: sql += f" {column} desc" return self.query_sql(table, sql, list(values.values())) def get(self, table, values: dict=dict()): return self.query(table, values).fetchone() def _call(self, table, fn: str, column: str, values: dict=dict()) -> int: sql = f"select {fn}({column}) as ret from {table.__table__}" if len(values) > 0: sql += " where " sql += " and ".join([f"{k} = ?" for k in values]) row = self.db.execute(sql, list(values.values())).fetchone() return row[0] if row is not None else None def min(self, table, column: str, values: dict=dict()) -> int: return self._call(table, 'min', column, values) def max(self, table, column: str, values: dict=dict()) -> int: return self._call(table, 'max', column, values) def count(self, table, values: dict=dict()) -> int: return self._call(table, 'count', table.__key__, values)