From 0eae832acf2229775b860801a5c07714a5d0870c Mon Sep 17 00:00:00 2001 From: XANTRONIX Development Date: Fri, 6 Dec 2024 11:14:38 -0500 Subject: [PATCH] Do not cache newsgroups in Server --- lib/nntp/tiny/server.py | 18 ++++------------- lib/nntp/tiny/session.py | 42 ++++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/lib/nntp/tiny/server.py b/lib/nntp/tiny/server.py index 472b00d..a4bcc6c 100644 --- a/lib/nntp/tiny/server.py +++ b/lib/nntp/tiny/server.py @@ -5,11 +5,10 @@ import socket import selectors import ssl -from nntp.tiny.config import Config, ConfigException -from nntp.tiny.db import Database -from nntp.tiny.host import Host -from nntp.tiny.newsgroup import Newsgroup -from nntp.tiny.session import Session +from nntp.tiny.config import Config, ConfigException +from nntp.tiny.db import Database +from nntp.tiny.host import Host +from nntp.tiny.session import Session class ServerCapability(enum.Flag): NONE = 0 @@ -20,7 +19,6 @@ class Server(): def __init__(self, config: Config): self.config = config self.capabilities = ServerCapability.NONE - self.newsgroups = dict() self.sslctx = None if config.section('listen').get('tls', 'no') == 'yes': @@ -28,17 +26,9 @@ class Server(): self.sslctx.load_cert_chain(config.get('tls', 'cert'), config.get('tls', 'key')) - self._init_newsgroups() - def connect_to_db(self): return Database.connect(self.config.get('database', 'path')) - def _init_newsgroups(self): - db = self.connect_to_db() - - for newsgroup in db.query(Newsgroup).each(): - self.newsgroups[newsgroup.name.casefold()] = newsgroup - def listen(self, host: str, port: int, af: int): listener = socket.socket(af, socket.SOCK_STREAM) listener.bind((host, port)) diff --git a/lib/nntp/tiny/session.py b/lib/nntp/tiny/session.py index f954777..9214a7d 100644 --- a/lib/nntp/tiny/session.py +++ b/lib/nntp/tiny/session.py @@ -47,6 +47,15 @@ class Session(Connection): super().__init__(sock) + def newsgroup(self, name: str): + return self.db.get(Newsgroup, {'name': name}) + + def each_newsgroup(self): + cr = self.db.query(Newsgroup) + + for newsgroup in cr.each(): + yield newsgroup + def respond(self, code: ResponseCode, message: str=None, body=None): response = Response(code, message, body) @@ -106,10 +115,10 @@ class Session(Connection): return self.respond(ResponseCode.NNTP_POST_PROHIBITED) def _cmd_group(self, name: str): - if name not in self.server.newsgroups: - return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND) + newsgroup = self.newsgroup(name) - newsgroup = self.server.newsgroups[name] + if newsgroup is None: + return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND) sql = """ select @@ -236,7 +245,7 @@ class Session(Connection): if len(args) == 0 and newsgroup is None: return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_SELECTED) elif len(args) > 0: - newsgroup = self.server.newsgroups.get(args[0]) + newsgroup = self.newsgroup(args[0]) if newsgroup is None: return self.respond(ResponseCode.NNTP_NEWSGROUP_NOT_FOUND) @@ -290,9 +299,7 @@ class Session(Connection): def _cmd_list_newsgroups(self): self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS) - for name in self.server.newsgroups: - newsgroup = self.server.newsgroups[name] - + for newsgroup in self.each_newsgroup(): self.print("%s %s" % ( newsgroup.name, newsgroup.description @@ -325,8 +332,7 @@ class Session(Connection): self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS) - for name in self.server.newsgroups: - newsgroup = self.server.newsgroups[name] + for newsgroup in self.each_newsgroup(): last_active = self._newsgroup_last_active(newsgroup) if now - last_active < datetime.timedelta(days=1): @@ -337,9 +343,7 @@ class Session(Connection): def _cmd_list_active_times(self): self.respond(ResponseCode.NNTP_INFORMATION_FOLLOWS) - for name in self.server.newsgroups: - newsgroup = self.server.newsgroups[name] - + for newsgroup in self.each_newsgroup(): self.print("%s %d %s" % ( name, newsgroup.created_on.timestamp(), @@ -469,10 +473,8 @@ class Session(Connection): and message.created_on >= ? """ - for name in self.server.newsgroups: - if fnmatch.fnmatch(name, wildmat): - newsgroup = self.server.newsgroups[name] - + for newsgroup in self.each_newsgroup(): + if fnmatch.fnmatch(newsgroup.name, wildmat): cr = self.db.execute(sql, (newsgroup.id, timestamp.isoformat())) while True: @@ -500,9 +502,7 @@ class Session(Connection): self.respond(ResponseCode.NNTP_GROUPS_NEW_FOLLOW) - for name in self.server.newsgroups: - newsgroup = self.server.newsgroups[name] - + for newsgroup in self.each_newsgroup(): self.print_newsgroup_summary(newsgroup, timestamp) return self.end() @@ -703,12 +703,12 @@ class Session(Connection): newsgroups = list() for name in names: - newsgroup = self.server.newsgroups.get(name) + newsgroup = self.newsgroup(name) if newsgroup is None or not newsgroup.writable: return ResponseCode.NNTP_POST_PROHIBITED - newsgroups.append(self.server.newsgroups[name]) + newsgroups.append(newsgroup) if len(newsgroups) == 0: return ResponseCode.NNTP_POST_FAILED