thrift update to 0.10.0

This commit is contained in:
mjames-upc 2018-06-19 09:19:08 -06:00
parent 3837f21015
commit 0ddbcd4bb0
25 changed files with 3878 additions and 2065 deletions

View file

@ -0,0 +1,55 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TProcessor, TMessageType, TException
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
class TMultiplexedProcessor(TProcessor):
def __init__(self):
self.services = {}
def registerProcessor(self, serviceName, processor):
self.services[serviceName] = processor
def process(self, iprot, oprot):
(name, type, seqid) = iprot.readMessageBegin()
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
raise TException("TMultiplex protocol only supports CALL & ONEWAY")
index = name.find(TMultiplexedProtocol.SEPARATOR)
if index < 0:
raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
serviceName = name[0:index]
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
if serviceName not in self.services:
raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
standardMessage = (call, type, seqid)
return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, messageBegin):
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
self.messageBegin = messageBegin
def readMessageBegin(self):
return self.messageBegin

View file

@ -19,17 +19,18 @@
from os import path from os import path
from SCons.Builder import Builder from SCons.Builder import Builder
from six.moves import map
def scons_env(env, add=''): def scons_env(env, add=''):
opath = path.dirname(path.abspath('$TARGET')) opath = path.dirname(path.abspath('$TARGET'))
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
cppbuild = Builder(action=lstr) cppbuild = Builder(action=lstr)
env.Append(BUILDERS={'ThriftCpp': cppbuild}) env.Append(BUILDERS={'ThriftCpp': cppbuild})
def gen_cpp(env, dir, file): def gen_cpp(env, dir, file):
scons_env(env) scons_env(env)
suffixes = ['_types.h', '_types.cpp'] suffixes = ['_types.h', '_types.cpp']
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
return env.ThriftCpp(targets, dir + file + '.thrift') return env.ThriftCpp(targets, dir + file + '.thrift')

View file

@ -17,8 +17,8 @@
# under the License. # under the License.
# #
from protocol import TBinaryProtocol from .protocol import TBinaryProtocol
from transport import TTransport from .transport import TTransport
def serialize(thrift_object, def serialize(thrift_object,

188
thrift/TTornado.py Normal file
View file

@ -0,0 +1,188 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
import logging
import socket
import struct
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
from io import BytesIO
from collections import deque
from contextlib import contextmanager
from tornado import gen, iostream, ioloop, tcpserver, concurrent
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
logger = logging.getLogger(__name__)
class _Lock(object):
def __init__(self):
self._waiters = deque()
def acquired(self):
return len(self._waiters) > 0
@gen.coroutine
def acquire(self):
blocker = self._waiters[-1] if self.acquired() else None
future = concurrent.Future()
self._waiters.append(future)
if blocker:
yield blocker
raise gen.Return(self._lock_context())
def release(self):
assert self.acquired(), 'Lock not aquired'
future = self._waiters.popleft()
future.set_result(None)
@contextmanager
def _lock_context(self):
try:
yield
finally:
self.release()
class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
def __init__(self, host, port, stream=None, io_loop=None):
self.host = host
self.port = port
self.io_loop = io_loop or ioloop.IOLoop.current()
self.__wbuf = BytesIO()
self._read_lock = _Lock()
# servers provide a ready-to-go stream
self.stream = stream
def with_timeout(self, timeout, future):
return gen.with_timeout(timeout, future, self.io_loop)
@gen.coroutine
def open(self, timeout=None):
logger.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
try:
connect = self.stream.connect((self.host, self.port))
if timeout is not None:
yield self.with_timeout(timeout, connect)
else:
yield connect
except (socket.error, IOError, ioloop.TimeoutError) as e:
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
raise gen.Return(self)
def set_close_callback(self, callback):
"""
Should be called only after open() returns
"""
self.stream.set_close_callback(callback)
def close(self):
# don't raise if we intend to close
self.stream.set_close_callback(None)
self.stream.close()
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
assert False, "you're doing it wrong"
@contextmanager
def io_exception_context(self):
try:
yield
except (socket.error, IOError) as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message=str(e))
except iostream.StreamBufferFullError as e:
raise TTransportException(
type=TTransportException.UNKNOWN,
message=str(e))
@gen.coroutine
def readFrame(self):
# IOStream processes reads one at a time
with (yield self._read_lock.acquire()):
with self.io_exception_context():
frame_header = yield self.stream.read_bytes(4)
if len(frame_header) == 0:
raise iostream.StreamClosedError('Read zero bytes from stream')
frame_length, = struct.unpack('!i', frame_header)
frame = yield self.stream.read_bytes(frame_length)
raise gen.Return(frame)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
frame = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
frame_length = struct.pack('!i', len(frame))
self.__wbuf = BytesIO()
with self.io_exception_context():
return self.stream.write(frame_length + frame)
class TTornadoServer(tcpserver.TCPServer):
def __init__(self, processor, iprot_factory, oprot_factory=None,
*args, **kwargs):
super(TTornadoServer, self).__init__(*args, **kwargs)
self._processor = processor
self._iprot_factory = iprot_factory
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
@gen.coroutine
def handle_stream(self, stream, address):
host, port = address[:2]
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
io_loop=self.io_loop)
oprot = self._oprot_factory.getProtocol(trans)
try:
while not trans.stream.closed():
try:
frame = yield trans.readFrame()
except TTransportException as e:
if e.type == TTransportException.END_OF_FILE:
break
else:
raise
tr = TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
yield self._processor.process(iprot, oprot)
except Exception:
logger.exception('thrift exception in handle_stream')
trans.close()
logger.info('client disconnected %s:%d', host, port)

View file

@ -20,138 +20,173 @@
import sys import sys
class TType: class TType(object):
STOP = 0 STOP = 0
VOID = 1 VOID = 1
BOOL = 2 BOOL = 2
BYTE = 3 BYTE = 3
I08 = 3 I08 = 3
DOUBLE = 4 DOUBLE = 4
I16 = 6 I16 = 6
I32 = 8 I32 = 8
I64 = 10 I64 = 10
STRING = 11 STRING = 11
UTF7 = 11 UTF7 = 11
STRUCT = 12 STRUCT = 12
MAP = 13 MAP = 13
SET = 14 SET = 14
LIST = 15 LIST = 15
UTF8 = 16 UTF8 = 16
UTF16 = 17 UTF16 = 17
_VALUES_TO_NAMES = ('STOP', _VALUES_TO_NAMES = (
'VOID', 'STOP',
'BOOL', 'VOID',
'BYTE', 'BOOL',
'DOUBLE', 'BYTE',
None, 'DOUBLE',
'I16', None,
None, 'I16',
'I32', None,
None, 'I32',
'I64', None,
'STRING', 'I64',
'STRUCT', 'STRING',
'MAP', 'STRUCT',
'SET', 'MAP',
'LIST', 'SET',
'UTF8', 'LIST',
'UTF16') 'UTF8',
'UTF16',
)
class TMessageType: class TMessageType(object):
CALL = 1 CALL = 1
REPLY = 2 REPLY = 2
EXCEPTION = 3 EXCEPTION = 3
ONEWAY = 4 ONEWAY = 4
class TProcessor: class TProcessor(object):
"""Base class for procsessor, which works on two streams.""" """Base class for procsessor, which works on two streams."""
def process(iprot, oprot): def process(iprot, oprot):
pass pass
class TException(Exception): class TException(Exception):
"""Base class for all thrift exceptions.""" """Base class for all thrift exceptions."""
# BaseException.message is deprecated in Python v[2.6,3.0) # BaseException.message is deprecated in Python v[2.6,3.0)
if (2, 6, 0) <= sys.version_info < (3, 0): if (2, 6, 0) <= sys.version_info < (3, 0):
def _get_message(self): def _get_message(self):
return self._message return self._message
def _set_message(self, message): def _set_message(self, message):
self._message = message self._message = message
message = property(_get_message, _set_message) message = property(_get_message, _set_message)
def __init__(self, message=None): def __init__(self, message=None):
Exception.__init__(self, message) Exception.__init__(self, message)
self.message = message self.message = message
class TApplicationException(TException): class TApplicationException(TException):
"""Application level thrift exceptions.""" """Application level thrift exceptions."""
UNKNOWN = 0 UNKNOWN = 0
UNKNOWN_METHOD = 1 UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE = 2 INVALID_MESSAGE_TYPE = 2
WRONG_METHOD_NAME = 3 WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4 BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5 MISSING_RESULT = 5
INTERNAL_ERROR = 6 INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7 PROTOCOL_ERROR = 7
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
def __init__(self, type=UNKNOWN, message=None): def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message) TException.__init__(self, message)
self.type = type self.type = type
def __str__(self): def __str__(self):
if self.message: if self.message:
return self.message return self.message
elif self.type == self.UNKNOWN_METHOD: elif self.type == self.UNKNOWN_METHOD:
return 'Unknown method' return 'Unknown method'
elif self.type == self.INVALID_MESSAGE_TYPE: elif self.type == self.INVALID_MESSAGE_TYPE:
return 'Invalid message type' return 'Invalid message type'
elif self.type == self.WRONG_METHOD_NAME: elif self.type == self.WRONG_METHOD_NAME:
return 'Wrong method name' return 'Wrong method name'
elif self.type == self.BAD_SEQUENCE_ID: elif self.type == self.BAD_SEQUENCE_ID:
return 'Bad sequence ID' return 'Bad sequence ID'
elif self.type == self.MISSING_RESULT: elif self.type == self.MISSING_RESULT:
return 'Missing result' return 'Missing result'
else: elif self.type == self.INTERNAL_ERROR:
return 'Default (unknown) TApplicationException' return 'Internal error'
elif self.type == self.PROTOCOL_ERROR:
def read(self, iprot): return 'Protocol error'
iprot.readStructBegin() elif self.type == self.INVALID_TRANSFORM:
while True: return 'Invalid transform'
(fname, ftype, fid) = iprot.readFieldBegin() elif self.type == self.INVALID_PROTOCOL:
if ftype == TType.STOP: return 'Invalid protocol'
break elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
if fid == 1: return 'Unsupported client type'
if ftype == TType.STRING:
self.message = iprot.readString()
else: else:
iprot.skip(ftype) return 'Default (unknown) TApplicationException'
elif fid == 2:
if ftype == TType.I32:
self.type = iprot.readI32()
else:
iprot.skip(ftype)
else:
iprot.skip(ftype)
iprot.readFieldEnd()
iprot.readStructEnd()
def write(self, oprot): def read(self, iprot):
oprot.writeStructBegin('TApplicationException') iprot.readStructBegin()
if self.message is not None: while True:
oprot.writeFieldBegin('message', TType.STRING, 1) (fname, ftype, fid) = iprot.readFieldBegin()
oprot.writeString(self.message) if ftype == TType.STOP:
oprot.writeFieldEnd() break
if self.type is not None: if fid == 1:
oprot.writeFieldBegin('type', TType.I32, 2) if ftype == TType.STRING:
oprot.writeI32(self.type) self.message = iprot.readString()
oprot.writeFieldEnd() else:
oprot.writeFieldStop() iprot.skip(ftype)
oprot.writeStructEnd() elif fid == 2:
if ftype == TType.I32:
self.type = iprot.readI32()
else:
iprot.skip(ftype)
else:
iprot.skip(ftype)
iprot.readFieldEnd()
iprot.readStructEnd()
def write(self, oprot):
oprot.writeStructBegin('TApplicationException')
if self.message is not None:
oprot.writeFieldBegin('message', TType.STRING, 1)
oprot.writeString(self.message)
oprot.writeFieldEnd()
if self.type is not None:
oprot.writeFieldBegin('type', TType.I32, 2)
oprot.writeI32(self.type)
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()
class TFrozenDict(dict):
"""A dictionary that is "frozen" like a frozenset"""
def __init__(self, *args, **kwargs):
super(TFrozenDict, self).__init__(*args, **kwargs)
# Sort the items so they will be in a consistent order.
# XOR in the hash of the class so we don't collide with
# the hash of a list of tuples.
self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
def __setitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __delitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __hash__(self):
return self.__hashval

40
thrift/compat.py Normal file
View file

@ -0,0 +1,40 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
if sys.version_info[0] == 2:
from cStringIO import StringIO as BufferIO
def binary_to_str(bin_val):
return bin_val
def str_to_binary(str_val):
return str_val
else:
from io import BytesIO as BufferIO # noqa
def binary_to_str(bin_val):
return bin_val.decode('utf8')
def str_to_binary(str_val):
return bytes(str_val, 'utf8')

View file

@ -17,65 +17,66 @@
# under the License. # under the License.
# #
from thrift.Thrift import *
from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport from thrift.transport import TTransport
try:
from thrift.protocol import fastbinary
except:
fastbinary = None
class TBase(object): class TBase(object):
__slots__ = [] __slots__ = ()
def __repr__(self): def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key)) L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
for key in self.__slots__] return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
for attr in self.__slots__: for attr in self.__slots__:
my_val = getattr(self, attr) my_val = getattr(self, attr)
other_val = getattr(other, attr) other_val = getattr(other, attr)
if my_val != other_val: if my_val != other_val:
return False return False
return True return True
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
def read(self, iprot): def read(self, iprot):
if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and isinstance(iprot.trans, TTransport.CReadableTransport) and
self.thrift_spec is not None and self.thrift_spec is not None):
fastbinary is not None): iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))
fastbinary.decode_binary(self, else:
iprot.trans, iprot.readStruct(self, self.thrift_spec)
(self.__class__, self.thrift_spec))
return
iprot.readStruct(self, self.thrift_spec)
def write(self, oprot): def write(self, oprot):
if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and if (oprot._fast_encode is not None and self.thrift_spec is not None):
self.thrift_spec is not None and oprot.trans.write(
fastbinary is not None): oprot._fast_encode(self, (self.__class__, self.thrift_spec)))
oprot.trans.write( else:
fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) oprot.writeStruct(self, self.thrift_spec)
return
oprot.writeStruct(self, self.thrift_spec)
class TExceptionBase(Exception): class TExceptionBase(TBase, Exception):
# old style class so python2.4 can raise exceptions derived from this pass
# This can't inherit from TBase because of that limitation.
__slots__ = []
__repr__ = TBase.__repr__.im_func
__eq__ = TBase.__eq__.im_func class TFrozenBase(TBase):
__ne__ = TBase.__ne__.im_func def __setitem__(self, *args):
read = TBase.read.im_func raise TypeError("Can't modify frozen struct")
write = TBase.write.im_func
def __delitem__(self, *args):
raise TypeError("Can't modify frozen struct")
def __hash__(self, *args):
return hash(self.__class__) ^ hash(self.__slots__)
@classmethod
def read(cls, iprot):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
cls.thrift_spec is not None):
self = cls()
return iprot._fast_decode(None, iprot,
(self.__class__, self.thrift_spec))
else:
return iprot.readStruct(cls, cls.thrift_spec, True)

View file

