Build config validation into Config.get() accessor
This commit is contained in:
parent
3a608a1636
commit
7653ebdd29
3 changed files with 24 additions and 46 deletions
|
@ -34,7 +34,7 @@ class ConfigFileException(ConfigException):
|
|||
", ".join(self.paths)
|
||||
)
|
||||
|
||||
class Config():
|
||||
class Config(configparser.ConfigParser):
|
||||
SEARCH_PATHS = [
|
||||
'./server.conf',
|
||||
'/etc/nntp-tiny/server.conf'
|
||||
|
@ -48,6 +48,7 @@ class Config():
|
|||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def load(path: Optional[str]=None):
|
||||
if path is None:
|
||||
path = Config.find()
|
||||
|
@ -55,7 +56,16 @@ class Config():
|
|||
if path is None:
|
||||
raise ConfigFileException(Config.SEARCH_PATHS)
|
||||
|
||||
parser = configparser.ConfigParser()
|
||||
parser.read(path)
|
||||
config = Config()
|
||||
config.read(path)
|
||||
|
||||
return parser
|
||||
return config
|
||||
|
||||
def get(self, section: str, option: str, *args, **kwargs):
|
||||
if not self.has_section(section):
|
||||
raise ConfigSectionException(section)
|
||||
|
||||
if not self.has_option(section, option):
|
||||
raise ConfigValueException(section, option)
|
||||
|
||||
return super().get(section, option, *args, **kwargs)
|
||||
|
|
|
@ -2,20 +2,12 @@ import os
|
|||
import signal
|
||||
import configparser
|
||||
|
||||
from nntp.tiny.config import (ConfigSectionException,
|
||||
ConfigValueException)
|
||||
from nntp.tiny.config import Config
|
||||
|
||||
class Daemon():
|
||||
def init(config: configparser.ConfigParser):
|
||||
if not config.has_section('daemon'):
|
||||
raise ConfigSectionException('daemon')
|
||||
|
||||
if not config.has_option('daemon', 'pidfile'):
|
||||
raise ConfigValueException('daemon', 'pidfile')
|
||||
|
||||
def init(config: Config):
|
||||
pidfile = config.get('daemon', 'pidfile')
|
||||
|
||||
pid = os.fork()
|
||||
pid = os.fork()
|
||||
|
||||
if pid > 0:
|
||||
exit(0)
|
||||
|
|
|
@ -5,10 +5,7 @@ import socket
|
|||
import selectors
|
||||
import ssl
|
||||
|
||||
from configparser import ConfigParser
|
||||
|
||||
from nntp.tiny.config import (
|
||||
ConfigException, ConfigSectionException, ConfigValueException)
|
||||
from nntp.tiny.config import Config, ConfigException
|
||||
from nntp.tiny.db import Database
|
||||
from nntp.tiny.newsgroup import Newsgroup
|
||||
from nntp.tiny.session import Session
|
||||
|
@ -19,36 +16,21 @@ class ServerCapability(enum.Flag):
|
|||
POST = enum.auto()
|
||||
|
||||
class Server():
|
||||
def __init__(self, config: ConfigParser):
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.capabilities = ServerCapability.NONE
|
||||
self.newsgroups = dict()
|
||||
self.sslctx = None
|
||||
|
||||
if config['listen'].get('tls', 'no') == 'yes':
|
||||
if not config.has_section('tls'):
|
||||
raise ConfigSectionException('tls')
|
||||
|
||||
if not config.has_option('tls', 'cert'):
|
||||
raise ConfigValueException('tls', 'cert')
|
||||
|
||||
if not config.has_option('tls', 'key'):
|
||||
raise ConfigValueException('tls', 'key')
|
||||
|
||||
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
self.sslctx.load_cert_chain(config['tls']['cert'],
|
||||
config['tls']['key'])
|
||||
self.sslctx.load_cert_chain(config.get('tls', 'cert'),
|
||||
config.get('tls', 'key'))
|
||||
|
||||
self._init_newsgroups()
|
||||
|
||||
def connect_to_db(self):
|
||||
if not self.config.has_section('database'):
|
||||
raise ConfigSectionException('database')
|
||||
|
||||
if not self.config.has_option('database', 'path'):
|
||||
raise ConfigValueException('database', 'path')
|
||||
|
||||
return Database.connect(self.config['database']['path'])
|
||||
return Database.connect(self.config.get('database', 'path'))
|
||||
|
||||
def _init_newsgroups(self):
|
||||
db = self.connect_to_db()
|
||||
|
@ -108,14 +90,8 @@ class Server():
|
|||
return True
|
||||
|
||||
def run(self):
|
||||
if not self.config.has_section('listen'):
|
||||
raise ConfigSectionException('listen')
|
||||
|
||||
if not self.config.has_option('listen', 'host'):
|
||||
raise ConfigValueException('listen', 'host')
|
||||
|
||||
hosts = re.split(r'\s*,\s*', self.config['listen']['host'])
|
||||
port = int(self.config['listen']['port'])
|
||||
hosts = re.split(r'\s*,\s*', self.config.get('listen', 'host'))
|
||||
port = int(self.config.get('listen', 'port'))
|
||||
|
||||
listeners = list()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue