Add type annotations to lib/xmet/db.py

This commit is contained in:
XANTRONIX 2025-04-19 11:34:57 -04:00
parent a447cfe86d
commit 1cec726451

View file

@ -1,7 +1,7 @@
import enum import enum
import sqlite3 import sqlite3
from typing import Self from typing import Self, Optional, Union, Iterable, Callable
from xmet.config import Config from xmet.config import Config
@ -122,7 +122,7 @@ class Database():
def from_config(config: Config) -> Self: def from_config(config: Config) -> Self:
return Database.connect(config['database']['path']) return Database.connect(config['database']['path'])
def column_placeholders(self, table, obj) -> list: def column_placeholders(self, table: DatabaseTable, obj) -> list:
ret = list() ret = list()
for c in table.__columns__: for c in table.__columns__:
@ -133,7 +133,7 @@ class Database():
return ret return ret
def value_placeholders(self, table, obj) -> list: def value_placeholders(self, table: DatabaseTable, obj) -> list:
ci = getattr(table, '__columns_write__', None) ci = getattr(table, '__columns_write__', None)
ret = list() ret = list()
@ -151,7 +151,7 @@ class Database():
return ret return ret
def row_values(self, table, obj) -> dict: def row_values(self, table: DatabaseTable, obj) -> dict:
ret = dict() ret = dict()
vi = getattr(table, '__values_write__', None) vi = getattr(table, '__values_write__', None)
@ -216,13 +216,21 @@ class Database():
self.db.execute(sql, values) self.db.execute(sql, values)
def query_sql(self, table, sql, values): def query_sql(self,
table: DatabaseTable,
sql: str,
values: Union[dict, Iterable]):
cr = DatabaseTableCursor(table, self.db.cursor()) cr = DatabaseTableCursor(table, self.db.cursor())
cr.execute(sql, values) cr.execute(sql, values)
return cr return cr
def query(self, table, clauses=list(), values=None, order_by=list(), limit=None): def query(self,
table: DatabaseTable,
clauses: Iterable=list(),
values: Optional[Union[dict, Iterable]]=None,
order_by: Iterable=list(),
limit: Optional[int]=None):
selectors = getattr(table, '__columns_read__', None) selectors = getattr(table, '__columns_read__', None)
if selectors is None: if selectors is None:
@ -262,15 +270,23 @@ class Database():
return self.query_sql(table, sql, values) return self.query_sql(table, sql, values)
def get_many(self, table, values: dict=dict()): def get_many(self,
table: DatabaseTable,
values: Union[dict, Iterable]=dict()):
clauses = [f"{k} = :{k}" for k in values] clauses = [f"{k} = :{k}" for k in values]
return self.query(table, clauses, values=values) return self.query(table, clauses, values=values)
def get(self, table, values: dict=dict()): def get(self,
table: DatabaseTable,
values: Union[dict, Iterable]=dict()):
return self.get_many(table, values).fetchone() return self.get_many(table, values).fetchone()
def _call(self, table, fn: str, column: str, values: dict=dict()) -> int: def _call(self,
table: DatabaseTable,
fn: str,
column: str,
values: Union[dict, Iterable]=dict()) -> int:
sql = f"select {fn}({column}) as ret from {table.__table__}" sql = f"select {fn}({column}) as ret from {table.__table__}"
if len(values) > 0: if len(values) > 0:
@ -281,11 +297,19 @@ class Database():
return row[0] if row is not None else None return row[0] if row is not None else None
def min(self, table, column: str, values: dict=dict()) -> int: def min(self,
table: DatabaseTable,
column: str,
values: Union[dict, Iterable]=dict()) -> int:
return self._call(table, 'min', column, values) return self._call(table, 'min', column, values)
def max(self, table, column: str, values: dict=dict()) -> int: def max(self,
table: DatabaseTable,
column: str,
values: Union[dict, Iterable]=dict()) -> int:
return self._call(table, 'max', column, values) return self._call(table, 'max', column, values)
def count(self, table, values: dict=dict()) -> int: def count(self,
table: DatabaseTable,
values: Union[dict, Iterable]=dict()) -> int:
return self._call(table, 'count', table.__key__, values) return self._call(table, 'count', table.__key__, values)