Build config validation into Config.get() accessor
This commit is contained in:
parent
3a608a1636
commit
7653ebdd29
3 changed files with 24 additions and 46 deletions
|
@ -34,7 +34,7 @@ class ConfigFileException(ConfigException):
|
||||||
", ".join(self.paths)
|
", ".join(self.paths)
|
||||||
)
|
)
|
||||||
|
|
||||||
class Config():
|
class Config(configparser.ConfigParser):
|
||||||
SEARCH_PATHS = [
|
SEARCH_PATHS = [
|
||||||
'./server.conf',
|
'./server.conf',
|
||||||
'/etc/nntp-tiny/server.conf'
|
'/etc/nntp-tiny/server.conf'
|
||||||
|
@ -48,6 +48,7 @@ class Config():
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def load(path: Optional[str]=None):
|
def load(path: Optional[str]=None):
|
||||||
if path is None:
|
if path is None:
|
||||||
path = Config.find()
|
path = Config.find()
|
||||||
|
@ -55,7 +56,16 @@ class Config():
|
||||||
if path is None:
|
if path is None:
|
||||||
raise ConfigFileException(Config.SEARCH_PATHS)
|
raise ConfigFileException(Config.SEARCH_PATHS)
|
||||||
|
|
||||||
parser = configparser.ConfigParser()
|
config = Config()
|
||||||
parser.read(path)
|
config.read(path)
|
||||||
|
|
||||||
return parser
|
return config
|
||||||
|
|
||||||
|
def get(self, section: str, option: str, *args, **kwargs):
|
||||||
|
if not self.has_section(section):
|
||||||
|
raise ConfigSectionException(section)
|
||||||
|
|
||||||
|
if not self.has_option(section, option):
|
||||||
|
raise ConfigValueException(section, option)
|
||||||
|
|
||||||
|
return super().get(section, option, *args, **kwargs)
|
||||||
|
|
|
@ -2,20 +2,12 @@ import os
|
||||||
import signal
|
import signal
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
from nntp.tiny.config import (ConfigSectionException,
|
from nntp.tiny.config import Config
|
||||||
ConfigValueException)
|
|
||||||
|
|
||||||
class Daemon():
|
class Daemon():
|
||||||
def init(config: configparser.ConfigParser):
|
def init(config: Config):
|
||||||
if not config.has_section('daemon'):
|
|
||||||
raise ConfigSectionException('daemon')
|
|
||||||
|
|
||||||
if not config.has_option('daemon', 'pidfile'):
|
|
||||||
raise ConfigValueException('daemon', 'pidfile')
|
|
||||||
|
|
||||||
pidfile = config.get('daemon', 'pidfile')
|
pidfile = config.get('daemon', 'pidfile')
|
||||||
|
pid = os.fork()
|
||||||
pid = os.fork()
|
|
||||||
|
|
||||||
if pid > 0:
|
if pid > 0:
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
|
@ -5,10 +5,7 @@ import socket
|
||||||
import selectors
|
import selectors
|
||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
from configparser import ConfigParser
|
from nntp.tiny.config import Config, ConfigException
|
||||||
|
|
||||||
from nntp.tiny.config import (
|
|
||||||
ConfigException, ConfigSectionException, ConfigValueException)
|
|
||||||
from nntp.tiny.db import Database
|
from nntp.tiny.db import Database
|
||||||
from nntp.tiny.newsgroup import Newsgroup
|
from nntp.tiny.newsgroup import Newsgroup
|
||||||
from nntp.tiny.session import Session
|
from nntp.tiny.session import Session
|
||||||
|
@ -19,36 +16,21 @@ class ServerCapability(enum.Flag):
|
||||||
POST = enum.auto()
|
POST = enum.auto()
|
||||||
|
|
||||||
class Server():
|
class Server():
|
||||||
def __init__(self, config: ConfigParser):
|
def __init__(self, config: Config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.capabilities = ServerCapability.NONE
|
self.capabilities = ServerCapability.NONE
|
||||||
self.newsgroups = dict()
|
self.newsgroups = dict()
|
||||||
self.sslctx = None
|
self.sslctx = None
|
||||||
|
|
||||||
if config['listen'].get('tls', 'no') == 'yes':
|
if config['listen'].get('tls', 'no') == 'yes':
|
||||||
if not config.has_section('tls'):
|
|
||||||
raise ConfigSectionException('tls')
|
|
||||||
|
|
||||||
if not config.has_option('tls', 'cert'):
|
|
||||||
raise ConfigValueException('tls', 'cert')
|
|
||||||
|
|
||||||
if not config.has_option('tls', 'key'):
|
|
||||||
raise ConfigValueException('tls', 'key')
|
|
||||||
|
|
||||||
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
self.sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
self.sslctx.load_cert_chain(config['tls']['cert'],
|
self.sslctx.load_cert_chain(config.get('tls', 'cert'),
|
||||||
config['tls']['key'])
|
config.get('tls', 'key'))
|
||||||
|
|
||||||
self._init_newsgroups()
|
self._init_newsgroups()
|
||||||
|
|
||||||
def connect_to_db(self):
|
def connect_to_db(self):
|
||||||
if not self.config.has_section('database'):
|
return Database.connect(self.config.get('database', 'path'))
|
||||||
raise ConfigSectionException('database')
|
|
||||||
|
|
||||||
if not self.config.has_option('database', 'path'):
|
|
||||||
raise ConfigValueException('database', 'path')
|
|
||||||
|
|
||||||
return Database.connect(self.config['database']['path'])
|
|
||||||
|
|
||||||
def _init_newsgroups(self):
|
def _init_newsgroups(self):
|
||||||
db = self.connect_to_db()
|
db = self.connect_to_db()
|
||||||
|
@ -108,14 +90,8 @@ class Server():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if not self.config.has_section('listen'):
|
hosts = re.split(r'\s*,\s*', self.config.get('listen', 'host'))
|
||||||
raise ConfigSectionException('listen')
|
port = int(self.config.get('listen', 'port'))
|
||||||
|
|
||||||
if not self.config.has_option('listen', 'host'):
|
|
||||||
raise ConfigValueException('listen', 'host')
|
|
||||||
|
|
||||||
hosts = re.split(r'\s*,\s*', self.config['listen']['host'])
|
|
||||||
port = int(self.config['listen']['port'])
|
|
||||||
|
|
||||||
listeners = list()
|
listeners = list()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue