Final set of changes for v20 python code:

- Brought over all new thrift files... had to untar and unzip the thrift package in awips2-rpm
  - then go into /lib/py/ and run `python setup.py build`
  - then copy all of the files that get put in the subdirectory in /build

- Replaced DataAccessLayer.py with the current one from our v18.1.11 of python-awips
This commit is contained in:
Shay Carter 2023-08-23 13:46:12 -06:00
parent 018afefee6
commit 3c7bd9f0de
29 changed files with 4972 additions and 2175 deletions

View file

@ -1,24 +1,3 @@
# #
# This software was developed and / or modified by Raytheon Company,
# pursuant to Contract DG133W-05-CQ-1067 with the US Government.
#
# U.S. EXPORT CONTROLLED TECHNICAL DATA
# This software product contains export-restricted data whose
# export/transfer/disclosure is restricted by U.S. law. Dissemination
# to non-U.S. persons whether in the United States or abroad requires
# an export license or other authorization.
#
# Contractor Name: Raytheon Company
# Contractor Address: 6825 Pine Street, Suite 340
# Mail Stop B8
# Omaha, NE 68106
# 402.291.0100
#
# See the AWIPS II Master Rights File ("Master Rights File.pdf") for
# further licensing information.
# #
# #
# Published interface for awips.dataaccess package # Published interface for awips.dataaccess package
# #
@ -26,34 +5,29 @@
# SOFTWARE HISTORY # SOFTWARE HISTORY
# #
# Date Ticket# Engineer Description # Date Ticket# Engineer Description
# ------------ ---------- ----------- -------------------------- # ------------ ------- ---------- -------------------------
# 12/10/12 njensen Initial Creation. # 12/10/12 njensen Initial Creation.
# Feb 14, 2013 1614 bsteffen refactor data access framework # Feb 14, 2013 1614 bsteffen refactor data access framework to use single request.
# to use single request.
# 04/10/13 1871 mnash move getLatLonCoords to JGridData and add default args # 04/10/13 1871 mnash move getLatLonCoords to JGridData and add default args
# 05/29/13 2023 dgilling Hook up ThriftClientRouter. # 05/29/13 2023 dgilling Hook up ThriftClientRouter.
# 03/03/14 2673 bsteffen Add ability to query only ref times. # 03/03/14 2673 bsteffen Add ability to query only ref times.
# 07/22/14 3185 njensen Added optional/default args to newDataRequest # 07/22/14 3185 njensen Added optional/default args to newDataRequest
# 07/30/14 3185 njensen Renamed valid identifiers to optional # 07/30/14 3185 njensen Renamed valid identifiers to optional
# Apr 26, 2015 4259 njensen Updated for new JEP API # Apr 26, 2015 4259 njensen Updated for new JEP API
# Apr 13, 2016 5379 tgurney Add getIdentifierValues() # Apr 13, 2016 5379 tgurney Add getIdentifierValues(), getRequiredIdentifiers(),
# Jun 01, 2016 5587 tgurney Add new signatures for # and getOptionalIdentifiers()
# getRequiredIdentifiers() and # Oct 07, 2016 ---- mjames@ucar Added getForecastRun
# getOptionalIdentifiers()
# Oct 18, 2016 5916 bsteffen Add setLazyLoadGridLatLon # Oct 18, 2016 5916 bsteffen Add setLazyLoadGridLatLon
# Oct 11, 2018 ---- mjames@ucar Added getMetarObs() getSynopticObs()
# #
#
import sys import sys
import subprocess
import warnings import warnings
THRIFT_HOST = "edex" THRIFT_HOST = "edex"
USING_NATIVE_THRIFT = False USING_NATIVE_THRIFT = False
if 'jep' in sys.modules: if 'jep' in sys.modules:
# intentionally do not catch if this fails to import, we want it to # intentionally do not catch if this fails to import, we want it to
# be obvious that something is configured wrong when running from within # be obvious that something is configured wrong when running from within
@ -66,6 +40,147 @@ else:
USING_NATIVE_THRIFT = True USING_NATIVE_THRIFT = True
def getRadarProductIDs(availableParms):
"""
Get only the numeric idetifiers for NEXRAD3 products.
Args:
availableParms: Full list of radar parameters
Returns:
List of filtered parameters
"""
productIDs = []
for p in list(availableParms):
try:
if isinstance(int(p), int):
productIDs.append(str(p))
except ValueError:
pass
return productIDs
def getRadarProductNames(availableParms):
"""
Get only the named idetifiers for NEXRAD3 products.
Args:
availableParms: Full list of radar parameters
Returns:
List of filtered parameters
"""
productNames = []
for p in list(availableParms):
if len(p) > 3:
productNames.append(p)
return productNames
def getMetarObs(response):
"""
Processes a DataAccessLayer "obs" response into a dictionary,
with special consideration for multi-value parameters
"presWeather", "skyCover", and "skyLayerBase".
Args:
response: DAL getGeometry() list
Returns:
A dictionary of METAR obs
"""
from datetime import datetime
single_val_params = ["timeObs", "stationName", "longitude", "latitude",
"temperature", "dewpoint", "windDir",
"windSpeed", "seaLevelPress"]
multi_val_params = ["presWeather", "skyCover", "skyLayerBase"]
params = single_val_params + multi_val_params
station_names, pres_weather, sky_cov, sky_layer_base = [], [], [], []
obs = dict({params: [] for params in params})
for ob in response:
avail_params = ob.getParameters()
if "presWeather" in avail_params:
pres_weather.append(ob.getString("presWeather"))
elif "skyCover" in avail_params and "skyLayerBase" in avail_params:
sky_cov.append(ob.getString("skyCover"))
sky_layer_base.append(ob.getNumber("skyLayerBase"))
else:
# If we already have a record for this stationName, skip
if ob.getString('stationName') not in station_names:
station_names.append(ob.getString('stationName'))
for param in single_val_params:
if param in avail_params:
if param == 'timeObs':
obs[param].append(datetime.fromtimestamp(ob.getNumber(param) / 1000.0))
else:
try:
obs[param].append(ob.getNumber(param))
except TypeError:
obs[param].append(ob.getString(param))
else:
obs[param].append(None)
obs['presWeather'].append(pres_weather)
obs['skyCover'].append(sky_cov)
obs['skyLayerBase'].append(sky_layer_base)
pres_weather = []
sky_cov = []
sky_layer_base = []
return obs
def getSynopticObs(response):
"""
Processes a DataAccessLayer "sfcobs" response into a dictionary
of available parameters.
Args:
response: DAL getGeometry() list
Returns:
A dictionary of synop obs
"""
from datetime import datetime
station_names = []
params = response[0].getParameters()
sfcobs = dict({params: [] for params in params})
for sfcob in response:
# If we already have a record for this stationId, skip
if sfcob.getString('stationId') not in station_names:
station_names.append(sfcob.getString('stationId'))
for param in params:
if param == 'timeObs':
sfcobs[param].append(datetime.fromtimestamp(sfcob.getNumber(param) / 1000.0))
else:
try:
sfcobs[param].append(sfcob.getNumber(param))
except TypeError:
sfcobs[param].append(sfcob.getString(param))
return sfcobs
def getForecastRun(cycle, times):
"""
Get the latest forecast run (list of objects) from all
all cycles and times returned from DataAccessLayer "grid"
response.
Args:
cycle: Forecast cycle reference time
times: All available times/cycles
Returns:
DataTime array for a single forecast run
"""
fcstRun = []
for t in times:
if str(t)[:19] == str(cycle):
fcstRun.append(t)
return fcstRun
def getAvailableTimes(request, refTimeOnly=False): def getAvailableTimes(request, refTimeOnly=False):
""" """
@ -204,8 +319,9 @@ def getIdentifierValues(request, identifierKey):
""" """
return router.getIdentifierValues(request, identifierKey) return router.getIdentifierValues(request, identifierKey)
def newDataRequest(datatype=None, **kwargs): def newDataRequest(datatype=None, **kwargs):
"""" """
Creates a new instance of IDataRequest suitable for the runtime environment. Creates a new instance of IDataRequest suitable for the runtime environment.
All args are optional and exist solely for convenience. All args are optional and exist solely for convenience.
@ -215,13 +331,14 @@ def newDataRequest(datatype=None, **kwargs):
levels: a list of levels to set on the request levels: a list of levels to set on the request
locationNames: a list of locationNames to set on the request locationNames: a list of locationNames to set on the request
envelope: an envelope to limit the request envelope: an envelope to limit the request
**kwargs: any leftover kwargs will be set as identifiers kwargs: any leftover kwargs will be set as identifiers
Returns: Returns:
a new IDataRequest a new IDataRequest
""" """
return router.newDataRequest(datatype, **kwargs) return router.newDataRequest(datatype, **kwargs)
def getSupportedDatatypes(): def getSupportedDatatypes():
""" """
Gets the datatypes that are supported by the framework Gets the datatypes that are supported by the framework
@ -239,7 +356,7 @@ def changeEDEXHost(newHostName):
method will throw a TypeError. method will throw a TypeError.
Args: Args:
newHostHame: the EDEX host to connect to newHostName: the EDEX host to connect to
""" """
if USING_NATIVE_THRIFT: if USING_NATIVE_THRIFT:
global THRIFT_HOST global THRIFT_HOST
@ -249,6 +366,7 @@ def changeEDEXHost(newHostName):
else: else:
raise TypeError("Cannot call changeEDEXHost when using JepRouter.") raise TypeError("Cannot call changeEDEXHost when using JepRouter.")
def setLazyLoadGridLatLon(lazyLoadGridLatLon): def setLazyLoadGridLatLon(lazyLoadGridLatLon):
""" """
Provide a hint to the Data Access Framework indicating whether to load the Provide a hint to the Data Access Framework indicating whether to load the

View file

@ -0,0 +1,82 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TProcessor, TMessageType
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
from thrift.protocol.TProtocol import TProtocolException
class TMultiplexedProcessor(TProcessor):
def __init__(self):
self.defaultProcessor = None
self.services = {}
def registerDefault(self, processor):
"""
If a non-multiplexed processor connects to the server and wants to
communicate, use the given processor to handle it. This mechanism
allows servers to upgrade from non-multiplexed to multiplexed in a
backwards-compatible way and still handle old clients.
"""
self.defaultProcessor = processor
def registerProcessor(self, serviceName, processor):
self.services[serviceName] = processor
def on_message_begin(self, func):
for key in self.services.keys():
self.services[key].on_message_begin(func)
def process(self, iprot, oprot):
(name, type, seqid) = iprot.readMessageBegin()
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
raise TProtocolException(
TProtocolException.NOT_IMPLEMENTED,
"TMultiplexedProtocol only supports CALL & ONEWAY")
index = name.find(TMultiplexedProtocol.SEPARATOR)
if index < 0:
if self.defaultProcessor:
return self.defaultProcessor.process(
StoredMessageProtocol(iprot, (name, type, seqid)), oprot)
else:
raise TProtocolException(
TProtocolException.NOT_IMPLEMENTED,
"Service name not found in message name: " + name + ". " +
"Did you forget to use TMultiplexedProtocol in your client?")
serviceName = name[0:index]
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
if serviceName not in self.services:
raise TProtocolException(
TProtocolException.NOT_IMPLEMENTED,
"Service name not found: " + serviceName + ". " +
"Did you forget to call registerProcessor()?")
standardMessage = (call, type, seqid)
return self.services[serviceName].process(
StoredMessageProtocol(iprot, standardMessage), oprot)
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, messageBegin):
self.messageBegin = messageBegin
def readMessageBegin(self):
return self.messageBegin

83
thrift/TRecursive.py Normal file
View file

