xenu_nntp/lib/nntp/tiny/db.py

144 lines
3.7 KiB
Python
Raw Normal View History

2024-11-08 23:11:09 -05:00
import enum
2024-11-08 15:47:21 -05:00
import sqlite3
2024-11-08 23:11:09 -05:00
class DatabaseOrder(enum.Enum):
DEFAULT = 0
ASC = 1
DESC = 2
class DatabaseTable():
pass
class DatabaseTableCursor():
__slots__ = 'cr', 'table',
def __init__(self, table, cr):
self.cr = cr
self.table = table
def __getattr__(self, name):
return getattr(self.cr, name)
def __map__(self, row):
fn = getattr(self.table, '__from_row__', None)
if fn is not None:
return fn(row)
2024-11-08 23:11:09 -05:00
obj = self.table()
for name in self.table.columns:
try:
setattr(obj, name, row[name])
except IndexError:
setattr(obj, name, None)
return obj
def fetchone(self):
2024-11-09 11:30:21 -05:00
row = self.cr.fetchone()
return self.__map__(row) if row is not None else None
2024-11-08 23:11:09 -05:00
def fetchall(self):
return map(self.__map__, self.cr.fetchall())
2024-11-08 15:47:21 -05:00
class Database():
__slots__ = 'db',
2024-11-08 23:11:09 -05:00
def __init__(self, db):
self.db = db
def __getattr__(self, name):
return getattr(self.db, name)
@staticmethod
def connect(path):
db = sqlite3.connect(path)
db.row_factory = sqlite3.Row
return Database(db)
def add(self, obj):
table = type(obj)
sql = f"insert into {table.name} ("
sql += ", ".join([c for c in table.columns if c != table.key])
2024-11-08 23:11:09 -05:00
sql += ') values ('
sql += ", ".join(['?' for c in table.columns if c != table.key])
2024-11-08 23:11:09 -05:00
sql += ')'
fn = getattr(obj, '__values__', None)
2024-11-08 23:11:09 -05:00
if fn is not None:
values = fn()
else:
values = list()
for column in table.columns:
if column != table.key:
values.append(getattr(obj, column))
self.db.execute(sql, values)
def query(self, table, values=dict(), order_by=list()):
sql = f"select * from {table.name}"
params = list()
first = True
for key in values.keys():
if first:
sql += f" where {table.name}.{key} = ?"
first = False
else:
sql += f" and {table.name}.{key} = ?"
params.append(values[key])
first = True
if len(order_by) > 0:
sql += " order by"
for column, order in order_by:
if first:
first = False
else:
sql += ", "
if order is None or order is DatabaseOrder.DEFAULT:
sql += f" {column}"
elif order is DatabaseOrder.ASC:
sql += f" {column} asc"
elif order is DatabaseOrder.DESC:
sql += f" {column} desc"
cr = DatabaseTableCursor(table, self.db.cursor())
cr.execute(sql, params)
return cr
2024-11-09 10:37:16 -05:00
def get(self, table, values: dict=dict()):
2024-11-08 23:11:09 -05:00
return self.query(table, values).fetchone()
def _sqlite3_function(self, table, fn: str, column: str, values: dict=dict()) -> int:
sql = f"select {fn}({column}) as ret from {table.name}"
if len(values) > 0:
2024-11-09 11:38:46 -05:00
sql += " where "
sql += " and ".join([f"{k} = ?" for k in values])
row = self.db.execute(sql, list(values.values())).fetchone()
2024-11-09 12:13:26 -05:00
return row[0] if row is not None else None
def min(self, table, column: str, values: dict=dict()) -> int:
return self._sqlite3_function(table, 'min', column, values)
def max(self, table, column: str, values: dict=dict()) -> int:
return self._sqlite3_function(table, 'max', column, values)
def count(self, table, values: dict=dict()) -> int:
return self._sqlite3_function(table, 'count', table.key, values)