Refactor to connect to Postgres databases

This commit is contained in:
XANTRONIX Development 2025-01-02 23:45:31 -05:00
parent 6fdd06b571
commit 5872b46752
11 changed files with 74 additions and 45 deletions

View file

@ -8,6 +8,11 @@ RUN mkdir -p /var/lib/xenu-nntp
RUN mkdir -p /var/opt/xenu-nntp/bin RUN mkdir -p /var/opt/xenu-nntp/bin
RUN mkdir -p /var/opt/xenu-nntp/lib/xenu_nntp RUN mkdir -p /var/opt/xenu-nntp/lib/xenu_nntp
COPY requirements.txt /root
RUN apk add libpq
RUN pip3 install -r /root/requirements.txt
COPY <<EOF /etc/xenu-nntp/server.conf COPY <<EOF /etc/xenu-nntp/server.conf
[daemon] [daemon]
pidfile = /var/run/xenu-nntp/server.pid pidfile = /var/run/xenu-nntp/server.pid
@ -22,7 +27,10 @@ cert = /etc/xenu-nntp/tls/tls.crt
key = /etc/xenu-nntp/tls/tls.key key = /etc/xenu-nntp/tls/tls.key
[database] [database]
path = /var/lib/xenu-nntp/db.sqlite3 host = postgres
port = 5432
user = postgres
dbname = xenu_nntp
EOF EOF
COPY bin/xenu-nntp-* /var/opt/xenu-nntp/bin COPY bin/xenu-nntp-* /var/opt/xenu-nntp/bin

View file

@ -15,7 +15,7 @@ parser.add_argument('username', type=str, help='Username o
args = parser.parse_args() args = parser.parse_args()
config = Config.load(args.config_file) config = Config.load(args.config_file)
db = Database.connect(config.get('database', 'path')) db = Database.from_config(config)
user = db.get(User, {'username': args.username}) user = db.get(User, {'username': args.username})

View file

@ -17,7 +17,7 @@ parser.add_argument('username', type=str, help='Username o
args = parser.parse_args() args = parser.parse_args()
config = Config.load(args.config_file) config = Config.load(args.config_file)
db = Database.connect(config.get('database', 'path')) db = Database.from_config(config)
user = db.get(User, {'username': args.username}) user = db.get(User, {'username': args.username})

View file

@ -4,7 +4,6 @@ import os
import argparse import argparse
from xenu_nntp.config import Config from xenu_nntp.config import Config
from xenu_nntp.db import Database
from xenu_nntp.server import Server from xenu_nntp.server import Server
from xenu_nntp.daemon import Daemon from xenu_nntp.daemon import Daemon

View file

@ -1,5 +1,7 @@
import enum import enum
import sqlite3 import psycopg
from xenu_nntp.config import Config
class DatabaseOrder(enum.Enum): class DatabaseOrder(enum.Enum):
DEFAULT = 0 DEFAULT = 0
@ -46,7 +48,7 @@ class DatabaseTableCursor():
for name in self.table.columns: for name in self.table.columns:
try: try:
setattr(obj, name, row[name]) setattr(obj, name, getattr(row, name))
except IndexError: except IndexError:
setattr(obj, name, None) setattr(obj, name, None)
@ -80,20 +82,38 @@ class Database():
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.db, name) return getattr(self.db, name)
VALID_ARGS = (
'host', 'port', 'dbname', 'user', 'password'
)
@staticmethod @staticmethod
def connect(path): def connect(info):
db = sqlite3.connect(path) if isinstance(info, dict):
db.row_factory = sqlite3.Row conninfo = ' '.join(f"{k}={str(info[k])}" for k in filter(lambda k: k in info, Database.VALID_ARGS))
else:
conninfo = info
db = psycopg.connect(conninfo, row_factory=psycopg.rows.namedtuple_row)
return Database(db) return Database(db)
@staticmethod
def from_config(config: Config):
info = dict()
for option in Database.VALID_ARGS:
if config.has_option('database', option):
info[option] = config.get('database', option)
return Database.connect(info)
def add(self, obj): def add(self, obj):
table = type(obj) table = type(obj)
sql = f"insert into {table.name} (" sql = f"insert into {table.name} ("
sql += ", ".join([c for c in table.columns if c != table.key]) sql += ", ".join([c for c in table.columns if c != table.key])
sql += ') values (' sql += ') values ('
sql += ", ".join(['?' for c in table.columns if c != table.key]) sql += ", ".join(['%s' for c in table.columns if c != table.key])
sql += ')' sql += f") returning {table.key}"
fn = getattr(obj, '__values__', None) fn = getattr(obj, '__values__', None)
@ -108,7 +128,7 @@ class Database():
cr = self.db.execute(sql, values) cr = self.db.execute(sql, values)
setattr(obj, table.key, cr.lastrowid) setattr(obj, table.key, cr.fetchone()[0])
def update(self, obj): def update(self, obj):
if not obj.__dirty__: if not obj.__dirty__:
@ -117,8 +137,8 @@ class Database():
dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0] dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0]
table = type(obj) table = type(obj)
sql = f"update {table.name} set " sql = f"update {table.name} set "
sql += ", ".join([f"{k} = ?" for k in dirty]) sql += ", ".join([f"{k} = %s" for k in dirty])
sql += f" where {table.key} = ?" sql += f" where {table.key} = %s"
values = [getattr(obj, k) for k in dirty] values = [getattr(obj, k) for k in dirty]
values.append(getattr(obj, table.key)) values.append(getattr(obj, table.key))
@ -139,7 +159,7 @@ class Database():
if len(values) > 0: if len(values) > 0:
sql += " where " sql += " where "
sql += " and ".join([f"{table.name}.{k} = ?" for k in values]) sql += " and ".join([f"{table.name}.{k} = %s" for k in values])
if len(order_by) > 0: if len(order_by) > 0:
sql += " order by" sql += " order by"
@ -169,7 +189,7 @@ class Database():
if len(values) > 0: if len(values) > 0:
sql += " where " sql += " where "
sql += " and ".join([f"{k} = ?" for k in values]) sql += " and ".join([f"{k} = %s" for k in values])
row = self.db.execute(sql, list(values.values())).fetchone() row = self.db.execute(sql, list(values.values())).fetchone()

