From 1688218c95fd63615293d57cf9cc1a6ffd1f8371 Mon Sep 17 00:00:00 2001 From: XANTRONIX Industrial Date: Fri, 14 Feb 2025 14:52:50 -0500 Subject: [PATCH] Drastic improvement in database layer --- lib/nexrad/db.py | 89 ++++++++++++++++++++++++++++++++++++--------- lib/nexrad/storm.py | 53 ++++----------------------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/lib/nexrad/db.py b/lib/nexrad/db.py index 95161af..2856b9a 100644 --- a/lib/nexrad/db.py +++ b/lib/nexrad/db.py @@ -108,31 +108,77 @@ class Database(): return Database(db) def column_placeholders(self, table, obj) -> list: - return ['?' for c in table.__columns__ if c != table.__key__] + ret = list() + + for c in table.__columns__: + v = getattr(obj, c) + + if v is not None: + ret.append(c) + + return ret + + def value_placeholders(self, table, obj) -> list: + ci = getattr(table, '__columns_insert__', None) + + if ci is None: + return [f':{c}' for c in table.__columns__] + else: + ret = list() + + for c in table.__columns__: + v = getattr(obj, c, None) + + if v is None: + continue + + if c in ci: + ret.append(ci[c]) + else: + ret.append(f':{c}') + + return ret + + def row_values(self, table, obj) -> dict: + ret = dict() + vi = getattr(table, '__values_insert__', None) + + if vi is None: + for c in table.__columns__: + v = getattr(obj, c) + + if v is not None: + ret[c] = v + else: + for c in table.__columns__: + v = getattr(obj, c) + + if v is not None: + if c in vi: + ret.update(vi[c](v)) + else: + ret[c] = v + + return ret def add(self, obj): - fn = getattr(obj, '__insert__') + fn = getattr(obj, '__insert__', None) if fn is not None: return fn(self.db) table = type(obj) sql = f"insert into {table.__table__} (" - sql += ", ".join([c for c in table.__columns__ if c != table.__key__]) - sql += ') values (' sql += ", ".join(self.column_placeholders(table, obj)) + sql += ') values (' + sql += ", ".join(self.value_placeholders(table, obj)) sql += f") returning {table.__key__}" - fn = getattr(obj, '__values__', None) + values = self.row_values(table, obj) - if fn is not None: - values = fn(obj) - else: - values = list() + print(f"Got values {values}") - for column in table.__columns__: - if column not in table.__constructors__ and column != table.__key__: - values.append(getattr(obj, column, None)) + print(sql) cr = self.db.execute(sql, values) @@ -145,11 +191,20 @@ class Database(): dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0] table = type(obj) sql = f"update {table.__table__} set " - sql += ", ".join([f"{k} = ?" for k in dirty]) - sql += f" where {table.__key__} = ?" + sql += ", ".join([f"{k} = :{k}" for k in dirty]) + sql += f" where {table.__key__} = :{table.__key__}" - values = [getattr(obj, k) for k in dirty] - values.append(getattr(obj, table.__key__)) + values = { + table.__key__: getattr(obj, table.__key__) + } + + vi = getattr(table, '__values_insert__', None) + + for k in dirty: + if vi is not None and k in vi: + values[k] = vi[k](getattr(obj, k)) + else: + values[k] = getattr(obj, k) self.db.execute(sql, values) @@ -160,7 +215,7 @@ class Database(): return cr def query(self, table, values=dict(), order_by=list()): - selectors = getattr(table, '__columns_select__') + selectors = getattr(table, '__columns_select__', None) if selectors is None: columns = table.__columns__ diff --git a/lib/nexrad/storm.py b/lib/nexrad/storm.py index 6b820a0..962e7c2 100644 --- a/lib/nexrad/storm.py +++ b/lib/nexrad/storm.py @@ -103,8 +103,13 @@ class StormReport(DatabaseTable): } __columns_insert__ = { - 'coord_start': lambda v: 'null' if v is None else f'MakePoint({v.lon}, {v.lat}, {COORD_SYSTEM})', - 'coord_end': lambda v: 'null' if v is None else f'MakePoint({v.lon}, {v.lat}, {COORD_SYSTEM})' + 'coord_start': 'MakePoint(:coord_start_lon, :coord_start_lat, {crs})'.format(crs=COORD_SYSTEM), + 'coord_end': 'MakePoint(:coord_end_lon, :coord_end_lat, {crs})'.format(crs=COORD_SYSTEM) + } + + __values_insert__ = { + 'coord_start': lambda v: {'coord_start_lon': v.lon, 'coord_start_lat': v.lat}, + 'coord_end': lambda v: {'coord_end_lon': v.lon, 'coord_end_lat': v.lat} } id: int @@ -149,50 +154,6 @@ class StormReport(DatabaseTable): return report - def __insert__(self, db): - columns = [ - 'id', 'timestamp_start', 'timestamp_end', 'episode_id', - 'state', 'event_type', 'wfo', 'locale_start', 'locale_end', - 'tornado_f_rating' - ] - - bindings = ['?' for c in columns] - values = [ - self.id, - self.timestamp_start.isoformat(), - self.timestamp_end.isoformat(), - self.episode_id, - self.state, - self.event_type, - self.wfo, - self.locale_start, - self.locale_end, - self.tornado_f_rating - ] - - if self.coord_start and self.coord_end: - columns.extend(['coord_start', 'coord_end']) - - bindings.extend([ - 'MakePoint(?, ?, {crs})'.format(crs=COORD_SYSTEM), - 'MakePoint(?, ?, {crs})'.format(crs=COORD_SYSTEM) - ]) - - values.extend([ - self.coord_start.lon, - self.coord_start.lat, - self.coord_end.lon, - self.coord_end.lat - ]) - - sql = "insert into nexrad_storm_report (" - sql += ', '.join(columns) - sql += ") values (" - sql += ', '.join(bindings) - sql += ")" - - db.execute(sql, values) - @staticmethod def from_csv_row(row: dict): report = StormReport()