Preserve case of header keys

This commit is contained in:
XANTRONIX Development 2024-11-11 17:20:53 -05:00
parent a4ccaae4ae
commit fa101660a8

View file

@ -3,7 +3,7 @@ import enum
import datetime import datetime
from email.utils import parsedate_to_datetime from email.utils import parsedate_to_datetime
from email.header import decode_header from email.header import decode_header, Header
from nntp.tiny.db import DatabaseTable from nntp.tiny.db import DatabaseTable
@ -54,6 +54,7 @@ class Message(DatabaseTable):
__slots__ = ( __slots__ = (
'_cache', '_cache',
'_headers', '_headers',
'_headers_lc',
'_body', '_body',
'_key', '_key',
'id', 'id',
@ -81,6 +82,7 @@ class Message(DatabaseTable):
def __init__(self): def __init__(self):
self._cache = dict() self._cache = dict()
self._headers = None self._headers = None
self._headers_lc = None
self._body = None self._body = None
self._key = None self._key = None
self.id = None self.id = None
@ -133,11 +135,26 @@ class Message(DatabaseTable):
return self._body return self._body
def header(self, key: str): def _header_set(self, key: str, value: str):
self._headers[key] = value
self._headers_lc[key.lower()] = value
def _header_append(self, key: str, value: str):
if key not in self._headers:
self._headers[key] = value
self._headers_lc[key.lower()] = value
else:
self._headers[key] += value
self._headers_lc[key.lower()] += value
def header(self, key: str, default=None):
if self._headers is None: if self._headers is None:
self.read(self.content) self.read(self.content)
return self.headers.get(key.lower()) value = self._headers_lc.get(key.lower(), default)
if value is not None:
return decode(value)
@property @property
def created_on(self): def created_on(self):
@ -156,10 +173,10 @@ class Message(DatabaseTable):
@created_on.setter @created_on.setter
def created_on(self, value): def created_on(self, value):
if self._headers is not None: if self._headers is None:
self._headers['date'] = str(value) self._cache['created_on'] = str(value)
elif value is not None:
self._cache['created_on'] = str(value) self._header_set('Date', Header(str(value)).encode())
@property @property
def message_id(self) -> str: def message_id(self) -> str:
@ -172,8 +189,8 @@ class Message(DatabaseTable):
def message_id(self, value): def message_id(self, value):
if self._headers is None: if self._headers is None:
self._cache['message_id'] = value self._cache['message_id'] = value
else: elif value is not None:
self.headers['message-id'] = value self._header_set('Message-ID', Header(value).encode())
@property @property
def parent_id(self) -> str: def parent_id(self) -> str:
@ -186,36 +203,36 @@ class Message(DatabaseTable):
def parent_id(self, value): def parent_id(self, value):
if self._headers is None: if self._headers is None:
self._cache['parent_id'] = value self._cache['parent_id'] = value
else: elif value is not None:
self.headers['references'] = value self._header_set('References', Header(value).encode())
@property @property
def sender(self) -> str: def sender(self) -> str:
if self._headers is None: if self._headers is None:
return self._cache.get('sender') return self._cache.get('sender')
return self.headers.get('from', 'Unknown') return self.header('From', 'Unknown')
@sender.setter @sender.setter
def sender(self, value): def sender(self, value):
if self._headers is None: if self._headers is None:
self._cache['sender'] = value self._cache['sender'] = value
else: elif value is not None:
self.headers['from'] = value self._header_set('From', Header(value).encode())
@property @property
def subject(self) -> str: def subject(self) -> str:
if self._headers is None: if self._headers is None:
return self._cache.get('subject', '(no subject)') return self._cache.get('subject', '(no subject)')
return self.headers.get('subject', '(no subject)') return self.header('subject', '(no subject)')
@subject.setter @subject.setter
def subject(self, value): def subject(self, value):
if self._headers is None: if self._headers is None:
self._cache['subject'] = value self._cache['subject'] = value
else: elif value is not None:
self.headers['subject'] = value self._header_set('Subject', Header(value).encode())
def is_first_line(self): def is_first_line(self):
return len(self.headers) == 1 and (self._body == '' or self._body is None) return len(self.headers) == 1 and (self._body == '' or self._body is None)
@ -225,21 +242,23 @@ class Message(DatabaseTable):
self.content += self.line self.content += self.line
if self.state is MessageState.EMPTY: if self.state is MessageState.EMPTY:
self.state = MessageState.HEADER self.state = MessageState.HEADER
self._headers = dict() self._headers = dict()
self._headers_lc = dict()
if self.state is MessageState.HEADER: if self.state is MessageState.HEADER:
if line == '\n' or line == '\r\n': if line == '\n' or line == '\r\n':
self.state = MessageState.BODY self.state = MessageState.BODY
elif line[0] == ' ' or line[0] == '\t': elif line[0] == ' ' or line[0] == '\t':
self._headers[self._key] += ' ' + decode(line.strip()) self._header_append(self._key, ' ' + line.strip())
else: else:
match = self.RE_HEADER.match(line) match = self.RE_HEADER.match(line)
if match: if match:
self._key = match[1].lower() self._key = match[1]
self._header_append(self._key, match[2].rstrip())
self._headers[self._key] = decode(match[2].rstrip())
elif self.state is MessageState.BODY: elif self.state is MessageState.BODY:
if self._body is None: if self._body is None:
self._body = '' self._body = ''