@ -17,244 +17,285 @@
# under the License. # under the License.
# #
from TProtocol import * from .TProtocol import TType, TProtocolBase, TProtocolException
from struct import pack, unpack from struct import pack, unpack
class TBinaryProtocol(TProtocolBase): class TBinaryProtocol(TProtocolBase):
"""Binary implementation of the Thrift protocol driver.""" """Binary implementation of the Thrift protocol driver."""
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
# positive, converting this into a long. If we hardcode the int value # positive, converting this into a long. If we hardcode the int value
# instead it'll stay in 32 bit-land. # instead it'll stay in 32 bit-land.
# VERSION_MASK = 0xffff0000 # VERSION_MASK = 0xffff0000
VERSION_MASK = -65536 VERSION_MASK = -65536
# VERSION_1 = 0x80010000 # VERSION_1 = 0x80010000
VERSION_1 = -2147418112 VERSION_1 = -2147418112
TYPE_MASK = 0x000000ff TYPE_MASK = 0x000000ff
def __init__(self, trans, strictRead=False, strictWrite=True): def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
TProtocolBase.__init__(self, trans) TProtocolBase.__init__(self, trans)
self.strictRead = strictRead self.strictRead = strictRead
self.strictWrite = strictWrite self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def writeMessageBegin(self, name, type, seqid): def _check_string_length(self, length):
if self.strictWrite: self._check_length(self.string_length_limit, length)
self.writeI32(TBinaryProtocol.VERSION_1 | type)
self.writeString(name)
self.writeI32(seqid)
else:
self.writeString(name)
self.writeByte(type)
self.writeI32(seqid)
def writeMessageEnd(self): def _check_container_length(self, length):
pass self._check_length(self.container_length_limit, length)
def writeStructBegin(self, name): def writeMessageBegin(self, name, type, seqid):
pass if self.strictWrite:
self.writeI32(TBinaryProtocol.VERSION_1 | type)
self.writeString(name)
self.writeI32(seqid)
else:
self.writeString(name)
self.writeByte(type)
self.writeI32(seqid)
def writeStructEnd(self): def writeMessageEnd(self):
pass pass
def writeFieldBegin(self, name, type, id): def writeStructBegin(self, name):
self.writeByte(type) pass
self.writeI16(id)
def writeFieldEnd(self): def writeStructEnd(self):
pass pass
def writeFieldStop(self): def writeFieldBegin(self, name, type, id):
self.writeByte(TType.STOP) self.writeByte(type)
self.writeI16(id)
def writeMapBegin(self, ktype, vtype, size): def writeFieldEnd(self):
self.writeByte(ktype) pass
self.writeByte(vtype)
self.writeI32(size)
def writeMapEnd(self): def writeFieldStop(self):
pass self.writeByte(TType.STOP)
def writeListBegin(self, etype, size): def writeMapBegin(self, ktype, vtype, size):
self.writeByte(etype) self.writeByte(ktype)
self.writeI32(size) self.writeByte(vtype)
self.writeI32(size)
def writeListEnd(self): def writeMapEnd(self):
pass pass
def writeSetBegin(self, etype, size): def writeListBegin(self, etype, size):
self.writeByte(etype) self.writeByte(etype)
self.writeI32(size) self.writeI32(size)
def writeSetEnd(self): def writeListEnd(self):
pass pass
def writeBool(self, bool): def writeSetBegin(self, etype, size):
if bool: self.writeByte(etype)
self.writeByte(1) self.writeI32(size)
else:
self.writeByte(0)
def writeByte(self, byte): def writeSetEnd(self):
buff = pack("!b", byte) pass
self.trans.write(buff)
def writeI16(self, i16): def writeBool(self, bool):
buff = pack("!h", i16) if bool:
self.trans.write(buff) self.writeByte(1)
else:
self.writeByte(0)
def writeI32(self, i32): def writeByte(self, byte):
buff = pack("!i", i32) buff = pack("!b", byte)
self.trans.write(buff) self.trans.write(buff)
def writeI64(self, i64): def writeI16(self, i16):
buff = pack("!q", i64) buff = pack("!h", i16)
self.trans.write(buff) self.trans.write(buff)
def writeDouble(self, dub): def writeI32(self, i32):
buff = pack("!d", dub) buff = pack("!i", i32)
self.trans.write(buff) self.trans.write(buff)
def writeString(self, str): def writeI64(self, i64):
self.writeI32(len(str)) buff = pack("!q", i64)
self.trans.write(str) self.trans.write(buff)
def readMessageBegin(self): def writeDouble(self, dub):
sz = self.readI32() buff = pack("!d", dub)
if sz < 0: self.trans.write(buff)
version = sz & TBinaryProtocol.VERSION_MASK
if version != TBinaryProtocol.VERSION_1:
raise TProtocolException(
type=TProtocolException.BAD_VERSION,
message='Bad version in readMessageBegin: %d' % (sz))
type = sz & TBinaryProtocol.TYPE_MASK
name = self.readString()
seqid = self.readI32()
else:
if self.strictRead:
raise TProtocolException(type=TProtocolException.BAD_VERSION,
message='No protocol version header')
name = self.trans.readAll(sz)
type = self.readByte()
seqid = self.readI32()
return (name, type, seqid)
def readMessageEnd(self): def writeBinary(self, str):
pass self.writeI32(len(str))
self.trans.write(str)
def readStructBegin(self): def readMessageBegin(self):
pass sz = self.readI32()
if sz < 0:
version = sz & TBinaryProtocol.VERSION_MASK
if version != TBinaryProtocol.VERSION_1:
raise TProtocolException(
type=TProtocolException.BAD_VERSION,
message='Bad version in readMessageBegin: %d' % (sz))
type = sz & TBinaryProtocol.TYPE_MASK
name = self.readString()
seqid = self.readI32()
else:
if self.strictRead:
raise TProtocolException(type=TProtocolException.BAD_VERSION,
message='No protocol version header')
name = self.trans.readAll(sz)
type = self.readByte()
seqid = self.readI32()
return (name, type, seqid)
def readStructEnd(self): def readMessageEnd(self):
pass pass
def readFieldBegin(self): def readStructBegin(self):
type = self.readByte() pass
if type == TType.STOP:
return (None, type, 0)
id = self.readI16()
return (None, type, id)
def readFieldEnd(self): def readStructEnd(self):
pass pass
def readMapBegin(self): def readFieldBegin(self):
ktype = self.readByte() type = self.readByte()
vtype = self.readByte() if type == TType.STOP:
size = self.readI32() return (None, type, 0)
return (ktype, vtype, size) id = self.readI16()
return (None, type, id)
def readMapEnd(self): def readFieldEnd(self):
pass pass
def readListBegin(self): def readMapBegin(self):
etype = self.readByte() ktype = self.readByte()
size = self.readI32() vtype = self.readByte()
return (etype, size) size = self.readI32()
self._check_container_length(size)
return (ktype, vtype, size)
def readListEnd(self): def readMapEnd(self):
pass pass
def readSetBegin(self): def readListBegin(self):
etype = self.readByte() etype = self.readByte()
size = self.readI32() size = self.readI32()
return (etype, size) self._check_container_length(size)
return (etype, size)
def readSetEnd(self): def readListEnd(self):
pass pass
def readBool(self): def readSetBegin(self):
byte = self.readByte() etype = self.readByte()
if byte == 0: size = self.readI32()
return False self._check_container_length(size)
return True return (etype, size)
def readByte(self): def readSetEnd(self):
buff = self.trans.readAll(1) pass
val, = unpack('!b', buff)
return val
def readI16(self): def readBool(self):
buff = self.trans.readAll(2) byte = self.readByte()
val, = unpack('!h', buff) if byte == 0:
return val return False
return True
def readI32(self): def readByte(self):
buff = self.trans.readAll(4) buff = self.trans.readAll(1)
val, = unpack('!i', buff) val, = unpack('!b', buff)
return val return val
def readI64(self): def readI16(self):
buff = self.trans.readAll(8) buff = self.trans.readAll(2)
val, = unpack('!q', buff) val, = unpack('!h', buff)
return val return val
def readDouble(self): def readI32(self):
buff = self.trans.readAll(8) buff = self.trans.readAll(4)
val, = unpack('!d', buff) val, = unpack('!i', buff)
return val return val
def readString(self): def readI64(self):
len = self.readI32() buff = self.trans.readAll(8)
str = self.trans.readAll(len) val, = unpack('!q', buff)
return str return val
def readDouble(self):
buff = self.trans.readAll(8)
val, = unpack('!d', buff)
return val
def readBinary(self):
size = self.readI32()
self._check_string_length(size)
s = self.trans.readAll(size)
return s
class TBinaryProtocolFactory: class TBinaryProtocolFactory(object):
def __init__(self, strictRead=False, strictWrite=True): def __init__(self, strictRead=False, strictWrite=True, **kwargs):
self.strictRead = strictRead self.strictRead = strictRead
self.strictWrite = strictWrite self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def getProtocol(self, trans): def getProtocol(self, trans):
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
return prot string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit)
return prot
class TBinaryProtocolAccelerated(TBinaryProtocol): class TBinaryProtocolAccelerated(TBinaryProtocol):
"""C-Accelerated version of TBinaryProtocol. """C-Accelerated version of TBinaryProtocol.
This class does not override any of TBinaryProtocol's methods, This class does not override any of TBinaryProtocol's methods,
but the generated code recognizes it directly and will call into but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely. our C module to do the encoding, bypassing this object entirely.
We inherit from TBinaryProtocol so that the normal TBinaryProtocol We inherit from TBinaryProtocol so that the normal TBinaryProtocol
encoding can happen if the fastbinary module doesn't work for some encoding can happen if the fastbinary module doesn't work for some
reason. (TODO(dreiss): Make this happen sanely in more cases.) reason. (TODO(dreiss): Make this happen sanely in more cases.)
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use In order to take advantage of the C module, just use
TBinaryProtocolAccelerated instead of TBinaryProtocol. TBinaryProtocolAccelerated instead of TBinaryProtocol.
NOTE: This code was contributed by an external developer. NOTE: This code was contributed by an external developer.
The internal Thrift team has reviewed and tested it, The internal Thrift team has reviewed and tested it,
but we cannot guarantee that it is production-ready. but we cannot guarantee that it is production-ready.
Please feel free to report bugs and/or success stories Please feel free to report bugs and/or success stories
to the public mailing list. to the public mailing list.
""" """
pass pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_binary
self._fast_encode = fastbinary.encode_binary
class TBinaryProtocolAcceleratedFactory: class TBinaryProtocolAcceleratedFactory(object):
def getProtocol(self, trans): def __init__(self,
return TBinaryProtocolAccelerated(trans) string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TBinaryProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)

View file

@ -17,9 +17,11 @@
# under the License. # under the License.
# #
from TProtocol import * from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
from struct import pack, unpack from struct import pack, unpack
from ..compat import binary_to_str, str_to_binary
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] __all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
CLEAR = 0 CLEAR = 0
@ -34,370 +36,437 @@ BOOL_READ = 8
def make_helper(v_from, container): def make_helper(v_from, container):
def helper(func): def helper(func):
def nested(self, *args, **kwargs): def nested(self, *args, **kwargs):
assert self.state in (v_from, container), (self.state, v_from, container) assert self.state in (v_from, container), (self.state, v_from, container)
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return nested return nested
return helper return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ) reader = make_helper(VALUE_READ, CONTAINER_READ)
def makeZigZag(n, bits): def makeZigZag(n, bits):
return (n << 1) ^ (n >> (bits - 1)) checkIntegerLimits(n, bits)
return (n << 1) ^ (n >> (bits - 1))
def fromZigZag(n): def fromZigZag(n):
return (n >> 1) ^ -(n & 1) return (n >> 1) ^ -(n & 1)
def writeVarint(trans, n): def writeVarint(trans, n):
out = [] out = bytearray()
while True: while True:
if n & ~0x7f == 0: if n & ~0x7f == 0:
out.append(n) out.append(n)
break break
else: else:
out.append((n & 0xff) | 0x80) out.append((n & 0xff) | 0x80)
n = n >> 7 n = n >> 7
trans.write(''.join(map(chr, out))) trans.write(bytes(out))
def readVarint(trans): def readVarint(trans):
result = 0 result = 0
shift = 0 shift = 0
while True: while True:
x = trans.readAll(1) x = trans.readAll(1)
byte = ord(x) byte = ord(x)
result |= (byte & 0x7f) << shift result |= (byte & 0x7f) << shift
if byte >> 7 == 0: if byte >> 7 == 0:
return result return result
shift += 7 shift += 7
class CompactType: class CompactType(object):
STOP = 0x00 STOP = 0x00
TRUE = 0x01 TRUE = 0x01
FALSE = 0x02 FALSE = 0x02
BYTE = 0x03 BYTE = 0x03
I16 = 0x04 I16 = 0x04
I32 = 0x05 I32 = 0x05
I64 = 0x06 I64 = 0x06
DOUBLE = 0x07 DOUBLE = 0x07
BINARY = 0x08 BINARY = 0x08
LIST = 0x09 LIST = 0x09
SET = 0x0A SET = 0x0A
MAP = 0x0B MAP = 0x0B
STRUCT = 0x0C STRUCT = 0x0C
CTYPES = {TType.STOP: CompactType.STOP, CTYPES = {
TType.BOOL: CompactType.TRUE, # used for collection TType.STOP: CompactType.STOP,
TType.BYTE: CompactType.BYTE, TType.BOOL: CompactType.TRUE, # used for collection
TType.I16: CompactType.I16, TType.BYTE: CompactType.BYTE,
TType.I32: CompactType.I32, TType.I16: CompactType.I16,
TType.I64: CompactType.I64, TType.I32: CompactType.I32,
TType.DOUBLE: CompactType.DOUBLE, TType.I64: CompactType.I64,
TType.STRING: CompactType.BINARY, TType.DOUBLE: CompactType.DOUBLE,
TType.STRUCT: CompactType.STRUCT, TType.STRING: CompactType.BINARY,
TType.LIST: CompactType.LIST, TType.STRUCT: CompactType.STRUCT,
TType.SET: CompactType.SET, TType.LIST: CompactType.LIST,
TType.MAP: CompactType.MAP TType.SET: CompactType.SET,
} TType.MAP: CompactType.MAP,
}
TTYPES = {} TTYPES = {}
for k, v in CTYPES.items(): for k, v in CTYPES.items():
TTYPES[v] = k TTYPES[v] = k
TTYPES[CompactType.FALSE] = TType.BOOL TTYPES[CompactType.FALSE] = TType.BOOL
del k del k
del v del v
class TCompactProtocol(TProtocolBase): class TCompactProtocol(TProtocolBase):
"""Compact implementation of the Thrift protocol driver.""" """Compact implementation of the Thrift protocol driver."""
PROTOCOL_ID = 0x82 PROTOCOL_ID = 0x82
VERSION = 1 VERSION = 1
VERSION_MASK = 0x1f VERSION_MASK = 0x1f
TYPE_MASK = 0xe0 TYPE_MASK = 0xe0
TYPE_SHIFT_AMOUNT = 5 TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5
def __init__(self, trans): def __init__(self, trans,
TProtocolBase.__init__(self, trans) string_length_limit=None,
self.state = CLEAR container_length_limit=None):
self.__last_fid = 0 TProtocolBase.__init__(self, trans)
self.__bool_fid = None self.state = CLEAR
self.__bool_value = None self.__last_fid = 0
self.__structs = [] self.__bool_fid = None
self.__containers = [] self.__bool_value = None
self.__structs = []
self.__containers = []
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def __writeVarint(self, n): def _check_string_length(self, length):
writeVarint(self.trans, n) self._check_length(self.string_length_limit, length)
def writeMessageBegin(self, name, type, seqid): def _check_container_length(self, length):
assert self.state == CLEAR self._check_length(self.container_length_limit, length)
self.__writeUByte(self.PROTOCOL_ID)
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
self.__writeVarint(seqid)
self.__writeString(name)
self.state = VALUE_WRITE
def writeMessageEnd(self): def __writeVarint(self, n):
assert self.state == VALUE_WRITE writeVarint(self.trans, n)
self.state = CLEAR
def writeStructBegin(self, name): def writeMessageBegin(self, name, type, seqid):
assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state assert self.state == CLEAR
self.__structs.append((self.state, self.__last_fid)) self.__writeUByte(self.PROTOCOL_ID)
self.state = FIELD_WRITE self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
self.__last_fid = 0 self.__writeVarint(seqid)
self.__writeBinary(str_to_binary(name))
self.state = VALUE_WRITE
def writeStructEnd(self): def writeMessageEnd(self):
assert self.state == FIELD_WRITE assert self.state == VALUE_WRITE
self.state, self.__last_fid = self.__structs.pop() self.state = CLEAR
def writeFieldStop(self): def writeStructBegin(self, name):
self.__writeByte(0) assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
self.__structs.append((self.state, self.__last_fid))
self.state = FIELD_WRITE
self.__last_fid = 0
def __writeFieldHeader(self, type, fid): def writeStructEnd(self):
delta = fid - self.__last_fid assert self.state == FIELD_WRITE
if 0 < delta <= 15: self.state, self.__last_fid = self.__structs.pop()
self.__writeUByte(delta << 4 | type)
else:
self.__writeByte(type)
self.__writeI16(fid)
self.__last_fid = fid
def writeFieldBegin(self, name, type, fid): def writeFieldStop(self):
assert self.state == FIELD_WRITE, self.state self.__writeByte(0)
if type == TType.BOOL:
self.state = BOOL_WRITE
self.__bool_fid = fid
else:
self.state = VALUE_WRITE
self.__writeFieldHeader(CTYPES[type], fid)
def writeFieldEnd(self): def __writeFieldHeader(self, type, fid):
assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state delta = fid - self.__last_fid
self.state = FIELD_WRITE if 0 < delta <= 15:
self.__writeUByte(delta << 4 | type)
else:
self.__writeByte(type)
self.__writeI16(fid)
self.__last_fid = fid
def __writeUByte(self, byte): def writeFieldBegin(self, name, type, fid):
self.trans.write(pack('!B', byte)) assert self.state == FIELD_WRITE, self.state
if type == TType.BOOL:
self.state = BOOL_WRITE
self.__bool_fid = fid
else:
self.state = VALUE_WRITE
self.__writeFieldHeader(CTYPES[type], fid)
def __writeByte(self, byte): def writeFieldEnd(self):
self.trans.write(pack('!b', byte)) assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
self.state = FIELD_WRITE
def __writeI16(self, i16): def __writeUByte(self, byte):
self.__writeVarint(makeZigZag(i16, 16)) self.trans.write(pack('!B', byte))
def __writeSize(self, i32): def __writeByte(self, byte):
self.__writeVarint(i32) self.trans.write(pack('!b', byte))
def writeCollectionBegin(self, etype, size): def __writeI16(self, i16):
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state self.__writeVarint(makeZigZag(i16, 16))
if size <= 14:
self.__writeUByte(size << 4 | CTYPES[etype])
else:
self.__writeUByte(0xf0 | CTYPES[etype])
self.__writeSize(size)
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
writeSetBegin = writeCollectionBegin
writeListBegin = writeCollectionBegin
def writeMapBegin(self, ktype, vtype, size): def __writeSize(self, i32):
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state self.__writeVarint(i32)
if size == 0:
self.__writeByte(0)
else:
self.__writeSize(size)
self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
def writeCollectionEnd(self): def writeCollectionBegin(self, etype, size):
assert self.state == CONTAINER_WRITE, self.state assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
self.state = self.__containers.pop() if size <= 14:
writeMapEnd = writeCollectionEnd self.__writeUByte(size << 4 | CTYPES[etype])
writeSetEnd = writeCollectionEnd else:
writeListEnd = writeCollectionEnd self.__writeUByte(0xf0 | CTYPES[etype])
self.__writeSize(size)
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
writeSetBegin = writeCollectionBegin
writeListBegin = writeCollectionBegin
def writeBool(self, bool): def writeMapBegin(self, ktype, vtype, size):
if self.state == BOOL_WRITE: assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
if bool: if size == 0:
ctype = CompactType.TRUE self.__writeByte(0)
else: else:
ctype = CompactType.FALSE self.__writeSize(size)
self.__writeFieldHeader(ctype, self.__bool_fid) self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
elif self.state == CONTAINER_WRITE: self.__containers.append(self.state)
if bool: self.state = CONTAINER_WRITE
self.__writeByte(CompactType.TRUE)
else:
self.__writeByte(CompactType.FALSE)
else:
raise AssertionError("Invalid state in compact protocol")
writeByte = writer(__writeByte) def writeCollectionEnd(self):
writeI16 = writer(__writeI16) assert self.state == CONTAINER_WRITE, self.state
self.state = self.__containers.pop()
writeMapEnd = writeCollectionEnd
writeSetEnd = writeCollectionEnd
writeListEnd = writeCollectionEnd
@writer def writeBool(self, bool):
def writeI32(self, i32): if self.state == BOOL_WRITE:
self.__writeVarint(makeZigZag(i32, 32)) if bool:
ctype = CompactType.TRUE
else:
ctype = CompactType.FALSE
self.__writeFieldHeader(ctype, self.__bool_fid)
elif self.state == CONTAINER_WRITE:
if bool:
self.__writeByte(CompactType.TRUE)
else:
self.__writeByte(CompactType.FALSE)
else:
raise AssertionError("Invalid state in compact protocol")
@writer writeByte = writer(__writeByte)
def writeI64(self, i64): writeI16 = writer(__writeI16)
self.__writeVarint(makeZigZag(i64, 64))
@writer @writer
def writeDouble(self, dub): def writeI32(self, i32):
self.trans.write(pack('!d', dub)) self.__writeVarint(makeZigZag(i32, 32))
def __writeString(self, s): @writer
self.__writeSize(len(s)) def writeI64(self, i64):
self.trans.write(s) self.__writeVarint(makeZigZag(i64, 64))
writeString = writer(__writeString)
def readFieldBegin(self): @writer
assert self.state == FIELD_READ, self.state def writeDouble(self, dub):
type = self.__readUByte() self.trans.write(pack('<d', dub))
if type & 0x0f == TType.STOP:
return (None, 0, 0)
delta = type >> 4
if delta == 0:
fid = self.__readI16()
else:
fid = self.__last_fid + delta
self.__last_fid = fid
type = type & 0x0f
if type == CompactType.TRUE:
self.state = BOOL_READ
self.__bool_value = True
elif type == CompactType.FALSE:
self.state = BOOL_READ
self.__bool_value = False
else:
self.state = VALUE_READ
return (None, self.__getTType(type), fid)
def readFieldEnd(self): def __writeBinary(self, s):
assert self.state in (VALUE_READ, BOOL_READ), self.state self.__writeSize(len(s))
self.state = FIELD_READ self.trans.write(s)
writeBinary = writer(__writeBinary)
def __readUByte(self): def readFieldBegin(self):
result, = unpack('!B', self.trans.readAll(1)) assert self.state == FIELD_READ, self.state
return result type = self.__readUByte()
if type & 0x0f == TType.STOP:
return (None, 0, 0)
delta = type >> 4
if delta == 0:
fid = self.__readI16()
else:
fid = self.__last_fid + delta
self.__last_fid = fid
type = type & 0x0f
if type == CompactType.TRUE:
self.state = BOOL_READ
self.__bool_value = True
elif type == CompactType.FALSE:
self.state = BOOL_READ
self.__bool_value = False
else:
self.state = VALUE_READ
return (None, self.__getTType(type), fid)
def __readByte(self): def readFieldEnd(self):
result, = unpack('!b', self.trans.readAll(1)) assert self.state in (VALUE_READ, BOOL_READ), self.state
return result self.state = FIELD_READ
def __readVarint(self): def __readUByte(self):
return readVarint(self.trans) result, = unpack('!B', self.trans.readAll(1))
return result
def __readZigZag(self): def __readByte(self):
return fromZigZag(self.__readVarint()) result, = unpack('!b', self.trans.readAll(1))
return result
def __readSize(self): def __readVarint(self):
result = self.__readVarint() return readVarint(self.trans)
if result < 0:
raise TException("Length < 0")
return result
def readMessageBegin(self): def __readZigZag(self):
assert self.state == CLEAR return fromZigZag(self.__readVarint())
proto_id = self.__readUByte()
if proto_id != self.PROTOCOL_ID:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad protocol id in the message: %d' % proto_id)
ver_type = self.__readUByte()
type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
version = ver_type & self.VERSION_MASK
if version != self.VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad version: %d (expect %d)' % (version, self.VERSION))
seqid = self.__readVarint()
name = self.__readString()
return (name, type, seqid)
def readMessageEnd(self): def __readSize(self):
assert self.state == CLEAR result = self.__readVarint()
assert len(self.__structs) == 0 if result < 0:
raise TProtocolException("Length < 0")
return result
def readStructBegin(self): def readMessageBegin(self):
assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state assert self.state == CLEAR
self.__structs.append((self.state, self.__last_fid)) proto_id = self.__readUByte()
self.state = FIELD_READ if proto_id != self.PROTOCOL_ID:
self.__last_fid = 0 raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad protocol id in the message: %d' % proto_id)
ver_type = self.__readUByte()
type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
version = ver_type & self.VERSION_MASK
if version != self.VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad version: %d (expect %d)' % (version, self.VERSION))
seqid = self.__readVarint()
name = binary_to_str(self.__readBinary())
return (name, type, seqid)
def readStructEnd(self): def readMessageEnd(self):
assert self.state == FIELD_READ assert self.state == CLEAR
self.state, self.__last_fid = self.__structs.pop() assert len(self.__structs) == 0
def readCollectionBegin(self): def readStructBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
size_type = self.__readUByte() self.__structs.append((self.state, self.__last_fid))
size = size_type >> 4 self.state = FIELD_READ
type = self.__getTType(size_type) self.__last_fid = 0
if size == 15:
size = self.__readSize()
self.__containers.append(self.state)
self.state = CONTAINER_READ
return type, size
readSetBegin = readCollectionBegin
readListBegin = readCollectionBegin
def readMapBegin(self): def readStructEnd(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state assert self.state == FIELD_READ
size = self.__readSize() self.state, self.__last_fid = self.__structs.pop()
types = 0
if size > 0:
types = self.__readUByte()
vtype = self.__getTType(types)
ktype = self.__getTType(types >> 4)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return (ktype, vtype, size)
def readCollectionEnd(self): def readCollectionBegin(self):
assert self.state == CONTAINER_READ, self.state assert self.state in (VALUE_READ, CONTAINER_READ), self.state
self.state = self.__containers.pop() size_type = self.__readUByte()
readSetEnd = readCollectionEnd size = size_type >> 4
readListEnd = readCollectionEnd type = self.__getTType(size_type)
readMapEnd = readCollectionEnd if size == 15:
size = self.__readSize()
self._check_container_length(size)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return type, size
readSetBegin = readCollectionBegin
readListBegin = readCollectionBegin
def readBool(self): def readMapBegin(self):
if self.state == BOOL_READ: assert self.state in (VALUE_READ, CONTAINER_READ), self.state
return self.__bool_value == CompactType.TRUE size = self.__readSize()
elif self.state == CONTAINER_READ: self._check_container_length(size)
return self.__readByte() == CompactType.TRUE types = 0
else: if size > 0:
raise AssertionError("Invalid state in compact protocol: %d" % types = self.__readUByte()
self.state) vtype = self.__getTType(types)
ktype = self.__getTType(types >> 4)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return (ktype, vtype, size)
readByte = reader(__readByte) def readCollectionEnd(self):
__readI16 = __readZigZag assert self.state == CONTAINER_READ, self.state
readI16 = reader(__readZigZag) self.state = self.__containers.pop()
readI32 = reader(__readZigZag) readSetEnd = readCollectionEnd
readI64 = reader(__readZigZag) readListEnd = readCollectionEnd
readMapEnd = readCollectionEnd
@reader def readBool(self):
def readDouble(self): if self.state == BOOL_READ:
buff = self.trans.readAll(8) return self.__bool_value == CompactType.TRUE
val, = unpack('!d', buff) elif self.state == CONTAINER_READ:
return val return self.__readByte() == CompactType.TRUE
else:
raise AssertionError("Invalid state in compact protocol: %d" %
self.state)
def __readString(self): readByte = reader(__readByte)
len = self.__readSize() __readI16 = __readZigZag
return self.trans.readAll(len) readI16 = reader(__readZigZag)
readString = reader(__readString) readI32 = reader(__readZigZag)
readI64 = reader(__readZigZag)
def __getTType(self, byte): @reader
return TTYPES[byte & 0x0f] def readDouble(self):
buff = self.trans.readAll(8)
val, = unpack('<d', buff)
return val
def __readBinary(self):
size = self.__readSize()
self._check_string_length(size)
return self.trans.readAll(size)
readBinary = reader(__readBinary)
def __getTType(self, byte):
return TTYPES[byte & 0x0f]
class TCompactProtocolFactory: class TCompactProtocolFactory(object):
def __init__(self): def __init__(self,
string_length_limit=None,
container_length_limit=None):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def getProtocol(self, trans):
return TCompactProtocol(trans,
self.string_length_limit,
self.container_length_limit)
class TCompactProtocolAccelerated(TCompactProtocol):
"""C-Accelerated version of TCompactProtocol.
This class does not override any of TCompactProtocol's methods,
but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely.
We inherit from TCompactProtocol so that the normal TCompactProtocol
encoding can happen if the fastbinary module doesn't work for some
reason.
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use
TCompactProtocolAccelerated instead of TCompactProtocol.
"""
pass pass
def getProtocol(self, trans): def __init__(self, *args, **kwargs):
return TCompactProtocol(trans) fallback = kwargs.pop('fallback', True)
super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_compact
self._fast_encode = fastbinary.encode_compact
class TCompactProtocolAcceleratedFactory(object):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TCompactProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)