View file

@ -75,6 +75,8 @@ class MBoxReader():
return ret return ret
line = line.replace('\x00', '')
self.line += 1 self.line += 1
self.buf.add(line) self.buf.add(line)

View file

@ -168,14 +168,14 @@ class Message(DatabaseTable):
# Defer parsing the message content until a specific header not already # Defer parsing the message content until a specific header not already
# assigned to a dedcicated property, or the message body, is required. # assigned to a dedcicated property, or the message body, is required.
# #
message.content = row['content'] message.content = row.content
message.id = row['id'] message.id = row.id
message.created_on = row['created_on'] message.created_on = row.created_on
message.message_id = row['message_id'] message.message_id = row.message_id
message.reference_ids = row['reference_ids'] message.reference_ids = row.reference_ids
message.sender = row['sender'] message.sender = row.sender
message.subject = row['subject'] message.subject = row.subject
return message return message

View file

@ -13,12 +13,12 @@ class Newsgroup(DatabaseTable):
@staticmethod @staticmethod
def __from_row__(row): def __from_row__(row):
newsgroup = Newsgroup() newsgroup = Newsgroup()
newsgroup.id = row['id'] newsgroup.id = row.id
newsgroup.created_on = datetime.datetime.fromisoformat(row['created_on']) newsgroup.created_on = row.created_on
newsgroup.created_by = row['created_by'] newsgroup.created_by = row.created_by
newsgroup.name = row['name'] newsgroup.name = row.name
newsgroup.description = row['description'] newsgroup.description = row.description
newsgroup.writable = row['writable'] newsgroup.writable = row.writable
return newsgroup return newsgroup

View file

@ -21,7 +21,7 @@ class Server():
config.get('tls', 'key')) config.get('tls', 'key'))
def connect_to_db(self): def connect_to_db(self):
return Database.connect(self.config.get('database', 'path')) return Database.from_config(self.config)
def listen(self, host: str, port: int, af: int): def listen(self, host: str, port: int, af: int):
listener = socket.socket(af, socket.SOCK_STREAM) listener = socket.socket(af, socket.SOCK_STREAM)

