mirror of
https://github.com/Unidata/python-awips.git
synced 2025-02-23 14:57:56 -05:00
thrift update to 0.10.0
This commit is contained in:
parent
3837f21015
commit
0ddbcd4bb0
25 changed files with 3878 additions and 2065 deletions
55
thrift/TMultiplexedProcessor.py
Normal file
55
thrift/TMultiplexedProcessor.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from thrift.Thrift import TProcessor, TMessageType, TException
|
||||
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
|
||||
|
||||
|
||||
class TMultiplexedProcessor(TProcessor):
|
||||
def __init__(self):
|
||||
self.services = {}
|
||||
|
||||
def registerProcessor(self, serviceName, processor):
|
||||
self.services[serviceName] = processor
|
||||
|
||||
def process(self, iprot, oprot):
|
||||
(name, type, seqid) = iprot.readMessageBegin()
|
||||
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
|
||||
raise TException("TMultiplex protocol only supports CALL & ONEWAY")
|
||||
|
||||
index = name.find(TMultiplexedProtocol.SEPARATOR)
|
||||
if index < 0:
|
||||
raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
|
||||
|
||||
serviceName = name[0:index]
|
||||
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
|
||||
if serviceName not in self.services:
|
||||
raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
|
||||
|
||||
standardMessage = (call, type, seqid)
|
||||
return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
|
||||
|
||||
|
||||
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
|
||||
def __init__(self, protocol, messageBegin):
|
||||
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
|
||||
self.messageBegin = messageBegin
|
||||
|
||||
def readMessageBegin(self):
|
||||
return self.messageBegin
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
from os import path
|
||||
from SCons.Builder import Builder
|
||||
from six.moves import map
|
||||
|
||||
|
||||
def scons_env(env, add=''):
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from protocol import TBinaryProtocol
|
||||
from transport import TTransport
|
||||
from .protocol import TBinaryProtocol
|
||||
from .transport import TTransport
|
||||
|
||||
|
||||
def serialize(thrift_object,
|
||||
|
|
188
thrift/TTornado.py
Normal file
188
thrift/TTornado.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
|
||||
|
||||
from io import BytesIO
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from tornado import gen, iostream, ioloop, tcpserver, concurrent
|
||||
|
||||
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _Lock(object):
|
||||
def __init__(self):
|
||||
self._waiters = deque()
|
||||
|
||||
def acquired(self):
|
||||
return len(self._waiters) > 0
|
||||
|
||||
@gen.coroutine
|
||||
def acquire(self):
|
||||
blocker = self._waiters[-1] if self.acquired() else None
|
||||
future = concurrent.Future()
|
||||
self._waiters.append(future)
|
||||
if blocker:
|
||||
yield blocker
|
||||
|
||||
raise gen.Return(self._lock_context())
|
||||
|
||||
def release(self):
|
||||
assert self.acquired(), 'Lock not aquired'
|
||||
future = self._waiters.popleft()
|
||||
future.set_result(None)
|
||||
|
||||
@contextmanager
|
||||
def _lock_context(self):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.release()
|
||||
|
||||
|
||||
class TTornadoStreamTransport(TTransportBase):
|
||||
"""a framed, buffered transport over a Tornado stream"""
|
||||
def __init__(self, host, port, stream=None, io_loop=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.io_loop = io_loop or ioloop.IOLoop.current()
|
||||
self.__wbuf = BytesIO()
|
||||
self._read_lock = _Lock()
|
||||
|
||||
# servers provide a ready-to-go stream
|
||||
self.stream = stream
|
||||
|
||||
def with_timeout(self, timeout, future):
|
||||
return gen.with_timeout(timeout, future, self.io_loop)
|
||||
|
||||
@gen.coroutine
|
||||
def open(self, timeout=None):
|
||||
logger.debug('socket connecting')
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
self.stream = iostream.IOStream(sock)
|
||||
|
||||
try:
|
||||
connect = self.stream.connect((self.host, self.port))
|
||||
if timeout is not None:
|
||||
yield self.with_timeout(timeout, connect)
|
||||
else:
|
||||
yield connect
|
||||
except (socket.error, IOError, ioloop.TimeoutError) as e:
|
||||
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
|
||||
raise TTransportException(
|
||||
type=TTransportException.NOT_OPEN,
|
||||
message=message)
|
||||
|
||||
raise gen.Return(self)
|
||||
|
||||
def set_close_callback(self, callback):
|
||||
"""
|
||||
Should be called only after open() returns
|
||||
"""
|
||||
self.stream.set_close_callback(callback)
|
||||
|
||||
def close(self):
|
||||
# don't raise if we intend to close
|
||||
self.stream.set_close_callback(None)
|
||||
self.stream.close()
|
||||
|
||||
def read(self, _):
|
||||
# The generated code for Tornado shouldn't do individual reads -- only
|
||||
# frames at a time
|
||||
assert False, "you're doing it wrong"
|
||||
|
||||
@contextmanager
|
||||
def io_exception_context(self):
|
||||
try:
|
||||
yield
|
||||
except (socket.error, IOError) as e:
|
||||
raise TTransportException(
|
||||
type=TTransportException.END_OF_FILE,
|
||||
message=str(e))
|
||||
except iostream.StreamBufferFullError as e:
|
||||
raise TTransportException(
|
||||
type=TTransportException.UNKNOWN,
|
||||
message=str(e))
|
||||
|
||||
@gen.coroutine
|
||||
def readFrame(self):
|
||||
# IOStream processes reads one at a time
|
||||
with (yield self._read_lock.acquire()):
|
||||
with self.io_exception_context():
|
||||
frame_header = yield self.stream.read_bytes(4)
|
||||
if len(frame_header) == 0:
|
||||
raise iostream.StreamClosedError('Read zero bytes from stream')
|
||||
frame_length, = struct.unpack('!i', frame_header)
|
||||
frame = yield self.stream.read_bytes(frame_length)
|
||||
raise gen.Return(frame)
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def flush(self):
|
||||
frame = self.__wbuf.getvalue()
|
||||
# reset wbuf before write/flush to preserve state on underlying failure
|
||||
frame_length = struct.pack('!i', len(frame))
|
||||
self.__wbuf = BytesIO()
|
||||
with self.io_exception_context():
|
||||
return self.stream.write(frame_length + frame)
|
||||
|
||||
|
||||
class TTornadoServer(tcpserver.TCPServer):
|
||||
def __init__(self, processor, iprot_factory, oprot_factory=None,
|
||||
*args, **kwargs):
|
||||
super(TTornadoServer, self).__init__(*args, **kwargs)
|
||||
|
||||
self._processor = processor
|
||||
self._iprot_factory = iprot_factory
|
||||
self._oprot_factory = (oprot_factory if oprot_factory is not None
|
||||
else iprot_factory)
|
||||
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
host, port = address[:2]
|
||||
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
|
||||
io_loop=self.io_loop)
|
||||
oprot = self._oprot_factory.getProtocol(trans)
|
||||
|
||||
try:
|
||||
while not trans.stream.closed():
|
||||
try:
|
||||
frame = yield trans.readFrame()
|
||||
except TTransportException as e:
|
||||
if e.type == TTransportException.END_OF_FILE:
|
||||
break
|
||||
else:
|
||||
raise
|
||||
tr = TMemoryBuffer(frame)
|
||||
iprot = self._iprot_factory.getProtocol(tr)
|
||||
yield self._processor.process(iprot, oprot)
|
||||
except Exception:
|
||||
logger.exception('thrift exception in handle_stream')
|
||||
trans.close()
|
||||
|
||||
logger.info('client disconnected %s:%d', host, port)
|
|
@ -20,7 +20,7 @@
|
|||
import sys
|
||||
|
||||
|
||||
class TType:
|
||||
class TType(object):
|
||||
STOP = 0
|
||||
VOID = 1
|
||||
BOOL = 2
|
||||
|
@ -39,7 +39,8 @@ class TType:
|
|||
UTF8 = 16
|
||||
UTF16 = 17
|
||||
|
||||
_VALUES_TO_NAMES = ('STOP',
|
||||
_VALUES_TO_NAMES = (
|
||||
'STOP',
|
||||
'VOID',
|
||||
'BOOL',
|
||||
'BYTE',
|
||||
|
@ -56,17 +57,18 @@ class TType:
|
|||
'SET',
|
||||
'LIST',
|
||||
'UTF8',
|
||||
'UTF16')
|
||||
'UTF16',
|
||||
)
|
||||
|
||||
|
||||
class TMessageType:
|
||||
class TMessageType(object):
|
||||
CALL = 1
|
||||
REPLY = 2
|
||||
EXCEPTION = 3
|
||||
ONEWAY = 4
|
||||
|
||||
|
||||
class TProcessor:
|
||||
class TProcessor(object):
|
||||
"""Base class for procsessor, which works on two streams."""
|
||||
|
||||
def process(iprot, oprot):
|
||||
|
@ -101,6 +103,9 @@ class TApplicationException(TException):
|
|||
MISSING_RESULT = 5
|
||||
INTERNAL_ERROR = 6
|
||||
PROTOCOL_ERROR = 7
|
||||
INVALID_TRANSFORM = 8
|
||||
INVALID_PROTOCOL = 9
|
||||
UNSUPPORTED_CLIENT_TYPE = 10
|
||||
|
||||
def __init__(self, type=UNKNOWN, message=None):
|
||||
TException.__init__(self, message)
|
||||
|
@ -119,6 +124,16 @@ class TApplicationException(TException):
|
|||
return 'Bad sequence ID'
|
||||
elif self.type == self.MISSING_RESULT:
|
||||
return 'Missing result'
|
||||
elif self.type == self.INTERNAL_ERROR:
|
||||
return 'Internal error'
|
||||
elif self.type == self.PROTOCOL_ERROR:
|
||||
return 'Protocol error'
|
||||
elif self.type == self.INVALID_TRANSFORM:
|
||||
return 'Invalid transform'
|
||||
elif self.type == self.INVALID_PROTOCOL:
|
||||
return 'Invalid protocol'
|
||||
elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
|
||||
return 'Unsupported client type'
|
||||
else:
|
||||
return 'Default (unknown) TApplicationException'
|
||||
|
||||
|
@ -155,3 +170,23 @@ class TApplicationException(TException):
|
|||
oprot.writeFieldEnd()
|
||||
oprot.writeFieldStop()
|
||||
oprot.writeStructEnd()
|
||||
|
||||
|
||||
class TFrozenDict(dict):
|
||||
"""A dictionary that is "frozen" like a frozenset"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TFrozenDict, self).__init__(*args, **kwargs)
|
||||
# Sort the items so they will be in a consistent order.
|
||||
# XOR in the hash of the class so we don't collide with
|
||||
# the hash of a list of tuples.
|
||||
self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
|
||||
|
||||
def __setitem__(self, *args):
|
||||
raise TypeError("Can't modify frozen TFreezableDict")
|
||||
|
||||
def __delitem__(self, *args):
|
||||
raise TypeError("Can't modify frozen TFreezableDict")
|
||||
|
||||
def __hash__(self):
|
||||
return self.__hashval
|
||||
|
|
40
thrift/compat.py
Normal file
40
thrift/compat.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
||||
from cStringIO import StringIO as BufferIO
|
||||
|
||||
def binary_to_str(bin_val):
|
||||
return bin_val
|
||||
|
||||
def str_to_binary(str_val):
|
||||
return str_val
|
||||
|
||||
else:
|
||||
|
||||
from io import BytesIO as BufferIO # noqa
|
||||
|
||||
def binary_to_str(bin_val):
|
||||
return bin_val.decode('utf8')
|
||||
|
||||
def str_to_binary(str_val):
|
||||
return bytes(str_val, 'utf8')
|
|
@ -17,22 +17,14 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from thrift.Thrift import *
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from thrift.transport import TTransport
|
||||
|
||||
try:
|
||||
from thrift.protocol import fastbinary
|
||||
except:
|
||||
fastbinary = None
|
||||
|
||||
|
||||
class TBase(object):
|
||||
__slots__ = []
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, getattr(self, key))
|
||||
for key in self.__slots__]
|
||||
L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -49,33 +41,42 @@ class TBase(object):
|
|||
return not (self == other)
|
||||
|
||||
def read(self, iprot):
|
||||
if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
|
||||
if (iprot._fast_decode is not None and
|
||||
isinstance(iprot.trans, TTransport.CReadableTransport) and
|
||||
self.thrift_spec is not None and
|
||||
fastbinary is not None):
|
||||
fastbinary.decode_binary(self,
|
||||
iprot.trans,
|
||||
(self.__class__, self.thrift_spec))
|
||||
return
|
||||
self.thrift_spec is not None):
|
||||
iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))
|
||||
else:
|
||||
iprot.readStruct(self, self.thrift_spec)
|
||||
|
||||
def write(self, oprot):
|
||||
if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
|
||||
self.thrift_spec is not None and
|
||||
fastbinary is not None):
|
||||
if (oprot._fast_encode is not None and self.thrift_spec is not None):
|
||||
oprot.trans.write(
|
||||
fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
|
||||
return
|
||||
oprot._fast_encode(self, (self.__class__, self.thrift_spec)))
|
||||
else:
|
||||
oprot.writeStruct(self, self.thrift_spec)
|
||||
|
||||
|
||||
class TExceptionBase(Exception):
|
||||
# old style class so python2.4 can raise exceptions derived from this
|
||||
# This can't inherit from TBase because of that limitation.
|
||||
__slots__ = []
|
||||
class TExceptionBase(TBase, Exception):
|
||||
pass
|
||||
|
||||
__repr__ = TBase.__repr__.im_func
|
||||
__eq__ = TBase.__eq__.im_func
|
||||
__ne__ = TBase.__ne__.im_func
|
||||
read = TBase.read.im_func
|
||||
write = TBase.write.im_func
|
||||
|
||||
class TFrozenBase(TBase):
|
||||
def __setitem__(self, *args):
|
||||
raise TypeError("Can't modify frozen struct")
|
||||
|
||||
def __delitem__(self, *args):
|
||||
raise TypeError("Can't modify frozen struct")
|
||||
|
||||
def __hash__(self, *args):
|
||||
return hash(self.__class__) ^ hash(self.__slots__)
|
||||
|
||||
@classmethod
|
||||
def read(cls, iprot):
|
||||
if (iprot._fast_decode is not None and
|
||||
isinstance(iprot.trans, TTransport.CReadableTransport) and
|
||||
cls.thrift_spec is not None):
|
||||
self = cls()
|
||||
return iprot._fast_decode(None, iprot,
|
||||
(self.__class__, self.thrift_spec))
|
||||
else:
|
||||
return iprot.readStruct(cls, cls.thrift_spec, True)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from TProtocol import *
|
||||
from .TProtocol import TType, TProtocolBase, TProtocolException
|
||||
from struct import pack, unpack
|
||||
|
||||
|
||||
|
@ -36,10 +36,18 @@ class TBinaryProtocol(TProtocolBase):
|
|||
|
||||
TYPE_MASK = 0x000000ff
|
||||
|
||||
def __init__(self, trans, strictRead=False, strictWrite=True):
|
||||
def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
|
||||
TProtocolBase.__init__(self, trans)
|
||||
self.strictRead = strictRead
|
||||
self.strictWrite = strictWrite
|
||||
self.string_length_limit = kwargs.get('string_length_limit', None)
|
||||
self.container_length_limit = kwargs.get('container_length_limit', None)
|
||||
|
||||
def _check_string_length(self, length):
|
||||
self._check_length(self.string_length_limit, length)
|
||||
|
||||
def _check_container_length(self, length):
|
||||
self._check_length(self.container_length_limit, length)
|
||||
|
||||
def writeMessageBegin(self, name, type, seqid):
|
||||
if self.strictWrite:
|
||||
|
@ -118,7 +126,7 @@ class TBinaryProtocol(TProtocolBase):
|
|||
buff = pack("!d", dub)
|
||||
self.trans.write(buff)
|
||||
|
||||
def writeString(self, str):
|
||||
def writeBinary(self, str):
|
||||
self.writeI32(len(str))
|
||||
self.trans.write(str)
|
||||
|
||||
|
@ -165,6 +173,7 @@ class TBinaryProtocol(TProtocolBase):
|
|||
ktype = self.readByte()
|
||||
vtype = self.readByte()
|
||||
size = self.readI32()
|
||||
self._check_container_length(size)
|
||||
return (ktype, vtype, size)
|
||||
|
||||
def readMapEnd(self):
|
||||
|
@ -173,6 +182,7 @@ class TBinaryProtocol(TProtocolBase):
|
|||
def readListBegin(self):
|
||||
etype = self.readByte()
|
||||
size = self.readI32()
|
||||
self._check_container_length(size)
|
||||
return (etype, size)
|
||||
|
||||
def readListEnd(self):
|
||||
|
@ -181,6 +191,7 @@ class TBinaryProtocol(TProtocolBase):
|
|||
def readSetBegin(self):
|
||||
etype = self.readByte()
|
||||
size = self.readI32()
|
||||
self._check_container_length(size)
|
||||
return (etype, size)
|
||||
|
||||
def readSetEnd(self):
|
||||
|
@ -217,19 +228,24 @@ class TBinaryProtocol(TProtocolBase):
|
|||
val, = unpack('!d', buff)
|
||||
return val
|
||||
|
||||
def readString(self):
|
||||
len = self.readI32()
|
||||
str = self.trans.readAll(len)
|
||||
return str
|
||||
def readBinary(self):
|
||||
size = self.readI32()
|
||||
self._check_string_length(size)
|
||||
s = self.trans.readAll(size)
|
||||
return s
|
||||
|
||||
|
||||
class TBinaryProtocolFactory:
|
||||
def __init__(self, strictRead=False, strictWrite=True):
|
||||
class TBinaryProtocolFactory(object):
|
||||
def __init__(self, strictRead=False, strictWrite=True, **kwargs):
|
||||
self.strictRead = strictRead
|
||||
self.strictWrite = strictWrite
|
||||
self.string_length_limit = kwargs.get('string_length_limit', None)
|
||||
self.container_length_limit = kwargs.get('container_length_limit', None)
|
||||
|
||||
def getProtocol(self, trans):
|
||||
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
|
||||
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
|
||||
string_length_limit=self.string_length_limit,
|
||||
container_length_limit=self.container_length_limit)
|
||||
return prot
|
||||
|
||||
|
||||
|
@ -242,6 +258,7 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):
|
|||
We inherit from TBinaryProtocol so that the normal TBinaryProtocol
|
||||
encoding can happen if the fastbinary module doesn't work for some
|
||||
reason. (TODO(dreiss): Make this happen sanely in more cases.)
|
||||
To disable this behavior, pass fallback=False constructor argument.
|
||||
|
||||
In order to take advantage of the C module, just use
|
||||
TBinaryProtocolAccelerated instead of TBinaryProtocol.
|
||||
|
@ -254,7 +271,31 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
fallback = kwargs.pop('fallback', True)
|
||||
super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs)
|
||||
try:
|
||||
from thrift.protocol import fastbinary
|
||||
except ImportError:
|
||||
if not fallback:
|
||||
raise
|
||||
else:
|
||||
self._fast_decode = fastbinary.decode_binary
|
||||
self._fast_encode = fastbinary.encode_binary
|
||||
|
||||
|
||||
class TBinaryProtocolAcceleratedFactory(object):
|
||||
def __init__(self,
|
||||
string_length_limit=None,
|
||||
container_length_limit=None,
|
||||
fallback=True):
|
||||
self.string_length_limit = string_length_limit
|
||||
self.container_length_limit = container_length_limit
|
||||
self._fallback = fallback
|
||||
|
||||
class TBinaryProtocolAcceleratedFactory:
|
||||
def getProtocol(self, trans):
|
||||
return TBinaryProtocolAccelerated(trans)
|
||||
return TBinaryProtocolAccelerated(
|
||||
trans,
|
||||
string_length_limit=self.string_length_limit,
|
||||
container_length_limit=self.container_length_limit,
|
||||
fallback=self._fallback)
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from TProtocol import *
|
||||
from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
|
||||
from struct import pack, unpack
|
||||
|
||||
from ..compat import binary_to_str, str_to_binary
|
||||
|
||||
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
|
||||
|
||||
CLEAR = 0
|
||||
|
@ -45,6 +47,7 @@ reader = make_helper(VALUE_READ, CONTAINER_READ)
|
|||
|
||||
|
||||
def makeZigZag(n, bits):
|
||||
checkIntegerLimits(n, bits)
|
||||
return (n << 1) ^ (n >> (bits - 1))
|
||||
|
||||
|
||||
|
@ -53,7 +56,7 @@ def fromZigZag(n):
|
|||
|
||||
|
||||
def writeVarint(trans, n):
|
||||
out = []
|
||||
out = bytearray()
|
||||
while True:
|
||||
if n & ~0x7f == 0:
|
||||
out.append(n)
|
||||
|
@ -61,7 +64,7 @@ def writeVarint(trans, n):
|
|||
else:
|
||||
out.append((n & 0xff) | 0x80)
|
||||
n = n >> 7
|
||||
trans.write(''.join(map(chr, out)))
|
||||
trans.write(bytes(out))
|
||||
|
||||
|
||||
def readVarint(trans):
|
||||
|
@ -76,7 +79,7 @@ def readVarint(trans):
|
|||
shift += 7
|
||||
|
||||
|
||||
class CompactType:
|
||||
class CompactType(object):
|
||||
STOP = 0x00
|
||||
TRUE = 0x01
|
||||
FALSE = 0x02
|
||||
|
@ -91,7 +94,8 @@ class CompactType:
|
|||
MAP = 0x0B
|
||||
STRUCT = 0x0C
|
||||
|
||||
CTYPES = {TType.STOP: CompactType.STOP,
|
||||
CTYPES = {
|
||||
TType.STOP: CompactType.STOP,
|
||||
TType.BOOL: CompactType.TRUE, # used for collection
|
||||
TType.BYTE: CompactType.BYTE,
|
||||
TType.I16: CompactType.I16,
|
||||
|
@ -102,8 +106,8 @@ CTYPES = {TType.STOP: CompactType.STOP,
|
|||
TType.STRUCT: CompactType.STRUCT,
|
||||
TType.LIST: CompactType.LIST,
|
||||
TType.SET: CompactType.SET,
|
||||
TType.MAP: CompactType.MAP
|
||||
}
|
||||
TType.MAP: CompactType.MAP,
|
||||
}
|
||||
|
||||
TTYPES = {}
|
||||
for k, v in CTYPES.items():
|
||||
|
@ -120,9 +124,12 @@ class TCompactProtocol(TProtocolBase):
|
|||
VERSION = 1
|
||||
VERSION_MASK = 0x1f
|
||||
TYPE_MASK = 0xe0
|
||||
TYPE_BITS = 0x07
|
||||
TYPE_SHIFT_AMOUNT = 5
|
||||
|
||||
def __init__(self, trans):
|
||||
def __init__(self, trans,
|
||||
string_length_limit=None,
|
||||
container_length_limit=None):
|
||||
TProtocolBase.__init__(self, trans)
|
||||
self.state = CLEAR
|
||||
self.__last_fid = 0
|
||||
|
@ -130,6 +137,14 @@ class TCompactProtocol(TProtocolBase):
|
|||
self.__bool_value = None
|
||||
self.__structs = []
|
||||
self.__containers = []
|
||||
self.string_length_limit = string_length_limit
|
||||
self.container_length_limit = container_length_limit
|
||||
|
||||
def _check_string_length(self, length):
|
||||
self._check_length(self.string_length_limit, length)
|
||||
|
||||
def _check_container_length(self, length):
|
||||
self._check_length(self.container_length_limit, length)
|
||||
|
||||
def __writeVarint(self, n):
|
||||
writeVarint(self.trans, n)
|
||||
|
@ -139,7 +154,7 @@ class TCompactProtocol(TProtocolBase):
|
|||
self.__writeUByte(self.PROTOCOL_ID)
|
||||
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
|
||||
self.__writeVarint(seqid)
|
||||
self.__writeString(name)
|
||||
self.__writeBinary(str_to_binary(name))
|
||||
self.state = VALUE_WRITE
|
||||
|
||||
def writeMessageEnd(self):
|
||||
|
@ -250,12 +265,12 @@ class TCompactProtocol(TProtocolBase):
|
|||
|
||||
@writer
|
||||
def writeDouble(self, dub):
|
||||
self.trans.write(pack('!d', dub))
|
||||
self.trans.write(pack('<d', dub))
|
||||
|
||||
def __writeString(self, s):
|
||||
def __writeBinary(self, s):
|
||||
self.__writeSize(len(s))
|
||||
self.trans.write(s)
|
||||
writeString = writer(__writeString)
|
||||
writeBinary = writer(__writeBinary)
|
||||
|
||||
def readFieldBegin(self):
|
||||
assert self.state == FIELD_READ, self.state
|
||||
|
@ -300,7 +315,7 @@ class TCompactProtocol(TProtocolBase):
|
|||
def __readSize(self):
|
||||
result = self.__readVarint()
|
||||
if result < 0:
|
||||
raise TException("Length < 0")
|
||||
raise TProtocolException("Length < 0")
|
||||
return result
|
||||
|
||||
def readMessageBegin(self):
|
||||
|
@ -310,13 +325,13 @@ class TCompactProtocol(TProtocolBase):
|
|||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||
'Bad protocol id in the message: %d' % proto_id)
|
||||
ver_type = self.__readUByte()
|
||||
type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
|
||||
type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
|
||||
version = ver_type & self.VERSION_MASK
|
||||
if version != self.VERSION:
|
||||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||
'Bad version: %d (expect %d)' % (version, self.VERSION))
|
||||
seqid = self.__readVarint()
|
||||
name = self.__readString()
|
||||
name = binary_to_str(self.__readBinary())
|
||||
return (name, type, seqid)
|
||||
|
||||
def readMessageEnd(self):
|
||||
|
@ -340,6 +355,7 @@ class TCompactProtocol(TProtocolBase):
|
|||
type = self.__getTType(size_type)
|
||||
if size == 15:
|
||||
size = self.__readSize()
|
||||
self._check_container_length(size)
|
||||
self.__containers.append(self.state)
|
||||
self.state = CONTAINER_READ
|
||||
return type, size
|
||||
|
@ -349,6 +365,7 @@ class TCompactProtocol(TProtocolBase):
|
|||
def readMapBegin(self):
|
||||
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
||||
size = self.__readSize()
|
||||
self._check_container_length(size)
|
||||
types = 0
|
||||
if size > 0:
|
||||
types = self.__readUByte()
|
||||
|
@ -383,21 +400,73 @@ class TCompactProtocol(TProtocolBase):
|
|||
@reader
|
||||
def readDouble(self):
|
||||
buff = self.trans.readAll(8)
|
||||
val, = unpack('!d', buff)
|
||||
val, = unpack('<d', buff)
|
||||
return val
|
||||
|
||||
def __readString(self):
|
||||
len = self.__readSize()
|
||||
return self.trans.readAll(len)
|
||||
readString = reader(__readString)
|
||||
def __readBinary(self):
|
||||
size = self.__readSize()
|
||||
self._check_string_length(size)
|
||||
return self.trans.readAll(size)
|
||||
readBinary = reader(__readBinary)
|
||||
|
||||
def __getTType(self, byte):
|
||||
return TTYPES[byte & 0x0f]
|
||||
|
||||
|
||||
class TCompactProtocolFactory:
|
||||
def __init__(self):
|
||||
pass
|
||||
class TCompactProtocolFactory(object):
|
||||
def __init__(self,
|
||||
string_length_limit=None,
|
||||
container_length_limit=None):
|
||||
self.string_length_limit = string_length_limit
|
||||
self.container_length_limit = container_length_limit
|
||||
|
||||
def getProtocol(self, trans):
|
||||
return TCompactProtocol(trans)
|
||||
return TCompactProtocol(trans,
|
||||
self.string_length_limit,
|
||||
self.container_length_limit)
|
||||
|
||||
|
||||
class TCompactProtocolAccelerated(TCompactProtocol):
|
||||
"""C-Accelerated version of TCompactProtocol.
|
||||
|
||||
This class does not override any of TCompactProtocol's methods,
|
||||
but the generated code recognizes it directly and will call into
|
||||
our C module to do the encoding, bypassing this object entirely.
|
||||
We inherit from TCompactProtocol so that the normal TCompactProtocol
|
||||
encoding can happen if the fastbinary module doesn't work for some
|
||||
reason.
|
||||
To disable this behavior, pass fallback=False constructor argument.
|
||||
|
||||
In order to take advantage of the C module, just use
|
||||
TCompactProtocolAccelerated instead of TCompactProtocol.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
fallback = kwargs.pop('fallback', True)
|
||||
super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
|
||||
try:
|
||||
from thrift.protocol import fastbinary
|
||||
except ImportError:
|
||||
if not fallback:
|
||||
raise
|
||||
else:
|
||||
self._fast_decode = fastbinary.decode_compact
|
||||
self._fast_encode = fastbinary.encode_compact
|
||||
|
||||
|
||||
class TCompactProtocolAcceleratedFactory(object):
|
||||
def __init__(self,
|
||||
string_length_limit=None,
|
||||
container_length_limit=None,
|
||||
fallback=True):
|
||||
self.string_length_limit = string_length_limit
|
||||
self.container_length_limit = container_length_limit
|
||||
self._fallback = fallback
|
||||
|
||||
def getProtocol(self, trans):
|
||||
return TCompactProtocolAccelerated(
|
||||
trans,
|
||||
string_length_limit=self.string_length_limit,
|
||||
container_length_limit=self.container_length_limit,
|
||||
fallback=self._fallback)
|
||||
|
|
677
thrift/protocol/TJSONProtocol.py
Normal file
677
thrift/protocol/TJSONProtocol.py
Normal file
|
@ -0,0 +1,677 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from .TProtocol import (TType, TProtocolBase, TProtocolException,
|
||||
checkIntegerLimits)
|
||||
import base64
|
||||
import math
|
||||
import sys
|
||||
|
||||
from ..compat import str_to_binary
|
||||
|
||||
|
||||
__all__ = ['TJSONProtocol',
|
||||
'TJSONProtocolFactory',
|
||||
'TSimpleJSONProtocol',
|
||||
'TSimpleJSONProtocolFactory']
|
||||
|
||||
VERSION = 1
|
||||
|
||||
COMMA = b','
|
||||
COLON = b':'
|
||||
LBRACE = b'{'
|
||||
RBRACE = b'}'
|
||||
LBRACKET = b'['
|
||||
RBRACKET = b']'
|
||||
QUOTE = b'"'
|
||||
BACKSLASH = b'\\'
|
||||
ZERO = b'0'
|
||||
|
||||
ESCSEQ0 = ord('\\')
|
||||
ESCSEQ1 = ord('u')
|
||||
ESCAPE_CHAR_VALS = {
|
||||
'"': '\\"',
|
||||
'\\': '\\\\',
|
||||
'\b': '\\b',
|
||||
'\f': '\\f',
|
||||
'\n': '\\n',
|
||||
'\r': '\\r',
|
||||
'\t': '\\t',
|
||||
# '/': '\\/',
|
||||
}
|
||||
ESCAPE_CHARS = {
|
||||
b'"': '"',
|
||||
b'\\': '\\',
|
||||
b'b': '\b',
|
||||
b'f': '\f',
|
||||
b'n': '\n',
|
||||
b'r': '\r',
|
||||
b't': '\t',
|
||||
b'/': '/',
|
||||
}
|
||||
NUMERIC_CHAR = b'+-.0123456789Ee'
|
||||
|
||||
CTYPES = {
|
||||
TType.BOOL: 'tf',
|
||||
TType.BYTE: 'i8',
|
||||
TType.I16: 'i16',
|
||||
TType.I32: 'i32',
|
||||
TType.I64: 'i64',
|
||||
TType.DOUBLE: 'dbl',
|
||||
TType.STRING: 'str',
|
||||
TType.STRUCT: 'rec',
|
||||
TType.LIST: 'lst',
|
||||
TType.SET: 'set',
|
||||
TType.MAP: 'map',
|
||||
}
|
||||
|
||||
JTYPES = {}
|
||||
for key in CTYPES.keys():
|
||||
JTYPES[CTYPES[key]] = key
|
||||
|
||||
|
||||
class JSONBaseContext(object):
|
||||
|
||||
def __init__(self, protocol):
|
||||
self.protocol = protocol
|
||||
self.first = True
|
||||
|
||||
def doIO(self, function):
|
||||
pass
|
||||
|
||||
def write(self):
|
||||
pass
|
||||
|
||||
def read(self):
|
||||
pass
|
||||
|
||||
def escapeNum(self):
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class JSONListContext(JSONBaseContext):
|
||||
|
||||
def doIO(self, function):
|
||||
if self.first is True:
|
||||
self.first = False
|
||||
else:
|
||||
function(COMMA)
|
||||
|
||||
def write(self):
|
||||
self.doIO(self.protocol.trans.write)
|
||||
|
||||
def read(self):
|
||||
self.doIO(self.protocol.readJSONSyntaxChar)
|
||||
|
||||
|
||||
class JSONPairContext(JSONBaseContext):
|
||||
|
||||
def __init__(self, protocol):
|
||||
super(JSONPairContext, self).__init__(protocol)
|
||||
self.colon = True
|
||||
|
||||
def doIO(self, function):
|
||||
if self.first:
|
||||
self.first = False
|
||||
self.colon = True
|
||||
else:
|
||||
function(COLON if self.colon else COMMA)
|
||||
self.colon = not self.colon
|
||||
|
||||
def write(self):
|
||||
self.doIO(self.protocol.trans.write)
|
||||
|
||||
def read(self):
|
||||
self.doIO(self.protocol.readJSONSyntaxChar)
|
||||
|
||||
def escapeNum(self):
|
||||
return self.colon
|
||||
|
||||
def __str__(self):
|
||||
return '%s, colon=%s' % (self.__class__.__name__, self.colon)
|
||||
|
||||
|
||||
class LookaheadReader():
|
||||
hasData = False
|
||||
data = ''
|
||||
|
||||
def __init__(self, protocol):
|
||||
self.protocol = protocol
|
||||
|
||||
def read(self):
|
||||
if self.hasData is True:
|
||||
self.hasData = False
|
||||
else:
|
||||
self.data = self.protocol.trans.read(1)
|
||||
return self.data
|
||||
|
||||
def peek(self):
|
||||
if self.hasData is False:
|
||||
self.data = self.protocol.trans.read(1)
|
||||
self.hasData = True
|
||||
return self.data
|
||||
|
||||
|
||||
class TJSONProtocolBase(TProtocolBase):
|
||||
|
||||
def __init__(self, trans):
|
||||
TProtocolBase.__init__(self, trans)
|
||||
self.resetWriteContext()
|
||||
self.resetReadContext()
|
||||
|
||||
# We don't have length limit implementation for JSON protocols
|
||||
@property
|
||||
def string_length_limit(senf):
|
||||
return None
|
||||
|
||||
@property
|
||||
def container_length_limit(senf):
|
||||
return None
|
||||
|
||||
def resetWriteContext(self):
|
||||
self.context = JSONBaseContext(self)
|
||||
self.contextStack = [self.context]
|
||||
|
||||
def resetReadContext(self):
|
||||
self.resetWriteContext()
|
||||
self.reader = LookaheadReader(self)
|
||||
|
||||
def pushContext(self, ctx):
|
||||
self.contextStack.append(ctx)
|
||||
self.context = ctx
|
||||
|
||||
def popContext(self):
|
||||
self.contextStack.pop()
|
||||
if self.contextStack:
|
||||
self.context = self.contextStack[-1]
|
||||
else:
|
||||
self.context = JSONBaseContext(self)
|
||||
|
||||
def writeJSONString(self, string):
|
||||
self.context.write()
|
||||
json_str = ['"']
|
||||
for s in string:
|
||||
escaped = ESCAPE_CHAR_VALS.get(s, s)
|
||||
json_str.append(escaped)
|
||||
json_str.append('"')
|
||||
self.trans.write(str_to_binary(''.join(json_str)))
|
||||
|
||||
def writeJSONNumber(self, number, formatter='{0}'):
|
||||
self.context.write()
|
||||
jsNumber = str(formatter.format(number)).encode('ascii')
|
||||
if self.context.escapeNum():
|
||||
self.trans.write(QUOTE)
|
||||
self.trans.write(jsNumber)
|
||||
self.trans.write(QUOTE)
|
||||
else:
|
||||
self.trans.write(jsNumber)
|
||||
|
||||
def writeJSONBase64(self, binary):
|
||||
self.context.write()
|
||||
self.trans.write(QUOTE)
|
||||
self.trans.write(base64.b64encode(binary))
|
||||
self.trans.write(QUOTE)
|
||||
|
||||
def writeJSONObjectStart(self):
|
||||
self.context.write()
|
||||
self.trans.write(LBRACE)
|
||||
self.pushContext(JSONPairContext(self))
|
||||
|
||||
def writeJSONObjectEnd(self):
|
||||
self.popContext()
|
||||
self.trans.write(RBRACE)
|
||||
|
||||
def writeJSONArrayStart(self):
|
||||
self.context.write()
|
||||
self.trans.write(LBRACKET)
|
||||
self.pushContext(JSONListContext(self))
|
||||
|
||||
def writeJSONArrayEnd(self):
|
||||
self.popContext()
|
||||
self.trans.write(RBRACKET)
|
||||
|
||||
def readJSONSyntaxChar(self, character):
|
||||
current = self.reader.read()
|
||||
if character != current:
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"Unexpected character: %s" % current)
|
||||
|
||||
def _isHighSurrogate(self, codeunit):
|
||||
return codeunit >= 0xd800 and codeunit <= 0xdbff
|
||||
|
||||
def _isLowSurrogate(self, codeunit):
|
||||
return codeunit >= 0xdc00 and codeunit <= 0xdfff
|
||||
|
||||
def _toChar(self, high, low=None):
|
||||
if not low:
|
||||
if sys.version_info[0] == 2:
|
||||
return ("\\u%04x" % high).decode('unicode-escape') \
|
||||
.encode('utf-8')
|
||||
else:
|
||||
return chr(high)
|
||||
else:
|
||||
codepoint = (1 << 16) + ((high & 0x3ff) << 10)
|
||||
codepoint += low & 0x3ff
|
||||
if sys.version_info[0] == 2:
|
||||
s = "\\U%08x" % codepoint
|
||||
return s.decode('unicode-escape').encode('utf-8')
|
||||
else:
|
||||
return chr(codepoint)
|
||||
|
||||
def readJSONString(self, skipContext):
|
||||
highSurrogate = None
|
||||
string = []
|
||||
if skipContext is False:
|
||||
self.context.read()
|
||||
self.readJSONSyntaxChar(QUOTE)
|
||||
while True:
|
||||
character = self.reader.read()
|
||||
if character == QUOTE:
|
||||
break
|
||||
if ord(character) == ESCSEQ0:
|
||||
character = self.reader.read()
|
||||
if ord(character) == ESCSEQ1:
|
||||
character = self.trans.read(4).decode('ascii')
|
||||
codeunit = int(character, 16)
|
||||
if self._isHighSurrogate(codeunit):
|
||||
if highSurrogate:
|
||||
raise TProtocolException(
|
||||
TProtocolException.INVALID_DATA,
|
||||
"Expected low surrogate char")
|
||||
highSurrogate = codeunit
|
||||
continue
|
||||
elif self._isLowSurrogate(codeunit):
|
||||
if not highSurrogate:
|
||||
raise TProtocolException(
|
||||
TProtocolException.INVALID_DATA,
|
||||
"Expected high surrogate char")
|
||||
character = self._toChar(highSurrogate, codeunit)
|
||||
highSurrogate = None
|
||||
else:
|
||||
character = self._toChar(codeunit)
|
||||
else:
|
||||
if character not in ESCAPE_CHARS:
|
||||
raise TProtocolException(
|
||||
TProtocolException.INVALID_DATA,
|
||||
"Expected control char")
|
||||
character = ESCAPE_CHARS[character]
|
||||
elif character in ESCAPE_CHAR_VALS:
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"Unescaped control char")
|
||||
elif sys.version_info[0] > 2:
|
||||
utf8_bytes = bytearray([ord(character)])
|
||||
while ord(self.reader.peek()) >= 0x80:
|
||||
utf8_bytes.append(ord(self.reader.read()))
|
||||
character = utf8_bytes.decode('utf8')
|
||||
string.append(character)
|
||||
|
||||
if highSurrogate:
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"Expected low surrogate char")
|
||||
return ''.join(string)
|
||||
|
||||
def isJSONNumeric(self, character):
|
||||
return (True if NUMERIC_CHAR.find(character) != - 1 else False)
|
||||
|
||||
def readJSONQuotes(self):
|
||||
if (self.context.escapeNum()):
|
||||
self.readJSONSyntaxChar(QUOTE)
|
||||
|
||||
def readJSONNumericChars(self):
|
||||
numeric = []
|
||||
while True:
|
||||
character = self.reader.peek()
|
||||
if self.isJSONNumeric(character) is False:
|
||||
break
|
||||
numeric.append(self.reader.read())
|
||||
return b''.join(numeric).decode('ascii')
|
||||
|
||||
def readJSONInteger(self):
|
||||
self.context.read()
|
||||
self.readJSONQuotes()
|
||||
numeric = self.readJSONNumericChars()
|
||||
self.readJSONQuotes()
|
||||
try:
|
||||
return int(numeric)
|
||||
except ValueError:
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"Bad data encounted in numeric data")
|
||||
|
||||
def readJSONDouble(self):
|
||||
self.context.read()
|
||||
if self.reader.peek() == QUOTE:
|
||||
string = self.readJSONString(True)
|
||||
try:
|
||||
double = float(string)
|
||||
if (self.context.escapeNum is False and
|
||||
not math.isinf(double) and
|
||||
not math.isnan(double)):
|
||||
raise TProtocolException(
|
||||
TProtocolException.INVALID_DATA,
|
||||
"Numeric data unexpectedly quoted")
|
||||
return double
|
||||
except ValueError:
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"Bad data encounted in numeric data")
|
||||
else:
|
||||
if self.context.escapeNum() is True:
|
||||
self.readJSONSyntaxChar(QUOTE)
|
||||
try:
|
||||
return float(self.readJSONNumericChars())
|
||||
except ValueError:
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"Bad data encounted in numeric data")
|
||||
|
||||
def readJSONBase64(self):
|
||||
string = self.readJSONString(False)
|
||||
size = len(string)
|
||||
m = size % 4
|
||||
# Force padding since b64encode method does not allow it
|
||||
if m != 0:
|
||||
for i in range(4 - m):
|
||||
string += '='
|
||||
return base64.b64decode(string)
|
||||
|
||||
def readJSONObjectStart(self):
|
||||
self.context.read()
|
||||
self.readJSONSyntaxChar(LBRACE)
|
||||
self.pushContext(JSONPairContext(self))
|
||||
|
||||
def readJSONObjectEnd(self):
|
||||
self.readJSONSyntaxChar(RBRACE)
|
||||
self.popContext()
|
||||
|
||||
def readJSONArrayStart(self):
|
||||
self.context.read()
|
||||
self.readJSONSyntaxChar(LBRACKET)
|
||||
self.pushContext(JSONListContext(self))
|
||||
|
||||
def readJSONArrayEnd(self):
|
||||
self.readJSONSyntaxChar(RBRACKET)
|
||||
self.popContext()
|
||||
|
||||
|
||||
class TJSONProtocol(TJSONProtocolBase):
|
||||
|
||||
def readMessageBegin(self):
|
||||
self.resetReadContext()
|
||||
self.readJSONArrayStart()
|
||||
if self.readJSONInteger() != VERSION:
|
||||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||
"Message contained bad version.")
|
||||
name = self.readJSONString(False)
|
||||
typen = self.readJSONInteger()
|
||||
seqid = self.readJSONInteger()
|
||||
return (name, typen, seqid)
|
||||
|
||||
def readMessageEnd(self):
|
||||
self.readJSONArrayEnd()
|
||||
|
||||
def readStructBegin(self):
|
||||
self.readJSONObjectStart()
|
||||
|
||||
def readStructEnd(self):
|
||||
self.readJSONObjectEnd()
|
||||
|
||||
def readFieldBegin(self):
|
||||
character = self.reader.peek()
|
||||
ttype = 0
|
||||
id = 0
|
||||
if character == RBRACE:
|
||||
ttype = TType.STOP
|
||||
else:
|
||||
id = self.readJSONInteger()
|
||||
self.readJSONObjectStart()
|
||||
ttype = JTYPES[self.readJSONString(False)]
|
||||
return (None, ttype, id)
|
||||
|
||||
def readFieldEnd(self):
|
||||
self.readJSONObjectEnd()
|
||||
|
||||
def readMapBegin(self):
|
||||
self.readJSONArrayStart()
|
||||
keyType = JTYPES[self.readJSONString(False)]
|
||||
valueType = JTYPES[self.readJSONString(False)]
|
||||
size = self.readJSONInteger()
|
||||
self.readJSONObjectStart()
|
||||
return (keyType, valueType, size)
|
||||
|
||||
def readMapEnd(self):
|
||||
self.readJSONObjectEnd()
|
||||
self.readJSONArrayEnd()
|
||||
|
||||
def readCollectionBegin(self):
|
||||
self.readJSONArrayStart()
|
||||
elemType = JTYPES[self.readJSONString(False)]
|
||||
size = self.readJSONInteger()
|
||||
return (elemType, size)
|
||||
readListBegin = readCollectionBegin
|
||||
readSetBegin = readCollectionBegin
|
||||
|
||||
def readCollectionEnd(self):
|
||||
self.readJSONArrayEnd()
|
||||
readSetEnd = readCollectionEnd
|
||||
readListEnd = readCollectionEnd
|
||||
|
||||
def readBool(self):
|
||||
return (False if self.readJSONInteger() == 0 else True)
|
||||
|
||||
def readNumber(self):
|
||||
return self.readJSONInteger()
|
||||
readByte = readNumber
|
||||
readI16 = readNumber
|
||||
readI32 = readNumber
|
||||
readI64 = readNumber
|
||||
|
||||
def readDouble(self):
|
||||
return self.readJSONDouble()
|
||||
|
||||
def readString(self):
|
||||
return self.readJSONString(False)
|
||||
|
||||
def readBinary(self):
|
||||
return self.readJSONBase64()
|
||||
|
||||
def writeMessageBegin(self, name, request_type, seqid):
|
||||
self.resetWriteContext()
|
||||
self.writeJSONArrayStart()
|
||||
self.writeJSONNumber(VERSION)
|
||||
self.writeJSONString(name)
|
||||
self.writeJSONNumber(request_type)
|
||||
self.writeJSONNumber(seqid)
|
||||
|
||||
def writeMessageEnd(self):
|
||||
self.writeJSONArrayEnd()
|
||||
|
||||
def writeStructBegin(self, name):
|
||||
self.writeJSONObjectStart()
|
||||
|
||||
def writeStructEnd(self):
|
||||
self.writeJSONObjectEnd()
|
||||
|
||||
def writeFieldBegin(self, name, ttype, id):
|
||||
self.writeJSONNumber(id)
|
||||
self.writeJSONObjectStart()
|
||||
self.writeJSONString(CTYPES[ttype])
|
||||
|
||||
def writeFieldEnd(self):
|
||||
self.writeJSONObjectEnd()
|
||||
|
||||
def writeFieldStop(self):
|
||||
pass
|
||||
|
||||
def writeMapBegin(self, ktype, vtype, size):
|
||||
self.writeJSONArrayStart()
|
||||
self.writeJSONString(CTYPES[ktype])
|
||||
self.writeJSONString(CTYPES[vtype])
|
||||
self.writeJSONNumber(size)
|
||||
self.writeJSONObjectStart()
|
||||
|
||||
def writeMapEnd(self):
|
||||
self.writeJSONObjectEnd()
|
||||
self.writeJSONArrayEnd()
|
||||
|
||||
def writeListBegin(self, etype, size):
|
||||
self.writeJSONArrayStart()
|
||||
self.writeJSONString(CTYPES[etype])
|
||||
self.writeJSONNumber(size)
|
||||
|
||||
def writeListEnd(self):
|
||||
self.writeJSONArrayEnd()
|
||||
|
||||
def writeSetBegin(self, etype, size):
|
||||
self.writeJSONArrayStart()
|
||||
self.writeJSONString(CTYPES[etype])
|
||||
self.writeJSONNumber(size)
|
||||
|
||||
def writeSetEnd(self):
|
||||
self.writeJSONArrayEnd()
|
||||
|
||||
def writeBool(self, boolean):
|
||||
self.writeJSONNumber(1 if boolean is True else 0)
|
||||
|
||||
def writeByte(self, byte):
|
||||
checkIntegerLimits(byte, 8)
|
||||
self.writeJSONNumber(byte)
|
||||
|
||||
def writeI16(self, i16):
|
||||
checkIntegerLimits(i16, 16)
|
||||
self.writeJSONNumber(i16)
|
||||
|
||||
def writeI32(self, i32):
|
||||
checkIntegerLimits(i32, 32)
|
||||
self.writeJSONNumber(i32)
|
||||
|
||||
def writeI64(self, i64):
|
||||
checkIntegerLimits(i64, 64)
|
||||
self.writeJSONNumber(i64)
|
||||
|
||||
def writeDouble(self, dbl):
|
||||
# 17 significant digits should be just enough for any double precision
|
||||
# value.
|
||||
self.writeJSONNumber(dbl, '{0:.17g}')
|
||||
|
||||
def writeString(self, string):
|
||||
self.writeJSONString(string)
|
||||
|
||||
def writeBinary(self, binary):
|
||||
self.writeJSONBase64(binary)
|
||||
|
||||
|
||||
class TJSONProtocolFactory(object):
|
||||
def getProtocol(self, trans):
|
||||
return TJSONProtocol(trans)
|
||||
|
||||
@property
|
||||
def string_length_limit(senf):
|
||||
return None
|
||||
|
||||
@property
|
||||
def container_length_limit(senf):
|
||||
return None
|
||||
|
||||
|
||||
class TSimpleJSONProtocol(TJSONProtocolBase):
|
||||
"""Simple, readable, write-only JSON protocol.
|
||||
|
||||
Useful for interacting with scripting languages.
|
||||
"""
|
||||
|
||||
def readMessageBegin(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def readMessageEnd(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def readStructBegin(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def readStructEnd(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def writeMessageBegin(self, name, request_type, seqid):
|
||||
self.resetWriteContext()
|
||||
|
||||
def writeMessageEnd(self):
|
||||
pass
|
||||
|
||||
def writeStructBegin(self, name):
|
||||
self.writeJSONObjectStart()
|
||||
|
||||
def writeStructEnd(self):
|
||||
self.writeJSONObjectEnd()
|
||||
|
||||
def writeFieldBegin(self, name, ttype, fid):
|
||||
self.writeJSONString(name)
|
||||
|
||||
def writeFieldEnd(self):
|
||||
pass
|
||||
|
||||
def writeMapBegin(self, ktype, vtype, size):
|
||||
self.writeJSONObjectStart()
|
||||
|
||||
def writeMapEnd(self):
|
||||
self.writeJSONObjectEnd()
|
||||
|
||||
def _writeCollectionBegin(self, etype, size):
|
||||
self.writeJSONArrayStart()
|
||||
|
||||
def _writeCollectionEnd(self):
|
||||
self.writeJSONArrayEnd()
|
||||
writeListBegin = _writeCollectionBegin
|
||||
writeListEnd = _writeCollectionEnd
|
||||
writeSetBegin = _writeCollectionBegin
|
||||
writeSetEnd = _writeCollectionEnd
|
||||
|
||||
def writeByte(self, byte):
|
||||
checkIntegerLimits(byte, 8)
|
||||
self.writeJSONNumber(byte)
|
||||
|
||||
def writeI16(self, i16):
|
||||
checkIntegerLimits(i16, 16)
|
||||
self.writeJSONNumber(i16)
|
||||
|
||||
def writeI32(self, i32):
|
||||
checkIntegerLimits(i32, 32)
|
||||
self.writeJSONNumber(i32)
|
||||
|
||||
def writeI64(self, i64):
|
||||
checkIntegerLimits(i64, 64)
|
||||
self.writeJSONNumber(i64)
|
||||
|
||||
def writeBool(self, boolean):
|
||||
self.writeJSONNumber(1 if boolean is True else 0)
|
||||
|
||||
def writeDouble(self, dbl):
|
||||
self.writeJSONNumber(dbl)
|
||||
|
||||
def writeString(self, string):
|
||||
self.writeJSONString(string)
|
||||
|
||||
def writeBinary(self, binary):
|
||||
self.writeJSONBase64(binary)
|
||||
|
||||
|
||||
class TSimpleJSONProtocolFactory(object):
|
||||
|
||||
def getProtocol(self, trans):
|
||||
return TSimpleJSONProtocol(trans)
|
40
thrift/protocol/TMultiplexedProtocol.py
Normal file
40
thrift/protocol/TMultiplexedProtocol.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from thrift.Thrift import TMessageType
|
||||
from thrift.protocol import TProtocolDecorator
|
||||
|
||||
SEPARATOR = ":"
|
||||
|
||||
|
||||
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
|
||||
def __init__(self, protocol, serviceName):
|
||||
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
|
||||
self.serviceName = serviceName
|
||||
|
||||
def writeMessageBegin(self, name, type, seqid):
|
||||
if (type == TMessageType.CALL or
|
||||
type == TMessageType.ONEWAY):
|
||||
self.protocol.writeMessageBegin(
|
||||
self.serviceName + SEPARATOR + name,
|
||||
type,
|
||||
seqid
|
||||
)
|
||||
else:
|
||||
self.protocol.writeMessageBegin(name, type, seqid)
|
|
@ -17,7 +17,14 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from thrift.Thrift import *
|
||||
from thrift.Thrift import TException, TType, TFrozenDict
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
from ..compat import binary_to_str, str_to_binary
|
||||
|
||||
import six
|
||||
import sys
|
||||
from itertools import islice
|
||||
from six.moves import zip
|
||||
|
||||
|
||||
class TProtocolException(TException):
|
||||
|
@ -28,19 +35,32 @@ class TProtocolException(TException):
|
|||
NEGATIVE_SIZE = 2
|
||||
SIZE_LIMIT = 3
|
||||
BAD_VERSION = 4
|
||||
NOT_IMPLEMENTED = 5
|
||||
DEPTH_LIMIT = 6
|
||||
|
||||
def __init__(self, type=UNKNOWN, message=None):
|
||||
TException.__init__(self, message)
|
||||
self.type = type
|
||||
|
||||
|
||||
class TProtocolBase:
|
||||
class TProtocolBase(object):
|
||||
"""Base class for Thrift protocol driver."""
|
||||
|
||||
def __init__(self, trans):
|
||||
self.trans = trans
|
||||
self._fast_decode = None
|
||||
self._fast_encode = None
|
||||
|
||||
def writeMessageBegin(self, name, type, seqid):
|
||||
@staticmethod
|
||||
def _check_length(limit, length):
|
||||
if length < 0:
|
||||
raise TTransportException(TTransportException.NEGATIVE_SIZE,
|
||||
'Negative length: %d' % length)
|
||||
if limit is not None and length > limit:
|
||||
raise TTransportException(TTransportException.SIZE_LIMIT,
|
||||
'Length exceeded max allowed: %d' % limit)
|
||||
|
||||
def writeMessageBegin(self, name, ttype, seqid):
|
||||
pass
|
||||
|
||||
def writeMessageEnd(self):
|
||||
|
@ -52,7 +72,7 @@ class TProtocolBase:
|
|||
def writeStructEnd(self):
|
||||
pass
|
||||
|
||||
def writeFieldBegin(self, name, type, id):
|
||||
def writeFieldBegin(self, name, ttype, fid):
|
||||
pass
|
||||
|
||||
def writeFieldEnd(self):
|
||||
|
@ -79,7 +99,7 @@ class TProtocolBase:
|
|||
def writeSetEnd(self):
|
||||
pass
|
||||
|
||||
def writeBool(self, bool):
|
||||
def writeBool(self, bool_val):
|
||||
pass
|
||||
|
||||
def writeByte(self, byte):
|
||||
|
@ -97,9 +117,15 @@ class TProtocolBase:
|
|||
def writeDouble(self, dub):
|
||||
pass
|
||||
|
||||
def writeString(self, str):
|
||||
def writeString(self, str_val):
|
||||
self.writeBinary(str_to_binary(str_val))
|
||||
|
||||
def writeBinary(self, str_val):
|
||||
pass
|
||||
|
||||
def writeUtf8(self, str_val):
|
||||
self.writeString(str_val.encode('utf8'))
|
||||
|
||||
def readMessageBegin(self):
|
||||
pass
|
||||
|
||||
|
@ -155,46 +181,52 @@ class TProtocolBase:
|
|||
pass
|
||||
|
||||
def readString(self):
|
||||
return binary_to_str(self.readBinary())
|
||||
|
||||
def readBinary(self):
|
||||
pass
|
||||
|
||||
def skip(self, type):
|
||||
if type == TType.STOP:
|
||||
def readUtf8(self):
|
||||
return self.readString().decode('utf8')
|
||||
|
||||
def skip(self, ttype):
|
||||
if ttype == TType.STOP:
|
||||
return
|
||||
elif type == TType.BOOL:
|
||||
elif ttype == TType.BOOL:
|
||||
self.readBool()
|
||||
elif type == TType.BYTE:
|
||||
elif ttype == TType.BYTE:
|
||||
self.readByte()
|
||||
elif type == TType.I16:
|
||||
elif ttype == TType.I16:
|
||||
self.readI16()
|
||||
elif type == TType.I32:
|
||||
elif ttype == TType.I32:
|
||||
self.readI32()
|
||||
elif type == TType.I64:
|
||||
elif ttype == TType.I64:
|
||||
self.readI64()
|
||||
elif type == TType.DOUBLE:
|
||||
elif ttype == TType.DOUBLE:
|
||||
self.readDouble()
|
||||
elif type == TType.STRING:
|
||||
elif ttype == TType.STRING:
|
||||
self.readString()
|
||||
elif type == TType.STRUCT:
|
||||
elif ttype == TType.STRUCT:
|
||||
name = self.readStructBegin()
|
||||
while True:
|
||||
(name, type, id) = self.readFieldBegin()
|
||||
if type == TType.STOP:
|
||||
(name, ttype, id) = self.readFieldBegin()
|
||||
if ttype == TType.STOP:
|
||||
break
|
||||
self.skip(type)
|
||||
self.skip(ttype)
|
||||
self.readFieldEnd()
|
||||
self.readStructEnd()
|
||||
elif type == TType.MAP:
|
||||
elif ttype == TType.MAP:
|
||||
(ktype, vtype, size) = self.readMapBegin()
|
||||
for i in range(size):
|
||||
self.skip(ktype)
|
||||
self.skip(vtype)
|
||||
self.readMapEnd()
|
||||
elif type == TType.SET:
|
||||
elif ttype == TType.SET:
|
||||
(etype, size) = self.readSetBegin()
|
||||
for i in range(size):
|
||||
self.skip(etype)
|
||||
self.readSetEnd()
|
||||
elif type == TType.LIST:
|
||||
elif ttype == TType.LIST:
|
||||
(etype, size) = self.readListBegin()
|
||||
for i in range(size):
|
||||
self.skip(etype)
|
||||
|
@ -222,55 +254,47 @@ class TProtocolBase:
|
|||
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
|
||||
)
|
||||
|
||||
def _ttype_handlers(self, ttype, spec):
|
||||
if spec == 'BINARY':
|
||||
if ttype != TType.STRING:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid binary field type %d' % ttype)
|
||||
return ('readBinary', 'writeBinary', False)
|
||||
if sys.version_info[0] == 2 and spec == 'UTF8':
|
||||
if ttype != TType.STRING:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid string field type %d' % ttype)
|
||||
return ('readUtf8', 'writeUtf8', False)
|
||||
return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
|
||||
|
||||
def _read_by_ttype(self, ttype, spec, espec):
|
||||
reader_name, _, is_container = self._ttype_handlers(ttype, spec)
|
||||
if reader_name is None:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid type %d' % (ttype))
|
||||
reader_func = getattr(self, reader_name)
|
||||
read = (lambda: reader_func(espec)) if is_container else reader_func
|
||||
while True:
|
||||
yield read()
|
||||
|
||||
def readFieldByTType(self, ttype, spec):
|
||||
try:
|
||||
(r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
|
||||
except IndexError:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid field type %d' % (ttype))
|
||||
if r_handler is None:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid field type %d' % (ttype))
|
||||
reader = getattr(self, r_handler)
|
||||
if not is_container:
|
||||
return reader()
|
||||
return reader(spec)
|
||||
return next(self._read_by_ttype(ttype, spec, spec))
|
||||
|
||||
def readContainerList(self, spec):
|
||||
results = []
|
||||
ttype, tspec = spec[0], spec[1]
|
||||
r_handler = self._TTYPE_HANDLERS[ttype][0]
|
||||
reader = getattr(self, r_handler)
|
||||
ttype, tspec, is_immutable = spec
|
||||
(list_type, list_len) = self.readListBegin()
|
||||
if tspec is None:
|
||||
# list values are simple types
|
||||
for idx in xrange(list_len):
|
||||
results.append(reader())
|
||||
else:
|
||||
# this is like an inlined readFieldByTType
|
||||
container_reader = self._TTYPE_HANDLERS[list_type][0]
|
||||
val_reader = getattr(self, container_reader)
|
||||
for idx in xrange(list_len):
|
||||
val = val_reader(tspec)
|
||||
results.append(val)
|
||||
# TODO: compare types we just decoded with thrift_spec
|
||||
elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
|
||||
results = (tuple if is_immutable else list)(elems)
|
||||
self.readListEnd()
|
||||
return results
|
||||
|
||||
def readContainerSet(self, spec):
|
||||
results = set()
|
||||
ttype, tspec = spec[0], spec[1]
|
||||
r_handler = self._TTYPE_HANDLERS[ttype][0]
|
||||
reader = getattr(self, r_handler)
|
||||
ttype, tspec, is_immutable = spec
|
||||
(set_type, set_len) = self.readSetBegin()
|
||||
if tspec is None:
|
||||
# set members are simple types
|
||||
for idx in xrange(set_len):
|
||||
results.add(reader())
|
||||
else:
|
||||
container_reader = self._TTYPE_HANDLERS[set_type][0]
|
||||
val_reader = getattr(self, container_reader)
|
||||
for idx in xrange(set_len):
|
||||
results.add(val_reader(tspec))
|
||||
# TODO: compare types we just decoded with thrift_spec
|
||||
elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
|
||||
results = (frozenset if is_immutable else set)(elems)
|
||||
self.readSetEnd()
|
||||
return results
|
||||
|
||||
|
@ -281,31 +305,20 @@ class TProtocolBase:
|
|||
return obj
|
||||
|
||||
def readContainerMap(self, spec):
|
||||
results = dict()
|
||||
key_ttype, key_spec = spec[0], spec[1]
|
||||
val_ttype, val_spec = spec[2], spec[3]
|
||||
ktype, kspec, vtype, vspec, is_immutable = spec
|
||||
(map_ktype, map_vtype, map_len) = self.readMapBegin()
|
||||
# TODO: compare types we just decoded with thrift_spec and
|
||||
# abort/skip if types disagree
|
||||
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
|
||||
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
|
||||
# list values are simple types
|
||||
for idx in xrange(map_len):
|
||||
if key_spec is None:
|
||||
k_val = key_reader()
|
||||
else:
|
||||
k_val = self.readFieldByTType(key_ttype, key_spec)
|
||||
if val_spec is None:
|
||||
v_val = val_reader()
|
||||
else:
|
||||
v_val = self.readFieldByTType(val_ttype, val_spec)
|
||||
# this raises a TypeError with unhashable keys types
|
||||
# i.e. this fails: d=dict(); d[[0,1]] = 2
|
||||
results[k_val] = v_val
|
||||
keys = self._read_by_ttype(ktype, spec, kspec)
|
||||
vals = self._read_by_ttype(vtype, spec, vspec)
|
||||
keyvals = islice(zip(keys, vals), map_len)
|
||||
results = (TFrozenDict if is_immutable else dict)(keyvals)
|
||||
self.readMapEnd()
|
||||
return results
|
||||
|
||||
def readStruct(self, obj, thrift_spec):
|
||||
def readStruct(self, obj, thrift_spec, is_immutable=False):
|
||||
if is_immutable:
|
||||
fields = {}
|
||||
self.readStructBegin()
|
||||
while True:
|
||||
(fname, ftype, fid) = self.readFieldBegin()
|
||||
|
@ -320,56 +333,40 @@ class TProtocolBase:
|
|||
fname = field[2]
|
||||
fspec = field[3]
|
||||
val = self.readFieldByTType(ftype, fspec)
|
||||
if is_immutable:
|
||||
fields[fname] = val
|
||||
else:
|
||||
setattr(obj, fname, val)
|
||||
else:
|
||||
self.skip(ftype)
|
||||
self.readFieldEnd()
|
||||
self.readStructEnd()
|
||||
if is_immutable:
|
||||
return obj(**fields)
|
||||
|
||||
def writeContainerStruct(self, val, spec):
|
||||
val.write(self)
|
||||
|
||||
def writeContainerList(self, val, spec):
|
||||
self.writeListBegin(spec[0], len(val))
|
||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
|
||||
e_writer = getattr(self, w_handler)
|
||||
if not is_container:
|
||||
for elem in val:
|
||||
e_writer(elem)
|
||||
else:
|
||||
for elem in val:
|
||||
e_writer(elem, spec[1])
|
||||
ttype, tspec, _ = spec
|
||||
self.writeListBegin(ttype, len(val))
|
||||
for _ in self._write_by_ttype(ttype, val, spec, tspec):
|
||||
pass
|
||||
self.writeListEnd()
|
||||
|
||||
def writeContainerSet(self, val, spec):
|
||||
self.writeSetBegin(spec[0], len(val))
|
||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
|
||||
e_writer = getattr(self, w_handler)
|
||||
if not is_container:
|
||||
for elem in val:
|
||||
e_writer(elem)
|
||||
else:
|
||||
for elem in val:
|
||||
e_writer(elem, spec[1])
|
||||
ttype, tspec, _ = spec
|
||||
self.writeSetBegin(ttype, len(val))
|
||||
for _ in self._write_by_ttype(ttype, val, spec, tspec):
|
||||
pass
|
||||
self.writeSetEnd()
|
||||
|
||||
def writeContainerMap(self, val, spec):
|
||||
k_type = spec[0]
|
||||
v_type = spec[2]
|
||||
ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
|
||||
ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
|
||||
k_writer = getattr(self, ktype_name)
|
||||
v_writer = getattr(self, vtype_name)
|
||||
self.writeMapBegin(k_type, v_type, len(val))
|
||||
for m_key, m_val in val.iteritems():
|
||||
if not k_is_container:
|
||||
k_writer(m_key)
|
||||
else:
|
||||
k_writer(m_key, spec[1])
|
||||
if not v_is_container:
|
||||
v_writer(m_val)
|
||||
else:
|
||||
v_writer(m_val, spec[3])
|
||||
ktype, kspec, vtype, vspec, _ = spec
|
||||
self.writeMapBegin(ktype, vtype, len(val))
|
||||
for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
|
||||
self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
|
||||
pass
|
||||
self.writeMapEnd()
|
||||
|
||||
def writeStruct(self, obj, thrift_spec):
|
||||
|
@ -385,22 +382,38 @@ class TProtocolBase:
|
|||
fid = field[0]
|
||||
ftype = field[1]
|
||||
fspec = field[3]
|
||||
# get the writer method for this value
|
||||
self.writeFieldBegin(fname, ftype, fid)
|
||||
self.writeFieldByTType(ftype, val, fspec)
|
||||
self.writeFieldEnd()
|
||||
self.writeFieldStop()
|
||||
self.writeStructEnd()
|
||||
|
||||
def _write_by_ttype(self, ttype, vals, spec, espec):
|
||||
_, writer_name, is_container = self._ttype_handlers(ttype, spec)
|
||||
writer_func = getattr(self, writer_name)
|
||||
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
|
||||
for v in vals:
|
||||
yield write(v)
|
||||
|
||||
def writeFieldByTType(self, ttype, val, spec):
|
||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
|
||||
writer = getattr(self, w_handler)
|
||||
if is_container:
|
||||
writer(val, spec)
|
||||
else:
|
||||
writer(val)
|
||||
next(self._write_by_ttype(ttype, [val], spec, spec))
|
||||
|
||||
|
||||
class TProtocolFactory:
|
||||
def checkIntegerLimits(i, bits):
|
||||
if bits == 8 and (i < -128 or i > 127):
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"i8 requires -128 <= number <= 127")
|
||||
elif bits == 16 and (i < -32768 or i > 32767):
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"i16 requires -32768 <= number <= 32767")
|
||||
elif bits == 32 and (i < -2147483648 or i > 2147483647):
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"i32 requires -2147483648 <= number <= 2147483647")
|
||||
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
|
||||
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
|
||||
|
||||
|
||||
class TProtocolFactory(object):
|
||||
def getProtocol(self, trans):
|
||||
pass
|
||||
|
|
50
thrift/protocol/TProtocolDecorator.py
Normal file
50
thrift/protocol/TProtocolDecorator.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import types
|
||||
|
||||
from thrift.protocol.TProtocol import TProtocolBase
|
||||
|
||||
|
||||
class TProtocolDecorator():
|
||||
def __init__(self, protocol):
|
||||
TProtocolBase(protocol)
|
||||
self.protocol = protocol
|
||||
|
||||
def __getattr__(self, name):
|
||||
if hasattr(self.protocol, name):
|
||||
member = getattr(self.protocol, name)
|
||||
if type(member) in [
|
||||
types.MethodType,
|
||||
types.FunctionType,
|
||||
types.LambdaType,
|
||||
types.BuiltinFunctionType,
|
||||
types.BuiltinMethodType,
|
||||
]:
|
||||
return lambda *args, **kwargs: self._wrap(member, args, kwargs)
|
||||
else:
|
||||
return member
|
||||
raise AttributeError(name)
|
||||
|
||||
def _wrap(self, func, args, kwargs):
|
||||
if isinstance(func, types.MethodType):
|
||||
result = func(*args, **kwargs)
|
||||
else:
|
||||
result = func(self.protocol, *args, **kwargs)
|
||||
return result
|
|
@ -17,4 +17,5 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']
|
||||
__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
|
||||
'TJSONProtocol', 'TProtocol']
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
import BaseHTTPServer
|
||||
from six.moves import BaseHTTPServer
|
||||
|
||||
from thrift.server import TServer
|
||||
from thrift.transport import TTransport
|
||||
|
|
|
@ -24,18 +24,22 @@ only from the main thread.
|
|||
The thread poool should be sized for concurrent tasks, not
|
||||
maximum connections
|
||||
"""
|
||||
import threading
|
||||
import socket
|
||||
import Queue
|
||||
import select
|
||||
import struct
|
||||
|
||||
import logging
|
||||
import select
|
||||
import socket
|
||||
import struct
|
||||
import threading
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from thrift.transport import TTransport
|
||||
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
|
||||
|
||||
__all__ = ['TNonblockingServer']
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Worker(threading.Thread):
|
||||
"""Worker is a small helper to process incoming connection."""
|
||||
|
@ -54,8 +58,8 @@ class Worker(threading.Thread):
|
|||
processor.process(iprot, oprot)
|
||||
callback(True, otrans.getvalue())
|
||||
except Exception:
|
||||
logging.exception("Exception while processing request")
|
||||
callback(False, '')
|
||||
logger.exception("Exception while processing request")
|
||||
callback(False, b'')
|
||||
|
||||
WAIT_LEN = 0
|
||||
WAIT_MESSAGE = 1
|
||||
|
@ -85,7 +89,7 @@ def socket_exception(func):
|
|||
return read
|
||||
|
||||
|
||||
class Connection:
|
||||
class Connection(object):
|
||||
"""Basic class is represented connection.
|
||||
|
||||
It can be in state:
|
||||
|
@ -102,7 +106,7 @@ class Connection:
|
|||
self.socket.setblocking(False)
|
||||
self.status = WAIT_LEN
|
||||
self.len = 0
|
||||
self.message = ''
|
||||
self.message = b''
|
||||
self.lock = threading.Lock()
|
||||
self.wake_up = wake_up
|
||||
|
||||
|
@ -116,21 +120,21 @@ class Connection:
|
|||
# if we read 0 bytes and self.message is empty, then
|
||||
# the client closed the connection
|
||||
if len(self.message) != 0:
|
||||
logging.error("can't read frame size from socket")
|
||||
logger.error("can't read frame size from socket")
|
||||
self.close()
|
||||
return
|
||||
self.message += read
|
||||
if len(self.message) == 4:
|
||||
self.len, = struct.unpack('!i', self.message)
|
||||
if self.len < 0:
|
||||
logging.error("negative frame size, it seems client "
|
||||
logger.error("negative frame size, it seems client "
|
||||
"doesn't use FramedTransport")
|
||||
self.close()
|
||||
elif self.len == 0:
|
||||
logging.error("empty frame, it's really strange")
|
||||
logger.error("empty frame, it's really strange")
|
||||
self.close()
|
||||
else:
|
||||
self.message = ''
|
||||
self.message = b''
|
||||
self.status = WAIT_MESSAGE
|
||||
|
||||
@socket_exception
|
||||
|
@ -145,7 +149,7 @@ class Connection:
|
|||
elif self.status == WAIT_MESSAGE:
|
||||
read = self.socket.recv(self.len - len(self.message))
|
||||
if len(read) == 0:
|
||||
logging.error("can't read frame from socket (get %d of "
|
||||
logger.error("can't read frame from socket (get %d of "
|
||||
"%d bytes)" % (len(self.message), self.len))
|
||||
self.close()
|
||||
return
|
||||
|
@ -160,7 +164,7 @@ class Connection:
|
|||
sent = self.socket.send(self.message)
|
||||
if sent == len(self.message):
|
||||
self.status = WAIT_LEN
|
||||
self.message = ''
|
||||
self.message = b''
|
||||
self.len = 0
|
||||
else:
|
||||
self.message = self.message[sent:]
|
||||
|
@ -183,10 +187,10 @@ class Connection:
|
|||
self.close()
|
||||
self.wake_up()
|
||||
return
|
||||
self.len = ''
|
||||
self.len = 0
|
||||
if len(message) == 0:
|
||||
# it was a oneway request, do not write answer
|
||||
self.message = ''
|
||||
self.message = b''
|
||||
self.status = WAIT_LEN
|
||||
else:
|
||||
self.message = struct.pack('!i', len(message)) + message
|
||||
|
@ -219,7 +223,7 @@ class Connection:
|
|||
self.socket.close()
|
||||
|
||||
|
||||
class TNonblockingServer:
|
||||
class TNonblockingServer(object):
|
||||
"""Non-blocking server."""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -234,7 +238,7 @@ class TNonblockingServer:
|
|||
self.out_protocol = outputProtocolFactory or self.in_protocol
|
||||
self.threads = int(threads)
|
||||
self.clients = {}
|
||||
self.tasks = Queue.Queue()
|
||||
self.tasks = queue.Queue()
|
||||
self._read, self._write = socket.socketpair()
|
||||
self.prepared = False
|
||||
self._stop = False
|
||||
|
@ -250,7 +254,7 @@ class TNonblockingServer:
|
|||
if self.prepared:
|
||||
return
|
||||
self.socket.listen()
|
||||
for _ in xrange(self.threads):
|
||||
for _ in range(self.threads):
|
||||
thread = Worker(self.tasks)
|
||||
thread.setDaemon(True)
|
||||
thread.start()
|
||||
|
@ -259,7 +263,7 @@ class TNonblockingServer:
|
|||
def wake_up(self):
|
||||
"""Wake up main thread.
|
||||
|
||||
The server usualy waits in select call in we should terminate one.
|
||||
The server usually waits in select call in we should terminate one.
|
||||
The simplest way is using socketpair.
|
||||
|
||||
Select always wait to read from the first socket of socketpair.
|
||||
|
@ -267,7 +271,7 @@ class TNonblockingServer:
|
|||
In this case, we can just write anything to the second socket from
|
||||
socketpair.
|
||||
"""
|
||||
self._write.send('1')
|
||||
self._write.send(b'1')
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server.
|
||||
|
@ -288,7 +292,7 @@ class TNonblockingServer:
|
|||
"""Does select on open connections."""
|
||||
readable = [self.socket.handle.fileno(), self._read.fileno()]
|
||||
writable = []
|
||||
for i, connection in self.clients.items():
|
||||
for i, connection in list(self.clients.items()):
|
||||
if connection.is_readable():
|
||||
readable.append(connection.fileno())
|
||||
if connection.is_writeable():
|
||||
|
@ -330,7 +334,7 @@ class TNonblockingServer:
|
|||
|
||||
def close(self):
|
||||
"""Closes the server."""
|
||||
for _ in xrange(self.threads):
|
||||
for _ in range(self.threads):
|
||||
self.tasks.put([None, None, None, None, None])
|
||||
self.socket.close()
|
||||
self.prepared = False
|
||||
|
|
|
@ -19,11 +19,14 @@
|
|||
|
||||
|
||||
import logging
|
||||
from multiprocessing import Process, Value, Condition, reduction
|
||||
|
||||
from TServer import TServer
|
||||
from multiprocessing import Process, Value, Condition
|
||||
|
||||
from .TServer import TServer
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TProcessPoolServer(TServer):
|
||||
"""Server with a fixed size pool of worker subprocesses to service requests
|
||||
|
@ -56,11 +59,13 @@ class TProcessPoolServer(TServer):
|
|||
while self.isRunning.value:
|
||||
try:
|
||||
client = self.serverTransport.accept()
|
||||
if not client:
|
||||
continue
|
||||
self.serveClient(client)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
return 0
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
def serveClient(self, client):
|
||||
"""Process input/output from a client for as long as possible"""
|
||||
|
@ -72,10 +77,10 @@ class TProcessPoolServer(TServer):
|
|||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransportException, tx:
|
||||
except TTransportException:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
@ -95,8 +100,8 @@ class TProcessPoolServer(TServer):
|
|||
w.daemon = True
|
||||
w.start()
|
||||
self.workers.append(w)
|
||||
except Exception, x:
|
||||
logging.exception(x)
|
||||
except Exception as x:
|
||||
logger.exception(x)
|
||||
|
||||
# wait until the condition is set by stop()
|
||||
while True:
|
||||
|
@ -107,7 +112,7 @@ class TProcessPoolServer(TServer):
|
|||
except (SystemExit, KeyboardInterrupt):
|
||||
break
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
self.isRunning.value = False
|
||||
|
||||
|
|
|
@ -17,19 +17,18 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
import Queue
|
||||
from six.moves import queue
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from thrift.Thrift import TProcessor
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from thrift.transport import TTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TServer:
|
||||
|
||||
class TServer(object):
|
||||
"""Base interface for a server, which must have a serve() method.
|
||||
|
||||
Three constructors for all servers:
|
||||
|
@ -75,6 +74,8 @@ class TSimpleServer(TServer):
|
|||
self.serverTransport.listen()
|
||||
while True:
|
||||
client = self.serverTransport.accept()
|
||||
if not client:
|
||||
continue
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||
|
@ -82,10 +83,10 @@ class TSimpleServer(TServer):
|
|||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
except TTransport.TTransportException:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
@ -103,13 +104,15 @@ class TThreadedServer(TServer):
|
|||
while True:
|
||||
try:
|
||||
client = self.serverTransport.accept()
|
||||
if not client:
|
||||
continue
|
||||
t = threading.Thread(target=self.handle, args=(client,))
|
||||
t.setDaemon(self.daemon)
|
||||
t.start()
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
def handle(self, client):
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
|
@ -119,10 +122,10 @@ class TThreadedServer(TServer):
|
|||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
except TTransport.TTransportException:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
@ -133,7 +136,7 @@ class TThreadPoolServer(TServer):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
TServer.__init__(self, *args)
|
||||
self.clients = Queue.Queue()
|
||||
self.clients = queue.Queue()
|
||||
self.threads = 10
|
||||
self.daemon = kwargs.get("daemon", False)
|
||||
|
||||
|
@ -147,8 +150,8 @@ class TThreadPoolServer(TServer):
|
|||
try:
|
||||
client = self.clients.get()
|
||||
self.serveClient(client)
|
||||
except Exception, x:
|
||||
logging.exception(x)
|
||||
except Exception as x:
|
||||
logger.exception(x)
|
||||
|
||||
def serveClient(self, client):
|
||||
"""Process input/output from a client for as long as possible"""
|
||||
|
@ -159,10 +162,10 @@ class TThreadPoolServer(TServer):
|
|||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
except TTransport.TTransportException:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
@ -175,16 +178,18 @@ class TThreadPoolServer(TServer):
|
|||
t.setDaemon(self.daemon)
|
||||
t.start()
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
# Pump the socket for clients
|
||||
self.serverTransport.listen()
|
||||
while True:
|
||||
try:
|
||||
client = self.serverTransport.accept()
|
||||
if not client:
|
||||
continue
|
||||
self.clients.put(client)
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
|
||||
class TForkingServer(TServer):
|
||||
|
@ -209,11 +214,13 @@ class TForkingServer(TServer):
|
|||
try:
|
||||
file.close()
|
||||
except IOError as e:
|
||||
logging.warning(e, exc_info=True)
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
self.serverTransport.listen()
|
||||
while True:
|
||||
client = self.serverTransport.accept()
|
||||
if not client:
|
||||
continue
|
||||
try:
|
||||
pid = os.fork()
|
||||
|
||||
|
@ -240,10 +247,10 @@ class TForkingServer(TServer):
|
|||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
except TTransport.TTransportException:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logger.exception(e)
|
||||
ecode = 1
|
||||
finally:
|
||||
try_close(itrans)
|
||||
|
@ -251,10 +258,10 @@ class TForkingServer(TServer):
|
|||
|
||||
os._exit(ecode)
|
||||
|
||||
except TTransport.TTransportException, tx:
|
||||
except TTransport.TTransportException:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
logger.exception(x)
|
||||
|
||||
def collect_children(self):
|
||||
while self.children:
|
||||
|
|
|
@ -17,17 +17,18 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
import httplib
|
||||
from io import BytesIO
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import urllib
|
||||
import urlparse
|
||||
import warnings
|
||||
import base64
|
||||
|
||||
from cStringIO import StringIO
|
||||
from six.moves import urllib
|
||||
from six.moves import http_client
|
||||
|
||||
from TTransport import *
|
||||
from .TTransport import TTransportBase
|
||||
import six
|
||||
|
||||
|
||||
class THttpClient(TTransportBase):
|
||||
|
@ -52,31 +53,64 @@ class THttpClient(TTransportBase):
|
|||
self.path = path
|
||||
self.scheme = 'http'
|
||||
else:
|
||||
parsed = urlparse.urlparse(uri_or_host)
|
||||
parsed = urllib.parse.urlparse(uri_or_host)
|
||||
self.scheme = parsed.scheme
|
||||
assert self.scheme in ('http', 'https')
|
||||
if self.scheme == 'http':
|
||||
self.port = parsed.port or httplib.HTTP_PORT
|
||||
self.port = parsed.port or http_client.HTTP_PORT
|
||||
elif self.scheme == 'https':
|
||||
self.port = parsed.port or httplib.HTTPS_PORT
|
||||
self.port = parsed.port or http_client.HTTPS_PORT
|
||||
self.host = parsed.hostname
|
||||
self.path = parsed.path
|
||||
if parsed.query:
|
||||
self.path += '?%s' % parsed.query
|
||||
self.__wbuf = StringIO()
|
||||
try:
|
||||
proxy = urllib.request.getproxies()[self.scheme]
|
||||
except KeyError:
|
||||
proxy = None
|
||||
else:
|
||||
if urllib.request.proxy_bypass(self.host):
|
||||
proxy = None
|
||||
if proxy:
|
||||
parsed = urllib.parse.urlparse(proxy)
|
||||
self.realhost = self.host
|
||||
self.realport = self.port
|
||||
self.host = parsed.hostname
|
||||
self.port = parsed.port
|
||||
self.proxy_auth = self.basic_proxy_auth_header(parsed)
|
||||
else:
|
||||
self.realhost = self.realport = self.proxy_auth = None
|
||||
self.__wbuf = BytesIO()
|
||||
self.__http = None
|
||||
self.__http_response = None
|
||||
self.__timeout = None
|
||||
self.__custom_headers = None
|
||||
|
||||
@staticmethod
|
||||
def basic_proxy_auth_header(proxy):
|
||||
if proxy is None or not proxy.username:
|
||||
return None
|
||||
ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
|
||||
urllib.parse.unquote(proxy.password))
|
||||
cr = base64.b64encode(ap).strip()
|
||||
return "Basic " + cr
|
||||
|
||||
def using_proxy(self):
|
||||
return self.realhost is not None
|
||||
|
||||
def open(self):
|
||||
if self.scheme == 'http':
|
||||
self.__http = httplib.HTTP(self.host, self.port)
|
||||
else:
|
||||
self.__http = httplib.HTTPS(self.host, self.port)
|
||||
self.__http = http_client.HTTPConnection(self.host, self.port)
|
||||
elif self.scheme == 'https':
|
||||
self.__http = http_client.HTTPSConnection(self.host, self.port)
|
||||
if self.using_proxy():
|
||||
self.__http.set_tunnel(self.realhost, self.realport,
|
||||
{"Proxy-Authorization": self.proxy_auth})
|
||||
|
||||
def close(self):
|
||||
self.__http.close()
|
||||
self.__http = None
|
||||
self.__http_response = None
|
||||
|
||||
def isOpen(self):
|
||||
return self.__http is not None
|
||||
|
@ -94,7 +128,7 @@ class THttpClient(TTransportBase):
|
|||
self.__custom_headers = headers
|
||||
|
||||
def read(self, sz):
|
||||
return self.__http.file.read(sz)
|
||||
return self.__http_response.read(sz)
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
@ -103,7 +137,9 @@ class THttpClient(TTransportBase):
|
|||
def _f(*args, **kwargs):
|
||||
orig_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(args[0].__timeout)
|
||||
try:
|
||||
result = f(*args, **kwargs)
|
||||
finally:
|
||||
socket.setdefaulttimeout(orig_timeout)
|
||||
return result
|
||||
return _f
|
||||
|
@ -115,25 +151,31 @@ class THttpClient(TTransportBase):
|
|||
|
||||
# Pull data out of buffer
|
||||
data = self.__wbuf.getvalue()
|
||||
self.__wbuf = StringIO()
|
||||
self.__wbuf = BytesIO()
|
||||
|
||||
# HTTP request
|
||||
if self.using_proxy() and self.scheme == "http":
|
||||
# need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel)
|
||||
self.__http.putrequest('POST', "http://%s:%s%s" %
|
||||
(self.realhost, self.realport, self.path))
|
||||
else:
|
||||
self.__http.putrequest('POST', self.path)
|
||||
|
||||
# Write headers
|
||||
self.__http.putheader('Host', self.host)
|
||||
self.__http.putheader('Content-Type', 'application/x-thrift')
|
||||
self.__http.putheader('Content-Length', str(len(data)))
|
||||
if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None:
|
||||
self.__http.putheader("Proxy-Authorization", self.proxy_auth)
|
||||
|
||||
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
|
||||
user_agent = 'Python/THttpClient'
|
||||
script = os.path.basename(sys.argv[0])
|
||||
if script:
|
||||
user_agent = '%s (%s)' % (user_agent, urllib.quote(script))
|
||||
user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
|
||||
self.__http.putheader('User-Agent', user_agent)
|
||||
|
||||
if self.__custom_headers:
|
||||
for key, val in self.__custom_headers.iteritems():
|
||||
for key, val in six.iteritems(self.__custom_headers):
|
||||
self.__http.putheader(key, val)
|
||||
|
||||
self.__http.endheaders()
|
||||
|
@ -142,7 +184,10 @@ class THttpClient(TTransportBase):
|
|||
self.__http.send(data)
|
||||
|
||||
# Get reply to flush the request
|
||||
self.code, self.message, self.headers = self.__http.getreply()
|
||||
self.__http_response = self.__http.getresponse()
|
||||
self.code = self.__http_response.status
|
||||
self.message = self.__http_response.reason
|
||||
self.headers = self.__http_response.msg
|
||||
|
||||
# Decorate if we know how to timeout
|
||||
if hasattr(socket, 'getdefaulttimeout'):
|
||||
|
|
|
@ -17,161 +17,341 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from .sslcompat import _match_hostname, _match_has_ipaddress
|
||||
from thrift.transport import TSocket
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
warnings.filterwarnings(
|
||||
'default', category=DeprecationWarning, module=__name__)
|
||||
|
||||
class TSSLSocket(TSocket.TSocket):
|
||||
|
||||
class TSSLBase(object):
|
||||
# SSLContext is not available for Python < 2.7.9
|
||||
_has_ssl_context = sys.hexversion >= 0x020709F0
|
||||
|
||||
# ciphers argument is not available for Python < 2.7.0
|
||||
_has_ciphers = sys.hexversion >= 0x020700F0
|
||||
|
||||
# For pythoon >= 2.7.9, use latest TLS that both client and server
|
||||
# supports.
|
||||
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
|
||||
# For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
|
||||
# unavailable.
|
||||
_default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
|
||||
ssl.PROTOCOL_TLSv1
|
||||
|
||||
def _init_context(self, ssl_version):
|
||||
if self._has_ssl_context:
|
||||
self._context = ssl.SSLContext(ssl_version)
|
||||
if self._context.protocol == ssl.PROTOCOL_SSLv23:
|
||||
self._context.options |= ssl.OP_NO_SSLv2
|
||||
self._context.options |= ssl.OP_NO_SSLv3
|
||||
else:
|
||||
self._context = None
|
||||
self._ssl_version = ssl_version
|
||||
|
||||
@property
|
||||
def _should_verify(self):
|
||||
if self._has_ssl_context:
|
||||
return self._context.verify_mode != ssl.CERT_NONE
|
||||
else:
|
||||
return self.cert_reqs != ssl.CERT_NONE
|
||||
|
||||
@property
|
||||
def ssl_version(self):
|
||||
if self._has_ssl_context:
|
||||
return self.ssl_context.protocol
|
||||
else:
|
||||
return self._ssl_version
|
||||
|
||||
@property
|
||||
def ssl_context(self):
|
||||
return self._context
|
||||
|
||||
SSL_VERSION = _default_protocol
|
||||
"""
|
||||
SSL implementation of client-side TSocket
|
||||
Default SSL version.
|
||||
For backword compatibility, it can be modified.
|
||||
Use __init__ keywoard argument "ssl_version" instead.
|
||||
"""
|
||||
|
||||
def _deprecated_arg(self, args, kwargs, pos, key):
|
||||
if len(args) <= pos:
|
||||
return
|
||||
real_pos = pos + 3
|
||||
warnings.warn(
|
||||
'%dth positional argument is deprecated.'
|
||||
'please use keyward argument insteand.'
|
||||
% real_pos, DeprecationWarning, stacklevel=3)
|
||||
|
||||
if key in kwargs:
|
||||
raise TypeError(
|
||||
'Duplicate argument: %dth argument and %s keyward argument.'
|
||||
% (real_pos, key))
|
||||
kwargs[key] = args[pos]
|
||||
|
||||
def _unix_socket_arg(self, host, port, args, kwargs):
|
||||
key = 'unix_socket'
|
||||
if host is None and port is None and len(args) == 1 and key not in kwargs:
|
||||
kwargs[key] = args[0]
|
||||
return True
|
||||
return False
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key == 'SSL_VERSION':
|
||||
warnings.warn(
|
||||
'SSL_VERSION is deprecated.'
|
||||
'please use ssl_version attribute instead.',
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.ssl_version
|
||||
|
||||
def __init__(self, server_side, host, ssl_opts):
|
||||
self._server_side = server_side
|
||||
if TSSLBase.SSL_VERSION != self._default_protocol:
|
||||
warnings.warn(
|
||||
'SSL_VERSION is deprecated.'
|
||||
'please use ssl_version keyward argument instead.',
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self._context = ssl_opts.pop('ssl_context', None)
|
||||
self._server_hostname = None
|
||||
if not self._server_side:
|
||||
self._server_hostname = ssl_opts.pop('server_hostname', host)
|
||||
if self._context:
|
||||
self._custom_context = True
|
||||
if ssl_opts:
|
||||
raise ValueError(
|
||||
'Incompatible arguments: ssl_context and %s'
|
||||
% ' '.join(ssl_opts.keys()))
|
||||
if not self._has_ssl_context:
|
||||
raise ValueError(
|
||||
'ssl_context is not available for this version of Python')
|
||||
else:
|
||||
self._custom_context = False
|
||||
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
|
||||
self._init_context(ssl_version)
|
||||
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
|
||||
self.ca_certs = ssl_opts.pop('ca_certs', None)
|
||||
self.keyfile = ssl_opts.pop('keyfile', None)
|
||||
self.certfile = ssl_opts.pop('certfile', None)
|
||||
self.ciphers = ssl_opts.pop('ciphers', None)
|
||||
|
||||
if ssl_opts:
|
||||
raise ValueError(
|
||||
'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
|
||||
|
||||
if self._should_verify:
|
||||
if not self.ca_certs:
|
||||
raise ValueError(
|
||||
'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
|
||||
if not os.access(self.ca_certs, os.R_OK):
|
||||
raise IOError('Certificate Authority ca_certs file "%s" '
|
||||
'is not readable, cannot validate SSL '
|
||||
'certificates.' % (self.ca_certs))
|
||||
|
||||
@property
|
||||
def certfile(self):
|
||||
return self._certfile
|
||||
|
||||
@certfile.setter
|
||||
def certfile(self, certfile):
|
||||
if self._server_side and not certfile:
|
||||
raise ValueError('certfile is needed for server-side')
|
||||
if certfile and not os.access(certfile, os.R_OK):
|
||||
raise IOError('No such certfile found: %s' % (certfile))
|
||||
self._certfile = certfile
|
||||
|
||||
def _wrap_socket(self, sock):
|
||||
if self._has_ssl_context:
|
||||
if not self._custom_context:
|
||||
self.ssl_context.verify_mode = self.cert_reqs
|
||||
if self.certfile:
|
||||
self.ssl_context.load_cert_chain(self.certfile,
|
||||
self.keyfile)
|
||||
if self.ciphers:
|
||||
self.ssl_context.set_ciphers(self.ciphers)
|
||||
if self.ca_certs:
|
||||
self.ssl_context.load_verify_locations(self.ca_certs)
|
||||
return self.ssl_context.wrap_socket(
|
||||
sock, server_side=self._server_side,
|
||||
server_hostname=self._server_hostname)
|
||||
else:
|
||||
ssl_opts = {
|
||||
'ssl_version': self._ssl_version,
|
||||
'server_side': self._server_side,
|
||||
'ca_certs': self.ca_certs,
|
||||
'keyfile': self.keyfile,
|
||||
'certfile': self.certfile,
|
||||
'cert_reqs': self.cert_reqs,
|
||||
}
|
||||
if self.ciphers:
|
||||
if self._has_ciphers:
|
||||
ssl_opts['ciphers'] = self.ciphers
|
||||
else:
|
||||
logger.warning(
|
||||
'ciphers is specified but ignored due to old Python version')
|
||||
return ssl.wrap_socket(sock, **ssl_opts)
|
||||
|
||||
|
||||
class TSSLSocket(TSocket.TSocket, TSSLBase):
|
||||
"""
|
||||
SSL implementation of TSocket
|
||||
|
||||
This class creates outbound sockets wrapped using the
|
||||
python standard ssl module for encrypted connections.
|
||||
|
||||
The protocol used is set using the class variable
|
||||
SSL_VERSION, which must be one of ssl.PROTOCOL_* and
|
||||
defaults to ssl.PROTOCOL_TLSv1 for greatest security.
|
||||
"""
|
||||
SSL_VERSION = ssl.PROTOCOL_TLSv1
|
||||
|
||||
def __init__(self,
|
||||
host='localhost',
|
||||
port=9090,
|
||||
validate=True,
|
||||
ca_certs=None,
|
||||
unix_socket=None):
|
||||
"""Create SSL TSocket
|
||||
# New signature
|
||||
# def __init__(self, host='localhost', port=9090, unix_socket=None,
|
||||
# **ssl_args):
|
||||
# Deprecated signature
|
||||
# def __init__(self, host='localhost', port=9090, validate=True,
|
||||
# ca_certs=None, keyfile=None, certfile=None,
|
||||
# unix_socket=None, ciphers=None):
|
||||
def __init__(self, host='localhost', port=9090, *args, **kwargs):
|
||||
"""Positional arguments: ``host``, ``port``, ``unix_socket``
|
||||
|
||||
@param validate: Set to False to disable SSL certificate validation
|
||||
@type validate: bool
|
||||
@param ca_certs: Filename to the Certificate Authority pem file, possibly a
|
||||
file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to
|
||||
the ssl_wrap function as the 'ca_certs' parameter.
|
||||
@type ca_certs: str
|
||||
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
|
||||
``ssl_version``, ``ca_certs``,
|
||||
``ciphers`` (Python 2.7.0 or later),
|
||||
``server_hostname`` (Python 2.7.9 or later)
|
||||
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
|
||||
|
||||
Raises an IOError exception if validate is True and the ca_certs file is
|
||||
None, not present or unreadable.
|
||||
Alternative keyword arguments: (Python 2.7.9 or later)
|
||||
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
|
||||
``server_hostname``: Passed to SSLContext.wrap_socket
|
||||
|
||||
Common keyword argument:
|
||||
``validate_callback`` (cert, hostname) -> None:
|
||||
Called after SSL handshake. Can raise when hostname does not
|
||||
match the cert.
|
||||
"""
|
||||
self.validate = validate
|
||||
self.is_valid = False
|
||||
self.peercert = None
|
||||
if not validate:
|
||||
self.cert_reqs = ssl.CERT_NONE
|
||||
else:
|
||||
self.cert_reqs = ssl.CERT_REQUIRED
|
||||
self.ca_certs = ca_certs
|
||||
if validate:
|
||||
if ca_certs is None or not os.access(ca_certs, os.R_OK):
|
||||
raise IOError('Certificate Authority ca_certs file "%s" '
|
||||
'is not readable, cannot validate SSL '
|
||||
'certificates.' % (ca_certs))
|
||||
|
||||
if args:
|
||||
if len(args) > 6:
|
||||
raise TypeError('Too many positional argument')
|
||||
if not self._unix_socket_arg(host, port, args, kwargs):
|
||||
self._deprecated_arg(args, kwargs, 0, 'validate')
|
||||
self._deprecated_arg(args, kwargs, 1, 'ca_certs')
|
||||
self._deprecated_arg(args, kwargs, 2, 'keyfile')
|
||||
self._deprecated_arg(args, kwargs, 3, 'certfile')
|
||||
self._deprecated_arg(args, kwargs, 4, 'unix_socket')
|
||||
self._deprecated_arg(args, kwargs, 5, 'ciphers')
|
||||
|
||||
validate = kwargs.pop('validate', None)
|
||||
if validate is not None:
|
||||
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
|
||||
warnings.warn(
|
||||
'validate is deprecated. please use cert_reqs=ssl.%s instead'
|
||||
% cert_reqs_name,
|
||||
DeprecationWarning, stacklevel=2)
|
||||
if 'cert_reqs' in kwargs:
|
||||
raise TypeError('Cannot specify both validate and cert_reqs')
|
||||
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
|
||||
|
||||
unix_socket = kwargs.pop('unix_socket', None)
|
||||
self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
|
||||
TSSLBase.__init__(self, False, host, kwargs)
|
||||
TSocket.TSocket.__init__(self, host, port, unix_socket)
|
||||
|
||||
@property
|
||||
def validate(self):
|
||||
warnings.warn('validate is deprecated. please use cert_reqs instead',
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return self.cert_reqs != ssl.CERT_NONE
|
||||
|
||||
@validate.setter
|
||||
def validate(self, value):
|
||||
warnings.warn('validate is deprecated. please use cert_reqs instead',
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
|
||||
|
||||
def _do_open(self, family, socktype):
|
||||
plain_sock = socket.socket(family, socktype)
|
||||
try:
|
||||
return self._wrap_socket(plain_sock)
|
||||
except Exception:
|
||||
plain_sock.close()
|
||||
msg = 'failed to initialize SSL'
|
||||
logger.exception(msg)
|
||||
raise TTransportException(TTransportException.NOT_OPEN, msg)
|
||||
|
||||
def open(self):
|
||||
super(TSSLSocket, self).open()
|
||||
if self._should_verify:
|
||||
self.peercert = self.handle.getpeercert()
|
||||
try:
|
||||
res0 = self._resolveAddr()
|
||||
for res in res0:
|
||||
sock_family, sock_type = res[0:2]
|
||||
ip_port = res[4]
|
||||
plain_sock = socket.socket(sock_family, sock_type)
|
||||
self.handle = ssl.wrap_socket(plain_sock,
|
||||
ssl_version=self.SSL_VERSION,
|
||||
do_handshake_on_connect=True,
|
||||
ca_certs=self.ca_certs,
|
||||
cert_reqs=self.cert_reqs)
|
||||
self.handle.settimeout(self._timeout)
|
||||
try:
|
||||
self.handle.connect(ip_port)
|
||||
except socket.error as e:
|
||||
if res is not res0[-1]:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
break
|
||||
except socket.error as e:
|
||||
if self._unix_socket:
|
||||
message = 'Could not connect to secure socket %s' % self._unix_socket
|
||||
else:
|
||||
message = 'Could not connect to %s:%d' % (self.host, self.port)
|
||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
||||
message=message)
|
||||
if self.validate:
|
||||
self._validate_cert()
|
||||
|
||||
def _validate_cert(self):
|
||||
"""internal method to validate the peer's SSL certificate, and to check the
|
||||
commonName of the certificate to ensure it matches the hostname we
|
||||
used to make this connection. Does not support subjectAltName records
|
||||
in certificates.
|
||||
|
||||
raises TTransportException if the certificate fails validation.
|
||||
"""
|
||||
cert = self.handle.getpeercert()
|
||||
self.peercert = cert
|
||||
if 'subject' not in cert:
|
||||
raise TTransportException(
|
||||
type=TTransportException.NOT_OPEN,
|
||||
message='No SSL certificate found from %s:%s' % (self.host, self.port))
|
||||
fields = cert['subject']
|
||||
for field in fields:
|
||||
# ensure structure we get back is what we expect
|
||||
if not isinstance(field, tuple):
|
||||
continue
|
||||
cert_pair = field[0]
|
||||
if len(cert_pair) < 2:
|
||||
continue
|
||||
cert_key, cert_value = cert_pair[0:2]
|
||||
if cert_key != 'commonName':
|
||||
continue
|
||||
certhost = cert_value
|
||||
if certhost == self.host:
|
||||
# success, cert commonName matches desired hostname
|
||||
self._validate_callback(self.peercert, self._server_hostname)
|
||||
self.is_valid = True
|
||||
return
|
||||
else:
|
||||
raise TTransportException(
|
||||
type=TTransportException.UNKNOWN,
|
||||
message='Hostname we connected to "%s" doesn\'t match certificate '
|
||||
'provided commonName "%s"' % (self.host, certhost))
|
||||
raise TTransportException(
|
||||
type=TTransportException.UNKNOWN,
|
||||
message='Could not validate SSL certificate from '
|
||||
'host "%s". Cert=%s' % (self.host, cert))
|
||||
except TTransportException:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise TTransportException(TTransportException.UNKNOWN, str(ex))
|
||||
|
||||
|
||||
class TSSLServerSocket(TSocket.TServerSocket):
|
||||
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
|
||||
"""SSL implementation of TServerSocket
|
||||
|
||||
This uses the ssl module's wrap_socket() method to provide SSL
|
||||
negotiated encryption.
|
||||
"""
|
||||
SSL_VERSION = ssl.PROTOCOL_TLSv1
|
||||
|
||||
def __init__(self,
|
||||
host=None,
|
||||
port=9090,
|
||||
certfile='cert.pem',
|
||||
unix_socket=None):
|
||||
"""Initialize a TSSLServerSocket
|
||||
# New signature
|
||||
# def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
|
||||
# Deprecated signature
|
||||
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
|
||||
def __init__(self, host=None, port=9090, *args, **kwargs):
|
||||
"""Positional arguments: ``host``, ``port``, ``unix_socket``
|
||||
|
||||
@param certfile: filename of the server certificate, defaults to cert.pem
|
||||
@type certfile: str
|
||||
@param host: The hostname or IP to bind the listen socket to,
|
||||
i.e. 'localhost' for only allowing local network connections.
|
||||
Pass None to bind to all interfaces.
|
||||
@type host: str
|
||||
@param port: The port to listen on for inbound connections.
|
||||
@type port: int
|
||||
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
|
||||
``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
|
||||
See ssl.wrap_socket documentation.
|
||||
|
||||
Alternative keyword arguments: (Python 2.7.9 or later)
|
||||
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
|
||||
``server_hostname``: Passed to SSLContext.wrap_socket
|
||||
|
||||
Common keyword argument:
|
||||
``validate_callback`` (cert, hostname) -> None:
|
||||
Called after SSL handshake. Can raise when hostname does not
|
||||
match the cert.
|
||||
"""
|
||||
self.setCertfile(certfile)
|
||||
TSocket.TServerSocket.__init__(self, host, port)
|
||||
if args:
|
||||
if len(args) > 3:
|
||||
raise TypeError('Too many positional argument')
|
||||
if not self._unix_socket_arg(host, port, args, kwargs):
|
||||
self._deprecated_arg(args, kwargs, 0, 'certfile')
|
||||
self._deprecated_arg(args, kwargs, 1, 'unix_socket')
|
||||
self._deprecated_arg(args, kwargs, 2, 'ciphers')
|
||||
|
||||
if 'ssl_context' not in kwargs:
|
||||
# Preserve existing behaviors for default values
|
||||
if 'cert_reqs' not in kwargs:
|
||||
kwargs['cert_reqs'] = ssl.CERT_NONE
|
||||
if'certfile' not in kwargs:
|
||||
kwargs['certfile'] = 'cert.pem'
|
||||
|
||||
unix_socket = kwargs.pop('unix_socket', None)
|
||||
self._validate_callback = \
|
||||
kwargs.pop('validate_callback', _match_hostname)
|
||||
TSSLBase.__init__(self, True, None, kwargs)
|
||||
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
|
||||
if self._should_verify and not _match_has_ipaddress:
|
||||
raise ValueError('Need ipaddress and backports.ssl_match_hostname '
|
||||
'module to verify client certificate')
|
||||
|
||||
def setCertfile(self, certfile):
|
||||
"""Set or change the server certificate file used to wrap new connections.
|
||||
"""Set or change the server certificate file used to wrap new
|
||||
connections.
|
||||
|
||||
@param certfile: The filename of the server certificate,
|
||||
i.e. '/etc/certs/server.pem'
|
||||
|
@ -179,24 +359,38 @@ class TSSLServerSocket(TSocket.TServerSocket):
|
|||
|
||||
Raises an IOError exception if the certfile is not present or unreadable.
|
||||
"""
|
||||
if not os.access(certfile, os.R_OK):
|
||||
raise IOError('No such certfile found: %s' % (certfile))
|
||||
warnings.warn(
|
||||
'setCertfile is deprecated. please use certfile property instead.',
|
||||
DeprecationWarning, stacklevel=2)
|
||||
self.certfile = certfile
|
||||
|
||||
def accept(self):
|
||||
plain_client, addr = self.handle.accept()
|
||||
try:
|
||||
client = ssl.wrap_socket(plain_client, certfile=self.certfile,
|
||||
server_side=True, ssl_version=self.SSL_VERSION)
|
||||
except ssl.SSLError as ssl_exc:
|
||||
client = self._wrap_socket(plain_client)
|
||||
except ssl.SSLError:
|
||||
logger.exception('Error while accepting from %s', addr)
|
||||
# failed handshake/ssl wrap, close socket to client
|
||||
plain_client.close()
|
||||
# raise ssl_exc
|
||||
# raise
|
||||
# We can't raise the exception, because it kills most TServer derived
|
||||
# serve() methods.
|
||||
# Instead, return None, and let the TServer instance deal with it in
|
||||
# other exception handling. (but TSimpleServer dies anyway)
|
||||
return None
|
||||
|
||||
if self._should_verify:
|
||||
client.peercert = client.getpeercert()
|
||||
try:
|
||||
self._validate_callback(client.peercert, addr[0])
|
||||
client.is_valid = True
|
||||
except Exception:
|
||||
logger.warn('Failed to validate client certificate address: %s',
|
||||
addr[0], exc_info=True)
|
||||
client.close()
|
||||
plain_client.close()
|
||||
return None
|
||||
|
||||
result = TSocket.TSocket()
|
||||
result.setHandle(client)
|
||||
result.handle = client
|
||||
return result
|
||||
|
|
|
@ -18,11 +18,14 @@
|
|||
#
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from TTransport import *
|
||||
from .TTransport import TTransportBase, TTransportException, TServerTransportBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TSocketBase(TTransportBase):
|
||||
|
@ -33,7 +36,7 @@ class TSocketBase(TTransportBase):
|
|||
else:
|
||||
return socket.getaddrinfo(self.host,
|
||||
self.port,
|
||||
socket.AF_UNSPEC,
|
||||
self._socket_family,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
|
||||
|
@ -47,19 +50,21 @@ class TSocketBase(TTransportBase):
|
|||
class TSocket(TSocketBase):
|
||||
"""Socket implementation of TTransport base."""
|
||||
|
||||
def __init__(self, host='localhost', port=9090, unix_socket=None):
|
||||
def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
|
||||
"""Initialize a TSocket
|
||||
|
||||
@param host(str) The host to connect to.
|
||||
@param port(int) The (TCP) port to connect to.
|
||||
@param unix_socket(str) The filename of a unix socket to connect to.
|
||||
(host and port will be ignored.)
|
||||
@param socket_family(int) The socket family to use with this socket.
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.handle = None
|
||||
self._unix_socket = unix_socket
|
||||
self._timeout = None
|
||||
self._socket_family = socket_family
|
||||
|
||||
def setHandle(self, h):
|
||||
self.handle = h
|
||||
|
@ -76,32 +81,41 @@ class TSocket(TSocketBase):
|
|||
if self.handle is not None:
|
||||
self.handle.settimeout(self._timeout)
|
||||
|
||||
def _do_open(self, family, socktype):
|
||||
return socket.socket(family, socktype)
|
||||
|
||||
@property
|
||||
def _address(self):
|
||||
return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)
|
||||
|
||||
def open(self):
|
||||
if self.handle:
|
||||
raise TTransportException(TTransportException.ALREADY_OPEN)
|
||||
try:
|
||||
res0 = self._resolveAddr()
|
||||
for res in res0:
|
||||
self.handle = socket.socket(res[0], res[1])
|
||||
self.handle.settimeout(self._timeout)
|
||||
addrs = self._resolveAddr()
|
||||
except socket.gaierror:
|
||||
msg = 'failed to resolve sockaddr for ' + str(self._address)
|
||||
logger.exception(msg)
|
||||
raise TTransportException(TTransportException.NOT_OPEN, msg)
|
||||
for family, socktype, _, _, sockaddr in addrs:
|
||||
handle = self._do_open(family, socktype)
|
||||
handle.settimeout(self._timeout)
|
||||
try:
|
||||
self.handle.connect(res[4])
|
||||
except socket.error, e:
|
||||
if res is not res0[-1]:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
break
|
||||
except socket.error, e:
|
||||
if self._unix_socket:
|
||||
message = 'Could not connect to socket %s' % self._unix_socket
|
||||
else:
|
||||
message = 'Could not connect to %s:%d' % (self.host, self.port)
|
||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
||||
message=message)
|
||||
handle.connect(sockaddr)
|
||||
self.handle = handle
|
||||
return
|
||||
except socket.error:
|
||||
handle.close()
|
||||
logger.info('Could not connect to %s', sockaddr, exc_info=True)
|
||||
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
|
||||
addrs))
|
||||
logger.error(msg)
|
||||
raise TTransportException(TTransportException.NOT_OPEN, msg)
|
||||
|
||||
def read(self, sz):
|
||||
try:
|
||||
buff = self.handle.recv(sz)
|
||||
except socket.error, e:
|
||||
except socket.error as e:
|
||||
if (e.args[0] == errno.ECONNRESET and
|
||||
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
|
||||
# freebsd and Mach don't follow POSIX semantic of recv
|
||||
|
@ -139,16 +153,18 @@ class TSocket(TSocketBase):
|
|||
class TServerSocket(TSocketBase, TServerTransportBase):
|
||||
"""Socket implementation of TServerTransport base."""
|
||||
|
||||
def __init__(self, host=None, port=9090, unix_socket=None):
|
||||
def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._unix_socket = unix_socket
|
||||
self._socket_family = socket_family
|
||||
self.handle = None
|
||||
|
||||
def listen(self):
|
||||
res0 = self._resolveAddr()
|
||||
socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
|
||||
for res in res0:
|
||||
if res[0] is socket.AF_INET6 or res is res0[-1]:
|
||||
if res[0] is socket_family or res is res0[-1]:
|
||||
break
|
||||
|
||||
# We need remove the old unix socket if the file exists and
|
||||
|
@ -157,7 +173,7 @@ class TServerSocket(TSocketBase, TServerTransportBase):
|
|||
tmp = socket.socket(res[0], res[1])
|
||||
try:
|
||||
tmp.connect(res[4])
|
||||
except socket.error, err:
|
||||
except socket.error as err:
|
||||
eno, message = err.args
|
||||
if eno == errno.ECONNREFUSED:
|
||||
os.unlink(res[4])
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from cStringIO import StringIO
|
||||
from struct import pack, unpack
|
||||
from thrift.Thrift import TException
|
||||
from ..compat import BufferIO
|
||||
|
||||
|
||||
class TTransportException(TException):
|
||||
|
@ -30,13 +30,15 @@ class TTransportException(TException):
|
|||
ALREADY_OPEN = 2
|
||||
TIMED_OUT = 3
|
||||
END_OF_FILE = 4
|
||||
NEGATIVE_SIZE = 5
|
||||
SIZE_LIMIT = 6
|
||||
|
||||
def __init__(self, type=UNKNOWN, message=None):
|
||||
TException.__init__(self, message)
|
||||
self.type = type
|
||||
|
||||
|
||||
class TTransportBase:
|
||||
class TTransportBase(object):
|
||||
"""Base class for Thrift transport layer."""
|
||||
|
||||
def isOpen(self):
|
||||
|
@ -52,7 +54,7 @@ class TTransportBase:
|
|||
pass
|
||||
|
||||
def readAll(self, sz):
|
||||
buff = ''
|
||||
buff = b''
|
||||
have = 0
|
||||
while (have < sz):
|
||||
chunk = self.read(sz - have)
|
||||
|
@ -72,7 +74,7 @@ class TTransportBase:
|
|||
|
||||
|
||||
# This class should be thought of as an interface.
|
||||
class CReadableTransport:
|
||||
class CReadableTransport(object):
|
||||
"""base class for transports that are readable from C"""
|
||||
|
||||
# TODO(dreiss): Think about changing this interface to allow us to use
|
||||
|
@ -100,7 +102,7 @@ class CReadableTransport:
|
|||
pass
|
||||
|
||||
|
||||
class TServerTransportBase:
|
||||
class TServerTransportBase(object):
|
||||
"""Base class for Thrift server transports."""
|
||||
|
||||
def listen(self):
|
||||
|
@ -113,14 +115,14 @@ class TServerTransportBase:
|
|||
pass
|
||||
|
||||
|
||||
class TTransportFactoryBase:
|
||||
class TTransportFactoryBase(object):
|
||||
"""Base class for a Transport Factory"""
|
||||
|
||||
def getTransport(self, trans):
|
||||
return trans
|
||||
|
||||
|
||||
class TBufferedTransportFactory:
|
||||
class TBufferedTransportFactory(object):
|
||||
"""Factory transport that builds buffered transports"""
|
||||
|
||||
def getTransport(self, trans):
|
||||
|
@ -138,8 +140,9 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
|
|||
|
||||
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
|
||||
self.__trans = trans
|
||||
self.__wbuf = StringIO()
|
||||
self.__rbuf = StringIO("")
|
||||
self.__wbuf = BufferIO()
|
||||
# Pass string argument to initialize read buffer as cStringIO.InputType
|
||||
self.__rbuf = BufferIO(b'')
|
||||
self.__rbuf_size = rbuf_size
|
||||
|
||||
def isOpen(self):
|
||||
|
@ -155,17 +158,22 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
|
|||
ret = self.__rbuf.read(sz)
|
||||
if len(ret) != 0:
|
||||
return ret
|
||||
|
||||
self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size)))
|
||||
self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
|
||||
return self.__rbuf.read(sz)
|
||||
|
||||
def write(self, buf):
|
||||
try:
|
||||
self.__wbuf.write(buf)
|
||||
except Exception as e:
|
||||
# on exception reset wbuf so it doesn't contain a partial function call
|
||||
self.__wbuf = BufferIO()
|
||||
raise e
|
||||
self.__wbuf.getvalue()
|
||||
|
||||
def flush(self):
|
||||
out = self.__wbuf.getvalue()
|
||||
# reset wbuf before write/flush to preserve state on underlying failure
|
||||
self.__wbuf = StringIO()
|
||||
self.__wbuf = BufferIO()
|
||||
self.__trans.write(out)
|
||||
self.__trans.flush()
|
||||
|
||||
|
@ -184,12 +192,12 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
|
|||
if len(retstring) < reqlen:
|
||||
retstring += self.__trans.readAll(reqlen - len(retstring))
|
||||
|
||||
self.__rbuf = StringIO(retstring)
|
||||
self.__rbuf = BufferIO(retstring)
|
||||
return self.__rbuf
|
||||
|
||||
|
||||
class TMemoryBuffer(TTransportBase, CReadableTransport):
|
||||
"""Wraps a cStringIO object as a TTransport.
|
||||
"""Wraps a cBytesIO object as a TTransport.
|
||||
|
||||
NOTE: Unlike the C++ version of this class, you cannot write to it
|
||||
then immediately read from it. If you want to read from a
|
||||
|
@ -203,9 +211,9 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
|
|||
If value is set, this will be a transport for reading,
|
||||
otherwise, it is for writing"""
|
||||
if value is not None:
|
||||
self._buffer = StringIO(value)
|
||||
self._buffer = BufferIO(value)
|
||||
else:
|
||||
self._buffer = StringIO()
|
||||
self._buffer = BufferIO()
|
||||
|
||||
def isOpen(self):
|
||||
return not self._buffer.closed
|
||||
|
@ -238,7 +246,7 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
|
|||
raise EOFError()
|
||||
|
||||
|
||||
class TFramedTransportFactory:
|
||||
class TFramedTransportFactory(object):
|
||||
"""Factory transport that builds framed transports"""
|
||||
|
||||
def getTransport(self, trans):
|
||||
|
@ -251,8 +259,8 @@ class TFramedTransport(TTransportBase, CReadableTransport):
|
|||
|
||||
def __init__(self, trans,):
|
||||
self.__trans = trans
|
||||
self.__rbuf = StringIO()
|
||||
self.__wbuf = StringIO()
|
||||
self.__rbuf = BufferIO(b'')
|
||||
self.__wbuf = BufferIO()
|
||||
|
||||
def isOpen(self):
|
||||
return self.__trans.isOpen()
|
||||
|
@ -274,7 +282,7 @@ class TFramedTransport(TTransportBase, CReadableTransport):
|
|||
def readFrame(self):
|
||||
buff = self.__trans.readAll(4)
|
||||
sz, = unpack('!i', buff)
|
||||
self.__rbuf = StringIO(self.__trans.readAll(sz))
|
||||
self.__rbuf = BufferIO(self.__trans.readAll(sz))
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
@ -283,7 +291,7 @@ class TFramedTransport(TTransportBase, CReadableTransport):
|
|||
wout = self.__wbuf.getvalue()
|
||||
wsz = len(wout)
|
||||
# reset wbuf before write/flush to preserve state on underlying failure
|
||||
self.__wbuf = StringIO()
|
||||
self.__wbuf = BufferIO()
|
||||
# N.B.: Doing this string concatenation is WAY cheaper than making
|
||||
# two separate calls to the underlying socket object. Socket writes in
|
||||
# Python turn out to be REALLY expensive, but it seems to do a pretty
|
||||
|
@ -304,7 +312,7 @@ class TFramedTransport(TTransportBase, CReadableTransport):
|
|||
while len(prefix) < reqlen:
|
||||
self.readFrame()
|
||||
prefix += self.__rbuf.getvalue()
|
||||
self.__rbuf = StringIO(prefix)
|
||||
self.__rbuf = BufferIO(prefix)
|
||||
return self.__rbuf
|
||||
|
||||
|
||||
|
@ -328,3 +336,117 @@ class TFileObjectTransport(TTransportBase):
|
|||
|
||||
def flush(self):
|
||||
self.fileobj.flush()
|
||||
|
||||
|
||||
class TSaslClientTransport(TTransportBase, CReadableTransport):
|
||||
"""
|
||||
SASL transport
|
||||
"""
|
||||
|
||||
START = 1
|
||||
OK = 2
|
||||
BAD = 3
|
||||
ERROR = 4
|
||||
COMPLETE = 5
|
||||
|
||||
def __init__(self, transport, host, service, mechanism='GSSAPI',
|
||||
**sasl_kwargs):
|
||||
"""
|
||||
transport: an underlying transport to use, typically just a TSocket
|
||||
host: the name of the server, from a SASL perspective
|
||||
service: the name of the server's service, from a SASL perspective
|
||||
mechanism: the name of the preferred mechanism to use
|
||||
|
||||
All other kwargs will be passed to the puresasl.client.SASLClient
|
||||
constructor.
|
||||
"""
|
||||
|
||||
from puresasl.client import SASLClient
|
||||
|
||||
self.transport = transport
|
||||
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
|
||||
|
||||
self.__wbuf = BufferIO()
|
||||
self.__rbuf = BufferIO(b'')
|
||||
|
||||
def open(self):
|
||||
if not self.transport.isOpen():
|
||||
self.transport.open()
|
||||
|
||||
self.send_sasl_msg(self.START, self.sasl.mechanism)
|
||||
self.send_sasl_msg(self.OK, self.sasl.process())
|
||||
|
||||
while True:
|
||||
status, challenge = self.recv_sasl_msg()
|
||||
if status == self.OK:
|
||||
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
|
||||
elif status == self.COMPLETE:
|
||||
if not self.sasl.complete:
|
||||
raise TTransportException(
|
||||
TTransportException.NOT_OPEN,
|
||||
"The server erroneously indicated "
|
||||
"that SASL negotiation was complete")
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise TTransportException(
|
||||
TTransportException.NOT_OPEN,
|
||||
"Bad SASL negotiation status: %d (%s)"
|
||||
% (status, challenge))
|
||||
|
||||
def send_sasl_msg(self, status, body):
|
||||
header = pack(">BI", status, len(body))
|
||||
self.transport.write(header + body)
|
||||
self.transport.flush()
|
||||
|
||||
def recv_sasl_msg(self):
|
||||
header = self.transport.readAll(5)
|
||||
status, length = unpack(">BI", header)
|
||||
if length > 0:
|
||||
payload = self.transport.readAll(length)
|
||||
else:
|
||||
payload = ""
|
||||
return status, payload
|
||||
|
||||
def write(self, data):
|
||||
self.__wbuf.write(data)
|
||||
|
||||
def flush(self):
|
||||
data = self.__wbuf.getvalue()
|
||||
encoded = self.sasl.wrap(data)
|
||||
self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
|
||||
self.transport.flush()
|
||||
self.__wbuf = BufferIO()
|
||||
|
||||
def read(self, sz):
|
||||
ret = self.__rbuf.read(sz)
|
||||
if len(ret) != 0:
|
||||
return ret
|
||||
|
||||
self._read_frame()
|
||||
return self.__rbuf.read(sz)
|
||||
|
||||
def _read_frame(self):
|
||||
header = self.transport.readAll(4)
|
||||
length, = unpack('!i', header)
|
||||
encoded = self.transport.readAll(length)
|
||||
self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
|
||||
|
||||
def close(self):
|
||||
self.sasl.dispose()
|
||||
self.transport.close()
|
||||
|
||||
# based on TFramedTransport
|
||||
@property
|
||||
def cstringio_buf(self):
|
||||
return self.__rbuf
|
||||
|
||||
def cstringio_refill(self, prefix, reqlen):
|
||||
# self.__rbuf will already be empty here because fastbinary doesn't
|
||||
# ask for a refill until the previous buffer is empty. Therefore,
|
||||
# we can start reading new frames immediately.
|
||||
while len(prefix) < reqlen:
|
||||
self._read_frame()
|
||||
prefix += self.__rbuf.getvalue()
|
||||
self.__rbuf = BufferIO(prefix)
|
||||
return self.__rbuf
|
||||
|
|
|
@ -17,14 +17,15 @@
|
|||
# under the License.
|
||||
#
|
||||
|
||||
from cStringIO import StringIO
|
||||
from io import BytesIO
|
||||
import struct
|
||||
|
||||
from zope.interface import implements, Interface, Attribute
|
||||
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
|
||||
from twisted.internet.protocol import ServerFactory, ClientFactory, \
|
||||
connectionDone
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.threads import deferToThread
|
||||
from twisted.protocols import basic
|
||||
from twisted.python import log
|
||||
from twisted.web import server, resource, http
|
||||
|
||||
from thrift.transport import TTransport
|
||||
|
@ -33,15 +34,15 @@ from thrift.transport import TTransport
|
|||
class TMessageSenderTransport(TTransport.TTransportBase):
|
||||
|
||||
def __init__(self):
|
||||
self.__wbuf = StringIO()
|
||||
self.__wbuf = BytesIO()
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def flush(self):
|
||||
msg = self.__wbuf.getvalue()
|
||||
self.__wbuf = StringIO()
|
||||
self.sendMessage(msg)
|
||||
self.__wbuf = BytesIO()
|
||||
return self.sendMessage(msg)
|
||||
|
||||
def sendMessage(self, message):
|
||||
raise NotImplementedError
|
||||
|
@ -54,7 +55,7 @@ class TCallbackTransport(TMessageSenderTransport):
|
|||
self.func = func
|
||||
|
||||
def sendMessage(self, message):
|
||||
self.func(message)
|
||||
return self.func(message)
|
||||
|
||||
|
||||
class ThriftClientProtocol(basic.Int32StringReceiver):
|
||||
|
@ -81,11 +82,18 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
|
|||
self.started.callback(self.client)
|
||||
|
||||
def connectionLost(self, reason=connectionDone):
|
||||
for k, v in self.client._reqs.iteritems():
|
||||
# the called errbacks can add items to our client's _reqs,
|
||||
# so we need to use a tmp, and iterate until no more requests
|
||||
# are added during errbacks
|
||||
if self.client:
|
||||
tex = TTransport.TTransportException(
|
||||
type=TTransport.TTransportException.END_OF_FILE,
|
||||
message='Connection closed')
|
||||
message='Connection closed (%s)' % reason)
|
||||
while self.client._reqs:
|
||||
_, v = self.client._reqs.popitem()
|
||||
v.errback(tex)
|
||||
del self.client._reqs
|
||||
self.client = None
|
||||
|
||||
def stringReceived(self, frame):
|
||||
tr = TTransport.TMemoryBuffer(frame)
|
||||
|
@ -101,6 +109,108 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
|
|||
method(iprot, mtype, rseqid)
|
||||
|
||||
|
||||
class ThriftSASLClientProtocol(ThriftClientProtocol):
|
||||
|
||||
START = 1
|
||||
OK = 2
|
||||
BAD = 3
|
||||
ERROR = 4
|
||||
COMPLETE = 5
|
||||
|
||||
MAX_LENGTH = 2 ** 31 - 1
|
||||
|
||||
def __init__(self, client_class, iprot_factory, oprot_factory=None,
|
||||
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
|
||||
"""
|
||||
host: the name of the server, from a SASL perspective
|
||||
service: the name of the server's service, from a SASL perspective
|
||||
mechanism: the name of the preferred mechanism to use
|
||||
|
||||
All other kwargs will be passed to the puresasl.client.SASLClient
|
||||
constructor.
|
||||
"""
|
||||
|
||||
from puresasl.client import SASLClient
|
||||
self.SASLCLient = SASLClient
|
||||
|
||||
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
|
||||
|
||||
self._sasl_negotiation_deferred = None
|
||||
self._sasl_negotiation_status = None
|
||||
self.client = None
|
||||
|
||||
if host is not None:
|
||||
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
|
||||
|
||||
def createSASLClient(self, host, service, mechanism, **kwargs):
|
||||
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
|
||||
|
||||
def dispatch(self, msg):
|
||||
encoded = self.sasl.wrap(msg)
|
||||
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
|
||||
ThriftClientProtocol.dispatch(self, len_and_encoded)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def connectionMade(self):
|
||||
self._sendSASLMessage(self.START, self.sasl.mechanism)
|
||||
initial_message = yield deferToThread(self.sasl.process)
|
||||
self._sendSASLMessage(self.OK, initial_message)
|
||||
|
||||
while True:
|
||||
status, challenge = yield self._receiveSASLMessage()
|
||||
if status == self.OK:
|
||||
response = yield deferToThread(self.sasl.process, challenge)
|
||||
self._sendSASLMessage(self.OK, response)
|
||||
elif status == self.COMPLETE:
|
||||
if not self.sasl.complete:
|
||||
msg = "The server erroneously indicated that SASL " \
|
||||
"negotiation was complete"
|
||||
raise TTransport.TTransportException(msg, message=msg)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
|
||||
raise TTransport.TTransportException(msg, message=msg)
|
||||
|
||||
self._sasl_negotiation_deferred = None
|
||||
ThriftClientProtocol.connectionMade(self)
|
||||
|
||||
def _sendSASLMessage(self, status, body):
|
||||
if body is None:
|
||||
body = ""
|
||||
header = struct.pack(">BI", status, len(body))
|
||||
self.transport.write(header + body)
|
||||
|
||||
def _receiveSASLMessage(self):
|
||||
self._sasl_negotiation_deferred = defer.Deferred()
|
||||
self._sasl_negotiation_status = None
|
||||
return self._sasl_negotiation_deferred
|
||||
|
||||
def connectionLost(self, reason=connectionDone):
|
||||
if self.client:
|
||||
ThriftClientProtocol.connectionLost(self, reason)
|
||||
|
||||
def dataReceived(self, data):
|
||||
if self._sasl_negotiation_deferred:
|
||||
# we got a sasl challenge in the format (status, length, challenge)
|
||||
# save the status, let IntNStringReceiver piece the challenge data together
|
||||
self._sasl_negotiation_status, = struct.unpack("B", data[0])
|
||||
ThriftClientProtocol.dataReceived(self, data[1:])
|
||||
else:
|
||||
# normal frame, let IntNStringReceiver piece it together
|
||||
ThriftClientProtocol.dataReceived(self, data)
|
||||
|
||||
def stringReceived(self, frame):
|
||||
if self._sasl_negotiation_deferred:
|
||||
# the frame is just a SASL challenge
|
||||
response = (self._sasl_negotiation_status, frame)
|
||||
self._sasl_negotiation_deferred.callback(response)
|
||||
else:
|
||||
# there's a second 4 byte length prefix inside the frame
|
||||
decoded_frame = self.sasl.unwrap(frame[4:])
|
||||
ThriftClientProtocol.stringReceived(self, decoded_frame)
|
||||
|
||||
|
||||
class ThriftServerProtocol(basic.Int32StringReceiver):
|
||||
|
||||
MAX_LENGTH = 2 ** 31 - 1
|
||||
|
|
|
@ -24,8 +24,8 @@ data compression.
|
|||
|
||||
from __future__ import division
|
||||
import zlib
|
||||
from cStringIO import StringIO
|
||||
from TTransport import TTransportBase, CReadableTransport
|
||||
from .TTransport import TTransportBase, CReadableTransport
|
||||
from ..compat import BufferIO
|
||||
|
||||
|
||||
class TZlibTransportFactory(object):
|
||||
|
@ -88,8 +88,8 @@ class TZlibTransport(TTransportBase, CReadableTransport):
|
|||
"""
|
||||
self.__trans = trans
|
||||
self.compresslevel = compresslevel
|
||||
self.__rbuf = StringIO()
|
||||
self.__wbuf = StringIO()
|
||||
self.__rbuf = BufferIO()
|
||||
self.__wbuf = BufferIO()
|
||||
self._init_zlib()
|
||||
self._init_stats()
|
||||
|
||||
|
@ -97,8 +97,8 @@ class TZlibTransport(TTransportBase, CReadableTransport):
|
|||
"""Internal method to initialize/reset the internal StringIO objects
|
||||
for read and write buffers.
|
||||
"""
|
||||
self.__rbuf = StringIO()
|
||||
self.__wbuf = StringIO()
|
||||
self.__rbuf = BufferIO()
|
||||
self.__wbuf = BufferIO()
|
||||
|
||||
def _init_stats(self):
|
||||
"""Internal method to reset the internal statistics counters
|
||||
|
@ -203,7 +203,7 @@ class TZlibTransport(TTransportBase, CReadableTransport):
|
|||
self.bytes_in += len(zbuf)
|
||||
self.bytes_in_comp += len(buf)
|
||||
old = self.__rbuf.read()
|
||||
self.__rbuf = StringIO(old + buf)
|
||||
self.__rbuf = BufferIO(old + buf)
|
||||
if len(old) + len(buf) == 0:
|
||||
return False
|
||||
return True
|
||||
|
@ -228,7 +228,7 @@ class TZlibTransport(TTransportBase, CReadableTransport):
|
|||
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
|
||||
self.bytes_out_comp += len(ztail)
|
||||
if (len(zbuf) + len(ztail)) > 0:
|
||||
self.__wbuf = StringIO()
|
||||
self.__wbuf = BufferIO()
|
||||
self.__trans.write(zbuf + ztail)
|
||||
self.__trans.flush()
|
||||
|
||||
|
@ -244,5 +244,5 @@ class TZlibTransport(TTransportBase, CReadableTransport):
|
|||
retstring += self.read(self.DEFAULT_BUFFSIZE)
|
||||
while len(retstring) < reqlen:
|
||||
retstring += self.read(reqlen - len(retstring))
|
||||
self.__rbuf = StringIO(retstring)
|
||||
self.__rbuf = BufferIO(retstring)
|
||||
return self.__rbuf
|
||||
|
|
99
thrift/transport/sslcompat.py
Normal file
99
thrift/transport/sslcompat.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
#
|
||||
# licensed to the apache software foundation (asf) under one
|
||||
# or more contributor license agreements. see the notice file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. the asf licenses this file
|
||||
# to you under the apache license, version 2.0 (the
|
||||
# "license"); you may not use this file except in compliance
|
||||
# with the license. you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the license is distributed on an
|
||||
# "as is" basis, without warranties or conditions of any
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def legacy_validate_callback(self, cert, hostname):
|
||||
"""legacy method to validate the peer's SSL certificate, and to check
|
||||
the commonName of the certificate to ensure it matches the hostname we
|
||||
used to make this connection. Does not support subjectAltName records
|
||||
in certificates.
|
||||
|
||||
raises TTransportException if the certificate fails validation.
|
||||
"""
|
||||
if 'subject' not in cert:
|
||||
raise TTransportException(
|
||||
TTransportException.NOT_OPEN,
|
||||
'No SSL certificate found from %s:%s' % (self.host, self.port))
|
||||
fields = cert['subject']
|
||||
for field in fields:
|
||||
# ensure structure we get back is what we expect
|
||||
if not isinstance(field, tuple):
|
||||
continue
|
||||
cert_pair = field[0]
|
||||
if len(cert_pair) < 2:
|
||||
continue
|
||||
cert_key, cert_value = cert_pair[0:2]
|
||||
if cert_key != 'commonName':
|
||||
continue
|
||||
certhost = cert_value
|
||||
# this check should be performed by some sort of Access Manager
|
||||
if certhost == hostname:
|
||||
# success, cert commonName matches desired hostname
|
||||
return
|
||||
else:
|
||||
raise TTransportException(
|
||||
TTransportException.UNKNOWN,
|
||||
'Hostname we connected to "%s" doesn\'t match certificate '
|
||||
'provided commonName "%s"' % (self.host, certhost))
|
||||
raise TTransportException(
|
||||
TTransportException.UNKNOWN,
|
||||
'Could not validate SSL certificate from host "%s". Cert=%s'
|
||||
% (hostname, cert))
|
||||
|
||||
|
||||
def _optional_dependencies():
|
||||
try:
|
||||
import ipaddress # noqa
|
||||
logger.debug('ipaddress module is available')
|
||||
ipaddr = True
|
||||
except ImportError:
|
||||
logger.warn('ipaddress module is unavailable')
|
||||
ipaddr = False
|
||||
|
||||
if sys.hexversion < 0x030500F0:
|
||||
try:
|
||||
from backports.ssl_match_hostname import match_hostname, __version__ as ver
|
||||
ver = list(map(int, ver.split('.')))
|
||||
logger.debug('backports.ssl_match_hostname module is available')
|
||||
match = match_hostname
|
||||
if ver[0] * 10 + ver[1] >= 35:
|
||||
return ipaddr, match
|
||||
else:
|
||||
logger.warn('backports.ssl_match_hostname module is too old')
|
||||
ipaddr = False
|
||||
except ImportError:
|
||||
logger.warn('backports.ssl_match_hostname is unavailable')
|
||||
ipaddr = False
|
||||
try:
|
||||
from ssl import match_hostname
|
||||
logger.debug('ssl.match_hostname is available')
|
||||
match = match_hostname
|
||||
except ImportError:
|
||||
logger.warn('using legacy validation callback')
|
||||
match = legacy_validate_callback
|
||||
return ipaddr, match
|
||||
|
||||
_match_has_ipaddress, _match_hostname = _optional_dependencies()
|
Loading…
Add table
Reference in a new issue