View file

@ -0,0 +1,677 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import (TType, TProtocolBase, TProtocolException,
checkIntegerLimits)
import base64
import math
import sys
from ..compat import str_to_binary
__all__ = ['TJSONProtocol',
'TJSONProtocolFactory',
'TSimpleJSONProtocol',
'TSimpleJSONProtocolFactory']
VERSION = 1
COMMA = b','
COLON = b':'
LBRACE = b'{'
RBRACE = b'}'
LBRACKET = b'['
RBRACKET = b']'
QUOTE = b'"'
BACKSLASH = b'\\'
ZERO = b'0'
ESCSEQ0 = ord('\\')
ESCSEQ1 = ord('u')
ESCAPE_CHAR_VALS = {
'"': '\\"',
'\\': '\\\\',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
# '/': '\\/',
}
ESCAPE_CHARS = {
b'"': '"',
b'\\': '\\',
b'b': '\b',
b'f': '\f',
b'n': '\n',
b'r': '\r',
b't': '\t',
b'/': '/',
}
NUMERIC_CHAR = b'+-.0123456789Ee'
CTYPES = {
TType.BOOL: 'tf',
TType.BYTE: 'i8',
TType.I16: 'i16',
TType.I32: 'i32',
TType.I64: 'i64',
TType.DOUBLE: 'dbl',
TType.STRING: 'str',
TType.STRUCT: 'rec',
TType.LIST: 'lst',
TType.SET: 'set',
TType.MAP: 'map',
}
JTYPES = {}
for key in CTYPES.keys():
JTYPES[CTYPES[key]] = key
class JSONBaseContext(object):
def __init__(self, protocol):
self.protocol = protocol
self.first = True
def doIO(self, function):
pass
def write(self):
pass
def read(self):
pass
def escapeNum(self):
return False
def __str__(self):
return self.__class__.__name__
class JSONListContext(JSONBaseContext):
def doIO(self, function):
if self.first is True:
self.first = False
else:
function(COMMA)
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
class JSONPairContext(JSONBaseContext):
def __init__(self, protocol):
super(JSONPairContext, self).__init__(protocol)
self.colon = True
def doIO(self, function):
if self.first:
self.first = False
self.colon = True
else:
function(COLON if self.colon else COMMA)
self.colon = not self.colon
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
def escapeNum(self):
return self.colon
def __str__(self):
return '%s, colon=%s' % (self.__class__.__name__, self.colon)
class LookaheadReader():
hasData = False
data = ''
def __init__(self, protocol):
self.protocol = protocol
def read(self):
if self.hasData is True:
self.hasData = False
else:
self.data = self.protocol.trans.read(1)
return self.data
def peek(self):
if self.hasData is False:
self.data = self.protocol.trans.read(1)
self.hasData = True
return self.data
class TJSONProtocolBase(TProtocolBase):
def __init__(self, trans):
TProtocolBase.__init__(self, trans)
self.resetWriteContext()
self.resetReadContext()
# We don't have length limit implementation for JSON protocols
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
def resetWriteContext(self):
self.context = JSONBaseContext(self)
self.contextStack = [self.context]
def resetReadContext(self):
self.resetWriteContext()
self.reader = LookaheadReader(self)
def pushContext(self, ctx):
self.contextStack.append(ctx)
self.context = ctx
def popContext(self):
self.contextStack.pop()
if self.contextStack:
self.context = self.contextStack[-1]
else:
self.context = JSONBaseContext(self)
def writeJSONString(self, string):
self.context.write()
json_str = ['"']
for s in string:
escaped = ESCAPE_CHAR_VALS.get(s, s)
json_str.append(escaped)
json_str.append('"')
self.trans.write(str_to_binary(''.join(json_str)))
def writeJSONNumber(self, number, formatter='{0}'):
self.context.write()
jsNumber = str(formatter.format(number)).encode('ascii')
if self.context.escapeNum():
self.trans.write(QUOTE)
self.trans.write(jsNumber)
self.trans.write(QUOTE)
else:
self.trans.write(jsNumber)
def writeJSONBase64(self, binary):
self.context.write()
self.trans.write(QUOTE)
self.trans.write(base64.b64encode(binary))
self.trans.write(QUOTE)
def writeJSONObjectStart(self):
self.context.write()
self.trans.write(LBRACE)
self.pushContext(JSONPairContext(self))
def writeJSONObjectEnd(self):
self.popContext()
self.trans.write(RBRACE)
def writeJSONArrayStart(self):
self.context.write()
self.trans.write(LBRACKET)
self.pushContext(JSONListContext(self))
def writeJSONArrayEnd(self):
self.popContext()
self.trans.write(RBRACKET)
def readJSONSyntaxChar(self, character):
current = self.reader.read()
if character != current:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unexpected character: %s" % current)
def _isHighSurrogate(self, codeunit):
return codeunit >= 0xd800 and codeunit <= 0xdbff
def _isLowSurrogate(self, codeunit):
return codeunit >= 0xdc00 and codeunit <= 0xdfff
def _toChar(self, high, low=None):
if not low:
if sys.version_info[0] == 2:
return ("\\u%04x" % high).decode('unicode-escape') \
.encode('utf-8')
else:
return chr(high)
else:
codepoint = (1 << 16) + ((high & 0x3ff) << 10)
codepoint += low & 0x3ff
if sys.version_info[0] == 2:
s = "\\U%08x" % codepoint
return s.decode('unicode-escape').encode('utf-8')
else:
return chr(codepoint)
def readJSONString(self, skipContext):
highSurrogate = None
string = []
if skipContext is False:
self.context.read()
self.readJSONSyntaxChar(QUOTE)
while True:
character = self.reader.read()
if character == QUOTE:
break
if ord(character) == ESCSEQ0:
character = self.reader.read()
if ord(character) == ESCSEQ1:
character = self.trans.read(4).decode('ascii')
codeunit = int(character, 16)
if self._isHighSurrogate(codeunit):
if highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected low surrogate char")
highSurrogate = codeunit
continue
elif self._isLowSurrogate(codeunit):
if not highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected high surrogate char")
character = self._toChar(highSurrogate, codeunit)
highSurrogate = None
else:
character = self._toChar(codeunit)
else:
if character not in ESCAPE_CHARS:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected control char")
character = ESCAPE_CHARS[character]
elif character in ESCAPE_CHAR_VALS:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unescaped control char")
elif sys.version_info[0] > 2:
utf8_bytes = bytearray([ord(character)])
while ord(self.reader.peek()) >= 0x80:
utf8_bytes.append(ord(self.reader.read()))
character = utf8_bytes.decode('utf8')
string.append(character)
if highSurrogate:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Expected low surrogate char")
return ''.join(string)
def isJSONNumeric(self, character):
return (True if NUMERIC_CHAR.find(character) != - 1 else False)
def readJSONQuotes(self):
if (self.context.escapeNum()):
self.readJSONSyntaxChar(QUOTE)
def readJSONNumericChars(self):
numeric = []
while True:
character = self.reader.peek()
if self.isJSONNumeric(character) is False:
break
numeric.append(self.reader.read())
return b''.join(numeric).decode('ascii')
def readJSONInteger(self):
self.context.read()
self.readJSONQuotes()
numeric = self.readJSONNumericChars()
self.readJSONQuotes()
try:
return int(numeric)
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONDouble(self):
self.context.read()
if self.reader.peek() == QUOTE:
string = self.readJSONString(True)
try:
double = float(string)
if (self.context.escapeNum is False and
not math.isinf(double) and
not math.isnan(double)):
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Numeric data unexpectedly quoted")
return double
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
else:
if self.context.escapeNum() is True:
self.readJSONSyntaxChar(QUOTE)
try:
return float(self.readJSONNumericChars())
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONBase64(self):
string = self.readJSONString(False)
size = len(string)
m = size % 4
# Force padding since b64encode method does not allow it
if m != 0:
for i in range(4 - m):
string += '='
return base64.b64decode(string)
def readJSONObjectStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACE)
self.pushContext(JSONPairContext(self))
def readJSONObjectEnd(self):
self.readJSONSyntaxChar(RBRACE)
self.popContext()
def readJSONArrayStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACKET)
self.pushContext(JSONListContext(self))
def readJSONArrayEnd(self):
self.readJSONSyntaxChar(RBRACKET)
self.popContext()
class TJSONProtocol(TJSONProtocolBase):
def readMessageBegin(self):
self.resetReadContext()
self.readJSONArrayStart()
if self.readJSONInteger() != VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
"Message contained bad version.")
name = self.readJSONString(False)
typen = self.readJSONInteger()
seqid = self.readJSONInteger()
return (name, typen, seqid)
def readMessageEnd(self):
self.readJSONArrayEnd()
def readStructBegin(self):
self.readJSONObjectStart()
def readStructEnd(self):
self.readJSONObjectEnd()
def readFieldBegin(self):
character = self.reader.peek()
ttype = 0
id = 0
if character == RBRACE:
ttype = TType.STOP
else:
id = self.readJSONInteger()
self.readJSONObjectStart()
ttype = JTYPES[self.readJSONString(False)]
return (None, ttype, id)
def readFieldEnd(self):
self.readJSONObjectEnd()
def readMapBegin(self):
self.readJSONArrayStart()
keyType = JTYPES[self.readJSONString(False)]
valueType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
self.readJSONObjectStart()
return (keyType, valueType, size)
def readMapEnd(self):
self.readJSONObjectEnd()
self.readJSONArrayEnd()
def readCollectionBegin(self):
self.readJSONArrayStart()
elemType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
return (elemType, size)
readListBegin = readCollectionBegin
readSetBegin = readCollectionBegin
def readCollectionEnd(self):
self.readJSONArrayEnd()
readSetEnd = readCollectionEnd
readListEnd = readCollectionEnd
def readBool(self):
return (False if self.readJSONInteger() == 0 else True)
def readNumber(self):
return self.readJSONInteger()
readByte = readNumber
readI16 = readNumber
readI32 = readNumber
readI64 = readNumber
def readDouble(self):
return self.readJSONDouble()
def readString(self):
return self.readJSONString(False)
def readBinary(self):
return self.readJSONBase64()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
self.writeJSONArrayStart()
self.writeJSONNumber(VERSION)
self.writeJSONString(name)
self.writeJSONNumber(request_type)
self.writeJSONNumber(seqid)
def writeMessageEnd(self):
self.writeJSONArrayEnd()
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, id):
self.writeJSONNumber(id)
self.writeJSONObjectStart()
self.writeJSONString(CTYPES[ttype])
def writeFieldEnd(self):
self.writeJSONObjectEnd()
def writeFieldStop(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[ktype])
self.writeJSONString(CTYPES[vtype])
self.writeJSONNumber(size)
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
self.writeJSONArrayEnd()
def writeListBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeListEnd(self):
self.writeJSONArrayEnd()
def writeSetBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeSetEnd(self):
self.writeJSONArrayEnd()
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeDouble(self, dbl):
# 17 significant digits should be just enough for any double precision
# value.
self.writeJSONNumber(dbl, '{0:.17g}')
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TJSONProtocolFactory(object):
def getProtocol(self, trans):
return TJSONProtocol(trans)
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
Useful for interacting with scripting languages.
"""
def readMessageBegin(self):
raise NotImplementedError()
def readMessageEnd(self):
raise NotImplementedError()
def readStructBegin(self):
raise NotImplementedError()
def readStructEnd(self):
raise NotImplementedError()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, fid):
self.writeJSONString(name)
def writeFieldEnd(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
def _writeCollectionBegin(self, etype, size):
self.writeJSONArrayStart()
def _writeCollectionEnd(self):
self.writeJSONArrayEnd()
writeListBegin = _writeCollectionBegin
writeListEnd = _writeCollectionEnd
writeSetBegin = _writeCollectionBegin
writeSetEnd = _writeCollectionEnd
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeDouble(self, dbl):
self.writeJSONNumber(dbl)
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TSimpleJSONProtocolFactory(object):
def getProtocol(self, trans):
return TSimpleJSONProtocol(trans)

View file

@ -0,0 +1,40 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TMessageType
from thrift.protocol import TProtocolDecorator
SEPARATOR = ":"
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, serviceName):
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
self.serviceName = serviceName
def writeMessageBegin(self, name, type, seqid):
if (type == TMessageType.CALL or
type == TMessageType.ONEWAY):
self.protocol.writeMessageBegin(
self.serviceName + SEPARATOR + name,
type,
seqid
)
else:
self.protocol.writeMessageBegin(name, type, seqid)

View file

@ -17,390 +17,403 @@
# under the License. # under the License.
# #
from thrift.Thrift import * from thrift.Thrift import TException, TType, TFrozenDict
from thrift.transport.TTransport import TTransportException
from ..compat import binary_to_str, str_to_binary
import six
import sys
from itertools import islice
from six.moves import zip
class TProtocolException(TException): class TProtocolException(TException):
"""Custom Protocol Exception class""" """Custom Protocol Exception class"""
UNKNOWN = 0 UNKNOWN = 0
INVALID_DATA = 1 INVALID_DATA = 1
NEGATIVE_SIZE = 2 NEGATIVE_SIZE = 2
SIZE_LIMIT = 3 SIZE_LIMIT = 3
BAD_VERSION = 4 BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
def __init__(self, type=UNKNOWN, message=None): def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message) TException.__init__(self, message)
self.type = type self.type = type
class TProtocolBase: class TProtocolBase(object):
"""Base class for Thrift protocol driver.""" """Base class for Thrift protocol driver."""
def __init__(self, trans): def __init__(self, trans):
self.trans = trans self.trans = trans
self._fast_decode = None
self._fast_encode = None
def writeMessageBegin(self, name, type, seqid): @staticmethod
pass def _check_length(limit, length):
if length < 0:
raise TTransportException(TTransportException.NEGATIVE_SIZE,
'Negative length: %d' % length)
if limit is not None and length > limit:
raise TTransportException(TTransportException.SIZE_LIMIT,
'Length exceeded max allowed: %d' % limit)
def writeMessageEnd(self): def writeMessageBegin(self, name, ttype, seqid):
pass pass
def writeStructBegin(self, name): def writeMessageEnd(self):
pass pass
def writeStructEnd(self): def writeStructBegin(self, name):
pass pass
def writeFieldBegin(self, name, type, id): def writeStructEnd(self):
pass pass
def writeFieldEnd(self): def writeFieldBegin(self, name, ttype, fid):
pass pass
def writeFieldStop(self): def writeFieldEnd(self):
pass pass
def writeMapBegin(self, ktype, vtype, size): def writeFieldStop(self):
pass pass
def writeMapEnd(self): def writeMapBegin(self, ktype, vtype, size):
pass pass
def writeListBegin(self, etype, size): def writeMapEnd(self):
pass pass
def writeListEnd(self): def writeListBegin(self, etype, size):
pass pass
def writeSetBegin(self, etype, size): def writeListEnd(self):
pass pass
def writeSetEnd(self): def writeSetBegin(self, etype, size):
pass pass
def writeBool(self, bool): def writeSetEnd(self):
pass pass
def writeByte(self, byte): def writeBool(self, bool_val):
pass pass
def writeI16(self, i16): def writeByte(self, byte):
pass pass
def writeI32(self, i32): def writeI16(self, i16):
pass pass
def writeI64(self, i64): def writeI32(self, i32):
pass pass
def writeDouble(self, dub): def writeI64(self, i64):
pass pass
def writeString(self, str): def writeDouble(self, dub):
pass pass
def readMessageBegin(self): def writeString(self, str_val):
pass self.writeBinary(str_to_binary(str_val))
def readMessageEnd(self): def writeBinary(self, str_val):
pass pass
def readStructBegin(self): def writeUtf8(self, str_val):
pass self.writeString(str_val.encode('utf8'))
def readStructEnd(self): def readMessageBegin(self):
pass pass
def readFieldBegin(self): def readMessageEnd(self):
pass pass
def readFieldEnd(self): def readStructBegin(self):
pass pass
def readMapBegin(self): def readStructEnd(self):
pass pass
def readMapEnd(self): def readFieldBegin(self):
pass pass
def readListBegin(self): def readFieldEnd(self):
pass pass
def readListEnd(self): def readMapBegin(self):
pass pass
def readSetBegin(self): def readMapEnd(self):
pass pass
def readSetEnd(self): def readListBegin(self):
pass pass
def readBool(self): def readListEnd(self):
pass pass
def readByte(self): def readSetBegin(self):
pass pass
def readI16(self): def readSetEnd(self):
pass pass
def readI32(self): def readBool(self):
pass pass
def readI64(self): def readByte(self):
pass pass
def readDouble(self): def readI16(self):
pass pass
def readString(self): def readI32(self):
pass pass
def skip(self, type): def readI64(self):
if type == TType.STOP: pass
return
elif type == TType.BOOL:
self.readBool()
elif type == TType.BYTE:
self.readByte()
elif type == TType.I16:
self.readI16()
elif type == TType.I32:
self.readI32()
elif type == TType.I64:
self.readI64()
elif type == TType.DOUBLE:
self.readDouble()
elif type == TType.STRING:
self.readString()
elif type == TType.STRUCT:
name = self.readStructBegin()
while True:
(name, type, id) = self.readFieldBegin()
if type == TType.STOP:
break
self.skip(type)
self.readFieldEnd()
self.readStructEnd()
elif type == TType.MAP:
(ktype, vtype, size) = self.readMapBegin()
for i in range(size):
self.skip(ktype)
self.skip(vtype)
self.readMapEnd()
elif type == TType.SET:
(etype, size) = self.readSetBegin()
for i in range(size):
self.skip(etype)
self.readSetEnd()
elif type == TType.LIST:
(etype, size) = self.readListBegin()
for i in range(size):
self.skip(etype)
self.readListEnd()
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) def readDouble(self):
_TTYPE_HANDLERS = ( pass
(None, None, False), # 0 TType.STOP
(None, None, False), # 1 TType.VOID # TODO: handle void?
('readBool', 'writeBool', False), # 2 TType.BOOL
('readByte', 'writeByte', False), # 3 TType.BYTE and I08
('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
(None, None, False), # 5 undefined
('readI16', 'writeI16', False), # 6 TType.I16
(None, None, False), # 7 undefined
('readI32', 'writeI32', False), # 8 TType.I32
(None, None, False), # 9 undefined
('readI64', 'writeI64', False), # 10 TType.I64
('readString', 'writeString', False), # 11 TType.STRING and UTF7
('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
('readContainerList', 'writeContainerList', True), # 15 TType.LIST
(None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
)
def readFieldByTType(self, ttype, spec): def readString(self):
try: return binary_to_str(self.readBinary())
(r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
except IndexError:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid field type %d' % (ttype))
if r_handler is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid field type %d' % (ttype))
reader = getattr(self, r_handler)
if not is_container:
return reader()
return reader(spec)
def readContainerList(self, spec): def readBinary(self):
results = [] pass
ttype, tspec = spec[0], spec[1]
r_handler = self._TTYPE_HANDLERS[ttype][0]
reader = getattr(self, r_handler)
(list_type, list_len) = self.readListBegin()
if tspec is None:
# list values are simple types
for idx in xrange(list_len):
results.append(reader())
else:
# this is like an inlined readFieldByTType
container_reader = self._TTYPE_HANDLERS[list_type][0]
val_reader = getattr(self, container_reader)
for idx in xrange(list_len):
val = val_reader(tspec)
results.append(val)
self.readListEnd()
return results
def readContainerSet(self, spec): def readUtf8(self):
results = set() return self.readString().decode('utf8')
ttype, tspec = spec[0], spec[1]
r_handler = self._TTYPE_HANDLERS[ttype][0]
reader = getattr(self, r_handler)
(set_type, set_len) = self.readSetBegin()
if tspec is None:
# set members are simple types
for idx in xrange(set_len):
results.add(reader())
else:
container_reader = self._TTYPE_HANDLERS[set_type][0]
val_reader = getattr(self, container_reader)
for idx in xrange(set_len):
results.add(val_reader(tspec))
self.readSetEnd()
return results
def readContainerStruct(self, spec): def skip(self, ttype):
(obj_class, obj_spec) = spec if ttype == TType.STOP:
obj = obj_class() return
obj.read(self) elif ttype == TType.BOOL:
return obj self.readBool()
elif ttype == TType.BYTE:
self.readByte()
elif ttype == TType.I16:
self.readI16()
elif ttype == TType.I32:
self.readI32()
elif ttype == TType.I64:
self.readI64()
elif ttype == TType.DOUBLE:
self.readDouble()
elif ttype == TType.STRING:
self.readString()
elif ttype == TType.STRUCT:
name = self.readStructBegin()
while True:
(name, ttype, id) = self.readFieldBegin()
if ttype == TType.STOP:
break
self.skip(ttype)
self.readFieldEnd()
self.readStructEnd()
elif ttype == TType.MAP:
(ktype, vtype, size) = self.readMapBegin()
for i in range(size):
self.skip(ktype)
self.skip(vtype)
self.readMapEnd()
elif ttype == TType.SET:
(etype, size) = self.readSetBegin()
for i in range(size):
self.skip(etype)
self.readSetEnd()
elif ttype == TType.LIST:
(etype, size) = self.readListBegin()
for i in range(size):
self.skip(etype)
self.readListEnd()
def readContainerMap(self, spec): # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
results = dict() _TTYPE_HANDLERS = (
key_ttype, key_spec = spec[0], spec[1] (None, None, False), # 0 TType.STOP
val_ttype, val_spec = spec[2], spec[3] (None, None, False), # 1 TType.VOID # TODO: handle void?
(map_ktype, map_vtype, map_len) = self.readMapBegin() ('readBool', 'writeBool', False), # 2 TType.BOOL
# TODO: compare types we just decoded with thrift_spec and ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
# abort/skip if types disagree ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) (None, None, False), # 5 undefined
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) ('readI16', 'writeI16', False), # 6 TType.I16
# list values are simple types (None, None, False), # 7 undefined
for idx in xrange(map_len): ('readI32', 'writeI32', False), # 8 TType.I32
if key_spec is None: (None, None, False), # 9 undefined
k_val = key_reader() ('readI64', 'writeI64', False), # 10 TType.I64
else: ('readString', 'writeString', False), # 11 TType.STRING and UTF7
k_val = self.readFieldByTType(key_ttype, key_spec) ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
if val_spec is None: ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
v_val = val_reader() ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
else: ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
v_val = self.readFieldByTType(val_ttype, val_spec) (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
# this raises a TypeError with unhashable keys types (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
# i.e. this fails: d=dict(); d[[0,1]] = 2 )
results[k_val] = v_val
self.readMapEnd()
return results
def readStruct(self, obj, thrift_spec): def _ttype_handlers(self, ttype, spec):
self.readStructBegin() if spec == 'BINARY':
while True: if ttype != TType.STRING:
(fname, ftype, fid) = self.readFieldBegin() raise TProtocolException(type=TProtocolException.INVALID_DATA,
if ftype == TType.STOP: message='Invalid binary field type %d' % ttype)
break return ('readBinary', 'writeBinary', False)
try: if sys.version_info[0] == 2 and spec == 'UTF8':
field = thrift_spec[fid] if ttype != TType.STRING:
except IndexError: raise TProtocolException(type=TProtocolException.INVALID_DATA,
self.skip(ftype) message='Invalid string field type %d' % ttype)
else: return ('readUtf8', 'writeUtf8', False)
if field is not None and ftype == field[1]: return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
fname = field[2]
fspec = field[3]
val = self.readFieldByTType(ftype, fspec)
setattr(obj, fname, val)
else:
self.skip(ftype)
self.readFieldEnd()
self.readStructEnd()
def writeContainerStruct(self, val, spec): def _read_by_ttype(self, ttype, spec, espec):
val.write(self) reader_name, _, is_container = self._ttype_handlers(ttype, spec)
if reader_name is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid type %d' % (ttype))
reader_func = getattr(self, reader_name)
read = (lambda: reader_func(espec)) if is_container else reader_func
while True:
yield read()
def writeContainerList(self, val, spec): def readFieldByTType(self, ttype, spec):
self.writeListBegin(spec[0], len(val)) return next(self._read_by_ttype(ttype, spec, spec))
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
e_writer = getattr(self, w_handler)
if not is_container:
for elem in val:
e_writer(elem)
else:
for elem in val:
e_writer(elem, spec[1])
self.writeListEnd()
def writeContainerSet(self, val, spec): def readContainerList(self, spec):
self.writeSetBegin(spec[0], len(val)) ttype, tspec, is_immutable = spec
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] (list_type, list_len) = self.readListBegin()
e_writer = getattr(self, w_handler) # TODO: compare types we just decoded with thrift_spec
if not is_container: elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
for elem in val: results = (tuple if is_immutable else list)(elems)
e_writer(elem) self.readListEnd()
else: return results
for elem in val:
e_writer(elem, spec[1])
self.writeSetEnd()
def writeContainerMap(self, val, spec): def readContainerSet(self, spec):
k_type = spec[0] ttype, tspec, is_immutable = spec
v_type = spec[2] (set_type, set_len) = self.readSetBegin()
ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] # TODO: compare types we just decoded with thrift_spec
ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
k_writer = getattr(self, ktype_name) results = (frozenset if is_immutable else set)(elems)
v_writer = getattr(self, vtype_name) self.readSetEnd()
self.writeMapBegin(k_type, v_type, len(val)) return results
for m_key, m_val in val.iteritems():
if not k_is_container:
k_writer(m_key)
else:
k_writer(m_key, spec[1])
if not v_is_container:
v_writer(m_val)
else:
v_writer(m_val, spec[3])
self.writeMapEnd()
def writeStruct(self, obj, thrift_spec): def readContainerStruct(self, spec):
self.writeStructBegin(obj.__class__.__name__) (obj_class, obj_spec) = spec
for field in thrift_spec: obj = obj_class()
if field is None: obj.read(self)
continue return obj
fname = field[2]
val = getattr(obj, fname)
if val is None:
# skip writing out unset fields
continue
fid = field[0]
ftype = field[1]
fspec = field[3]
# get the writer method for this value
self.writeFieldBegin(fname, ftype, fid)
self.writeFieldByTType(ftype, val, fspec)
self.writeFieldEnd()
self.writeFieldStop()
self.writeStructEnd()
def writeFieldByTType(self, ttype, val, spec): def readContainerMap(self, spec):
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] ktype, kspec, vtype, vspec, is_immutable = spec
writer = getattr(self, w_handler) (map_ktype, map_vtype, map_len) = self.readMapBegin()
if is_container: # TODO: compare types we just decoded with thrift_spec and
writer(val, spec) # abort/skip if types disagree
else: keys = self._read_by_ttype(ktype, spec, kspec)
writer(val) vals = self._read_by_ttype(vtype, spec, vspec)
keyvals = islice(zip(keys, vals), map_len)
results = (TFrozenDict if is_immutable else dict)(keyvals)
self.readMapEnd()
return results
def readStruct(self, obj, thrift_spec, is_immutable=False):
if is_immutable:
fields = {}
self.readStructBegin()
while True:
(fname, ftype, fid) = self.readFieldBegin()
if ftype == TType.STOP:
break
try:
field = thrift_spec[fid]
except IndexError:
self.skip(ftype)
else:
if field is not None and ftype == field[1]:
fname = field[2]
fspec = field[3]
val = self.readFieldByTType(ftype, fspec)
if is_immutable:
fields[fname] = val
else:
setattr(obj, fname, val)
else:
self.skip(ftype)
self.readFieldEnd()
self.readStructEnd()
if is_immutable:
return obj(**fields)
def writeContainerStruct(self, val, spec):
val.write(self)
def writeContainerList(self, val, spec):
ttype, tspec, _ = spec
self.writeListBegin(ttype, len(val))
for _ in self._write_by_ttype(ttype, val, spec, tspec):
pass
self.writeListEnd()
def writeContainerSet(self, val, spec):
ttype, tspec, _ = spec
self.writeSetBegin(ttype, len(val))
for _ in self._write_by_ttype(ttype, val, spec, tspec):
pass
self.writeSetEnd()
def writeContainerMap(self, val, spec):
ktype, kspec, vtype, vspec, _ = spec
self.writeMapBegin(ktype, vtype, len(val))
for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
pass
self.writeMapEnd()
def writeStruct(self, obj, thrift_spec):
self.writeStructBegin(obj.__class__.__name__)
for field in thrift_spec:
if field is None:
continue
fname = field[2]
val = getattr(obj, fname)
if val is None:
# skip writing out unset fields
continue
fid = field[0]
ftype = field[1]
fspec = field[3]
self.writeFieldBegin(fname, ftype, fid)
self.writeFieldByTType(ftype, val, fspec)
self.writeFieldEnd()
self.writeFieldStop()
self.writeStructEnd()
def _write_by_ttype(self, ttype, vals, spec, espec):
_, writer_name, is_container = self._ttype_handlers(ttype, spec)
writer_func = getattr(self, writer_name)
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
for v in vals:
yield write(v)
def writeFieldByTType(self, ttype, val, spec):
next(self._write_by_ttype(ttype, [val], spec, spec))
class TProtocolFactory: def checkIntegerLimits(i, bits):
def getProtocol(self, trans): if bits == 8 and (i < -128 or i > 127):
pass raise TProtocolException(TProtocolException.INVALID_DATA,
"i8 requires -128 <= number <= 127")
elif bits == 16 and (i < -32768 or i > 32767):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i16 requires -32768 <= number <= 32767")
elif bits == 32 and (i < -2147483648 or i > 2147483647):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i32 requires -2147483648 <= number <= 2147483647")
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
class TProtocolFactory(object):
def getProtocol(self, trans):
pass

View file

@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import types
from thrift.protocol.TProtocol import TProtocolBase
class TProtocolDecorator():
def __init__(self, protocol):
TProtocolBase(protocol)
self.protocol = protocol
def __getattr__(self, name):
if hasattr(self.protocol, name):
member = getattr(self.protocol, name)
if type(member) in [
types.MethodType,
types.FunctionType,
types.LambdaType,
types.BuiltinFunctionType,
types.BuiltinMethodType,
]:
return lambda *args, **kwargs: self._wrap(member, args, kwargs)
else:
return member
raise AttributeError(name)
def _wrap(self, func, args, kwargs):
if isinstance(func, types.MethodType):
result = func(*args, **kwargs)
else:
result = func(self.protocol, *args, **kwargs)
return result

View file

@ -17,4 +17,5 @@
# under the License. # under the License.
# #
__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] __all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
'TJSONProtocol', 'TProtocol']

View file

@ -17,71 +17,71 @@
# under the License. # under the License.
# #
import BaseHTTPServer from six.moves import BaseHTTPServer
from thrift.server import TServer from thrift.server import TServer
from thrift.transport import TTransport from thrift.transport import TTransport
class ResponseException(Exception): class ResponseException(Exception):
"""Allows handlers to override the HTTP response """Allows handlers to override the HTTP response
Normally, THttpServer always sends a 200 response. If a handler wants Normally, THttpServer always sends a 200 response. If a handler wants
to override this behavior (e.g., to simulate a misconfigured or to override this behavior (e.g., to simulate a misconfigured or
overloaded web server during testing), it can raise a ResponseException. overloaded web server during testing), it can raise a ResponseException.
The function passed to the constructor will be called with the The function passed to the constructor will be called with the
RequestHandler as its only argument. RequestHandler as its only argument.
""" """
def __init__(self, handler): def __init__(self, handler):
self.handler = handler self.handler = handler
class THttpServer(TServer.TServer): class THttpServer(TServer.TServer):
"""A simple HTTP-based Thrift server """A simple HTTP-based Thrift server
This class is not very performant, but it is useful (for example) for This class is not very performant, but it is useful (for example) for
acting as a mock version of an Apache-based PHP Thrift endpoint. acting as a mock version of an Apache-based PHP Thrift endpoint.
"""
def __init__(self,
processor,
server_address,
inputProtocolFactory,
outputProtocolFactory=None,
server_class=BaseHTTPServer.HTTPServer):
"""Set up protocol factories and HTTP server.
See BaseHTTPServer for server_address.
See TServer for protocol factories.
""" """
if outputProtocolFactory is None: def __init__(self,
outputProtocolFactory = inputProtocolFactory processor,
server_address,
inputProtocolFactory,
outputProtocolFactory=None,
server_class=BaseHTTPServer.HTTPServer):
"""Set up protocol factories and HTTP server.
TServer.TServer.__init__(self, processor, None, None, None, See BaseHTTPServer for server_address.
inputProtocolFactory, outputProtocolFactory) See TServer for protocol factories.
"""
if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory
thttpserver = self TServer.TServer.__init__(self, processor, None, None, None,
inputProtocolFactory, outputProtocolFactory)
class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): thttpserver = self
def do_POST(self):
# Don't care about the request path.
itrans = TTransport.TFileObjectTransport(self.rfile)
otrans = TTransport.TFileObjectTransport(self.wfile)
itrans = TTransport.TBufferedTransport(
itrans, int(self.headers['Content-Length']))
otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
try:
thttpserver.processor.process(iprot, oprot)
except ResponseException as exn:
exn.handler(self)
else:
self.send_response(200)
self.send_header("content-type", "application/x-thrift")
self.end_headers()
self.wfile.write(otrans.getvalue())
self.httpd = server_class(server_address, RequestHander) class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
def do_POST(self):
# Don't care about the request path.
itrans = TTransport.TFileObjectTransport(self.rfile)
otrans = TTransport.TFileObjectTransport(self.wfile)
itrans = TTransport.TBufferedTransport(
itrans, int(self.headers['Content-Length']))
otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
try:
thttpserver.processor.process(iprot, oprot)
except ResponseException as exn:
exn.handler(self)
else:
self.send_response(200)
self.send_header("content-type", "application/x-thrift")
self.end_headers()
self.wfile.write(otrans.getvalue())
def serve(self): self.httpd = server_class(server_address, RequestHander)
self.httpd.serve_forever()
def serve(self):
self.httpd.serve_forever()

View file

@ -24,18 +24,22 @@ only from the main thread.
The thread poool should be sized for concurrent tasks, not The thread poool should be sized for concurrent tasks, not
maximum connections maximum connections
""" """
import threading
import socket
import Queue
import select
import struct
import logging import logging
import select
import socket
import struct
import threading
from six.moves import queue
from thrift.transport import TTransport from thrift.transport import TTransport
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
__all__ = ['TNonblockingServer'] __all__ = ['TNonblockingServer']
logger = logging.getLogger(__name__)
class Worker(threading.Thread): class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection.""" """Worker is a small helper to process incoming connection."""
@ -54,8 +58,8 @@ class Worker(threading.Thread):
processor.process(iprot, oprot) processor.process(iprot, oprot)
callback(True, otrans.getvalue()) callback(True, otrans.getvalue())
except Exception: except Exception:
logging.exception("Exception while processing request") logger.exception("Exception while processing request")
callback(False, '') callback(False, b'')
WAIT_LEN = 0 WAIT_LEN = 0
WAIT_MESSAGE = 1 WAIT_MESSAGE = 1
@ -85,7 +89,7 @@ def socket_exception(func):
return read return read
class Connection: class Connection(object):
"""Basic class is represented connection. """Basic class is represented connection.
It can be in state: It can be in state:
@ -102,7 +106,7 @@ class Connection:
self.socket.setblocking(False) self.socket.setblocking(False)
self.status = WAIT_LEN self.status = WAIT_LEN
self.len = 0 self.len = 0
self.message = '' self.message = b''
self.lock = threading.Lock() self.lock = threading.Lock()
self.wake_up = wake_up self.wake_up = wake_up
@ -116,21 +120,21 @@ class Connection:
# if we read 0 bytes and self.message is empty, then # if we read 0 bytes and self.message is empty, then
# the client closed the connection # the client closed the connection
if len(self.message) != 0: if len(self.message) != 0:
logging.error("can't read frame size from socket") logger.error("can't read frame size from socket")
self.close() self.close()
return return
self.message += read self.message += read
if len(self.message) == 4: if len(self.message) == 4:
self.len, = struct.unpack('!i', self.message) self.len, = struct.unpack('!i', self.message)
if self.len < 0: if self.len < 0:
logging.error("negative frame size, it seems client " logger.error("negative frame size, it seems client "
"doesn't use FramedTransport") "doesn't use FramedTransport")
self.close() self.close()
elif self.len == 0: elif self.len == 0:
logging.error("empty frame, it's really strange") logger.error("empty frame, it's really strange")
self.close() self.close()
else: else:
self.message = '' self.message = b''
self.status = WAIT_MESSAGE self.status = WAIT_MESSAGE
@socket_exception @socket_exception
@ -145,8 +149,8 @@ class Connection:
elif self.status == WAIT_MESSAGE: elif self.status == WAIT_MESSAGE:
read = self.socket.recv(self.len - len(self.message)) read = self.socket.recv(self.len - len(self.message))
if len(read) == 0: if len(read) == 0:
logging.error("can't read frame from socket (get %d of " logger.error("can't read frame from socket (get %d of "
"%d bytes)" % (len(self.message), self.len)) "%d bytes)" % (len(self.message), self.len))
self.close() self.close()
return return
self.message += read self.message += read
@ -160,7 +164,7 @@ class Connection:
sent = self.socket.send(self.message) sent = self.socket.send(self.message)
if sent == len(self.message): if sent == len(self.message):
self.status = WAIT_LEN self.status = WAIT_LEN
self.message = '' self.message = b''
self.len = 0 self.len = 0
else: else:
self.message = self.message[sent:] self.message = self.message[sent:]
@ -183,10 +187,10 @@ class Connection:
self.close() self.close()
self.wake_up() self.wake_up()
return return
self.len = '' self.len = 0
if len(message) == 0: if len(message) == 0:
# it was a oneway request, do not write answer # it was a oneway request, do not write answer
self.message = '' self.message = b''
self.status = WAIT_LEN self.status = WAIT_LEN
else: else:
self.message = struct.pack('!i', len(message)) + message self.message = struct.pack('!i', len(message)) + message
@ -219,7 +223,7 @@ class Connection:
self.socket.close() self.socket.close()
class TNonblockingServer: class TNonblockingServer(object):
"""Non-blocking server.""" """Non-blocking server."""
def __init__(self, def __init__(self,
@ -234,7 +238,7 @@ class TNonblockingServer:
self.out_protocol = outputProtocolFactory or self.in_protocol self.out_protocol = outputProtocolFactory or self.in_protocol
self.threads = int(threads) self.threads = int(threads)
self.clients = {} self.clients = {}
self.tasks = Queue.Queue() self.tasks = queue.Queue()
self._read, self._write = socket.socketpair() self._read, self._write = socket.socketpair()
self.prepared = False self.prepared = False
self._stop = False self._stop = False
@ -250,7 +254,7 @@ class TNonblockingServer:
if self.prepared: if self.prepared:
return return
self.socket.listen() self.socket.listen()
for _ in xrange(self.threads): for _ in range(self.threads):
thread = Worker(self.tasks) thread = Worker(self.tasks)
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
@ -259,7 +263,7 @@ class TNonblockingServer:
def wake_up(self): def wake_up(self):
"""Wake up main thread. """Wake up main thread.
The server usualy waits in select call in we should terminate one. The server usually waits in select call in we should terminate one.
The simplest way is using socketpair. The simplest way is using socketpair.
Select always wait to read from the first socket of socketpair. Select always wait to read from the first socket of socketpair.
@ -267,7 +271,7 @@ class TNonblockingServer:
In this case, we can just write anything to the second socket from In this case, we can just write anything to the second socket from
socketpair. socketpair.
""" """
self._write.send('1') self._write.send(b'1')
def stop(self): def stop(self):
"""Stop the server. """Stop the server.
@ -288,7 +292,7 @@ class TNonblockingServer:
"""Does select on open connections.""" """Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()] readable = [self.socket.handle.fileno(), self._read.fileno()]
writable = [] writable = []
for i, connection in self.clients.items(): for i, connection in list(self.clients.items()):
if connection.is_readable(): if connection.is_readable():
readable.append(connection.fileno()) readable.append(connection.fileno())
if connection.is_writeable(): if connection.is_writeable():
@ -330,7 +334,7 @@ class TNonblockingServer:
def close(self): def close(self):
"""Closes the server.""" """Closes the server."""
for _ in xrange(self.threads): for _ in range(self.threads):
self.tasks.put([None, None, None, None, None]) self.tasks.put([None, None, None, None, None])
self.socket.close() self.socket.close()
self.prepared = False self.prepared = False

View file

@ -19,11 +19,14 @@
import logging import logging
from multiprocessing import Process, Value, Condition, reduction
from TServer import TServer from multiprocessing import Process, Value, Condition
from .TServer import TServer
from thrift.transport.TTransport import TTransportException from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
class TProcessPoolServer(TServer): class TProcessPoolServer(TServer):
"""Server with a fixed size pool of worker subprocesses to service requests """Server with a fixed size pool of worker subprocesses to service requests
@ -56,11 +59,13 @@ class TProcessPoolServer(TServer):
while self.isRunning.value: while self.isRunning.value:
try: try:
client = self.serverTransport.accept() client = self.serverTransport.accept()
if not client:
continue
self.serveClient(client) self.serveClient(client)
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
return 0 return 0
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
def serveClient(self, client): def serveClient(self, client):
"""Process input/output from a client for as long as possible""" """Process input/output from a client for as long as possible"""
@ -72,10 +77,10 @@ class TProcessPoolServer(TServer):
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransportException, tx: except TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
otrans.close() otrans.close()
@ -95,8 +100,8 @@ class TProcessPoolServer(TServer):
w.daemon = True w.daemon = True
w.start() w.start()
self.workers.append(w) self.workers.append(w)
except Exception, x: except Exception as x:
logging.exception(x) logger.exception(x)
# wait until the condition is set by stop() # wait until the condition is set by stop()
while True: while True:
@ -107,7 +112,7 @@ class TProcessPoolServer(TServer):
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
break break
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
self.isRunning.value = False self.isRunning.value = False

View file

@ -17,253 +17,260 @@
# under the License. # under the License.
# #
import Queue from six.moves import queue
import logging import logging
import os import os
import sys
import threading import threading
import traceback
from thrift.Thrift import TProcessor
from thrift.protocol import TBinaryProtocol from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport from thrift.transport import TTransport
logger = logging.getLogger(__name__)
class TServer:
"""Base interface for a server, which must have a serve() method.
Three constructors for all servers: class TServer(object):
1) (processor, serverTransport) """Base interface for a server, which must have a serve() method.
2) (processor, serverTransport, transportFactory, protocolFactory)
3) (processor, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory)
"""
def __init__(self, *args):
if (len(args) == 2):
self.__initArgs__(args[0], args[1],
TTransport.TTransportFactoryBase(),
TTransport.TTransportFactoryBase(),
TBinaryProtocol.TBinaryProtocolFactory(),
TBinaryProtocol.TBinaryProtocolFactory())
elif (len(args) == 4):
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
elif (len(args) == 6):
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
def __initArgs__(self, processor, serverTransport, Three constructors for all servers:
inputTransportFactory, outputTransportFactory, 1) (processor, serverTransport)
inputProtocolFactory, outputProtocolFactory): 2) (processor, serverTransport, transportFactory, protocolFactory)
self.processor = processor 3) (processor, serverTransport,
self.serverTransport = serverTransport inputTransportFactory, outputTransportFactory,
self.inputTransportFactory = inputTransportFactory inputProtocolFactory, outputProtocolFactory)
self.outputTransportFactory = outputTransportFactory """
self.inputProtocolFactory = inputProtocolFactory def __init__(self, *args):
self.outputProtocolFactory = outputProtocolFactory if (len(args) == 2):
self.__initArgs__(args[0], args[1],
TTransport.TTransportFactoryBase(),
TTransport.TTransportFactoryBase(),
TBinaryProtocol.TBinaryProtocolFactory(),
TBinaryProtocol.TBinaryProtocolFactory())
elif (len(args) == 4):
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
elif (len(args) == 6):
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
def serve(self): def __initArgs__(self, processor, serverTransport,
pass inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory):
self.processor = processor
self.serverTransport = serverTransport
self.inputTransportFactory = inputTransportFactory
self.outputTransportFactory = outputTransportFactory
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
def serve(self):
pass
class TSimpleServer(TServer): class TSimpleServer(TServer):
"""Simple single-threaded server that just pumps around one transport.""" """Simple single-threaded server that just pumps around one transport."""
def __init__(self, *args): def __init__(self, *args):
TServer.__init__(self, *args) TServer.__init__(self, *args)
def serve(self): def serve(self):
self.serverTransport.listen() self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True: while True:
self.processor.process(iprot, oprot) client = self.serverTransport.accept()
except TTransport.TTransportException, tx: if not client:
pass continue
except Exception as x: itrans = self.inputTransportFactory.getTransport(client)
logging.exception(x) otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close() itrans.close()
otrans.close() otrans.close()
class TThreadedServer(TServer): class TThreadedServer(TServer):
"""Threaded server that spawns a new thread per each connection.""" """Threaded server that spawns a new thread per each connection."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
TServer.__init__(self, *args) TServer.__init__(self, *args)
self.daemon = kwargs.get("daemon", False) self.daemon = kwargs.get("daemon", False)
def serve(self): def serve(self):
self.serverTransport.listen() self.serverTransport.listen()
while True: while True:
try: try:
client = self.serverTransport.accept() client = self.serverTransport.accept()
t = threading.Thread(target=self.handle, args=(client,)) if not client:
t.setDaemon(self.daemon) continue
t.start() t = threading.Thread(target=self.handle, args=(client,))
except KeyboardInterrupt: t.setDaemon(self.daemon)
raise t.start()
except Exception as x: except KeyboardInterrupt:
logging.exception(x) raise
except Exception as x:
logger.exception(x)
def handle(self, client): def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client) itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client) otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans) iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans) oprot = self.outputProtocolFactory.getProtocol(otrans)
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransport.TTransportException, tx: except TTransport.TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
otrans.close() otrans.close()
class TThreadPoolServer(TServer): class TThreadPoolServer(TServer):
"""Server with a fixed size pool of threads which service requests.""" """Server with a fixed size pool of threads which service requests."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
TServer.__init__(self, *args) TServer.__init__(self, *args)
self.clients = Queue.Queue() self.clients = queue.Queue()
self.threads = 10 self.threads = 10
self.daemon = kwargs.get("daemon", False) self.daemon = kwargs.get("daemon", False)
def setNumThreads(self, num): def setNumThreads(self, num):
"""Set the number of worker threads that should be created""" """Set the number of worker threads that should be created"""
self.threads = num self.threads = num
def serveThread(self): def serveThread(self):
"""Loop around getting clients from the shared queue and process them.""" """Loop around getting clients from the shared queue and process them."""
while True: while True:
try: try:
client = self.clients.get() client = self.clients.get()
self.serveClient(client) self.serveClient(client)
except Exception, x: except Exception as x:
logging.exception(x) logger.exception(x)
def serveClient(self, client): def serveClient(self, client):
"""Process input/output from a client for as long as possible""" """Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client) itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client) otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans) iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans) oprot = self.outputProtocolFactory.getProtocol(otrans)
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransport.TTransportException, tx: except TTransport.TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
otrans.close() otrans.close()
def serve(self): def serve(self):
"""Start a fixed number of worker threads and put client into a queue""" """Start a fixed number of worker threads and put client into a queue"""
for i in range(self.threads): for i in range(self.threads):
try: try:
t = threading.Thread(target=self.serveThread) t = threading.Thread(target=self.serveThread)
t.setDaemon(self.daemon) t.setDaemon(self.daemon)
t.start() t.start()
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
# Pump the socket for clients # Pump the socket for clients
self.serverTransport.listen() self.serverTransport.listen()
while True: while True:
try: try:
client = self.serverTransport.accept() client = self.serverTransport.accept()
self.clients.put(client) if not client:
except Exception as x: continue
logging.exception(x) self.clients.put(client)
except Exception as x:
logger.exception(x)
class TForkingServer(TServer): class TForkingServer(TServer):
"""A Thrift server that forks a new process for each request """A Thrift server that forks a new process for each request
This is more scalable than the threaded server as it does not cause This is more scalable than the threaded server as it does not cause
GIL contention. GIL contention.
Note that this has different semantics from the threading server. Note that this has different semantics from the threading server.
Specifically, updates to shared variables will no longer be shared. Specifically, updates to shared variables will no longer be shared.
It will also not work on windows. It will also not work on windows.
This code is heavily inspired by SocketServer.ForkingMixIn in the This code is heavily inspired by SocketServer.ForkingMixIn in the
Python stdlib. Python stdlib.
""" """
def __init__(self, *args): def __init__(self, *args):
TServer.__init__(self, *args) TServer.__init__(self, *args)
self.children = [] self.children = []
def serve(self): def serve(self):
def try_close(file): def try_close(file):
try:
file.close()
except IOError as e:
logging.warning(e, exc_info=True)
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
try:
pid = os.fork()
if pid: # parent
# add before collect, otherwise you race w/ waitpid
self.children.append(pid)
self.collect_children()
# Parent must close socket or the connection may not get
# closed promptly
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
try_close(itrans)
try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0
try:
try: try:
while True: file.close()
self.processor.process(iprot, oprot) except IOError as e:
except TTransport.TTransportException, tx: logger.warning(e, exc_info=True)
pass
except Exception as e:
logging.exception(e)
ecode = 1
finally:
try_close(itrans)
try_close(otrans)
os._exit(ecode) self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
if not client:
continue
try:
pid = os.fork()
except TTransport.TTransportException, tx: if pid: # parent
pass # add before collect, otherwise you race w/ waitpid
except Exception as x: self.children.append(pid)
logging.exception(x) self.collect_children()
def collect_children(self): # Parent must close socket or the connection may not get
while self.children: # closed promptly
try: itrans = self.inputTransportFactory.getTransport(client)
pid, status = os.waitpid(0, os.WNOHANG) otrans = self.outputTransportFactory.getTransport(client)
except os.error: try_close(itrans)
pid = None try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
if pid: iprot = self.inputProtocolFactory.getProtocol(itrans)
self.children.remove(pid) oprot = self.outputProtocolFactory.getProtocol(otrans)
else:
break ecode = 0
try:
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as e:
logger.exception(e)
ecode = 1
finally:
try_close(itrans)
try_close(otrans)
os._exit(ecode)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
def collect_children(self):
while self.children:
try:
pid, status = os.waitpid(0, os.WNOHANG)
except os.error:
pid = None
if pid:
self.children.remove(pid)
else:
break

