diff --git a/lib/xenu_nntp/db.py b/lib/xenu_nntp/db.py index 1ae02ac..1502a15 100644 --- a/lib/xenu_nntp/db.py +++ b/lib/xenu_nntp/db.py @@ -3,6 +3,9 @@ import psycopg from xenu_nntp.config import Config +def default(a, b): + return b if a is None else a + class DatabaseOrder(enum.Enum): DEFAULT = 0 ASC = 1 @@ -145,13 +148,16 @@ class Database(): self.db.execute(sql, values) - def query_sql(self, table, sql, values=list()): + def query_sql(self, table, sql, values: list=None): cr = DatabaseTableCursor(table, self.db.cursor()) - cr.execute(sql, values) + cr.execute(sql, default(values, list())) return cr - def query(self, table, values=dict(), order_by=list()): + def query(self, table, values: dict=None, order_by: list=None): + values = default(values, dict()) + order_by = default(order_by, list()) + sql = "select %s from %s" % ( ', '.join(table.columns), table.name @@ -181,10 +187,12 @@ class Database(): return self.query_sql(table, sql, list(values.values())) - def get(self, table, values: dict=dict()): - return self.query(table, values).fetchone() + def get(self, table, values: dict=None): + return self.query(table, default(values, dict())).fetchone() + + def _call(self, table, fn: str, column: str, values: dict=None) -> int: + values = default(values, dict()) - def _call(self, table, fn: str, column: str, values: dict=dict()) -> int: sql = f"select {fn}({column}) as ret from {table.name}" if len(values) > 0: @@ -195,11 +203,11 @@ class Database(): return row[0] if row is not None else None - def min(self, table, column: str, values: dict=dict()) -> int: - return self._call(table, 'min', column, values) + def min(self, table, column: str, values: dict=None) -> int: + return self._call(table, 'min', column, default(values, dict())) - def max(self, table, column: str, values: dict=dict()) -> int: - return self._call(table, 'max', column, values) + def max(self, table, column: str, values: dict=None) -> int: + return self._call(table, 'max', column, default(values, dict())) - def count(self, table, values: dict=dict()) -> int: - return self._call(table, 'count', table.key, values) + def count(self, table, values: dict=None) -> int: + return self._call(table, 'count', table.key, default(values, dict()))