import enum import sqlite3 class DatabaseOrder(enum.Enum): DEFAULT = 0 ASC = 1 DESC = 2 class DatabaseTable(): pass class DatabaseTableCursor(): __slots__ = 'cr', 'table', def __init__(self, table, cr): self.cr = cr self.table = table def __getattr__(self, name): return getattr(self.cr, name) def __map__(self, row): fn = getattr(self.table, '__from_row__', None) if fn is not None: return fn(row) obj = self.table() for name in self.table.columns: try: setattr(obj, name, row[name]) except IndexError: setattr(obj, name, None) return obj def fetchone(self): row = self.cr.fetchone() return self.__map__(row) if row is not None else None def fetchall(self): return map(self.__map__, self.cr.fetchall()) class Database(): __slots__ = 'db', def __init__(self, db): self.db = db def __getattr__(self, name): return getattr(self.db, name) @staticmethod def connect(path): db = sqlite3.connect(path) db.row_factory = sqlite3.Row return Database(db) def add(self, obj): table = type(obj) sql = f"insert into {table.name} (" sql += ", ".join([c for c in table.columns if c != table.key]) sql += ') values (' sql += ", ".join(['?' for c in table.columns if c != table.key]) sql += ')' fn = getattr(obj, '__values__', None) if fn is not None: values = fn() else: values = list() for column in table.columns: if column != table.key: values.append(getattr(obj, column)) self.db.execute(sql, values) def query(self, table, values=dict(), order_by=list()): sql = f"select * from {table.name}" if len(values) > 0: sql += " where " sql += " and ".join([f"{table.name}.{k} = ?" for k in values]) if len(order_by) > 0: sql += " order by" first = True for column, order in order_by: if first: first = False else: sql += ", " if order is None or order is DatabaseOrder.DEFAULT: sql += f" {column}" elif order is DatabaseOrder.ASC: sql += f" {column} asc" elif order is DatabaseOrder.DESC: sql += f" {column} desc" cr = DatabaseTableCursor(table, self.db.cursor()) cr.execute(sql, list(values.values())) return cr def get(self, table, values: dict=dict()): return self.query(table, values).fetchone() def _sqlite3_function(self, table, fn: str, column: str, values: dict=dict()) -> int: sql = f"select {fn}({column}) as ret from {table.name}" if len(values) > 0: sql += " where " sql += " and ".join([f"{k} = ?" for k in values]) row = self.db.execute(sql, list(values.values())).fetchone() return row[0] if row is not None else None def min(self, table, column: str, values: dict=dict()) -> int: return self._sqlite3_function(table, 'min', column, values) def max(self, table, column: str, values: dict=dict()) -> int: return self._sqlite3_function(table, 'max', column, values) def count(self, table, values: dict=dict()) -> int: return self._sqlite3_function(table, 'count', table.key, values)