View file

@ -17,133 +17,178 @@
# under the License. # under the License.
# #
import httplib from io import BytesIO
import os import os
import socket import socket
import sys import sys
import urllib
import urlparse
import warnings import warnings
import base64
from cStringIO import StringIO from six.moves import urllib
from six.moves import http_client
from TTransport import * from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase): class THttpClient(TTransportBase):
"""Http implementation of TTransport base.""" """Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None): def __init__(self, uri_or_host, port=None, path=None):
"""THttpClient supports two different types constructor parameters. """THttpClient supports two different types constructor parameters.
THttpClient(host, port, path) - deprecated THttpClient(host, port, path) - deprecated
THttpClient(uri) THttpClient(uri)
Only the second supports https. Only the second supports https.
""" """
if port is not None: if port is not None:
warnings.warn( warnings.warn(
"Please use the THttpClient('http://host:port/path') syntax", "Please use the THttpClient('http://host:port/path') syntax",
DeprecationWarning, DeprecationWarning,
stacklevel=2) stacklevel=2)
self.host = uri_or_host self.host = uri_or_host
self.port = port self.port = port
assert path assert path
self.path = path self.path = path
self.scheme = 'http' self.scheme = 'http'
else: else:
parsed = urlparse.urlparse(uri_or_host) parsed = urllib.parse.urlparse(uri_or_host)
self.scheme = parsed.scheme self.scheme = parsed.scheme
assert self.scheme in ('http', 'https') assert self.scheme in ('http', 'https')
if self.scheme == 'http': if self.scheme == 'http':
self.port = parsed.port or httplib.HTTP_PORT self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == 'https': elif self.scheme == 'https':
self.port = parsed.port or httplib.HTTPS_PORT self.port = parsed.port or http_client.HTTPS_PORT
self.host = parsed.hostname self.host = parsed.hostname
self.path = parsed.path self.path = parsed.path
if parsed.query: if parsed.query:
self.path += '?%s' % parsed.query self.path += '?%s' % parsed.query
self.__wbuf = StringIO() try:
self.__http = None proxy = urllib.request.getproxies()[self.scheme]
self.__timeout = None except KeyError:
self.__custom_headers = None proxy = None
else:
if urllib.request.proxy_bypass(self.host):
proxy = None
if proxy:
parsed = urllib.parse.urlparse(proxy)
self.realhost = self.host
self.realport = self.port
self.host = parsed.hostname
self.port = parsed.port
self.proxy_auth = self.basic_proxy_auth_header(parsed)
else:
self.realhost = self.realport = self.proxy_auth = None
self.__wbuf = BytesIO()
self.__http = None
self.__http_response = None
self.__timeout = None
self.__custom_headers = None
def open(self): @staticmethod
if self.scheme == 'http': def basic_proxy_auth_header(proxy):
self.__http = httplib.HTTP(self.host, self.port) if proxy is None or not proxy.username:
else: return None
self.__http = httplib.HTTPS(self.host, self.port) ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
urllib.parse.unquote(proxy.password))
cr = base64.b64encode(ap).strip()
return "Basic " + cr
def close(self): def using_proxy(self):
self.__http.close() return self.realhost is not None
self.__http = None
def isOpen(self): def open(self):
return self.__http is not None if self.scheme == 'http':
self.__http = http_client.HTTPConnection(self.host, self.port)
elif self.scheme == 'https':
self.__http = http_client.HTTPSConnection(self.host, self.port)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
def setTimeout(self, ms): def close(self):
if not hasattr(socket, 'getdefaulttimeout'): self.__http.close()
raise NotImplementedError self.__http = None
self.__http_response = None
if ms is None: def isOpen(self):
self.__timeout = None return self.__http is not None
else:
self.__timeout = ms / 1000.0
def setCustomHeaders(self, headers): def setTimeout(self, ms):
self.__custom_headers = headers if not hasattr(socket, 'getdefaulttimeout'):
raise NotImplementedError
def read(self, sz): if ms is None:
return self.__http.file.read(sz) self.__timeout = None
else:
self.__timeout = ms / 1000.0
def write(self, buf): def setCustomHeaders(self, headers):
self.__wbuf.write(buf) self.__custom_headers = headers
def __withTimeout(f): def read(self, sz):
def _f(*args, **kwargs): return self.__http_response.read(sz)
orig_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(args[0].__timeout)
result = f(*args, **kwargs)
socket.setdefaulttimeout(orig_timeout)
return result
return _f
def flush(self): def write(self, buf):
if self.isOpen(): self.__wbuf.write(buf)
self.close()
self.open()
# Pull data out of buffer def __withTimeout(f):
data = self.__wbuf.getvalue() def _f(*args, **kwargs):
self.__wbuf = StringIO() orig_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(args[0].__timeout)
try:
result = f(*args, **kwargs)
finally:
socket.setdefaulttimeout(orig_timeout)
return result
return _f
# HTTP request def flush(self):
self.__http.putrequest('POST', self.path) if self.isOpen():
self.close()
self.open()
# Write headers # Pull data out of buffer
self.__http.putheader('Host', self.host) data = self.__wbuf.getvalue()
self.__http.putheader('Content-Type', 'application/x-thrift') self.__wbuf = BytesIO()
self.__http.putheader('Content-Length', str(len(data)))
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: # HTTP request
user_agent = 'Python/THttpClient' if self.using_proxy() and self.scheme == "http":
script = os.path.basename(sys.argv[0]) # need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel)
if script: self.__http.putrequest('POST', "http://%s:%s%s" %
user_agent = '%s (%s)' % (user_agent, urllib.quote(script)) (self.realhost, self.realport, self.path))
self.__http.putheader('User-Agent', user_agent) else:
self.__http.putrequest('POST', self.path)
if self.__custom_headers: # Write headers
for key, val in self.__custom_headers.iteritems(): self.__http.putheader('Content-Type', 'application/x-thrift')
self.__http.putheader(key, val) self.__http.putheader('Content-Length', str(len(data)))
if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None:
self.__http.putheader("Proxy-Authorization", self.proxy_auth)
self.__http.endheaders() if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
user_agent = 'Python/THttpClient'
script = os.path.basename(sys.argv[0])
if script:
user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
self.__http.putheader('User-Agent', user_agent)
# Write payload if self.__custom_headers:
self.__http.send(data) for key, val in six.iteritems(self.__custom_headers):
self.__http.putheader(key, val)
# Get reply to flush the request self.__http.endheaders()
self.code, self.message, self.headers = self.__http.getreply()
# Decorate if we know how to timeout # Write payload
if hasattr(socket, 'getdefaulttimeout'): self.__http.send(data)
flush = __withTimeout(flush)
# Get reply to flush the request
self.__http_response = self.__http.getresponse()
self.code = self.__http_response.status
self.message = self.__http_response.reason
self.headers = self.__http_response.msg
# Decorate if we know how to timeout
if hasattr(socket, 'getdefaulttimeout'):
flush = __withTimeout(flush)

View file

@ -17,186 +17,380 @@
# under the License. # under the License.
# #
import logging
import os import os
import socket import socket
import ssl import ssl
import sys
import warnings
from .sslcompat import _match_hostname, _match_has_ipaddress
from thrift.transport import TSocket from thrift.transport import TSocket
from thrift.transport.TTransport import TTransportException from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
warnings.filterwarnings(
'default', category=DeprecationWarning, module=__name__)
class TSSLSocket(TSocket.TSocket):
"""
SSL implementation of client-side TSocket
This class creates outbound sockets wrapped using the class TSSLBase(object):
python standard ssl module for encrypted connections. # SSLContext is not available for Python < 2.7.9
_has_ssl_context = sys.hexversion >= 0x020709F0
The protocol used is set using the class variable # ciphers argument is not available for Python < 2.7.0
SSL_VERSION, which must be one of ssl.PROTOCOL_* and _has_ciphers = sys.hexversion >= 0x020700F0
defaults to ssl.PROTOCOL_TLSv1 for greatest security.
"""
SSL_VERSION = ssl.PROTOCOL_TLSv1
def __init__(self, # For pythoon >= 2.7.9, use latest TLS that both client and server
host='localhost', # supports.
port=9090, # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
validate=True, # For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
ca_certs=None, # unavailable.
unix_socket=None): _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
"""Create SSL TSocket ssl.PROTOCOL_TLSv1
@param validate: Set to False to disable SSL certificate validation def _init_context(self, ssl_version):
@type validate: bool if self._has_ssl_context:
@param ca_certs: Filename to the Certificate Authority pem file, possibly a self._context = ssl.SSLContext(ssl_version)
file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to if self._context.protocol == ssl.PROTOCOL_SSLv23:
the ssl_wrap function as the 'ca_certs' parameter. self._context.options |= ssl.OP_NO_SSLv2
@type ca_certs: str self._context.options |= ssl.OP_NO_SSLv3
else:
self._context = None
self._ssl_version = ssl_version
Raises an IOError exception if validate is True and the ca_certs file is @property
None, not present or unreadable. def _should_verify(self):
if self._has_ssl_context:
return self._context.verify_mode != ssl.CERT_NONE
else:
return self.cert_reqs != ssl.CERT_NONE
@property
def ssl_version(self):
if self._has_ssl_context:
return self.ssl_context.protocol
else:
return self._ssl_version
@property
def ssl_context(self):
return self._context
SSL_VERSION = _default_protocol
""" """
self.validate = validate Default SSL version.
self.is_valid = False For backword compatibility, it can be modified.
self.peercert = None Use __init__ keywoard argument "ssl_version" instead.
if not validate: """
self.cert_reqs = ssl.CERT_NONE
else:
self.cert_reqs = ssl.CERT_REQUIRED
self.ca_certs = ca_certs
if validate:
if ca_certs is None or not os.access(ca_certs, os.R_OK):
raise IOError('Certificate Authority ca_certs file "%s" '
'is not readable, cannot validate SSL '
'certificates.' % (ca_certs))
TSocket.TSocket.__init__(self, host, port, unix_socket)
def open(self): def _deprecated_arg(self, args, kwargs, pos, key):
try: if len(args) <= pos:
res0 = self._resolveAddr() return
for res in res0: real_pos = pos + 3
sock_family, sock_type = res[0:2] warnings.warn(
ip_port = res[4] '%dth positional argument is deprecated.'
plain_sock = socket.socket(sock_family, sock_type) 'please use keyward argument insteand.'
self.handle = ssl.wrap_socket(plain_sock, % real_pos, DeprecationWarning, stacklevel=3)
ssl_version=self.SSL_VERSION,
do_handshake_on_connect=True, if key in kwargs:
ca_certs=self.ca_certs, raise TypeError(
cert_reqs=self.cert_reqs) 'Duplicate argument: %dth argument and %s keyward argument.'
self.handle.settimeout(self._timeout) % (real_pos, key))
kwargs[key] = args[pos]
def _unix_socket_arg(self, host, port, args, kwargs):
key = 'unix_socket'
if host is None and port is None and len(args) == 1 and key not in kwargs:
kwargs[key] = args[0]
return True
return False
def __getattr__(self, key):
if key == 'SSL_VERSION':
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version attribute instead.',
DeprecationWarning, stacklevel=2)
return self.ssl_version
def __init__(self, server_side, host, ssl_opts):
self._server_side = server_side
if TSSLBase.SSL_VERSION != self._default_protocol:
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version keyward argument instead.',
DeprecationWarning, stacklevel=2)
self._context = ssl_opts.pop('ssl_context', None)
self._server_hostname = None
if not self._server_side:
self._server_hostname = ssl_opts.pop('server_hostname', host)
if self._context:
self._custom_context = True
if ssl_opts:
raise ValueError(
'Incompatible arguments: ssl_context and %s'
% ' '.join(ssl_opts.keys()))
if not self._has_ssl_context:
raise ValueError(
'ssl_context is not available for this version of Python')
else:
self._custom_context = False
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
self._init_context(ssl_version)
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
self.ca_certs = ssl_opts.pop('ca_certs', None)
self.keyfile = ssl_opts.pop('keyfile', None)
self.certfile = ssl_opts.pop('certfile', None)
self.ciphers = ssl_opts.pop('ciphers', None)
if ssl_opts:
raise ValueError(
'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
if self._should_verify:
if not self.ca_certs:
raise ValueError(
'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
if not os.access(self.ca_certs, os.R_OK):
raise IOError('Certificate Authority ca_certs file "%s" '
'is not readable, cannot validate SSL '
'certificates.' % (self.ca_certs))
@property
def certfile(self):
return self._certfile
@certfile.setter
def certfile(self, certfile):
if self._server_side and not certfile:
raise ValueError('certfile is needed for server-side')
if certfile and not os.access(certfile, os.R_OK):
raise IOError('No such certfile found: %s' % (certfile))
self._certfile = certfile
def _wrap_socket(self, sock):
if self._has_ssl_context:
if not self._custom_context:
self.ssl_context.verify_mode = self.cert_reqs
if self.certfile:
self.ssl_context.load_cert_chain(self.certfile,
self.keyfile)
if self.ciphers:
self.ssl_context.set_ciphers(self.ciphers)
if self.ca_certs:
self.ssl_context.load_verify_locations(self.ca_certs)
return self.ssl_context.wrap_socket(
sock, server_side=self._server_side,
server_hostname=self._server_hostname)
else:
ssl_opts = {
'ssl_version': self._ssl_version,
'server_side': self._server_side,
'ca_certs': self.ca_certs,
'keyfile': self.keyfile,
'certfile': self.certfile,
'cert_reqs': self.cert_reqs,
}
if self.ciphers:
if self._has_ciphers:
ssl_opts['ciphers'] = self.ciphers
else:
logger.warning(
'ciphers is specified but ignored due to old Python version')
return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
"""
SSL implementation of TSocket
This class creates outbound sockets wrapped using the
python standard ssl module for encrypted connections.
"""
# New signature
# def __init__(self, host='localhost', port=9090, unix_socket=None,
# **ssl_args):
# Deprecated signature
# def __init__(self, host='localhost', port=9090, validate=True,
# ca_certs=None, keyfile=None, certfile=None,
# unix_socket=None, ciphers=None):
def __init__(self, host='localhost', port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
``ssl_version``, ``ca_certs``,
``ciphers`` (Python 2.7.0 or later),
``server_hostname`` (Python 2.7.9 or later)
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
"""
self.is_valid = False
self.peercert = None
if args:
if len(args) > 6:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'validate')
self._deprecated_arg(args, kwargs, 1, 'ca_certs')
self._deprecated_arg(args, kwargs, 2, 'keyfile')
self._deprecated_arg(args, kwargs, 3, 'certfile')
self._deprecated_arg(args, kwargs, 4, 'unix_socket')
self._deprecated_arg(args, kwargs, 5, 'ciphers')
validate = kwargs.pop('validate', None)
if validate is not None:
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
warnings.warn(
'validate is deprecated. please use cert_reqs=ssl.%s instead'
% cert_reqs_name,
DeprecationWarning, stacklevel=2)
if 'cert_reqs' in kwargs:
raise TypeError('Cannot specify both validate and cert_reqs')
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
unix_socket = kwargs.pop('unix_socket', None)
self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, False, host, kwargs)
TSocket.TSocket.__init__(self, host, port, unix_socket)
@property
def validate(self):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
return self.cert_reqs != ssl.CERT_NONE
@validate.setter
def validate(self, value):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
def _do_open(self, family, socktype):
plain_sock = socket.socket(family, socktype)
try: try:
self.handle.connect(ip_port) return self._wrap_socket(plain_sock)
except socket.error as e: except Exception:
if res is not res0[-1]: plain_sock.close()
continue msg = 'failed to initialize SSL'
else: logger.exception(msg)
raise e raise TTransportException(TTransportException.NOT_OPEN, msg)
break
except socket.error as e:
if self._unix_socket:
message = 'Could not connect to secure socket %s' % self._unix_socket
else:
message = 'Could not connect to %s:%d' % (self.host, self.port)
raise TTransportException(type=TTransportException.NOT_OPEN,
message=message)
if self.validate:
self._validate_cert()
def _validate_cert(self): def open(self):
"""internal method to validate the peer's SSL certificate, and to check the super(TSSLSocket, self).open()
commonName of the certificate to ensure it matches the hostname we if self._should_verify:
used to make this connection. Does not support subjectAltName records self.peercert = self.handle.getpeercert()
in certificates. try:
self._validate_callback(self.peercert, self._server_hostname)
self.is_valid = True
except TTransportException:
raise
except Exception as ex:
raise TTransportException(TTransportException.UNKNOWN, str(ex))
raises TTransportException if the certificate fails validation.
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
"""SSL implementation of TServerSocket
This uses the ssl module's wrap_socket() method to provide SSL
negotiated encryption.
""" """
cert = self.handle.getpeercert()
self.peercert = cert
if 'subject' not in cert:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message='No SSL certificate found from %s:%s' % (self.host, self.port))
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
if not isinstance(field, tuple):
continue
cert_pair = field[0]
if len(cert_pair) < 2:
continue
cert_key, cert_value = cert_pair[0:2]
if cert_key != 'commonName':
continue
certhost = cert_value
if certhost == self.host:
# success, cert commonName matches desired hostname
self.is_valid = True
return
else:
raise TTransportException(
type=TTransportException.UNKNOWN,
message='Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
type=TTransportException.UNKNOWN,
message='Could not validate SSL certificate from '
'host "%s". Cert=%s' % (self.host, cert))
# New signature
# def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
# Deprecated signature
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
def __init__(self, host=None, port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
class TSSLServerSocket(TSocket.TServerSocket): Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
"""SSL implementation of TServerSocket ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
See ssl.wrap_socket documentation.
This uses the ssl module's wrap_socket() method to provide SSL Alternative keyword arguments: (Python 2.7.9 or later)
negotiated encryption. ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
""" ``server_hostname``: Passed to SSLContext.wrap_socket
SSL_VERSION = ssl.PROTOCOL_TLSv1
def __init__(self, Common keyword argument:
host=None, ``validate_callback`` (cert, hostname) -> None:
port=9090, Called after SSL handshake. Can raise when hostname does not
certfile='cert.pem', match the cert.
unix_socket=None): """
"""Initialize a TSSLServerSocket if args:
if len(args) > 3:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'certfile')
self._deprecated_arg(args, kwargs, 1, 'unix_socket')
self._deprecated_arg(args, kwargs, 2, 'ciphers')
@param certfile: filename of the server certificate, defaults to cert.pem if 'ssl_context' not in kwargs:
@type certfile: str # Preserve existing behaviors for default values
@param host: The hostname or IP to bind the listen socket to, if 'cert_reqs' not in kwargs:
i.e. 'localhost' for only allowing local network connections. kwargs['cert_reqs'] = ssl.CERT_NONE
Pass None to bind to all interfaces. if'certfile' not in kwargs:
@type host: str kwargs['certfile'] = 'cert.pem'
@param port: The port to listen on for inbound connections.
@type port: int
"""
self.setCertfile(certfile)
TSocket.TServerSocket.__init__(self, host, port)
def setCertfile(self, certfile): unix_socket = kwargs.pop('unix_socket', None)
"""Set or change the server certificate file used to wrap new connections. self._validate_callback = \
kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, True, None, kwargs)
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
if self._should_verify and not _match_has_ipaddress:
raise ValueError('Need ipaddress and backports.ssl_match_hostname '
'module to verify client certificate')
@param certfile: The filename of the server certificate, def setCertfile(self, certfile):
i.e. '/etc/certs/server.pem' """Set or change the server certificate file used to wrap new
@type certfile: str connections.
Raises an IOError exception if the certfile is not present or unreadable. @param certfile: The filename of the server certificate,
""" i.e. '/etc/certs/server.pem'
if not os.access(certfile, os.R_OK): @type certfile: str
raise IOError('No such certfile found: %s' % (certfile))
self.certfile = certfile
def accept(self): Raises an IOError exception if the certfile is not present or unreadable.
plain_client, addr = self.handle.accept() """
try: warnings.warn(
client = ssl.wrap_socket(plain_client, certfile=self.certfile, 'setCertfile is deprecated. please use certfile property instead.',
server_side=True, ssl_version=self.SSL_VERSION) DeprecationWarning, stacklevel=2)
except ssl.SSLError as ssl_exc: self.certfile = certfile
# failed handshake/ssl wrap, close socket to client
plain_client.close() def accept(self):
# raise ssl_exc plain_client, addr = self.handle.accept()
# We can't raise the exception, because it kills most TServer derived try:
# serve() methods. client = self._wrap_socket(plain_client)
# Instead, return None, and let the TServer instance deal with it in except ssl.SSLError:
# other exception handling. (but TSimpleServer dies anyway) logger.exception('Error while accepting from %s', addr)
return None # failed handshake/ssl wrap, close socket to client
result = TSocket.TSocket() plain_client.close()
result.setHandle(client) # raise
return result # We can't raise the exception, because it kills most TServer derived
# serve() methods.
# Instead, return None, and let the TServer instance deal with it in
# other exception handling. (but TSimpleServer dies anyway)
return None
if self._should_verify:
client.peercert = client.getpeercert()
try:
self._validate_callback(client.peercert, addr[0])
client.is_valid = True
except Exception:
logger.warn('Failed to validate client certificate address: %s',
addr[0], exc_info=True)
client.close()
plain_client.close()
return None
result = TSocket.TSocket()
result.handle = client
return result

View file

@ -18,159 +18,175 @@
# #
import errno import errno
import logging
import os import os
import socket import socket
import sys import sys
from TTransport import * from .TTransport import TTransportBase, TTransportException, TServerTransportBase
logger = logging.getLogger(__name__)
class TSocketBase(TTransportBase): class TSocketBase(TTransportBase):
def _resolveAddr(self): def _resolveAddr(self):
if self._unix_socket is not None: if self._unix_socket is not None:
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
self._unix_socket)] self._unix_socket)]
else: else:
return socket.getaddrinfo(self.host, return socket.getaddrinfo(self.host,
self.port, self.port,
socket.AF_UNSPEC, self._socket_family,
socket.SOCK_STREAM, socket.SOCK_STREAM,
0, 0,
socket.AI_PASSIVE | socket.AI_ADDRCONFIG) socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
def close(self): def close(self):
if self.handle: if self.handle:
self.handle.close() self.handle.close()
self.handle = None self.handle = None
class TSocket(TSocketBase): class TSocket(TSocketBase):
"""Socket implementation of TTransport base.""" """Socket implementation of TTransport base."""
def __init__(self, host='localhost', port=9090, unix_socket=None): def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
"""Initialize a TSocket """Initialize a TSocket
@param host(str) The host to connect to. @param host(str) The host to connect to.
@param port(int) The (TCP) port to connect to. @param port(int) The (TCP) port to connect to.
@param unix_socket(str) The filename of a unix socket to connect to. @param unix_socket(str) The filename of a unix socket to connect to.
(host and port will be ignored.) (host and port will be ignored.)
""" @param socket_family(int) The socket family to use with this socket.
self.host = host """
self.port = port self.host = host
self.handle = None self.port = port
self._unix_socket = unix_socket self.handle = None
self._timeout = None self._unix_socket = unix_socket
self._timeout = None
self._socket_family = socket_family
def setHandle(self, h): def setHandle(self, h):
self.handle = h self.handle = h
def isOpen(self): def isOpen(self):
return self.handle is not None return self.handle is not None
def setTimeout(self, ms): def setTimeout(self, ms):
if ms is None: if ms is None:
self._timeout = None self._timeout = None
else: else:
self._timeout = ms / 1000.0 self._timeout = ms / 1000.0
if self.handle is not None: if self.handle is not None:
self.handle.settimeout(self._timeout) self.handle.settimeout(self._timeout)
def open(self): def _do_open(self, family, socktype):
try: return socket.socket(family, socktype)
res0 = self._resolveAddr()
for res in res0: @property
self.handle = socket.socket(res[0], res[1]) def _address(self):
self.handle.settimeout(self._timeout) return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)
def open(self):
if self.handle:
raise TTransportException(TTransportException.ALREADY_OPEN)
try: try:
self.handle.connect(res[4]) addrs = self._resolveAddr()
except socket.error, e: except socket.gaierror:
if res is not res0[-1]: msg = 'failed to resolve sockaddr for ' + str(self._address)
continue logger.exception(msg)
else: raise TTransportException(TTransportException.NOT_OPEN, msg)
raise e for family, socktype, _, _, sockaddr in addrs:
break handle = self._do_open(family, socktype)
except socket.error, e: handle.settimeout(self._timeout)
if self._unix_socket: try:
message = 'Could not connect to socket %s' % self._unix_socket handle.connect(sockaddr)
else: self.handle = handle
message = 'Could not connect to %s:%d' % (self.host, self.port) return
raise TTransportException(type=TTransportException.NOT_OPEN, except socket.error:
message=message) handle.close()
logger.info('Could not connect to %s', sockaddr, exc_info=True)
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
addrs))
logger.error(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
def read(self, sz): def read(self, sz):
try: try:
buff = self.handle.recv(sz) buff = self.handle.recv(sz)
except socket.error, e: except socket.error as e:
if (e.args[0] == errno.ECONNRESET and if (e.args[0] == errno.ECONNRESET and
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
# freebsd and Mach don't follow POSIX semantic of recv # freebsd and Mach don't follow POSIX semantic of recv
# and fail with ECONNRESET if peer performed shutdown. # and fail with ECONNRESET if peer performed shutdown.
# See corresponding comment and code in TSocket::read() # See corresponding comment and code in TSocket::read()
# in lib/cpp/src/transport/TSocket.cpp. # in lib/cpp/src/transport/TSocket.cpp.
self.close() self.close()
# Trigger the check to raise the END_OF_FILE exception below. # Trigger the check to raise the END_OF_FILE exception below.
buff = '' buff = ''
else: else:
raise raise
if len(buff) == 0: if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE, raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes') message='TSocket read 0 bytes')
return buff return buff
def write(self, buff): def write(self, buff):
if not self.handle: if not self.handle:
raise TTransportException(type=TTransportException.NOT_OPEN, raise TTransportException(type=TTransportException.NOT_OPEN,
message='Transport not open') message='Transport not open')
sent = 0 sent = 0
have = len(buff) have = len(buff)
while sent < have: while sent < have:
plus = self.handle.send(buff) plus = self.handle.send(buff)
if plus == 0: if plus == 0:
raise TTransportException(type=TTransportException.END_OF_FILE, raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket sent 0 bytes') message='TSocket sent 0 bytes')
sent += plus sent += plus
buff = buff[plus:] buff = buff[plus:]
def flush(self): def flush(self):
pass pass
class TServerSocket(TSocketBase, TServerTransportBase): class TServerSocket(TSocketBase, TServerTransportBase):
"""Socket implementation of TServerTransport base.""" """Socket implementation of TServerTransport base."""
def __init__(self, host=None, port=9090, unix_socket=None): def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
self.host = host self.host = host
self.port = port self.port = port
self._unix_socket = unix_socket self._unix_socket = unix_socket
self.handle = None self._socket_family = socket_family
self.handle = None
def listen(self): def listen(self):
res0 = self._resolveAddr() res0 = self._resolveAddr()
for res in res0: socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
if res[0] is socket.AF_INET6 or res is res0[-1]: for res in res0:
break if res[0] is socket_family or res is res0[-1]:
break
# We need remove the old unix socket if the file exists and # We need remove the old unix socket if the file exists and
# nobody is listening on it. # nobody is listening on it.
if self._unix_socket: if self._unix_socket:
tmp = socket.socket(res[0], res[1]) tmp = socket.socket(res[0], res[1])
try: try:
tmp.connect(res[4]) tmp.connect(res[4])
except socket.error, err: except socket.error as err:
eno, message = err.args eno, message = err.args
if eno == errno.ECONNREFUSED: if eno == errno.ECONNREFUSED:
os.unlink(res[4]) os.unlink(res[4])
self.handle = socket.socket(res[0], res[1]) self.handle = socket.socket(res[0], res[1])
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(self.handle, 'settimeout'): if hasattr(self.handle, 'settimeout'):
self.handle.settimeout(None) self.handle.settimeout(None)
self.handle.bind(res[4]) self.handle.bind(res[4])
self.handle.listen(128) self.handle.listen(128)
def accept(self): def accept(self):
client, addr = self.handle.accept() client, addr = self.handle.accept()
result = TSocket() result = TSocket()
result.setHandle(client) result.setHandle(client)
return result return result

View file

@ -17,314 +17,436 @@
# under the License. # under the License.
# #
from cStringIO import StringIO
from struct import pack, unpack from struct import pack, unpack
from thrift.Thrift import TException from thrift.Thrift import TException
from ..compat import BufferIO
class TTransportException(TException): class TTransportException(TException):
"""Custom Transport Exception class""" """Custom Transport Exception class"""
UNKNOWN = 0 UNKNOWN = 0
NOT_OPEN = 1 NOT_OPEN = 1
ALREADY_OPEN = 2 ALREADY_OPEN = 2
TIMED_OUT = 3 TIMED_OUT = 3
END_OF_FILE = 4 END_OF_FILE = 4
NEGATIVE_SIZE = 5
SIZE_LIMIT = 6
def __init__(self, type=UNKNOWN, message=None): def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message) TException.__init__(self, message)
self.type = type self.type = type
class TTransportBase: class TTransportBase(object):
"""Base class for Thrift transport layer.""" """Base class for Thrift transport layer."""
def isOpen(self): def isOpen(self):
pass pass
def open(self): def open(self):
pass pass
def close(self): def close(self):
pass pass
def read(self, sz): def read(self, sz):
pass pass
def readAll(self, sz): def readAll(self, sz):
buff = '' buff = b''
have = 0 have = 0
while (have < sz): while (have < sz):
chunk = self.read(sz - have) chunk = self.read(sz - have)
have += len(chunk) have += len(chunk)
buff += chunk buff += chunk
if len(chunk) == 0: if len(chunk) == 0:
raise EOFError() raise EOFError()
return buff return buff
def write(self, buf): def write(self, buf):
pass pass
def flush(self): def flush(self):
pass pass
# This class should be thought of as an interface. # This class should be thought of as an interface.
class CReadableTransport: class CReadableTransport(object):
"""base class for transports that are readable from C""" """base class for transports that are readable from C"""
# TODO(dreiss): Think about changing this interface to allow us to use # TODO(dreiss): Think about changing this interface to allow us to use
# a (Python, not c) StringIO instead, because it allows # a (Python, not c) StringIO instead, because it allows
# you to write after reading. # you to write after reading.
# NOTE: This is a classic class, so properties will NOT work # NOTE: This is a classic class, so properties will NOT work
# correctly for setting. # correctly for setting.
@property @property
def cstringio_buf(self): def cstringio_buf(self):
"""A cStringIO buffer that contains the current chunk we are reading.""" """A cStringIO buffer that contains the current chunk we are reading."""
pass pass
def cstringio_refill(self, partialread, reqlen): def cstringio_refill(self, partialread, reqlen):
"""Refills cstringio_buf. """Refills cstringio_buf.
Returns the currently used buffer (which can but need not be the same as Returns the currently used buffer (which can but need not be the same as
the old cstringio_buf). partialread is what the C code has read from the the old cstringio_buf). partialread is what the C code has read from the
buffer, and should be inserted into the buffer before any more reads. The buffer, and should be inserted into the buffer before any more reads. The
return value must be a new, not borrowed reference. Something along the return value must be a new, not borrowed reference. Something along the
lines of self._buf should be fine. lines of self._buf should be fine.
If reqlen bytes can't be read, throw EOFError. If reqlen bytes can't be read, throw EOFError.
""" """
pass pass
class TServerTransportBase: class TServerTransportBase(object):
"""Base class for Thrift server transports.""" """Base class for Thrift server transports."""
def listen(self): def listen(self):
pass pass
def accept(self): def accept(self):
pass pass
def close(self): def close(self):
pass pass
class TTransportFactoryBase: class TTransportFactoryBase(object):
"""Base class for a Transport Factory""" """Base class for a Transport Factory"""
def getTransport(self, trans): def getTransport(self, trans):
return trans return trans
class TBufferedTransportFactory: class TBufferedTransportFactory(object):
"""Factory transport that builds buffered transports""" """Factory transport that builds buffered transports"""
def getTransport(self, trans): def getTransport(self, trans):
buffered = TBufferedTransport(trans) buffered = TBufferedTransport(trans)
return buffered return buffered
class TBufferedTransport(TTransportBase, CReadableTransport): class TBufferedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and buffers its I/O. """Class that wraps another transport and buffers its I/O.
The implementation uses a (configurable) fixed-size read buffer The implementation uses a (configurable) fixed-size read buffer
but buffers all writes until a flush is performed. but buffers all writes until a flush is performed.
""" """
DEFAULT_BUFFER = 4096 DEFAULT_BUFFER = 4096
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
self.__trans = trans self.__trans = trans
self.__wbuf = StringIO() self.__wbuf = BufferIO()
self.__rbuf = StringIO("") # Pass string argument to initialize read buffer as cStringIO.InputType
self.__rbuf_size = rbuf_size self.__rbuf = BufferIO(b'')
self.__rbuf_size = rbuf_size
def isOpen(self): def isOpen(self):
return self.__trans.isOpen() return self.__trans.isOpen()
def open(self): def open(self):
return self.__trans.open() return self.__trans.open()
def close(self): def close(self):
return self.__trans.close() return self.__trans.close()
def read(self, sz): def read(self, sz):
ret = self.__rbuf.read(sz) ret = self.__rbuf.read(sz)
if len(ret) != 0: if len(ret) != 0:
return ret return ret
self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
return self.__rbuf.read(sz)
self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size))) def write(self, buf):
return self.__rbuf.read(sz) try:
self.__wbuf.write(buf)
except Exception as e:
# on exception reset wbuf so it doesn't contain a partial function call
self.__wbuf = BufferIO()
raise e
self.__wbuf.getvalue()
def write(self, buf): def flush(self):
self.__wbuf.write(buf) out = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BufferIO()
self.__trans.write(out)
self.__trans.flush()
def flush(self): # Implement the CReadableTransport interface.
out = self.__wbuf.getvalue() @property
# reset wbuf before write/flush to preserve state on underlying failure def cstringio_buf(self):
self.__wbuf = StringIO() return self.__rbuf
self.__trans.write(out)
self.__trans.flush()
# Implement the CReadableTransport interface. def cstringio_refill(self, partialread, reqlen):
@property retstring = partialread
def cstringio_buf(self): if reqlen < self.__rbuf_size:
return self.__rbuf # try to make a read of as much as we can.
retstring += self.__trans.read(self.__rbuf_size)
def cstringio_refill(self, partialread, reqlen): # but make sure we do read reqlen bytes.
retstring = partialread if len(retstring) < reqlen:
if reqlen < self.__rbuf_size: retstring += self.__trans.readAll(reqlen - len(retstring))
# try to make a read of as much as we can.
retstring += self.__trans.read(self.__rbuf_size)
# but make sure we do read reqlen bytes. self.__rbuf = BufferIO(retstring)
if len(retstring) < reqlen: return self.__rbuf
retstring += self.__trans.readAll(reqlen - len(retstring))
self.__rbuf = StringIO(retstring)
return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport): class TMemoryBuffer(TTransportBase, CReadableTransport):
"""Wraps a cStringIO object as a TTransport. """Wraps a cBytesIO object as a TTransport.
NOTE: Unlike the C++ version of this class, you cannot write to it NOTE: Unlike the C++ version of this class, you cannot write to it
then immediately read from it. If you want to read from a then immediately read from it. If you want to read from a
TMemoryBuffer, you must either pass a string to the constructor. TMemoryBuffer, you must either pass a string to the constructor.
TODO(dreiss): Make this work like the C++ version. TODO(dreiss): Make this work like the C++ version.
""" """
def __init__(self, value=None): def __init__(self, value=None):
"""value -- a value to read from for stringio """value -- a value to read from for stringio
If value is set, this will be a transport for reading, If value is set, this will be a transport for reading,
otherwise, it is for writing""" otherwise, it is for writing"""
if value is not None: if value is not None:
self._buffer = StringIO(value) self._buffer = BufferIO(value)
else: else:
self._buffer = StringIO() self._buffer = BufferIO()
def isOpen(self): def isOpen(self):
return not self._buffer.closed return not self._buffer.closed
def open(self): def open(self):
pass pass
def close(self): def close(self):
self._buffer.close() self._buffer.close()
def read(self, sz): def read(self, sz):
return self._buffer.read(sz) return self._buffer.read(sz)
def write(self, buf): def write(self, buf):
self._buffer.write(buf) self._buffer.write(buf)
def flush(self): def flush(self):
pass pass
def getvalue(self): def getvalue(self):
return self._buffer.getvalue() return self._buffer.getvalue()
# Implement the CReadableTransport interface. # Implement the CReadableTransport interface.
@property @property
def cstringio_buf(self): def cstringio_buf(self):
return self._buffer return self._buffer
def cstringio_refill(self, partialread, reqlen): def cstringio_refill(self, partialread, reqlen):
# only one shot at reading... # only one shot at reading...
raise EOFError() raise EOFError()
class TFramedTransportFactory: class TFramedTransportFactory(object):
"""Factory transport that builds framed transports""" """Factory transport that builds framed transports"""
def getTransport(self, trans): def getTransport(self, trans):
framed = TFramedTransport(trans) framed = TFramedTransport(trans)
return framed return framed
class TFramedTransport(TTransportBase, CReadableTransport): class TFramedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and frames its I/O when writing.""" """Class that wraps another transport and frames its I/O when writing."""
def __init__(self, trans,): def __init__(self, trans,):
self.__trans = trans self.__trans = trans
self.__rbuf = StringIO() self.__rbuf = BufferIO(b'')
self.__wbuf = StringIO() self.__wbuf = BufferIO()
def isOpen(self): def isOpen(self):
return self.__trans.isOpen() return self.__trans.isOpen()
def open(self): def open(self):
return self.__trans.open() return self.__trans.open()
def close(self): def close(self):
return self.__trans.close() return self.__trans.close()
def read(self, sz): def read(self, sz):
ret = self.__rbuf.read(sz) ret = self.__rbuf.read(sz)
if len(ret) != 0: if len(ret) != 0:
return ret return ret
self.readFrame() self.readFrame()
return self.__rbuf.read(sz) return self.__rbuf.read(sz)
def readFrame(self): def readFrame(self):
buff = self.__trans.readAll(4) buff = self.__trans.readAll(4)
sz, = unpack('!i', buff) sz, = unpack('!i', buff)
self.__rbuf = StringIO(self.__trans.readAll(sz)) self.__rbuf = BufferIO(self.__trans.readAll(sz))
def write(self, buf): def write(self, buf):
self.__wbuf.write(buf) self.__wbuf.write(buf)
def flush(self): def flush(self):
wout = self.__wbuf.getvalue() wout = self.__wbuf.getvalue()
wsz = len(wout) wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure # reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = StringIO() self.__wbuf = BufferIO()
# N.B.: Doing this string concatenation is WAY cheaper than making # N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in # two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty # Python turn out to be REALLY expensive, but it seems to do a pretty
# good job of managing string buffer operations without excessive copies # good job of managing string buffer operations without excessive copies
buf = pack("!i", wsz) + wout buf = pack("!i", wsz) + wout
self.__trans.write(buf) self.__trans.write(buf)
self.__trans.flush() self.__trans.flush()
# Implement the CReadableTransport interface. # Implement the CReadableTransport interface.
@property @property
def cstringio_buf(self): def cstringio_buf(self):
return self.__rbuf return self.__rbuf
def cstringio_refill(self, prefix, reqlen): def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't # self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore, # ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately. # we can start reading new frames immediately.
while len(prefix) < reqlen: while len(prefix) < reqlen:
self.readFrame() self.readFrame()
prefix += self.__rbuf.getvalue() prefix += self.__rbuf.getvalue()
self.__rbuf = StringIO(prefix) self.__rbuf = BufferIO(prefix)
return self.__rbuf return self.__rbuf
class TFileObjectTransport(TTransportBase): class TFileObjectTransport(TTransportBase):
"""Wraps a file-like object to make it work as a Thrift transport.""" """Wraps a file-like object to make it work as a Thrift transport."""
def __init__(self, fileobj): def __init__(self, fileobj):
self.fileobj = fileobj self.fileobj = fileobj
def isOpen(self): def isOpen(self):
return True return True
def close(self): def close(self):
self.fileobj.close() self.fileobj.close()
def read(self, sz): def read(self, sz):
return self.fileobj.read(sz) return self.fileobj.read(sz)
def write(self, buf): def write(self, buf):
self.fileobj.write(buf) self.fileobj.write(buf)
def flush(self): def flush(self):
self.fileobj.flush() self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
"""
SASL transport
"""
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
def __init__(self, transport, host, service, mechanism='GSSAPI',
**sasl_kwargs):
"""
transport: an underlying transport to use, typically just a TSocket
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.transport = transport
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
self.__wbuf = BufferIO()
self.__rbuf = BufferIO(b'')
def open(self):
if not self.transport.isOpen():
self.transport.open()
self.send_sasl_msg(self.START, self.sasl.mechanism)
self.send_sasl_msg(self.OK, self.sasl.process())
while True:
status, challenge = self.recv_sasl_msg()
if status == self.OK:
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
elif status == self.COMPLETE:
if not self.sasl.complete:
raise TTransportException(
TTransportException.NOT_OPEN,
"The server erroneously indicated "
"that SASL negotiation was complete")
else:
break
else:
raise TTransportException(
TTransportException.NOT_OPEN,
"Bad SASL negotiation status: %d (%s)"
% (status, challenge))
def send_sasl_msg(self, status, body):
header = pack(">BI", status, len(body))
self.transport.write(header + body)
self.transport.flush()
def recv_sasl_msg(self):
header = self.transport.readAll(5)
status, length = unpack(">BI", header)
if length > 0:
payload = self.transport.readAll(length)
else:
payload = ""
return status, payload
def write(self, data):
self.__wbuf.write(data)
def flush(self):
data = self.__wbuf.getvalue()
encoded = self.sasl.wrap(data)
self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
self.transport.flush()
self.__wbuf = BufferIO()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self._read_frame()
return self.__rbuf.read(sz)
def _read_frame(self):
header = self.transport.readAll(4)
length, = unpack('!i', header)
encoded = self.transport.readAll(length)
self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
def close(self):
self.sasl.dispose()
self.transport.close()
# based on TFramedTransport
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self._read_frame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf

View file

@ -17,14 +17,15 @@
# under the License. # under the License.
# #
from cStringIO import StringIO from io import BytesIO
import struct
from zope.interface import implements, Interface, Attribute from zope.interface import implements, Interface, Attribute
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone connectionDone
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.threads import deferToThread
from twisted.protocols import basic from twisted.protocols import basic
from twisted.python import log
from twisted.web import server, resource, http from twisted.web import server, resource, http
from thrift.transport import TTransport from thrift.transport import TTransport
@ -33,15 +34,15 @@ from thrift.transport import TTransport
class TMessageSenderTransport(TTransport.TTransportBase): class TMessageSenderTransport(TTransport.TTransportBase):
def __init__(self): def __init__(self):
self.__wbuf = StringIO() self.__wbuf = BytesIO()
def write(self, buf): def write(self, buf):
self.__wbuf.write(buf) self.__wbuf.write(buf)
def flush(self): def flush(self):
msg = self.__wbuf.getvalue() msg = self.__wbuf.getvalue()
self.__wbuf = StringIO() self.__wbuf = BytesIO()
self.sendMessage(msg) return self.sendMessage(msg)
def sendMessage(self, message): def sendMessage(self, message):
raise NotImplementedError raise NotImplementedError
@ -54,7 +55,7 @@ class TCallbackTransport(TMessageSenderTransport):
self.func = func self.func = func
def sendMessage(self, message): def sendMessage(self, message):
self.func(message) return self.func(message)
class ThriftClientProtocol(basic.Int32StringReceiver): class ThriftClientProtocol(basic.Int32StringReceiver):
@ -81,11 +82,18 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
self.started.callback(self.client) self.started.callback(self.client)
def connectionLost(self, reason=connectionDone): def connectionLost(self, reason=connectionDone):
for k, v in self.client._reqs.iteritems(): # the called errbacks can add items to our client's _reqs,
# so we need to use a tmp, and iterate until no more requests
# are added during errbacks
if self.client:
tex = TTransport.TTransportException( tex = TTransport.TTransportException(
type=TTransport.TTransportException.END_OF_FILE, type=TTransport.TTransportException.END_OF_FILE,
message='Connection closed') message='Connection closed (%s)' % reason)
v.errback(tex) while self.client._reqs:
_, v = self.client._reqs.popitem()
v.errback(tex)
del self.client._reqs
self.client = None
def stringReceived(self, frame): def stringReceived(self, frame):
tr = TTransport.TMemoryBuffer(frame) tr = TTransport.TMemoryBuffer(frame)
@ -101,6 +109,108 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
method(iprot, mtype, rseqid) method(iprot, mtype, rseqid)
class ThriftSASLClientProtocol(ThriftClientProtocol):
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.SASLCLient = SASLClient
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
self._sasl_negotiation_deferred = None
self._sasl_negotiation_status = None
self.client = None
if host is not None:
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
def createSASLClient(self, host, service, mechanism, **kwargs):
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
def dispatch(self, msg):
encoded = self.sasl.wrap(msg)
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
ThriftClientProtocol.dispatch(self, len_and_encoded)
@defer.inlineCallbacks
def connectionMade(self):
self._sendSASLMessage(self.START, self.sasl.mechanism)
initial_message = yield deferToThread(self.sasl.process)
self._sendSASLMessage(self.OK, initial_message)
while True:
status, challenge = yield self._receiveSASLMessage()
if status == self.OK:
response = yield deferToThread(self.sasl.process, challenge)
self._sendSASLMessage(self.OK, response)
elif status == self.COMPLETE:
if not self.sasl.complete:
msg = "The server erroneously indicated that SASL " \
"negotiation was complete"
raise TTransport.TTransportException(msg, message=msg)
else:
break
else:
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
raise TTransport.TTransportException(msg, message=msg)
self._sasl_negotiation_deferred = None
ThriftClientProtocol.connectionMade(self)
def _sendSASLMessage(self, status, body):
if body is None:
body = ""
header = struct.pack(">BI", status, len(body))
self.transport.write(header + body)
def _receiveSASLMessage(self):
self._sasl_negotiation_deferred = defer.Deferred()
self._sasl_negotiation_status = None
return self._sasl_negotiation_deferred
def connectionLost(self, reason=connectionDone):
if self.client:
ThriftClientProtocol.connectionLost(self, reason)
def dataReceived(self, data):
if self._sasl_negotiation_deferred:
# we got a sasl challenge in the format (status, length, challenge)
# save the status, let IntNStringReceiver piece the challenge data together
self._sasl_negotiation_status, = struct.unpack("B", data[0])
ThriftClientProtocol.dataReceived(self, data[1:])
else:
# normal frame, let IntNStringReceiver piece it together
ThriftClientProtocol.dataReceived(self, data)
def stringReceived(self, frame):
if self._sasl_negotiation_deferred:
# the frame is just a SASL challenge
response = (self._sasl_negotiation_status, frame)
self._sasl_negotiation_deferred.callback(response)
else:
# there's a second 4 byte length prefix inside the frame
decoded_frame = self.sasl.unwrap(frame[4:])
ThriftClientProtocol.stringReceived(self, decoded_frame)
class ThriftServerProtocol(basic.Int32StringReceiver): class ThriftServerProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1 MAX_LENGTH = 2 ** 31 - 1
@ -126,7 +236,7 @@ class ThriftServerProtocol(basic.Int32StringReceiver):
d = self.factory.processor.process(iprot, oprot) d = self.factory.processor.process(iprot, oprot)
d.addCallbacks(self.processOk, self.processError, d.addCallbacks(self.processOk, self.processError,
callbackArgs=(tmo,)) callbackArgs=(tmo,))
class IThriftServerFactory(Interface): class IThriftServerFactory(Interface):
@ -178,7 +288,7 @@ class ThriftClientFactory(ClientFactory):
def buildProtocol(self, addr): def buildProtocol(self, addr):
p = self.protocol(self.client_class, self.iprot_factory, p = self.protocol(self.client_class, self.iprot_factory,
self.oprot_factory) self.oprot_factory)
p.factory = self p.factory = self
return p return p
@ -188,7 +298,7 @@ class ThriftResource(resource.Resource):
allowedMethods = ('POST',) allowedMethods = ('POST',)
def __init__(self, processor, inputProtocolFactory, def __init__(self, processor, inputProtocolFactory,
outputProtocolFactory=None): outputProtocolFactory=None):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.inputProtocolFactory = inputProtocolFactory self.inputProtocolFactory = inputProtocolFactory
if outputProtocolFactory is None: if outputProtocolFactory is None:

View file

@ -24,225 +24,225 @@ data compression.
from __future__ import division from __future__ import division
import zlib import zlib
from cStringIO import StringIO from .TTransport import TTransportBase, CReadableTransport
from TTransport import TTransportBase, CReadableTransport from ..compat import BufferIO
class TZlibTransportFactory(object): class TZlibTransportFactory(object):
"""Factory transport that builds zlib compressed transports. """Factory transport that builds zlib compressed transports.
This factory caches the last single client/transport that it was passed This factory caches the last single client/transport that it was passed
and returns the same TZlibTransport object that was created. and returns the same TZlibTransport object that was created.
This caching means the TServer class will get the _same_ transport This caching means the TServer class will get the _same_ transport
object for both input and output transports from this factory. object for both input and output transports from this factory.
(For non-threaded scenarios only, since the cache only holds one object) (For non-threaded scenarios only, since the cache only holds one object)
The purpose of this caching is to allocate only one TZlibTransport where The purpose of this caching is to allocate only one TZlibTransport where
only one is really needed (since it must have separate read/write buffers), only one is really needed (since it must have separate read/write buffers),
and makes the statistics from getCompSavings() and getCompRatio() and makes the statistics from getCompSavings() and getCompRatio()
easier to understand. easier to understand.
"""
# class scoped cache of last transport given and zlibtransport returned
_last_trans = None
_last_z = None
def getTransport(self, trans, compresslevel=9):
"""Wrap a transport, trans, with the TZlibTransport
compressed transport class, returning a new
transport to the caller.
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Defaults to 9.
@type compresslevel: int
This method returns a TZlibTransport which wraps the
passed C{trans} TTransport derived instance.
""" """
if trans == self._last_trans: # class scoped cache of last transport given and zlibtransport returned
return self._last_z _last_trans = None
ztrans = TZlibTransport(trans, compresslevel) _last_z = None
self._last_trans = trans
self._last_z = ztrans def getTransport(self, trans, compresslevel=9):
return ztrans """Wrap a transport, trans, with the TZlibTransport
compressed transport class, returning a new
transport to the caller.
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Defaults to 9.
@type compresslevel: int
This method returns a TZlibTransport which wraps the
passed C{trans} TTransport derived instance.
"""
if trans == self._last_trans:
return self._last_z
ztrans = TZlibTransport(trans, compresslevel)
self._last_trans = trans
self._last_z = ztrans
return ztrans
class TZlibTransport(TTransportBase, CReadableTransport): class TZlibTransport(TTransportBase, CReadableTransport):
"""Class that wraps a transport with zlib, compressing writes """Class that wraps a transport with zlib, compressing writes
and decompresses reads, using the python standard and decompresses reads, using the python standard
library zlib module. library zlib module.
"""
# Read buffer size for the python fastbinary C extension,
# the TBinaryProtocolAccelerated class.
DEFAULT_BUFFSIZE = 4096
def __init__(self, trans, compresslevel=9):
"""Create a new TZlibTransport, wrapping C{trans}, another
TTransport derived object.
@param trans: A thrift transport object, i.e. a TSocket() object.
@type trans: TTransport
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Default is 9.
@type compresslevel: int
""" """
self.__trans = trans # Read buffer size for the python fastbinary C extension,
self.compresslevel = compresslevel # the TBinaryProtocolAccelerated class.
self.__rbuf = StringIO() DEFAULT_BUFFSIZE = 4096
self.__wbuf = StringIO()
self._init_zlib()
self._init_stats()
def _reinit_buffers(self): def __init__(self, trans, compresslevel=9):
"""Internal method to initialize/reset the internal StringIO objects """Create a new TZlibTransport, wrapping C{trans}, another
for read and write buffers. TTransport derived object.
"""
self.__rbuf = StringIO()
self.__wbuf = StringIO()
def _init_stats(self): @param trans: A thrift transport object, i.e. a TSocket() object.
"""Internal method to reset the internal statistics counters @type trans: TTransport
for compression ratios and bandwidth savings. @param compresslevel: The zlib compression level, ranging
""" from 0 (no compression) to 9 (best compression). Default is 9.
self.bytes_in = 0 @type compresslevel: int
self.bytes_out = 0 """
self.bytes_in_comp = 0 self.__trans = trans
self.bytes_out_comp = 0 self.compresslevel = compresslevel
self.__rbuf = BufferIO()
self.__wbuf = BufferIO()
self._init_zlib()
self._init_stats()
def _init_zlib(self): def _reinit_buffers(self):
"""Internal method for setting up the zlib compression and """Internal method to initialize/reset the internal StringIO objects
decompression objects. for read and write buffers.
""" """
self._zcomp_read = zlib.decompressobj() self.__rbuf = BufferIO()
self._zcomp_write = zlib.compressobj(self.compresslevel) self.__wbuf = BufferIO()
def getCompRatio(self): def _init_stats(self):
"""Get the current measured compression ratios (in,out) from """Internal method to reset the internal statistics counters
this transport. for compression ratios and bandwidth savings.
"""
self.bytes_in = 0
self.bytes_out = 0
self.bytes_in_comp = 0
self.bytes_out_comp = 0
Returns a tuple of: def _init_zlib(self):
(inbound_compression_ratio, outbound_compression_ratio) """Internal method for setting up the zlib compression and
decompression objects.
"""
self._zcomp_read = zlib.decompressobj()
self._zcomp_write = zlib.compressobj(self.compresslevel)
The compression ratios are computed as: def getCompRatio(self):
compressed / uncompressed """Get the current measured compression ratios (in,out) from
this transport.
E.g., data that compresses by 10x will have a ratio of: 0.10 Returns a tuple of:
and data that compresses to half of ts original size will (inbound_compression_ratio, outbound_compression_ratio)
have a ratio of 0.5
None is returned if no bytes have yet been processed in The compression ratios are computed as:
a particular direction. compressed / uncompressed
"""
r_percent, w_percent = (None, None)
if self.bytes_in > 0:
r_percent = self.bytes_in_comp / self.bytes_in
if self.bytes_out > 0:
w_percent = self.bytes_out_comp / self.bytes_out
return (r_percent, w_percent)
def getCompSavings(self): E.g., data that compresses by 10x will have a ratio of: 0.10
"""Get the current count of saved bytes due to data and data that compresses to half of ts original size will
compression. have a ratio of 0.5
Returns a tuple of: None is returned if no bytes have yet been processed in
(inbound_saved_bytes, outbound_saved_bytes) a particular direction.
"""
r_percent, w_percent = (None, None)
if self.bytes_in > 0:
r_percent = self.bytes_in_comp / self.bytes_in
if self.bytes_out > 0:
w_percent = self.bytes_out_comp / self.bytes_out
return (r_percent, w_percent)
Note: if compression is actually expanding your def getCompSavings(self):
data (only likely with very tiny thrift objects), then """Get the current count of saved bytes due to data
the values returned will be negative. compression.
"""
r_saved = self.bytes_in - self.bytes_in_comp
w_saved = self.bytes_out - self.bytes_out_comp
return (r_saved, w_saved)
def isOpen(self): Returns a tuple of:
"""Return the underlying transport's open status""" (inbound_saved_bytes, outbound_saved_bytes)
return self.__trans.isOpen()
def open(self): Note: if compression is actually expanding your
"""Open the underlying transport""" data (only likely with very tiny thrift objects), then
self._init_stats() the values returned will be negative.
return self.__trans.open() """
r_saved = self.bytes_in - self.bytes_in_comp
w_saved = self.bytes_out - self.bytes_out_comp
return (r_saved, w_saved)
def listen(self): def isOpen(self):
"""Invoke the underlying transport's listen() method""" """Return the underlying transport's open status"""
self.__trans.listen() return self.__trans.isOpen()
def accept(self): def open(self):
"""Accept connections on the underlying transport""" """Open the underlying transport"""
return self.__trans.accept() self._init_stats()
return self.__trans.open()
def close(self): def listen(self):
"""Close the underlying transport,""" """Invoke the underlying transport's listen() method"""
self._reinit_buffers() self.__trans.listen()
self._init_zlib()
return self.__trans.close()
def read(self, sz): def accept(self):
"""Read up to sz bytes from the decompressed bytes buffer, and """Accept connections on the underlying transport"""
read from the underlying transport if the decompression return self.__trans.accept()
buffer is empty.
"""
ret = self.__rbuf.read(sz)
if len(ret) > 0:
return ret
# keep reading from transport until something comes back
while True:
if self.readComp(sz):
break
ret = self.__rbuf.read(sz)
return ret
def readComp(self, sz): def close(self):
"""Read compressed data from the underlying transport, then """Close the underlying transport,"""
decompress it and append it to the internal StringIO read buffer self._reinit_buffers()
""" self._init_zlib()
zbuf = self.__trans.read(sz) return self.__trans.close()
zbuf = self._zcomp_read.unconsumed_tail + zbuf
buf = self._zcomp_read.decompress(zbuf)
self.bytes_in += len(zbuf)
self.bytes_in_comp += len(buf)
old = self.__rbuf.read()
self.__rbuf = StringIO(old + buf)
if len(old) + len(buf) == 0:
return False
return True
def write(self, buf): def read(self, sz):
"""Write some bytes, putting them into the internal write """Read up to sz bytes from the decompressed bytes buffer, and
buffer for eventual compression. read from the underlying transport if the decompression
""" buffer is empty.
self.__wbuf.write(buf) """
ret = self.__rbuf.read(sz)
if len(ret) > 0:
return ret
# keep reading from transport until something comes back
while True:
if self.readComp(sz):
break
ret = self.__rbuf.read(sz)
return ret
def flush(self): def readComp(self, sz):
"""Flush any queued up data in the write buffer and ensure the """Read compressed data from the underlying transport, then
compression buffer is flushed out to the underlying transport decompress it and append it to the internal StringIO read buffer
""" """
wout = self.__wbuf.getvalue() zbuf = self.__trans.read(sz)
if len(wout) > 0: zbuf = self._zcomp_read.unconsumed_tail + zbuf
zbuf = self._zcomp_write.compress(wout) buf = self._zcomp_read.decompress(zbuf)
self.bytes_out += len(wout) self.bytes_in += len(zbuf)
self.bytes_out_comp += len(zbuf) self.bytes_in_comp += len(buf)
else: old = self.__rbuf.read()
zbuf = '' self.__rbuf = BufferIO(old + buf)
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) if len(old) + len(buf) == 0:
self.bytes_out_comp += len(ztail) return False
if (len(zbuf) + len(ztail)) > 0: return True
self.__wbuf = StringIO()
self.__trans.write(zbuf + ztail)
self.__trans.flush()
@property def write(self, buf):
def cstringio_buf(self): """Write some bytes, putting them into the internal write
"""Implement the CReadableTransport interface""" buffer for eventual compression.
return self.__rbuf """
self.__wbuf.write(buf)
def cstringio_refill(self, partialread, reqlen): def flush(self):
"""Implement the CReadableTransport interface for refill""" """Flush any queued up data in the write buffer and ensure the
retstring = partialread compression buffer is flushed out to the underlying transport
if reqlen < self.DEFAULT_BUFFSIZE: """
retstring += self.read(self.DEFAULT_BUFFSIZE) wout = self.__wbuf.getvalue()
while len(retstring) < reqlen: if len(wout) > 0:
retstring += self.read(reqlen - len(retstring)) zbuf = self._zcomp_write.compress(wout)
self.__rbuf = StringIO(retstring) self.bytes_out += len(wout)
return self.__rbuf self.bytes_out_comp += len(zbuf)
else:
zbuf = ''
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
self.bytes_out_comp += len(ztail)
if (len(zbuf) + len(ztail)) > 0:
self.__wbuf = BufferIO()
self.__trans.write(zbuf + ztail)
self.__trans.flush()
@property
def cstringio_buf(self):
"""Implement the CReadableTransport interface"""
return self.__rbuf
def cstringio_refill(self, partialread, reqlen):
"""Implement the CReadableTransport interface for refill"""
retstring = partialread
if reqlen < self.DEFAULT_BUFFSIZE:
retstring += self.read(self.DEFAULT_BUFFSIZE)
while len(retstring) < reqlen:
retstring += self.read(reqlen - len(retstring))
self.__rbuf = BufferIO(retstring)
return self.__rbuf

