Do not cache newsgroups in Server

This commit is contained in:
XANTRONIX Development 2024-12-06 11:14:38 -05:00
parent ecdbcbbbba
commit 0eae832acf
2 changed files with 25 additions and 35 deletions

View file

@ -5,11 +5,10 @@ import socket
import selectors import selectors
import ssl import ssl
from nntp.tiny.config import Config, ConfigException from nntp.tiny.config import Config, ConfigException
from nntp.tiny.db import Database from nntp.tiny.db import Database
from nntp.tiny.host import Host from nntp.tiny.host import Host
from nntp.tiny.newsgroup import Newsgroup from nntp.tiny.session import Session
from nntp.tiny.session import Session
class ServerCapability(enum.Flag): class ServerCapability(enum.Flag):
NONE = 0 NONE = 0
@ -20,7 +19,6 @@ class Server():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
self.capabilities = ServerCapability.NONE self.capabilities = ServerCapability.NONE
self.newsgroups = dict()
self.sslctx = None self.sslctx = None
if config.section('listen').get('tls', 'no') == 'yes': if config.section('listen').get('tls', 'no') == 'yes':
@ -28,17 +26,9 @@ class Server():
self.sslctx.load_cert_chain(config.get('tls', 'cert'), self.sslctx.load_cert_chain(config.get('tls', 'cert'),
config.get('tls', 'key')) config.get('tls', 'key'))
self._init_newsgroups()
def connect_to_db(self): def connect_to_db(self):
return Database.connect(self.config.get('database', 'path')) return Database.connect(self.config.get('database', 'path'))
def _init_newsgroups(self):
db = self.connect_to_db()
for newsgroup in db.query(Newsgroup).each():
self.newsgroups[newsgroup.name.casefold()] = newsgroup
def listen(self, host: str, port: int, af: int): def listen(self, host: str, port: int, af: int):
listener = socket.socket(af, socket.SOCK_STREAM) listener = socket.socket(af, socket.SOCK_STREAM)
listener.bind((host, port)) listener.bind((host, port))

View file

@ -47,6 +47,15 @@ class Session(Connection):
super().__init__(sock) super().__init__(sock)
def newsgroup(self, name: str):
return self.db.get(Newsgroup, {'name': name})
def each_newsgroup(self):
cr = self.db.query(Newsgroup)
for newsgroup in cr.each():
yield newsgroup
def respond(self, code: ResponseCode, message: str=None, body=None): def respond(self, code: ResponseCode, message: str=None, body=None):
response = Response(code, message, body) response = Response(code, message, body)
@ -106,10 +115,10 @@ class Session(Connection):
return self.respond(ResponseCode.NNTP_POST_PROHIBITED) return self.respond(ResponseCode.NNTP_POST_PROHIBITED)
def _cmd_group(self, name: str): def _cmd_group(self, name: str):
if name not in self.server.newsgroups: newsgroup = self.newsgroup(name)
return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND)
newsgroup = self.server.newsgroups[name] if newsgroup is None:
return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND)
sql = """ sql = """
select select
@ -236,7 +245,7 @@ class Session(Connection):
if len(args) == 0 and newsgroup is None: if len(args) == 0 and newsgroup is None:
return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_SELECTED) return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_SELECTED)
elif len(args) > 0: elif len(args) > 0:
newsgroup = self.server.newsgroups.get(args[0]) newsgroup = self.newsgroup(args[0])
if newsgroup is None: if newsgroup is None:
return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND) return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND)
@ -290,9 +299,7 @@ class Session(Connection):
def _cmd_list_newsgroups(self): def _cmd_list_newsgroups(self):
self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS) self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS)
for name in self.server.newsgroups: for newsgroup in self.each_newsgroup():
newsgroup = self.server.newsgroups[name]
self.print("%s %s" % ( self.print("%s %s" % (
newsgroup.name, newsgroup.name,
newsgroup.description newsgroup.description
@ -325,8 +332,7 @@ class Session(Connection):
self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS) self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS)
for name in self.server.newsgroups: for newsgroup in self.each_newsgroup():
newsgroup = self.server.newsgroups[name]
last_active = self._newsgroup_last_active(newsgroup) last_active = self._newsgroup_last_active(newsgroup)
if now - last_active < datetime.timedelta(days=1): if now - last_active < datetime.timedelta(days=1):
@ -337,9 +343,7 @@ class Session(Connection):
def _cmd_list_active_times(self): def _cmd_list_active_times(self):
self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS) self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS)
for name in self.server.newsgroups: for newsgroup in self.each_newsgroup():
newsgroup = self.server.newsgroups[name]
self.print("%s %d %s" % ( self.print("%s %d %s" % (
name, name,
newsgroup.created_on.timestamp(), newsgroup.created_on.timestamp(),
@ -469,10 +473,8 @@ class Session(Connection):
and message.created_on >= ? and message.created_on >= ?
""" """
for name in self.server.newsgroups: for newsgroup in self.each_newsgroup():
if fnmatch.fnmatch(name, wildmat): if fnmatch.fnmatch(newsgroup.name, wildmat):
newsgroup = self.server.newsgroups[name]
cr = self.db.execute(sql, (newsgroup.id, timestamp.isoformat())) cr = self.db.execute(sql, (newsgroup.id, timestamp.isoformat()))
while True: while True:
@ -500,9 +502,7 @@ class Session(Connection):
self.respond(ResponseCode.NNTP_GROUPS_NEW_FOLLOW) self.respond(ResponseCode.NNTP_GROUPS_NEW_FOLLOW)
for name in self.server.newsgroups: for newsgroup in self.each_newsgroup():
newsgroup = self.server.newsgroups[name]
self.print_newsgroup_summary(newsgroup, timestamp) self.print_newsgroup_summary(newsgroup, timestamp)
return self.end() return self.end()
@ -703,12 +703,12 @@ class Session(Connection):
newsgroups = list() newsgroups = list()
for name in names: for name in names:
newsgroup = self.server.newsgroups.get(name) newsgroup = self.newsgroup(name)
if newsgroup is None or not newsgroup.writable: if newsgroup is None or not newsgroup.writable:
return ResponseCode.NNTP_POST_PROHIBITED return ResponseCode.NNTP_POST_PROHIBITED
newsgroups.append(self.server.newsgroups[name]) newsgroups.append(newsgroup)
if len(newsgroups) == 0: if len(newsgroups) == 0:
return ResponseCode.NNTP_POST_FAILED return ResponseCode.NNTP_POST_FAILED