Fix issue with early default value evaluation

This commit is contained in:
XANTRONIX 2025-10-31 13:02:11 -04:00
parent 6d1b47ebee
commit 8327bfa259
2 changed files with 26 additions and 8 deletions

View file

@ -4,6 +4,7 @@ import sqlite3
from typing import Self, Optional, Union, Iterable, Callable from typing import Self, Optional, Union, Iterable, Callable
from xmet.config import Config from xmet.config import Config
from xmet.util import default
class DatabaseOrder(enum.Enum): class DatabaseOrder(enum.Enum):
DEFAULT = 0 DEFAULT = 0
@ -227,12 +228,15 @@ class Database():
def query(self, def query(self,
table: DatabaseTable, table: DatabaseTable,
clauses: Iterable=list(), clauses: Optional[Iterable]=None,
values: Optional[Union[dict, Iterable]]=None, values: Optional[Union[dict, Iterable]]=None,
order_by: Iterable=list(), order_by: Optional[Iterable]=None,
limit: Optional[int]=None): limit: Optional[int]=None):
selectors = getattr(table, '__columns_read__', None) selectors = getattr(table, '__columns_read__', None)
clauses = default(clauses, list())
order_by = default(order_by, list())
if selectors is None: if selectors is None:
columns = table.__columns__ columns = table.__columns__
else: else:
@ -272,21 +276,26 @@ class Database():
def get_many(self, def get_many(self,
table: DatabaseTable, table: DatabaseTable,
values: Union[dict, Iterable]=dict()): values: Optional[Union[dict, Iterable]]=None):
values = default(values, dict())
clauses = [f"{k} = :{k}" for k in values] clauses = [f"{k} = :{k}" for k in values]
return self.query(table, clauses, values=values) return self.query(table, clauses, values=values)
def get(self, def get(self,
table: DatabaseTable, table: DatabaseTable,
values: Union[dict, Iterable]=dict()): values: Optional[Union[dict, Iterable]]=None):
values = default(values, dict())
return self.get_many(table, values).fetchone() return self.get_many(table, values).fetchone()
def _call(self, def _call(self,
table: DatabaseTable, table: DatabaseTable,
fn: str, fn: str,
column: str, column: str,
values: Union[dict, Iterable]=dict()) -> int: values: Optional[Union[dict, Iterable]]=None) -> int:
values = default(values, dict())
sql = f"select {fn}({column}) as ret from {table.__table__}" sql = f"select {fn}({column}) as ret from {table.__table__}"
if len(values) > 0: if len(values) > 0:
@ -300,16 +309,22 @@ class Database():
def min(self, def min(self,
table: DatabaseTable, table: DatabaseTable,
column: str, column: str,
values: Union[dict, Iterable]=dict()) -> int: values: Optional[Union[dict, Iterable]]=None) -> int:
values = default(values, dict())
return self._call(table, 'min', column, values) return self._call(table, 'min', column, values)
def max(self, def max(self,
table: DatabaseTable, table: DatabaseTable,
column: str, column: str,
values: Union[dict, Iterable]=dict()) -> int: values: Optional[Union[dict, Iterable]]=None) -> int:
values = default(values, dict())
return self._call(table, 'max', column, values) return self._call(table, 'max', column, values)
def count(self, def count(self,
table: DatabaseTable, table: DatabaseTable,
values: Union[dict, Iterable]=dict()) -> int: values: Optional[Union[dict, Iterable]]=None) -> int:
values = default(values, dict())
return self._call(table, 'count', table.__key__, values) return self._call(table, 'count', table.__key__, values)

View file

@ -3,6 +3,9 @@ import io
CHUNK_SIZE = 4096 CHUNK_SIZE = 4096
CHUNK_STRIP = "\x01\x03\x0a\x20" CHUNK_STRIP = "\x01\x03\x0a\x20"
def default(a, b):
return b if a is None else a
def each_chunk(fh: io.TextIOBase, sep: str, strip=None): def each_chunk(fh: io.TextIOBase, sep: str, strip=None):
buf = '' buf = ''