126 lines
3 KiB
Python
126 lines
3 KiB
Python
import enum
|
|
import sqlite3
|
|
|
|
class DatabaseOrder(enum.Enum):
|
|
DEFAULT = 0
|
|
ASC = 1
|
|
DESC = 2
|
|
|
|
class DatabaseTable():
|
|
pass
|
|
|
|
class DatabaseTableCursor():
|
|
__slots__ = 'cr', 'table',
|
|
|
|
def __init__(self, table, cr):
|
|
self.cr = cr
|
|
self.table = table
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.cr, name)
|
|
|
|
def __map__(self, row):
|
|
obj = self.table()
|
|
|
|
for name in self.table.columns:
|
|
try:
|
|
setattr(obj, name, row[name])
|
|
except IndexError:
|
|
setattr(obj, name, None)
|
|
|
|
return obj
|
|
|
|
def fetchone(self):
|
|
return self.__map__(self.cr.fetchone())
|
|
|
|
def fetchall(self):
|
|
return map(self.__map__, self.cr.fetchall())
|
|
|
|
class Database():
|
|
__slots__ = 'db',
|
|
|
|
def __init__(self, db):
|
|
self.db = db
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self.db, name)
|
|
|
|
@staticmethod
|
|
def connect(path):
|
|
db = sqlite3.connect(path)
|
|
db.row_factory = sqlite3.Row
|
|
|
|
return Database(db)
|
|
|
|
def add(self, obj):
|
|
table = type(obj)
|
|
sql = f"insert into {table.name} ("
|
|
sql += ", ".join([c for c in table.columns if c != table.key])
|
|
sql += ') values ('
|
|
sql += ", ".join(['?' for c in table.columns if c != table.key])
|
|
sql += ')'
|
|
|
|
fn = getattr(obj, '__values__')
|
|
|
|
if fn is not None:
|
|
values = fn()
|
|
else:
|
|
values = list()
|
|
|
|
for column in table.columns:
|
|
if column != table.key:
|
|
values.append(getattr(obj, column))
|
|
|
|
self.db.execute(sql, values)
|
|
|
|
def query(self, table, values=dict(), order_by=list()):
|
|
sql = f"select * from {table.name}"
|
|
|
|
params = list()
|
|
first = True
|
|
|
|
for key in values.keys():
|
|
if first:
|
|
sql += f" where {table.name}.{key} = ?"
|
|
|
|
first = False
|
|
else:
|
|
sql += f" and {table.name}.{key} = ?"
|
|
|
|
params.append(values[key])
|
|
|
|
first = True
|
|
|
|
if len(order_by) > 0:
|
|
sql += " order by"
|
|
|
|
for column, order in order_by:
|
|
if first:
|
|
first = False
|
|
else:
|
|
sql += ", "
|
|
|
|
if order is None or order is DatabaseOrder.DEFAULT:
|
|
sql += f" {column}"
|
|
elif order is DatabaseOrder.ASC:
|
|
sql += f" {column} asc"
|
|
elif order is DatabaseOrder.DESC:
|
|
sql += f" {column} desc"
|
|
|
|
cr = DatabaseTableCursor(table, self.db.cursor())
|
|
cr.execute(sql, params)
|
|
|
|
return cr
|
|
|
|
def get(self, table, values: dict=dict()):
|
|
return self.query(table, values).fetchone()
|
|
|
|
def count(self, table, values: dict=dict()):
|
|
sql = "select count(id) as num from {table.name}"
|
|
|
|
if len(values) > 0:
|
|
sql += " and ".join([f"{k} = ?" for k in values])
|
|
|
|
result = self.db.execute(sql, values.values())
|
|
|
|
return result[0][0]
|