@ -0,0 +1,83 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from thrift.Thrift import TType
TYPE_IDX = 1
SPEC_ARGS_IDX = 3
SPEC_ARGS_CLASS_REF_IDX = 0
SPEC_ARGS_THRIFT_SPEC_IDX = 1
def fix_spec(all_structs):
"""Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
for struc in all_structs:
spec = struc.thrift_spec
for thrift_spec in spec:
if thrift_spec is None:
continue
elif thrift_spec[TYPE_IDX] == TType.STRUCT:
other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec
thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other
elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET):
_fix_list_or_set(thrift_spec[SPEC_ARGS_IDX])
elif thrift_spec[TYPE_IDX] == TType.MAP:
_fix_map(thrift_spec[SPEC_ARGS_IDX])
def _fix_list_or_set(element_type):
# For a list or set, the thrift_spec entry looks like,
# (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1
# so ``element_type`` will be,
# (TType.STRUCT, [RecList, None], False)
if element_type[0] == TType.STRUCT:
element_type[1][1] = element_type[1][0].thrift_spec
elif element_type[0] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[1])
elif element_type[0] == TType.MAP:
_fix_map(element_type[1])
def _fix_map(element_type):
# For a map of key -> value type, ``element_type`` will be,
# (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
# which is just a normal struct definition.
#
# For a map of key -> list / set, ``element_type`` will be,
# (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False)
# and we need to process the 3rd element as a list.
#
# For a map of key -> map, ``element_type`` will be,
# (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT,
# [RecMapMap, None], False), False)
# and need to process 3rd element as a map.
# Is the map key a struct?
if element_type[0] == TType.STRUCT:
element_type[1][1] = element_type[1][0].thrift_spec
elif element_type[0] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[1])
elif element_type[0] == TType.MAP:
_fix_map(element_type[1])
# Is the map value a struct?
if element_type[2] == TType.STRUCT:
element_type[3][1] = element_type[3][0].thrift_spec
elif element_type[2] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[3])
elif element_type[2] == TType.MAP:
_fix_map(element_type[3])

View file

@ -19,6 +19,7 @@
from os import path from os import path
from SCons.Builder import Builder from SCons.Builder import Builder
from six.moves import map
def scons_env(env, add=''): def scons_env(env, add=''):
@ -31,5 +32,5 @@ def scons_env(env, add=''):
def gen_cpp(env, dir, file): def gen_cpp(env, dir, file):
scons_env(env) scons_env(env)
suffixes = ['_types.h', '_types.cpp'] suffixes = ['_types.h', '_types.cpp']
targets = ['gen-cpp/' + file + s for s in suffixes] targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
return env.ThriftCpp(targets, dir + file + '.thrift') return env.ThriftCpp(targets, dir + file + '.thrift')

188
thrift/TTornado.py Normal file
View file

@ -0,0 +1,188 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
import logging
import socket
import struct
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
from io import BytesIO
from collections import deque
from contextlib import contextmanager
from tornado import gen, iostream, ioloop, tcpserver, concurrent
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
logger = logging.getLogger(__name__)
class _Lock(object):
def __init__(self):
self._waiters = deque()
def acquired(self):
return len(self._waiters) > 0
@gen.coroutine
def acquire(self):
blocker = self._waiters[-1] if self.acquired() else None
future = concurrent.Future()
self._waiters.append(future)
if blocker:
yield blocker
raise gen.Return(self._lock_context())
def release(self):
assert self.acquired(), 'Lock not aquired'
future = self._waiters.popleft()
future.set_result(None)
@contextmanager
def _lock_context(self):
try:
yield
finally:
self.release()
class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
def __init__(self, host, port, stream=None, io_loop=None):
self.host = host
self.port = port
self.io_loop = io_loop or ioloop.IOLoop.current()
self.__wbuf = BytesIO()
self._read_lock = _Lock()
# servers provide a ready-to-go stream
self.stream = stream
def with_timeout(self, timeout, future):
return gen.with_timeout(timeout, future, self.io_loop)
@gen.coroutine
def open(self, timeout=None):
logger.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
try:
connect = self.stream.connect((self.host, self.port))
if timeout is not None:
yield self.with_timeout(timeout, connect)
else:
yield connect
except (socket.error, IOError, ioloop.TimeoutError) as e:
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
raise gen.Return(self)
def set_close_callback(self, callback):
"""
Should be called only after open() returns
"""
self.stream.set_close_callback(callback)
def close(self):
# don't raise if we intend to close
self.stream.set_close_callback(None)
self.stream.close()
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
assert False, "you're doing it wrong"
@contextmanager
def io_exception_context(self):
try:
yield
except (socket.error, IOError) as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message=str(e))
except iostream.StreamBufferFullError as e:
raise TTransportException(
type=TTransportException.UNKNOWN,
message=str(e))
@gen.coroutine
def readFrame(self):
# IOStream processes reads one at a time
with (yield self._read_lock.acquire()):
with self.io_exception_context():
frame_header = yield self.stream.read_bytes(4)
if len(frame_header) == 0:
raise iostream.StreamClosedError('Read zero bytes from stream')
frame_length, = struct.unpack('!i', frame_header)
frame = yield self.stream.read_bytes(frame_length)
raise gen.Return(frame)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
frame = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
frame_length = struct.pack('!i', len(frame))
self.__wbuf = BytesIO()
with self.io_exception_context():
return self.stream.write(frame_length + frame)
class TTornadoServer(tcpserver.TCPServer):
def __init__(self, processor, iprot_factory, oprot_factory=None,
*args, **kwargs):
super(TTornadoServer, self).__init__(*args, **kwargs)
self._processor = processor
self._iprot_factory = iprot_factory
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
@gen.coroutine
def handle_stream(self, stream, address):
host, port = address[:2]
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
io_loop=self.io_loop)
oprot = self._oprot_factory.getProtocol(trans)
try:
while not trans.stream.closed():
try:
frame = yield trans.readFrame()
except TTransportException as e:
if e.type == TTransportException.END_OF_FILE:
break
else:
raise
tr = TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
yield self._processor.process(iprot, oprot)
except Exception:
logger.exception('thrift exception in handle_stream')
trans.close()
logger.info('client disconnected %s:%d', host, port)

View file

@ -17,10 +17,8 @@
# under the License. # under the License.
# #
import sys
class TType(object):
class TType:
STOP = 0 STOP = 0
VOID = 1 VOID = 1
BOOL = 2 BOOL = 2
@ -39,7 +37,8 @@ class TType:
UTF8 = 16 UTF8 = 16
UTF16 = 17 UTF16 = 17
_VALUES_TO_NAMES = ('STOP', _VALUES_TO_NAMES = (
'STOP',
'VOID', 'VOID',
'BOOL', 'BOOL',
'BYTE', 'BYTE',
@ -56,38 +55,42 @@ class TType:
'SET', 'SET',
'LIST', 'LIST',
'UTF8', 'UTF8',
'UTF16') 'UTF16',
)
class TMessageType: class TMessageType(object):
CALL = 1 CALL = 1
REPLY = 2 REPLY = 2
EXCEPTION = 3 EXCEPTION = 3
ONEWAY = 4 ONEWAY = 4
class TProcessor: class TProcessor(object):
"""Base class for procsessor, which works on two streams.""" """Base class for processor, which works on two streams."""
def process(iprot, oprot): def process(self, iprot, oprot):
"""
Process a request. The normal behvaior is to have the
processor invoke the correct handler and then it is the
server's responsibility to write the response to oprot.
"""
pass
def on_message_begin(self, func):
"""
Install a callback that receives (name, type, seqid)
after the message header is read.
"""
pass pass
class TException(Exception): class TException(Exception):
"""Base class for all thrift exceptions.""" """Base class for all thrift exceptions."""
# BaseException.message is deprecated in Python v[2.6,3.0)
if (2, 6, 0) <= sys.version_info < (3, 0):
def _get_message(self):
return self._message
def _set_message(self, message):
self._message = message
message = property(_get_message, _set_message)
def __init__(self, message=None): def __init__(self, message=None):
Exception.__init__(self, message) Exception.__init__(self, message)
self.message = message super(TException, self).__setattr__("message", message)
class TApplicationException(TException): class TApplicationException(TException):
@ -101,6 +104,9 @@ class TApplicationException(TException):
MISSING_RESULT = 5 MISSING_RESULT = 5
INTERNAL_ERROR = 6 INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7 PROTOCOL_ERROR = 7
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
def __init__(self, type=UNKNOWN, message=None): def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message) TException.__init__(self, message)
@ -119,6 +125,16 @@ class TApplicationException(TException):
return 'Bad sequence ID' return 'Bad sequence ID'
elif self.type == self.MISSING_RESULT: elif self.type == self.MISSING_RESULT:
return 'Missing result' return 'Missing result'
elif self.type == self.INTERNAL_ERROR:
return 'Internal error'
elif self.type == self.PROTOCOL_ERROR:
return 'Protocol error'
elif self.type == self.INVALID_TRANSFORM:
return 'Invalid transform'
elif self.type == self.INVALID_PROTOCOL:
return 'Invalid protocol'
elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
return 'Unsupported client type'
else: else:
return 'Default (unknown) TApplicationException' return 'Default (unknown) TApplicationException'
@ -155,3 +171,23 @@ class TApplicationException(TException):
oprot.writeFieldEnd() oprot.writeFieldEnd()
oprot.writeFieldStop() oprot.writeFieldStop()
oprot.writeStructEnd() oprot.writeStructEnd()
class TFrozenDict(dict):
"""A dictionary that is "frozen" like a frozenset"""
def __init__(self, *args, **kwargs):
super(TFrozenDict, self).__init__(*args, **kwargs)
# Sort the items so they will be in a consistent order.
# XOR in the hash of the class so we don't collide with
# the hash of a list of tuples.
self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
def __setitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __delitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __hash__(self):
return self.__hashval

46
thrift/compat.py Normal file
View file

@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
if sys.version_info[0] == 2:
from cStringIO import StringIO as BufferIO
def binary_to_str(bin_val):
return bin_val
def str_to_binary(str_val):
return str_val
def byte_index(bytes_val, i):
return ord(bytes_val[i])
else:
from io import BytesIO as BufferIO # noqa
def binary_to_str(bin_val):
return bin_val.decode('utf8')
def str_to_binary(str_val):
return bytes(str_val, 'utf8')
def byte_index(bytes_val, i):
return bytes_val[i]

View file

@ -17,22 +17,14 @@
# under the License. # under the License.
# #
from thrift.Thrift import *
from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport from thrift.transport import TTransport
try:
from thrift.protocol import fastbinary
except:
fastbinary = None
class TBase(object): class TBase(object):
__slots__ = [] __slots__ = ()
def __repr__(self): def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key)) L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
for key in self.__slots__]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other): def __eq__(self, other):
@ -49,33 +41,46 @@ class TBase(object):
return not (self == other) return not (self == other)
def read(self, iprot): def read(self, iprot):
if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and isinstance(iprot.trans, TTransport.CReadableTransport) and
self.thrift_spec is not None and self.thrift_spec is not None):
fastbinary is not None): iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
fastbinary.decode_binary(self, else:
iprot.trans,
(self.__class__, self.thrift_spec))
return
iprot.readStruct(self, self.thrift_spec) iprot.readStruct(self, self.thrift_spec)
def write(self, oprot): def write(self, oprot):
if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and if (oprot._fast_encode is not None and self.thrift_spec is not None):
self.thrift_spec is not None and
fastbinary is not None):
oprot.trans.write( oprot.trans.write(
fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
return else:
oprot.writeStruct(self, self.thrift_spec) oprot.writeStruct(self, self.thrift_spec)
class TExceptionBase(Exception): class TExceptionBase(TBase, Exception):
# old style class so python2.4 can raise exceptions derived from this pass
# This can't inherit from TBase because of that limitation.
__slots__ = []
__repr__ = TBase.__repr__.__func__
__eq__ = TBase.__eq__.__func__ class TFrozenBase(TBase):
__ne__ = TBase.__ne__.__func__ def __setitem__(self, *args):
read = TBase.read.__func__ raise TypeError("Can't modify frozen struct")
write = TBase.write.__func__
def __delitem__(self, *args):
raise TypeError("Can't modify frozen struct")
def __hash__(self, *args):
return hash(self.__class__) ^ hash(self.__slots__)
@classmethod
def read(cls, iprot):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
cls.thrift_spec is not None):
self = cls()
return iprot._fast_decode(None, iprot,
[self.__class__, self.thrift_spec])
else:
return iprot.readStruct(cls, cls.thrift_spec, True)
class TFrozenExceptionBase(TFrozenBase, TExceptionBase):
pass

View file

@ -17,7 +17,7 @@
# under the License. # under the License.
# #
from .TProtocol import * from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory
from struct import pack, unpack from struct import pack, unpack
@ -36,10 +36,18 @@ class TBinaryProtocol(TProtocolBase):
TYPE_MASK = 0x000000ff TYPE_MASK = 0x000000ff
def __init__(self, trans, strictRead=False, strictWrite=True): def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
TProtocolBase.__init__(self, trans) TProtocolBase.__init__(self, trans)
self.strictRead = strictRead self.strictRead = strictRead
self.strictWrite = strictWrite self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def _check_string_length(self, length):
self._check_length(self.string_length_limit, length)
def _check_container_length(self, length):
self._check_length(self.container_length_limit, length)
def writeMessageBegin(self, name, type, seqid): def writeMessageBegin(self, name, type, seqid):
if self.strictWrite: if self.strictWrite:
@ -118,7 +126,7 @@ class TBinaryProtocol(TProtocolBase):
buff = pack("!d", dub) buff = pack("!d", dub)
self.trans.write(buff) self.trans.write(buff)
def writeString(self, str): def writeBinary(self, str):
self.writeI32(len(str)) self.writeI32(len(str))
self.trans.write(str) self.trans.write(str)
@ -165,6 +173,7 @@ class TBinaryProtocol(TProtocolBase):
ktype = self.readByte() ktype = self.readByte()
vtype = self.readByte() vtype = self.readByte()
size = self.readI32() size = self.readI32()
self._check_container_length(size)
return (ktype, vtype, size) return (ktype, vtype, size)
def readMapEnd(self): def readMapEnd(self):
@ -173,6 +182,7 @@ class TBinaryProtocol(TProtocolBase):
def readListBegin(self): def readListBegin(self):
etype = self.readByte() etype = self.readByte()
size = self.readI32() size = self.readI32()
self._check_container_length(size)
return (etype, size) return (etype, size)
def readListEnd(self): def readListEnd(self):
@ -181,6 +191,7 @@ class TBinaryProtocol(TProtocolBase):
def readSetBegin(self): def readSetBegin(self):
etype = self.readByte() etype = self.readByte()
size = self.readI32() size = self.readI32()
self._check_container_length(size)
return (etype, size) return (etype, size)
def readSetEnd(self): def readSetEnd(self):
@ -204,10 +215,6 @@ class TBinaryProtocol(TProtocolBase):
def readI32(self): def readI32(self):
buff = self.trans.readAll(4) buff = self.trans.readAll(4)
try:
val, = unpack('!i', buff)
except TypeError:
#str does not support the buffer interface
val, = unpack('!i', buff) val, = unpack('!i', buff)
return val return val
@ -221,19 +228,24 @@ class TBinaryProtocol(TProtocolBase):
val, = unpack('!d', buff) val, = unpack('!d', buff)
return val return val
def readString(self): def readBinary(self):
len = self.readI32() size = self.readI32()
str = self.trans.readAll(len) self._check_string_length(size)
return str s = self.trans.readAll(size)
return s
class TBinaryProtocolFactory: class TBinaryProtocolFactory(TProtocolFactory):
def __init__(self, strictRead=False, strictWrite=True): def __init__(self, strictRead=False, strictWrite=True, **kwargs):
self.strictRead = strictRead self.strictRead = strictRead
self.strictWrite = strictWrite self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def getProtocol(self, trans): def getProtocol(self, trans):
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit)
return prot return prot
@ -246,6 +258,7 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):
We inherit from TBinaryProtocol so that the normal TBinaryProtocol We inherit from TBinaryProtocol so that the normal TBinaryProtocol
encoding can happen if the fastbinary module doesn't work for some encoding can happen if the fastbinary module doesn't work for some
reason. (TODO(dreiss): Make this happen sanely in more cases.) reason. (TODO(dreiss): Make this happen sanely in more cases.)
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use In order to take advantage of the C module, just use
TBinaryProtocolAccelerated instead of TBinaryProtocol. TBinaryProtocolAccelerated instead of TBinaryProtocol.
@ -258,7 +271,31 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):
""" """
pass pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_binary
self._fast_encode = fastbinary.encode_binary
class TBinaryProtocolAcceleratedFactory(TProtocolFactory):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
class TBinaryProtocolAcceleratedFactory:
def getProtocol(self, trans): def getProtocol(self, trans):
return TBinaryProtocolAccelerated(trans) return TBinaryProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)

View file

@ -17,9 +17,11 @@
# under the License. # under the License.
# #
from .TProtocol import * from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits
from struct import pack, unpack from struct import pack, unpack
from ..compat import binary_to_str, str_to_binary
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] __all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
CLEAR = 0 CLEAR = 0
@ -40,11 +42,14 @@ def make_helper(v_from, container):
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return nested return nested
return helper return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ) reader = make_helper(VALUE_READ, CONTAINER_READ)
def makeZigZag(n, bits): def makeZigZag(n, bits):
checkIntegerLimits(n, bits)
return (n << 1) ^ (n >> (bits - 1)) return (n << 1) ^ (n >> (bits - 1))
@ -53,7 +58,8 @@ def fromZigZag(n):
def writeVarint(trans, n): def writeVarint(trans, n):
out = [] assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!"
out = bytearray()
while True: while True:
if n & ~0x7f == 0: if n & ~0x7f == 0:
out.append(n) out.append(n)
@ -61,7 +67,7 @@ def writeVarint(trans, n):
else: else:
out.append((n & 0xff) | 0x80) out.append((n & 0xff) | 0x80)
n = n >> 7 n = n >> 7
trans.write(''.join(map(chr, out))) trans.write(bytes(out))
def readVarint(trans): def readVarint(trans):
@ -76,7 +82,7 @@ def readVarint(trans):
shift += 7 shift += 7
class CompactType: class CompactType(object):
STOP = 0x00 STOP = 0x00
TRUE = 0x01 TRUE = 0x01
FALSE = 0x02 FALSE = 0x02
@ -91,7 +97,9 @@ class CompactType:
MAP = 0x0B MAP = 0x0B
STRUCT = 0x0C STRUCT = 0x0C
CTYPES = {TType.STOP: CompactType.STOP,
CTYPES = {
TType.STOP: CompactType.STOP,
TType.BOOL: CompactType.TRUE, # used for collection TType.BOOL: CompactType.TRUE, # used for collection
TType.BYTE: CompactType.BYTE, TType.BYTE: CompactType.BYTE,
TType.I16: CompactType.I16, TType.I16: CompactType.I16,
@ -102,11 +110,11 @@ CTYPES = {TType.STOP: CompactType.STOP,
TType.STRUCT: CompactType.STRUCT, TType.STRUCT: CompactType.STRUCT,
TType.LIST: CompactType.LIST, TType.LIST: CompactType.LIST,
TType.SET: CompactType.SET, TType.SET: CompactType.SET,
TType.MAP: CompactType.MAP TType.MAP: CompactType.MAP,
} }
TTYPES = {} TTYPES = {}
for k, v in list(CTYPES.items()): for k, v in CTYPES.items():
TTYPES[v] = k TTYPES[v] = k
TTYPES[CompactType.FALSE] = TType.BOOL TTYPES[CompactType.FALSE] = TType.BOOL
del k del k
@ -120,9 +128,12 @@ class TCompactProtocol(TProtocolBase):
VERSION = 1 VERSION = 1
VERSION_MASK = 0x1f VERSION_MASK = 0x1f
TYPE_MASK = 0xe0 TYPE_MASK = 0xe0
TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5 TYPE_SHIFT_AMOUNT = 5
def __init__(self, trans): def __init__(self, trans,
string_length_limit=None,
container_length_limit=None):
TProtocolBase.__init__(self, trans) TProtocolBase.__init__(self, trans)
self.state = CLEAR self.state = CLEAR
self.__last_fid = 0 self.__last_fid = 0
@ -130,6 +141,14 @@ class TCompactProtocol(TProtocolBase):
self.__bool_value = None self.__bool_value = None
self.__structs = [] self.__structs = []
self.__containers = [] self.__containers = []
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def _check_string_length(self, length):
self._check_length(self.string_length_limit, length)
def _check_container_length(self, length):
self._check_length(self.container_length_limit, length)
def __writeVarint(self, n): def __writeVarint(self, n):
writeVarint(self.trans, n) writeVarint(self.trans, n)
@ -138,8 +157,15 @@ class TCompactProtocol(TProtocolBase):
assert self.state == CLEAR assert self.state == CLEAR
self.__writeUByte(self.PROTOCOL_ID) self.__writeUByte(self.PROTOCOL_ID)
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
self.__writeVarint(seqid) # The sequence id is a signed 32-bit integer but the compact protocol
self.__writeString(name) # writes this out as a "var int" which is always positive, and attempting
# to write a negative number results in an infinite loop, so we may
# need to do some conversion here...
tseqid = seqid
if tseqid < 0:
tseqid = 2147483648 + (2147483648 + tseqid)
self.__writeVarint(tseqid)
self.__writeBinary(str_to_binary(name))
self.state = VALUE_WRITE self.state = VALUE_WRITE
def writeMessageEnd(self): def writeMessageEnd(self):
@ -250,12 +276,12 @@ class TCompactProtocol(TProtocolBase):
@writer @writer
def writeDouble(self, dub): def writeDouble(self, dub):
self.trans.write(pack('!d', dub)) self.trans.write(pack('<d', dub))
def __writeString(self, s): def __writeBinary(self, s):
self.__writeSize(len(s)) self.__writeSize(len(s))
self.trans.write(s) self.trans.write(s)
writeString = writer(__writeString) writeBinary = writer(__writeBinary)
def readFieldBegin(self): def readFieldBegin(self):
assert self.state == FIELD_READ, self.state assert self.state == FIELD_READ, self.state
@ -300,7 +326,7 @@ class TCompactProtocol(TProtocolBase):
def __readSize(self): def __readSize(self):
result = self.__readVarint() result = self.__readVarint()
if result < 0: if result < 0:
raise TException("Length < 0") raise TProtocolException("Length < 0")
return result return result
def readMessageBegin(self): def readMessageBegin(self):
@ -310,13 +336,17 @@ class TCompactProtocol(TProtocolBase):
raise TProtocolException(TProtocolException.BAD_VERSION, raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad protocol id in the message: %d' % proto_id) 'Bad protocol id in the message: %d' % proto_id)
ver_type = self.__readUByte() ver_type = self.__readUByte()
type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
version = ver_type & self.VERSION_MASK version = ver_type & self.VERSION_MASK
if version != self.VERSION: if version != self.VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION, raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad version: %d (expect %d)' % (version, self.VERSION)) 'Bad version: %d (expect %d)' % (version, self.VERSION))
seqid = self.__readVarint() seqid = self.__readVarint()
name = self.__readString() # the sequence is a compact "var int" which is treaded as unsigned,
# however the sequence is actually signed...
if seqid > 2147483647:
seqid = -2147483648 - (2147483648 - seqid)
name = binary_to_str(self.__readBinary())
return (name, type, seqid) return (name, type, seqid)
def readMessageEnd(self): def readMessageEnd(self):
@ -340,6 +370,7 @@ class TCompactProtocol(TProtocolBase):
type = self.__getTType(size_type) type = self.__getTType(size_type)
if size == 15: if size == 15:
size = self.__readSize() size = self.__readSize()
self._check_container_length(size)
self.__containers.append(self.state) self.__containers.append(self.state)
self.state = CONTAINER_READ self.state = CONTAINER_READ
return type, size return type, size
@ -349,6 +380,7 @@ class TCompactProtocol(TProtocolBase):
def readMapBegin(self): def readMapBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size = self.__readSize() size = self.__readSize()
self._check_container_length(size)
types = 0 types = 0
if size > 0: if size > 0:
types = self.__readUByte() types = self.__readUByte()
@ -383,21 +415,73 @@ class TCompactProtocol(TProtocolBase):
@reader @reader
def readDouble(self): def readDouble(self):
buff = self.trans.readAll(8) buff = self.trans.readAll(8)
val, = unpack('!d', buff) val, = unpack('<d', buff)
return val return val
def __readString(self): def __readBinary(self):
len = self.__readSize() size = self.__readSize()
return self.trans.readAll(len) self._check_string_length(size)
readString = reader(__readString) return self.trans.readAll(size)
readBinary = reader(__readBinary)
def __getTType(self, byte): def __getTType(self, byte):
return TTYPES[byte & 0x0f] return TTYPES[byte & 0x0f]
class TCompactProtocolFactory: class TCompactProtocolFactory(TProtocolFactory):
def __init__(self): def __init__(self,
pass string_length_limit=None,
container_length_limit=None):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def getProtocol(self, trans): def getProtocol(self, trans):
return TCompactProtocol(trans) return TCompactProtocol(trans,
self.string_length_limit,
self.container_length_limit)
class TCompactProtocolAccelerated(TCompactProtocol):
"""C-Accelerated version of TCompactProtocol.
This class does not override any of TCompactProtocol's methods,
but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely.
We inherit from TCompactProtocol so that the normal TCompactProtocol
encoding can happen if the fastbinary module doesn't work for some
reason.
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use
TCompactProtocolAccelerated instead of TCompactProtocol.
"""
pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_compact
self._fast_encode = fastbinary.encode_compact
class TCompactProtocolAcceleratedFactory(TProtocolFactory):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TCompactProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)

View file

@ -0,0 +1,232 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory
from thrift.Thrift import TApplicationException, TMessageType
from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
PROTOCOLS_BY_ID = {
THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
}
class THeaderProtocol(TProtocolBase):
"""A framed protocol with headers and payload transforms.
THeaderProtocol frames other Thrift protocols and adds support for optional
out-of-band headers. The currently supported subprotocols are
TBinaryProtocol and TCompactProtocol. When used as a client, the
subprotocol to frame can be chosen with the `default_protocol` parameter to
the constructor.
It's also possible to apply transforms to the encoded message payload. The
only transform currently supported is to gzip.
When used in a server, THeaderProtocol can accept messages from
non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
includes framed and unframed transports and both TBinaryProtocol and
TCompactProtocol. The server will respond in the appropriate dialect for
the connected client. HTTP clients are not currently supported.
THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
or TProcessPoolServer.
See doc/specs/HeaderFormat.md for details of the wire format.
"""
def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY):
# much of the actual work for THeaderProtocol happens down in
# THeaderTransport since we need to do low-level shenanigans to detect
# if the client is sending us headers or one of the headerless formats
# we support. this wraps the real transport with the one that does all
# the magic.
if not isinstance(transport, THeaderTransport):
transport = THeaderTransport(transport, allowed_client_types, default_protocol)
super(THeaderProtocol, self).__init__(transport)
self._set_protocol()
def get_headers(self):
return self.trans.get_headers()
def set_header(self, key, value):
self.trans.set_header(key, value)
def clear_headers(self):
self.trans.clear_headers()
def add_transform(self, transform_id):
self.trans.add_transform(transform_id)
def writeMessageBegin(self, name, ttype, seqid):
self.trans.sequence_id = seqid
return self._protocol.writeMessageBegin(name, ttype, seqid)
def writeMessageEnd(self):
return self._protocol.writeMessageEnd()
def writeStructBegin(self, name):
return self._protocol.writeStructBegin(name)
def writeStructEnd(self):
return self._protocol.writeStructEnd()
def writeFieldBegin(self, name, ttype, fid):
return self._protocol.writeFieldBegin(name, ttype, fid)
def writeFieldEnd(self):
return self._protocol.writeFieldEnd()
def writeFieldStop(self):
return self._protocol.writeFieldStop()
def writeMapBegin(self, ktype, vtype, size):
return self._protocol.writeMapBegin(ktype, vtype, size)
def writeMapEnd(self):
return self._protocol.writeMapEnd()
def writeListBegin(self, etype, size):
return self._protocol.writeListBegin(etype, size)
def writeListEnd(self):
return self._protocol.writeListEnd()
def writeSetBegin(self, etype, size):
return self._protocol.writeSetBegin(etype, size)
def writeSetEnd(self):
return self._protocol.writeSetEnd()
def writeBool(self, bool_val):
return self._protocol.writeBool(bool_val)
def writeByte(self, byte):
return self._protocol.writeByte(byte)
def writeI16(self, i16):
return self._protocol.writeI16(i16)
def writeI32(self, i32):
return self._protocol.writeI32(i32)
def writeI64(self, i64):
return self._protocol.writeI64(i64)
def writeDouble(self, dub):
return self._protocol.writeDouble(dub)
def writeBinary(self, str_val):
return self._protocol.writeBinary(str_val)
def _set_protocol(self):
try:
protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
except KeyError:
raise TApplicationException(
TProtocolException.INVALID_PROTOCOL,
"Unknown protocol requested.",
)
self._protocol = protocol_cls(self.trans)
self._fast_encode = self._protocol._fast_encode
self._fast_decode = self._protocol._fast_decode
def readMessageBegin(self):
try:
self.trans.readFrame(0)
self._set_protocol()
except TApplicationException as exc:
self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
exc.write(self._protocol)
self._protocol.writeMessageEnd()
self.trans.flush()
return self._protocol.readMessageBegin()
def readMessageEnd(self):
return self._protocol.readMessageEnd()
def readStructBegin(self):
return self._protocol.readStructBegin()
def readStructEnd(self):
return self._protocol.readStructEnd()
def readFieldBegin(self):
return self._protocol.readFieldBegin()
def readFieldEnd(self):
return self._protocol.readFieldEnd()
def readMapBegin(self):
return self._protocol.readMapBegin()
def readMapEnd(self):
return self._protocol.readMapEnd()
def readListBegin(self):
return self._protocol.readListBegin()
def readListEnd(self):
return self._protocol.readListEnd()
def readSetBegin(self):
return self._protocol.readSetBegin()
def readSetEnd(self):
return self._protocol.readSetEnd()
def readBool(self):
return self._protocol.readBool()
def readByte(self):
return self._protocol.readByte()
def readI16(self):
return self._protocol.readI16()
def readI32(self):
return self._protocol.readI32()
def readI64(self):
return self._protocol.readI64()
def readDouble(self):
return self._protocol.readDouble()
def readBinary(self):
return self._protocol.readBinary()
class THeaderProtocolFactory(TProtocolFactory):
def __init__(
self,
allowed_client_types=(THeaderClientType.HEADERS,),
default_protocol=THeaderSubprotocolID.BINARY,
):
self.allowed_client_types = allowed_client_types
self.default_protocol = default_protocol
def getProtocol(self, trans):
return THeaderProtocol(trans, self.allowed_client_types, self.default_protocol)

View file

@ -0,0 +1,677 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import (TType, TProtocolBase, TProtocolException,
TProtocolFactory, checkIntegerLimits)
import base64
import math
import sys
from ..compat import str_to_binary
__all__ = ['TJSONProtocol',
'TJSONProtocolFactory',
'TSimpleJSONProtocol',
'TSimpleJSONProtocolFactory']
VERSION = 1
COMMA = b','
COLON = b':'
LBRACE = b'{'
RBRACE = b'}'
LBRACKET = b'['
RBRACKET = b']'
QUOTE = b'"'
BACKSLASH = b'\\'
ZERO = b'0'
ESCSEQ0 = ord('\\')
ESCSEQ1 = ord('u')
ESCAPE_CHAR_VALS = {
'"': '\\"',
'\\': '\\\\',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
# '/': '\\/',
}
ESCAPE_CHARS = {
b'"': '"',
b'\\': '\\',
b'b': '\b',
b'f': '\f',
b'n': '\n',
b'r': '\r',
b't': '\t',
b'/': '/',
}
NUMERIC_CHAR = b'+-.0123456789Ee'
CTYPES = {
TType.BOOL: 'tf',
TType.BYTE: 'i8',
TType.I16: 'i16',
TType.I32: 'i32',
TType.I64: 'i64',
TType.DOUBLE: 'dbl',
TType.STRING: 'str',
TType.STRUCT: 'rec',
TType.LIST: 'lst',
TType.SET: 'set',
TType.MAP: 'map',
}
JTYPES = {}
for key in CTYPES.keys():
JTYPES[CTYPES[key]] = key
class JSONBaseContext(object):
def __init__(self, protocol):
self.protocol = protocol
self.first = True
def doIO(self, function):
pass
def write(self):
pass
def read(self):
pass
def escapeNum(self):
return False
def __str__(self):
return self.__class__.__name__
class JSONListContext(JSONBaseContext):
def doIO(self, function):
if self.first is True:
self.first = False
else:
function(COMMA)
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
class JSONPairContext(JSONBaseContext):
def __init__(self, protocol):
super(JSONPairContext, self).__init__(protocol)
self.colon = True
def doIO(self, function):
if self.first:
self.first = False
self.colon = True
else:
function(COLON if self.colon else COMMA)
self.colon = not self.colon
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
def escapeNum(self):
return self.colon
def __str__(self):
return '%s, colon=%s' % (self.__class__.__name__, self.colon)
class LookaheadReader():
hasData = False
data = ''
def __init__(self, protocol):
self.protocol = protocol
def read(self):
if self.hasData is True:
self.hasData = False
else:
self.data = self.protocol.trans.read(1)
return self.data
def peek(self):
if self.hasData is False:
self.data = self.protocol.trans.read(1)
self.hasData = True
return self.data
class TJSONProtocolBase(TProtocolBase):
def __init__(self, trans):
TProtocolBase.__init__(self, trans)
self.resetWriteContext()
self.resetReadContext()
# We don't have length limit implementation for JSON protocols
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
def resetWriteContext(self):
self.context = JSONBaseContext(self)
self.contextStack = [self.context]
def resetReadContext(self):
self.resetWriteContext()
self.reader = LookaheadReader(self)
def pushContext(self, ctx):
self.contextStack.append(ctx)
self.context = ctx
def popContext(self):
self.contextStack.pop()
if self.contextStack:
self.context = self.contextStack[-1]
else:
self.context = JSONBaseContext(self)
def writeJSONString(self, string):
self.context.write()
json_str = ['"']
for s in string:
escaped = ESCAPE_CHAR_VALS.get(s, s)
json_str.append(escaped)
json_str.append('"')
self.trans.write(str_to_binary(''.join(json_str)))
def writeJSONNumber(self, number, formatter='{0}'):
self.context.write()
jsNumber = str(formatter.format(number)).encode('ascii')
if self.context.escapeNum():
self.trans.write(QUOTE)
self.trans.write(jsNumber)
self.trans.write(QUOTE)
else:
self.trans.write(jsNumber)
def writeJSONBase64(self, binary):
self.context.write()
self.trans.write(QUOTE)
self.trans.write(base64.b64encode(binary))
self.trans.write(QUOTE)
def writeJSONObjectStart(self):
self.context.write()
self.trans.write(LBRACE)
self.pushContext(JSONPairContext(self))
def writeJSONObjectEnd(self):
self.popContext()
self.trans.write(RBRACE)
def writeJSONArrayStart(self):
self.context.write()
self.trans.write(LBRACKET)
self.pushContext(JSONListContext(self))
def writeJSONArrayEnd(self):
self.popContext()
self.trans.write(RBRACKET)
def readJSONSyntaxChar(self, character):
current = self.reader.read()
if character != current:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unexpected character: %s" % current)
def _isHighSurrogate(self, codeunit):
return codeunit >= 0xd800 and codeunit <= 0xdbff
def _isLowSurrogate(self, codeunit):
return codeunit >= 0xdc00 and codeunit <= 0xdfff
def _toChar(self, high, low=None):
if not low:
if sys.version_info[0] == 2:
return ("\\u%04x" % high).decode('unicode-escape') \
.encode('utf-8')
else:
return chr(high)
else:
codepoint = (1 << 16) + ((high & 0x3ff) << 10)
codepoint += low & 0x3ff
if sys.version_info[0] == 2:
s = "\\U%08x" % codepoint
return s.decode('unicode-escape').encode('utf-8')
else:
return chr(codepoint)
def readJSONString(self, skipContext):
highSurrogate = None
string = []
if skipContext is False:
self.context.read()
self.readJSONSyntaxChar(QUOTE)
while True:
character = self.reader.read()
if character == QUOTE:
break
if ord(character) == ESCSEQ0:
character = self.reader.read()
if ord(character) == ESCSEQ1:
character = self.trans.read(4).decode('ascii')
codeunit = int(character, 16)
if self._isHighSurrogate(codeunit):
if highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected low surrogate char")
highSurrogate = codeunit
continue
elif self._isLowSurrogate(codeunit):
if not highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected high surrogate char")
character = self._toChar(highSurrogate, codeunit)
highSurrogate = None
else:
character = self._toChar(codeunit)
else:
if character not in ESCAPE_CHARS:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected control char")
character = ESCAPE_CHARS[character]
elif character in ESCAPE_CHAR_VALS:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unescaped control char")
elif sys.version_info[0] > 2:
utf8_bytes = bytearray([ord(character)])
while ord(self.reader.peek()) >= 0x80:
utf8_bytes.append(ord(self.reader.read()))
character = utf8_bytes.decode('utf8')
string.append(character)
if highSurrogate:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Expected low surrogate char")
return ''.join(string)
def isJSONNumeric(self, character):
return (True if NUMERIC_CHAR.find(character) != - 1 else False)
def readJSONQuotes(self):
if (self.context.escapeNum()):
self.readJSONSyntaxChar(QUOTE)
def readJSONNumericChars(self):
numeric = []
while True:
character = self.reader.peek()
if self.isJSONNumeric(character) is False:
break
numeric.append(self.reader.read())
return b''.join(numeric).decode('ascii')
def readJSONInteger(self):
self.context.read()
self.readJSONQuotes()
numeric = self.readJSONNumericChars()
self.readJSONQuotes()
try:
return int(numeric)
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONDouble(self):
self.context.read()
if self.reader.peek() == QUOTE:
string = self.readJSONString(True)
try:
double = float(string)
if (self.context.escapeNum is False and
not math.isinf(double) and
not math.isnan(double)):
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Numeric data unexpectedly quoted")
return double
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
else:
if self.context.escapeNum() is True:
self.readJSONSyntaxChar(QUOTE)
try:
return float(self.readJSONNumericChars())
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONBase64(self):
string = self.readJSONString(False)
size = len(string)
m = size % 4
# Force padding since b64encode method does not allow it
if m != 0:
for i in range(4 - m):
string += '='
return base64.b64decode(string)
def readJSONObjectStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACE)
self.pushContext(JSONPairContext(self))
def readJSONObjectEnd(self):
self.readJSONSyntaxChar(RBRACE)
self.popContext()
def readJSONArrayStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACKET)
self.pushContext(JSONListContext(self))
def readJSONArrayEnd(self):
self.readJSONSyntaxChar(RBRACKET)
self.popContext()
class TJSONProtocol(TJSONProtocolBase):
def readMessageBegin(self):
self.resetReadContext()
self.readJSONArrayStart()
if self.readJSONInteger() != VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
"Message contained bad version.")
name = self.readJSONString(False)
typen = self.readJSONInteger()
seqid = self.readJSONInteger()
return (name, typen, seqid)
def readMessageEnd(self):
self.readJSONArrayEnd()
def readStructBegin(self):
self.readJSONObjectStart()
def readStructEnd(self):
self.readJSONObjectEnd()
def readFieldBegin(self):
character = self.reader.peek()
ttype = 0
id = 0
if character == RBRACE:
ttype = TType.STOP
else:
id = self.readJSONInteger()
self.readJSONObjectStart()
ttype = JTYPES[self.readJSONString(False)]
return (None, ttype, id)
def readFieldEnd(self):
self.readJSONObjectEnd()
def readMapBegin(self):
self.readJSONArrayStart()
keyType = JTYPES[self.readJSONString(False)]
valueType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
self.readJSONObjectStart()
return (keyType, valueType, size)
def readMapEnd(self):
self.readJSONObjectEnd()
self.readJSONArrayEnd()
def readCollectionBegin(self):
self.readJSONArrayStart()
elemType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
return (elemType, size)
readListBegin = readCollectionBegin
readSetBegin = readCollectionBegin
def readCollectionEnd(self):
self.readJSONArrayEnd()
readSetEnd = readCollectionEnd
readListEnd = readCollectionEnd
def readBool(self):
return (False if self.readJSONInteger() == 0 else True)
def readNumber(self):
return self.readJSONInteger()
readByte = readNumber
readI16 = readNumber
readI32 = readNumber
readI64 = readNumber
def readDouble(self):
return self.readJSONDouble()
def readString(self):
return self.readJSONString(False)
def readBinary(self):
return self.readJSONBase64()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
self.writeJSONArrayStart()
self.writeJSONNumber(VERSION)
self.writeJSONString(name)
self.writeJSONNumber(request_type)
self.writeJSONNumber(seqid)
def writeMessageEnd(self):
self.writeJSONArrayEnd()
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, id):
self.writeJSONNumber(id)
self.writeJSONObjectStart()
self.writeJSONString(CTYPES[ttype])
def writeFieldEnd(self):
self.writeJSONObjectEnd()
def writeFieldStop(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[ktype])
self.writeJSONString(CTYPES[vtype])
self.writeJSONNumber(size)
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
self.writeJSONArrayEnd()
def writeListBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeListEnd(self):
self.writeJSONArrayEnd()
def writeSetBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeSetEnd(self):
self.writeJSONArrayEnd()
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeDouble(self, dbl):
# 17 significant digits should be just enough for any double precision
# value.
self.writeJSONNumber(dbl, '{0:.17g}')
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TJSONProtocolFactory(TProtocolFactory):
def getProtocol(self, trans):
return TJSONProtocol(trans)
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
Useful for interacting with scripting languages.
"""
def readMessageBegin(self):
raise NotImplementedError()
def readMessageEnd(self):
raise NotImplementedError()
def readStructBegin(self):
raise NotImplementedError()
def readStructEnd(self):
raise NotImplementedError()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, fid):
self.writeJSONString(name)
def writeFieldEnd(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
def _writeCollectionBegin(self, etype, size):
self.writeJSONArrayStart()
def _writeCollectionEnd(self):
self.writeJSONArrayEnd()
writeListBegin = _writeCollectionBegin
writeListEnd = _writeCollectionEnd
writeSetBegin = _writeCollectionBegin
writeSetEnd = _writeCollectionEnd
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeDouble(self, dbl):
self.writeJSONNumber(dbl)
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TSimpleJSONProtocolFactory(TProtocolFactory):
def getProtocol(self, trans):
return TSimpleJSONProtocol(trans)

View file

@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TMessageType
from thrift.protocol import TProtocolDecorator
SEPARATOR = ":"
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, serviceName):
self.serviceName = serviceName
def writeMessageBegin(self, name, type, seqid):
if (type == TMessageType.CALL or
type == TMessageType.ONEWAY):
super(TMultiplexedProtocol, self).writeMessageBegin(
self.serviceName + SEPARATOR + name,
type,
seqid
)
else:
super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid)

View file

@ -17,7 +17,14 @@
# under the License. # under the License.
# #
from thrift.Thrift import * from thrift.Thrift import TException, TType, TFrozenDict
from thrift.transport.TTransport import TTransportException
from ..compat import binary_to_str, str_to_binary
import six
import sys
from itertools import islice
from six.moves import zip
class TProtocolException(TException): class TProtocolException(TException):
@ -28,19 +35,33 @@ class TProtocolException(TException):
NEGATIVE_SIZE = 2 NEGATIVE_SIZE = 2
SIZE_LIMIT = 3 SIZE_LIMIT = 3
BAD_VERSION = 4 BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
INVALID_PROTOCOL = 7
def __init__(self, type=UNKNOWN, message=None): def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message) TException.__init__(self, message)
self.type = type self.type = type
class TProtocolBase: class TProtocolBase(object):
"""Base class for Thrift protocol driver.""" """Base class for Thrift protocol driver."""
def __init__(self, trans): def __init__(self, trans):
self.trans = trans self.trans = trans
self._fast_decode = None
self._fast_encode = None
def writeMessageBegin(self, name, type, seqid): @staticmethod
def _check_length(limit, length):
if length < 0:
raise TTransportException(TTransportException.NEGATIVE_SIZE,
'Negative length: %d' % length)
if limit is not None and length > limit:
raise TTransportException(TTransportException.SIZE_LIMIT,
'Length exceeded max allowed: %d' % limit)
def writeMessageBegin(self, name, ttype, seqid):
pass pass
def writeMessageEnd(self): def writeMessageEnd(self):
@ -52,7 +73,7 @@ class TProtocolBase:
def writeStructEnd(self): def writeStructEnd(self):
pass pass
def writeFieldBegin(self, name, type, id): def writeFieldBegin(self, name, ttype, fid):
pass pass
def writeFieldEnd(self): def writeFieldEnd(self):
@ -79,7 +100,7 @@ class TProtocolBase:
def writeSetEnd(self): def writeSetEnd(self):
pass pass
def writeBool(self, bool): def writeBool(self, bool_val):
pass pass
def writeByte(self, byte): def writeByte(self, byte):
@ -97,9 +118,15 @@ class TProtocolBase:
def writeDouble(self, dub): def writeDouble(self, dub):
pass pass
def writeString(self, str): def writeString(self, str_val):
self.writeBinary(str_to_binary(str_val))
def writeBinary(self, str_val):
pass pass
def writeUtf8(self, str_val):
self.writeString(str_val.encode('utf8'))
def readMessageBegin(self): def readMessageBegin(self):
pass pass
@ -155,50 +182,58 @@ class TProtocolBase:
pass pass
def readString(self): def readString(self):
return binary_to_str(self.readBinary())
def readBinary(self):
pass pass
def skip(self, type): def readUtf8(self):
if type == TType.STOP: return self.readString().decode('utf8')
return
elif type == TType.BOOL: def skip(self, ttype):
if ttype == TType.BOOL:
self.readBool() self.readBool()
elif type == TType.BYTE: elif ttype == TType.BYTE:
self.readByte() self.readByte()
elif type == TType.I16: elif ttype == TType.I16:
self.readI16() self.readI16()
elif type == TType.I32: elif ttype == TType.I32:
self.readI32() self.readI32()
elif type == TType.I64: elif ttype == TType.I64:
self.readI64() self.readI64()
elif type == TType.DOUBLE: elif ttype == TType.DOUBLE:
self.readDouble() self.readDouble()
elif type == TType.STRING: elif ttype == TType.STRING:
self.readString() self.readString()
elif type == TType.STRUCT: elif ttype == TType.STRUCT:
name = self.readStructBegin() name = self.readStructBegin()
while True: while True:
(name, type, id) = self.readFieldBegin() (name, ttype, id) = self.readFieldBegin()
if type == TType.STOP: if ttype == TType.STOP:
break break
self.skip(type) self.skip(ttype)
self.readFieldEnd() self.readFieldEnd()
self.readStructEnd() self.readStructEnd()
elif type == TType.MAP: elif ttype == TType.MAP:
(ktype, vtype, size) = self.readMapBegin() (ktype, vtype, size) = self.readMapBegin()
for i in range(size): for i in range(size):
self.skip(ktype) self.skip(ktype)
self.skip(vtype) self.skip(vtype)
self.readMapEnd() self.readMapEnd()
elif type == TType.SET: elif ttype == TType.SET:
(etype, size) = self.readSetBegin() (etype, size) = self.readSetBegin()
for i in range(size): for i in range(size):
self.skip(etype) self.skip(etype)
self.readSetEnd() self.readSetEnd()
elif type == TType.LIST: elif ttype == TType.LIST:
(etype, size) = self.readListBegin() (etype, size) = self.readListBegin()
for i in range(size): for i in range(size):
self.skip(etype) self.skip(etype)
self.readListEnd() self.readListEnd()
else:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"invalid TType")
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
_TTYPE_HANDLERS = ( _TTYPE_HANDLERS = (
@ -222,90 +257,77 @@ class TProtocolBase:
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
) )
def _ttype_handlers(self, ttype, spec):
if spec == 'BINARY':
if ttype != TType.STRING:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid binary field type %d' % ttype)
return ('readBinary', 'writeBinary', False)
if sys.version_info[0] == 2 and spec == 'UTF8':
if ttype != TType.STRING:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid string field type %d' % ttype)
return ('readUtf8', 'writeUtf8', False)
return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
def _read_by_ttype(self, ttype, spec, espec):
reader_name, _, is_container = self._ttype_handlers(ttype, espec)
if reader_name is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid type %d' % (ttype))
reader_func = getattr(self, reader_name)
read = (lambda: reader_func(espec)) if is_container else reader_func
while True:
yield read()
def readFieldByTType(self, ttype, spec): def readFieldByTType(self, ttype, spec):
try: return next(self._read_by_ttype(ttype, spec, spec))
(r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
except IndexError:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid field type %d' % (ttype))
if r_handler is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid field type %d' % (ttype))
reader = getattr(self, r_handler)
if not is_container:
return reader()
return reader(spec)
def readContainerList(self, spec): def readContainerList(self, spec):
results = [] ttype, tspec, is_immutable = spec
ttype, tspec = spec[0], spec[1]
r_handler = self._TTYPE_HANDLERS[ttype][0]
reader = getattr(self, r_handler)
(list_type, list_len) = self.readListBegin() (list_type, list_len) = self.readListBegin()
if tspec is None: # TODO: compare types we just decoded with thrift_spec
# list values are simple types elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
for idx in range(list_len): results = (tuple if is_immutable else list)(elems)
results.append(reader())
else:
# this is like an inlined readFieldByTType
container_reader = self._TTYPE_HANDLERS[list_type][0]
val_reader = getattr(self, container_reader)
for idx in range(list_len):
val = val_reader(tspec)
results.append(val)
self.readListEnd() self.readListEnd()
return results return results
def readContainerSet(self, spec): def readContainerSet(self, spec):
results = set() ttype, tspec, is_immutable = spec
ttype, tspec = spec[0], spec[1]
r_handler = self._TTYPE_HANDLERS[ttype][0]
reader = getattr(self, r_handler)
(set_type, set_len) = self.readSetBegin() (set_type, set_len) = self.readSetBegin()
if tspec is None: # TODO: compare types we just decoded with thrift_spec
# set members are simple types elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
for idx in range(set_len): results = (frozenset if is_immutable else set)(elems)
results.add(reader())
else:
container_reader = self._TTYPE_HANDLERS[set_type][0]
val_reader = getattr(self, container_reader)
for idx in range(set_len):
results.add(val_reader(tspec))
self.readSetEnd() self.readSetEnd()
return results return results
def readContainerStruct(self, spec): def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec (obj_class, obj_spec) = spec
# If obj_class.read is a classmethod (e.g. in frozen structs),
# call it as such.
if getattr(obj_class.read, '__self__', None) is obj_class:
obj = obj_class.read(self)
else:
obj = obj_class() obj = obj_class()
obj.read(self) obj.read(self)
return obj return obj
def readContainerMap(self, spec): def readContainerMap(self, spec):
results = dict() ktype, kspec, vtype, vspec, is_immutable = spec
key_ttype, key_spec = spec[0], spec[1]
val_ttype, val_spec = spec[2], spec[3]
(map_ktype, map_vtype, map_len) = self.readMapBegin() (map_ktype, map_vtype, map_len) = self.readMapBegin()
# TODO: compare types we just decoded with thrift_spec and # TODO: compare types we just decoded with thrift_spec and
# abort/skip if types disagree # abort/skip if types disagree
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) keys = self._read_by_ttype(ktype, spec, kspec)
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) vals = self._read_by_ttype(vtype, spec, vspec)
# list values are simple types keyvals = islice(zip(keys, vals), map_len)
for idx in range(map_len): results = (TFrozenDict if is_immutable else dict)(keyvals)
if key_spec is None:
k_val = key_reader()
else:
k_val = self.readFieldByTType(key_ttype, key_spec)
if val_spec is None:
v_val = val_reader()
else:
v_val = self.readFieldByTType(val_ttype, val_spec)
# this raises a TypeError with unhashable keys types
# i.e. this fails: d=dict(); d[[0,1]] = 2
results[k_val] = v_val
self.readMapEnd() self.readMapEnd()
return results return results
def readStruct(self, obj, thrift_spec): def readStruct(self, obj, thrift_spec, is_immutable=False):
if is_immutable:
fields = {}
self.readStructBegin() self.readStructBegin()
while True: while True:
(fname, ftype, fid) = self.readFieldBegin() (fname, ftype, fid) = self.readFieldBegin()
@ -320,56 +342,40 @@ class TProtocolBase:
fname = field[2] fname = field[2]
fspec = field[3] fspec = field[3]
val = self.readFieldByTType(ftype, fspec) val = self.readFieldByTType(ftype, fspec)
if is_immutable:
fields[fname] = val
else:
setattr(obj, fname, val) setattr(obj, fname, val)
else: else:
self.skip(ftype) self.skip(ftype)
self.readFieldEnd() self.readFieldEnd()
self.readStructEnd() self.readStructEnd()
if is_immutable:
return obj(**fields)
def writeContainerStruct(self, val, spec): def writeContainerStruct(self, val, spec):
val.write(self) val.write(self)
def writeContainerList(self, val, spec): def writeContainerList(self, val, spec):
self.writeListBegin(spec[0], len(val)) ttype, tspec, _ = spec
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] self.writeListBegin(ttype, len(val))
e_writer = getattr(self, w_handler) for _ in self._write_by_ttype(ttype, val, spec, tspec):
if not is_container: pass
for elem in val:
e_writer(elem)
else:
for elem in val:
e_writer(elem, spec[1])
self.writeListEnd() self.writeListEnd()
def writeContainerSet(self, val, spec): def writeContainerSet(self, val, spec):
self.writeSetBegin(spec[0], len(val)) ttype, tspec, _ = spec
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] self.writeSetBegin(ttype, len(val))
e_writer = getattr(self, w_handler) for _ in self._write_by_ttype(ttype, val, spec, tspec):
if not is_container: pass
for elem in val:
e_writer(elem)
else:
for elem in val:
e_writer(elem, spec[1])
self.writeSetEnd() self.writeSetEnd()
def writeContainerMap(self, val, spec): def writeContainerMap(self, val, spec):
k_type = spec[0] ktype, kspec, vtype, vspec, _ = spec
v_type = spec[2] self.writeMapBegin(ktype, vtype, len(val))
ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
k_writer = getattr(self, ktype_name) pass
v_writer = getattr(self, vtype_name)
self.writeMapBegin(k_type, v_type, len(val))
for m_key, m_val in val.items():
if not k_is_container:
k_writer(m_key)
else:
k_writer(m_key, spec[1])
if not v_is_container:
v_writer(m_val)
else:
v_writer(m_val, spec[3])
self.writeMapEnd() self.writeMapEnd()
def writeStruct(self, obj, thrift_spec): def writeStruct(self, obj, thrift_spec):
@ -385,22 +391,38 @@ class TProtocolBase:
fid = field[0] fid = field[0]
ftype = field[1] ftype = field[1]
fspec = field[3] fspec = field[3]
# get the writer method for this value
self.writeFieldBegin(fname, ftype, fid) self.writeFieldBegin(fname, ftype, fid)
self.writeFieldByTType(ftype, val, fspec) self.writeFieldByTType(ftype, val, fspec)
self.writeFieldEnd() self.writeFieldEnd()
self.writeFieldStop() self.writeFieldStop()
self.writeStructEnd() self.writeStructEnd()
def _write_by_ttype(self, ttype, vals, spec, espec):
_, writer_name, is_container = self._ttype_handlers(ttype, espec)
writer_func = getattr(self, writer_name)
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
for v in vals:
yield write(v)
def writeFieldByTType(self, ttype, val, spec): def writeFieldByTType(self, ttype, val, spec):
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] next(self._write_by_ttype(ttype, [val], spec, spec))
writer = getattr(self, w_handler)
if is_container:
writer(val, spec)
else:
writer(val)
class TProtocolFactory: def checkIntegerLimits(i, bits):
if bits == 8 and (i < -128 or i > 127):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i8 requires -128 <= number <= 127")
elif bits == 16 and (i < -32768 or i > 32767):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i16 requires -32768 <= number <= 32767")
elif bits == 32 and (i < -2147483648 or i > 2147483647):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i32 requires -2147483648 <= number <= 2147483647")
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
class TProtocolFactory(object):
def getProtocol(self, trans): def getProtocol(self, trans):
pass pass

View file

@ -0,0 +1,26 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
class TProtocolDecorator(object):
def __new__(cls, protocol, *args, **kwargs):
decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]),
(cls, protocol.__class__),
protocol.__dict__)
return object.__new__(decorated_cls)

View file

@ -17,4 +17,5 @@
# under the License. # under the License.
# #
__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] __all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
'TJSONProtocol', 'TProtocol', 'TProtocolDecorator']

Binary file not shown.

View file

@ -17,8 +17,11 @@
# under the License. # under the License.
# #
import http.server import ssl
from six.moves import BaseHTTPServer
from thrift.Thrift import TMessageType
from thrift.server import TServer from thrift.server import TServer
from thrift.transport import TTransport from thrift.transport import TTransport
@ -30,7 +33,9 @@ class ResponseException(Exception):
to override this behavior (e.g., to simulate a misconfigured or to override this behavior (e.g., to simulate a misconfigured or
overloaded web server during testing), it can raise a ResponseException. overloaded web server during testing), it can raise a ResponseException.
The function passed to the constructor will be called with the The function passed to the constructor will be called with the
RequestHandler as its only argument. RequestHandler as its only argument. Note that this is irrelevant
for ONEWAY requests, as the HTTP response must be sent before the
RPC is processed.
""" """
def __init__(self, handler): def __init__(self, handler):
self.handler = handler self.handler = handler
@ -41,17 +46,26 @@ class THttpServer(TServer.TServer):
This class is not very performant, but it is useful (for example) for This class is not very performant, but it is useful (for example) for
acting as a mock version of an Apache-based PHP Thrift endpoint. acting as a mock version of an Apache-based PHP Thrift endpoint.
Also important to note the HTTP implementation pretty much violates the
transport/protocol/processor/server layering, by performing the transport
functions here. This means things like oneway handling are oddly exposed.
""" """
def __init__(self, def __init__(self,
processor, processor,
server_address, server_address,
inputProtocolFactory, inputProtocolFactory,
outputProtocolFactory=None, outputProtocolFactory=None,
server_class=http.server.HTTPServer): server_class=BaseHTTPServer.HTTPServer,
"""Set up protocol factories and HTTP server. **kwargs):
"""Set up protocol factories and HTTP (or HTTPS) server.
See BaseHTTPServer for server_address. See BaseHTTPServer for server_address.
See TServer for protocol factories. See TServer for protocol factories.
To make a secure server, provide the named arguments:
* cafile - to validate clients [optional]
* cert_file - the server cert
* key_file - the server's key
""" """
if outputProtocolFactory is None: if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory outputProtocolFactory = inputProtocolFactory
@ -60,28 +74,58 @@ class THttpServer(TServer.TServer):
inputProtocolFactory, outputProtocolFactory) inputProtocolFactory, outputProtocolFactory)
thttpserver = self thttpserver = self
self._replied = None
class RequestHander(http.server.BaseHTTPRequestHandler): class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
def do_POST(self): def do_POST(self):
# Don't care about the request path. # Don't care about the request path.
itrans = TTransport.TFileObjectTransport(self.rfile) thttpserver._replied = False
otrans = TTransport.TFileObjectTransport(self.wfile) iftrans = TTransport.TFileObjectTransport(self.rfile)
itrans = TTransport.TBufferedTransport( itrans = TTransport.TBufferedTransport(
itrans, int(self.headers['Content-Length'])) iftrans, int(self.headers['Content-Length']))
otrans = TTransport.TMemoryBuffer() otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
try: try:
thttpserver.processor.on_message_begin(self.on_begin)
thttpserver.processor.process(iprot, oprot) thttpserver.processor.process(iprot, oprot)
except ResponseException as exn: except ResponseException as exn:
exn.handler(self) exn.handler(self)
else: else:
if not thttpserver._replied:
# If the request was ONEWAY we would have replied already
data = otrans.getvalue()
self.send_response(200) self.send_response(200)
self.send_header("content-type", "application/x-thrift") self.send_header("Content-Length", len(data))
self.send_header("Content-Type", "application/x-thrift")
self.end_headers() self.end_headers()
self.wfile.write(otrans.getvalue()) self.wfile.write(data)
def on_begin(self, name, type, seqid):
"""
Inspect the message header.
This allows us to post an immediate transport response
if the request is a ONEWAY message type.
"""
if type == TMessageType.ONEWAY:
self.send_response(200)
self.send_header("Content-Type", "application/x-thrift")
self.end_headers()
thttpserver._replied = True
self.httpd = server_class(server_address, RequestHander) self.httpd = server_class(server_address, RequestHander)
if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
context = ssl.create_default_context(cafile=kwargs.get('cafile'))
context.check_hostname = False
context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
def serve(self): def serve(self):
self.httpd.serve_forever() self.httpd.serve_forever()
def shutdown(self):
self.httpd.socket.close()
# self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!

View file

@ -24,18 +24,23 @@ only from the main thread.
The thread poool should be sized for concurrent tasks, not The thread poool should be sized for concurrent tasks, not
maximum connections maximum connections
""" """
import threading
import socket
import queue
import select
import struct
import logging import logging
import select
import socket
import struct
import threading
from collections import deque
from six.moves import queue
from thrift.transport import TTransport from thrift.transport import TTransport
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
__all__ = ['TNonblockingServer'] __all__ = ['TNonblockingServer']
logger = logging.getLogger(__name__)
class Worker(threading.Thread): class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection.""" """Worker is a small helper to process incoming connection."""
@ -54,8 +59,9 @@ class Worker(threading.Thread):
processor.process(iprot, oprot) processor.process(iprot, oprot)
callback(True, otrans.getvalue()) callback(True, otrans.getvalue())
except Exception: except Exception:
logging.exception("Exception while processing request") logger.exception("Exception while processing request", exc_info=True)
callback(False, '') callback(False, b'')
WAIT_LEN = 0 WAIT_LEN = 0
WAIT_MESSAGE = 1 WAIT_MESSAGE = 1
@ -81,11 +87,24 @@ def socket_exception(func):
try: try:
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
except socket.error: except socket.error:
logger.debug('ignoring socket exception', exc_info=True)
self.close() self.close()
return read return read
class Connection: class Message(object):
def __init__(self, offset, len_, header):
self.offset = offset
self.len = len_
self.buffer = None
self.is_header = header
@property
def end(self):
return self.offset + self.len
class Connection(object):
"""Basic class is represented connection. """Basic class is represented connection.
It can be in state: It can be in state:
@ -102,68 +121,60 @@ class Connection:
self.socket.setblocking(False) self.socket.setblocking(False)
self.status = WAIT_LEN self.status = WAIT_LEN
self.len = 0 self.len = 0
self.message = '' self.received = deque()
self._reading = Message(0, 4, True)
self._rbuf = b''
self._wbuf = b''
self.lock = threading.Lock() self.lock = threading.Lock()
self.wake_up = wake_up self.wake_up = wake_up
self.remaining = False
def _read_len(self):
"""Reads length of request.
It's a safer alternative to self.socket.recv(4)
"""
read = self.socket.recv(4 - len(self.message))
if len(read) == 0:
# if we read 0 bytes and self.message is empty, then
# the client closed the connection
if len(self.message) != 0:
logging.error("can't read frame size from socket")
self.close()
return
self.message += read
if len(self.message) == 4:
self.len, = struct.unpack('!i', self.message)
if self.len < 0:
logging.error("negative frame size, it seems client "
"doesn't use FramedTransport")
self.close()
elif self.len == 0:
logging.error("empty frame, it's really strange")
self.close()
else:
self.message = ''
self.status = WAIT_MESSAGE
@socket_exception @socket_exception
def read(self): def read(self):
"""Reads data from stream and switch state.""" """Reads data from stream and switch state."""
assert self.status in (WAIT_LEN, WAIT_MESSAGE) assert self.status in (WAIT_LEN, WAIT_MESSAGE)
if self.status == WAIT_LEN: assert not self.received
self._read_len() buf_size = 8192
# go back to the main loop here for simplicity instead of first = True
# falling through, even though there is a good chance that done = False
# the message is already available while not done:
elif self.status == WAIT_MESSAGE: read = self.socket.recv(buf_size)
read = self.socket.recv(self.len - len(self.message)) rlen = len(read)
if len(read) == 0: done = rlen < buf_size
logging.error("can't read frame from socket (get %d of " self._rbuf += read
"%d bytes)" % (len(self.message), self.len)) if first and rlen == 0:
if self.status != WAIT_LEN or self._rbuf:
logger.error('could not read frame from socket')
else:
logger.debug('read zero length. client might have disconnected')
self.close() self.close()
return while len(self._rbuf) >= self._reading.end:
self.message += read if self._reading.is_header:
if len(self.message) == self.len: mlen, = struct.unpack('!i', self._rbuf[:4])
self._reading = Message(self._reading.end, mlen, False)
self.status = WAIT_MESSAGE
else:
self._reading.buffer = self._rbuf
self.received.append(self._reading)
self._rbuf = self._rbuf[self._reading.end:]
self._reading = Message(0, 4, True)
first = False
if self.received:
self.status = WAIT_PROCESS self.status = WAIT_PROCESS
break
self.remaining = not done
@socket_exception @socket_exception
def write(self): def write(self):
"""Writes data from socket and switch state.""" """Writes data from socket and switch state."""
assert self.status == SEND_ANSWER assert self.status == SEND_ANSWER
sent = self.socket.send(self.message) sent = self.socket.send(self._wbuf)
if sent == len(self.message): if sent == len(self._wbuf):
self.status = WAIT_LEN self.status = WAIT_LEN
self.message = '' self._wbuf = b''
self.len = 0 self.len = 0
else: else:
self.message = self.message[sent:] self._wbuf = self._wbuf[sent:]
@locked @locked
def ready(self, all_ok, message): def ready(self, all_ok, message):
@ -183,13 +194,13 @@ class Connection:
self.close() self.close()
self.wake_up() self.wake_up()
return return
self.len = '' self.len = 0
if len(message) == 0: if len(message) == 0:
# it was a oneway request, do not write answer # it was a oneway request, do not write answer
self.message = '' self._wbuf = b''
self.status = WAIT_LEN self.status = WAIT_LEN
else: else:
self.message = struct.pack('!i', len(message)) + message self._wbuf = struct.pack('!i', len(message)) + message
self.status = SEND_ANSWER self.status = SEND_ANSWER
self.wake_up() self.wake_up()
@ -219,7 +230,7 @@ class Connection:
self.socket.close() self.socket.close()
class TNonblockingServer: class TNonblockingServer(object):
"""Non-blocking server.""" """Non-blocking server."""
def __init__(self, def __init__(self,
@ -259,7 +270,7 @@ class TNonblockingServer:
def wake_up(self): def wake_up(self):
"""Wake up main thread. """Wake up main thread.
The server usualy waits in select call in we should terminate one. The server usually waits in select call in we should terminate one.
The simplest way is using socketpair. The simplest way is using socketpair.
Select always wait to read from the first socket of socketpair. Select always wait to read from the first socket of socketpair.
@ -267,7 +278,7 @@ class TNonblockingServer:
In this case, we can just write anything to the second socket from In this case, we can just write anything to the second socket from
socketpair. socketpair.
""" """
self._write.send('1') self._write.send(b'1')
def stop(self): def stop(self):
"""Stop the server. """Stop the server.
@ -288,14 +299,20 @@ class TNonblockingServer:
"""Does select on open connections.""" """Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()] readable = [self.socket.handle.fileno(), self._read.fileno()]
writable = [] writable = []
remaining = []
for i, connection in list(self.clients.items()): for i, connection in list(self.clients.items()):
if connection.is_readable(): if connection.is_readable():
readable.append(connection.fileno()) readable.append(connection.fileno())
if connection.remaining or connection.received:
remaining.append(connection.fileno())
if connection.is_writeable(): if connection.is_writeable():
writable.append(connection.fileno()) writable.append(connection.fileno())
if connection.is_closed(): if connection.is_closed():
del self.clients[i] del self.clients[i]
return select.select(readable, writable, readable) if remaining:
return remaining, [], [], False
else:
return select.select(readable, writable, readable) + (True,)
def handle(self): def handle(self):
"""Handle requests. """Handle requests.
@ -303,20 +320,27 @@ class TNonblockingServer:
WARNING! You must call prepare() BEFORE calling handle() WARNING! You must call prepare() BEFORE calling handle()
""" """
assert self.prepared, "You have to call prepare before handle" assert self.prepared, "You have to call prepare before handle"
rset, wset, xset = self._select() rset, wset, xset, selected = self._select()
for readable in rset: for readable in rset:
if readable == self._read.fileno(): if readable == self._read.fileno():
# don't care i just need to clean readable flag # don't care i just need to clean readable flag
self._read.recv(1024) self._read.recv(1024)
elif readable == self.socket.handle.fileno(): elif readable == self.socket.handle.fileno():
client = self.socket.accept().handle try:
self.clients[client.fileno()] = Connection(client, client = self.socket.accept()
if client:
self.clients[client.handle.fileno()] = Connection(client.handle,
self.wake_up) self.wake_up)
except socket.error:
logger.debug('error while accepting', exc_info=True)
else: else:
connection = self.clients[readable] connection = self.clients[readable]
if selected:
connection.read() connection.read()
if connection.status == WAIT_PROCESS: if connection.received:
itransport = TTransport.TMemoryBuffer(connection.message) connection.status = WAIT_PROCESS
msg = connection.received.popleft()
itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset)
otransport = TTransport.TMemoryBuffer() otransport = TTransport.TMemoryBuffer()
iprot = self.in_protocol.getProtocol(itransport) iprot = self.in_protocol.getProtocol(itransport)
oprot = self.out_protocol.getProtocol(otransport) oprot = self.out_protocol.getProtocol(otransport)

View file

@ -19,11 +19,13 @@
import logging import logging
from multiprocessing import Process, Value, Condition, reduction
from multiprocessing import Process, Value, Condition
from .TServer import TServer from .TServer import TServer
from thrift.transport.TTransport import TTransportException from thrift.transport.TTransport import TTransportException
import collections.abc
logger = logging.getLogger(__name__)
class TProcessPoolServer(TServer): class TProcessPoolServer(TServer):
@ -41,7 +43,7 @@ class TProcessPoolServer(TServer):
self.postForkCallback = None self.postForkCallback = None
def setPostForkCallback(self, callback): def setPostForkCallback(self, callback):
if not isinstance(callback, collections.abc.Callable): if not callable(callback):
raise TypeError("This is not a callback!") raise TypeError("This is not a callback!")
self.postForkCallback = callback self.postForkCallback = callback
@ -57,11 +59,13 @@ class TProcessPoolServer(TServer):
while self.isRunning.value: while self.isRunning.value:
try: try:
client = self.serverTransport.accept() client = self.serverTransport.accept()
if not client:
continue
self.serveClient(client) self.serveClient(client)
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
return 0 return 0
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
def serveClient(self, client): def serveClient(self, client):
"""Process input/output from a client for as long as possible""" """Process input/output from a client for as long as possible"""
@ -73,10 +77,10 @@ class TProcessPoolServer(TServer):
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransportException as tx: except TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
otrans.close() otrans.close()
@ -97,7 +101,7 @@ class TProcessPoolServer(TServer):
w.start() w.start()
self.workers.append(w) self.workers.append(w)
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
# wait until the condition is set by stop() # wait until the condition is set by stop()
while True: while True:
@ -108,7 +112,7 @@ class TProcessPoolServer(TServer):
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
break break
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
self.isRunning.value = False self.isRunning.value = False

View file

@ -17,19 +17,19 @@
# under the License. # under the License.
# #
import queue from six.moves import queue
import logging import logging
import os import os
import sys
import threading import threading
import traceback
from thrift.Thrift import TProcessor
from thrift.protocol import TBinaryProtocol from thrift.protocol import TBinaryProtocol
from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.transport import TTransport from thrift.transport import TTransport
logger = logging.getLogger(__name__)
class TServer:
class TServer(object):
"""Base interface for a server, which must have a serve() method. """Base interface for a server, which must have a serve() method.
Three constructors for all servers: Three constructors for all servers:
@ -61,6 +61,12 @@ class TServer:
self.inputProtocolFactory = inputProtocolFactory self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory self.outputProtocolFactory = outputProtocolFactory
input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
raise ValueError("THeaderProtocol servers require that both the input and "
"output protocols are THeaderProtocol.")
def serve(self): def serve(self):
pass pass
@ -75,19 +81,32 @@ class TSimpleServer(TServer):
self.serverTransport.listen() self.serverTransport.listen()
while True: while True:
client = self.serverTransport.accept() client = self.serverTransport.accept()
if not client:
continue
itrans = self.inputTransportFactory.getTransport(client) itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans) iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol instance for
# input and output so that the response is in the same dialect that
# the server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans) oprot = self.outputProtocolFactory.getProtocol(otrans)
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransport.TTransportException as tx: except TTransport.TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
if otrans:
otrans.close() otrans.close()
@ -103,28 +122,40 @@ class TThreadedServer(TServer):
while True: while True:
try: try:
client = self.serverTransport.accept() client = self.serverTransport.accept()
if not client:
continue
t = threading.Thread(target=self.handle, args=(client,)) t = threading.Thread(target=self.handle, args=(client,))
t.setDaemon(self.daemon) t.setDaemon(self.daemon)
t.start() t.start()
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
def handle(self, client): def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client) itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans) iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol instance for input
# and output so that the response is in the same dialect that the
# server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans) oprot = self.outputProtocolFactory.getProtocol(otrans)
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransport.TTransportException as tx: except TTransport.TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
if otrans:
otrans.close() otrans.close()
@ -148,23 +179,33 @@ class TThreadPoolServer(TServer):
client = self.clients.get() client = self.clients.get()
self.serveClient(client) self.serveClient(client)
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
def serveClient(self, client): def serveClient(self, client):
"""Process input/output from a client for as long as possible""" """Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client) itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans) iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol instance for input
# and output so that the response is in the same dialect that the
# server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans) oprot = self.outputProtocolFactory.getProtocol(otrans)
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransport.TTransportException as tx: except TTransport.TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
itrans.close() itrans.close()
if otrans:
otrans.close() otrans.close()
def serve(self): def serve(self):
@ -175,16 +216,18 @@ class TThreadPoolServer(TServer):
t.setDaemon(self.daemon) t.setDaemon(self.daemon)
t.start() t.start()
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
# Pump the socket for clients # Pump the socket for clients
self.serverTransport.listen() self.serverTransport.listen()
while True: while True:
try: try:
client = self.serverTransport.accept() client = self.serverTransport.accept()
if not client:
continue
self.clients.put(client) self.clients.put(client)
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
class TForkingServer(TServer): class TForkingServer(TServer):
@ -209,11 +252,13 @@ class TForkingServer(TServer):
try: try:
file.close() file.close()
except IOError as e: except IOError as e:
logging.warning(e, exc_info=True) logger.warning(e, exc_info=True)
self.serverTransport.listen() self.serverTransport.listen()
while True: while True:
client = self.serverTransport.accept() client = self.serverTransport.accept()
if not client:
continue
try: try:
pid = os.fork() pid = os.fork()
@ -230,9 +275,17 @@ class TForkingServer(TServer):
try_close(otrans) try_close(otrans)
else: else:
itrans = self.inputTransportFactory.getTransport(client) itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans) iprot = self.inputProtocolFactory.getProtocol(itrans)
# for THeaderProtocol, we must use the same protocol
# instance for input and output so that the response is in
# the same dialect that the server detected the request was
# in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans) oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0 ecode = 0
@ -240,21 +293,22 @@ class TForkingServer(TServer):
try: try:
while True: while True:
self.processor.process(iprot, oprot) self.processor.process(iprot, oprot)
except TTransport.TTransportException as tx: except TTransport.TTransportException:
pass pass
except Exception as e: except Exception as e:
logging.exception(e) logger.exception(e)
ecode = 1 ecode = 1
finally: finally:
try_close(itrans) try_close(itrans)
if otrans:
try_close(otrans) try_close(otrans)
os._exit(ecode) os._exit(ecode)
except TTransport.TTransportException as tx: except TTransport.TTransportException:
pass pass
except Exception as x: except Exception as x:
logging.exception(x) logger.exception(x)
def collect_children(self): def collect_children(self):
while self.children: while self.children:

View file

@ -0,0 +1,352 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import struct
import zlib
from thrift.compat import BufferIO, byte_index
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
from thrift.Thrift import TApplicationException
from thrift.transport.TTransport import (
CReadableTransport,
TMemoryBuffer,
TTransportBase,
TTransportException,
)
U16 = struct.Struct("!H")
I32 = struct.Struct("!i")
HEADER_MAGIC = 0x0FFF
HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
class THeaderClientType(object):
HEADERS = 0x00
FRAMED_BINARY = 0x01
UNFRAMED_BINARY = 0x02
FRAMED_COMPACT = 0x03
UNFRAMED_COMPACT = 0x04
class THeaderSubprotocolID(object):
BINARY = 0x00
COMPACT = 0x02
class TInfoHeaderType(object):
KEY_VALUE = 0x01
class THeaderTransformID(object):
ZLIB = 0x01
READ_TRANSFORMS_BY_ID = {
THeaderTransformID.ZLIB: zlib.decompress,
}
WRITE_TRANSFORMS_BY_ID = {
THeaderTransformID.ZLIB: zlib.compress,
}
def _readString(trans):
size = readVarint(trans)
if size < 0:
raise TTransportException(
TTransportException.NEGATIVE_SIZE,
"Negative length"
)
return trans.read(size)
def _writeString(trans, value):
writeVarint(trans, len(value))
trans.write(value)
class THeaderTransport(TTransportBase, CReadableTransport):
def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY):
self._transport = transport
self._client_type = THeaderClientType.HEADERS
self._allowed_client_types = allowed_client_types
self._read_buffer = BufferIO(b"")
self._read_headers = {}
self._write_buffer = BufferIO()
self._write_headers = {}
self._write_transforms = []
self.flags = 0
self.sequence_id = 0
self._protocol_id = default_protocol
self._max_frame_size = HARD_MAX_FRAME_SIZE
def isOpen(self):
return self._transport.isOpen()
def open(self):
return self._transport.open()
def close(self):
return self._transport.close()
def get_headers(self):
return self._read_headers
def set_header(self, key, value):
if not isinstance(key, bytes):
raise ValueError("header names must be bytes")
if not isinstance(value, bytes):
raise ValueError("header values must be bytes")
self._write_headers[key] = value
def clear_headers(self):
self._write_headers.clear()
def add_transform(self, transform_id):
if transform_id not in WRITE_TRANSFORMS_BY_ID:
raise ValueError("unknown transform")
self._write_transforms.append(transform_id)
def set_max_frame_size(self, size):
if not 0 < size < HARD_MAX_FRAME_SIZE:
raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
self._max_frame_size = size
@property
def protocol_id(self):
if self._client_type == THeaderClientType.HEADERS:
return self._protocol_id
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
return THeaderSubprotocolID.BINARY
elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
return THeaderSubprotocolID.COMPACT
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Protocol ID not know for client type %d" % self._client_type,
)
def read(self, sz):
# if there are bytes left in the buffer, produce those first.
bytes_read = self._read_buffer.read(sz)
bytes_left_to_read = sz - len(bytes_read)
if bytes_left_to_read == 0:
return bytes_read
# if we've determined this is an unframed client, just pass the read
# through to the underlying transport until we're reset again at the
# beginning of the next message.
if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
return bytes_read + self._transport.read(bytes_left_to_read)
# we're empty and (maybe) framed. fill the buffers with the next frame.
self.readFrame(bytes_left_to_read)
return bytes_read + self._read_buffer.read(bytes_left_to_read)
def _set_client_type(self, client_type):
if client_type not in self._allowed_client_types:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Client type %d not allowed by server." % client_type,
)
self._client_type = client_type
def readFrame(self, req_sz):
# the first word could either be the length field of a framed message
# or the first bytes of an unframed message.
first_word = self._transport.readAll(I32.size)
frame_size, = I32.unpack(first_word)
is_unframed = False
if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
is_unframed = True
elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
is_unframed = True
if is_unframed:
bytes_left_to_read = req_sz - I32.size
if bytes_left_to_read > 0:
rest = self._transport.read(bytes_left_to_read)
else:
rest = b""
self._read_buffer = BufferIO(first_word + rest)
return
# ok, we're still here so we're framed.
if frame_size > self._max_frame_size:
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Frame was too large.",
)
read_buffer = BufferIO(self._transport.readAll(frame_size))
# the next word is either going to be the version field of a
# binary/compact protocol message or the magic value + flags of a
# header protocol message.
second_word = read_buffer.read(I32.size)
version, = I32.unpack(second_word)
read_buffer.seek(0)
if version >> 16 == HEADER_MAGIC:
self._set_client_type(THeaderClientType.HEADERS)
self._read_buffer = self._parse_header_format(read_buffer)
elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
self._set_client_type(THeaderClientType.FRAMED_BINARY)
self._read_buffer = read_buffer
elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
self._set_client_type(THeaderClientType.FRAMED_COMPACT)
self._read_buffer = read_buffer
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Could not detect client transport type.",
)
def _parse_header_format(self, buffer):
# make BufferIO look like TTransport for varint helpers
buffer_transport = TMemoryBuffer()
buffer_transport._buffer = buffer
buffer.read(2) # discard the magic bytes
self.flags, = U16.unpack(buffer.read(U16.size))
self.sequence_id, = I32.unpack(buffer.read(I32.size))
header_length = U16.unpack(buffer.read(U16.size))[0] * 4
end_of_headers = buffer.tell() + header_length
if end_of_headers > len(buffer.getvalue()):
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Header size is larger than whole frame.",
)
self._protocol_id = readVarint(buffer_transport)
transforms = []
transform_count = readVarint(buffer_transport)
for _ in range(transform_count):
transform_id = readVarint(buffer_transport)
if transform_id not in READ_TRANSFORMS_BY_ID:
raise TApplicationException(
TApplicationException.INVALID_TRANSFORM,
"Unknown transform: %d" % transform_id,
)
transforms.append(transform_id)
transforms.reverse()
headers = {}
while buffer.tell() < end_of_headers:
header_type = readVarint(buffer_transport)
if header_type == TInfoHeaderType.KEY_VALUE:
count = readVarint(buffer_transport)
for _ in range(count):
key = _readString(buffer_transport)
value = _readString(buffer_transport)
headers[key] = value
else:
break # ignore unknown headers
self._read_headers = headers
# skip padding / anything we didn't understand
buffer.seek(end_of_headers)
payload = buffer.read()
for transform_id in transforms:
transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
payload = transform_fn(payload)
return BufferIO(payload)
def write(self, buf):
self._write_buffer.write(buf)
def flush(self):
payload = self._write_buffer.getvalue()
self._write_buffer = BufferIO()
buffer = BufferIO()
if self._client_type == THeaderClientType.HEADERS:
for transform_id in self._write_transforms:
transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
payload = transform_fn(payload)
headers = BufferIO()
writeVarint(headers, self._protocol_id)
writeVarint(headers, len(self._write_transforms))
for transform_id in self._write_transforms:
writeVarint(headers, transform_id)
if self._write_headers:
writeVarint(headers, TInfoHeaderType.KEY_VALUE)
writeVarint(headers, len(self._write_headers))
for key, value in self._write_headers.items():
_writeString(headers, key)
_writeString(headers, value)
self._write_headers = {}
padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
headers.write(b"\x00" * padding_needed)
header_bytes = headers.getvalue()
buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
buffer.write(U16.pack(HEADER_MAGIC))
buffer.write(U16.pack(self.flags))
buffer.write(I32.pack(self.sequence_id))
buffer.write(U16.pack(len(header_bytes) // 4))
buffer.write(header_bytes)
buffer.write(payload)
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
buffer.write(I32.pack(len(payload)))
buffer.write(payload)
elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
buffer.write(payload)
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Unknown client type.",
)
# the frame length field doesn't count towards the frame payload size
frame_bytes = buffer.getvalue()
frame_payload_size = len(frame_bytes) - 4
if frame_payload_size > self._max_frame_size:
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Attempting to send frame that is too large.",
)
self._transport.write(frame_bytes)
self._transport.flush()
@property
def cstringio_buf(self):
return self._read_buffer
def cstringio_refill(self, partialread, reqlen):
result = bytearray(partialread)
while len(result) < reqlen:
result += self.read(reqlen - len(result))
self._read_buffer = BufferIO(result)
return self._read_buffer

View file

@ -17,33 +17,37 @@
# under the License. # under the License.
# #
import http.client from io import BytesIO
import os import os
import socket import ssl
import sys import sys
import urllib.request, urllib.parse, urllib.error
import urllib.parse
import warnings import warnings
import base64
from io import StringIO from six.moves import urllib
from six.moves import http_client
from .TTransport import * from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase): class THttpClient(TTransportBase):
"""Http implementation of TTransport base.""" """Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None): def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None):
"""THttpClient supports two different types constructor parameters. """THttpClient supports two different types of construction:
THttpClient(host, port, path) - deprecated THttpClient(host, port, path) - deprecated
THttpClient(uri) THttpClient(uri, [port=<n>, path=<s>, cafile=<filename>, cert_file=<filename>, key_file=<filename>, ssl_context=<context>])
Only the second supports https. Only the second supports https. To properly authenticate against the server,
provide the client's identity by specifying cert_file and key_file. To properly
authenticate the server, specify either cafile or ssl_context with a CA defined.
NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile.
""" """
if port is not None: if port is not None:
warnings.warn( warnings.warn(
"Please use the THttpClient('http://host:port/path') syntax", "Please use the THttpClient('http{s}://host:port/path') constructor",
DeprecationWarning, DeprecationWarning,
stacklevel=2) stacklevel=2)
self.host = uri_or_host self.host = uri_or_host
@ -56,35 +60,73 @@ class THttpClient(TTransportBase):
self.scheme = parsed.scheme self.scheme = parsed.scheme
assert self.scheme in ('http', 'https') assert self.scheme in ('http', 'https')
if self.scheme == 'http': if self.scheme == 'http':
self.port = parsed.port or http.client.HTTP_PORT self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == 'https': elif self.scheme == 'https':
self.port = parsed.port or http.client.HTTPS_PORT self.port = parsed.port or http_client.HTTPS_PORT
self.certfile = cert_file
self.keyfile = key_file
self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context
self.host = parsed.hostname self.host = parsed.hostname
self.path = parsed.path self.path = parsed.path
if parsed.query: if parsed.query:
self.path += '?%s' % parsed.query self.path += '?%s' % parsed.query
self.__wbuf = StringIO() try:
proxy = urllib.request.getproxies()[self.scheme]
except KeyError:
proxy = None
else:
if urllib.request.proxy_bypass(self.host):
proxy = None
if proxy:
parsed = urllib.parse.urlparse(proxy)
self.realhost = self.host
self.realport = self.port
self.host = parsed.hostname
self.port = parsed.port
self.proxy_auth = self.basic_proxy_auth_header(parsed)
else:
self.realhost = self.realport = self.proxy_auth = None
self.__wbuf = BytesIO()
self.__http = None self.__http = None
self.__http_response = None
self.__timeout = None self.__timeout = None
self.__custom_headers = None self.__custom_headers = None
@staticmethod
def basic_proxy_auth_header(proxy):
if proxy is None or not proxy.username:
return None
ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
urllib.parse.unquote(proxy.password))
cr = base64.b64encode(ap).strip()
return "Basic " + cr
def using_proxy(self):
return self.realhost is not None
def open(self): def open(self):
if self.scheme == 'http': if self.scheme == 'http':
self.__http = http.client.HTTP(self.host, self.port) self.__http = http_client.HTTPConnection(self.host, self.port,
else: timeout=self.__timeout)
self.__http = http.client.HTTPS(self.host, self.port) elif self.scheme == 'https':
self.__http = http_client.HTTPSConnection(self.host, self.port,
key_file=self.keyfile,
cert_file=self.certfile,
timeout=self.__timeout,
context=self.context)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
def close(self): def close(self):
self.__http.close() self.__http.close()
self.__http = None self.__http = None
self.__http_response = None
def isOpen(self): def isOpen(self):
return self.__http is not None return self.__http is not None
def setTimeout(self, ms): def setTimeout(self, ms):
if not hasattr(socket, 'getdefaulttimeout'):
raise NotImplementedError
if ms is None: if ms is None:
self.__timeout = None self.__timeout = None
else: else:
@ -94,20 +136,11 @@ class THttpClient(TTransportBase):
self.__custom_headers = headers self.__custom_headers = headers
def read(self, sz): def read(self, sz):
return self.__http.file.read(sz) return self.__http_response.read(sz)
def write(self, buf): def write(self, buf):
self.__wbuf.write(buf) self.__wbuf.write(buf)
def __withTimeout(f):
def _f(*args, **kwargs):
orig_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(args[0].__timeout)
result = f(*args, **kwargs)
socket.setdefaulttimeout(orig_timeout)
return result
return _f
def flush(self): def flush(self):
if self.isOpen(): if self.isOpen():
self.close() self.close()
@ -115,15 +148,21 @@ class THttpClient(TTransportBase):
# Pull data out of buffer # Pull data out of buffer
data = self.__wbuf.getvalue() data = self.__wbuf.getvalue()
self.__wbuf = StringIO() self.__wbuf = BytesIO()
# HTTP request # HTTP request
if self.using_proxy() and self.scheme == "http":
# need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel)
self.__http.putrequest('POST', "http://%s:%s%s" %
(self.realhost, self.realport, self.path))
else:
self.__http.putrequest('POST', self.path) self.__http.putrequest('POST', self.path)
# Write headers # Write headers
self.__http.putheader('Host', self.host)
self.__http.putheader('Content-Type', 'application/x-thrift') self.__http.putheader('Content-Type', 'application/x-thrift')
self.__http.putheader('Content-Length', str(len(data))) self.__http.putheader('Content-Length', str(len(data)))
if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None:
self.__http.putheader("Proxy-Authorization", self.proxy_auth)
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
user_agent = 'Python/THttpClient' user_agent = 'Python/THttpClient'
@ -133,7 +172,7 @@ class THttpClient(TTransportBase):
self.__http.putheader('User-Agent', user_agent) self.__http.putheader('User-Agent', user_agent)
if self.__custom_headers: if self.__custom_headers:
for key, val in self.__custom_headers.items(): for key, val in six.iteritems(self.__custom_headers):
self.__http.putheader(key, val) self.__http.putheader(key, val)
self.__http.endheaders() self.__http.endheaders()
@ -142,8 +181,11 @@ class THttpClient(TTransportBase):
self.__http.send(data) self.__http.send(data)
# Get reply to flush the request # Get reply to flush the request
self.code, self.message, self.headers = self.__http.getreply() self.__http_response = self.__http.getresponse()
self.code = self.__http_response.status
self.message = self.__http_response.reason
self.headers = self.__http_response.msg
# Decorate if we know how to timeout # Saves the cookie sent by the server response
if hasattr(socket, 'getdefaulttimeout'): if 'Set-Cookie' in self.headers:
flush = __withTimeout(flush) self.__http.putheader('Cookie', self.headers['Set-Cookie'])

View file

