Build config validation into Config.get() accessor

This commit is contained in:
XANTRONIX Development 2024-12-04 11:48:56 -05:00
parent 3a608a1636
commit 7653ebdd29
3 changed files with 24 additions and 46 deletions

View file

@ -34,7 +34,7 @@ class ConfigFileException(ConfigException):
", ".join(self.paths)
)
class Config():
class Config(configparser.ConfigParser):
SEARCH_PATHS = [
'./server.conf',
'/etc/nntp-tiny/server.conf'
@ -48,6 +48,7 @@ class Config():
return None
@staticmethod
def load(path: Optional[str]=None):
if path is None:
path = Config.find()
@ -55,7 +56,16 @@ class Config():
if path is None:
raise ConfigFileException(Config.SEARCH_PATHS)
parser = configparser.ConfigParser()
parser.read(path)
config = Config()
config.read(path)
return parser
return config
def get(self, section: str, option: str, *args, **kwargs):
if not self.has_section(section):
raise ConfigSectionException(section)
if not self.has_option(section, option):
raise ConfigValueException(section, option)
return super().get(section, option, *args, **kwargs)

View file

@ -2,20 +2,12 @@ import os
import signal
import configparser
from nntp.tiny.config import (ConfigSectionException,
ConfigValueException)
from nntp.tiny.config import Config
class Daemon():
def init(config: configparser.ConfigParser):
if not config.has_section('daemon'):
raise ConfigSectionException('daemon')
if not config.has_option('daemon', 'pidfile'):
raise ConfigValueException('daemon', 'pidfile')
def init(config: Config):
pidfile = config.get('daemon', 'pidfile')
pid = os.fork()
pid = os.fork()
if pid > 0:
exit(0)

View file

@ -5,10 +5,7 @@ import socket
import selectors
import ssl
from configparser import ConfigParser
from nntp.tiny.config import (
ConfigException, ConfigSectionException, ConfigValueException)
from nntp.tiny.config import Config, ConfigException
from nntp.tiny.db import Database
from nntp.tiny.newsgroup import Newsgroup
from nntp.tiny.session import Session
@ -19,36 +16,21 @@ class ServerCapability(enum.Flag):
POST = enum.auto()
class Server():
def __init__(self, config: ConfigParser):
def __init__(self, config: Config):
self.config = config
self.capabilities = ServerCapability.NONE
self.newsgroups = dict()
self.sslctx = None
if config['listen'].get('tls', 'no') == 'yes':
if not config.has_section('tls'):
raise ConfigSectionException('tls')
if not config.has_option('tls', 'cert'):
raise ConfigValueException('tls', 'cert')
if not config.has_option('tls', 'key'):
raise ConfigValueException('tls', 'key')
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.sslctx.load_cert_chain(config['tls']['cert'],
config['tls']['key'])
self.sslctx.load_cert_chain(config.get('tls', 'cert'),
config.get('tls', 'key'))
self._init_newsgroups()
def connect_to_db(self):
if not self.config.has_section('database'):
raise ConfigSectionException('database')
if not self.config.has_option('database', 'path'):
raise ConfigValueException('database', 'path')
return Database.connect(self.config['database']['path'])
return Database.connect(self.config.get('database', 'path'))
def _init_newsgroups(self):
db = self.connect_to_db()
@ -108,14 +90,8 @@ class Server():
return True
def run(self):
if not self.config.has_section('listen'):
raise ConfigSectionException('listen')
if not self.config.has_option('listen', 'host'):
raise ConfigValueException('listen', 'host')
hosts = re.split(r'\s*,\s*', self.config['listen']['host'])
port = int(self.config['listen']['port'])
hosts = re.split(r'\s*,\s*', self.config.get('listen', 'host'))
port = int(self.config.get('listen', 'port'))
listeners = list()