diff --git a/lib/xmet/db.py b/lib/xmet/db.py index 9e8d5eb..a17a6f6 100644 --- a/lib/xmet/db.py +++ b/lib/xmet/db.py @@ -1,7 +1,7 @@ import enum import sqlite3 -from typing import Self +from typing import Self, Optional, Union, Iterable, Callable from xmet.config import Config @@ -122,7 +122,7 @@ class Database(): def from_config(config: Config) -> Self: return Database.connect(config['database']['path']) - def column_placeholders(self, table, obj) -> list: + def column_placeholders(self, table: DatabaseTable, obj) -> list: ret = list() for c in table.__columns__: @@ -133,7 +133,7 @@ class Database(): return ret - def value_placeholders(self, table, obj) -> list: + def value_placeholders(self, table: DatabaseTable, obj) -> list: ci = getattr(table, '__columns_write__', None) ret = list() @@ -151,7 +151,7 @@ class Database(): return ret - def row_values(self, table, obj) -> dict: + def row_values(self, table: DatabaseTable, obj) -> dict: ret = dict() vi = getattr(table, '__values_write__', None) @@ -216,13 +216,21 @@ class Database(): self.db.execute(sql, values) - def query_sql(self, table, sql, values): + def query_sql(self, + table: DatabaseTable, + sql: str, + values: Union[dict, Iterable]): cr = DatabaseTableCursor(table, self.db.cursor()) cr.execute(sql, values) return cr - def query(self, table, clauses=list(), values=None, order_by=list(), limit=None): + def query(self, + table: DatabaseTable, + clauses: Iterable=list(), + values: Optional[Union[dict, Iterable]]=None, + order_by: Iterable=list(), + limit: Optional[int]=None): selectors = getattr(table, '__columns_read__', None) if selectors is None: @@ -262,15 +270,23 @@ class Database(): return self.query_sql(table, sql, values) - def get_many(self, table, values: dict=dict()): + def get_many(self, + table: DatabaseTable, + values: Union[dict, Iterable]=dict()): clauses = [f"{k} = :{k}" for k in values] return self.query(table, clauses, values=values) - def get(self, table, values: dict=dict()): + def get(self, + table: DatabaseTable, + values: Union[dict, Iterable]=dict()): return self.get_many(table, values).fetchone() - def _call(self, table, fn: str, column: str, values: dict=dict()) -> int: + def _call(self, + table: DatabaseTable, + fn: str, + column: str, + values: Union[dict, Iterable]=dict()) -> int: sql = f"select {fn}({column}) as ret from {table.__table__}" if len(values) > 0: @@ -281,11 +297,19 @@ class Database(): return row[0] if row is not None else None - def min(self, table, column: str, values: dict=dict()) -> int: + def min(self, + table: DatabaseTable, + column: str, + values: Union[dict, Iterable]=dict()) -> int: return self._call(table, 'min', column, values) - def max(self, table, column: str, values: dict=dict()) -> int: + def max(self, + table: DatabaseTable, + column: str, + values: Union[dict, Iterable]=dict()) -> int: return self._call(table, 'max', column, values) - def count(self, table, values: dict=dict()) -> int: + def count(self, + table: DatabaseTable, + values: Union[dict, Iterable]=dict()) -> int: return self._call(table, 'count', table.__key__, values)