@ -17,161 +17,353 @@
# under the License. # under the License.
# #
import logging
import os import os
import socket import socket
import ssl import ssl
import sys
import warnings
from .sslcompat import _match_hostname, _match_has_ipaddress
from thrift.transport import TSocket from thrift.transport import TSocket
from thrift.transport.TTransport import TTransportException from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
warnings.filterwarnings(
'default', category=DeprecationWarning, module=__name__)
class TSSLSocket(TSocket.TSocket):
class TSSLBase(object):
# SSLContext is not available for Python < 2.7.9
_has_ssl_context = sys.hexversion >= 0x020709F0
# ciphers argument is not available for Python < 2.7.0
_has_ciphers = sys.hexversion >= 0x020700F0
# For python >= 2.7.9, use latest TLS that both client and server
# supports.
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
# For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
# unavailable.
_default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
ssl.PROTOCOL_TLSv1
def _init_context(self, ssl_version):
if self._has_ssl_context:
self._context = ssl.SSLContext(ssl_version)
if self._context.protocol == ssl.PROTOCOL_SSLv23:
self._context.options |= ssl.OP_NO_SSLv2
self._context.options |= ssl.OP_NO_SSLv3
else:
self._context = None
self._ssl_version = ssl_version
@property
def _should_verify(self):
if self._has_ssl_context:
return self._context.verify_mode != ssl.CERT_NONE
else:
return self.cert_reqs != ssl.CERT_NONE
@property
def ssl_version(self):
if self._has_ssl_context:
return self.ssl_context.protocol
else:
return self._ssl_version
@property
def ssl_context(self):
return self._context
SSL_VERSION = _default_protocol
""" """
SSL implementation of client-side TSocket Default SSL version.
For backwards compatibility, it can be modified.
Use __init__ keyword argument "ssl_version" instead.
"""
def _deprecated_arg(self, args, kwargs, pos, key):
if len(args) <= pos:
return
real_pos = pos + 3
warnings.warn(
'%dth positional argument is deprecated.'
'please use keyword argument instead.'
% real_pos, DeprecationWarning, stacklevel=3)
if key in kwargs:
raise TypeError(
'Duplicate argument: %dth argument and %s keyword argument.'
% (real_pos, key))
kwargs[key] = args[pos]
def _unix_socket_arg(self, host, port, args, kwargs):
key = 'unix_socket'
if host is None and port is None and len(args) == 1 and key not in kwargs:
kwargs[key] = args[0]
return True
return False
def __getattr__(self, key):
if key == 'SSL_VERSION':
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version attribute instead.',
DeprecationWarning, stacklevel=2)
return self.ssl_version
def __init__(self, server_side, host, ssl_opts):
self._server_side = server_side
if TSSLBase.SSL_VERSION != self._default_protocol:
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version keyword argument instead.',
DeprecationWarning, stacklevel=2)
self._context = ssl_opts.pop('ssl_context', None)
self._server_hostname = None
if not self._server_side:
self._server_hostname = ssl_opts.pop('server_hostname', host)
if self._context:
self._custom_context = True
if ssl_opts:
raise ValueError(
'Incompatible arguments: ssl_context and %s'
% ' '.join(ssl_opts.keys()))
if not self._has_ssl_context:
raise ValueError(
'ssl_context is not available for this version of Python')
else:
self._custom_context = False
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
self._init_context(ssl_version)
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
self.ca_certs = ssl_opts.pop('ca_certs', None)
self.keyfile = ssl_opts.pop('keyfile', None)
self.certfile = ssl_opts.pop('certfile', None)
self.ciphers = ssl_opts.pop('ciphers', None)
if ssl_opts:
raise ValueError(
'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
if self._should_verify:
if not self.ca_certs:
raise ValueError(
'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
if not os.access(self.ca_certs, os.R_OK):
raise IOError('Certificate Authority ca_certs file "%s" '
'is not readable, cannot validate SSL '
'certificates.' % (self.ca_certs))
@property
def certfile(self):
return self._certfile
@certfile.setter
def certfile(self, certfile):
if self._server_side and not certfile:
raise ValueError('certfile is needed for server-side')
if certfile and not os.access(certfile, os.R_OK):
raise IOError('No such certfile found: %s' % (certfile))
self._certfile = certfile
def _wrap_socket(self, sock):
if self._has_ssl_context:
if not self._custom_context:
self.ssl_context.verify_mode = self.cert_reqs
if self.certfile:
self.ssl_context.load_cert_chain(self.certfile,
self.keyfile)
if self.ciphers:
self.ssl_context.set_ciphers(self.ciphers)
if self.ca_certs:
self.ssl_context.load_verify_locations(self.ca_certs)
return self.ssl_context.wrap_socket(
sock, server_side=self._server_side,
server_hostname=self._server_hostname)
else:
ssl_opts = {
'ssl_version': self._ssl_version,
'server_side': self._server_side,
'ca_certs': self.ca_certs,
'keyfile': self.keyfile,
'certfile': self.certfile,
'cert_reqs': self.cert_reqs,
}
if self.ciphers:
if self._has_ciphers:
ssl_opts['ciphers'] = self.ciphers
else:
logger.warning(
'ciphers is specified but ignored due to old Python version')
return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
"""
SSL implementation of TSocket
This class creates outbound sockets wrapped using the This class creates outbound sockets wrapped using the
python standard ssl module for encrypted connections. python standard ssl module for encrypted connections.
The protocol used is set using the class variable
SSL_VERSION, which must be one of ssl.PROTOCOL_* and
defaults to ssl.PROTOCOL_TLSv1 for greatest security.
""" """
SSL_VERSION = ssl.PROTOCOL_TLSv1
def __init__(self, # New signature
host='localhost', # def __init__(self, host='localhost', port=9090, unix_socket=None,
port=9090, # **ssl_args):
validate=True, # Deprecated signature
ca_certs=None, # def __init__(self, host='localhost', port=9090, validate=True,
unix_socket=None): # ca_certs=None, keyfile=None, certfile=None,
"""Create SSL TSocket # unix_socket=None, ciphers=None):
def __init__(self, host='localhost', port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
@param validate: Set to False to disable SSL certificate validation Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
@type validate: bool ``ssl_version``, ``ca_certs``,
@param ca_certs: Filename to the Certificate Authority pem file, possibly a ``ciphers`` (Python 2.7.0 or later),
file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to ``server_hostname`` (Python 2.7.9 or later)
the ssl_wrap function as the 'ca_certs' parameter. Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
@type ca_certs: str
Raises an IOError exception if validate is True and the ca_certs file is Alternative keyword arguments: (Python 2.7.9 or later)
None, not present or unreadable. ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
``socket_keepalive`` enable TCP keepalive, default off.
""" """
self.validate = validate
self.is_valid = False self.is_valid = False
self.peercert = None self.peercert = None
if not validate:
self.cert_reqs = ssl.CERT_NONE if args:
else: if len(args) > 6:
self.cert_reqs = ssl.CERT_REQUIRED raise TypeError('Too many positional argument')
self.ca_certs = ca_certs if not self._unix_socket_arg(host, port, args, kwargs):
if validate: self._deprecated_arg(args, kwargs, 0, 'validate')
if ca_certs is None or not os.access(ca_certs, os.R_OK): self._deprecated_arg(args, kwargs, 1, 'ca_certs')
raise IOError('Certificate Authority ca_certs file "%s" ' self._deprecated_arg(args, kwargs, 2, 'keyfile')
'is not readable, cannot validate SSL ' self._deprecated_arg(args, kwargs, 3, 'certfile')
'certificates.' % (ca_certs)) self._deprecated_arg(args, kwargs, 4, 'unix_socket')
TSocket.TSocket.__init__(self, host, port, unix_socket) self._deprecated_arg(args, kwargs, 5, 'ciphers')
validate = kwargs.pop('validate', None)
if validate is not None:
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
warnings.warn(
'validate is deprecated. please use cert_reqs=ssl.%s instead'
% cert_reqs_name,
DeprecationWarning, stacklevel=2)
if 'cert_reqs' in kwargs:
raise TypeError('Cannot specify both validate and cert_reqs')
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
unix_socket = kwargs.pop('unix_socket', None)
socket_keepalive = kwargs.pop('socket_keepalive', False)
self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, False, host, kwargs)
TSocket.TSocket.__init__(self, host, port, unix_socket,
socket_keepalive=socket_keepalive)
def close(self):
try:
self.handle.settimeout(0.001)
self.handle = self.handle.unwrap()
except (ssl.SSLError, socket.error, OSError):
# could not complete shutdown in a reasonable amount of time. bail.
pass
TSocket.TSocket.close(self)
@property
def validate(self):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
return self.cert_reqs != ssl.CERT_NONE
@validate.setter
def validate(self, value):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
def _do_open(self, family, socktype):
plain_sock = socket.socket(family, socktype)
try:
return self._wrap_socket(plain_sock)
except Exception as ex:
plain_sock.close()
msg = 'failed to initialize SSL'
logger.exception(msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex)
def open(self): def open(self):
super(TSSLSocket, self).open()
if self._should_verify:
self.peercert = self.handle.getpeercert()
try: try:
res0 = self._resolveAddr() self._validate_callback(self.peercert, self._server_hostname)
for res in res0:
sock_family, sock_type = res[0:2]
ip_port = res[4]
plain_sock = socket.socket(sock_family, sock_type)
self.handle = ssl.wrap_socket(plain_sock,
ssl_version=self.SSL_VERSION,
do_handshake_on_connect=True,
ca_certs=self.ca_certs,
cert_reqs=self.cert_reqs)
self.handle.settimeout(self._timeout)
try:
self.handle.connect(ip_port)
except socket.error as e:
if res is not res0[-1]:
continue
else:
raise e
break
except socket.error as e:
if self._unix_socket:
message = 'Could not connect to secure socket %s' % self._unix_socket
else:
message = 'Could not connect to %s:%d' % (self.host, self.port)
raise TTransportException(type=TTransportException.NOT_OPEN,
message=message)
if self.validate:
self._validate_cert()
def _validate_cert(self):
"""internal method to validate the peer's SSL certificate, and to check the
commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
in certificates.
raises TTransportException if the certificate fails validation.
"""
cert = self.handle.getpeercert()
self.peercert = cert
if 'subject' not in cert:
raise TTransportException(
type=TTransportException.NOT_OPEN,
message='No SSL certificate found from %s:%s' % (self.host, self.port))
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
if not isinstance(field, tuple):
continue
cert_pair = field[0]
if len(cert_pair) < 2:
continue
cert_key, cert_value = cert_pair[0:2]
if cert_key != 'commonName':
continue
certhost = cert_value
if certhost == self.host:
# success, cert commonName matches desired hostname
self.is_valid = True self.is_valid = True
return except TTransportException:
else: raise
raise TTransportException( except Exception as ex:
type=TTransportException.UNKNOWN, raise TTransportException(message=str(ex), inner=ex)
message='Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
type=TTransportException.UNKNOWN,
message='Could not validate SSL certificate from '
'host "%s". Cert=%s' % (self.host, cert))
class TSSLServerSocket(TSocket.TServerSocket): class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
"""SSL implementation of TServerSocket """SSL implementation of TServerSocket
This uses the ssl module's wrap_socket() method to provide SSL This uses the ssl module's wrap_socket() method to provide SSL
negotiated encryption. negotiated encryption.
""" """
SSL_VERSION = ssl.PROTOCOL_TLSv1
def __init__(self, # New signature
host=None, # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
port=9090, # Deprecated signature
certfile='cert.pem', # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
unix_socket=None): def __init__(self, host=None, port=9090, *args, **kwargs):
"""Initialize a TSSLServerSocket """Positional arguments: ``host``, ``port``, ``unix_socket``
@param certfile: filename of the server certificate, defaults to cert.pem Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
@type certfile: str ``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
@param host: The hostname or IP to bind the listen socket to, See ssl.wrap_socket documentation.
i.e. 'localhost' for only allowing local network connections.
Pass None to bind to all interfaces. Alternative keyword arguments: (Python 2.7.9 or later)
@type host: str ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
@param port: The port to listen on for inbound connections. ``server_hostname``: Passed to SSLContext.wrap_socket
@type port: int
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
""" """
self.setCertfile(certfile) if args:
TSocket.TServerSocket.__init__(self, host, port) if len(args) > 3:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'certfile')
self._deprecated_arg(args, kwargs, 1, 'unix_socket')
self._deprecated_arg(args, kwargs, 2, 'ciphers')
if 'ssl_context' not in kwargs:
# Preserve existing behaviors for default values
if 'cert_reqs' not in kwargs:
kwargs['cert_reqs'] = ssl.CERT_NONE
if'certfile' not in kwargs:
kwargs['certfile'] = 'cert.pem'
unix_socket = kwargs.pop('unix_socket', None)
self._validate_callback = \
kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, True, None, kwargs)
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
if self._should_verify and not _match_has_ipaddress:
raise ValueError('Need ipaddress and backports.ssl_match_hostname '
'module to verify client certificate')
def setCertfile(self, certfile): def setCertfile(self, certfile):
"""Set or change the server certificate file used to wrap new connections. """Set or change the server certificate file used to wrap new
connections.
@param certfile: The filename of the server certificate, @param certfile: The filename of the server certificate,
i.e. '/etc/certs/server.pem' i.e. '/etc/certs/server.pem'
@ -179,24 +371,38 @@ class TSSLServerSocket(TSocket.TServerSocket):
Raises an IOError exception if the certfile is not present or unreadable. Raises an IOError exception if the certfile is not present or unreadable.
""" """
if not os.access(certfile, os.R_OK): warnings.warn(
raise IOError('No such certfile found: %s' % (certfile)) 'setCertfile is deprecated. please use certfile property instead.',
DeprecationWarning, stacklevel=2)
self.certfile = certfile self.certfile = certfile
def accept(self): def accept(self):
plain_client, addr = self.handle.accept() plain_client, addr = self.handle.accept()
try: try:
client = ssl.wrap_socket(plain_client, certfile=self.certfile, client = self._wrap_socket(plain_client)
server_side=True, ssl_version=self.SSL_VERSION) except (ssl.SSLError, socket.error, OSError):
except ssl.SSLError as ssl_exc: logger.exception('Error while accepting from %s', addr)
# failed handshake/ssl wrap, close socket to client # failed handshake/ssl wrap, close socket to client
plain_client.close() plain_client.close()
# raise ssl_exc # raise
# We can't raise the exception, because it kills most TServer derived # We can't raise the exception, because it kills most TServer derived
# serve() methods. # serve() methods.
# Instead, return None, and let the TServer instance deal with it in # Instead, return None, and let the TServer instance deal with it in
# other exception handling. (but TSimpleServer dies anyway) # other exception handling. (but TSimpleServer dies anyway)
return None return None
if self._should_verify:
client.peercert = client.getpeercert()
try:
self._validate_callback(client.peercert, addr[0])
client.is_valid = True
except Exception:
logger.warn('Failed to validate client certificate address: %s',
addr[0], exc_info=True)
client.close()
plain_client.close()
return None
result = TSocket.TSocket() result = TSocket.TSocket()
result.setHandle(client) result.handle = client
return result return result

View file

@ -18,11 +18,14 @@
# #
import errno import errno
import logging
import os import os
import socket import socket
import sys import sys
from .TTransport import * from .TTransport import TTransportBase, TTransportException, TServerTransportBase
logger = logging.getLogger(__name__)
class TSocketBase(TTransportBase): class TSocketBase(TTransportBase):
@ -33,10 +36,10 @@ class TSocketBase(TTransportBase):
else: else:
return socket.getaddrinfo(self.host, return socket.getaddrinfo(self.host,
self.port, self.port,
socket.AF_UNSPEC, self._socket_family,
socket.SOCK_STREAM, socket.SOCK_STREAM,
0, 0,
socket.AI_PASSIVE | socket.AI_ADDRCONFIG) socket.AI_PASSIVE)
def close(self): def close(self):
if self.handle: if self.handle:
@ -47,25 +50,55 @@ class TSocketBase(TTransportBase):
class TSocket(TSocketBase): class TSocket(TSocketBase):
"""Socket implementation of TTransport base.""" """Socket implementation of TTransport base."""
def __init__(self, host='localhost', port=9090, unix_socket=None): def __init__(self, host='localhost', port=9090, unix_socket=None,
socket_family=socket.AF_UNSPEC,
socket_keepalive=False):
"""Initialize a TSocket """Initialize a TSocket
@param host(str) The host to connect to. @param host(str) The host to connect to.
@param port(int) The (TCP) port to connect to. @param port(int) The (TCP) port to connect to.
@param unix_socket(str) The filename of a unix socket to connect to. @param unix_socket(str) The filename of a unix socket to connect to.
(host and port will be ignored.) (host and port will be ignored.)
@param socket_family(int) The socket family to use with this socket.
@param socket_keepalive(bool) enable TCP keepalive, default off.
""" """
self.host = host self.host = host
self.port = port self.port = port
self.handle = None self.handle = None
self._unix_socket = unix_socket self._unix_socket = unix_socket
self._timeout = None self._timeout = None
self._socket_family = socket_family
self._socket_keepalive = socket_keepalive
def setHandle(self, h): def setHandle(self, h):
self.handle = h self.handle = h
def isOpen(self): def isOpen(self):
return self.handle is not None if self.handle is None:
return False
# this lets us cheaply see if the other end of the socket is still
# connected. if disconnected, we'll get EOF back (expressed as zero
# bytes of data) otherwise we'll get one byte or an error indicating
# we'd have to block for data.
#
# note that we're not doing this with socket.MSG_DONTWAIT because 1)
# it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us
# when timeout is non-zero.
original_timeout = self.handle.gettimeout()
try:
self.handle.settimeout(0)
try:
peeked_bytes = self.handle.recv(1, socket.MSG_PEEK)
except (socket.error, OSError) as exc: # on modern python this is just BlockingIOError
if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
return True
return False
finally:
self.handle.settimeout(original_timeout)
# the length will be zero if we got EOF (indicating connection closed)
return len(peeked_bytes) == 1
def setTimeout(self, ms): def setTimeout(self, ms):
if ms is None: if ms is None:
@ -76,27 +109,41 @@ class TSocket(TSocketBase):
if self.handle is not None: if self.handle is not None:
self.handle.settimeout(self._timeout) self.handle.settimeout(self._timeout)
def _do_open(self, family, socktype):
return socket.socket(family, socktype)
@property
def _address(self):
return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)
def open(self): def open(self):
if self.handle:
raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open")
try: try:
res0 = self._resolveAddr() addrs = self._resolveAddr()
for res in res0: except socket.gaierror as gai:
self.handle = socket.socket(res[0], res[1]) msg = 'failed to resolve sockaddr for ' + str(self._address)
self.handle.settimeout(self._timeout) logger.exception(msg)
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai)
for family, socktype, _, _, sockaddr in addrs:
handle = self._do_open(family, socktype)
# TCP_KEEPALIVE
if self._socket_keepalive:
handle.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1)
handle.settimeout(self._timeout)
try: try:
self.handle.connect(res[4]) handle.connect(sockaddr)
except socket.error as e: self.handle = handle
if res is not res0[-1]: return
continue except socket.error:
else: handle.close()
raise e logger.info('Could not connect to %s', sockaddr, exc_info=True)
break msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
except socket.error as e: addrs))
if self._unix_socket: logger.error(msg)
message = 'Could not connect to socket %s' % self._unix_socket raise TTransportException(type=TTransportException.NOT_OPEN, message=msg)
else:
message = 'Could not connect to %s:%d' % (self.host, self.port)
raise TTransportException(type=TTransportException.NOT_OPEN,
message=message)
def read(self, sz): def read(self, sz):
try: try:
@ -111,8 +158,10 @@ class TSocket(TSocketBase):
self.close() self.close()
# Trigger the check to raise the END_OF_FILE exception below. # Trigger the check to raise the END_OF_FILE exception below.
buff = '' buff = ''
elif e.args[0] == errno.ETIMEDOUT:
raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
else: else:
raise raise TTransportException(message="unexpected exception", inner=e)
if len(buff) == 0: if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE, raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes') message='TSocket read 0 bytes')
@ -125,12 +174,15 @@ class TSocket(TSocketBase):
sent = 0 sent = 0
have = len(buff) have = len(buff)
while sent < have: while sent < have:
try:
plus = self.handle.send(buff) plus = self.handle.send(buff)
if plus == 0: if plus == 0:
raise TTransportException(type=TTransportException.END_OF_FILE, raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket sent 0 bytes') message='TSocket sent 0 bytes')
sent += plus sent += plus
buff = buff[plus:] buff = buff[plus:]
except socket.error as e:
raise TTransportException(message="unexpected exception", inner=e)
def flush(self): def flush(self):
pass pass
@ -139,16 +191,27 @@ class TSocket(TSocketBase):
class TServerSocket(TSocketBase, TServerTransportBase): class TServerSocket(TSocketBase, TServerTransportBase):
"""Socket implementation of TServerTransport base.""" """Socket implementation of TServerTransport base."""
def __init__(self, host=None, port=9090, unix_socket=None): def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
self.host = host self.host = host
self.port = port self.port = port
self._unix_socket = unix_socket self._unix_socket = unix_socket
self._socket_family = socket_family
self.handle = None self.handle = None
self._backlog = 128
def setBacklog(self, backlog=None):
if not self.handle:
self._backlog = backlog
else:
# We cann't update backlog when it is already listening, since the
# handle has been created.
logger.warn('You have to set backlog before listen.')
def listen(self): def listen(self):
res0 = self._resolveAddr() res0 = self._resolveAddr()
socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
for res in res0: for res in res0:
if res[0] is socket.AF_INET6 or res is res0[-1]: if res[0] is socket_family or res is res0[-1]:
break break
# We need remove the old unix socket if the file exists and # We need remove the old unix socket if the file exists and
@ -167,7 +230,7 @@ class TServerSocket(TSocketBase, TServerTransportBase):
if hasattr(self.handle, 'settimeout'): if hasattr(self.handle, 'settimeout'):
self.handle.settimeout(None) self.handle.settimeout(None)
self.handle.bind(res[4]) self.handle.bind(res[4])
self.handle.listen(128) self.handle.listen(self._backlog)
def accept(self): def accept(self):
client, addr = self.handle.accept() client, addr = self.handle.accept()

View file

@ -17,9 +17,9 @@
# under the License. # under the License.
# #
from six import BytesIO
from struct import pack, unpack from struct import pack, unpack
from thrift.Thrift import TException from thrift.Thrift import TException
from ..compat import BufferIO
class TTransportException(TException): class TTransportException(TException):
@ -30,13 +30,17 @@ class TTransportException(TException):
ALREADY_OPEN = 2 ALREADY_OPEN = 2
TIMED_OUT = 3 TIMED_OUT = 3
END_OF_FILE = 4 END_OF_FILE = 4
NEGATIVE_SIZE = 5
SIZE_LIMIT = 6
INVALID_CLIENT_TYPE = 7
def __init__(self, type=UNKNOWN, message=None): def __init__(self, type=UNKNOWN, message=None, inner=None):
TException.__init__(self, message) TException.__init__(self, message)
self.type = type self.type = type
self.inner = inner
class TTransportBase: class TTransportBase(object):
"""Base class for Thrift transport layer.""" """Base class for Thrift transport layer."""
def isOpen(self): def isOpen(self):
@ -56,10 +60,11 @@ class TTransportBase:
have = 0 have = 0
while (have < sz): while (have < sz):
chunk = self.read(sz - have) chunk = self.read(sz - have)
have += len(chunk) chunkLen = len(chunk)
have += chunkLen
buff += chunk buff += chunk
if len(chunk) == 0: if chunkLen == 0:
raise EOFError() raise EOFError()
return buff return buff
@ -72,7 +77,7 @@ class TTransportBase:
# This class should be thought of as an interface. # This class should be thought of as an interface.
class CReadableTransport: class CReadableTransport(object):
"""base class for transports that are readable from C""" """base class for transports that are readable from C"""
# TODO(dreiss): Think about changing this interface to allow us to use # TODO(dreiss): Think about changing this interface to allow us to use
@ -100,7 +105,7 @@ class CReadableTransport:
pass pass
class TServerTransportBase: class TServerTransportBase(object):
"""Base class for Thrift server transports.""" """Base class for Thrift server transports."""
def listen(self): def listen(self):
@ -113,14 +118,14 @@ class TServerTransportBase:
pass pass
class TTransportFactoryBase: class TTransportFactoryBase(object):
"""Base class for a Transport Factory""" """Base class for a Transport Factory"""
def getTransport(self, trans): def getTransport(self, trans):
return trans return trans
class TBufferedTransportFactory: class TBufferedTransportFactory(object):
"""Factory transport that builds buffered transports""" """Factory transport that builds buffered transports"""
def getTransport(self, trans): def getTransport(self, trans):
@ -138,8 +143,9 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
self.__trans = trans self.__trans = trans
self.__wbuf = BytesIO() self.__wbuf = BufferIO()
self.__rbuf = BytesIO("") # Pass string argument to initialize read buffer as cStringIO.InputType
self.__rbuf = BufferIO(b'')
self.__rbuf_size = rbuf_size self.__rbuf_size = rbuf_size
def isOpen(self): def isOpen(self):
@ -155,17 +161,21 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
ret = self.__rbuf.read(sz) ret = self.__rbuf.read(sz)
if len(ret) != 0: if len(ret) != 0:
return ret return ret
self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
self.__rbuf = BytesIO(self.__trans.read(max(sz, self.__rbuf_size)))
return self.__rbuf.read(sz) return self.__rbuf.read(sz)
def write(self, buf): def write(self, buf):
try:
self.__wbuf.write(buf) self.__wbuf.write(buf)
except Exception as e:
# on exception reset wbuf so it doesn't contain a partial function call
self.__wbuf = BufferIO()
raise e
def flush(self): def flush(self):
out = self.__wbuf.getvalue() out = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure # reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BytesIO() self.__wbuf = BufferIO()
self.__trans.write(out) self.__trans.write(out)
self.__trans.flush() self.__trans.flush()
@ -184,12 +194,12 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
if len(retstring) < reqlen: if len(retstring) < reqlen:
retstring += self.__trans.readAll(reqlen - len(retstring)) retstring += self.__trans.readAll(reqlen - len(retstring))
self.__rbuf = BytesIO(retstring) self.__rbuf = BufferIO(retstring)
return self.__rbuf return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport): class TMemoryBuffer(TTransportBase, CReadableTransport):
"""Wraps a cStringIO object as a TTransport. """Wraps a cBytesIO object as a TTransport.
NOTE: Unlike the C++ version of this class, you cannot write to it NOTE: Unlike the C++ version of this class, you cannot write to it
then immediately read from it. If you want to read from a then immediately read from it. If you want to read from a
@ -197,15 +207,17 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
TODO(dreiss): Make this work like the C++ version. TODO(dreiss): Make this work like the C++ version.
""" """
def __init__(self, value=None): def __init__(self, value=None, offset=0):
"""value -- a value to read from for stringio """value -- a value to read from for stringio
If value is set, this will be a transport for reading, If value is set, this will be a transport for reading,
otherwise, it is for writing""" otherwise, it is for writing"""
if value is not None: if value is not None:
self._buffer = BytesIO(value) self._buffer = BufferIO(value)
else: else:
self._buffer = BytesIO() self._buffer = BufferIO()
if offset:
self._buffer.seek(offset)
def isOpen(self): def isOpen(self):
return not self._buffer.closed return not self._buffer.closed
@ -220,10 +232,7 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
return self._buffer.read(sz) return self._buffer.read(sz)
def write(self, buf): def write(self, buf):
try:
self._buffer.write(buf) self._buffer.write(buf)
except TypeError:
self._buffer.write(buf.encode('cp437'))
def flush(self): def flush(self):
pass pass
@ -241,7 +250,7 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
raise EOFError() raise EOFError()
class TFramedTransportFactory: class TFramedTransportFactory(object):
"""Factory transport that builds framed transports""" """Factory transport that builds framed transports"""
def getTransport(self, trans): def getTransport(self, trans):
@ -254,8 +263,8 @@ class TFramedTransport(TTransportBase, CReadableTransport):
def __init__(self, trans,): def __init__(self, trans,):
self.__trans = trans self.__trans = trans
self.__rbuf = BytesIO() self.__rbuf = BufferIO(b'')
self.__wbuf = BytesIO() self.__wbuf = BufferIO()
def isOpen(self): def isOpen(self):
return self.__trans.isOpen() return self.__trans.isOpen()
@ -277,7 +286,7 @@ class TFramedTransport(TTransportBase, CReadableTransport):
def readFrame(self): def readFrame(self):
buff = self.__trans.readAll(4) buff = self.__trans.readAll(4)
sz, = unpack('!i', buff) sz, = unpack('!i', buff)
self.__rbuf = BytesIO(self.__trans.readAll(sz)) self.__rbuf = BufferIO(self.__trans.readAll(sz))
def write(self, buf): def write(self, buf):
self.__wbuf.write(buf) self.__wbuf.write(buf)
@ -286,7 +295,7 @@ class TFramedTransport(TTransportBase, CReadableTransport):
wout = self.__wbuf.getvalue() wout = self.__wbuf.getvalue()
wsz = len(wout) wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure # reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BytesIO() self.__wbuf = BufferIO()
# N.B.: Doing this string concatenation is WAY cheaper than making # N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in # two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty # Python turn out to be REALLY expensive, but it seems to do a pretty
@ -307,7 +316,7 @@ class TFramedTransport(TTransportBase, CReadableTransport):
while len(prefix) < reqlen: while len(prefix) < reqlen:
self.readFrame() self.readFrame()
prefix += self.__rbuf.getvalue() prefix += self.__rbuf.getvalue()
self.__rbuf = BytesIO(prefix) self.__rbuf = BufferIO(prefix)
return self.__rbuf return self.__rbuf
@ -331,3 +340,117 @@ class TFileObjectTransport(TTransportBase):
def flush(self): def flush(self):
self.fileobj.flush() self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
"""
SASL transport
"""
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
def __init__(self, transport, host, service, mechanism='GSSAPI',
**sasl_kwargs):
"""
transport: an underlying transport to use, typically just a TSocket
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.transport = transport
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
self.__wbuf = BufferIO()
self.__rbuf = BufferIO(b'')
def open(self):
if not self.transport.isOpen():
self.transport.open()
self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii'))
self.send_sasl_msg(self.OK, self.sasl.process())
while True:
status, challenge = self.recv_sasl_msg()
if status == self.OK:
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
elif status == self.COMPLETE:
if not self.sasl.complete:
raise TTransportException(
TTransportException.NOT_OPEN,
"The server erroneously indicated "
"that SASL negotiation was complete")
else:
break
else:
raise TTransportException(
TTransportException.NOT_OPEN,
"Bad SASL negotiation status: %d (%s)"
% (status, challenge))
def send_sasl_msg(self, status, body):
header = pack(">BI", status, len(body))
self.transport.write(header + body)
self.transport.flush()
def recv_sasl_msg(self):
header = self.transport.readAll(5)
status, length = unpack(">BI", header)
if length > 0:
payload = self.transport.readAll(length)
else:
payload = ""
return status, payload
def write(self, data):
self.__wbuf.write(data)
def flush(self):
data = self.__wbuf.getvalue()
encoded = self.sasl.wrap(data)
self.transport.write(pack("!i", len(encoded)) + encoded)
self.transport.flush()
self.__wbuf = BufferIO()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self._read_frame()
return self.__rbuf.read(sz)
def _read_frame(self):
header = self.transport.readAll(4)
length, = unpack('!i', header)
encoded = self.transport.readAll(length)
self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
def close(self):
self.sasl.dispose()
self.transport.close()
# based on TFramedTransport
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self._read_frame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf

View file

@ -17,14 +17,15 @@
# under the License. # under the License.
# #
from io import StringIO from io import BytesIO
import struct
from zope.interface import implements, Interface, Attribute from zope.interface import implementer, Interface, Attribute
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone connectionDone
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.threads import deferToThread
from twisted.protocols import basic from twisted.protocols import basic
from twisted.python import log
from twisted.web import server, resource, http from twisted.web import server, resource, http
from thrift.transport import TTransport from thrift.transport import TTransport
@ -33,15 +34,15 @@ from thrift.transport import TTransport
class TMessageSenderTransport(TTransport.TTransportBase): class TMessageSenderTransport(TTransport.TTransportBase):
def __init__(self): def __init__(self):
self.__wbuf = StringIO() self.__wbuf = BytesIO()
def write(self, buf): def write(self, buf):
self.__wbuf.write(buf) self.__wbuf.write(buf)
def flush(self): def flush(self):
msg = self.__wbuf.getvalue() msg = self.__wbuf.getvalue()
self.__wbuf = StringIO() self.__wbuf = BytesIO()
self.sendMessage(msg) return self.sendMessage(msg)
def sendMessage(self, message): def sendMessage(self, message):
raise NotImplementedError raise NotImplementedError
@ -54,7 +55,7 @@ class TCallbackTransport(TMessageSenderTransport):
self.func = func self.func = func
def sendMessage(self, message): def sendMessage(self, message):
self.func(message) return self.func(message)
class ThriftClientProtocol(basic.Int32StringReceiver): class ThriftClientProtocol(basic.Int32StringReceiver):
@ -81,11 +82,18 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
self.started.callback(self.client) self.started.callback(self.client)
def connectionLost(self, reason=connectionDone): def connectionLost(self, reason=connectionDone):
for k, v in self.client._reqs.items(): # the called errbacks can add items to our client's _reqs,
# so we need to use a tmp, and iterate until no more requests
# are added during errbacks
if self.client:
tex = TTransport.TTransportException( tex = TTransport.TTransportException(
type=TTransport.TTransportException.END_OF_FILE, type=TTransport.TTransportException.END_OF_FILE,
message='Connection closed') message='Connection closed (%s)' % reason)
while self.client._reqs:
_, v = self.client._reqs.popitem()
v.errback(tex) v.errback(tex)
del self.client._reqs
self.client = None
def stringReceived(self, frame): def stringReceived(self, frame):
tr = TTransport.TMemoryBuffer(frame) tr = TTransport.TMemoryBuffer(frame)
@ -101,6 +109,108 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
method(iprot, mtype, rseqid) method(iprot, mtype, rseqid)
class ThriftSASLClientProtocol(ThriftClientProtocol):
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.SASLCLient = SASLClient
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
self._sasl_negotiation_deferred = None
self._sasl_negotiation_status = None
self.client = None
if host is not None:
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
def createSASLClient(self, host, service, mechanism, **kwargs):
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
def dispatch(self, msg):
encoded = self.sasl.wrap(msg)
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
ThriftClientProtocol.dispatch(self, len_and_encoded)
@defer.inlineCallbacks
def connectionMade(self):
self._sendSASLMessage(self.START, self.sasl.mechanism)
initial_message = yield deferToThread(self.sasl.process)
self._sendSASLMessage(self.OK, initial_message)
while True:
status, challenge = yield self._receiveSASLMessage()
if status == self.OK:
response = yield deferToThread(self.sasl.process, challenge)
self._sendSASLMessage(self.OK, response)
elif status == self.COMPLETE:
if not self.sasl.complete:
msg = "The server erroneously indicated that SASL " \
"negotiation was complete"
raise TTransport.TTransportException(msg, message=msg)
else:
break
else:
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
raise TTransport.TTransportException(msg, message=msg)
self._sasl_negotiation_deferred = None
ThriftClientProtocol.connectionMade(self)
def _sendSASLMessage(self, status, body):
if body is None:
body = ""
header = struct.pack(">BI", status, len(body))
self.transport.write(header + body)
def _receiveSASLMessage(self):
self._sasl_negotiation_deferred = defer.Deferred()
self._sasl_negotiation_status = None
return self._sasl_negotiation_deferred
def connectionLost(self, reason=connectionDone):
if self.client:
ThriftClientProtocol.connectionLost(self, reason)
def dataReceived(self, data):
if self._sasl_negotiation_deferred:
# we got a sasl challenge in the format (status, length, challenge)
# save the status, let IntNStringReceiver piece the challenge data together
self._sasl_negotiation_status, = struct.unpack("B", data[0])
ThriftClientProtocol.dataReceived(self, data[1:])
else:
# normal frame, let IntNStringReceiver piece it together
ThriftClientProtocol.dataReceived(self, data)
def stringReceived(self, frame):
if self._sasl_negotiation_deferred:
# the frame is just a SASL challenge
response = (self._sasl_negotiation_status, frame)
self._sasl_negotiation_deferred.callback(response)
else:
# there's a second 4 byte length prefix inside the frame
decoded_frame = self.sasl.unwrap(frame[4:])
ThriftClientProtocol.stringReceived(self, decoded_frame)
class ThriftServerProtocol(basic.Int32StringReceiver): class ThriftServerProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1 MAX_LENGTH = 2 ** 31 - 1
@ -147,10 +257,9 @@ class IThriftClientFactory(Interface):
oprot_factory = Attribute("Output protocol factory") oprot_factory = Attribute("Output protocol factory")
@implementer(IThriftServerFactory)
class ThriftServerFactory(ServerFactory): class ThriftServerFactory(ServerFactory):
implements(IThriftServerFactory)
protocol = ThriftServerProtocol protocol = ThriftServerProtocol
def __init__(self, processor, iprot_factory, oprot_factory=None): def __init__(self, processor, iprot_factory, oprot_factory=None):
@ -162,10 +271,9 @@ class ThriftServerFactory(ServerFactory):
self.oprot_factory = oprot_factory self.oprot_factory = oprot_factory
@implementer(IThriftClientFactory)
class ThriftClientFactory(ClientFactory): class ThriftClientFactory(ClientFactory):
implements(IThriftClientFactory)
protocol = ThriftClientProtocol protocol = ThriftClientProtocol
def __init__(self, client_class, iprot_factory, oprot_factory=None): def __init__(self, client_class, iprot_factory, oprot_factory=None):

View file

@ -22,10 +22,10 @@ class, using the python standard library zlib module to implement
data compression. data compression.
""" """
from __future__ import division
import zlib import zlib
from io import StringIO
from .TTransport import TTransportBase, CReadableTransport from .TTransport import TTransportBase, CReadableTransport
from ..compat import BufferIO
class TZlibTransportFactory(object): class TZlibTransportFactory(object):
@ -88,8 +88,8 @@ class TZlibTransport(TTransportBase, CReadableTransport):
""" """
self.__trans = trans self.__trans = trans
self.compresslevel = compresslevel self.compresslevel = compresslevel
self.__rbuf = StringIO() self.__rbuf = BufferIO()
self.__wbuf = StringIO() self.__wbuf = BufferIO()
self._init_zlib() self._init_zlib()
self._init_stats() self._init_stats()
@ -97,8 +97,8 @@ class TZlibTransport(TTransportBase, CReadableTransport):
"""Internal method to initialize/reset the internal StringIO objects """Internal method to initialize/reset the internal StringIO objects
for read and write buffers. for read and write buffers.
""" """
self.__rbuf = StringIO() self.__rbuf = BufferIO()
self.__wbuf = StringIO() self.__wbuf = BufferIO()
def _init_stats(self): def _init_stats(self):
"""Internal method to reset the internal statistics counters """Internal method to reset the internal statistics counters
@ -203,7 +203,7 @@ class TZlibTransport(TTransportBase, CReadableTransport):
self.bytes_in += len(zbuf) self.bytes_in += len(zbuf)
self.bytes_in_comp += len(buf) self.bytes_in_comp += len(buf)
old = self.__rbuf.read() old = self.__rbuf.read()
self.__rbuf = StringIO(old + buf) self.__rbuf = BufferIO(old + buf)
if len(old) + len(buf) == 0: if len(old) + len(buf) == 0:
return False return False
return True return True
@ -228,7 +228,7 @@ class TZlibTransport(TTransportBase, CReadableTransport):
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
self.bytes_out_comp += len(ztail) self.bytes_out_comp += len(ztail)
if (len(zbuf) + len(ztail)) > 0: if (len(zbuf) + len(ztail)) > 0:
self.__wbuf = StringIO() self.__wbuf = BufferIO()
self.__trans.write(zbuf + ztail) self.__trans.write(zbuf + ztail)
self.__trans.flush() self.__trans.flush()
@ -244,5 +244,5 @@ class TZlibTransport(TTransportBase, CReadableTransport):
retstring += self.read(self.DEFAULT_BUFFSIZE) retstring += self.read(self.DEFAULT_BUFFSIZE)
while len(retstring) < reqlen: while len(retstring) < reqlen:
retstring += self.read(reqlen - len(retstring)) retstring += self.read(reqlen - len(retstring))
self.__rbuf = StringIO(retstring) self.__rbuf = BufferIO(retstring)
return self.__rbuf return self.__rbuf

View file

@ -0,0 +1,100 @@
#
# licensed to the apache software foundation (asf) under one
# or more contributor license agreements. see the notice file
# distributed with this work for additional information
# regarding copyright ownership. the asf licenses this file
# to you under the apache license, version 2.0 (the
# "license"); you may not use this file except in compliance
# with the license. you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing,
# software distributed under the license is distributed on an
# "as is" basis, without warranties or conditions of any
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import sys
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
def legacy_validate_callback(cert, hostname):
"""legacy method to validate the peer's SSL certificate, and to check
the commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
in certificates.
raises TTransportException if the certificate fails validation.
"""
if 'subject' not in cert:
raise TTransportException(
TTransportException.NOT_OPEN,
'No SSL certificate found from %s' % hostname)
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
if not isinstance(field, tuple):
continue
cert_pair = field[0]
if len(cert_pair) < 2:
continue
cert_key, cert_value = cert_pair[0:2]
if cert_key != 'commonName':
continue
certhost = cert_value
# this check should be performed by some sort of Access Manager
if certhost == hostname:
# success, cert commonName matches desired hostname
return
else:
raise TTransportException(
TTransportException.UNKNOWN,
'Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (hostname, certhost))
raise TTransportException(
TTransportException.UNKNOWN,
'Could not validate SSL certificate from host "%s". Cert=%s'
% (hostname, cert))
def _optional_dependencies():
try:
import ipaddress # noqa
logger.debug('ipaddress module is available')
ipaddr = True
except ImportError:
logger.warn('ipaddress module is unavailable')
ipaddr = False
if sys.hexversion < 0x030500F0:
try:
from backports.ssl_match_hostname import match_hostname, __version__ as ver
ver = list(map(int, ver.split('.')))
logger.debug('backports.ssl_match_hostname module is available')
match = match_hostname
if ver[0] * 10 + ver[1] >= 35:
return ipaddr, match
else:
logger.warn('backports.ssl_match_hostname module is too old')
ipaddr = False
except ImportError:
logger.warn('backports.ssl_match_hostname is unavailable')
ipaddr = False
try:
from ssl import match_hostname
logger.debug('ssl.match_hostname is available')
match = match_hostname
except ImportError:
logger.warn('using legacy validation callback')
match = legacy_validate_callback
return ipaddr, match
_match_has_ipaddress, _match_hostname = _optional_dependencies()