Refactor to connect to Postgres databases
This commit is contained in:
parent
6fdd06b571
commit
5872b46752
11 changed files with 74 additions and 45 deletions
10
Dockerfile
10
Dockerfile
|
@ -8,6 +8,11 @@ RUN mkdir -p /var/lib/xenu-nntp
|
||||||
RUN mkdir -p /var/opt/xenu-nntp/bin
|
RUN mkdir -p /var/opt/xenu-nntp/bin
|
||||||
RUN mkdir -p /var/opt/xenu-nntp/lib/xenu_nntp
|
RUN mkdir -p /var/opt/xenu-nntp/lib/xenu_nntp
|
||||||
|
|
||||||
|
COPY requirements.txt /root
|
||||||
|
|
||||||
|
RUN apk add libpq
|
||||||
|
RUN pip3 install -r /root/requirements.txt
|
||||||
|
|
||||||
COPY <<EOF /etc/xenu-nntp/server.conf
|
COPY <<EOF /etc/xenu-nntp/server.conf
|
||||||
[daemon]
|
[daemon]
|
||||||
pidfile = /var/run/xenu-nntp/server.pid
|
pidfile = /var/run/xenu-nntp/server.pid
|
||||||
|
@ -22,7 +27,10 @@ cert = /etc/xenu-nntp/tls/tls.crt
|
||||||
key = /etc/xenu-nntp/tls/tls.key
|
key = /etc/xenu-nntp/tls/tls.key
|
||||||
|
|
||||||
[database]
|
[database]
|
||||||
path = /var/lib/xenu-nntp/db.sqlite3
|
host = postgres
|
||||||
|
port = 5432
|
||||||
|
user = postgres
|
||||||
|
dbname = xenu_nntp
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
COPY bin/xenu-nntp-* /var/opt/xenu-nntp/bin
|
COPY bin/xenu-nntp-* /var/opt/xenu-nntp/bin
|
||||||
|
|
|
@ -15,7 +15,7 @@ parser.add_argument('username', type=str, help='Username o
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = Config.load(args.config_file)
|
config = Config.load(args.config_file)
|
||||||
db = Database.connect(config.get('database', 'path'))
|
db = Database.from_config(config)
|
||||||
|
|
||||||
user = db.get(User, {'username': args.username})
|
user = db.get(User, {'username': args.username})
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ parser.add_argument('username', type=str, help='Username o
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = Config.load(args.config_file)
|
config = Config.load(args.config_file)
|
||||||
db = Database.connect(config.get('database', 'path'))
|
db = Database.from_config(config)
|
||||||
|
|
||||||
user = db.get(User, {'username': args.username})
|
user = db.get(User, {'username': args.username})
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ import os
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from xenu_nntp.config import Config
|
from xenu_nntp.config import Config
|
||||||
from xenu_nntp.db import Database
|
|
||||||
from xenu_nntp.server import Server
|
from xenu_nntp.server import Server
|
||||||
from xenu_nntp.daemon import Daemon
|
from xenu_nntp.daemon import Daemon
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import enum
|
import enum
|
||||||
import sqlite3
|
import psycopg
|
||||||
|
|
||||||
|
from xenu_nntp.config import Config
|
||||||
|
|
||||||
class DatabaseOrder(enum.Enum):
|
class DatabaseOrder(enum.Enum):
|
||||||
DEFAULT = 0
|
DEFAULT = 0
|
||||||
|
@ -46,7 +48,7 @@ class DatabaseTableCursor():
|
||||||
|
|
||||||
for name in self.table.columns:
|
for name in self.table.columns:
|
||||||
try:
|
try:
|
||||||
setattr(obj, name, row[name])
|
setattr(obj, name, getattr(row, name))
|
||||||
except IndexError:
|
except IndexError:
|
||||||
setattr(obj, name, None)
|
setattr(obj, name, None)
|
||||||
|
|
||||||
|
@ -80,20 +82,38 @@ class Database():
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.db, name)
|
return getattr(self.db, name)
|
||||||
|
|
||||||
|
VALID_ARGS = (
|
||||||
|
'host', 'port', 'dbname', 'user', 'password'
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def connect(path):
|
def connect(info):
|
||||||
db = sqlite3.connect(path)
|
if isinstance(info, dict):
|
||||||
db.row_factory = sqlite3.Row
|
conninfo = ' '.join(f"{k}={str(info[k])}" for k in filter(lambda k: k in info, Database.VALID_ARGS))
|
||||||
|
else:
|
||||||
|
conninfo = info
|
||||||
|
|
||||||
|
db = psycopg.connect(conninfo, row_factory=psycopg.rows.namedtuple_row)
|
||||||
|
|
||||||
return Database(db)
|
return Database(db)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: Config):
|
||||||
|
info = dict()
|
||||||
|
|
||||||
|
for option in Database.VALID_ARGS:
|
||||||
|
if config.has_option('database', option):
|
||||||
|
info[option] = config.get('database', option)
|
||||||
|
|
||||||
|
return Database.connect(info)
|
||||||
|
|
||||||
def add(self, obj):
|
def add(self, obj):
|
||||||
table = type(obj)
|
table = type(obj)
|
||||||
sql = f"insert into {table.name} ("
|
sql = f"insert into {table.name} ("
|
||||||
sql += ", ".join([c for c in table.columns if c != table.key])
|
sql += ", ".join([c for c in table.columns if c != table.key])
|
||||||
sql += ') values ('
|
sql += ') values ('
|
||||||
sql += ", ".join(['?' for c in table.columns if c != table.key])
|
sql += ", ".join(['%s' for c in table.columns if c != table.key])
|
||||||
sql += ')'
|
sql += f") returning {table.key}"
|
||||||
|
|
||||||
fn = getattr(obj, '__values__', None)
|
fn = getattr(obj, '__values__', None)
|
||||||
|
|
||||||
|
@ -108,7 +128,7 @@ class Database():
|
||||||
|
|
||||||
cr = self.db.execute(sql, values)
|
cr = self.db.execute(sql, values)
|
||||||
|
|
||||||
setattr(obj, table.key, cr.lastrowid)
|
setattr(obj, table.key, cr.fetchone()[0])
|
||||||
|
|
||||||
def update(self, obj):
|
def update(self, obj):
|
||||||
if not obj.__dirty__:
|
if not obj.__dirty__:
|
||||||
|
@ -117,8 +137,8 @@ class Database():
|
||||||
dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0]
|
dirty = [k for k in obj.__dirty_columns__ if obj.__dirty_columns__[k] > 0]
|
||||||
table = type(obj)
|
table = type(obj)
|
||||||
sql = f"update {table.name} set "
|
sql = f"update {table.name} set "
|
||||||
sql += ", ".join([f"{k} = ?" for k in dirty])
|
sql += ", ".join([f"{k} = %s" for k in dirty])
|
||||||
sql += f" where {table.key} = ?"
|
sql += f" where {table.key} = %s"
|
||||||
|
|
||||||
values = [getattr(obj, k) for k in dirty]
|
values = [getattr(obj, k) for k in dirty]
|
||||||
values.append(getattr(obj, table.key))
|
values.append(getattr(obj, table.key))
|
||||||
|
@ -139,7 +159,7 @@ class Database():
|
||||||
|
|
||||||
if len(values) > 0:
|
if len(values) > 0:
|
||||||
sql += " where "
|
sql += " where "
|
||||||
sql += " and ".join([f"{table.name}.{k} = ?" for k in values])
|
sql += " and ".join([f"{table.name}.{k} = %s" for k in values])
|
||||||
|
|
||||||
if len(order_by) > 0:
|
if len(order_by) > 0:
|
||||||
sql += " order by"
|
sql += " order by"
|
||||||
|
@ -169,7 +189,7 @@ class Database():
|
||||||
|
|
||||||
if len(values) > 0:
|
if len(values) > 0:
|
||||||
sql += " where "
|
sql += " where "
|
||||||
sql += " and ".join([f"{k} = ?" for k in values])
|
sql += " and ".join([f"{k} = %s" for k in values])
|
||||||
|
|
||||||
row = self.db.execute(sql, list(values.values())).fetchone()
|
row = self.db.execute(sql, list(values.values())).fetchone()
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,8 @@ class MBoxReader():
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
line = line.replace('\x00', '')
|
||||||
|
|
||||||
self.line += 1
|
self.line += 1
|
||||||
|
|
||||||
self.buf.add(line)
|
self.buf.add(line)
|
||||||
|
|
|
@ -168,14 +168,14 @@ class Message(DatabaseTable):
|
||||||
# Defer parsing the message content until a specific header not already
|
# Defer parsing the message content until a specific header not already
|
||||||
# assigned to a dedcicated property, or the message body, is required.
|
# assigned to a dedcicated property, or the message body, is required.
|
||||||
#
|
#
|
||||||
message.content = row['content']
|
message.content = row.content
|
||||||
|
|
||||||
message.id = row['id']
|
message.id = row.id
|
||||||
message.created_on = row['created_on']
|
message.created_on = row.created_on
|
||||||
message.message_id = row['message_id']
|
message.message_id = row.message_id
|
||||||
message.reference_ids = row['reference_ids']
|
message.reference_ids = row.reference_ids
|
||||||
message.sender = row['sender']
|
message.sender = row.sender
|
||||||
message.subject = row['subject']
|
message.subject = row.subject
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,12 @@ class Newsgroup(DatabaseTable):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __from_row__(row):
|
def __from_row__(row):
|
||||||
newsgroup = Newsgroup()
|
newsgroup = Newsgroup()
|
||||||
newsgroup.id = row['id']
|
newsgroup.id = row.id
|
||||||
newsgroup.created_on = datetime.datetime.fromisoformat(row['created_on'])
|
newsgroup.created_on = row.created_on
|
||||||
newsgroup.created_by = row['created_by']
|
newsgroup.created_by = row.created_by
|
||||||
newsgroup.name = row['name']
|
newsgroup.name = row.name
|
||||||
newsgroup.description = row['description']
|
newsgroup.description = row.description
|
||||||
newsgroup.writable = row['writable']
|
newsgroup.writable = row.writable
|
||||||
|
|
||||||
return newsgroup
|
return newsgroup
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ class Server():
|
||||||
config.get('tls', 'key'))
|
config.get('tls', 'key'))
|
||||||
|
|
||||||
def connect_to_db(self):
|
def connect_to_db(self):
|
||||||
return Database.connect(self.config.get('database', 'path'))
|
return Database.from_config(self.config)
|
||||||
|
|
||||||
def listen(self, host: str, port: int, af: int):
|
def listen(self, host: str, port: int, af: int):
|
||||||
listener = socket.socket(af, socket.SOCK_STREAM)
|
listener = socket.socket(af, socket.SOCK_STREAM)
|
||||||
|
|
|
@ -130,7 +130,7 @@ class Session(Connection):
|
||||||
from
|
from
|
||||||
newsgroup_message
|
newsgroup_message
|
||||||
where
|
where
|
||||||
newsgroup_id = ?
|
newsgroup_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = self.db.execute(sql, (newsgroup.id,))
|
cr = self.db.execute(sql, (newsgroup.id,))
|
||||||
|
@ -169,8 +169,8 @@ class Session(Connection):
|
||||||
from
|
from
|
||||||
newsgroup_message
|
newsgroup_message
|
||||||
where
|
where
|
||||||
newsgroup_id = ?
|
newsgroup_id = %s
|
||||||
and message_id < ?
|
and message_id < %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = self.db.execute(sql, (self.newsgroup.id, self.article_id))
|
cr = self.db.execute(sql, (self.newsgroup.id, self.article_id))
|
||||||
|
@ -196,8 +196,8 @@ class Session(Connection):
|
||||||
from
|
from
|
||||||
newsgroup_message
|
newsgroup_message
|
||||||
where
|
where
|
||||||
message_id = ?
|
message_id = %s
|
||||||
and id > ?
|
and id > %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = self.db.execute(sql, (self.newsgroup.id, self.article_id))
|
cr = self.db.execute(sql, (self.newsgroup.id, self.article_id))
|
||||||
|
@ -221,13 +221,13 @@ class Session(Connection):
|
||||||
message
|
message
|
||||||
where
|
where
|
||||||
message.id = newsgroup_message.message_id
|
message.id = newsgroup_message.message_id
|
||||||
and newsgroup_message.newsgroup_id = ?
|
and newsgroup_message.newsgroup_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
values = [newsgroup.id]
|
values = [newsgroup.id]
|
||||||
|
|
||||||
if since is not None:
|
if since is not None:
|
||||||
sql += " and message.created_on >= ?"
|
sql += " and message.created_on >= %s"
|
||||||
values.append(since.isoformat())
|
values.append(since.isoformat())
|
||||||
|
|
||||||
cr = self.db.execute(sql, values)
|
cr = self.db.execute(sql, values)
|
||||||
|
@ -258,7 +258,7 @@ class Session(Connection):
|
||||||
from
|
from
|
||||||
newsgroup_message
|
newsgroup_message
|
||||||
where
|
where
|
||||||
newsgroup_id = ?
|
newsgroup_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
|
@ -318,7 +318,7 @@ class Session(Connection):
|
||||||
message
|
message
|
||||||
where
|
where
|
||||||
message.id = newsgroup_message.message_id
|
message.id = newsgroup_message.message_id
|
||||||
and newsgroup_message.newsgroup_id = ?
|
and newsgroup_message.newsgroup_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = self.db.execute(sql, (newsgroup.id,))
|
cr = self.db.execute(sql, (newsgroup.id,))
|
||||||
|
@ -471,8 +471,8 @@ class Session(Connection):
|
||||||
message
|
message
|
||||||
where
|
where
|
||||||
message.id = newsgroup_message.message_id
|
message.id = newsgroup_message.message_id
|
||||||
and newsgroup_message.newsgroup_id = ?
|
and newsgroup_message.newsgroup_id = %s
|
||||||
and message.created_on >= ?
|
and message.created_on >= %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for newsgroup in self.each_newsgroup():
|
for newsgroup in self.each_newsgroup():
|
||||||
|
@ -555,7 +555,7 @@ class Session(Connection):
|
||||||
message
|
message
|
||||||
where
|
where
|
||||||
message.id = newsgroup_message.message_id
|
message.id = newsgroup_message.message_id
|
||||||
and newsgroup_message.newsgroup_id = ?
|
and newsgroup_message.newsgroup_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sql += " and " + msgrange.where('newsgroup_message.message_id')
|
sql += " and " + msgrange.where('newsgroup_message.message_id')
|
||||||
|
@ -726,7 +726,7 @@ class Session(Connection):
|
||||||
sql = """
|
sql = """
|
||||||
insert into newsgroup_message (
|
insert into newsgroup_message (
|
||||||
newsgroup_id, message_id
|
newsgroup_id, message_id
|
||||||
) values (?, ?)
|
) values (%s, %s)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = self.db.execute(sql, (newsgroup.id, message.id))
|
cr = self.db.execute(sql, (newsgroup.id, message.id))
|
||||||
|
@ -746,7 +746,7 @@ class Session(Connection):
|
||||||
from
|
from
|
||||||
message
|
message
|
||||||
where
|
where
|
||||||
message_id = ?
|
message_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = self.db.execute(sql, (message_id,))
|
cr = self.db.execute(sql, (message_id,))
|
||||||
|
|
|
@ -38,7 +38,7 @@ class User(DatabaseTable):
|
||||||
server_user_permission user_perm
|
server_user_permission user_perm
|
||||||
where
|
where
|
||||||
perm.id = user_perm.permission_id
|
perm.id = user_perm.permission_id
|
||||||
and user_perm.user_id = ?
|
and user_perm.user_id = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cr = db.execute(sql, (self.id,))
|
cr = db.execute(sql, (self.id,))
|
||||||
|
|
Loading…
Add table
Reference in a new issue