View file

@ -130,7 +130,7 @@ class Session(Connection):
from from
newsgroup_message newsgroup_message
where where
newsgroup_id = ? newsgroup_id = %s
""" """
cr = self.db.execute(sql, (newsgroup.id,)) cr = self.db.execute(sql, (newsgroup.id,))
@ -169,8 +169,8 @@ class Session(Connection):
from from
newsgroup_message newsgroup_message
where where
newsgroup_id = ? newsgroup_id = %s
and message_id < ? and message_id < %s
""" """
cr = self.db.execute(sql, (self.newsgroup.id, self.article_id)) cr = self.db.execute(sql, (self.newsgroup.id, self.article_id))
@ -196,8 +196,8 @@ class Session(Connection):
from from
newsgroup_message newsgroup_message
where where
message_id = ? message_id = %s
and id > ? and id > %s
""" """
cr = self.db.execute(sql, (self.newsgroup.id, self.article_id)) cr = self.db.execute(sql, (self.newsgroup.id, self.article_id))
@ -221,13 +221,13 @@ class Session(Connection):
message message
where where
message.id = newsgroup_message.message_id message.id = newsgroup_message.message_id
and newsgroup_message.newsgroup_id = ? and newsgroup_message.newsgroup_id = %s
""" """
values = [newsgroup.id] values = [newsgroup.id]
if since is not None: if since is not None:
sql += " and message.created_on >= ?" sql += " and message.created_on >= %s"
values.append(since.isoformat()) values.append(since.isoformat())
cr = self.db.execute(sql, values) cr = self.db.execute(sql, values)
@ -258,7 +258,7 @@ class Session(Connection):
from from
newsgroup_message newsgroup_message
where where
newsgroup_id = ? newsgroup_id = %s
""" """
if len(args) > 1: if len(args) > 1:
@ -318,7 +318,7 @@ class Session(Connection):
message message
where where
message.id = newsgroup_message.message_id message.id = newsgroup_message.message_id
and newsgroup_message.newsgroup_id = ? and newsgroup_message.newsgroup_id = %s
""" """
cr = self.db.execute(sql, (newsgroup.id,)) cr = self.db.execute(sql, (newsgroup.id,))
@ -471,8 +471,8 @@ class Session(Connection):
message message
where where
message.id = newsgroup_message.message_id message.id = newsgroup_message.message_id
and newsgroup_message.newsgroup_id = ? and newsgroup_message.newsgroup_id = %s
and message.created_on >= ? and message.created_on >= %s
""" """
for newsgroup in self.each_newsgroup(): for newsgroup in self.each_newsgroup():
@ -555,7 +555,7 @@ class Session(Connection):
message message
where where
message.id = newsgroup_message.message_id message.id = newsgroup_message.message_id
and newsgroup_message.newsgroup_id = ? and newsgroup_message.newsgroup_id = %s
""" """
sql += " and " + msgrange.where('newsgroup_message.message_id') sql += " and " + msgrange.where('newsgroup_message.message_id')
@ -726,7 +726,7 @@ class Session(Connection):
sql = """ sql = """
insert into newsgroup_message ( insert into newsgroup_message (
newsgroup_id, message_id newsgroup_id, message_id
) values (?, ?) ) values (%s, %s)
""" """
cr = self.db.execute(sql, (newsgroup.id, message.id)) cr = self.db.execute(sql, (newsgroup.id, message.id))
@ -746,7 +746,7 @@ class Session(Connection):
from from
message message
where where
message_id = ? message_id = %s
""" """
cr = self.db.execute(sql, (message_id,)) cr = self.db.execute(sql, (message_id,))

View file

@ -38,7 +38,7 @@ class User(DatabaseTable):
server_user_permission user_perm server_user_permission user_perm
where where
perm.id = user_perm.permission_id perm.id = user_perm.permission_id
and user_perm.user_id = ? and user_perm.user_id = %s
""" """
cr = db.execute(sql, (self.id,)) cr = db.execute(sql, (self.id,))