View file

@ -0,0 +1,99 @@
#
# licensed to the apache software foundation (asf) under one
# or more contributor license agreements. see the notice file
# distributed with this work for additional information
# regarding copyright ownership. the asf licenses this file
# to you under the apache license, version 2.0 (the
# "license"); you may not use this file except in compliance
# with the license. you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing,
# software distributed under the license is distributed on an
# "as is" basis, without warranties or conditions of any
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import sys
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
def legacy_validate_callback(self, cert, hostname):
"""legacy method to validate the peer's SSL certificate, and to check
the commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
in certificates.
raises TTransportException if the certificate fails validation.
"""
if 'subject' not in cert:
raise TTransportException(
TTransportException.NOT_OPEN,
'No SSL certificate found from %s:%s' % (self.host, self.port))
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
if not isinstance(field, tuple):
continue
cert_pair = field[0]
if len(cert_pair) < 2:
continue
cert_key, cert_value = cert_pair[0:2]
if cert_key != 'commonName':
continue
certhost = cert_value
# this check should be performed by some sort of Access Manager
if certhost == hostname:
# success, cert commonName matches desired hostname
return
else:
raise TTransportException(
TTransportException.UNKNOWN,
'Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
TTransportException.UNKNOWN,
'Could not validate SSL certificate from host "%s". Cert=%s'
% (hostname, cert))
def _optional_dependencies():
try:
import ipaddress # noqa
logger.debug('ipaddress module is available')
ipaddr = True
except ImportError:
logger.warn('ipaddress module is unavailable')
ipaddr = False
if sys.hexversion < 0x030500F0:
try:
from backports.ssl_match_hostname import match_hostname, __version__ as ver
ver = list(map(int, ver.split('.')))
logger.debug('backports.ssl_match_hostname module is available')
match = match_hostname
if ver[0] * 10 + ver[1] >= 35:
return ipaddr, match
else:
logger.warn('backports.ssl_match_hostname module is too old')
ipaddr = False
except ImportError:
logger.warn('backports.ssl_match_hostname is unavailable')
ipaddr = False
try:
from ssl import match_hostname
logger.debug('ssl.match_hostname is available')
match = match_hostname
except ImportError:
logger.warn('using legacy validation callback')
match = legacy_validate_callback
return ipaddr, match
_match_has_ipaddress, _match_hostname = _optional_dependencies()