mirror of
https://github.com/Unidata/python-awips.git
synced 2025-02-23 14:57:56 -05:00
added thrift
This commit is contained in:
parent
7194175bcb
commit
230171f59b
22 changed files with 4825 additions and 0 deletions
35
thrift/TSCons.py
Normal file
35
thrift/TSCons.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from os import path
|
||||
from SCons.Builder import Builder
|
||||
|
||||
|
||||
def scons_env(env, add=''):
|
||||
opath = path.dirname(path.abspath('$TARGET'))
|
||||
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
|
||||
cppbuild = Builder(action=lstr)
|
||||
env.Append(BUILDERS={'ThriftCpp': cppbuild})
|
||||
|
||||
|
||||
def gen_cpp(env, dir, file):
|
||||
scons_env(env)
|
||||
suffixes = ['_types.h', '_types.cpp']
|
||||
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
|
||||
return env.ThriftCpp(targets, dir + file + '.thrift')
|
38
thrift/TSerialization.py
Normal file
38
thrift/TSerialization.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from protocol import TBinaryProtocol
|
||||
from transport import TTransport
|
||||
|
||||
|
||||
def serialize(thrift_object,
|
||||
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
|
||||
transport = TTransport.TMemoryBuffer()
|
||||
protocol = protocol_factory.getProtocol(transport)
|
||||
thrift_object.write(protocol)
|
||||
return transport.getvalue()
|
||||
|
||||
|
||||
def deserialize(base,
|
||||
buf,
|
||||
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
|
||||
transport = TTransport.TMemoryBuffer(buf)
|
||||
protocol = protocol_factory.getProtocol(transport)
|
||||
base.read(protocol)
|
||||
return base
|
157
thrift/Thrift.py
Normal file
157
thrift/Thrift.py
Normal file
|
@ -0,0 +1,157 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
class TType:
|
||||
STOP = 0
|
||||
VOID = 1
|
||||
BOOL = 2
|
||||
BYTE = 3
|
||||
I08 = 3
|
||||
DOUBLE = 4
|
||||
I16 = 6
|
||||
I32 = 8
|
||||
I64 = 10
|
||||
STRING = 11
|
||||
UTF7 = 11
|
||||
STRUCT = 12
|
||||
MAP = 13
|
||||
SET = 14
|
||||
LIST = 15
|
||||
UTF8 = 16
|
||||
UTF16 = 17
|
||||
|
||||
_VALUES_TO_NAMES = ('STOP',
|
||||
'VOID',
|
||||
'BOOL',
|
||||
'BYTE',
|
||||
'DOUBLE',
|
||||
None,
|
||||
'I16',
|
||||
None,
|
||||
'I32',
|
||||
None,
|
||||
'I64',
|
||||
'STRING',
|
||||
'STRUCT',
|
||||
'MAP',
|
||||
'SET',
|
||||
'LIST',
|
||||
'UTF8',
|
||||
'UTF16')
|
||||
|
||||
|
||||
class TMessageType:
|
||||
CALL = 1
|
||||
REPLY = 2
|
||||
EXCEPTION = 3
|
||||
ONEWAY = 4
|
||||
|
||||
|
||||
class TProcessor:
|
||||
"""Base class for procsessor, which works on two streams."""
|
||||
|
||||
def process(iprot, oprot):
|
||||
pass
|
||||
|
||||
|
||||
class TException(Exception):
|
||||
"""Base class for all thrift exceptions."""
|
||||
|
||||
# BaseException.message is deprecated in Python v[2.6,3.0)
|
||||
if (2, 6, 0) <= sys.version_info < (3, 0):
|
||||
def _get_message(self):
|
||||
return self._message
|
||||
|
||||
def _set_message(self, message):
|
||||
self._message = message
|
||||
message = property(_get_message, _set_message)
|
||||
|
||||
def __init__(self, message=None):
|
||||
Exception.__init__(self, message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class TApplicationException(TException):
|
||||
"""Application level thrift exceptions."""
|
||||
|
||||
UNKNOWN = 0
|
||||
UNKNOWN_METHOD = 1
|
||||
INVALID_MESSAGE_TYPE = 2
|
||||
WRONG_METHOD_NAME = 3
|
||||
BAD_SEQUENCE_ID = 4
|
||||
MISSING_RESULT = 5
|
||||
INTERNAL_ERROR = 6
|
||||
PROTOCOL_ERROR = 7
|
||||
|
||||
def __init__(self, type=UNKNOWN, message=None):
|
||||
TException.__init__(self, message)
|
||||
self.type = type
|
||||
|
||||
def __str__(self):
|
||||
if self.message:
|
||||
return self.message
|
||||
elif self.type == self.UNKNOWN_METHOD:
|
||||
return 'Unknown method'
|
||||
elif self.type == self.INVALID_MESSAGE_TYPE:
|
||||
return 'Invalid message type'
|
||||
elif self.type == self.WRONG_METHOD_NAME:
|
||||
return 'Wrong method name'
|
||||
elif self.type == self.BAD_SEQUENCE_ID:
|
||||
return 'Bad sequence ID'
|
||||
elif self.type == self.MISSING_RESULT:
|
||||
return 'Missing result'
|
||||
else:
|
||||
return 'Default (unknown) TApplicationException'
|
||||
|
||||
def read(self, iprot):
|
||||
iprot.readStructBegin()
|
||||
while True:
|
||||
(fname, ftype, fid) = iprot.readFieldBegin()
|
||||
if ftype == TType.STOP:
|
||||
break
|
||||
if fid == 1:
|
||||
if ftype == TType.STRING:
|
||||
self.message = iprot.readString()
|
||||
else:
|
||||
iprot.skip(ftype)
|
||||
elif fid == 2:
|
||||
if ftype == TType.I32:
|
||||
self.type = iprot.readI32()
|
||||
else:
|
||||
iprot.skip(ftype)
|
||||
else:
|
||||
iprot.skip(ftype)
|
||||
iprot.readFieldEnd()
|
||||
iprot.readStructEnd()
|
||||
|
||||
def write(self, oprot):
|
||||
oprot.writeStructBegin('TApplicationException')
|
||||
if self.message is not None:
|
||||
oprot.writeFieldBegin('message', TType.STRING, 1)
|
||||
oprot.writeString(self.message)
|
||||
oprot.writeFieldEnd()
|
||||
if self.type is not None:
|
||||
oprot.writeFieldBegin('type', TType.I32, 2)
|
||||
oprot.writeI32(self.type)
|
||||
oprot.writeFieldEnd()
|
||||
oprot.writeFieldStop()
|
||||
oprot.writeStructEnd()
|
20
thrift/__init__.py
Normal file
20
thrift/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
__all__ = ['Thrift', 'TSCons']
|
81
thrift/protocol/TBase.py
Normal file
81
thrift/protocol/TBase.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from thrift.Thrift import *
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from thrift.transport import TTransport
|
||||
|
||||
try:
|
||||
from thrift.protocol import fastbinary
|
||||
except:
|
||||
fastbinary = None
|
||||
|
||||
|
||||
class TBase(object):
|
||||
__slots__ = []
|
||||
|
||||
def __repr__(self):
|
||||
L = ['%s=%r' % (key, getattr(self, key))
|
||||
for key in self.__slots__]
|
||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
for attr in self.__slots__:
|
||||
my_val = getattr(self, attr)
|
||||
other_val = getattr(other, attr)
|
||||
if my_val != other_val:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def read(self, iprot):
|
||||
if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
|
||||
isinstance(iprot.trans, TTransport.CReadableTransport) and
|
||||
self.thrift_spec is not None and
|
||||
fastbinary is not None):
|
||||
fastbinary.decode_binary(self,
|
||||
iprot.trans,
|
||||
(self.__class__, self.thrift_spec))
|
||||
return
|
||||
iprot.readStruct(self, self.thrift_spec)
|
||||
|
||||
def write(self, oprot):
|
||||
if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
|
||||
self.thrift_spec is not None and
|
||||
fastbinary is not None):
|
||||
oprot.trans.write(
|
||||
fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
|
||||
return
|
||||
oprot.writeStruct(self, self.thrift_spec)
|
||||
|
||||
|
||||
class TExceptionBase(Exception):
|
||||
# old style class so python2.4 can raise exceptions derived from this
|
||||
# This can't inherit from TBase because of that limitation.
|
||||
__slots__ = []
|
||||
|
||||
__repr__ = TBase.__repr__.im_func
|
||||
__eq__ = TBase.__eq__.im_func
|
||||
__ne__ = TBase.__ne__.im_func
|
||||
read = TBase.read.im_func
|
||||
write = TBase.write.im_func
|
260
thrift/protocol/TBinaryProtocol.py
Normal file
260
thrift/protocol/TBinaryProtocol.py
Normal file
|
@ -0,0 +1,260 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from TProtocol import *
|
||||
from struct import pack, unpack
|
||||
|
||||
|
||||
class TBinaryProtocol(TProtocolBase):
|
||||
"""Binary implementation of the Thrift protocol driver."""
|
||||
|
||||
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
|
||||
# positive, converting this into a long. If we hardcode the int value
|
||||
# instead it'll stay in 32 bit-land.
|
||||
|
||||
# VERSION_MASK = 0xffff0000
|
||||
VERSION_MASK = -65536
|
||||
|
||||
# VERSION_1 = 0x80010000
|
||||
VERSION_1 = -2147418112
|
||||
|
||||
TYPE_MASK = 0x000000ff
|
||||
|
||||
def __init__(self, trans, strictRead=False, strictWrite=True):
|
||||
TProtocolBase.__init__(self, trans)
|
||||
self.strictRead = strictRead
|
||||
self.strictWrite = strictWrite
|
||||
|
||||
def writeMessageBegin(self, name, type, seqid):
|
||||
if self.strictWrite:
|
||||
self.writeI32(TBinaryProtocol.VERSION_1 | type)
|
||||
self.writeString(name)
|
||||
self.writeI32(seqid)
|
||||
else:
|
||||
self.writeString(name)
|
||||
self.writeByte(type)
|
||||
self.writeI32(seqid)
|
||||
|
||||
def writeMessageEnd(self):
|
||||
pass
|
||||
|
||||
def writeStructBegin(self, name):
|
||||
pass
|
||||
|
||||
def writeStructEnd(self):
|
||||
pass
|
||||
|
||||
def writeFieldBegin(self, name, type, id):
|
||||
self.writeByte(type)
|
||||
self.writeI16(id)
|
||||
|
||||
def writeFieldEnd(self):
|
||||
pass
|
||||
|
||||
def writeFieldStop(self):
|
||||
self.writeByte(TType.STOP)
|
||||
|
||||
def writeMapBegin(self, ktype, vtype, size):
|
||||
self.writeByte(ktype)
|
||||
self.writeByte(vtype)
|
||||
self.writeI32(size)
|
||||
|
||||
def writeMapEnd(self):
|
||||
pass
|
||||
|
||||
def writeListBegin(self, etype, size):
|
||||
self.writeByte(etype)
|
||||
self.writeI32(size)
|
||||
|
||||
def writeListEnd(self):
|
||||
pass
|
||||
|
||||
def writeSetBegin(self, etype, size):
|
||||
self.writeByte(etype)
|
||||
self.writeI32(size)
|
||||
|
||||
def writeSetEnd(self):
|
||||
pass
|
||||
|
||||
def writeBool(self, bool):
|
||||
if bool:
|
||||
self.writeByte(1)
|
||||
else:
|
||||
self.writeByte(0)
|
||||
|
||||
def writeByte(self, byte):
|
||||
buff = pack("!b", byte)
|
||||
self.trans.write(buff)
|
||||
|
||||
def writeI16(self, i16):
|
||||
buff = pack("!h", i16)
|
||||
self.trans.write(buff)
|
||||
|
||||
def writeI32(self, i32):
|
||||
buff = pack("!i", i32)
|
||||
self.trans.write(buff)
|
||||
|
||||
def writeI64(self, i64):
|
||||
buff = pack("!q", i64)
|
||||
self.trans.write(buff)
|
||||
|
||||
def writeDouble(self, dub):
|
||||
buff = pack("!d", dub)
|
||||
self.trans.write(buff)
|
||||
|
||||
def writeString(self, str):
|
||||
self.writeI32(len(str))
|
||||
self.trans.write(str)
|
||||
|
||||
def readMessageBegin(self):
|
||||
sz = self.readI32()
|
||||
if sz < 0:
|
||||
version = sz & TBinaryProtocol.VERSION_MASK
|
||||
if version != TBinaryProtocol.VERSION_1:
|
||||
raise TProtocolException(
|
||||
type=TProtocolException.BAD_VERSION,
|
||||
message='Bad version in readMessageBegin: %d' % (sz))
|
||||
type = sz & TBinaryProtocol.TYPE_MASK
|
||||
name = self.readString()
|
||||
seqid = self.readI32()
|
||||
else:
|
||||
if self.strictRead:
|
||||
raise TProtocolException(type=TProtocolException.BAD_VERSION,
|
||||
message='No protocol version header')
|
||||
name = self.trans.readAll(sz)
|
||||
type = self.readByte()
|
||||
seqid = self.readI32()
|
||||
return (name, type, seqid)
|
||||
|
||||
def readMessageEnd(self):
|
||||
pass
|
||||
|
||||
def readStructBegin(self):
|
||||
pass
|
||||
|
||||
def readStructEnd(self):
|
||||
pass
|
||||
|
||||
def readFieldBegin(self):
|
||||
type = self.readByte()
|
||||
if type == TType.STOP:
|
||||
return (None, type, 0)
|
||||
id = self.readI16()
|
||||
return (None, type, id)
|
||||
|
||||
def readFieldEnd(self):
|
||||
pass
|
||||
|
||||
def readMapBegin(self):
|
||||
ktype = self.readByte()
|
||||
vtype = self.readByte()
|
||||
size = self.readI32()
|
||||
return (ktype, vtype, size)
|
||||
|
||||
def readMapEnd(self):
|
||||
pass
|
||||
|
||||
def readListBegin(self):
|
||||
etype = self.readByte()
|
||||
size = self.readI32()
|
||||
return (etype, size)
|
||||
|
||||
def readListEnd(self):
|
||||
pass
|
||||
|
||||
def readSetBegin(self):
|
||||
etype = self.readByte()
|
||||
size = self.readI32()
|
||||
return (etype, size)
|
||||
|
||||
def readSetEnd(self):
|
||||
pass
|
||||
|
||||
def readBool(self):
|
||||
byte = self.readByte()
|
||||
if byte == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def readByte(self):
|
||||
buff = self.trans.readAll(1)
|
||||
val, = unpack('!b', buff)
|
||||
return val
|
||||
|
||||
def readI16(self):
|
||||
buff = self.trans.readAll(2)
|
||||
val, = unpack('!h', buff)
|
||||
return val
|
||||
|
||||
def readI32(self):
|
||||
buff = self.trans.readAll(4)
|
||||
val, = unpack('!i', buff)
|
||||
return val
|
||||
|
||||
def readI64(self):
|
||||
buff = self.trans.readAll(8)
|
||||
val, = unpack('!q', buff)
|
||||
return val
|
||||
|
||||
def readDouble(self):
|
||||
buff = self.trans.readAll(8)
|
||||
val, = unpack('!d', buff)
|
||||
return val
|
||||
|
||||
def readString(self):
|
||||
len = self.readI32()
|
||||
str = self.trans.readAll(len)
|
||||
return str
|
||||
|
||||
|
||||
class TBinaryProtocolFactory:
|
||||
def __init__(self, strictRead=False, strictWrite=True):
|
||||
self.strictRead = strictRead
|
||||
self.strictWrite = strictWrite
|
||||
|
||||
def getProtocol(self, trans):
|
||||
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
|
||||
return prot
|
||||
|
||||
|
||||
class TBinaryProtocolAccelerated(TBinaryProtocol):
|
||||
"""C-Accelerated version of TBinaryProtocol.
|
||||
|
||||
This class does not override any of TBinaryProtocol's methods,
|
||||
but the generated code recognizes it directly and will call into
|
||||
our C module to do the encoding, bypassing this object entirely.
|
||||
We inherit from TBinaryProtocol so that the normal TBinaryProtocol
|
||||
encoding can happen if the fastbinary module doesn't work for some
|
||||
reason. (TODO(dreiss): Make this happen sanely in more cases.)
|
||||
|
||||
In order to take advantage of the C module, just use
|
||||
TBinaryProtocolAccelerated instead of TBinaryProtocol.
|
||||
|
||||
NOTE: This code was contributed by an external developer.
|
||||
The internal Thrift team has reviewed and tested it,
|
||||
but we cannot guarantee that it is production-ready.
|
||||
Please feel free to report bugs and/or success stories
|
||||
to the public mailing list.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TBinaryProtocolAcceleratedFactory:
|
||||
def getProtocol(self, trans):
|
||||
return TBinaryProtocolAccelerated(trans)
|
403
thrift/protocol/TCompactProtocol.py
Normal file
403
thrift/protocol/TCompactProtocol.py
Normal file
|
@ -0,0 +1,403 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from TProtocol import *
|
||||
from struct import pack, unpack
|
||||
|
||||
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
|
||||
|
||||
CLEAR = 0
|
||||
FIELD_WRITE = 1
|
||||
VALUE_WRITE = 2
|
||||
CONTAINER_WRITE = 3
|
||||
BOOL_WRITE = 4
|
||||
FIELD_READ = 5
|
||||
CONTAINER_READ = 6
|
||||
VALUE_READ = 7
|
||||
BOOL_READ = 8
|
||||
|
||||
|
||||
def make_helper(v_from, container):
|
||||
def helper(func):
|
||||
def nested(self, *args, **kwargs):
|
||||
assert self.state in (v_from, container), (self.state, v_from, container)
|
||||
return func(self, *args, **kwargs)
|
||||
return nested
|
||||
return helper
|
||||
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
|
||||
reader = make_helper(VALUE_READ, CONTAINER_READ)
|
||||
|
||||
|
||||
def makeZigZag(n, bits):
|
||||
return (n << 1) ^ (n >> (bits - 1))
|
||||
|
||||
|
||||
def fromZigZag(n):
|
||||
return (n >> 1) ^ -(n & 1)
|
||||
|
||||
|
||||
def writeVarint(trans, n):
|
||||
out = []
|
||||
while True:
|
||||
if n & ~0x7f == 0:
|
||||
out.append(n)
|
||||
break
|
||||
else:
|
||||
out.append((n & 0xff) | 0x80)
|
||||
n = n >> 7
|
||||
trans.write(''.join(map(chr, out)))
|
||||
|
||||
|
||||
def readVarint(trans):
|
||||
result = 0
|
||||
shift = 0
|
||||
while True:
|
||||
x = trans.readAll(1)
|
||||
byte = ord(x)
|
||||
result |= (byte & 0x7f) << shift
|
||||
if byte >> 7 == 0:
|
||||
return result
|
||||
shift += 7
|
||||
|
||||
|
||||
class CompactType:
|
||||
STOP = 0x00
|
||||
TRUE = 0x01
|
||||
FALSE = 0x02
|
||||
BYTE = 0x03
|
||||
I16 = 0x04
|
||||
I32 = 0x05
|
||||
I64 = 0x06
|
||||
DOUBLE = 0x07
|
||||
BINARY = 0x08
|
||||
LIST = 0x09
|
||||
SET = 0x0A
|
||||
MAP = 0x0B
|
||||
STRUCT = 0x0C
|
||||
|
||||
CTYPES = {TType.STOP: CompactType.STOP,
|
||||
TType.BOOL: CompactType.TRUE, # used for collection
|
||||
TType.BYTE: CompactType.BYTE,
|
||||
TType.I16: CompactType.I16,
|
||||
TType.I32: CompactType.I32,
|
||||
TType.I64: CompactType.I64,
|
||||
TType.DOUBLE: CompactType.DOUBLE,
|
||||
TType.STRING: CompactType.BINARY,
|
||||
TType.STRUCT: CompactType.STRUCT,
|
||||
TType.LIST: CompactType.LIST,
|
||||
TType.SET: CompactType.SET,
|
||||
TType.MAP: CompactType.MAP
|
||||
}
|
||||
|
||||
TTYPES = {}
|
||||
for k, v in CTYPES.items():
|
||||
TTYPES[v] = k
|
||||
TTYPES[CompactType.FALSE] = TType.BOOL
|
||||
del k
|
||||
del v
|
||||
|
||||
|
||||
class TCompactProtocol(TProtocolBase):
|
||||
"""Compact implementation of the Thrift protocol driver."""
|
||||
|
||||
PROTOCOL_ID = 0x82
|
||||
VERSION = 1
|
||||
VERSION_MASK = 0x1f
|
||||
TYPE_MASK = 0xe0
|
||||
TYPE_SHIFT_AMOUNT = 5
|
||||
|
||||
def __init__(self, trans):
|
||||
TProtocolBase.__init__(self, trans)
|
||||
self.state = CLEAR
|
||||
self.__last_fid = 0
|
||||
self.__bool_fid = None
|
||||
self.__bool_value = None
|
||||
self.__structs = []
|
||||
self.__containers = []
|
||||
|
||||
def __writeVarint(self, n):
|
||||
writeVarint(self.trans, n)
|
||||
|
||||
def writeMessageBegin(self, name, type, seqid):
|
||||
assert self.state == CLEAR
|
||||
self.__writeUByte(self.PROTOCOL_ID)
|
||||
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
|
||||
self.__writeVarint(seqid)
|
||||
self.__writeString(name)
|
||||
self.state = VALUE_WRITE
|
||||
|
||||
def writeMessageEnd(self):
|
||||
assert self.state == VALUE_WRITE
|
||||
self.state = CLEAR
|
||||
|
||||
def writeStructBegin(self, name):
|
||||
assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
|
||||
self.__structs.append((self.state, self.__last_fid))
|
||||
self.state = FIELD_WRITE
|
||||
self.__last_fid = 0
|
||||
|
||||
def writeStructEnd(self):
|
||||
assert self.state == FIELD_WRITE
|
||||
self.state, self.__last_fid = self.__structs.pop()
|
||||
|
||||
def writeFieldStop(self):
|
||||
self.__writeByte(0)
|
||||
|
||||
def __writeFieldHeader(self, type, fid):
|
||||
delta = fid - self.__last_fid
|
||||
if 0 < delta <= 15:
|
||||
self.__writeUByte(delta << 4 | type)
|
||||
else:
|
||||
self.__writeByte(type)
|
||||
self.__writeI16(fid)
|
||||
self.__last_fid = fid
|
||||
|
||||
def writeFieldBegin(self, name, type, fid):
|
||||
assert self.state == FIELD_WRITE, self.state
|
||||
if type == TType.BOOL:
|
||||
self.state = BOOL_WRITE
|
||||
self.__bool_fid = fid
|
||||
else:
|
||||
self.state = VALUE_WRITE
|
||||
self.__writeFieldHeader(CTYPES[type], fid)
|
||||
|
||||
def writeFieldEnd(self):
|
||||
assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
|
||||
self.state = FIELD_WRITE
|
||||
|
||||
def __writeUByte(self, byte):
|
||||
self.trans.write(pack('!B', byte))
|
||||
|
||||
def __writeByte(self, byte):
|
||||
self.trans.write(pack('!b', byte))
|
||||
|
||||
def __writeI16(self, i16):
|
||||
self.__writeVarint(makeZigZag(i16, 16))
|
||||
|
||||
def __writeSize(self, i32):
|
||||
self.__writeVarint(i32)
|
||||
|
||||
def writeCollectionBegin(self, etype, size):
|
||||
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
|
||||
if size <= 14:
|
||||
self.__writeUByte(size << 4 | CTYPES[etype])
|
||||
else:
|
||||
self.__writeUByte(0xf0 | CTYPES[etype])
|
||||
self.__writeSize(size)
|
||||
self.__containers.append(self.state)
|
||||
self.state = CONTAINER_WRITE
|
||||
writeSetBegin = writeCollectionBegin
|
||||
writeListBegin = writeCollectionBegin
|
||||
|
||||
def writeMapBegin(self, ktype, vtype, size):
|
||||
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
|
||||
if size == 0:
|
||||
self.__writeByte(0)
|
||||
else:
|
||||
self.__writeSize(size)
|
||||
self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
|
||||
self.__containers.append(self.state)
|
||||
self.state = CONTAINER_WRITE
|
||||
|
||||
def writeCollectionEnd(self):
|
||||
assert self.state == CONTAINER_WRITE, self.state
|
||||
self.state = self.__containers.pop()
|
||||
writeMapEnd = writeCollectionEnd
|
||||
writeSetEnd = writeCollectionEnd
|
||||
writeListEnd = writeCollectionEnd
|
||||
|
||||
def writeBool(self, bool):
|
||||
if self.state == BOOL_WRITE:
|
||||
if bool:
|
||||
ctype = CompactType.TRUE
|
||||
else:
|
||||
ctype = CompactType.FALSE
|
||||
self.__writeFieldHeader(ctype, self.__bool_fid)
|
||||
elif self.state == CONTAINER_WRITE:
|
||||
if bool:
|
||||
self.__writeByte(CompactType.TRUE)
|
||||
else:
|
||||
self.__writeByte(CompactType.FALSE)
|
||||
else:
|
||||
raise AssertionError("Invalid state in compact protocol")
|
||||
|
||||
writeByte = writer(__writeByte)
|
||||
writeI16 = writer(__writeI16)
|
||||
|
||||
@writer
|
||||
def writeI32(self, i32):
|
||||
self.__writeVarint(makeZigZag(i32, 32))
|
||||
|
||||
@writer
|
||||
def writeI64(self, i64):
|
||||
self.__writeVarint(makeZigZag(i64, 64))
|
||||
|
||||
@writer
|
||||
def writeDouble(self, dub):
|
||||
self.trans.write(pack('!d', dub))
|
||||
|
||||
def __writeString(self, s):
|
||||
self.__writeSize(len(s))
|
||||
self.trans.write(s)
|
||||
writeString = writer(__writeString)
|
||||
|
||||
def readFieldBegin(self):
|
||||
assert self.state == FIELD_READ, self.state
|
||||
type = self.__readUByte()
|
||||
if type & 0x0f == TType.STOP:
|
||||
return (None, 0, 0)
|
||||
delta = type >> 4
|
||||
if delta == 0:
|
||||
fid = self.__readI16()
|
||||
else:
|
||||
fid = self.__last_fid + delta
|
||||
self.__last_fid = fid
|
||||
type = type & 0x0f
|
||||
if type == CompactType.TRUE:
|
||||
self.state = BOOL_READ
|
||||
self.__bool_value = True
|
||||
elif type == CompactType.FALSE:
|
||||
self.state = BOOL_READ
|
||||
self.__bool_value = False
|
||||
else:
|
||||
self.state = VALUE_READ
|
||||
return (None, self.__getTType(type), fid)
|
||||
|
||||
def readFieldEnd(self):
|
||||
assert self.state in (VALUE_READ, BOOL_READ), self.state
|
||||
self.state = FIELD_READ
|
||||
|
||||
def __readUByte(self):
|
||||
result, = unpack('!B', self.trans.readAll(1))
|
||||
return result
|
||||
|
||||
def __readByte(self):
|
||||
result, = unpack('!b', self.trans.readAll(1))
|
||||
return result
|
||||
|
||||
def __readVarint(self):
|
||||
return readVarint(self.trans)
|
||||
|
||||
def __readZigZag(self):
|
||||
return fromZigZag(self.__readVarint())
|
||||
|
||||
def __readSize(self):
|
||||
result = self.__readVarint()
|
||||
if result < 0:
|
||||
raise TException("Length < 0")
|
||||
return result
|
||||
|
||||
def readMessageBegin(self):
|
||||
assert self.state == CLEAR
|
||||
proto_id = self.__readUByte()
|
||||
if proto_id != self.PROTOCOL_ID:
|
||||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||
'Bad protocol id in the message: %d' % proto_id)
|
||||
ver_type = self.__readUByte()
|
||||
type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
|
||||
version = ver_type & self.VERSION_MASK
|
||||
if version != self.VERSION:
|
||||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||
'Bad version: %d (expect %d)' % (version, self.VERSION))
|
||||
seqid = self.__readVarint()
|
||||
name = self.__readString()
|
||||
return (name, type, seqid)
|
||||
|
||||
def readMessageEnd(self):
|
||||
assert self.state == CLEAR
|
||||
assert len(self.__structs) == 0
|
||||
|
||||
def readStructBegin(self):
|
||||
assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
|
||||
self.__structs.append((self.state, self.__last_fid))
|
||||
self.state = FIELD_READ
|
||||
self.__last_fid = 0
|
||||
|
||||
def readStructEnd(self):
|
||||
assert self.state == FIELD_READ
|
||||
self.state, self.__last_fid = self.__structs.pop()
|
||||
|
||||
def readCollectionBegin(self):
|
||||
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
||||
size_type = self.__readUByte()
|
||||
size = size_type >> 4
|
||||
type = self.__getTType(size_type)
|
||||
if size == 15:
|
||||
size = self.__readSize()
|
||||
self.__containers.append(self.state)
|
||||
self.state = CONTAINER_READ
|
||||
return type, size
|
||||
readSetBegin = readCollectionBegin
|
||||
readListBegin = readCollectionBegin
|
||||
|
||||
def readMapBegin(self):
|
||||
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
||||
size = self.__readSize()
|
||||
types = 0
|
||||
if size > 0:
|
||||
types = self.__readUByte()
|
||||
vtype = self.__getTType(types)
|
||||
ktype = self.__getTType(types >> 4)
|
||||
self.__containers.append(self.state)
|
||||
self.state = CONTAINER_READ
|
||||
return (ktype, vtype, size)
|
||||
|
||||
def readCollectionEnd(self):
|
||||
assert self.state == CONTAINER_READ, self.state
|
||||
self.state = self.__containers.pop()
|
||||
readSetEnd = readCollectionEnd
|
||||
readListEnd = readCollectionEnd
|
||||
readMapEnd = readCollectionEnd
|
||||
|
||||
def readBool(self):
|
||||
if self.state == BOOL_READ:
|
||||
return self.__bool_value == CompactType.TRUE
|
||||
elif self.state == CONTAINER_READ:
|
||||
return self.__readByte() == CompactType.TRUE
|
||||
else:
|
||||
raise AssertionError("Invalid state in compact protocol: %d" %
|
||||
self.state)
|
||||
|
||||
readByte = reader(__readByte)
|
||||
__readI16 = __readZigZag
|
||||
readI16 = reader(__readZigZag)
|
||||
readI32 = reader(__readZigZag)
|
||||
readI64 = reader(__readZigZag)
|
||||
|
||||
@reader
|
||||
def readDouble(self):
|
||||
buff = self.trans.readAll(8)
|
||||
val, = unpack('!d', buff)
|
||||
return val
|
||||
|
||||
def __readString(self):
|
||||
len = self.__readSize()
|
||||
return self.trans.readAll(len)
|
||||
readString = reader(__readString)
|
||||
|
||||
def __getTType(self, byte):
|
||||
return TTYPES[byte & 0x0f]
|
||||
|
||||
|
||||
class TCompactProtocolFactory:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def getProtocol(self, trans):
|
||||
return TCompactProtocol(trans)
|
406
thrift/protocol/TProtocol.py
Normal file
406
thrift/protocol/TProtocol.py
Normal file
|
@ -0,0 +1,406 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from thrift.Thrift import *
|
||||
|
||||
|
||||
class TProtocolException(TException):
|
||||
"""Custom Protocol Exception class"""
|
||||
|
||||
UNKNOWN = 0
|
||||
INVALID_DATA = 1
|
||||
NEGATIVE_SIZE = 2
|
||||
SIZE_LIMIT = 3
|
||||
BAD_VERSION = 4
|
||||
|
||||
def __init__(self, type=UNKNOWN, message=None):
|
||||
TException.__init__(self, message)
|
||||
self.type = type
|
||||
|
||||
|
||||
class TProtocolBase:
|
||||
"""Base class for Thrift protocol driver."""
|
||||
|
||||
def __init__(self, trans):
|
||||
self.trans = trans
|
||||
|
||||
def writeMessageBegin(self, name, type, seqid):
|
||||
pass
|
||||
|
||||
def writeMessageEnd(self):
|
||||
pass
|
||||
|
||||
def writeStructBegin(self, name):
|
||||
pass
|
||||
|
||||
def writeStructEnd(self):
|
||||
pass
|
||||
|
||||
def writeFieldBegin(self, name, type, id):
|
||||
pass
|
||||
|
||||
def writeFieldEnd(self):
|
||||
pass
|
||||
|
||||
def writeFieldStop(self):
|
||||
pass
|
||||
|
||||
def writeMapBegin(self, ktype, vtype, size):
|
||||
pass
|
||||
|
||||
def writeMapEnd(self):
|
||||
pass
|
||||
|
||||
def writeListBegin(self, etype, size):
|
||||
pass
|
||||
|
||||
def writeListEnd(self):
|
||||
pass
|
||||
|
||||
def writeSetBegin(self, etype, size):
|
||||
pass
|
||||
|
||||
def writeSetEnd(self):
|
||||
pass
|
||||
|
||||
def writeBool(self, bool):
|
||||
pass
|
||||
|
||||
def writeByte(self, byte):
|
||||
pass
|
||||
|
||||
def writeI16(self, i16):
|
||||
pass
|
||||
|
||||
def writeI32(self, i32):
|
||||
pass
|
||||
|
||||
def writeI64(self, i64):
|
||||
pass
|
||||
|
||||
def writeDouble(self, dub):
|
||||
pass
|
||||
|
||||
def writeString(self, str):
|
||||
pass
|
||||
|
||||
def readMessageBegin(self):
|
||||
pass
|
||||
|
||||
def readMessageEnd(self):
|
||||
pass
|
||||
|
||||
def readStructBegin(self):
|
||||
pass
|
||||
|
||||
def readStructEnd(self):
|
||||
pass
|
||||
|
||||
def readFieldBegin(self):
|
||||
pass
|
||||
|
||||
def readFieldEnd(self):
|
||||
pass
|
||||
|
||||
def readMapBegin(self):
|
||||
pass
|
||||
|
||||
def readMapEnd(self):
|
||||
pass
|
||||
|
||||
def readListBegin(self):
|
||||
pass
|
||||
|
||||
def readListEnd(self):
|
||||
pass
|
||||
|
||||
def readSetBegin(self):
|
||||
pass
|
||||
|
||||
def readSetEnd(self):
|
||||
pass
|
||||
|
||||
def readBool(self):
|
||||
pass
|
||||
|
||||
def readByte(self):
|
||||
pass
|
||||
|
||||
def readI16(self):
|
||||
pass
|
||||
|
||||
def readI32(self):
|
||||
pass
|
||||
|
||||
def readI64(self):
|
||||
pass
|
||||
|
||||
def readDouble(self):
|
||||
pass
|
||||
|
||||
def readString(self):
|
||||
pass
|
||||
|
||||
def skip(self, type):
|
||||
if type == TType.STOP:
|
||||
return
|
||||
elif type == TType.BOOL:
|
||||
self.readBool()
|
||||
elif type == TType.BYTE:
|
||||
self.readByte()
|
||||
elif type == TType.I16:
|
||||
self.readI16()
|
||||
elif type == TType.I32:
|
||||
self.readI32()
|
||||
elif type == TType.I64:
|
||||
self.readI64()
|
||||
elif type == TType.DOUBLE:
|
||||
self.readDouble()
|
||||
elif type == TType.STRING:
|
||||
self.readString()
|
||||
elif type == TType.STRUCT:
|
||||
name = self.readStructBegin()
|
||||
while True:
|
||||
(name, type, id) = self.readFieldBegin()
|
||||
if type == TType.STOP:
|
||||
break
|
||||
self.skip(type)
|
||||
self.readFieldEnd()
|
||||
self.readStructEnd()
|
||||
elif type == TType.MAP:
|
||||
(ktype, vtype, size) = self.readMapBegin()
|
||||
for i in range(size):
|
||||
self.skip(ktype)
|
||||
self.skip(vtype)
|
||||
self.readMapEnd()
|
||||
elif type == TType.SET:
|
||||
(etype, size) = self.readSetBegin()
|
||||
for i in range(size):
|
||||
self.skip(etype)
|
||||
self.readSetEnd()
|
||||
elif type == TType.LIST:
|
||||
(etype, size) = self.readListBegin()
|
||||
for i in range(size):
|
||||
self.skip(etype)
|
||||
self.readListEnd()
|
||||
|
||||
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
|
||||
_TTYPE_HANDLERS = (
|
||||
(None, None, False), # 0 TType.STOP
|
||||
(None, None, False), # 1 TType.VOID # TODO: handle void?
|
||||
('readBool', 'writeBool', False), # 2 TType.BOOL
|
||||
('readByte', 'writeByte', False), # 3 TType.BYTE and I08
|
||||
('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
|
||||
(None, None, False), # 5 undefined
|
||||
('readI16', 'writeI16', False), # 6 TType.I16
|
||||
(None, None, False), # 7 undefined
|
||||
('readI32', 'writeI32', False), # 8 TType.I32
|
||||
(None, None, False), # 9 undefined
|
||||
('readI64', 'writeI64', False), # 10 TType.I64
|
||||
('readString', 'writeString', False), # 11 TType.STRING and UTF7
|
||||
('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
|
||||
('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
|
||||
('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
|
||||
('readContainerList', 'writeContainerList', True), # 15 TType.LIST
|
||||
(None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
|
||||
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
|
||||
)
|
||||
|
||||
def readFieldByTType(self, ttype, spec):
|
||||
try:
|
||||
(r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
|
||||
except IndexError:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid field type %d' % (ttype))
|
||||
if r_handler is None:
|
||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||
message='Invalid field type %d' % (ttype))
|
||||
reader = getattr(self, r_handler)
|
||||
if not is_container:
|
||||
return reader()
|
||||
return reader(spec)
|
||||
|
||||
def readContainerList(self, spec):
|
||||
results = []
|
||||
ttype, tspec = spec[0], spec[1]
|
||||
r_handler = self._TTYPE_HANDLERS[ttype][0]
|
||||
reader = getattr(self, r_handler)
|
||||
(list_type, list_len) = self.readListBegin()
|
||||
if tspec is None:
|
||||
# list values are simple types
|
||||
for idx in xrange(list_len):
|
||||
results.append(reader())
|
||||
else:
|
||||
# this is like an inlined readFieldByTType
|
||||
container_reader = self._TTYPE_HANDLERS[list_type][0]
|
||||
val_reader = getattr(self, container_reader)
|
||||
for idx in xrange(list_len):
|
||||
val = val_reader(tspec)
|
||||
results.append(val)
|
||||
self.readListEnd()
|
||||
return results
|
||||
|
||||
def readContainerSet(self, spec):
|
||||
results = set()
|
||||
ttype, tspec = spec[0], spec[1]
|
||||
r_handler = self._TTYPE_HANDLERS[ttype][0]
|
||||
reader = getattr(self, r_handler)
|
||||
(set_type, set_len) = self.readSetBegin()
|
||||
if tspec is None:
|
||||
# set members are simple types
|
||||
for idx in xrange(set_len):
|
||||
results.add(reader())
|
||||
else:
|
||||
container_reader = self._TTYPE_HANDLERS[set_type][0]
|
||||
val_reader = getattr(self, container_reader)
|
||||
for idx in xrange(set_len):
|
||||
results.add(val_reader(tspec))
|
||||
self.readSetEnd()
|
||||
return results
|
||||
|
||||
def readContainerStruct(self, spec):
|
||||
(obj_class, obj_spec) = spec
|
||||
obj = obj_class()
|
||||
obj.read(self)
|
||||
return obj
|
||||
|
||||
def readContainerMap(self, spec):
|
||||
results = dict()
|
||||
key_ttype, key_spec = spec[0], spec[1]
|
||||
val_ttype, val_spec = spec[2], spec[3]
|
||||
(map_ktype, map_vtype, map_len) = self.readMapBegin()
|
||||
# TODO: compare types we just decoded with thrift_spec and
|
||||
# abort/skip if types disagree
|
||||
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
|
||||
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
|
||||
# list values are simple types
|
||||
for idx in xrange(map_len):
|
||||
if key_spec is None:
|
||||
k_val = key_reader()
|
||||
else:
|
||||
k_val = self.readFieldByTType(key_ttype, key_spec)
|
||||
if val_spec is None:
|
||||
v_val = val_reader()
|
||||
else:
|
||||
v_val = self.readFieldByTType(val_ttype, val_spec)
|
||||
# this raises a TypeError with unhashable keys types
|
||||
# i.e. this fails: d=dict(); d[[0,1]] = 2
|
||||
results[k_val] = v_val
|
||||
self.readMapEnd()
|
||||
return results
|
||||
|
||||
def readStruct(self, obj, thrift_spec):
|
||||
self.readStructBegin()
|
||||
while True:
|
||||
(fname, ftype, fid) = self.readFieldBegin()
|
||||
if ftype == TType.STOP:
|
||||
break
|
||||
try:
|
||||
field = thrift_spec[fid]
|
||||
except IndexError:
|
||||
self.skip(ftype)
|
||||
else:
|
||||
if field is not None and ftype == field[1]:
|
||||
fname = field[2]
|
||||
fspec = field[3]
|
||||
val = self.readFieldByTType(ftype, fspec)
|
||||
setattr(obj, fname, val)
|
||||
else:
|
||||
self.skip(ftype)
|
||||
self.readFieldEnd()
|
||||
self.readStructEnd()
|
||||
|
||||
def writeContainerStruct(self, val, spec):
|
||||
val.write(self)
|
||||
|
||||
def writeContainerList(self, val, spec):
|
||||
self.writeListBegin(spec[0], len(val))
|
||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
|
||||
e_writer = getattr(self, w_handler)
|
||||
if not is_container:
|
||||
for elem in val:
|
||||
e_writer(elem)
|
||||
else:
|
||||
for elem in val:
|
||||
e_writer(elem, spec[1])
|
||||
self.writeListEnd()
|
||||
|
||||
def writeContainerSet(self, val, spec):
|
||||
self.writeSetBegin(spec[0], len(val))
|
||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
|
||||
e_writer = getattr(self, w_handler)
|
||||
if not is_container:
|
||||
for elem in val:
|
||||
e_writer(elem)
|
||||
else:
|
||||
for elem in val:
|
||||
e_writer(elem, spec[1])
|
||||
self.writeSetEnd()
|
||||
|
||||
def writeContainerMap(self, val, spec):
|
||||
k_type = spec[0]
|
||||
v_type = spec[2]
|
||||
ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
|
||||
ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
|
||||
k_writer = getattr(self, ktype_name)
|
||||
v_writer = getattr(self, vtype_name)
|
||||
self.writeMapBegin(k_type, v_type, len(val))
|
||||
for m_key, m_val in val.iteritems():
|
||||
if not k_is_container:
|
||||
k_writer(m_key)
|
||||
else:
|
||||
k_writer(m_key, spec[1])
|
||||
if not v_is_container:
|
||||
v_writer(m_val)
|
||||
else:
|
||||
v_writer(m_val, spec[3])
|
||||
self.writeMapEnd()
|
||||
|
||||
def writeStruct(self, obj, thrift_spec):
|
||||
self.writeStructBegin(obj.__class__.__name__)
|
||||
for field in thrift_spec:
|
||||
if field is None:
|
||||
continue
|
||||
fname = field[2]
|
||||
val = getattr(obj, fname)
|
||||
if val is None:
|
||||
# skip writing out unset fields
|
||||
continue
|
||||
fid = field[0]
|
||||
ftype = field[1]
|
||||
fspec = field[3]
|
||||
# get the writer method for this value
|
||||
self.writeFieldBegin(fname, ftype, fid)
|
||||
self.writeFieldByTType(ftype, val, fspec)
|
||||
self.writeFieldEnd()
|
||||
self.writeFieldStop()
|
||||
self.writeStructEnd()
|
||||
|
||||
def writeFieldByTType(self, ttype, val, spec):
|
||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
|
||||
writer = getattr(self, w_handler)
|
||||
if is_container:
|
||||
writer(val, spec)
|
||||
else:
|
||||
writer(val)
|
||||
|
||||
|
||||
class TProtocolFactory:
|
||||
def getProtocol(self, trans):
|
||||
pass
|
20
thrift/protocol/__init__.py
Normal file
20
thrift/protocol/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']
|
1219
thrift/protocol/fastbinary.c
Normal file
1219
thrift/protocol/fastbinary.c
Normal file
File diff suppressed because it is too large
Load diff
87
thrift/server/THttpServer.py
Normal file
87
thrift/server/THttpServer.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import BaseHTTPServer
|
||||
|
||||
from thrift.server import TServer
|
||||
from thrift.transport import TTransport
|
||||
|
||||
|
||||
class ResponseException(Exception):
|
||||
"""Allows handlers to override the HTTP response
|
||||
|
||||
Normally, THttpServer always sends a 200 response. If a handler wants
|
||||
to override this behavior (e.g., to simulate a misconfigured or
|
||||
overloaded web server during testing), it can raise a ResponseException.
|
||||
The function passed to the constructor will be called with the
|
||||
RequestHandler as its only argument.
|
||||
"""
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
|
||||
|
||||
class THttpServer(TServer.TServer):
|
||||
"""A simple HTTP-based Thrift server
|
||||
|
||||
This class is not very performant, but it is useful (for example) for
|
||||
acting as a mock version of an Apache-based PHP Thrift endpoint.
|
||||
"""
|
||||
def __init__(self,
|
||||
processor,
|
||||
server_address,
|
||||
inputProtocolFactory,
|
||||
outputProtocolFactory=None,
|
||||
server_class=BaseHTTPServer.HTTPServer):
|
||||
"""Set up protocol factories and HTTP server.
|
||||
|
||||
See BaseHTTPServer for server_address.
|
||||
See TServer for protocol factories.
|
||||
"""
|
||||
if outputProtocolFactory is None:
|
||||
outputProtocolFactory = inputProtocolFactory
|
||||
|
||||
TServer.TServer.__init__(self, processor, None, None, None,
|
||||
inputProtocolFactory, outputProtocolFactory)
|
||||
|
||||
thttpserver = self
|
||||
|
||||
class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
# Don't care about the request path.
|
||||
itrans = TTransport.TFileObjectTransport(self.rfile)
|
||||
otrans = TTransport.TFileObjectTransport(self.wfile)
|
||||
itrans = TTransport.TBufferedTransport(
|
||||
itrans, int(self.headers['Content-Length']))
|
||||
otrans = TTransport.TMemoryBuffer()
|
||||
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
|
||||
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
|
||||
try:
|
||||
thttpserver.processor.process(iprot, oprot)
|
||||
except ResponseException as exn:
|
||||
exn.handler(self)
|
||||
else:
|
||||
self.send_response(200)
|
||||
self.send_header("content-type", "application/x-thrift")
|
||||
self.end_headers()
|
||||
self.wfile.write(otrans.getvalue())
|
||||
|
||||
self.httpd = server_class(server_address, RequestHander)
|
||||
|
||||
def serve(self):
|
||||
self.httpd.serve_forever()
|
346
thrift/server/TNonblockingServer.py
Normal file
346
thrift/server/TNonblockingServer.py
Normal file
|
@ -0,0 +1,346 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
"""Implementation of non-blocking server.
|
||||
|
||||
The main idea of the server is to receive and send requests
|
||||
only from the main thread.
|
||||
|
||||
The thread poool should be sized for concurrent tasks, not
|
||||
maximum connections
|
||||
"""
|
||||
import threading
|
||||
import socket
|
||||
import Queue
|
||||
import select
|
||||
import struct
|
||||
import logging
|
||||
|
||||
from thrift.transport import TTransport
|
||||
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
|
||||
|
||||
__all__ = ['TNonblockingServer']
|
||||
|
||||
|
||||
class Worker(threading.Thread):
|
||||
"""Worker is a small helper to process incoming connection."""
|
||||
|
||||
def __init__(self, queue):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue
|
||||
|
||||
def run(self):
|
||||
"""Process queries from task queue, stop if processor is None."""
|
||||
while True:
|
||||
try:
|
||||
processor, iprot, oprot, otrans, callback = self.queue.get()
|
||||
if processor is None:
|
||||
break
|
||||
processor.process(iprot, oprot)
|
||||
callback(True, otrans.getvalue())
|
||||
except Exception:
|
||||
logging.exception("Exception while processing request")
|
||||
callback(False, '')
|
||||
|
||||
WAIT_LEN = 0
|
||||
WAIT_MESSAGE = 1
|
||||
WAIT_PROCESS = 2
|
||||
SEND_ANSWER = 3
|
||||
CLOSED = 4
|
||||
|
||||
|
||||
def locked(func):
|
||||
"""Decorator which locks self.lock."""
|
||||
def nested(self, *args, **kwargs):
|
||||
self.lock.acquire()
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
self.lock.release()
|
||||
return nested
|
||||
|
||||
|
||||
def socket_exception(func):
|
||||
"""Decorator close object on socket.error."""
|
||||
def read(self, *args, **kwargs):
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except socket.error:
|
||||
self.close()
|
||||
return read
|
||||
|
||||
|
||||
class Connection:
|
||||
"""Basic class is represented connection.
|
||||
|
||||
It can be in state:
|
||||
WAIT_LEN --- connection is reading request len.
|
||||
WAIT_MESSAGE --- connection is reading request.
|
||||
WAIT_PROCESS --- connection has just read whole request and
|
||||
waits for call ready routine.
|
||||
SEND_ANSWER --- connection is sending answer string (including length
|
||||
of answer).
|
||||
CLOSED --- socket was closed and connection should be deleted.
|
||||
"""
|
||||
def __init__(self, new_socket, wake_up):
|
||||
self.socket = new_socket
|
||||
self.socket.setblocking(False)
|
||||
self.status = WAIT_LEN
|
||||
self.len = 0
|
||||
self.message = ''
|
||||
self.lock = threading.Lock()
|
||||
self.wake_up = wake_up
|
||||
|
||||
def _read_len(self):
|
||||
"""Reads length of request.
|
||||
|
||||
It's a safer alternative to self.socket.recv(4)
|
||||
"""
|
||||
read = self.socket.recv(4 - len(self.message))
|
||||
if len(read) == 0:
|
||||
# if we read 0 bytes and self.message is empty, then
|
||||
# the client closed the connection
|
||||
if len(self.message) != 0:
|
||||
logging.error("can't read frame size from socket")
|
||||
self.close()
|
||||
return
|
||||
self.message += read
|
||||
if len(self.message) == 4:
|
||||
self.len, = struct.unpack('!i', self.message)
|
||||
if self.len < 0:
|
||||
logging.error("negative frame size, it seems client "
|
||||
"doesn't use FramedTransport")
|
||||
self.close()
|
||||
elif self.len == 0:
|
||||
logging.error("empty frame, it's really strange")
|
||||
self.close()
|
||||
else:
|
||||
self.message = ''
|
||||
self.status = WAIT_MESSAGE
|
||||
|
||||
@socket_exception
|
||||
def read(self):
|
||||
"""Reads data from stream and switch state."""
|
||||
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
|
||||
if self.status == WAIT_LEN:
|
||||
self._read_len()
|
||||
# go back to the main loop here for simplicity instead of
|
||||
# falling through, even though there is a good chance that
|
||||
# the message is already available
|
||||
elif self.status == WAIT_MESSAGE:
|
||||
read = self.socket.recv(self.len - len(self.message))
|
||||
if len(read) == 0:
|
||||
logging.error("can't read frame from socket (get %d of "
|
||||
"%d bytes)" % (len(self.message), self.len))
|
||||
self.close()
|
||||
return
|
||||
self.message += read
|
||||
if len(self.message) == self.len:
|
||||
self.status = WAIT_PROCESS
|
||||
|
||||
@socket_exception
|
||||
def write(self):
|
||||
"""Writes data from socket and switch state."""
|
||||
assert self.status == SEND_ANSWER
|
||||
sent = self.socket.send(self.message)
|
||||
if sent == len(self.message):
|
||||
self.status = WAIT_LEN
|
||||
self.message = ''
|
||||
self.len = 0
|
||||
else:
|
||||
self.message = self.message[sent:]
|
||||
|
||||
@locked
|
||||
def ready(self, all_ok, message):
|
||||
"""Callback function for switching state and waking up main thread.
|
||||
|
||||
This function is the only function witch can be called asynchronous.
|
||||
|
||||
The ready can switch Connection to three states:
|
||||
WAIT_LEN if request was oneway.
|
||||
SEND_ANSWER if request was processed in normal way.
|
||||
CLOSED if request throws unexpected exception.
|
||||
|
||||
The one wakes up main thread.
|
||||
"""
|
||||
assert self.status == WAIT_PROCESS
|
||||
if not all_ok:
|
||||
self.close()
|
||||
self.wake_up()
|
||||
return
|
||||
self.len = ''
|
||||
if len(message) == 0:
|
||||
# it was a oneway request, do not write answer
|
||||
self.message = ''
|
||||
self.status = WAIT_LEN
|
||||
else:
|
||||
self.message = struct.pack('!i', len(message)) + message
|
||||
self.status = SEND_ANSWER
|
||||
self.wake_up()
|
||||
|
||||
@locked
|
||||
def is_writeable(self):
|
||||
"""Return True if connection should be added to write list of select"""
|
||||
return self.status == SEND_ANSWER
|
||||
|
||||
# it's not necessary, but...
|
||||
@locked
|
||||
def is_readable(self):
|
||||
"""Return True if connection should be added to read list of select"""
|
||||
return self.status in (WAIT_LEN, WAIT_MESSAGE)
|
||||
|
||||
@locked
|
||||
def is_closed(self):
|
||||
"""Returns True if connection is closed."""
|
||||
return self.status == CLOSED
|
||||
|
||||
def fileno(self):
|
||||
"""Returns the file descriptor of the associated socket."""
|
||||
return self.socket.fileno()
|
||||
|
||||
def close(self):
|
||||
"""Closes connection"""
|
||||
self.status = CLOSED
|
||||
self.socket.close()
|
||||
|
||||
|
||||
class TNonblockingServer:
|
||||
"""Non-blocking server."""
|
||||
|
||||
def __init__(self,
|
||||
processor,
|
||||
lsocket,
|
||||
inputProtocolFactory=None,
|
||||
outputProtocolFactory=None,
|
||||
threads=10):
|
||||
self.processor = processor
|
||||
self.socket = lsocket
|
||||
self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
|
||||
self.out_protocol = outputProtocolFactory or self.in_protocol
|
||||
self.threads = int(threads)
|
||||
self.clients = {}
|
||||
self.tasks = Queue.Queue()
|
||||
self._read, self._write = socket.socketpair()
|
||||
self.prepared = False
|
||||
self._stop = False
|
||||
|
||||
def setNumThreads(self, num):
|
||||
"""Set the number of worker threads that should be created."""
|
||||
# implement ThreadPool interface
|
||||
assert not self.prepared, "Can't change number of threads after start"
|
||||
self.threads = num
|
||||
|
||||
def prepare(self):
|
||||
"""Prepares server for serve requests."""
|
||||
if self.prepared:
|
||||
return
|
||||
self.socket.listen()
|
||||
for _ in xrange(self.threads):
|
||||
thread = Worker(self.tasks)
|
||||
thread.setDaemon(True)
|
||||
thread.start()
|
||||
self.prepared = True
|
||||
|
||||
def wake_up(self):
|
||||
"""Wake up main thread.
|
||||
|
||||
The server usualy waits in select call in we should terminate one.
|
||||
The simplest way is using socketpair.
|
||||
|
||||
Select always wait to read from the first socket of socketpair.
|
||||
|
||||
In this case, we can just write anything to the second socket from
|
||||
socketpair.
|
||||
"""
|
||||
self._write.send('1')
|
||||
|
||||
def stop(self):
|
||||
"""Stop the server.
|
||||
|
||||
This method causes the serve() method to return. stop() may be invoked
|
||||
from within your handler, or from another thread.
|
||||
|
||||
After stop() is called, serve() will return but the server will still
|
||||
be listening on the socket. serve() may then be called again to resume
|
||||
processing requests. Alternatively, close() may be called after
|
||||
serve() returns to close the server socket and shutdown all worker
|
||||
threads.
|
||||
"""
|
||||
self._stop = True
|
||||
self.wake_up()
|
||||
|
||||
def _select(self):
|
||||
"""Does select on open connections."""
|
||||
readable = [self.socket.handle.fileno(), self._read.fileno()]
|
||||
writable = []
|
||||
for i, connection in self.clients.items():
|
||||
if connection.is_readable():
|
||||
readable.append(connection.fileno())
|
||||
if connection.is_writeable():
|
||||
writable.append(connection.fileno())
|
||||
if connection.is_closed():
|
||||
del self.clients[i]
|
||||
return select.select(readable, writable, readable)
|
||||
|
||||
def handle(self):
|
||||
"""Handle requests.
|
||||
|
||||
WARNING! You must call prepare() BEFORE calling handle()
|
||||
"""
|
||||
assert self.prepared, "You have to call prepare before handle"
|
||||
rset, wset, xset = self._select()
|
||||
for readable in rset:
|
||||
if readable == self._read.fileno():
|
||||
# don't care i just need to clean readable flag
|
||||
self._read.recv(1024)
|
||||
elif readable == self.socket.handle.fileno():
|
||||
client = self.socket.accept().handle
|
||||
self.clients[client.fileno()] = Connection(client,
|
||||
self.wake_up)
|
||||
else:
|
||||
connection = self.clients[readable]
|
||||
connection.read()
|
||||
if connection.status == WAIT_PROCESS:
|
||||
itransport = TTransport.TMemoryBuffer(connection.message)
|
||||
otransport = TTransport.TMemoryBuffer()
|
||||
iprot = self.in_protocol.getProtocol(itransport)
|
||||
oprot = self.out_protocol.getProtocol(otransport)
|
||||
self.tasks.put([self.processor, iprot, oprot,
|
||||
otransport, connection.ready])
|
||||
for writeable in wset:
|
||||
self.clients[writeable].write()
|
||||
for oob in xset:
|
||||
self.clients[oob].close()
|
||||
del self.clients[oob]
|
||||
|
||||
def close(self):
|
||||
"""Closes the server."""
|
||||
for _ in xrange(self.threads):
|
||||
self.tasks.put([None, None, None, None, None])
|
||||
self.socket.close()
|
||||
self.prepared = False
|
||||
|
||||
def serve(self):
|
||||
"""Serve requests.
|
||||
|
||||
Serve requests forever, or until stop() is called.
|
||||
"""
|
||||
self._stop = False
|
||||
self.prepare()
|
||||
while not self._stop:
|
||||
self.handle()
|
118
thrift/server/TProcessPoolServer.py
Normal file
118
thrift/server/TProcessPoolServer.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
from multiprocessing import Process, Value, Condition, reduction
|
||||
|
||||
from TServer import TServer
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
|
||||
class TProcessPoolServer(TServer):
|
||||
"""Server with a fixed size pool of worker subprocesses to service requests
|
||||
|
||||
Note that if you need shared state between the handlers - it's up to you!
|
||||
Written by Dvir Volk, doat.com
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
TServer.__init__(self, *args)
|
||||
self.numWorkers = 10
|
||||
self.workers = []
|
||||
self.isRunning = Value('b', False)
|
||||
self.stopCondition = Condition()
|
||||
self.postForkCallback = None
|
||||
|
||||
def setPostForkCallback(self, callback):
|
||||
if not callable(callback):
|
||||
raise TypeError("This is not a callback!")
|
||||
self.postForkCallback = callback
|
||||
|
||||
def setNumWorkers(self, num):
|
||||
"""Set the number of worker threads that should be created"""
|
||||
self.numWorkers = num
|
||||
|
||||
def workerProcess(self):
|
||||
"""Loop getting clients from the shared queue and process them"""
|
||||
if self.postForkCallback:
|
||||
self.postForkCallback()
|
||||
|
||||
while self.isRunning.value:
|
||||
try:
|
||||
client = self.serverTransport.accept()
|
||||
self.serveClient(client)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
return 0
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
def serveClient(self, client):
|
||||
"""Process input/output from a client for as long as possible"""
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||
|
||||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransportException, tx:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
||||
def serve(self):
|
||||
"""Start workers and put into queue"""
|
||||
# this is a shared state that can tell the workers to exit when False
|
||||
self.isRunning.value = True
|
||||
|
||||
# first bind and listen to the port
|
||||
self.serverTransport.listen()
|
||||
|
||||
# fork the children
|
||||
for i in range(self.numWorkers):
|
||||
try:
|
||||
w = Process(target=self.workerProcess)
|
||||
w.daemon = True
|
||||
w.start()
|
||||
self.workers.append(w)
|
||||
except Exception, x:
|
||||
logging.exception(x)
|
||||
|
||||
# wait until the condition is set by stop()
|
||||
while True:
|
||||
self.stopCondition.acquire()
|
||||
try:
|
||||
self.stopCondition.wait()
|
||||
break
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
break
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
self.isRunning.value = False
|
||||
|
||||
def stop(self):
|
||||
self.isRunning.value = False
|
||||
self.stopCondition.acquire()
|
||||
self.stopCondition.notify()
|
||||
self.stopCondition.release()
|
269
thrift/server/TServer.py
Normal file
269
thrift/server/TServer.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import Queue
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
from thrift.Thrift import TProcessor
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from thrift.transport import TTransport
|
||||
|
||||
|
||||
class TServer:
|
||||
"""Base interface for a server, which must have a serve() method.
|
||||
|
||||
Three constructors for all servers:
|
||||
1) (processor, serverTransport)
|
||||
2) (processor, serverTransport, transportFactory, protocolFactory)
|
||||
3) (processor, serverTransport,
|
||||
inputTransportFactory, outputTransportFactory,
|
||||
inputProtocolFactory, outputProtocolFactory)
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
if (len(args) == 2):
|
||||
self.__initArgs__(args[0], args[1],
|
||||
TTransport.TTransportFactoryBase(),
|
||||
TTransport.TTransportFactoryBase(),
|
||||
TBinaryProtocol.TBinaryProtocolFactory(),
|
||||
TBinaryProtocol.TBinaryProtocolFactory())
|
||||
elif (len(args) == 4):
|
||||
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
|
||||
elif (len(args) == 6):
|
||||
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
|
||||
|
||||
def __initArgs__(self, processor, serverTransport,
|
||||
inputTransportFactory, outputTransportFactory,
|
||||
inputProtocolFactory, outputProtocolFactory):
|
||||
self.processor = processor
|
||||
self.serverTransport = serverTransport
|
||||
self.inputTransportFactory = inputTransportFactory
|
||||
self.outputTransportFactory = outputTransportFactory
|
||||
self.inputProtocolFactory = inputProtocolFactory
|
||||
self.outputProtocolFactory = outputProtocolFactory
|
||||
|
||||
def serve(self):
|
||||
pass
|
||||
|
||||
|
||||
class TSimpleServer(TServer):
|
||||
"""Simple single-threaded server that just pumps around one transport."""
|
||||
|
||||
def __init__(self, *args):
|
||||
TServer.__init__(self, *args)
|
||||
|
||||
def serve(self):
|
||||
self.serverTransport.listen()
|
||||
while True:
|
||||
client = self.serverTransport.accept()
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
||||
|
||||
class TThreadedServer(TServer):
|
||||
"""Threaded server that spawns a new thread per each connection."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
TServer.__init__(self, *args)
|
||||
self.daemon = kwargs.get("daemon", False)
|
||||
|
||||
def serve(self):
|
||||
self.serverTransport.listen()
|
||||
while True:
|
||||
try:
|
||||
client = self.serverTransport.accept()
|
||||
t = threading.Thread(target=self.handle, args=(client,))
|
||||
t.setDaemon(self.daemon)
|
||||
t.start()
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
def handle(self, client):
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
||||
|
||||
class TThreadPoolServer(TServer):
|
||||
"""Server with a fixed size pool of threads which service requests."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
TServer.__init__(self, *args)
|
||||
self.clients = Queue.Queue()
|
||||
self.threads = 10
|
||||
self.daemon = kwargs.get("daemon", False)
|
||||
|
||||
def setNumThreads(self, num):
|
||||
"""Set the number of worker threads that should be created"""
|
||||
self.threads = num
|
||||
|
||||
def serveThread(self):
|
||||
"""Loop around getting clients from the shared queue and process them."""
|
||||
while True:
|
||||
try:
|
||||
client = self.clients.get()
|
||||
self.serveClient(client)
|
||||
except Exception, x:
|
||||
logging.exception(x)
|
||||
|
||||
def serveClient(self, client):
|
||||
"""Process input/output from a client for as long as possible"""
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
itrans.close()
|
||||
otrans.close()
|
||||
|
||||
def serve(self):
|
||||
"""Start a fixed number of worker threads and put client into a queue"""
|
||||
for i in range(self.threads):
|
||||
try:
|
||||
t = threading.Thread(target=self.serveThread)
|
||||
t.setDaemon(self.daemon)
|
||||
t.start()
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
# Pump the socket for clients
|
||||
self.serverTransport.listen()
|
||||
while True:
|
||||
try:
|
||||
client = self.serverTransport.accept()
|
||||
self.clients.put(client)
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
|
||||
class TForkingServer(TServer):
|
||||
"""A Thrift server that forks a new process for each request
|
||||
|
||||
This is more scalable than the threaded server as it does not cause
|
||||
GIL contention.
|
||||
|
||||
Note that this has different semantics from the threading server.
|
||||
Specifically, updates to shared variables will no longer be shared.
|
||||
It will also not work on windows.
|
||||
|
||||
This code is heavily inspired by SocketServer.ForkingMixIn in the
|
||||
Python stdlib.
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
TServer.__init__(self, *args)
|
||||
self.children = []
|
||||
|
||||
def serve(self):
|
||||
def try_close(file):
|
||||
try:
|
||||
file.close()
|
||||
except IOError as e:
|
||||
logging.warning(e, exc_info=True)
|
||||
|
||||
self.serverTransport.listen()
|
||||
while True:
|
||||
client = self.serverTransport.accept()
|
||||
try:
|
||||
pid = os.fork()
|
||||
|
||||
if pid: # parent
|
||||
# add before collect, otherwise you race w/ waitpid
|
||||
self.children.append(pid)
|
||||
self.collect_children()
|
||||
|
||||
# Parent must close socket or the connection may not get
|
||||
# closed promptly
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
try_close(itrans)
|
||||
try_close(otrans)
|
||||
else:
|
||||
itrans = self.inputTransportFactory.getTransport(client)
|
||||
otrans = self.outputTransportFactory.getTransport(client)
|
||||
|
||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||
|
||||
ecode = 0
|
||||
try:
|
||||
try:
|
||||
while True:
|
||||
self.processor.process(iprot, oprot)
|
||||
except TTransport.TTransportException, tx:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
ecode = 1
|
||||
finally:
|
||||
try_close(itrans)
|
||||
try_close(otrans)
|
||||
|
||||
os._exit(ecode)
|
||||
|
||||
except TTransport.TTransportException, tx:
|
||||
pass
|
||||
except Exception as x:
|
||||
logging.exception(x)
|
||||
|
||||
def collect_children(self):
|
||||
while self.children:
|
||||
try:
|
||||
pid, status = os.waitpid(0, os.WNOHANG)
|
||||
except os.error:
|
||||
pid = None
|
||||
|
||||
if pid:
|
||||
self.children.remove(pid)
|
||||
else:
|
||||
break
|
20
thrift/server/__init__.py
Normal file
20
thrift/server/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
__all__ = ['TServer', 'TNonblockingServer']
|
149
thrift/transport/THttpClient.py
Normal file
149
thrift/transport/THttpClient.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import httplib
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import urllib
|
||||
import urlparse
|
||||
import warnings
|
||||
|
||||
from cStringIO import StringIO
|
||||
|
||||
from TTransport import *
|
||||
|
||||
|
||||
class THttpClient(TTransportBase):
|
||||
"""Http implementation of TTransport base."""
|
||||
|
||||
def __init__(self, uri_or_host, port=None, path=None):
|
||||
"""THttpClient supports two different types constructor parameters.
|
||||
|
||||
THttpClient(host, port, path) - deprecated
|
||||
THttpClient(uri)
|
||||
|
||||
Only the second supports https.
|
||||
"""
|
||||
if port is not None:
|
||||
warnings.warn(
|
||||
"Please use the THttpClient('http://host:port/path') syntax",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
self.host = uri_or_host
|
||||
self.port = port
|
||||
assert path
|
||||
self.path = path
|
||||
self.scheme = 'http'
|
||||
else:
|
||||
parsed = urlparse.urlparse(uri_or_host)
|
||||
self.scheme = parsed.scheme
|
||||
assert self.scheme in ('http', 'https')
|
||||
if self.scheme == 'http':
|
||||
self.port = parsed.port or httplib.HTTP_PORT
|
||||
elif self.scheme == 'https':
|
||||
self.port = parsed.port or httplib.HTTPS_PORT
|
||||
self.host = parsed.hostname
|
||||
self.path = parsed.path
|
||||
if parsed.query:
|
||||
self.path += '?%s' % parsed.query
|
||||
self.__wbuf = StringIO()
|
||||
self.__http = None
|
||||
self.__timeout = None
|
||||
self.__custom_headers = None
|
||||
|
||||
def open(self):
|
||||
if self.scheme == 'http':
|
||||
self.__http = httplib.HTTP(self.host, self.port)
|
||||
else:
|
||||
self.__http = httplib.HTTPS(self.host, self.port)
|
||||
|
||||
def close(self):
|
||||
self.__http.close()
|
||||
self.__http = None
|
||||
|
||||
def isOpen(self):
|
||||
return self.__http is not None
|
||||
|
||||
def setTimeout(self, ms):
|
||||
if not hasattr(socket, 'getdefaulttimeout'):
|
||||
raise NotImplementedError
|
||||
|
||||
if ms is None:
|
||||
self.__timeout = None
|
||||
else:
|
||||
self.__timeout = ms / 1000.0
|
||||
|
||||
def setCustomHeaders(self, headers):
|
||||
self.__custom_headers = headers
|
||||
|
||||
def read(self, sz):
|
||||
return self.__http.file.read(sz)
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def __withTimeout(f):
|
||||
def _f(*args, **kwargs):
|
||||
orig_timeout = socket.getdefaulttimeout()
|
||||
socket.setdefaulttimeout(args[0].__timeout)
|
||||
result = f(*args, **kwargs)
|
||||
socket.setdefaulttimeout(orig_timeout)
|
||||
return result
|
||||
return _f
|
||||
|
||||
def flush(self):
|
||||
if self.isOpen():
|
||||
self.close()
|
||||
self.open()
|
||||
|
||||
# Pull data out of buffer
|
||||
data = self.__wbuf.getvalue()
|
||||
self.__wbuf = StringIO()
|
||||
|
||||
# HTTP request
|
||||
self.__http.putrequest('POST', self.path)
|
||||
|
||||
# Write headers
|
||||
self.__http.putheader('Host', self.host)
|
||||
self.__http.putheader('Content-Type', 'application/x-thrift')
|
||||
self.__http.putheader('Content-Length', str(len(data)))
|
||||
|
||||
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
|
||||
user_agent = 'Python/THttpClient'
|
||||
script = os.path.basename(sys.argv[0])
|
||||
if script:
|
||||
user_agent = '%s (%s)' % (user_agent, urllib.quote(script))
|
||||
self.__http.putheader('User-Agent', user_agent)
|
||||
|
||||
if self.__custom_headers:
|
||||
for key, val in self.__custom_headers.iteritems():
|
||||
self.__http.putheader(key, val)
|
||||
|
||||
self.__http.endheaders()
|
||||
|
||||
# Write payload
|
||||
self.__http.send(data)
|
||||
|
||||
# Get reply to flush the request
|
||||
self.code, self.message, self.headers = self.__http.getreply()
|
||||
|
||||
# Decorate if we know how to timeout
|
||||
if hasattr(socket, 'getdefaulttimeout'):
|
||||
flush = __withTimeout(flush)
|
202
thrift/transport/TSSLSocket.py
Normal file
202
thrift/transport/TSSLSocket.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
from thrift.transport import TSocket
|
||||
from thrift.transport.TTransport import TTransportException
|
||||
|
||||
|
||||
class TSSLSocket(TSocket.TSocket):
|
||||
"""
|
||||
SSL implementation of client-side TSocket
|
||||
|
||||
This class creates outbound sockets wrapped using the
|
||||
python standard ssl module for encrypted connections.
|
||||
|
||||
The protocol used is set using the class variable
|
||||
SSL_VERSION, which must be one of ssl.PROTOCOL_* and
|
||||
defaults to ssl.PROTOCOL_TLSv1 for greatest security.
|
||||
"""
|
||||
SSL_VERSION = ssl.PROTOCOL_TLSv1
|
||||
|
||||
def __init__(self,
|
||||
host='localhost',
|
||||
port=9090,
|
||||
validate=True,
|
||||
ca_certs=None,
|
||||
unix_socket=None):
|
||||
"""Create SSL TSocket
|
||||
|
||||
@param validate: Set to False to disable SSL certificate validation
|
||||
@type validate: bool
|
||||
@param ca_certs: Filename to the Certificate Authority pem file, possibly a
|
||||
file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to
|
||||
the ssl_wrap function as the 'ca_certs' parameter.
|
||||
@type ca_certs: str
|
||||
|
||||
Raises an IOError exception if validate is True and the ca_certs file is
|
||||
None, not present or unreadable.
|
||||
"""
|
||||
self.validate = validate
|
||||
self.is_valid = False
|
||||
self.peercert = None
|
||||
if not validate:
|
||||
self.cert_reqs = ssl.CERT_NONE
|
||||
else:
|
||||
self.cert_reqs = ssl.CERT_REQUIRED
|
||||
self.ca_certs = ca_certs
|
||||
if validate:
|
||||
if ca_certs is None or not os.access(ca_certs, os.R_OK):
|
||||
raise IOError('Certificate Authority ca_certs file "%s" '
|
||||
'is not readable, cannot validate SSL '
|
||||
'certificates.' % (ca_certs))
|
||||
TSocket.TSocket.__init__(self, host, port, unix_socket)
|
||||
|
||||
def open(self):
|
||||
try:
|
||||
res0 = self._resolveAddr()
|
||||
for res in res0:
|
||||
sock_family, sock_type = res[0:2]
|
||||
ip_port = res[4]
|
||||
plain_sock = socket.socket(sock_family, sock_type)
|
||||
self.handle = ssl.wrap_socket(plain_sock,
|
||||
ssl_version=self.SSL_VERSION,
|
||||
do_handshake_on_connect=True,
|
||||
ca_certs=self.ca_certs,
|
||||
cert_reqs=self.cert_reqs)
|
||||
self.handle.settimeout(self._timeout)
|
||||
try:
|
||||
self.handle.connect(ip_port)
|
||||
except socket.error as e:
|
||||
if res is not res0[-1]:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
break
|
||||
except socket.error as e:
|
||||
if self._unix_socket:
|
||||
message = 'Could not connect to secure socket %s' % self._unix_socket
|
||||
else:
|
||||
message = 'Could not connect to %s:%d' % (self.host, self.port)
|
||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
||||
message=message)
|
||||
if self.validate:
|
||||
self._validate_cert()
|
||||
|
||||
def _validate_cert(self):
|
||||
"""internal method to validate the peer's SSL certificate, and to check the
|
||||
commonName of the certificate to ensure it matches the hostname we
|
||||
used to make this connection. Does not support subjectAltName records
|
||||
in certificates.
|
||||
|
||||
raises TTransportException if the certificate fails validation.
|
||||
"""
|
||||
cert = self.handle.getpeercert()
|
||||
self.peercert = cert
|
||||
if 'subject' not in cert:
|
||||
raise TTransportException(
|
||||
type=TTransportException.NOT_OPEN,
|
||||
message='No SSL certificate found from %s:%s' % (self.host, self.port))
|
||||
fields = cert['subject']
|
||||
for field in fields:
|
||||
# ensure structure we get back is what we expect
|
||||
if not isinstance(field, tuple):
|
||||
continue
|
||||
cert_pair = field[0]
|
||||
if len(cert_pair) < 2:
|
||||
continue
|
||||
cert_key, cert_value = cert_pair[0:2]
|
||||
if cert_key != 'commonName':
|
||||
continue
|
||||
certhost = cert_value
|
||||
if certhost == self.host:
|
||||
# success, cert commonName matches desired hostname
|
||||
self.is_valid = True
|
||||
return
|
||||
else:
|
||||
raise TTransportException(
|
||||
type=TTransportException.UNKNOWN,
|
||||
message='Hostname we connected to "%s" doesn\'t match certificate '
|
||||
'provided commonName "%s"' % (self.host, certhost))
|
||||
raise TTransportException(
|
||||
type=TTransportException.UNKNOWN,
|
||||
message='Could not validate SSL certificate from '
|
||||
'host "%s". Cert=%s' % (self.host, cert))
|
||||
|
||||
|
||||
class TSSLServerSocket(TSocket.TServerSocket):
|
||||
"""SSL implementation of TServerSocket
|
||||
|
||||
This uses the ssl module's wrap_socket() method to provide SSL
|
||||
negotiated encryption.
|
||||
"""
|
||||
SSL_VERSION = ssl.PROTOCOL_TLSv1
|
||||
|
||||
def __init__(self,
|
||||
host=None,
|
||||
port=9090,
|
||||
certfile='cert.pem',
|
||||
unix_socket=None):
|
||||
"""Initialize a TSSLServerSocket
|
||||
|
||||
@param certfile: filename of the server certificate, defaults to cert.pem
|
||||
@type certfile: str
|
||||
@param host: The hostname or IP to bind the listen socket to,
|
||||
i.e. 'localhost' for only allowing local network connections.
|
||||
Pass None to bind to all interfaces.
|
||||
@type host: str
|
||||
@param port: The port to listen on for inbound connections.
|
||||
@type port: int
|
||||
"""
|
||||
self.setCertfile(certfile)
|
||||
TSocket.TServerSocket.__init__(self, host, port)
|
||||
|
||||
def setCertfile(self, certfile):
|
||||
"""Set or change the server certificate file used to wrap new connections.
|
||||
|
||||
@param certfile: The filename of the server certificate,
|
||||
i.e. '/etc/certs/server.pem'
|
||||
@type certfile: str
|
||||
|
||||
Raises an IOError exception if the certfile is not present or unreadable.
|
||||
"""
|
||||
if not os.access(certfile, os.R_OK):
|
||||
raise IOError('No such certfile found: %s' % (certfile))
|
||||
self.certfile = certfile
|
||||
|
||||
def accept(self):
|
||||
plain_client, addr = self.handle.accept()
|
||||
try:
|
||||
client = ssl.wrap_socket(plain_client, certfile=self.certfile,
|
||||
server_side=True, ssl_version=self.SSL_VERSION)
|
||||
except ssl.SSLError as ssl_exc:
|
||||
# failed handshake/ssl wrap, close socket to client
|
||||
plain_client.close()
|
||||
# raise ssl_exc
|
||||
# We can't raise the exception, because it kills most TServer derived
|
||||
# serve() methods.
|
||||
# Instead, return None, and let the TServer instance deal with it in
|
||||
# other exception handling. (but TSimpleServer dies anyway)
|
||||
return None
|
||||
result = TSocket.TSocket()
|
||||
result.setHandle(client)
|
||||
return result
|
176
thrift/transport/TSocket.py
Normal file
176
thrift/transport/TSocket.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import errno
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from TTransport import *
|
||||
|
||||
|
||||
class TSocketBase(TTransportBase):
|
||||
def _resolveAddr(self):
|
||||
if self._unix_socket is not None:
|
||||
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
|
||||
self._unix_socket)]
|
||||
else:
|
||||
return socket.getaddrinfo(self.host,
|
||||
self.port,
|
||||
socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM,
|
||||
0,
|
||||
socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
|
||||
|
||||
def close(self):
|
||||
if self.handle:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
|
||||
class TSocket(TSocketBase):
|
||||
"""Socket implementation of TTransport base."""
|
||||
|
||||
def __init__(self, host='localhost', port=9090, unix_socket=None):
|
||||
"""Initialize a TSocket
|
||||
|
||||
@param host(str) The host to connect to.
|
||||
@param port(int) The (TCP) port to connect to.
|
||||
@param unix_socket(str) The filename of a unix socket to connect to.
|
||||
(host and port will be ignored.)
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.handle = None
|
||||
self._unix_socket = unix_socket
|
||||
self._timeout = None
|
||||
|
||||
def setHandle(self, h):
|
||||
self.handle = h
|
||||
|
||||
def isOpen(self):
|
||||
return self.handle is not None
|
||||
|
||||
def setTimeout(self, ms):
|
||||
if ms is None:
|
||||
self._timeout = None
|
||||
else:
|
||||
self._timeout = ms / 1000.0
|
||||
|
||||
if self.handle is not None:
|
||||
self.handle.settimeout(self._timeout)
|
||||
|
||||
def open(self):
|
||||
try:
|
||||
res0 = self._resolveAddr()
|
||||
for res in res0:
|
||||
self.handle = socket.socket(res[0], res[1])
|
||||
self.handle.settimeout(self._timeout)
|
||||
try:
|
||||
self.handle.connect(res[4])
|
||||
except socket.error, e:
|
||||
if res is not res0[-1]:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
break
|
||||
except socket.error, e:
|
||||
if self._unix_socket:
|
||||
message = 'Could not connect to socket %s' % self._unix_socket
|
||||
else:
|
||||
message = 'Could not connect to %s:%d' % (self.host, self.port)
|
||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
||||
message=message)
|
||||
|
||||
def read(self, sz):
|
||||
try:
|
||||
buff = self.handle.recv(sz)
|
||||
except socket.error, e:
|
||||
if (e.args[0] == errno.ECONNRESET and
|
||||
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
|
||||
# freebsd and Mach don't follow POSIX semantic of recv
|
||||
# and fail with ECONNRESET if peer performed shutdown.
|
||||
# See corresponding comment and code in TSocket::read()
|
||||
# in lib/cpp/src/transport/TSocket.cpp.
|
||||
self.close()
|
||||
# Trigger the check to raise the END_OF_FILE exception below.
|
||||
buff = ''
|
||||
else:
|
||||
raise
|
||||
if len(buff) == 0:
|
||||
raise TTransportException(type=TTransportException.END_OF_FILE,
|
||||
message='TSocket read 0 bytes')
|
||||
return buff
|
||||
|
||||
def write(self, buff):
|
||||
if not self.handle:
|
||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
||||
message='Transport not open')
|
||||
sent = 0
|
||||
have = len(buff)
|
||||
while sent < have:
|
||||
plus = self.handle.send(buff)
|
||||
if plus == 0:
|
||||
raise TTransportException(type=TTransportException.END_OF_FILE,
|
||||
message='TSocket sent 0 bytes')
|
||||
sent += plus
|
||||
buff = buff[plus:]
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
class TServerSocket(TSocketBase, TServerTransportBase):
|
||||
"""Socket implementation of TServerTransport base."""
|
||||
|
||||
def __init__(self, host=None, port=9090, unix_socket=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self._unix_socket = unix_socket
|
||||
self.handle = None
|
||||
|
||||
def listen(self):
|
||||
res0 = self._resolveAddr()
|
||||
for res in res0:
|
||||
if res[0] is socket.AF_INET6 or res is res0[-1]:
|
||||
break
|
||||
|
||||
# We need remove the old unix socket if the file exists and
|
||||
# nobody is listening on it.
|
||||
if self._unix_socket:
|
||||
tmp = socket.socket(res[0], res[1])
|
||||
try:
|
||||
tmp.connect(res[4])
|
||||
except socket.error, err:
|
||||
eno, message = err.args
|
||||
if eno == errno.ECONNREFUSED:
|
||||
os.unlink(res[4])
|
||||
|
||||
self.handle = socket.socket(res[0], res[1])
|
||||
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
if hasattr(self.handle, 'settimeout'):
|
||||
self.handle.settimeout(None)
|
||||
self.handle.bind(res[4])
|
||||
self.handle.listen(128)
|
||||
|
||||
def accept(self):
|
||||
client, addr = self.handle.accept()
|
||||
result = TSocket()
|
||||
result.setHandle(client)
|
||||
return result
|
330
thrift/transport/TTransport.py
Normal file
330
thrift/transport/TTransport.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from cStringIO import StringIO
|
||||
from struct import pack, unpack
|
||||
from thrift.Thrift import TException
|
||||
|
||||
|
||||
class TTransportException(TException):
|
||||
"""Custom Transport Exception class"""
|
||||
|
||||
UNKNOWN = 0
|
||||
NOT_OPEN = 1
|
||||
ALREADY_OPEN = 2
|
||||
TIMED_OUT = 3
|
||||
END_OF_FILE = 4
|
||||
|
||||
def __init__(self, type=UNKNOWN, message=None):
|
||||
TException.__init__(self, message)
|
||||
self.type = type
|
||||
|
||||
|
||||
class TTransportBase:
|
||||
"""Base class for Thrift transport layer."""
|
||||
|
||||
def isOpen(self):
|
||||
pass
|
||||
|
||||
def open(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def read(self, sz):
|
||||
pass
|
||||
|
||||
def readAll(self, sz):
|
||||
buff = ''
|
||||
have = 0
|
||||
while (have < sz):
|
||||
chunk = self.read(sz - have)
|
||||
have += len(chunk)
|
||||
buff += chunk
|
||||
|
||||
if len(chunk) == 0:
|
||||
raise EOFError()
|
||||
|
||||
return buff
|
||||
|
||||
def write(self, buf):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
# This class should be thought of as an interface.
|
||||
class CReadableTransport:
|
||||
"""base class for transports that are readable from C"""
|
||||
|
||||
# TODO(dreiss): Think about changing this interface to allow us to use
|
||||
# a (Python, not c) StringIO instead, because it allows
|
||||
# you to write after reading.
|
||||
|
||||
# NOTE: This is a classic class, so properties will NOT work
|
||||
# correctly for setting.
|
||||
@property
|
||||
def cstringio_buf(self):
|
||||
"""A cStringIO buffer that contains the current chunk we are reading."""
|
||||
pass
|
||||
|
||||
def cstringio_refill(self, partialread, reqlen):
|
||||
"""Refills cstringio_buf.
|
||||
|
||||
Returns the currently used buffer (which can but need not be the same as
|
||||
the old cstringio_buf). partialread is what the C code has read from the
|
||||
buffer, and should be inserted into the buffer before any more reads. The
|
||||
return value must be a new, not borrowed reference. Something along the
|
||||
lines of self._buf should be fine.
|
||||
|
||||
If reqlen bytes can't be read, throw EOFError.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TServerTransportBase:
|
||||
"""Base class for Thrift server transports."""
|
||||
|
||||
def listen(self):
|
||||
pass
|
||||
|
||||
def accept(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class TTransportFactoryBase:
|
||||
"""Base class for a Transport Factory"""
|
||||
|
||||
def getTransport(self, trans):
|
||||
return trans
|
||||
|
||||
|
||||
class TBufferedTransportFactory:
|
||||
"""Factory transport that builds buffered transports"""
|
||||
|
||||
def getTransport(self, trans):
|
||||
buffered = TBufferedTransport(trans)
|
||||
return buffered
|
||||
|
||||
|
||||
class TBufferedTransport(TTransportBase, CReadableTransport):
|
||||
"""Class that wraps another transport and buffers its I/O.
|
||||
|
||||
The implementation uses a (configurable) fixed-size read buffer
|
||||
but buffers all writes until a flush is performed.
|
||||
"""
|
||||
DEFAULT_BUFFER = 4096
|
||||
|
||||
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
|
||||
self.__trans = trans
|
||||
self.__wbuf = StringIO()
|
||||
self.__rbuf = StringIO("")
|
||||
self.__rbuf_size = rbuf_size
|
||||
|
||||
def isOpen(self):
|
||||
return self.__trans.isOpen()
|
||||
|
||||
def open(self):
|
||||
return self.__trans.open()
|
||||
|
||||
def close(self):
|
||||
return self.__trans.close()
|
||||
|
||||
def read(self, sz):
|
||||
ret = self.__rbuf.read(sz)
|
||||
if len(ret) != 0:
|
||||
return ret
|
||||
|
||||
self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size)))
|
||||
return self.__rbuf.read(sz)
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def flush(self):
|
||||
out = self.__wbuf.getvalue()
|
||||
# reset wbuf before write/flush to preserve state on underlying failure
|
||||
self.__wbuf = StringIO()
|
||||
self.__trans.write(out)
|
||||
self.__trans.flush()
|
||||
|
||||
# Implement the CReadableTransport interface.
|
||||
@property
|
||||
def cstringio_buf(self):
|
||||
return self.__rbuf
|
||||
|
||||
def cstringio_refill(self, partialread, reqlen):
|
||||
retstring = partialread
|
||||
if reqlen < self.__rbuf_size:
|
||||
# try to make a read of as much as we can.
|
||||
retstring += self.__trans.read(self.__rbuf_size)
|
||||
|
||||
# but make sure we do read reqlen bytes.
|
||||
if len(retstring) < reqlen:
|
||||
retstring += self.__trans.readAll(reqlen - len(retstring))
|
||||
|
||||
self.__rbuf = StringIO(retstring)
|
||||
return self.__rbuf
|
||||
|
||||
|
||||
class TMemoryBuffer(TTransportBase, CReadableTransport):
|
||||
"""Wraps a cStringIO object as a TTransport.
|
||||
|
||||
NOTE: Unlike the C++ version of this class, you cannot write to it
|
||||
then immediately read from it. If you want to read from a
|
||||
TMemoryBuffer, you must either pass a string to the constructor.
|
||||
TODO(dreiss): Make this work like the C++ version.
|
||||
"""
|
||||
|
||||
def __init__(self, value=None):
|
||||
"""value -- a value to read from for stringio
|
||||
|
||||
If value is set, this will be a transport for reading,
|
||||
otherwise, it is for writing"""
|
||||
if value is not None:
|
||||
self._buffer = StringIO(value)
|
||||
else:
|
||||
self._buffer = StringIO()
|
||||
|
||||
def isOpen(self):
|
||||
return not self._buffer.closed
|
||||
|
||||
def open(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self._buffer.close()
|
||||
|
||||
def read(self, sz):
|
||||
return self._buffer.read(sz)
|
||||
|
||||
def write(self, buf):
|
||||
self._buffer.write(buf)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def getvalue(self):
|
||||
return self._buffer.getvalue()
|
||||
|
||||
# Implement the CReadableTransport interface.
|
||||
@property
|
||||
def cstringio_buf(self):
|
||||
return self._buffer
|
||||
|
||||
def cstringio_refill(self, partialread, reqlen):
|
||||
# only one shot at reading...
|
||||
raise EOFError()
|
||||
|
||||
|
||||
class TFramedTransportFactory:
|
||||
"""Factory transport that builds framed transports"""
|
||||
|
||||
def getTransport(self, trans):
|
||||
framed = TFramedTransport(trans)
|
||||
return framed
|
||||
|
||||
|
||||
class TFramedTransport(TTransportBase, CReadableTransport):
|
||||
"""Class that wraps another transport and frames its I/O when writing."""
|
||||
|
||||
def __init__(self, trans,):
|
||||
self.__trans = trans
|
||||
self.__rbuf = StringIO()
|
||||
self.__wbuf = StringIO()
|
||||
|
||||
def isOpen(self):
|
||||
return self.__trans.isOpen()
|
||||
|
||||
def open(self):
|
||||
return self.__trans.open()
|
||||
|
||||
def close(self):
|
||||
return self.__trans.close()
|
||||
|
||||
def read(self, sz):
|
||||
ret = self.__rbuf.read(sz)
|
||||
if len(ret) != 0:
|
||||
return ret
|
||||
|
||||
self.readFrame()
|
||||
return self.__rbuf.read(sz)
|
||||
|
||||
def readFrame(self):
|
||||
buff = self.__trans.readAll(4)
|
||||
sz, = unpack('!i', buff)
|
||||
self.__rbuf = StringIO(self.__trans.readAll(sz))
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def flush(self):
|
||||
wout = self.__wbuf.getvalue()
|
||||
wsz = len(wout)
|
||||
# reset wbuf before write/flush to preserve state on underlying failure
|
||||
self.__wbuf = StringIO()
|
||||
# N.B.: Doing this string concatenation is WAY cheaper than making
|
||||
# two separate calls to the underlying socket object. Socket writes in
|
||||
# Python turn out to be REALLY expensive, but it seems to do a pretty
|
||||
# good job of managing string buffer operations without excessive copies
|
||||
buf = pack("!i", wsz) + wout
|
||||
self.__trans.write(buf)
|
||||
self.__trans.flush()
|
||||
|
||||
# Implement the CReadableTransport interface.
|
||||
@property
|
||||
def cstringio_buf(self):
|
||||
return self.__rbuf
|
||||
|
||||
def cstringio_refill(self, prefix, reqlen):
|
||||
# self.__rbuf will already be empty here because fastbinary doesn't
|
||||
# ask for a refill until the previous buffer is empty. Therefore,
|
||||
# we can start reading new frames immediately.
|
||||
while len(prefix) < reqlen:
|
||||
self.readFrame()
|
||||
prefix += self.__rbuf.getvalue()
|
||||
self.__rbuf = StringIO(prefix)
|
||||
return self.__rbuf
|
||||
|
||||
|
||||
class TFileObjectTransport(TTransportBase):
|
||||
"""Wraps a file-like object to make it work as a Thrift transport."""
|
||||
|
||||
def __init__(self, fileobj):
|
||||
self.fileobj = fileobj
|
||||
|
||||
def isOpen(self):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
self.fileobj.close()
|
||||
|
||||
def read(self, sz):
|
||||
return self.fileobj.read(sz)
|
||||
|
||||
def write(self, buf):
|
||||
self.fileobj.write(buf)
|
||||
|
||||
def flush(self):
|
||||
self.fileobj.flush()
|
221
thrift/transport/TTwisted.py
Normal file
221
thrift/transport/TTwisted.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
from cStringIO import StringIO
|
||||
|
||||
from zope.interface import implements, Interface, Attribute
|
||||
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
|
||||
connectionDone
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols import basic
|
||||
from twisted.python import log
|
||||
from twisted.web import server, resource, http
|
||||
|
||||
from thrift.transport import TTransport
|
||||
|
||||
|
||||
class TMessageSenderTransport(TTransport.TTransportBase):
|
||||
|
||||
def __init__(self):
|
||||
self.__wbuf = StringIO()
|
||||
|
||||
def write(self, buf):
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def flush(self):
|
||||
msg = self.__wbuf.getvalue()
|
||||
self.__wbuf = StringIO()
|
||||
self.sendMessage(msg)
|
||||
|
||||
def sendMessage(self, message):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TCallbackTransport(TMessageSenderTransport):
|
||||
|
||||
def __init__(self, func):
|
||||
TMessageSenderTransport.__init__(self)
|
||||
self.func = func
|
||||
|
||||
def sendMessage(self, message):
|
||||
self.func(message)
|
||||
|
||||
|
||||
class ThriftClientProtocol(basic.Int32StringReceiver):
|
||||
|
||||
MAX_LENGTH = 2 ** 31 - 1
|
||||
|
||||
def __init__(self, client_class, iprot_factory, oprot_factory=None):
|
||||
self._client_class = client_class
|
||||
self._iprot_factory = iprot_factory
|
||||
if oprot_factory is None:
|
||||
self._oprot_factory = iprot_factory
|
||||
else:
|
||||
self._oprot_factory = oprot_factory
|
||||
|
||||
self.recv_map = {}
|
||||
self.started = defer.Deferred()
|
||||
|
||||
def dispatch(self, msg):
|
||||
self.sendString(msg)
|
||||
|
||||
def connectionMade(self):
|
||||
tmo = TCallbackTransport(self.dispatch)
|
||||
self.client = self._client_class(tmo, self._oprot_factory)
|
||||
self.started.callback(self.client)
|
||||
|
||||
def connectionLost(self, reason=connectionDone):
|
||||
for k, v in self.client._reqs.iteritems():
|
||||
tex = TTransport.TTransportException(
|
||||
type=TTransport.TTransportException.END_OF_FILE,
|
||||
message='Connection closed')
|
||||
v.errback(tex)
|
||||
|
||||
def stringReceived(self, frame):
|
||||
tr = TTransport.TMemoryBuffer(frame)
|
||||
iprot = self._iprot_factory.getProtocol(tr)
|
||||
(fname, mtype, rseqid) = iprot.readMessageBegin()
|
||||
|
||||
try:
|
||||
method = self.recv_map[fname]
|
||||
except KeyError:
|
||||
method = getattr(self.client, 'recv_' + fname)
|
||||
self.recv_map[fname] = method
|
||||
|
||||
method(iprot, mtype, rseqid)
|
||||
|
||||
|
||||
class ThriftServerProtocol(basic.Int32StringReceiver):
|
||||
|
||||
MAX_LENGTH = 2 ** 31 - 1
|
||||
|
||||
def dispatch(self, msg):
|
||||
self.sendString(msg)
|
||||
|
||||
def processError(self, error):
|
||||
self.transport.loseConnection()
|
||||
|
||||
def processOk(self, _, tmo):
|
||||
msg = tmo.getvalue()
|
||||
|
||||
if len(msg) > 0:
|
||||
self.dispatch(msg)
|
||||
|
||||
def stringReceived(self, frame):
|
||||
tmi = TTransport.TMemoryBuffer(frame)
|
||||
tmo = TTransport.TMemoryBuffer()
|
||||
|
||||
iprot = self.factory.iprot_factory.getProtocol(tmi)
|
||||
oprot = self.factory.oprot_factory.getProtocol(tmo)
|
||||
|
||||
d = self.factory.processor.process(iprot, oprot)
|
||||
d.addCallbacks(self.processOk, self.processError,
|
||||
callbackArgs=(tmo,))
|
||||
|
||||
|
||||
class IThriftServerFactory(Interface):
|
||||
|
||||
processor = Attribute("Thrift processor")
|
||||
|
||||
iprot_factory = Attribute("Input protocol factory")
|
||||
|
||||
oprot_factory = Attribute("Output protocol factory")
|
||||
|
||||
|
||||
class IThriftClientFactory(Interface):
|
||||
|
||||
client_class = Attribute("Thrift client class")
|
||||
|
||||
iprot_factory = Attribute("Input protocol factory")
|
||||
|
||||
oprot_factory = Attribute("Output protocol factory")
|
||||
|
||||
|
||||
class ThriftServerFactory(ServerFactory):
|
||||
|
||||
implements(IThriftServerFactory)
|
||||
|
||||
protocol = ThriftServerProtocol
|
||||
|
||||
def __init__(self, processor, iprot_factory, oprot_factory=None):
|
||||
self.processor = processor
|
||||
self.iprot_factory = iprot_factory
|
||||
if oprot_factory is None:
|
||||
self.oprot_factory = iprot_factory
|
||||
else:
|
||||
self.oprot_factory = oprot_factory
|
||||
|
||||
|
||||
class ThriftClientFactory(ClientFactory):
|
||||
|
||||
implements(IThriftClientFactory)
|
||||
|
||||
protocol = ThriftClientProtocol
|
||||
|
||||
def __init__(self, client_class, iprot_factory, oprot_factory=None):
|
||||
self.client_class = client_class
|
||||
self.iprot_factory = iprot_factory
|
||||
if oprot_factory is None:
|
||||
self.oprot_factory = iprot_factory
|
||||
else:
|
||||
self.oprot_factory = oprot_factory
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self.protocol(self.client_class, self.iprot_factory,
|
||||
self.oprot_factory)
|
||||
p.factory = self
|
||||
return p
|
||||
|
||||
|
||||
class ThriftResource(resource.Resource):
|
||||
|
||||
allowedMethods = ('POST',)
|
||||
|
||||
def __init__(self, processor, inputProtocolFactory,
|
||||
outputProtocolFactory=None):
|
||||
resource.Resource.__init__(self)
|
||||
self.inputProtocolFactory = inputProtocolFactory
|
||||
if outputProtocolFactory is None:
|
||||
self.outputProtocolFactory = inputProtocolFactory
|
||||
else:
|
||||
self.outputProtocolFactory = outputProtocolFactory
|
||||
self.processor = processor
|
||||
|
||||
def getChild(self, path, request):
|
||||
return self
|
||||
|
||||
def _cbProcess(self, _, request, tmo):
|
||||
msg = tmo.getvalue()
|
||||
request.setResponseCode(http.OK)
|
||||
request.setHeader("content-type", "application/x-thrift")
|
||||
request.write(msg)
|
||||
request.finish()
|
||||
|
||||
def render_POST(self, request):
|
||||
request.content.seek(0, 0)
|
||||
data = request.content.read()
|
||||
tmi = TTransport.TMemoryBuffer(data)
|
||||
tmo = TTransport.TMemoryBuffer()
|
||||
|
||||
iprot = self.inputProtocolFactory.getProtocol(tmi)
|
||||
oprot = self.outputProtocolFactory.getProtocol(tmo)
|
||||
|
||||
d = self.processor.process(iprot, oprot)
|
||||
d.addCallback(self._cbProcess, request, tmo)
|
||||
return server.NOT_DONE_YET
|
248
thrift/transport/TZlibTransport.py
Normal file
248
thrift/transport/TZlibTransport.py
Normal file
|
@ -0,0 +1,248 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
"""TZlibTransport provides a compressed transport and transport factory
|
||||
class, using the python standard library zlib module to implement
|
||||
data compression.
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import zlib
|
||||
from cStringIO import StringIO
|
||||
from TTransport import TTransportBase, CReadableTransport
|
||||
|
||||
|
||||
class TZlibTransportFactory(object):
|
||||
"""Factory transport that builds zlib compressed transports.
|
||||
|
||||
This factory caches the last single client/transport that it was passed
|
||||
and returns the same TZlibTransport object that was created.
|
||||
|
||||
This caching means the TServer class will get the _same_ transport
|
||||
object for both input and output transports from this factory.
|
||||
(For non-threaded scenarios only, since the cache only holds one object)
|
||||
|
||||
The purpose of this caching is to allocate only one TZlibTransport where
|
||||
only one is really needed (since it must have separate read/write buffers),
|
||||
and makes the statistics from getCompSavings() and getCompRatio()
|
||||
easier to understand.
|
||||
"""
|
||||
# class scoped cache of last transport given and zlibtransport returned
|
||||
_last_trans = None
|
||||
_last_z = None
|
||||
|
||||
def getTransport(self, trans, compresslevel=9):
|
||||
"""Wrap a transport, trans, with the TZlibTransport
|
||||
compressed transport class, returning a new
|
||||
transport to the caller.
|
||||
|
||||
@param compresslevel: The zlib compression level, ranging
|
||||
from 0 (no compression) to 9 (best compression). Defaults to 9.
|
||||
@type compresslevel: int
|
||||
|
||||
This method returns a TZlibTransport which wraps the
|
||||
passed C{trans} TTransport derived instance.
|
||||
"""
|
||||
if trans == self._last_trans:
|
||||
return self._last_z
|
||||
ztrans = TZlibTransport(trans, compresslevel)
|
||||
self._last_trans = trans
|
||||
self._last_z = ztrans
|
||||
return ztrans
|
||||
|
||||
|
||||
class TZlibTransport(TTransportBase, CReadableTransport):
|
||||
"""Class that wraps a transport with zlib, compressing writes
|
||||
and decompresses reads, using the python standard
|
||||
library zlib module.
|
||||
"""
|
||||
# Read buffer size for the python fastbinary C extension,
|
||||
# the TBinaryProtocolAccelerated class.
|
||||
DEFAULT_BUFFSIZE = 4096
|
||||
|
||||
def __init__(self, trans, compresslevel=9):
|
||||
"""Create a new TZlibTransport, wrapping C{trans}, another
|
||||
TTransport derived object.
|
||||
|
||||
@param trans: A thrift transport object, i.e. a TSocket() object.
|
||||
@type trans: TTransport
|
||||
@param compresslevel: The zlib compression level, ranging
|
||||
from 0 (no compression) to 9 (best compression). Default is 9.
|
||||
@type compresslevel: int
|
||||
"""
|
||||
self.__trans = trans
|
||||
self.compresslevel = compresslevel
|
||||
self.__rbuf = StringIO()
|
||||
self.__wbuf = StringIO()
|
||||
self._init_zlib()
|
||||
self._init_stats()
|
||||
|
||||
def _reinit_buffers(self):
|
||||
"""Internal method to initialize/reset the internal StringIO objects
|
||||
for read and write buffers.
|
||||
"""
|
||||
self.__rbuf = StringIO()
|
||||
self.__wbuf = StringIO()
|
||||
|
||||
def _init_stats(self):
|
||||
"""Internal method to reset the internal statistics counters
|
||||
for compression ratios and bandwidth savings.
|
||||
"""
|
||||
self.bytes_in = 0
|
||||
self.bytes_out = 0
|
||||
self.bytes_in_comp = 0
|
||||
self.bytes_out_comp = 0
|
||||
|
||||
def _init_zlib(self):
|
||||
"""Internal method for setting up the zlib compression and
|
||||
decompression objects.
|
||||
"""
|
||||
self._zcomp_read = zlib.decompressobj()
|
||||
self._zcomp_write = zlib.compressobj(self.compresslevel)
|
||||
|
||||
def getCompRatio(self):
|
||||
"""Get the current measured compression ratios (in,out) from
|
||||
this transport.
|
||||
|
||||
Returns a tuple of:
|
||||
(inbound_compression_ratio, outbound_compression_ratio)
|
||||
|
||||
The compression ratios are computed as:
|
||||
compressed / uncompressed
|
||||
|
||||
E.g., data that compresses by 10x will have a ratio of: 0.10
|
||||
and data that compresses to half of ts original size will
|
||||
have a ratio of 0.5
|
||||
|
||||
None is returned if no bytes have yet been processed in
|
||||
a particular direction.
|
||||
"""
|
||||
r_percent, w_percent = (None, None)
|
||||
if self.bytes_in > 0:
|
||||
r_percent = self.bytes_in_comp / self.bytes_in
|
||||
if self.bytes_out > 0:
|
||||
w_percent = self.bytes_out_comp / self.bytes_out
|
||||
return (r_percent, w_percent)
|
||||
|
||||
def getCompSavings(self):
|
||||
"""Get the current count of saved bytes due to data
|
||||
compression.
|
||||
|
||||
Returns a tuple of:
|
||||
(inbound_saved_bytes, outbound_saved_bytes)
|
||||
|
||||
Note: if compression is actually expanding your
|
||||
data (only likely with very tiny thrift objects), then
|
||||
the values returned will be negative.
|
||||
"""
|
||||
r_saved = self.bytes_in - self.bytes_in_comp
|
||||
w_saved = self.bytes_out - self.bytes_out_comp
|
||||
return (r_saved, w_saved)
|
||||
|
||||
def isOpen(self):
|
||||
"""Return the underlying transport's open status"""
|
||||
return self.__trans.isOpen()
|
||||
|
||||
def open(self):
|
||||
"""Open the underlying transport"""
|
||||
self._init_stats()
|
||||
return self.__trans.open()
|
||||
|
||||
def listen(self):
|
||||
"""Invoke the underlying transport's listen() method"""
|
||||
self.__trans.listen()
|
||||
|
||||
def accept(self):
|
||||
"""Accept connections on the underlying transport"""
|
||||
return self.__trans.accept()
|
||||
|
||||
def close(self):
|
||||
"""Close the underlying transport,"""
|
||||
self._reinit_buffers()
|
||||
self._init_zlib()
|
||||
return self.__trans.close()
|
||||
|
||||
def read(self, sz):
|
||||
"""Read up to sz bytes from the decompressed bytes buffer, and
|
||||
read from the underlying transport if the decompression
|
||||
buffer is empty.
|
||||
"""
|
||||
ret = self.__rbuf.read(sz)
|
||||
if len(ret) > 0:
|
||||
return ret
|
||||
# keep reading from transport until something comes back
|
||||
while True:
|
||||
if self.readComp(sz):
|
||||
break
|
||||
ret = self.__rbuf.read(sz)
|
||||
return ret
|
||||
|
||||
def readComp(self, sz):
|
||||
"""Read compressed data from the underlying transport, then
|
||||
decompress it and append it to the internal StringIO read buffer
|
||||
"""
|
||||
zbuf = self.__trans.read(sz)
|
||||
zbuf = self._zcomp_read.unconsumed_tail + zbuf
|
||||
buf = self._zcomp_read.decompress(zbuf)
|
||||
self.bytes_in += len(zbuf)
|
||||
self.bytes_in_comp += len(buf)
|
||||
old = self.__rbuf.read()
|
||||
self.__rbuf = StringIO(old + buf)
|
||||
if len(old) + len(buf) == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
def write(self, buf):
|
||||
"""Write some bytes, putting them into the internal write
|
||||
buffer for eventual compression.
|
||||
"""
|
||||
self.__wbuf.write(buf)
|
||||
|
||||
def flush(self):
|
||||
"""Flush any queued up data in the write buffer and ensure the
|
||||
compression buffer is flushed out to the underlying transport
|
||||
"""
|
||||
wout = self.__wbuf.getvalue()
|
||||
if len(wout) > 0:
|
||||
zbuf = self._zcomp_write.compress(wout)
|
||||
self.bytes_out += len(wout)
|
||||
self.bytes_out_comp += len(zbuf)
|
||||
else:
|
||||
zbuf = ''
|
||||
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
|
||||
self.bytes_out_comp += len(ztail)
|
||||
if (len(zbuf) + len(ztail)) > 0:
|
||||
self.__wbuf = StringIO()
|
||||
self.__trans.write(zbuf + ztail)
|
||||
self.__trans.flush()
|
||||
|
||||
@property
|
||||
def cstringio_buf(self):
|
||||
"""Implement the CReadableTransport interface"""
|
||||
return self.__rbuf
|
||||
|
||||
def cstringio_refill(self, partialread, reqlen):
|
||||
"""Implement the CReadableTransport interface for refill"""
|
||||
retstring = partialread
|
||||
if reqlen < self.DEFAULT_BUFFSIZE:
|
||||
retstring += self.read(self.DEFAULT_BUFFSIZE)
|
||||
while len(retstring) < reqlen:
|
||||
retstring += self.read(reqlen - len(retstring))
|
||||
self.__rbuf = StringIO(retstring)
|
||||
return self.__rbuf
|
20
thrift/transport/__init__.py
Normal file
20
thrift/transport/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport']
|
Loading…
Add table
Reference in a new issue