diff --git a/awips/dataaccess/DataAccessLayer.py b/awips/dataaccess/DataAccessLayer.py index 6899cb9..58943ce 100644 --- a/awips/dataaccess/DataAccessLayer.py +++ b/awips/dataaccess/DataAccessLayer.py @@ -1,59 +1,33 @@ -# # -# This software was developed and / or modified by Raytheon Company, -# pursuant to Contract DG133W-05-CQ-1067 with the US Government. -# -# U.S. EXPORT CONTROLLED TECHNICAL DATA -# This software product contains export-restricted data whose -# export/transfer/disclosure is restricted by U.S. law. Dissemination -# to non-U.S. persons whether in the United States or abroad requires -# an export license or other authorization. -# -# Contractor Name: Raytheon Company -# Contractor Address: 6825 Pine Street, Suite 340 -# Mail Stop B8 -# Omaha, NE 68106 -# 402.291.0100 -# -# See the AWIPS II Master Rights File ("Master Rights File.pdf") for -# further licensing information. -# # - - # # Published interface for awips.dataaccess package # # # SOFTWARE HISTORY # -# Date Ticket# Engineer Description -# ------------ ---------- ----------- -------------------------- -# 12/10/12 njensen Initial Creation. -# Feb 14, 2013 1614 bsteffen refactor data access framework -# to use single request. -# 04/10/13 1871 mnash move getLatLonCoords to JGridData and add default args -# 05/29/13 2023 dgilling Hook up ThriftClientRouter. -# 03/03/14 2673 bsteffen Add ability to query only ref times. -# 07/22/14 3185 njensen Added optional/default args to newDataRequest -# 07/30/14 3185 njensen Renamed valid identifiers to optional -# Apr 26, 2015 4259 njensen Updated for new JEP API -# Apr 13, 2016 5379 tgurney Add getIdentifierValues() -# Jun 01, 2016 5587 tgurney Add new signatures for -# getRequiredIdentifiers() and -# getOptionalIdentifiers() -# Oct 18, 2016 5916 bsteffen Add setLazyLoadGridLatLon +# Date Ticket# Engineer Description +# ------------ ------- ---------- ------------------------- +# 12/10/12 njensen Initial Creation. +# Feb 14, 2013 1614 bsteffen refactor data access framework to use single request. +# 04/10/13 1871 mnash move getLatLonCoords to JGridData and add default args +# 05/29/13 2023 dgilling Hook up ThriftClientRouter. +# 03/03/14 2673 bsteffen Add ability to query only ref times. +# 07/22/14 3185 njensen Added optional/default args to newDataRequest +# 07/30/14 3185 njensen Renamed valid identifiers to optional +# Apr 26, 2015 4259 njensen Updated for new JEP API +# Apr 13, 2016 5379 tgurney Add getIdentifierValues(), getRequiredIdentifiers(), +# and getOptionalIdentifiers() +# Oct 07, 2016 ---- mjames@ucar Added getForecastRun +# Oct 18, 2016 5916 bsteffen Add setLazyLoadGridLatLon +# Oct 11, 2018 ---- mjames@ucar Added getMetarObs() getSynopticObs() # -# - import sys -import subprocess import warnings THRIFT_HOST = "edex" USING_NATIVE_THRIFT = False - if 'jep' in sys.modules: # intentionally do not catch if this fails to import, we want it to # be obvious that something is configured wrong when running from within @@ -66,6 +40,147 @@ else: USING_NATIVE_THRIFT = True +def getRadarProductIDs(availableParms): + """ + Get only the numeric idetifiers for NEXRAD3 products. + + Args: + availableParms: Full list of radar parameters + + Returns: + List of filtered parameters + """ + productIDs = [] + for p in list(availableParms): + try: + if isinstance(int(p), int): + productIDs.append(str(p)) + except ValueError: + pass + + return productIDs + + +def getRadarProductNames(availableParms): + """ + Get only the named idetifiers for NEXRAD3 products. + + Args: + availableParms: Full list of radar parameters + + Returns: + List of filtered parameters + """ + productNames = [] + for p in list(availableParms): + if len(p) > 3: + productNames.append(p) + + return productNames + + +def getMetarObs(response): + """ + Processes a DataAccessLayer "obs" response into a dictionary, + with special consideration for multi-value parameters + "presWeather", "skyCover", and "skyLayerBase". + + Args: + response: DAL getGeometry() list + + Returns: + A dictionary of METAR obs + """ + from datetime import datetime + single_val_params = ["timeObs", "stationName", "longitude", "latitude", + "temperature", "dewpoint", "windDir", + "windSpeed", "seaLevelPress"] + multi_val_params = ["presWeather", "skyCover", "skyLayerBase"] + params = single_val_params + multi_val_params + station_names, pres_weather, sky_cov, sky_layer_base = [], [], [], [] + obs = dict({params: [] for params in params}) + for ob in response: + avail_params = ob.getParameters() + if "presWeather" in avail_params: + pres_weather.append(ob.getString("presWeather")) + elif "skyCover" in avail_params and "skyLayerBase" in avail_params: + sky_cov.append(ob.getString("skyCover")) + sky_layer_base.append(ob.getNumber("skyLayerBase")) + else: + # If we already have a record for this stationName, skip + if ob.getString('stationName') not in station_names: + station_names.append(ob.getString('stationName')) + for param in single_val_params: + if param in avail_params: + if param == 'timeObs': + obs[param].append(datetime.fromtimestamp(ob.getNumber(param) / 1000.0)) + else: + try: + obs[param].append(ob.getNumber(param)) + except TypeError: + obs[param].append(ob.getString(param)) + else: + obs[param].append(None) + + obs['presWeather'].append(pres_weather) + obs['skyCover'].append(sky_cov) + obs['skyLayerBase'].append(sky_layer_base) + pres_weather = [] + sky_cov = [] + sky_layer_base = [] + return obs + + +def getSynopticObs(response): + """ + Processes a DataAccessLayer "sfcobs" response into a dictionary + of available parameters. + + Args: + response: DAL getGeometry() list + + Returns: + A dictionary of synop obs + """ + from datetime import datetime + station_names = [] + params = response[0].getParameters() + sfcobs = dict({params: [] for params in params}) + for sfcob in response: + # If we already have a record for this stationId, skip + if sfcob.getString('stationId') not in station_names: + station_names.append(sfcob.getString('stationId')) + for param in params: + if param == 'timeObs': + sfcobs[param].append(datetime.fromtimestamp(sfcob.getNumber(param) / 1000.0)) + else: + try: + sfcobs[param].append(sfcob.getNumber(param)) + except TypeError: + sfcobs[param].append(sfcob.getString(param)) + + return sfcobs + + +def getForecastRun(cycle, times): + """ + Get the latest forecast run (list of objects) from all + all cycles and times returned from DataAccessLayer "grid" + response. + + Args: + cycle: Forecast cycle reference time + times: All available times/cycles + + Returns: + DataTime array for a single forecast run + """ + fcstRun = [] + for t in times: + if str(t)[:19] == str(cycle): + fcstRun.append(t) + return fcstRun + def getAvailableTimes(request, refTimeOnly=False): """ @@ -74,7 +189,7 @@ def getAvailableTimes(request, refTimeOnly=False): Args: request: the IDataRequest to get data for refTimeOnly: optional, use True if only unique refTimes should be - returned (without a forecastHr) + returned (without a forecastHr) Returns: a list of DataTimes @@ -91,7 +206,7 @@ def getGridData(request, times=[]): Args: request: the IDataRequest to get data for times: a list of DataTimes, a TimeRange, or None if the data is time - agnostic + agnostic Returns: a list of IGridData @@ -108,10 +223,10 @@ def getGeometryData(request, times=[]): Args: request: the IDataRequest to get data for times: a list of DataTimes, a TimeRange, or None if the data is time - agnostic + agnostic Returns: - a list of IGeometryData + a list of IGeometryData """ return router.getGeometryData(request, times) @@ -204,8 +319,9 @@ def getIdentifierValues(request, identifierKey): """ return router.getIdentifierValues(request, identifierKey) + def newDataRequest(datatype=None, **kwargs): - """" + """ Creates a new instance of IDataRequest suitable for the runtime environment. All args are optional and exist solely for convenience. @@ -215,13 +331,14 @@ def newDataRequest(datatype=None, **kwargs): levels: a list of levels to set on the request locationNames: a list of locationNames to set on the request envelope: an envelope to limit the request - **kwargs: any leftover kwargs will be set as identifiers + kwargs: any leftover kwargs will be set as identifiers Returns: a new IDataRequest """ return router.newDataRequest(datatype, **kwargs) + def getSupportedDatatypes(): """ Gets the datatypes that are supported by the framework @@ -239,7 +356,7 @@ def changeEDEXHost(newHostName): method will throw a TypeError. Args: - newHostHame: the EDEX host to connect to + newHostName: the EDEX host to connect to """ if USING_NATIVE_THRIFT: global THRIFT_HOST @@ -249,6 +366,7 @@ def changeEDEXHost(newHostName): else: raise TypeError("Cannot call changeEDEXHost when using JepRouter.") + def setLazyLoadGridLatLon(lazyLoadGridLatLon): """ Provide a hint to the Data Access Framework indicating whether to load the @@ -261,7 +379,7 @@ def setLazyLoadGridLatLon(lazyLoadGridLatLon): set to False if it is guaranteed that all lat/lon information is needed and it would be better to get any performance overhead for generating the lat/lon data out of the way during the initial request. - + Args: lazyLoadGridLatLon: Boolean value indicating whether to lazy load. diff --git a/thrift/TMultiplexedProcessor.py b/thrift/TMultiplexedProcessor.py new file mode 100644 index 0000000..ff88430 --- /dev/null +++ b/thrift/TMultiplexedProcessor.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import TProcessor, TMessageType +from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol +from thrift.protocol.TProtocol import TProtocolException + + +class TMultiplexedProcessor(TProcessor): + def __init__(self): + self.defaultProcessor = None + self.services = {} + + def registerDefault(self, processor): + """ + If a non-multiplexed processor connects to the server and wants to + communicate, use the given processor to handle it. This mechanism + allows servers to upgrade from non-multiplexed to multiplexed in a + backwards-compatible way and still handle old clients. + """ + self.defaultProcessor = processor + + def registerProcessor(self, serviceName, processor): + self.services[serviceName] = processor + + def on_message_begin(self, func): + for key in self.services.keys(): + self.services[key].on_message_begin(func) + + def process(self, iprot, oprot): + (name, type, seqid) = iprot.readMessageBegin() + if type != TMessageType.CALL and type != TMessageType.ONEWAY: + raise TProtocolException( + TProtocolException.NOT_IMPLEMENTED, + "TMultiplexedProtocol only supports CALL & ONEWAY") + + index = name.find(TMultiplexedProtocol.SEPARATOR) + if index < 0: + if self.defaultProcessor: + return self.defaultProcessor.process( + StoredMessageProtocol(iprot, (name, type, seqid)), oprot) + else: + raise TProtocolException( + TProtocolException.NOT_IMPLEMENTED, + "Service name not found in message name: " + name + ". " + + "Did you forget to use TMultiplexedProtocol in your client?") + + serviceName = name[0:index] + call = name[index + len(TMultiplexedProtocol.SEPARATOR):] + if serviceName not in self.services: + raise TProtocolException( + TProtocolException.NOT_IMPLEMENTED, + "Service name not found: " + serviceName + ". " + + "Did you forget to call registerProcessor()?") + + standardMessage = (call, type, seqid) + return self.services[serviceName].process( + StoredMessageProtocol(iprot, standardMessage), oprot) + + +class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): + def __init__(self, protocol, messageBegin): + self.messageBegin = messageBegin + + def readMessageBegin(self): + return self.messageBegin diff --git a/thrift/TRecursive.py b/thrift/TRecursive.py new file mode 100644 index 0000000..abf202c --- /dev/null +++ b/thrift/TRecursive.py @@ -0,0 +1,83 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from thrift.Thrift import TType + +TYPE_IDX = 1 +SPEC_ARGS_IDX = 3 +SPEC_ARGS_CLASS_REF_IDX = 0 +SPEC_ARGS_THRIFT_SPEC_IDX = 1 + + +def fix_spec(all_structs): + """Wire up recursive references for all TStruct definitions inside of each thrift_spec.""" + for struc in all_structs: + spec = struc.thrift_spec + for thrift_spec in spec: + if thrift_spec is None: + continue + elif thrift_spec[TYPE_IDX] == TType.STRUCT: + other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec + thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other + elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET): + _fix_list_or_set(thrift_spec[SPEC_ARGS_IDX]) + elif thrift_spec[TYPE_IDX] == TType.MAP: + _fix_map(thrift_spec[SPEC_ARGS_IDX]) + + +def _fix_list_or_set(element_type): + # For a list or set, the thrift_spec entry looks like, + # (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1 + # so ``element_type`` will be, + # (TType.STRUCT, [RecList, None], False) + if element_type[0] == TType.STRUCT: + element_type[1][1] = element_type[1][0].thrift_spec + elif element_type[0] in (TType.LIST, TType.SET): + _fix_list_or_set(element_type[1]) + elif element_type[0] == TType.MAP: + _fix_map(element_type[1]) + + +def _fix_map(element_type): + # For a map of key -> value type, ``element_type`` will be, + # (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, ) + # which is just a normal struct definition. + # + # For a map of key -> list / set, ``element_type`` will be, + # (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False) + # and we need to process the 3rd element as a list. + # + # For a map of key -> map, ``element_type`` will be, + # (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT, + # [RecMapMap, None], False), False) + # and need to process 3rd element as a map. + + # Is the map key a struct? + if element_type[0] == TType.STRUCT: + element_type[1][1] = element_type[1][0].thrift_spec + elif element_type[0] in (TType.LIST, TType.SET): + _fix_list_or_set(element_type[1]) + elif element_type[0] == TType.MAP: + _fix_map(element_type[1]) + + # Is the map value a struct? + if element_type[2] == TType.STRUCT: + element_type[3][1] = element_type[3][0].thrift_spec + elif element_type[2] in (TType.LIST, TType.SET): + _fix_list_or_set(element_type[3]) + elif element_type[2] == TType.MAP: + _fix_map(element_type[3]) diff --git a/thrift/TSCons.py b/thrift/TSCons.py index d3176ed..bc67d70 100644 --- a/thrift/TSCons.py +++ b/thrift/TSCons.py @@ -19,17 +19,18 @@ from os import path from SCons.Builder import Builder +from six.moves import map def scons_env(env, add=''): - opath = path.dirname(path.abspath('$TARGET')) - lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' - cppbuild = Builder(action=lstr) - env.Append(BUILDERS={'ThriftCpp': cppbuild}) + opath = path.dirname(path.abspath('$TARGET')) + lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' + cppbuild = Builder(action=lstr) + env.Append(BUILDERS={'ThriftCpp': cppbuild}) def gen_cpp(env, dir, file): - scons_env(env) - suffixes = ['_types.h', '_types.cpp'] - targets = ['gen-cpp/' + file + s for s in suffixes] - return env.ThriftCpp(targets, dir + file + '.thrift') + scons_env(env) + suffixes = ['_types.h', '_types.cpp'] + targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/thrift/TTornado.py b/thrift/TTornado.py new file mode 100644 index 0000000..5eff11d --- /dev/null +++ b/thrift/TTornado.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from __future__ import absolute_import +import logging +import socket +import struct + +from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer + +from io import BytesIO +from collections import deque +from contextlib import contextmanager +from tornado import gen, iostream, ioloop, tcpserver, concurrent + +__all__ = ['TTornadoServer', 'TTornadoStreamTransport'] + +logger = logging.getLogger(__name__) + + +class _Lock(object): + def __init__(self): + self._waiters = deque() + + def acquired(self): + return len(self._waiters) > 0 + + @gen.coroutine + def acquire(self): + blocker = self._waiters[-1] if self.acquired() else None + future = concurrent.Future() + self._waiters.append(future) + if blocker: + yield blocker + + raise gen.Return(self._lock_context()) + + def release(self): + assert self.acquired(), 'Lock not aquired' + future = self._waiters.popleft() + future.set_result(None) + + @contextmanager + def _lock_context(self): + try: + yield + finally: + self.release() + + +class TTornadoStreamTransport(TTransportBase): + """a framed, buffered transport over a Tornado stream""" + def __init__(self, host, port, stream=None, io_loop=None): + self.host = host + self.port = port + self.io_loop = io_loop or ioloop.IOLoop.current() + self.__wbuf = BytesIO() + self._read_lock = _Lock() + + # servers provide a ready-to-go stream + self.stream = stream + + def with_timeout(self, timeout, future): + return gen.with_timeout(timeout, future, self.io_loop) + + @gen.coroutine + def open(self, timeout=None): + logger.debug('socket connecting') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.stream = iostream.IOStream(sock) + + try: + connect = self.stream.connect((self.host, self.port)) + if timeout is not None: + yield self.with_timeout(timeout, connect) + else: + yield connect + except (socket.error, IOError, ioloop.TimeoutError) as e: + message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e) + raise TTransportException( + type=TTransportException.NOT_OPEN, + message=message) + + raise gen.Return(self) + + def set_close_callback(self, callback): + """ + Should be called only after open() returns + """ + self.stream.set_close_callback(callback) + + def close(self): + # don't raise if we intend to close + self.stream.set_close_callback(None) + self.stream.close() + + def read(self, _): + # The generated code for Tornado shouldn't do individual reads -- only + # frames at a time + assert False, "you're doing it wrong" + + @contextmanager + def io_exception_context(self): + try: + yield + except (socket.error, IOError) as e: + raise TTransportException( + type=TTransportException.END_OF_FILE, + message=str(e)) + except iostream.StreamBufferFullError as e: + raise TTransportException( + type=TTransportException.UNKNOWN, + message=str(e)) + + @gen.coroutine + def readFrame(self): + # IOStream processes reads one at a time + with (yield self._read_lock.acquire()): + with self.io_exception_context(): + frame_header = yield self.stream.read_bytes(4) + if len(frame_header) == 0: + raise iostream.StreamClosedError('Read zero bytes from stream') + frame_length, = struct.unpack('!i', frame_header) + frame = yield self.stream.read_bytes(frame_length) + raise gen.Return(frame) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + frame = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + frame_length = struct.pack('!i', len(frame)) + self.__wbuf = BytesIO() + with self.io_exception_context(): + return self.stream.write(frame_length + frame) + + +class TTornadoServer(tcpserver.TCPServer): + def __init__(self, processor, iprot_factory, oprot_factory=None, + *args, **kwargs): + super(TTornadoServer, self).__init__(*args, **kwargs) + + self._processor = processor + self._iprot_factory = iprot_factory + self._oprot_factory = (oprot_factory if oprot_factory is not None + else iprot_factory) + + @gen.coroutine + def handle_stream(self, stream, address): + host, port = address[:2] + trans = TTornadoStreamTransport(host=host, port=port, stream=stream, + io_loop=self.io_loop) + oprot = self._oprot_factory.getProtocol(trans) + + try: + while not trans.stream.closed(): + try: + frame = yield trans.readFrame() + except TTransportException as e: + if e.type == TTransportException.END_OF_FILE: + break + else: + raise + tr = TMemoryBuffer(frame) + iprot = self._iprot_factory.getProtocol(tr) + yield self._processor.process(iprot, oprot) + except Exception: + logger.exception('thrift exception in handle_stream') + trans.close() + + logger.info('client disconnected %s:%d', host, port) diff --git a/thrift/Thrift.py b/thrift/Thrift.py index 707a8cc..81fe8cf 100644 --- a/thrift/Thrift.py +++ b/thrift/Thrift.py @@ -17,141 +17,177 @@ # under the License. # -import sys + +class TType(object): + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + + _VALUES_TO_NAMES = ( + 'STOP', + 'VOID', + 'BOOL', + 'BYTE', + 'DOUBLE', + None, + 'I16', + None, + 'I32', + None, + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16', + ) -class TType: - STOP = 0 - VOID = 1 - BOOL = 2 - BYTE = 3 - I08 = 3 - DOUBLE = 4 - I16 = 6 - I32 = 8 - I64 = 10 - STRING = 11 - UTF7 = 11 - STRUCT = 12 - MAP = 13 - SET = 14 - LIST = 15 - UTF8 = 16 - UTF16 = 17 - - _VALUES_TO_NAMES = ('STOP', - 'VOID', - 'BOOL', - 'BYTE', - 'DOUBLE', - None, - 'I16', - None, - 'I32', - None, - 'I64', - 'STRING', - 'STRUCT', - 'MAP', - 'SET', - 'LIST', - 'UTF8', - 'UTF16') +class TMessageType(object): + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 -class TMessageType: - CALL = 1 - REPLY = 2 - EXCEPTION = 3 - ONEWAY = 4 +class TProcessor(object): + """Base class for processor, which works on two streams.""" + def process(self, iprot, oprot): + """ + Process a request. The normal behvaior is to have the + processor invoke the correct handler and then it is the + server's responsibility to write the response to oprot. + """ + pass -class TProcessor: - """Base class for procsessor, which works on two streams.""" - - def process(iprot, oprot): - pass + def on_message_begin(self, func): + """ + Install a callback that receives (name, type, seqid) + after the message header is read. + """ + pass class TException(Exception): - """Base class for all thrift exceptions.""" + """Base class for all thrift exceptions.""" - # BaseException.message is deprecated in Python v[2.6,3.0) - if (2, 6, 0) <= sys.version_info < (3, 0): - def _get_message(self): - return self._message - - def _set_message(self, message): - self._message = message - message = property(_get_message, _set_message) - - def __init__(self, message=None): - Exception.__init__(self, message) - self.message = message + def __init__(self, message=None): + Exception.__init__(self, message) + super(TException, self).__setattr__("message", message) class TApplicationException(TException): - """Application level thrift exceptions.""" + """Application level thrift exceptions.""" - UNKNOWN = 0 - UNKNOWN_METHOD = 1 - INVALID_MESSAGE_TYPE = 2 - WRONG_METHOD_NAME = 3 - BAD_SEQUENCE_ID = 4 - MISSING_RESULT = 5 - INTERNAL_ERROR = 6 - PROTOCOL_ERROR = 7 + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type - def __str__(self): - if self.message: - return self.message - elif self.type == self.UNKNOWN_METHOD: - return 'Unknown method' - elif self.type == self.INVALID_MESSAGE_TYPE: - return 'Invalid message type' - elif self.type == self.WRONG_METHOD_NAME: - return 'Wrong method name' - elif self.type == self.BAD_SEQUENCE_ID: - return 'Bad sequence ID' - elif self.type == self.MISSING_RESULT: - return 'Missing result' - else: - return 'Default (unknown) TApplicationException' - - def read(self, iprot): - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.message = iprot.readString() + def __str__(self): + if self.message: + return self.message + elif self.type == self.UNKNOWN_METHOD: + return 'Unknown method' + elif self.type == self.INVALID_MESSAGE_TYPE: + return 'Invalid message type' + elif self.type == self.WRONG_METHOD_NAME: + return 'Wrong method name' + elif self.type == self.BAD_SEQUENCE_ID: + return 'Bad sequence ID' + elif self.type == self.MISSING_RESULT: + return 'Missing result' + elif self.type == self.INTERNAL_ERROR: + return 'Internal error' + elif self.type == self.PROTOCOL_ERROR: + return 'Protocol error' + elif self.type == self.INVALID_TRANSFORM: + return 'Invalid transform' + elif self.type == self.INVALID_PROTOCOL: + return 'Invalid protocol' + elif self.type == self.UNSUPPORTED_CLIENT_TYPE: + return 'Unsupported client type' else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.type = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() + return 'Default (unknown) TApplicationException' - def write(self, oprot): - oprot.writeStructBegin('TApplicationException') - if self.message is not None: - oprot.writeFieldBegin('message', TType.STRING, 1) - oprot.writeString(self.message) - oprot.writeFieldEnd() - if self.type is not None: - oprot.writeFieldBegin('type', TType.I32, 2) - oprot.writeI32(self.type) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() + def read(self, iprot): + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.message = iprot.readString() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.type = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + oprot.writeStructBegin('TApplicationException') + if self.message is not None: + oprot.writeFieldBegin('message', TType.STRING, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + if self.type is not None: + oprot.writeFieldBegin('type', TType.I32, 2) + oprot.writeI32(self.type) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + +class TFrozenDict(dict): + """A dictionary that is "frozen" like a frozenset""" + + def __init__(self, *args, **kwargs): + super(TFrozenDict, self).__init__(*args, **kwargs) + # Sort the items so they will be in a consistent order. + # XOR in the hash of the class so we don't collide with + # the hash of a list of tuples. + self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items()))) + + def __setitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") + + def __hash__(self): + return self.__hashval diff --git a/thrift/compat.py b/thrift/compat.py new file mode 100644 index 0000000..0e8271d --- /dev/null +++ b/thrift/compat.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + +if sys.version_info[0] == 2: + + from cStringIO import StringIO as BufferIO + + def binary_to_str(bin_val): + return bin_val + + def str_to_binary(str_val): + return str_val + + def byte_index(bytes_val, i): + return ord(bytes_val[i]) + +else: + + from io import BytesIO as BufferIO # noqa + + def binary_to_str(bin_val): + return bin_val.decode('utf8') + + def str_to_binary(str_val): + return bytes(str_val, 'utf8') + + def byte_index(bytes_val, i): + return bytes_val[i] diff --git a/thrift/protocol/TBase.py b/thrift/protocol/TBase.py index 6cd6c28..6c6ef18 100644 --- a/thrift/protocol/TBase.py +++ b/thrift/protocol/TBase.py @@ -17,65 +17,70 @@ # under the License. # -from thrift.Thrift import * -from thrift.protocol import TBinaryProtocol from thrift.transport import TTransport -try: - from thrift.protocol import fastbinary -except: - fastbinary = None - class TBase(object): - __slots__ = [] + __slots__ = () - def __repr__(self): - L = ['%s=%r' % (key, getattr(self, key)) - for key in self.__slots__] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - for attr in self.__slots__: - my_val = getattr(self, attr) - other_val = getattr(other, attr) - if my_val != other_val: - return False - return True + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True - def __ne__(self, other): - return not (self == other) + def __ne__(self, other): + return not (self == other) - def read(self, iprot): - if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and - isinstance(iprot.trans, TTransport.CReadableTransport) and - self.thrift_spec is not None and - fastbinary is not None): - fastbinary.decode_binary(self, - iprot.trans, - (self.__class__, self.thrift_spec)) - return - iprot.readStruct(self, self.thrift_spec) + def read(self, iprot): + if (iprot._fast_decode is not None and + isinstance(iprot.trans, TTransport.CReadableTransport) and + self.thrift_spec is not None): + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + else: + iprot.readStruct(self, self.thrift_spec) - def write(self, oprot): - if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and - self.thrift_spec is not None and - fastbinary is not None): - oprot.trans.write( - fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStruct(self, self.thrift_spec) + def write(self, oprot): + if (oprot._fast_encode is not None and self.thrift_spec is not None): + oprot.trans.write( + oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + else: + oprot.writeStruct(self, self.thrift_spec) -class TExceptionBase(Exception): - # old style class so python2.4 can raise exceptions derived from this - # This can't inherit from TBase because of that limitation. - __slots__ = [] +class TExceptionBase(TBase, Exception): + pass - __repr__ = TBase.__repr__.__func__ - __eq__ = TBase.__eq__.__func__ - __ne__ = TBase.__ne__.__func__ - read = TBase.read.__func__ - write = TBase.write.__func__ + +class TFrozenBase(TBase): + def __setitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __hash__(self, *args): + return hash(self.__class__) ^ hash(self.__slots__) + + @classmethod + def read(cls, iprot): + if (iprot._fast_decode is not None and + isinstance(iprot.trans, TTransport.CReadableTransport) and + cls.thrift_spec is not None): + self = cls() + return iprot._fast_decode(None, iprot, + [self.__class__, self.thrift_spec]) + else: + return iprot.readStruct(cls, cls.thrift_spec, True) + + +class TFrozenExceptionBase(TFrozenBase, TExceptionBase): + pass diff --git a/thrift/protocol/TBinaryProtocol.py b/thrift/protocol/TBinaryProtocol.py index dbcb1e9..6b2facc 100644 --- a/thrift/protocol/TBinaryProtocol.py +++ b/thrift/protocol/TBinaryProtocol.py @@ -17,248 +17,285 @@ # under the License. # -from .TProtocol import * +from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory from struct import pack, unpack class TBinaryProtocol(TProtocolBase): - """Binary implementation of the Thrift protocol driver.""" + """Binary implementation of the Thrift protocol driver.""" - # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be - # positive, converting this into a long. If we hardcode the int value - # instead it'll stay in 32 bit-land. + # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be + # positive, converting this into a long. If we hardcode the int value + # instead it'll stay in 32 bit-land. - # VERSION_MASK = 0xffff0000 - VERSION_MASK = -65536 + # VERSION_MASK = 0xffff0000 + VERSION_MASK = -65536 - # VERSION_1 = 0x80010000 - VERSION_1 = -2147418112 + # VERSION_1 = 0x80010000 + VERSION_1 = -2147418112 - TYPE_MASK = 0x000000ff + TYPE_MASK = 0x000000ff - def __init__(self, trans, strictRead=False, strictWrite=True): - TProtocolBase.__init__(self, trans) - self.strictRead = strictRead - self.strictWrite = strictWrite + def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): + TProtocolBase.__init__(self, trans) + self.strictRead = strictRead + self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) - def writeMessageBegin(self, name, type, seqid): - if self.strictWrite: - self.writeI32(TBinaryProtocol.VERSION_1 | type) - self.writeString(name) - self.writeI32(seqid) - else: - self.writeString(name) - self.writeByte(type) - self.writeI32(seqid) + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) - def writeMessageEnd(self): - pass + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) - def writeStructBegin(self, name): - pass + def writeMessageBegin(self, name, type, seqid): + if self.strictWrite: + self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeString(name) + self.writeI32(seqid) + else: + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) - def writeStructEnd(self): - pass + def writeMessageEnd(self): + pass - def writeFieldBegin(self, name, type, id): - self.writeByte(type) - self.writeI16(id) + def writeStructBegin(self, name): + pass - def writeFieldEnd(self): - pass + def writeStructEnd(self): + pass - def writeFieldStop(self): - self.writeByte(TType.STOP) + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) - def writeMapBegin(self, ktype, vtype, size): - self.writeByte(ktype) - self.writeByte(vtype) - self.writeI32(size) + def writeFieldEnd(self): + pass - def writeMapEnd(self): - pass + def writeFieldStop(self): + self.writeByte(TType.STOP) - def writeListBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) - def writeListEnd(self): - pass + def writeMapEnd(self): + pass - def writeSetBegin(self, etype, size): - self.writeByte(etype) - self.writeI32(size) + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) - def writeSetEnd(self): - pass + def writeListEnd(self): + pass - def writeBool(self, bool): - if bool: - self.writeByte(1) - else: - self.writeByte(0) + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) - def writeByte(self, byte): - buff = pack("!b", byte) - self.trans.write(buff) + def writeSetEnd(self): + pass - def writeI16(self, i16): - buff = pack("!h", i16) - self.trans.write(buff) + def writeBool(self, bool): + if bool: + self.writeByte(1) + else: + self.writeByte(0) - def writeI32(self, i32): - buff = pack("!i", i32) - self.trans.write(buff) + def writeByte(self, byte): + buff = pack("!b", byte) + self.trans.write(buff) - def writeI64(self, i64): - buff = pack("!q", i64) - self.trans.write(buff) + def writeI16(self, i16): + buff = pack("!h", i16) + self.trans.write(buff) - def writeDouble(self, dub): - buff = pack("!d", dub) - self.trans.write(buff) + def writeI32(self, i32): + buff = pack("!i", i32) + self.trans.write(buff) - def writeString(self, str): - self.writeI32(len(str)) - self.trans.write(str) + def writeI64(self, i64): + buff = pack("!q", i64) + self.trans.write(buff) - def readMessageBegin(self): - sz = self.readI32() - if sz < 0: - version = sz & TBinaryProtocol.VERSION_MASK - if version != TBinaryProtocol.VERSION_1: - raise TProtocolException( - type=TProtocolException.BAD_VERSION, - message='Bad version in readMessageBegin: %d' % (sz)) - type = sz & TBinaryProtocol.TYPE_MASK - name = self.readString() - seqid = self.readI32() - else: - if self.strictRead: - raise TProtocolException(type=TProtocolException.BAD_VERSION, - message='No protocol version header') - name = self.trans.readAll(sz) - type = self.readByte() - seqid = self.readI32() - return (name, type, seqid) + def writeDouble(self, dub): + buff = pack("!d", dub) + self.trans.write(buff) - def readMessageEnd(self): - pass + def writeBinary(self, str): + self.writeI32(len(str)) + self.trans.write(str) - def readStructBegin(self): - pass + def readMessageBegin(self): + sz = self.readI32() + if sz < 0: + version = sz & TBinaryProtocol.VERSION_MASK + if version != TBinaryProtocol.VERSION_1: + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in readMessageBegin: %d' % (sz)) + type = sz & TBinaryProtocol.TYPE_MASK + name = self.readString() + seqid = self.readI32() + else: + if self.strictRead: + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') + name = self.trans.readAll(sz) + type = self.readByte() + seqid = self.readI32() + return (name, type, seqid) - def readStructEnd(self): - pass + def readMessageEnd(self): + pass - def readFieldBegin(self): - type = self.readByte() - if type == TType.STOP: - return (None, type, 0) - id = self.readI16() - return (None, type, id) + def readStructBegin(self): + pass - def readFieldEnd(self): - pass + def readStructEnd(self): + pass - def readMapBegin(self): - ktype = self.readByte() - vtype = self.readByte() - size = self.readI32() - return (ktype, vtype, size) + def readFieldBegin(self): + type = self.readByte() + if type == TType.STOP: + return (None, type, 0) + id = self.readI16() + return (None, type, id) - def readMapEnd(self): - pass + def readFieldEnd(self): + pass - def readListBegin(self): - etype = self.readByte() - size = self.readI32() - return (etype, size) + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (ktype, vtype, size) - def readListEnd(self): - pass + def readMapEnd(self): + pass - def readSetBegin(self): - etype = self.readByte() - size = self.readI32() - return (etype, size) + def readListBegin(self): + etype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (etype, size) - def readSetEnd(self): - pass + def readListEnd(self): + pass - def readBool(self): - byte = self.readByte() - if byte == 0: - return False - return True + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (etype, size) - def readByte(self): - buff = self.trans.readAll(1) - val, = unpack('!b', buff) - return val + def readSetEnd(self): + pass - def readI16(self): - buff = self.trans.readAll(2) - val, = unpack('!h', buff) - return val + def readBool(self): + byte = self.readByte() + if byte == 0: + return False + return True - def readI32(self): - buff = self.trans.readAll(4) - try: - val, = unpack('!i', buff) - except TypeError: - #str does not support the buffer interface - val, = unpack('!i', buff) - return val + def readByte(self): + buff = self.trans.readAll(1) + val, = unpack('!b', buff) + return val - def readI64(self): - buff = self.trans.readAll(8) - val, = unpack('!q', buff) - return val + def readI16(self): + buff = self.trans.readAll(2) + val, = unpack('!h', buff) + return val - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val + def readI32(self): + buff = self.trans.readAll(4) + val, = unpack('!i', buff) + return val - def readString(self): - len = self.readI32() - str = self.trans.readAll(len) - return str + def readI64(self): + buff = self.trans.readAll(8) + val, = unpack('!q', buff) + return val + + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def readBinary(self): + size = self.readI32() + self._check_string_length(size) + s = self.trans.readAll(size) + return s -class TBinaryProtocolFactory: - def __init__(self, strictRead=False, strictWrite=True): - self.strictRead = strictRead - self.strictWrite = strictWrite +class TBinaryProtocolFactory(TProtocolFactory): + def __init__(self, strictRead=False, strictWrite=True, **kwargs): + self.strictRead = strictRead + self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) - def getProtocol(self, trans): - prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) - return prot + def getProtocol(self, trans): + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit) + return prot class TBinaryProtocolAccelerated(TBinaryProtocol): - """C-Accelerated version of TBinaryProtocol. + """C-Accelerated version of TBinaryProtocol. - This class does not override any of TBinaryProtocol's methods, - but the generated code recognizes it directly and will call into - our C module to do the encoding, bypassing this object entirely. - We inherit from TBinaryProtocol so that the normal TBinaryProtocol - encoding can happen if the fastbinary module doesn't work for some - reason. (TODO(dreiss): Make this happen sanely in more cases.) + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely in more cases.) + To disable this behavior, pass fallback=False constructor argument. - In order to take advantage of the C module, just use - TBinaryProtocolAccelerated instead of TBinaryProtocol. + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. - NOTE: This code was contributed by an external developer. - The internal Thrift team has reviewed and tested it, - but we cannot guarantee that it is production-ready. - Please feel free to report bugs and/or success stories - to the public mailing list. - """ - pass + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + pass + + def __init__(self, *args, **kwargs): + fallback = kwargs.pop('fallback', True) + super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs) + try: + from thrift.protocol import fastbinary + except ImportError: + if not fallback: + raise + else: + self._fast_decode = fastbinary.decode_binary + self._fast_encode = fastbinary.encode_binary -class TBinaryProtocolAcceleratedFactory: - def getProtocol(self, trans): - return TBinaryProtocolAccelerated(trans) +class TBinaryProtocolAcceleratedFactory(TProtocolFactory): + def __init__(self, + string_length_limit=None, + container_length_limit=None, + fallback=True): + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + self._fallback = fallback + + def getProtocol(self, trans): + return TBinaryProtocolAccelerated( + trans, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit, + fallback=self._fallback) diff --git a/thrift/protocol/TCompactProtocol.py b/thrift/protocol/TCompactProtocol.py index a3385e1..700e792 100644 --- a/thrift/protocol/TCompactProtocol.py +++ b/thrift/protocol/TCompactProtocol.py @@ -17,9 +17,11 @@ # under the License. # -from .TProtocol import * +from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits from struct import pack, unpack +from ..compat import binary_to_str, str_to_binary + __all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] CLEAR = 0 @@ -34,370 +36,452 @@ BOOL_READ = 8 def make_helper(v_from, container): - def helper(func): - def nested(self, *args, **kwargs): - assert self.state in (v_from, container), (self.state, v_from, container) - return func(self, *args, **kwargs) - return nested - return helper + def helper(func): + def nested(self, *args, **kwargs): + assert self.state in (v_from, container), (self.state, v_from, container) + return func(self, *args, **kwargs) + return nested + return helper + + writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) reader = make_helper(VALUE_READ, CONTAINER_READ) def makeZigZag(n, bits): - return (n << 1) ^ (n >> (bits - 1)) + checkIntegerLimits(n, bits) + return (n << 1) ^ (n >> (bits - 1)) def fromZigZag(n): - return (n >> 1) ^ -(n & 1) + return (n >> 1) ^ -(n & 1) def writeVarint(trans, n): - out = [] - while True: - if n & ~0x7f == 0: - out.append(n) - break - else: - out.append((n & 0xff) | 0x80) - n = n >> 7 - trans.write(''.join(map(chr, out))) + assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!" + out = bytearray() + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + trans.write(bytes(out)) def readVarint(trans): - result = 0 - shift = 0 - while True: - x = trans.readAll(1) - byte = ord(x) - result |= (byte & 0x7f) << shift - if byte >> 7 == 0: - return result - shift += 7 + result = 0 + shift = 0 + while True: + x = trans.readAll(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 -class CompactType: - STOP = 0x00 - TRUE = 0x01 - FALSE = 0x02 - BYTE = 0x03 - I16 = 0x04 - I32 = 0x05 - I64 = 0x06 - DOUBLE = 0x07 - BINARY = 0x08 - LIST = 0x09 - SET = 0x0A - MAP = 0x0B - STRUCT = 0x0C +class CompactType(object): + STOP = 0x00 + TRUE = 0x01 + FALSE = 0x02 + BYTE = 0x03 + I16 = 0x04 + I32 = 0x05 + I64 = 0x06 + DOUBLE = 0x07 + BINARY = 0x08 + LIST = 0x09 + SET = 0x0A + MAP = 0x0B + STRUCT = 0x0C -CTYPES = {TType.STOP: CompactType.STOP, - TType.BOOL: CompactType.TRUE, # used for collection - TType.BYTE: CompactType.BYTE, - TType.I16: CompactType.I16, - TType.I32: CompactType.I32, - TType.I64: CompactType.I64, - TType.DOUBLE: CompactType.DOUBLE, - TType.STRING: CompactType.BINARY, - TType.STRUCT: CompactType.STRUCT, - TType.LIST: CompactType.LIST, - TType.SET: CompactType.SET, - TType.MAP: CompactType.MAP - } + +CTYPES = { + TType.STOP: CompactType.STOP, + TType.BOOL: CompactType.TRUE, # used for collection + TType.BYTE: CompactType.BYTE, + TType.I16: CompactType.I16, + TType.I32: CompactType.I32, + TType.I64: CompactType.I64, + TType.DOUBLE: CompactType.DOUBLE, + TType.STRING: CompactType.BINARY, + TType.STRUCT: CompactType.STRUCT, + TType.LIST: CompactType.LIST, + TType.SET: CompactType.SET, + TType.MAP: CompactType.MAP, +} TTYPES = {} -for k, v in list(CTYPES.items()): - TTYPES[v] = k +for k, v in CTYPES.items(): + TTYPES[v] = k TTYPES[CompactType.FALSE] = TType.BOOL del k del v class TCompactProtocol(TProtocolBase): - """Compact implementation of the Thrift protocol driver.""" + """Compact implementation of the Thrift protocol driver.""" - PROTOCOL_ID = 0x82 - VERSION = 1 - VERSION_MASK = 0x1f - TYPE_MASK = 0xe0 - TYPE_SHIFT_AMOUNT = 5 + PROTOCOL_ID = 0x82 + VERSION = 1 + VERSION_MASK = 0x1f + TYPE_MASK = 0xe0 + TYPE_BITS = 0x07 + TYPE_SHIFT_AMOUNT = 5 - def __init__(self, trans): - TProtocolBase.__init__(self, trans) - self.state = CLEAR - self.__last_fid = 0 - self.__bool_fid = None - self.__bool_value = None - self.__structs = [] - self.__containers = [] + def __init__(self, trans, + string_length_limit=None, + container_length_limit=None): + TProtocolBase.__init__(self, trans) + self.state = CLEAR + self.__last_fid = 0 + self.__bool_fid = None + self.__bool_value = None + self.__structs = [] + self.__containers = [] + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit - def __writeVarint(self, n): - writeVarint(self.trans, n) + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) - def writeMessageBegin(self, name, type, seqid): - assert self.state == CLEAR - self.__writeUByte(self.PROTOCOL_ID) - self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) - self.__writeVarint(seqid) - self.__writeString(name) - self.state = VALUE_WRITE + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) - def writeMessageEnd(self): - assert self.state == VALUE_WRITE - self.state = CLEAR + def __writeVarint(self, n): + writeVarint(self.trans, n) - def writeStructBegin(self, name): - assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_WRITE - self.__last_fid = 0 + def writeMessageBegin(self, name, type, seqid): + assert self.state == CLEAR + self.__writeUByte(self.PROTOCOL_ID) + self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) + # The sequence id is a signed 32-bit integer but the compact protocol + # writes this out as a "var int" which is always positive, and attempting + # to write a negative number results in an infinite loop, so we may + # need to do some conversion here... + tseqid = seqid + if tseqid < 0: + tseqid = 2147483648 + (2147483648 + tseqid) + self.__writeVarint(tseqid) + self.__writeBinary(str_to_binary(name)) + self.state = VALUE_WRITE - def writeStructEnd(self): - assert self.state == FIELD_WRITE - self.state, self.__last_fid = self.__structs.pop() + def writeMessageEnd(self): + assert self.state == VALUE_WRITE + self.state = CLEAR - def writeFieldStop(self): - self.__writeByte(0) + def writeStructBegin(self, name): + assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_WRITE + self.__last_fid = 0 - def __writeFieldHeader(self, type, fid): - delta = fid - self.__last_fid - if 0 < delta <= 15: - self.__writeUByte(delta << 4 | type) - else: - self.__writeByte(type) - self.__writeI16(fid) - self.__last_fid = fid + def writeStructEnd(self): + assert self.state == FIELD_WRITE + self.state, self.__last_fid = self.__structs.pop() - def writeFieldBegin(self, name, type, fid): - assert self.state == FIELD_WRITE, self.state - if type == TType.BOOL: - self.state = BOOL_WRITE - self.__bool_fid = fid - else: - self.state = VALUE_WRITE - self.__writeFieldHeader(CTYPES[type], fid) + def writeFieldStop(self): + self.__writeByte(0) - def writeFieldEnd(self): - assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state - self.state = FIELD_WRITE + def __writeFieldHeader(self, type, fid): + delta = fid - self.__last_fid + if 0 < delta <= 15: + self.__writeUByte(delta << 4 | type) + else: + self.__writeByte(type) + self.__writeI16(fid) + self.__last_fid = fid - def __writeUByte(self, byte): - self.trans.write(pack('!B', byte)) + def writeFieldBegin(self, name, type, fid): + assert self.state == FIELD_WRITE, self.state + if type == TType.BOOL: + self.state = BOOL_WRITE + self.__bool_fid = fid + else: + self.state = VALUE_WRITE + self.__writeFieldHeader(CTYPES[type], fid) - def __writeByte(self, byte): - self.trans.write(pack('!b', byte)) + def writeFieldEnd(self): + assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state + self.state = FIELD_WRITE - def __writeI16(self, i16): - self.__writeVarint(makeZigZag(i16, 16)) + def __writeUByte(self, byte): + self.trans.write(pack('!B', byte)) - def __writeSize(self, i32): - self.__writeVarint(i32) + def __writeByte(self, byte): + self.trans.write(pack('!b', byte)) - def writeCollectionBegin(self, etype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size <= 14: - self.__writeUByte(size << 4 | CTYPES[etype]) - else: - self.__writeUByte(0xf0 | CTYPES[etype]) - self.__writeSize(size) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE - writeSetBegin = writeCollectionBegin - writeListBegin = writeCollectionBegin + def __writeI16(self, i16): + self.__writeVarint(makeZigZag(i16, 16)) - def writeMapBegin(self, ktype, vtype, size): - assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state - if size == 0: - self.__writeByte(0) - else: - self.__writeSize(size) - self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) - self.__containers.append(self.state) - self.state = CONTAINER_WRITE + def __writeSize(self, i32): + self.__writeVarint(i32) - def writeCollectionEnd(self): - assert self.state == CONTAINER_WRITE, self.state - self.state = self.__containers.pop() - writeMapEnd = writeCollectionEnd - writeSetEnd = writeCollectionEnd - writeListEnd = writeCollectionEnd + def writeCollectionBegin(self, etype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size <= 14: + self.__writeUByte(size << 4 | CTYPES[etype]) + else: + self.__writeUByte(0xf0 | CTYPES[etype]) + self.__writeSize(size) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + writeSetBegin = writeCollectionBegin + writeListBegin = writeCollectionBegin - def writeBool(self, bool): - if self.state == BOOL_WRITE: - if bool: - ctype = CompactType.TRUE - else: - ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) - elif self.state == CONTAINER_WRITE: - if bool: - self.__writeByte(CompactType.TRUE) - else: - self.__writeByte(CompactType.FALSE) - else: - raise AssertionError("Invalid state in compact protocol") + def writeMapBegin(self, ktype, vtype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size == 0: + self.__writeByte(0) + else: + self.__writeSize(size) + self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE - writeByte = writer(__writeByte) - writeI16 = writer(__writeI16) + def writeCollectionEnd(self): + assert self.state == CONTAINER_WRITE, self.state + self.state = self.__containers.pop() + writeMapEnd = writeCollectionEnd + writeSetEnd = writeCollectionEnd + writeListEnd = writeCollectionEnd - @writer - def writeI32(self, i32): - self.__writeVarint(makeZigZag(i32, 32)) + def writeBool(self, bool): + if self.state == BOOL_WRITE: + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) + elif self.state == CONTAINER_WRITE: + if bool: + self.__writeByte(CompactType.TRUE) + else: + self.__writeByte(CompactType.FALSE) + else: + raise AssertionError("Invalid state in compact protocol") - @writer - def writeI64(self, i64): - self.__writeVarint(makeZigZag(i64, 64)) + writeByte = writer(__writeByte) + writeI16 = writer(__writeI16) - @writer - def writeDouble(self, dub): - self.trans.write(pack('!d', dub)) + @writer + def writeI32(self, i32): + self.__writeVarint(makeZigZag(i32, 32)) - def __writeString(self, s): - self.__writeSize(len(s)) - self.trans.write(s) - writeString = writer(__writeString) + @writer + def writeI64(self, i64): + self.__writeVarint(makeZigZag(i64, 64)) - def readFieldBegin(self): - assert self.state == FIELD_READ, self.state - type = self.__readUByte() - if type & 0x0f == TType.STOP: - return (None, 0, 0) - delta = type >> 4 - if delta == 0: - fid = self.__readI16() - else: - fid = self.__last_fid + delta - self.__last_fid = fid - type = type & 0x0f - if type == CompactType.TRUE: - self.state = BOOL_READ - self.__bool_value = True - elif type == CompactType.FALSE: - self.state = BOOL_READ - self.__bool_value = False - else: - self.state = VALUE_READ - return (None, self.__getTType(type), fid) + @writer + def writeDouble(self, dub): + self.trans.write(pack('> 4 + if delta == 0: + fid = self.__readI16() + else: + fid = self.__last_fid + delta + self.__last_fid = fid + type = type & 0x0f + if type == CompactType.TRUE: + self.state = BOOL_READ + self.__bool_value = True + elif type == CompactType.FALSE: + self.state = BOOL_READ + self.__bool_value = False + else: + self.state = VALUE_READ + return (None, self.__getTType(type), fid) - def __readByte(self): - result, = unpack('!b', self.trans.readAll(1)) - return result + def readFieldEnd(self): + assert self.state in (VALUE_READ, BOOL_READ), self.state + self.state = FIELD_READ - def __readVarint(self): - return readVarint(self.trans) + def __readUByte(self): + result, = unpack('!B', self.trans.readAll(1)) + return result - def __readZigZag(self): - return fromZigZag(self.__readVarint()) + def __readByte(self): + result, = unpack('!b', self.trans.readAll(1)) + return result - def __readSize(self): - result = self.__readVarint() - if result < 0: - raise TException("Length < 0") - return result + def __readVarint(self): + return readVarint(self.trans) - def readMessageBegin(self): - assert self.state == CLEAR - proto_id = self.__readUByte() - if proto_id != self.PROTOCOL_ID: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad protocol id in the message: %d' % proto_id) - ver_type = self.__readUByte() - type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT - version = ver_type & self.VERSION_MASK - if version != self.VERSION: - raise TProtocolException(TProtocolException.BAD_VERSION, - 'Bad version: %d (expect %d)' % (version, self.VERSION)) - seqid = self.__readVarint() - name = self.__readString() - return (name, type, seqid) + def __readZigZag(self): + return fromZigZag(self.__readVarint()) - def readMessageEnd(self): - assert self.state == CLEAR - assert len(self.__structs) == 0 + def __readSize(self): + result = self.__readVarint() + if result < 0: + raise TProtocolException("Length < 0") + return result - def readStructBegin(self): - assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state - self.__structs.append((self.state, self.__last_fid)) - self.state = FIELD_READ - self.__last_fid = 0 + def readMessageBegin(self): + assert self.state == CLEAR + proto_id = self.__readUByte() + if proto_id != self.PROTOCOL_ID: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad protocol id in the message: %d' % proto_id) + ver_type = self.__readUByte() + type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS + version = ver_type & self.VERSION_MASK + if version != self.VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad version: %d (expect %d)' % (version, self.VERSION)) + seqid = self.__readVarint() + # the sequence is a compact "var int" which is treaded as unsigned, + # however the sequence is actually signed... + if seqid > 2147483647: + seqid = -2147483648 - (2147483648 - seqid) + name = binary_to_str(self.__readBinary()) + return (name, type, seqid) - def readStructEnd(self): - assert self.state == FIELD_READ - self.state, self.__last_fid = self.__structs.pop() + def readMessageEnd(self): + assert self.state == CLEAR + assert len(self.__structs) == 0 - def readCollectionBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size_type = self.__readUByte() - size = size_type >> 4 - type = self.__getTType(size_type) - if size == 15: - size = self.__readSize() - self.__containers.append(self.state) - self.state = CONTAINER_READ - return type, size - readSetBegin = readCollectionBegin - readListBegin = readCollectionBegin + def readStructBegin(self): + assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_READ + self.__last_fid = 0 - def readMapBegin(self): - assert self.state in (VALUE_READ, CONTAINER_READ), self.state - size = self.__readSize() - types = 0 - if size > 0: - types = self.__readUByte() - vtype = self.__getTType(types) - ktype = self.__getTType(types >> 4) - self.__containers.append(self.state) - self.state = CONTAINER_READ - return (ktype, vtype, size) + def readStructEnd(self): + assert self.state == FIELD_READ + self.state, self.__last_fid = self.__structs.pop() - def readCollectionEnd(self): - assert self.state == CONTAINER_READ, self.state - self.state = self.__containers.pop() - readSetEnd = readCollectionEnd - readListEnd = readCollectionEnd - readMapEnd = readCollectionEnd + def readCollectionBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size_type = self.__readUByte() + size = size_type >> 4 + type = self.__getTType(size_type) + if size == 15: + size = self.__readSize() + self._check_container_length(size) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return type, size + readSetBegin = readCollectionBegin + readListBegin = readCollectionBegin - def readBool(self): - if self.state == BOOL_READ: - return self.__bool_value == CompactType.TRUE - elif self.state == CONTAINER_READ: - return self.__readByte() == CompactType.TRUE - else: - raise AssertionError("Invalid state in compact protocol: %d" % - self.state) + def readMapBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size = self.__readSize() + self._check_container_length(size) + types = 0 + if size > 0: + types = self.__readUByte() + vtype = self.__getTType(types) + ktype = self.__getTType(types >> 4) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return (ktype, vtype, size) - readByte = reader(__readByte) - __readI16 = __readZigZag - readI16 = reader(__readZigZag) - readI32 = reader(__readZigZag) - readI64 = reader(__readZigZag) + def readCollectionEnd(self): + assert self.state == CONTAINER_READ, self.state + self.state = self.__containers.pop() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + readMapEnd = readCollectionEnd - @reader - def readDouble(self): - buff = self.trans.readAll(8) - val, = unpack('!d', buff) - return val + def readBool(self): + if self.state == BOOL_READ: + return self.__bool_value == CompactType.TRUE + elif self.state == CONTAINER_READ: + return self.__readByte() == CompactType.TRUE + else: + raise AssertionError("Invalid state in compact protocol: %d" % + self.state) - def __readString(self): - len = self.__readSize() - return self.trans.readAll(len) - readString = reader(__readString) + readByte = reader(__readByte) + __readI16 = __readZigZag + readI16 = reader(__readZigZag) + readI32 = reader(__readZigZag) + readI64 = reader(__readZigZag) - def __getTType(self, byte): - return TTYPES[byte & 0x0f] + @reader + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('= 0xd800 and codeunit <= 0xdbff + + def _isLowSurrogate(self, codeunit): + return codeunit >= 0xdc00 and codeunit <= 0xdfff + + def _toChar(self, high, low=None): + if not low: + if sys.version_info[0] == 2: + return ("\\u%04x" % high).decode('unicode-escape') \ + .encode('utf-8') + else: + return chr(high) + else: + codepoint = (1 << 16) + ((high & 0x3ff) << 10) + codepoint += low & 0x3ff + if sys.version_info[0] == 2: + s = "\\U%08x" % codepoint + return s.decode('unicode-escape').encode('utf-8') + else: + return chr(codepoint) + + def readJSONString(self, skipContext): + highSurrogate = None + string = [] + if skipContext is False: + self.context.read() + self.readJSONSyntaxChar(QUOTE) + while True: + character = self.reader.read() + if character == QUOTE: + break + if ord(character) == ESCSEQ0: + character = self.reader.read() + if ord(character) == ESCSEQ1: + character = self.trans.read(4).decode('ascii') + codeunit = int(character, 16) + if self._isHighSurrogate(codeunit): + if highSurrogate: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected low surrogate char") + highSurrogate = codeunit + continue + elif self._isLowSurrogate(codeunit): + if not highSurrogate: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected high surrogate char") + character = self._toChar(highSurrogate, codeunit) + highSurrogate = None + else: + character = self._toChar(codeunit) + else: + if character not in ESCAPE_CHARS: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected control char") + character = ESCAPE_CHARS[character] + elif character in ESCAPE_CHAR_VALS: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Unescaped control char") + elif sys.version_info[0] > 2: + utf8_bytes = bytearray([ord(character)]) + while ord(self.reader.peek()) >= 0x80: + utf8_bytes.append(ord(self.reader.read())) + character = utf8_bytes.decode('utf8') + string.append(character) + + if highSurrogate: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Expected low surrogate char") + return ''.join(string) + + def isJSONNumeric(self, character): + return (True if NUMERIC_CHAR.find(character) != - 1 else False) + + def readJSONQuotes(self): + if (self.context.escapeNum()): + self.readJSONSyntaxChar(QUOTE) + + def readJSONNumericChars(self): + numeric = [] + while True: + character = self.reader.peek() + if self.isJSONNumeric(character) is False: + break + numeric.append(self.reader.read()) + return b''.join(numeric).decode('ascii') + + def readJSONInteger(self): + self.context.read() + self.readJSONQuotes() + numeric = self.readJSONNumericChars() + self.readJSONQuotes() + try: + return int(numeric) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONDouble(self): + self.context.read() + if self.reader.peek() == QUOTE: + string = self.readJSONString(True) + try: + double = float(string) + if (self.context.escapeNum is False and + not math.isinf(double) and + not math.isnan(double)): + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted") + return double + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + else: + if self.context.escapeNum() is True: + self.readJSONSyntaxChar(QUOTE) + try: + return float(self.readJSONNumericChars()) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONBase64(self): + string = self.readJSONString(False) + size = len(string) + m = size % 4 + # Force padding since b64encode method does not allow it + if m != 0: + for i in range(4 - m): + string += '=' + return base64.b64decode(string) + + def readJSONObjectStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACE) + self.pushContext(JSONPairContext(self)) + + def readJSONObjectEnd(self): + self.readJSONSyntaxChar(RBRACE) + self.popContext() + + def readJSONArrayStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACKET) + self.pushContext(JSONListContext(self)) + + def readJSONArrayEnd(self): + self.readJSONSyntaxChar(RBRACKET) + self.popContext() + + +class TJSONProtocol(TJSONProtocolBase): + + def readMessageBegin(self): + self.resetReadContext() + self.readJSONArrayStart() + if self.readJSONInteger() != VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version.") + name = self.readJSONString(False) + typen = self.readJSONInteger() + seqid = self.readJSONInteger() + return (name, typen, seqid) + + def readMessageEnd(self): + self.readJSONArrayEnd() + + def readStructBegin(self): + self.readJSONObjectStart() + + def readStructEnd(self): + self.readJSONObjectEnd() + + def readFieldBegin(self): + character = self.reader.peek() + ttype = 0 + id = 0 + if character == RBRACE: + ttype = TType.STOP + else: + id = self.readJSONInteger() + self.readJSONObjectStart() + ttype = JTYPES[self.readJSONString(False)] + return (None, ttype, id) + + def readFieldEnd(self): + self.readJSONObjectEnd() + + def readMapBegin(self): + self.readJSONArrayStart() + keyType = JTYPES[self.readJSONString(False)] + valueType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + self.readJSONObjectStart() + return (keyType, valueType, size) + + def readMapEnd(self): + self.readJSONObjectEnd() + self.readJSONArrayEnd() + + def readCollectionBegin(self): + self.readJSONArrayStart() + elemType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + return (elemType, size) + readListBegin = readCollectionBegin + readSetBegin = readCollectionBegin + + def readCollectionEnd(self): + self.readJSONArrayEnd() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + + def readBool(self): + return (False if self.readJSONInteger() == 0 else True) + + def readNumber(self): + return self.readJSONInteger() + readByte = readNumber + readI16 = readNumber + readI32 = readNumber + readI64 = readNumber + + def readDouble(self): + return self.readJSONDouble() + + def readString(self): + return self.readJSONString(False) + + def readBinary(self): + return self.readJSONBase64() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + self.writeJSONArrayStart() + self.writeJSONNumber(VERSION) + self.writeJSONString(name) + self.writeJSONNumber(request_type) + self.writeJSONNumber(seqid) + + def writeMessageEnd(self): + self.writeJSONArrayEnd() + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, id): + self.writeJSONNumber(id) + self.writeJSONObjectStart() + self.writeJSONString(CTYPES[ttype]) + + def writeFieldEnd(self): + self.writeJSONObjectEnd() + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[ktype]) + self.writeJSONString(CTYPES[vtype]) + self.writeJSONNumber(size) + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + self.writeJSONArrayEnd() + + def writeListBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeListEnd(self): + self.writeJSONArrayEnd() + + def writeSetBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeSetEnd(self): + self.writeJSONArrayEnd() + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeByte(self, byte): + checkIntegerLimits(byte, 8) + self.writeJSONNumber(byte) + + def writeI16(self, i16): + checkIntegerLimits(i16, 16) + self.writeJSONNumber(i16) + + def writeI32(self, i32): + checkIntegerLimits(i32, 32) + self.writeJSONNumber(i32) + + def writeI64(self, i64): + checkIntegerLimits(i64, 64) + self.writeJSONNumber(i64) + + def writeDouble(self, dbl): + # 17 significant digits should be just enough for any double precision + # value. + self.writeJSONNumber(dbl, '{0:.17g}') + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TJSONProtocolFactory(TProtocolFactory): + def getProtocol(self, trans): + return TJSONProtocol(trans) + + @property + def string_length_limit(senf): + return None + + @property + def container_length_limit(senf): + return None + + +class TSimpleJSONProtocol(TJSONProtocolBase): + """Simple, readable, write-only JSON protocol. + + Useful for interacting with scripting languages. + """ + + def readMessageBegin(self): + raise NotImplementedError() + + def readMessageEnd(self): + raise NotImplementedError() + + def readStructBegin(self): + raise NotImplementedError() + + def readStructEnd(self): + raise NotImplementedError() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, fid): + self.writeJSONString(name) + + def writeFieldEnd(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + + def _writeCollectionBegin(self, etype, size): + self.writeJSONArrayStart() + + def _writeCollectionEnd(self): + self.writeJSONArrayEnd() + writeListBegin = _writeCollectionBegin + writeListEnd = _writeCollectionEnd + writeSetBegin = _writeCollectionBegin + writeSetEnd = _writeCollectionEnd + + def writeByte(self, byte): + checkIntegerLimits(byte, 8) + self.writeJSONNumber(byte) + + def writeI16(self, i16): + checkIntegerLimits(i16, 16) + self.writeJSONNumber(i16) + + def writeI32(self, i32): + checkIntegerLimits(i32, 32) + self.writeJSONNumber(i32) + + def writeI64(self, i64): + checkIntegerLimits(i64, 64) + self.writeJSONNumber(i64) + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(TProtocolFactory): + + def getProtocol(self, trans): + return TSimpleJSONProtocol(trans) diff --git a/thrift/protocol/TMultiplexedProtocol.py b/thrift/protocol/TMultiplexedProtocol.py new file mode 100644 index 0000000..0f8390f --- /dev/null +++ b/thrift/protocol/TMultiplexedProtocol.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import TMessageType +from thrift.protocol import TProtocolDecorator + +SEPARATOR = ":" + + +class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator): + def __init__(self, protocol, serviceName): + self.serviceName = serviceName + + def writeMessageBegin(self, name, type, seqid): + if (type == TMessageType.CALL or + type == TMessageType.ONEWAY): + super(TMultiplexedProtocol, self).writeMessageBegin( + self.serviceName + SEPARATOR + name, + type, + seqid + ) + else: + super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid) diff --git a/thrift/protocol/TProtocol.py b/thrift/protocol/TProtocol.py index 56d323a..339a283 100644 --- a/thrift/protocol/TProtocol.py +++ b/thrift/protocol/TProtocol.py @@ -17,390 +17,412 @@ # under the License. # -from thrift.Thrift import * +from thrift.Thrift import TException, TType, TFrozenDict +from thrift.transport.TTransport import TTransportException +from ..compat import binary_to_str, str_to_binary + +import six +import sys +from itertools import islice +from six.moves import zip class TProtocolException(TException): - """Custom Protocol Exception class""" + """Custom Protocol Exception class""" - UNKNOWN = 0 - INVALID_DATA = 1 - NEGATIVE_SIZE = 2 - SIZE_LIMIT = 3 - BAD_VERSION = 4 + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 + INVALID_PROTOCOL = 7 - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type -class TProtocolBase: - """Base class for Thrift protocol driver.""" +class TProtocolBase(object): + """Base class for Thrift protocol driver.""" - def __init__(self, trans): - self.trans = trans + def __init__(self, trans): + self.trans = trans + self._fast_decode = None + self._fast_encode = None - def writeMessageBegin(self, name, type, seqid): - pass + @staticmethod + def _check_length(limit, length): + if length < 0: + raise TTransportException(TTransportException.NEGATIVE_SIZE, + 'Negative length: %d' % length) + if limit is not None and length > limit: + raise TTransportException(TTransportException.SIZE_LIMIT, + 'Length exceeded max allowed: %d' % limit) - def writeMessageEnd(self): - pass + def writeMessageBegin(self, name, ttype, seqid): + pass - def writeStructBegin(self, name): - pass + def writeMessageEnd(self): + pass - def writeStructEnd(self): - pass + def writeStructBegin(self, name): + pass - def writeFieldBegin(self, name, type, id): - pass + def writeStructEnd(self): + pass - def writeFieldEnd(self): - pass + def writeFieldBegin(self, name, ttype, fid): + pass - def writeFieldStop(self): - pass + def writeFieldEnd(self): + pass - def writeMapBegin(self, ktype, vtype, size): - pass + def writeFieldStop(self): + pass - def writeMapEnd(self): - pass + def writeMapBegin(self, ktype, vtype, size): + pass - def writeListBegin(self, etype, size): - pass + def writeMapEnd(self): + pass - def writeListEnd(self): - pass + def writeListBegin(self, etype, size): + pass - def writeSetBegin(self, etype, size): - pass + def writeListEnd(self): + pass - def writeSetEnd(self): - pass + def writeSetBegin(self, etype, size): + pass - def writeBool(self, bool): - pass + def writeSetEnd(self): + pass - def writeByte(self, byte): - pass + def writeBool(self, bool_val): + pass - def writeI16(self, i16): - pass + def writeByte(self, byte): + pass - def writeI32(self, i32): - pass + def writeI16(self, i16): + pass - def writeI64(self, i64): - pass + def writeI32(self, i32): + pass - def writeDouble(self, dub): - pass + def writeI64(self, i64): + pass - def writeString(self, str): - pass + def writeDouble(self, dub): + pass - def readMessageBegin(self): - pass + def writeString(self, str_val): + self.writeBinary(str_to_binary(str_val)) - def readMessageEnd(self): - pass + def writeBinary(self, str_val): + pass - def readStructBegin(self): - pass + def writeUtf8(self, str_val): + self.writeString(str_val.encode('utf8')) - def readStructEnd(self): - pass + def readMessageBegin(self): + pass - def readFieldBegin(self): - pass + def readMessageEnd(self): + pass - def readFieldEnd(self): - pass + def readStructBegin(self): + pass - def readMapBegin(self): - pass + def readStructEnd(self): + pass - def readMapEnd(self): - pass + def readFieldBegin(self): + pass - def readListBegin(self): - pass + def readFieldEnd(self): + pass - def readListEnd(self): - pass + def readMapBegin(self): + pass - def readSetBegin(self): - pass + def readMapEnd(self): + pass - def readSetEnd(self): - pass + def readListBegin(self): + pass - def readBool(self): - pass + def readListEnd(self): + pass - def readByte(self): - pass + def readSetBegin(self): + pass - def readI16(self): - pass + def readSetEnd(self): + pass - def readI32(self): - pass + def readBool(self): + pass - def readI64(self): - pass + def readByte(self): + pass - def readDouble(self): - pass + def readI16(self): + pass - def readString(self): - pass + def readI32(self): + pass - def skip(self, type): - if type == TType.STOP: - return - elif type == TType.BOOL: - self.readBool() - elif type == TType.BYTE: - self.readByte() - elif type == TType.I16: - self.readI16() - elif type == TType.I32: - self.readI32() - elif type == TType.I64: - self.readI64() - elif type == TType.DOUBLE: - self.readDouble() - elif type == TType.STRING: - self.readString() - elif type == TType.STRUCT: - name = self.readStructBegin() - while True: - (name, type, id) = self.readFieldBegin() - if type == TType.STOP: - break - self.skip(type) - self.readFieldEnd() - self.readStructEnd() - elif type == TType.MAP: - (ktype, vtype, size) = self.readMapBegin() - for i in range(size): - self.skip(ktype) - self.skip(vtype) - self.readMapEnd() - elif type == TType.SET: - (etype, size) = self.readSetBegin() - for i in range(size): - self.skip(etype) - self.readSetEnd() - elif type == TType.LIST: - (etype, size) = self.readListBegin() - for i in range(size): - self.skip(etype) - self.readListEnd() + def readI64(self): + pass - # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) - _TTYPE_HANDLERS = ( - (None, None, False), # 0 TType.STOP - (None, None, False), # 1 TType.VOID # TODO: handle void? - ('readBool', 'writeBool', False), # 2 TType.BOOL - ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 - ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE - (None, None, False), # 5 undefined - ('readI16', 'writeI16', False), # 6 TType.I16 - (None, None, False), # 7 undefined - ('readI32', 'writeI32', False), # 8 TType.I32 - (None, None, False), # 9 undefined - ('readI64', 'writeI64', False), # 10 TType.I64 - ('readString', 'writeString', False), # 11 TType.STRING and UTF7 - ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT - ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP - ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET - ('readContainerList', 'writeContainerList', True), # 15 TType.LIST - (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? - (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? - ) + def readDouble(self): + pass - def readFieldByTType(self, ttype, spec): - try: - (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] - except IndexError: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - if r_handler is None: - raise TProtocolException(type=TProtocolException.INVALID_DATA, - message='Invalid field type %d' % (ttype)) - reader = getattr(self, r_handler) - if not is_container: - return reader() - return reader(spec) + def readString(self): + return binary_to_str(self.readBinary()) - def readContainerList(self, spec): - results = [] - ttype, tspec = spec[0], spec[1] - r_handler = self._TTYPE_HANDLERS[ttype][0] - reader = getattr(self, r_handler) - (list_type, list_len) = self.readListBegin() - if tspec is None: - # list values are simple types - for idx in range(list_len): - results.append(reader()) - else: - # this is like an inlined readFieldByTType - container_reader = self._TTYPE_HANDLERS[list_type][0] - val_reader = getattr(self, container_reader) - for idx in range(list_len): - val = val_reader(tspec) - results.append(val) - self.readListEnd() - return results + def readBinary(self): + pass - def readContainerSet(self, spec): - results = set() - ttype, tspec = spec[0], spec[1] - r_handler = self._TTYPE_HANDLERS[ttype][0] - reader = getattr(self, r_handler) - (set_type, set_len) = self.readSetBegin() - if tspec is None: - # set members are simple types - for idx in range(set_len): - results.add(reader()) - else: - container_reader = self._TTYPE_HANDLERS[set_type][0] - val_reader = getattr(self, container_reader) - for idx in range(set_len): - results.add(val_reader(tspec)) - self.readSetEnd() - return results + def readUtf8(self): + return self.readString().decode('utf8') - def readContainerStruct(self, spec): - (obj_class, obj_spec) = spec - obj = obj_class() - obj.read(self) - return obj - - def readContainerMap(self, spec): - results = dict() - key_ttype, key_spec = spec[0], spec[1] - val_ttype, val_spec = spec[2], spec[3] - (map_ktype, map_vtype, map_len) = self.readMapBegin() - # TODO: compare types we just decoded with thrift_spec and - # abort/skip if types disagree - key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) - val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) - # list values are simple types - for idx in range(map_len): - if key_spec is None: - k_val = key_reader() - else: - k_val = self.readFieldByTType(key_ttype, key_spec) - if val_spec is None: - v_val = val_reader() - else: - v_val = self.readFieldByTType(val_ttype, val_spec) - # this raises a TypeError with unhashable keys types - # i.e. this fails: d=dict(); d[[0,1]] = 2 - results[k_val] = v_val - self.readMapEnd() - return results - - def readStruct(self, obj, thrift_spec): - self.readStructBegin() - while True: - (fname, ftype, fid) = self.readFieldBegin() - if ftype == TType.STOP: - break - try: - field = thrift_spec[fid] - except IndexError: - self.skip(ftype) - else: - if field is not None and ftype == field[1]: - fname = field[2] - fspec = field[3] - val = self.readFieldByTType(ftype, fspec) - setattr(obj, fname, val) + def skip(self, ttype): + if ttype == TType.BOOL: + self.readBool() + elif ttype == TType.BYTE: + self.readByte() + elif ttype == TType.I16: + self.readI16() + elif ttype == TType.I32: + self.readI32() + elif ttype == TType.I64: + self.readI64() + elif ttype == TType.DOUBLE: + self.readDouble() + elif ttype == TType.STRING: + self.readString() + elif ttype == TType.STRUCT: + name = self.readStructBegin() + while True: + (name, ttype, id) = self.readFieldBegin() + if ttype == TType.STOP: + break + self.skip(ttype) + self.readFieldEnd() + self.readStructEnd() + elif ttype == TType.MAP: + (ktype, vtype, size) = self.readMapBegin() + for i in range(size): + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() + elif ttype == TType.SET: + (etype, size) = self.readSetBegin() + for i in range(size): + self.skip(etype) + self.readSetEnd() + elif ttype == TType.LIST: + (etype, size) = self.readListBegin() + for i in range(size): + self.skip(etype) + self.readListEnd() else: - self.skip(ftype) - self.readFieldEnd() - self.readStructEnd() + raise TProtocolException( + TProtocolException.INVALID_DATA, + "invalid TType") - def writeContainerStruct(self, val, spec): - val.write(self) + # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) + _TTYPE_HANDLERS = ( + (None, None, False), # 0 TType.STOP + (None, None, False), # 1 TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 TType.BOOL + ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE + (None, None, False), # 5 undefined + ('readI16', 'writeI16', False), # 6 TType.I16 + (None, None, False), # 7 undefined + ('readI32', 'writeI32', False), # 8 TType.I32 + (None, None, False), # 9 undefined + ('readI64', 'writeI64', False), # 10 TType.I64 + ('readString', 'writeString', False), # 11 TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET + ('readContainerList', 'writeContainerList', True), # 15 TType.LIST + (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? + (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? + ) - def writeContainerList(self, val, spec): - self.writeListBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) - self.writeListEnd() + def _ttype_handlers(self, ttype, spec): + if spec == 'BINARY': + if ttype != TType.STRING: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid binary field type %d' % ttype) + return ('readBinary', 'writeBinary', False) + if sys.version_info[0] == 2 and spec == 'UTF8': + if ttype != TType.STRING: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid string field type %d' % ttype) + return ('readUtf8', 'writeUtf8', False) + return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) - def writeContainerSet(self, val, spec): - self.writeSetBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] - e_writer = getattr(self, w_handler) - if not is_container: - for elem in val: - e_writer(elem) - else: - for elem in val: - e_writer(elem, spec[1]) - self.writeSetEnd() + def _read_by_ttype(self, ttype, spec, espec): + reader_name, _, is_container = self._ttype_handlers(ttype, espec) + if reader_name is None: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid type %d' % (ttype)) + reader_func = getattr(self, reader_name) + read = (lambda: reader_func(espec)) if is_container else reader_func + while True: + yield read() - def writeContainerMap(self, val, spec): - k_type = spec[0] - v_type = spec[2] - ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] - ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] - k_writer = getattr(self, ktype_name) - v_writer = getattr(self, vtype_name) - self.writeMapBegin(k_type, v_type, len(val)) - for m_key, m_val in val.items(): - if not k_is_container: - k_writer(m_key) - else: - k_writer(m_key, spec[1]) - if not v_is_container: - v_writer(m_val) - else: - v_writer(m_val, spec[3]) - self.writeMapEnd() + def readFieldByTType(self, ttype, spec): + return next(self._read_by_ttype(ttype, spec, spec)) - def writeStruct(self, obj, thrift_spec): - self.writeStructBegin(obj.__class__.__name__) - for field in thrift_spec: - if field is None: - continue - fname = field[2] - val = getattr(obj, fname) - if val is None: - # skip writing out unset fields - continue - fid = field[0] - ftype = field[1] - fspec = field[3] - # get the writer method for this value - self.writeFieldBegin(fname, ftype, fid) - self.writeFieldByTType(ftype, val, fspec) - self.writeFieldEnd() - self.writeFieldStop() - self.writeStructEnd() + def readContainerList(self, spec): + ttype, tspec, is_immutable = spec + (list_type, list_len) = self.readListBegin() + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) + results = (tuple if is_immutable else list)(elems) + self.readListEnd() + return results - def writeFieldByTType(self, ttype, val, spec): - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] - writer = getattr(self, w_handler) - if is_container: - writer(val, spec) - else: - writer(val) + def readContainerSet(self, spec): + ttype, tspec, is_immutable = spec + (set_type, set_len) = self.readSetBegin() + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) + results = (frozenset if is_immutable else set)(elems) + self.readSetEnd() + return results + + def readContainerStruct(self, spec): + (obj_class, obj_spec) = spec + + # If obj_class.read is a classmethod (e.g. in frozen structs), + # call it as such. + if getattr(obj_class.read, '__self__', None) is obj_class: + obj = obj_class.read(self) + else: + obj = obj_class() + obj.read(self) + return obj + + def readContainerMap(self, spec): + ktype, kspec, vtype, vspec, is_immutable = spec + (map_ktype, map_vtype, map_len) = self.readMapBegin() + # TODO: compare types we just decoded with thrift_spec and + # abort/skip if types disagree + keys = self._read_by_ttype(ktype, spec, kspec) + vals = self._read_by_ttype(vtype, spec, vspec) + keyvals = islice(zip(keys, vals), map_len) + results = (TFrozenDict if is_immutable else dict)(keyvals) + self.readMapEnd() + return results + + def readStruct(self, obj, thrift_spec, is_immutable=False): + if is_immutable: + fields = {} + self.readStructBegin() + while True: + (fname, ftype, fid) = self.readFieldBegin() + if ftype == TType.STOP: + break + try: + field = thrift_spec[fid] + except IndexError: + self.skip(ftype) + else: + if field is not None and ftype == field[1]: + fname = field[2] + fspec = field[3] + val = self.readFieldByTType(ftype, fspec) + if is_immutable: + fields[fname] = val + else: + setattr(obj, fname, val) + else: + self.skip(ftype) + self.readFieldEnd() + self.readStructEnd() + if is_immutable: + return obj(**fields) + + def writeContainerStruct(self, val, spec): + val.write(self) + + def writeContainerList(self, val, spec): + ttype, tspec, _ = spec + self.writeListBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass + self.writeListEnd() + + def writeContainerSet(self, val, spec): + ttype, tspec, _ = spec + self.writeSetBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass + self.writeSetEnd() + + def writeContainerMap(self, val, spec): + ktype, kspec, vtype, vspec, _ = spec + self.writeMapBegin(ktype, vtype, len(val)) + for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec), + self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)): + pass + self.writeMapEnd() + + def writeStruct(self, obj, thrift_spec): + self.writeStructBegin(obj.__class__.__name__) + for field in thrift_spec: + if field is None: + continue + fname = field[2] + val = getattr(obj, fname) + if val is None: + # skip writing out unset fields + continue + fid = field[0] + ftype = field[1] + fspec = field[3] + self.writeFieldBegin(fname, ftype, fid) + self.writeFieldByTType(ftype, val, fspec) + self.writeFieldEnd() + self.writeFieldStop() + self.writeStructEnd() + + def _write_by_ttype(self, ttype, vals, spec, espec): + _, writer_name, is_container = self._ttype_handlers(ttype, espec) + writer_func = getattr(self, writer_name) + write = (lambda v: writer_func(v, espec)) if is_container else writer_func + for v in vals: + yield write(v) + + def writeFieldByTType(self, ttype, val, spec): + next(self._write_by_ttype(ttype, [val], spec, spec)) -class TProtocolFactory: - def getProtocol(self, trans): - pass +def checkIntegerLimits(i, bits): + if bits == 8 and (i < -128 or i > 127): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i8 requires -128 <= number <= 127") + elif bits == 16 and (i < -32768 or i > 32767): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i16 requires -32768 <= number <= 32767") + elif bits == 32 and (i < -2147483648 or i > 2147483647): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i32 requires -2147483648 <= number <= 2147483647") + elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i64 requires -9223372036854775808 <= number <= 9223372036854775807") + + +class TProtocolFactory(object): + def getProtocol(self, trans): + pass diff --git a/thrift/protocol/TProtocolDecorator.py b/thrift/protocol/TProtocolDecorator.py new file mode 100644 index 0000000..f5546c7 --- /dev/null +++ b/thrift/protocol/TProtocolDecorator.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +class TProtocolDecorator(object): + def __new__(cls, protocol, *args, **kwargs): + decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]), + (cls, protocol.__class__), + protocol.__dict__) + return object.__new__(decorated_cls) diff --git a/thrift/protocol/__init__.py b/thrift/protocol/__init__.py index d53359b..06647a2 100644 --- a/thrift/protocol/__init__.py +++ b/thrift/protocol/__init__.py @@ -17,4 +17,5 @@ # under the License. # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', + 'TJSONProtocol', 'TProtocol', 'TProtocolDecorator'] diff --git a/thrift/protocol/fastbinary.cpython-38-darwin.so b/thrift/protocol/fastbinary.cpython-38-darwin.so new file mode 100755 index 0000000..bbfd5f1 Binary files /dev/null and b/thrift/protocol/fastbinary.cpython-38-darwin.so differ diff --git a/thrift/server/THttpServer.py b/thrift/server/THttpServer.py index f6d1ff5..47e817d 100644 --- a/thrift/server/THttpServer.py +++ b/thrift/server/THttpServer.py @@ -17,71 +17,115 @@ # under the License. # -import http.server +import ssl +from six.moves import BaseHTTPServer + +from thrift.Thrift import TMessageType from thrift.server import TServer from thrift.transport import TTransport class ResponseException(Exception): - """Allows handlers to override the HTTP response + """Allows handlers to override the HTTP response - Normally, THttpServer always sends a 200 response. If a handler wants - to override this behavior (e.g., to simulate a misconfigured or - overloaded web server during testing), it can raise a ResponseException. - The function passed to the constructor will be called with the - RequestHandler as its only argument. - """ - def __init__(self, handler): - self.handler = handler + Normally, THttpServer always sends a 200 response. If a handler wants + to override this behavior (e.g., to simulate a misconfigured or + overloaded web server during testing), it can raise a ResponseException. + The function passed to the constructor will be called with the + RequestHandler as its only argument. Note that this is irrelevant + for ONEWAY requests, as the HTTP response must be sent before the + RPC is processed. + """ + def __init__(self, handler): + self.handler = handler class THttpServer(TServer.TServer): - """A simple HTTP-based Thrift server + """A simple HTTP-based Thrift server - This class is not very performant, but it is useful (for example) for - acting as a mock version of an Apache-based PHP Thrift endpoint. - """ - def __init__(self, - processor, - server_address, - inputProtocolFactory, - outputProtocolFactory=None, - server_class=http.server.HTTPServer): - """Set up protocol factories and HTTP server. - - See BaseHTTPServer for server_address. - See TServer for protocol factories. + This class is not very performant, but it is useful (for example) for + acting as a mock version of an Apache-based PHP Thrift endpoint. + Also important to note the HTTP implementation pretty much violates the + transport/protocol/processor/server layering, by performing the transport + functions here. This means things like oneway handling are oddly exposed. """ - if outputProtocolFactory is None: - outputProtocolFactory = inputProtocolFactory + def __init__(self, + processor, + server_address, + inputProtocolFactory, + outputProtocolFactory=None, + server_class=BaseHTTPServer.HTTPServer, + **kwargs): + """Set up protocol factories and HTTP (or HTTPS) server. - TServer.TServer.__init__(self, processor, None, None, None, - inputProtocolFactory, outputProtocolFactory) + See BaseHTTPServer for server_address. + See TServer for protocol factories. - thttpserver = self + To make a secure server, provide the named arguments: + * cafile - to validate clients [optional] + * cert_file - the server cert + * key_file - the server's key + """ + if outputProtocolFactory is None: + outputProtocolFactory = inputProtocolFactory - class RequestHander(http.server.BaseHTTPRequestHandler): - def do_POST(self): - # Don't care about the request path. - itrans = TTransport.TFileObjectTransport(self.rfile) - otrans = TTransport.TFileObjectTransport(self.wfile) - itrans = TTransport.TBufferedTransport( - itrans, int(self.headers['Content-Length'])) - otrans = TTransport.TMemoryBuffer() - iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) - oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) - try: - thttpserver.processor.process(iprot, oprot) - except ResponseException as exn: - exn.handler(self) - else: - self.send_response(200) - self.send_header("content-type", "application/x-thrift") - self.end_headers() - self.wfile.write(otrans.getvalue()) + TServer.TServer.__init__(self, processor, None, None, None, + inputProtocolFactory, outputProtocolFactory) - self.httpd = server_class(server_address, RequestHander) + thttpserver = self + self._replied = None - def serve(self): - self.httpd.serve_forever() + class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): + def do_POST(self): + # Don't care about the request path. + thttpserver._replied = False + iftrans = TTransport.TFileObjectTransport(self.rfile) + itrans = TTransport.TBufferedTransport( + iftrans, int(self.headers['Content-Length'])) + otrans = TTransport.TMemoryBuffer() + iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) + oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) + try: + thttpserver.processor.on_message_begin(self.on_begin) + thttpserver.processor.process(iprot, oprot) + except ResponseException as exn: + exn.handler(self) + else: + if not thttpserver._replied: + # If the request was ONEWAY we would have replied already + data = otrans.getvalue() + self.send_response(200) + self.send_header("Content-Length", len(data)) + self.send_header("Content-Type", "application/x-thrift") + self.end_headers() + self.wfile.write(data) + + def on_begin(self, name, type, seqid): + """ + Inspect the message header. + + This allows us to post an immediate transport response + if the request is a ONEWAY message type. + """ + if type == TMessageType.ONEWAY: + self.send_response(200) + self.send_header("Content-Type", "application/x-thrift") + self.end_headers() + thttpserver._replied = True + + self.httpd = server_class(server_address, RequestHander) + + if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')): + context = ssl.create_default_context(cafile=kwargs.get('cafile')) + context.check_hostname = False + context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file')) + context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE + self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True) + + def serve(self): + self.httpd.serve_forever() + + def shutdown(self): + self.httpd.socket.close() + # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly! diff --git a/thrift/server/TNonblockingServer.py b/thrift/server/TNonblockingServer.py index 764c9ae..f62d486 100644 --- a/thrift/server/TNonblockingServer.py +++ b/thrift/server/TNonblockingServer.py @@ -24,18 +24,23 @@ only from the main thread. The thread poool should be sized for concurrent tasks, not maximum connections """ -import threading -import socket -import queue -import select -import struct + import logging +import select +import socket +import struct +import threading + +from collections import deque +from six.moves import queue from thrift.transport import TTransport from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory __all__ = ['TNonblockingServer'] +logger = logging.getLogger(__name__) + class Worker(threading.Thread): """Worker is a small helper to process incoming connection.""" @@ -54,8 +59,9 @@ class Worker(threading.Thread): processor.process(iprot, oprot) callback(True, otrans.getvalue()) except Exception: - logging.exception("Exception while processing request") - callback(False, '') + logger.exception("Exception while processing request", exc_info=True) + callback(False, b'') + WAIT_LEN = 0 WAIT_MESSAGE = 1 @@ -81,11 +87,24 @@ def socket_exception(func): try: return func(self, *args, **kwargs) except socket.error: + logger.debug('ignoring socket exception', exc_info=True) self.close() return read -class Connection: +class Message(object): + def __init__(self, offset, len_, header): + self.offset = offset + self.len = len_ + self.buffer = None + self.is_header = header + + @property + def end(self): + return self.offset + self.len + + +class Connection(object): """Basic class is represented connection. It can be in state: @@ -102,68 +121,60 @@ class Connection: self.socket.setblocking(False) self.status = WAIT_LEN self.len = 0 - self.message = '' + self.received = deque() + self._reading = Message(0, 4, True) + self._rbuf = b'' + self._wbuf = b'' self.lock = threading.Lock() self.wake_up = wake_up - - def _read_len(self): - """Reads length of request. - - It's a safer alternative to self.socket.recv(4) - """ - read = self.socket.recv(4 - len(self.message)) - if len(read) == 0: - # if we read 0 bytes and self.message is empty, then - # the client closed the connection - if len(self.message) != 0: - logging.error("can't read frame size from socket") - self.close() - return - self.message += read - if len(self.message) == 4: - self.len, = struct.unpack('!i', self.message) - if self.len < 0: - logging.error("negative frame size, it seems client " - "doesn't use FramedTransport") - self.close() - elif self.len == 0: - logging.error("empty frame, it's really strange") - self.close() - else: - self.message = '' - self.status = WAIT_MESSAGE + self.remaining = False @socket_exception def read(self): """Reads data from stream and switch state.""" assert self.status in (WAIT_LEN, WAIT_MESSAGE) - if self.status == WAIT_LEN: - self._read_len() - # go back to the main loop here for simplicity instead of - # falling through, even though there is a good chance that - # the message is already available - elif self.status == WAIT_MESSAGE: - read = self.socket.recv(self.len - len(self.message)) - if len(read) == 0: - logging.error("can't read frame from socket (get %d of " - "%d bytes)" % (len(self.message), self.len)) + assert not self.received + buf_size = 8192 + first = True + done = False + while not done: + read = self.socket.recv(buf_size) + rlen = len(read) + done = rlen < buf_size + self._rbuf += read + if first and rlen == 0: + if self.status != WAIT_LEN or self._rbuf: + logger.error('could not read frame from socket') + else: + logger.debug('read zero length. client might have disconnected') self.close() - return - self.message += read - if len(self.message) == self.len: + while len(self._rbuf) >= self._reading.end: + if self._reading.is_header: + mlen, = struct.unpack('!i', self._rbuf[:4]) + self._reading = Message(self._reading.end, mlen, False) + self.status = WAIT_MESSAGE + else: + self._reading.buffer = self._rbuf + self.received.append(self._reading) + self._rbuf = self._rbuf[self._reading.end:] + self._reading = Message(0, 4, True) + first = False + if self.received: self.status = WAIT_PROCESS + break + self.remaining = not done @socket_exception def write(self): """Writes data from socket and switch state.""" assert self.status == SEND_ANSWER - sent = self.socket.send(self.message) - if sent == len(self.message): + sent = self.socket.send(self._wbuf) + if sent == len(self._wbuf): self.status = WAIT_LEN - self.message = '' + self._wbuf = b'' self.len = 0 else: - self.message = self.message[sent:] + self._wbuf = self._wbuf[sent:] @locked def ready(self, all_ok, message): @@ -183,13 +194,13 @@ class Connection: self.close() self.wake_up() return - self.len = '' + self.len = 0 if len(message) == 0: # it was a oneway request, do not write answer - self.message = '' + self._wbuf = b'' self.status = WAIT_LEN else: - self.message = struct.pack('!i', len(message)) + message + self._wbuf = struct.pack('!i', len(message)) + message self.status = SEND_ANSWER self.wake_up() @@ -219,7 +230,7 @@ class Connection: self.socket.close() -class TNonblockingServer: +class TNonblockingServer(object): """Non-blocking server.""" def __init__(self, @@ -259,7 +270,7 @@ class TNonblockingServer: def wake_up(self): """Wake up main thread. - The server usualy waits in select call in we should terminate one. + The server usually waits in select call in we should terminate one. The simplest way is using socketpair. Select always wait to read from the first socket of socketpair. @@ -267,7 +278,7 @@ class TNonblockingServer: In this case, we can just write anything to the second socket from socketpair. """ - self._write.send('1') + self._write.send(b'1') def stop(self): """Stop the server. @@ -288,14 +299,20 @@ class TNonblockingServer: """Does select on open connections.""" readable = [self.socket.handle.fileno(), self._read.fileno()] writable = [] + remaining = [] for i, connection in list(self.clients.items()): if connection.is_readable(): readable.append(connection.fileno()) + if connection.remaining or connection.received: + remaining.append(connection.fileno()) if connection.is_writeable(): writable.append(connection.fileno()) if connection.is_closed(): del self.clients[i] - return select.select(readable, writable, readable) + if remaining: + return remaining, [], [], False + else: + return select.select(readable, writable, readable) + (True,) def handle(self): """Handle requests. @@ -303,20 +320,27 @@ class TNonblockingServer: WARNING! You must call prepare() BEFORE calling handle() """ assert self.prepared, "You have to call prepare before handle" - rset, wset, xset = self._select() + rset, wset, xset, selected = self._select() for readable in rset: if readable == self._read.fileno(): # don't care i just need to clean readable flag self._read.recv(1024) elif readable == self.socket.handle.fileno(): - client = self.socket.accept().handle - self.clients[client.fileno()] = Connection(client, - self.wake_up) + try: + client = self.socket.accept() + if client: + self.clients[client.handle.fileno()] = Connection(client.handle, + self.wake_up) + except socket.error: + logger.debug('error while accepting', exc_info=True) else: connection = self.clients[readable] - connection.read() - if connection.status == WAIT_PROCESS: - itransport = TTransport.TMemoryBuffer(connection.message) + if selected: + connection.read() + if connection.received: + connection.status = WAIT_PROCESS + msg = connection.received.popleft() + itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset) otransport = TTransport.TMemoryBuffer() iprot = self.in_protocol.getProtocol(itransport) oprot = self.out_protocol.getProtocol(otransport) diff --git a/thrift/server/TProcessPoolServer.py b/thrift/server/TProcessPoolServer.py index e8b2c9c..fe6dc81 100644 --- a/thrift/server/TProcessPoolServer.py +++ b/thrift/server/TProcessPoolServer.py @@ -19,11 +19,13 @@ import logging -from multiprocessing import Process, Value, Condition, reduction + +from multiprocessing import Process, Value, Condition from .TServer import TServer from thrift.transport.TTransport import TTransportException -import collections.abc + +logger = logging.getLogger(__name__) class TProcessPoolServer(TServer): @@ -41,7 +43,7 @@ class TProcessPoolServer(TServer): self.postForkCallback = None def setPostForkCallback(self, callback): - if not isinstance(callback, collections.abc.Callable): + if not callable(callback): raise TypeError("This is not a callback!") self.postForkCallback = callback @@ -57,11 +59,13 @@ class TProcessPoolServer(TServer): while self.isRunning.value: try: client = self.serverTransport.accept() + if not client: + continue self.serveClient(client) except (KeyboardInterrupt, SystemExit): return 0 except Exception as x: - logging.exception(x) + logger.exception(x) def serveClient(self, client): """Process input/output from a client for as long as possible""" @@ -73,10 +77,10 @@ class TProcessPoolServer(TServer): try: while True: self.processor.process(iprot, oprot) - except TTransportException as tx: + except TTransportException: pass except Exception as x: - logging.exception(x) + logger.exception(x) itrans.close() otrans.close() @@ -97,7 +101,7 @@ class TProcessPoolServer(TServer): w.start() self.workers.append(w) except Exception as x: - logging.exception(x) + logger.exception(x) # wait until the condition is set by stop() while True: @@ -108,7 +112,7 @@ class TProcessPoolServer(TServer): except (SystemExit, KeyboardInterrupt): break except Exception as x: - logging.exception(x) + logger.exception(x) self.isRunning.value = False diff --git a/thrift/server/TServer.py b/thrift/server/TServer.py index 9e340f4..df2a7bb 100644 --- a/thrift/server/TServer.py +++ b/thrift/server/TServer.py @@ -17,253 +17,307 @@ # under the License. # -import queue +from six.moves import queue import logging import os -import sys import threading -import traceback -from thrift.Thrift import TProcessor from thrift.protocol import TBinaryProtocol +from thrift.protocol.THeaderProtocol import THeaderProtocolFactory from thrift.transport import TTransport +logger = logging.getLogger(__name__) -class TServer: - """Base interface for a server, which must have a serve() method. - Three constructors for all servers: - 1) (processor, serverTransport) - 2) (processor, serverTransport, transportFactory, protocolFactory) - 3) (processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory) - """ - def __init__(self, *args): - if (len(args) == 2): - self.__initArgs__(args[0], args[1], - TTransport.TTransportFactoryBase(), - TTransport.TTransportFactoryBase(), - TBinaryProtocol.TBinaryProtocolFactory(), - TBinaryProtocol.TBinaryProtocolFactory()) - elif (len(args) == 4): - self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) - elif (len(args) == 6): - self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) +class TServer(object): + """Base interface for a server, which must have a serve() method. - def __initArgs__(self, processor, serverTransport, - inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory): - self.processor = processor - self.serverTransport = serverTransport - self.inputTransportFactory = inputTransportFactory - self.outputTransportFactory = outputTransportFactory - self.inputProtocolFactory = inputProtocolFactory - self.outputProtocolFactory = outputProtocolFactory + Three constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory) + """ + def __init__(self, *args): + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) - def serve(self): - pass + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): + self.processor = processor + self.serverTransport = serverTransport + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory + + input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory) + output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory) + if any((input_is_header, output_is_header)) and input_is_header != output_is_header: + raise ValueError("THeaderProtocol servers require that both the input and " + "output protocols are THeaderProtocol.") + + def serve(self): + pass class TSimpleServer(TServer): - """Simple single-threaded server that just pumps around one transport.""" + """Simple single-threaded server that just pumps around one transport.""" - def __init__(self, *args): - TServer.__init__(self, *args) + def __init__(self, *args): + TServer.__init__(self, *args) - def serve(self): - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: + def serve(self): + self.serverTransport.listen() while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as x: - logging.exception(x) + client = self.serverTransport.accept() + if not client: + continue - itrans.close() - otrans.close() + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + + # for THeaderProtocol, we must use the same protocol instance for + # input and output so that the response is in the same dialect that + # the server detected the request was in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + if otrans: + otrans.close() class TThreadedServer(TServer): - """Threaded server that spawns a new thread per each connection.""" + """Threaded server that spawns a new thread per each connection.""" - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.daemon = kwargs.get("daemon", False) + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.daemon = kwargs.get("daemon", False) - def serve(self): - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - t = threading.Thread(target=self.handle, args=(client,)) - t.setDaemon(self.daemon) - t.start() - except KeyboardInterrupt: - raise - except Exception as x: - logging.exception(x) + def serve(self): + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + if not client: + continue + t = threading.Thread(target=self.handle, args=(client,)) + t.setDaemon(self.daemon) + t.start() + except KeyboardInterrupt: + raise + except Exception as x: + logger.exception(x) - def handle(self, client): - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as x: - logging.exception(x) + def handle(self, client): + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) - itrans.close() - otrans.close() + # for THeaderProtocol, we must use the same protocol instance for input + # and output so that the response is in the same dialect that the + # server detected the request was in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + if otrans: + otrans.close() class TThreadPoolServer(TServer): - """Server with a fixed size pool of threads which service requests.""" + """Server with a fixed size pool of threads which service requests.""" - def __init__(self, *args, **kwargs): - TServer.__init__(self, *args) - self.clients = queue.Queue() - self.threads = 10 - self.daemon = kwargs.get("daemon", False) + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.clients = queue.Queue() + self.threads = 10 + self.daemon = kwargs.get("daemon", False) - def setNumThreads(self, num): - """Set the number of worker threads that should be created""" - self.threads = num + def setNumThreads(self, num): + """Set the number of worker threads that should be created""" + self.threads = num - def serveThread(self): - """Loop around getting clients from the shared queue and process them.""" - while True: - try: - client = self.clients.get() - self.serveClient(client) - except Exception as x: - logging.exception(x) + def serveThread(self): + """Loop around getting clients from the shared queue and process them.""" + while True: + try: + client = self.clients.get() + self.serveClient(client) + except Exception as x: + logger.exception(x) - def serveClient(self, client): - """Process input/output from a client for as long as possible""" - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as x: - logging.exception(x) + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) - itrans.close() - otrans.close() + # for THeaderProtocol, we must use the same protocol instance for input + # and output so that the response is in the same dialect that the + # server detected the request was in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - for i in range(self.threads): - try: - t = threading.Thread(target=self.serveThread) - t.setDaemon(self.daemon) - t.start() - except Exception as x: - logging.exception(x) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) - # Pump the socket for clients - self.serverTransport.listen() - while True: - try: - client = self.serverTransport.accept() - self.clients.put(client) - except Exception as x: - logging.exception(x) + itrans.close() + if otrans: + otrans.close() + + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + for i in range(self.threads): + try: + t = threading.Thread(target=self.serveThread) + t.setDaemon(self.daemon) + t.start() + except Exception as x: + logger.exception(x) + + # Pump the socket for clients + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + if not client: + continue + self.clients.put(client) + except Exception as x: + logger.exception(x) class TForkingServer(TServer): - """A Thrift server that forks a new process for each request + """A Thrift server that forks a new process for each request - This is more scalable than the threaded server as it does not cause - GIL contention. + This is more scalable than the threaded server as it does not cause + GIL contention. - Note that this has different semantics from the threading server. - Specifically, updates to shared variables will no longer be shared. - It will also not work on windows. + Note that this has different semantics from the threading server. + Specifically, updates to shared variables will no longer be shared. + It will also not work on windows. - This code is heavily inspired by SocketServer.ForkingMixIn in the - Python stdlib. - """ - def __init__(self, *args): - TServer.__init__(self, *args) - self.children = [] + This code is heavily inspired by SocketServer.ForkingMixIn in the + Python stdlib. + """ + def __init__(self, *args): + TServer.__init__(self, *args) + self.children = [] - def serve(self): - def try_close(file): - try: - file.close() - except IOError as e: - logging.warning(e, exc_info=True) - - self.serverTransport.listen() - while True: - client = self.serverTransport.accept() - try: - pid = os.fork() - - if pid: # parent - # add before collect, otherwise you race w/ waitpid - self.children.append(pid) - self.collect_children() - - # Parent must close socket or the connection may not get - # closed promptly - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - try_close(itrans) - try_close(otrans) - else: - itrans = self.inputTransportFactory.getTransport(client) - otrans = self.outputTransportFactory.getTransport(client) - - iprot = self.inputProtocolFactory.getProtocol(itrans) - oprot = self.outputProtocolFactory.getProtocol(otrans) - - ecode = 0 - try: + def serve(self): + def try_close(file): try: - while True: - self.processor.process(iprot, oprot) - except TTransport.TTransportException as tx: - pass - except Exception as e: - logging.exception(e) - ecode = 1 - finally: - try_close(itrans) - try_close(otrans) + file.close() + except IOError as e: + logger.warning(e, exc_info=True) - os._exit(ecode) + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + if not client: + continue + try: + pid = os.fork() - except TTransport.TTransportException as tx: - pass - except Exception as x: - logging.exception(x) + if pid: # parent + # add before collect, otherwise you race w/ waitpid + self.children.append(pid) + self.collect_children() - def collect_children(self): - while self.children: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except os.error: - pid = None + # Parent must close socket or the connection may not get + # closed promptly + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + try_close(itrans) + try_close(otrans) + else: + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) - if pid: - self.children.remove(pid) - else: - break + # for THeaderProtocol, we must use the same protocol + # instance for input and output so that the response is in + # the same dialect that the server detected the request was + # in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + ecode = 0 + try: + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as e: + logger.exception(e) + ecode = 1 + finally: + try_close(itrans) + if otrans: + try_close(otrans) + + os._exit(ecode) + + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + def collect_children(self): + while self.children: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except os.error: + pid = None + + if pid: + self.children.remove(pid) + else: + break diff --git a/thrift/transport/THeaderTransport.py b/thrift/transport/THeaderTransport.py new file mode 100644 index 0000000..7c9827b --- /dev/null +++ b/thrift/transport/THeaderTransport.py @@ -0,0 +1,352 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import struct +import zlib + +from thrift.compat import BufferIO, byte_index +from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint +from thrift.Thrift import TApplicationException +from thrift.transport.TTransport import ( + CReadableTransport, + TMemoryBuffer, + TTransportBase, + TTransportException, +) + + +U16 = struct.Struct("!H") +I32 = struct.Struct("!i") +HEADER_MAGIC = 0x0FFF +HARD_MAX_FRAME_SIZE = 0x3FFFFFFF + + +class THeaderClientType(object): + HEADERS = 0x00 + + FRAMED_BINARY = 0x01 + UNFRAMED_BINARY = 0x02 + + FRAMED_COMPACT = 0x03 + UNFRAMED_COMPACT = 0x04 + + +class THeaderSubprotocolID(object): + BINARY = 0x00 + COMPACT = 0x02 + + +class TInfoHeaderType(object): + KEY_VALUE = 0x01 + + +class THeaderTransformID(object): + ZLIB = 0x01 + + +READ_TRANSFORMS_BY_ID = { + THeaderTransformID.ZLIB: zlib.decompress, +} + + +WRITE_TRANSFORMS_BY_ID = { + THeaderTransformID.ZLIB: zlib.compress, +} + + +def _readString(trans): + size = readVarint(trans) + if size < 0: + raise TTransportException( + TTransportException.NEGATIVE_SIZE, + "Negative length" + ) + return trans.read(size) + + +def _writeString(trans, value): + writeVarint(trans, len(value)) + trans.write(value) + + +class THeaderTransport(TTransportBase, CReadableTransport): + def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY): + self._transport = transport + self._client_type = THeaderClientType.HEADERS + self._allowed_client_types = allowed_client_types + + self._read_buffer = BufferIO(b"") + self._read_headers = {} + + self._write_buffer = BufferIO() + self._write_headers = {} + self._write_transforms = [] + + self.flags = 0 + self.sequence_id = 0 + self._protocol_id = default_protocol + self._max_frame_size = HARD_MAX_FRAME_SIZE + + def isOpen(self): + return self._transport.isOpen() + + def open(self): + return self._transport.open() + + def close(self): + return self._transport.close() + + def get_headers(self): + return self._read_headers + + def set_header(self, key, value): + if not isinstance(key, bytes): + raise ValueError("header names must be bytes") + if not isinstance(value, bytes): + raise ValueError("header values must be bytes") + self._write_headers[key] = value + + def clear_headers(self): + self._write_headers.clear() + + def add_transform(self, transform_id): + if transform_id not in WRITE_TRANSFORMS_BY_ID: + raise ValueError("unknown transform") + self._write_transforms.append(transform_id) + + def set_max_frame_size(self, size): + if not 0 < size < HARD_MAX_FRAME_SIZE: + raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE) + self._max_frame_size = size + + @property + def protocol_id(self): + if self._client_type == THeaderClientType.HEADERS: + return self._protocol_id + elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY): + return THeaderSubprotocolID.BINARY + elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT): + return THeaderSubprotocolID.COMPACT + else: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Protocol ID not know for client type %d" % self._client_type, + ) + + def read(self, sz): + # if there are bytes left in the buffer, produce those first. + bytes_read = self._read_buffer.read(sz) + bytes_left_to_read = sz - len(bytes_read) + if bytes_left_to_read == 0: + return bytes_read + + # if we've determined this is an unframed client, just pass the read + # through to the underlying transport until we're reset again at the + # beginning of the next message. + if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT): + return bytes_read + self._transport.read(bytes_left_to_read) + + # we're empty and (maybe) framed. fill the buffers with the next frame. + self.readFrame(bytes_left_to_read) + return bytes_read + self._read_buffer.read(bytes_left_to_read) + + def _set_client_type(self, client_type): + if client_type not in self._allowed_client_types: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Client type %d not allowed by server." % client_type, + ) + self._client_type = client_type + + def readFrame(self, req_sz): + # the first word could either be the length field of a framed message + # or the first bytes of an unframed message. + first_word = self._transport.readAll(I32.size) + frame_size, = I32.unpack(first_word) + is_unframed = False + if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1: + self._set_client_type(THeaderClientType.UNFRAMED_BINARY) + is_unframed = True + elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and + byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION): + self._set_client_type(THeaderClientType.UNFRAMED_COMPACT) + is_unframed = True + + if is_unframed: + bytes_left_to_read = req_sz - I32.size + if bytes_left_to_read > 0: + rest = self._transport.read(bytes_left_to_read) + else: + rest = b"" + self._read_buffer = BufferIO(first_word + rest) + return + + # ok, we're still here so we're framed. + if frame_size > self._max_frame_size: + raise TTransportException( + TTransportException.SIZE_LIMIT, + "Frame was too large.", + ) + read_buffer = BufferIO(self._transport.readAll(frame_size)) + + # the next word is either going to be the version field of a + # binary/compact protocol message or the magic value + flags of a + # header protocol message. + second_word = read_buffer.read(I32.size) + version, = I32.unpack(second_word) + read_buffer.seek(0) + if version >> 16 == HEADER_MAGIC: + self._set_client_type(THeaderClientType.HEADERS) + self._read_buffer = self._parse_header_format(read_buffer) + elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1: + self._set_client_type(THeaderClientType.FRAMED_BINARY) + self._read_buffer = read_buffer + elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and + byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION): + self._set_client_type(THeaderClientType.FRAMED_COMPACT) + self._read_buffer = read_buffer + else: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Could not detect client transport type.", + ) + + def _parse_header_format(self, buffer): + # make BufferIO look like TTransport for varint helpers + buffer_transport = TMemoryBuffer() + buffer_transport._buffer = buffer + + buffer.read(2) # discard the magic bytes + self.flags, = U16.unpack(buffer.read(U16.size)) + self.sequence_id, = I32.unpack(buffer.read(I32.size)) + + header_length = U16.unpack(buffer.read(U16.size))[0] * 4 + end_of_headers = buffer.tell() + header_length + if end_of_headers > len(buffer.getvalue()): + raise TTransportException( + TTransportException.SIZE_LIMIT, + "Header size is larger than whole frame.", + ) + + self._protocol_id = readVarint(buffer_transport) + + transforms = [] + transform_count = readVarint(buffer_transport) + for _ in range(transform_count): + transform_id = readVarint(buffer_transport) + if transform_id not in READ_TRANSFORMS_BY_ID: + raise TApplicationException( + TApplicationException.INVALID_TRANSFORM, + "Unknown transform: %d" % transform_id, + ) + transforms.append(transform_id) + transforms.reverse() + + headers = {} + while buffer.tell() < end_of_headers: + header_type = readVarint(buffer_transport) + if header_type == TInfoHeaderType.KEY_VALUE: + count = readVarint(buffer_transport) + for _ in range(count): + key = _readString(buffer_transport) + value = _readString(buffer_transport) + headers[key] = value + else: + break # ignore unknown headers + self._read_headers = headers + + # skip padding / anything we didn't understand + buffer.seek(end_of_headers) + + payload = buffer.read() + for transform_id in transforms: + transform_fn = READ_TRANSFORMS_BY_ID[transform_id] + payload = transform_fn(payload) + return BufferIO(payload) + + def write(self, buf): + self._write_buffer.write(buf) + + def flush(self): + payload = self._write_buffer.getvalue() + self._write_buffer = BufferIO() + + buffer = BufferIO() + if self._client_type == THeaderClientType.HEADERS: + for transform_id in self._write_transforms: + transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id] + payload = transform_fn(payload) + + headers = BufferIO() + writeVarint(headers, self._protocol_id) + writeVarint(headers, len(self._write_transforms)) + for transform_id in self._write_transforms: + writeVarint(headers, transform_id) + if self._write_headers: + writeVarint(headers, TInfoHeaderType.KEY_VALUE) + writeVarint(headers, len(self._write_headers)) + for key, value in self._write_headers.items(): + _writeString(headers, key) + _writeString(headers, value) + self._write_headers = {} + padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4 + headers.write(b"\x00" * padding_needed) + header_bytes = headers.getvalue() + + buffer.write(I32.pack(10 + len(header_bytes) + len(payload))) + buffer.write(U16.pack(HEADER_MAGIC)) + buffer.write(U16.pack(self.flags)) + buffer.write(I32.pack(self.sequence_id)) + buffer.write(U16.pack(len(header_bytes) // 4)) + buffer.write(header_bytes) + buffer.write(payload) + elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT): + buffer.write(I32.pack(len(payload))) + buffer.write(payload) + elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT): + buffer.write(payload) + else: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Unknown client type.", + ) + + # the frame length field doesn't count towards the frame payload size + frame_bytes = buffer.getvalue() + frame_payload_size = len(frame_bytes) - 4 + if frame_payload_size > self._max_frame_size: + raise TTransportException( + TTransportException.SIZE_LIMIT, + "Attempting to send frame that is too large.", + ) + + self._transport.write(frame_bytes) + self._transport.flush() + + @property + def cstringio_buf(self): + return self._read_buffer + + def cstringio_refill(self, partialread, reqlen): + result = bytearray(partialread) + while len(result) < reqlen: + result += self.read(reqlen - len(result)) + self._read_buffer = BufferIO(result) + return self._read_buffer diff --git a/thrift/transport/THttpClient.py b/thrift/transport/THttpClient.py index 20be338..212da3a 100644 --- a/thrift/transport/THttpClient.py +++ b/thrift/transport/THttpClient.py @@ -17,133 +17,175 @@ # under the License. # -import http.client +from io import BytesIO import os -import socket +import ssl import sys -import urllib.request, urllib.parse, urllib.error -import urllib.parse import warnings +import base64 -from io import StringIO +from six.moves import urllib +from six.moves import http_client -from .TTransport import * +from .TTransport import TTransportBase +import six class THttpClient(TTransportBase): - """Http implementation of TTransport base.""" + """Http implementation of TTransport base.""" - def __init__(self, uri_or_host, port=None, path=None): - """THttpClient supports two different types constructor parameters. + def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): + """THttpClient supports two different types of construction: - THttpClient(host, port, path) - deprecated - THttpClient(uri) + THttpClient(host, port, path) - deprecated + THttpClient(uri, [port=, path=, cafile=, cert_file=, key_file=, ssl_context=]) - Only the second supports https. - """ - if port is not None: - warnings.warn( - "Please use the THttpClient('http://host:port/path') syntax", - DeprecationWarning, - stacklevel=2) - self.host = uri_or_host - self.port = port - assert path - self.path = path - self.scheme = 'http' - else: - parsed = urllib.parse.urlparse(uri_or_host) - self.scheme = parsed.scheme - assert self.scheme in ('http', 'https') - if self.scheme == 'http': - self.port = parsed.port or http.client.HTTP_PORT - elif self.scheme == 'https': - self.port = parsed.port or http.client.HTTPS_PORT - self.host = parsed.hostname - self.path = parsed.path - if parsed.query: - self.path += '?%s' % parsed.query - self.__wbuf = StringIO() - self.__http = None - self.__timeout = None - self.__custom_headers = None + Only the second supports https. To properly authenticate against the server, + provide the client's identity by specifying cert_file and key_file. To properly + authenticate the server, specify either cafile or ssl_context with a CA defined. + NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile. + """ + if port is not None: + warnings.warn( + "Please use the THttpClient('http{s}://host:port/path') constructor", + DeprecationWarning, + stacklevel=2) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = 'http' + else: + parsed = urllib.parse.urlparse(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or http_client.HTTP_PORT + elif self.scheme == 'https': + self.port = parsed.port or http_client.HTTPS_PORT + self.certfile = cert_file + self.keyfile = key_file + self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context + self.host = parsed.hostname + self.path = parsed.path + if parsed.query: + self.path += '?%s' % parsed.query + try: + proxy = urllib.request.getproxies()[self.scheme] + except KeyError: + proxy = None + else: + if urllib.request.proxy_bypass(self.host): + proxy = None + if proxy: + parsed = urllib.parse.urlparse(proxy) + self.realhost = self.host + self.realport = self.port + self.host = parsed.hostname + self.port = parsed.port + self.proxy_auth = self.basic_proxy_auth_header(parsed) + else: + self.realhost = self.realport = self.proxy_auth = None + self.__wbuf = BytesIO() + self.__http = None + self.__http_response = None + self.__timeout = None + self.__custom_headers = None - def open(self): - if self.scheme == 'http': - self.__http = http.client.HTTP(self.host, self.port) - else: - self.__http = http.client.HTTPS(self.host, self.port) + @staticmethod + def basic_proxy_auth_header(proxy): + if proxy is None or not proxy.username: + return None + ap = "%s:%s" % (urllib.parse.unquote(proxy.username), + urllib.parse.unquote(proxy.password)) + cr = base64.b64encode(ap).strip() + return "Basic " + cr - def close(self): - self.__http.close() - self.__http = None + def using_proxy(self): + return self.realhost is not None - def isOpen(self): - return self.__http is not None + def open(self): + if self.scheme == 'http': + self.__http = http_client.HTTPConnection(self.host, self.port, + timeout=self.__timeout) + elif self.scheme == 'https': + self.__http = http_client.HTTPSConnection(self.host, self.port, + key_file=self.keyfile, + cert_file=self.certfile, + timeout=self.__timeout, + context=self.context) + if self.using_proxy(): + self.__http.set_tunnel(self.realhost, self.realport, + {"Proxy-Authorization": self.proxy_auth}) - def setTimeout(self, ms): - if not hasattr(socket, 'getdefaulttimeout'): - raise NotImplementedError + def close(self): + self.__http.close() + self.__http = None + self.__http_response = None - if ms is None: - self.__timeout = None - else: - self.__timeout = ms / 1000.0 + def isOpen(self): + return self.__http is not None - def setCustomHeaders(self, headers): - self.__custom_headers = headers + def setTimeout(self, ms): + if ms is None: + self.__timeout = None + else: + self.__timeout = ms / 1000.0 - def read(self, sz): - return self.__http.file.read(sz) + def setCustomHeaders(self, headers): + self.__custom_headers = headers - def write(self, buf): - self.__wbuf.write(buf) + def read(self, sz): + return self.__http_response.read(sz) - def __withTimeout(f): - def _f(*args, **kwargs): - orig_timeout = socket.getdefaulttimeout() - socket.setdefaulttimeout(args[0].__timeout) - result = f(*args, **kwargs) - socket.setdefaulttimeout(orig_timeout) - return result - return _f + def write(self, buf): + self.__wbuf.write(buf) - def flush(self): - if self.isOpen(): - self.close() - self.open() + def flush(self): + if self.isOpen(): + self.close() + self.open() - # Pull data out of buffer - data = self.__wbuf.getvalue() - self.__wbuf = StringIO() + # Pull data out of buffer + data = self.__wbuf.getvalue() + self.__wbuf = BytesIO() - # HTTP request - self.__http.putrequest('POST', self.path) + # HTTP request + if self.using_proxy() and self.scheme == "http": + # need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel) + self.__http.putrequest('POST', "http://%s:%s%s" % + (self.realhost, self.realport, self.path)) + else: + self.__http.putrequest('POST', self.path) - # Write headers - self.__http.putheader('Host', self.host) - self.__http.putheader('Content-Type', 'application/x-thrift') - self.__http.putheader('Content-Length', str(len(data))) + # Write headers + self.__http.putheader('Content-Type', 'application/x-thrift') + self.__http.putheader('Content-Length', str(len(data))) + if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None: + self.__http.putheader("Proxy-Authorization", self.proxy_auth) - if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: - user_agent = 'Python/THttpClient' - script = os.path.basename(sys.argv[0]) - if script: - user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script)) - self.__http.putheader('User-Agent', user_agent) + if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: + user_agent = 'Python/THttpClient' + script = os.path.basename(sys.argv[0]) + if script: + user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script)) + self.__http.putheader('User-Agent', user_agent) - if self.__custom_headers: - for key, val in self.__custom_headers.items(): - self.__http.putheader(key, val) + if self.__custom_headers: + for key, val in six.iteritems(self.__custom_headers): + self.__http.putheader(key, val) - self.__http.endheaders() + self.__http.endheaders() - # Write payload - self.__http.send(data) + # Write payload + self.__http.send(data) - # Get reply to flush the request - self.code, self.message, self.headers = self.__http.getreply() + # Get reply to flush the request + self.__http_response = self.__http.getresponse() + self.code = self.__http_response.status + self.message = self.__http_response.reason + self.headers = self.__http_response.msg - # Decorate if we know how to timeout - if hasattr(socket, 'getdefaulttimeout'): - flush = __withTimeout(flush) + # Saves the cookie sent by the server response + if 'Set-Cookie' in self.headers: + self.__http.putheader('Cookie', self.headers['Set-Cookie']) diff --git a/thrift/transport/TSSLSocket.py b/thrift/transport/TSSLSocket.py index e0ff4f9..5b3ae59 100644 --- a/thrift/transport/TSSLSocket.py +++ b/thrift/transport/TSSLSocket.py @@ -17,186 +17,392 @@ # under the License. # +import logging import os import socket import ssl +import sys +import warnings +from .sslcompat import _match_hostname, _match_has_ipaddress from thrift.transport import TSocket from thrift.transport.TTransport import TTransportException +logger = logging.getLogger(__name__) +warnings.filterwarnings( + 'default', category=DeprecationWarning, module=__name__) -class TSSLSocket(TSocket.TSocket): - """ - SSL implementation of client-side TSocket - This class creates outbound sockets wrapped using the - python standard ssl module for encrypted connections. +class TSSLBase(object): + # SSLContext is not available for Python < 2.7.9 + _has_ssl_context = sys.hexversion >= 0x020709F0 - The protocol used is set using the class variable - SSL_VERSION, which must be one of ssl.PROTOCOL_* and - defaults to ssl.PROTOCOL_TLSv1 for greatest security. - """ - SSL_VERSION = ssl.PROTOCOL_TLSv1 + # ciphers argument is not available for Python < 2.7.0 + _has_ciphers = sys.hexversion >= 0x020700F0 - def __init__(self, - host='localhost', - port=9090, - validate=True, - ca_certs=None, - unix_socket=None): - """Create SSL TSocket + # For python >= 2.7.9, use latest TLS that both client and server + # supports. + # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3. + # For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is + # unavailable. + _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \ + ssl.PROTOCOL_TLSv1 - @param validate: Set to False to disable SSL certificate validation - @type validate: bool - @param ca_certs: Filename to the Certificate Authority pem file, possibly a - file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to - the ssl_wrap function as the 'ca_certs' parameter. - @type ca_certs: str + def _init_context(self, ssl_version): + if self._has_ssl_context: + self._context = ssl.SSLContext(ssl_version) + if self._context.protocol == ssl.PROTOCOL_SSLv23: + self._context.options |= ssl.OP_NO_SSLv2 + self._context.options |= ssl.OP_NO_SSLv3 + else: + self._context = None + self._ssl_version = ssl_version - Raises an IOError exception if validate is True and the ca_certs file is - None, not present or unreadable. + @property + def _should_verify(self): + if self._has_ssl_context: + return self._context.verify_mode != ssl.CERT_NONE + else: + return self.cert_reqs != ssl.CERT_NONE + + @property + def ssl_version(self): + if self._has_ssl_context: + return self.ssl_context.protocol + else: + return self._ssl_version + + @property + def ssl_context(self): + return self._context + + SSL_VERSION = _default_protocol """ - self.validate = validate - self.is_valid = False - self.peercert = None - if not validate: - self.cert_reqs = ssl.CERT_NONE - else: - self.cert_reqs = ssl.CERT_REQUIRED - self.ca_certs = ca_certs - if validate: - if ca_certs is None or not os.access(ca_certs, os.R_OK): - raise IOError('Certificate Authority ca_certs file "%s" ' - 'is not readable, cannot validate SSL ' - 'certificates.' % (ca_certs)) - TSocket.TSocket.__init__(self, host, port, unix_socket) + Default SSL version. + For backwards compatibility, it can be modified. + Use __init__ keyword argument "ssl_version" instead. + """ - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - sock_family, sock_type = res[0:2] - ip_port = res[4] - plain_sock = socket.socket(sock_family, sock_type) - self.handle = ssl.wrap_socket(plain_sock, - ssl_version=self.SSL_VERSION, - do_handshake_on_connect=True, - ca_certs=self.ca_certs, - cert_reqs=self.cert_reqs) - self.handle.settimeout(self._timeout) + def _deprecated_arg(self, args, kwargs, pos, key): + if len(args) <= pos: + return + real_pos = pos + 3 + warnings.warn( + '%dth positional argument is deprecated.' + 'please use keyword argument instead.' + % real_pos, DeprecationWarning, stacklevel=3) + + if key in kwargs: + raise TypeError( + 'Duplicate argument: %dth argument and %s keyword argument.' + % (real_pos, key)) + kwargs[key] = args[pos] + + def _unix_socket_arg(self, host, port, args, kwargs): + key = 'unix_socket' + if host is None and port is None and len(args) == 1 and key not in kwargs: + kwargs[key] = args[0] + return True + return False + + def __getattr__(self, key): + if key == 'SSL_VERSION': + warnings.warn( + 'SSL_VERSION is deprecated.' + 'please use ssl_version attribute instead.', + DeprecationWarning, stacklevel=2) + return self.ssl_version + + def __init__(self, server_side, host, ssl_opts): + self._server_side = server_side + if TSSLBase.SSL_VERSION != self._default_protocol: + warnings.warn( + 'SSL_VERSION is deprecated.' + 'please use ssl_version keyword argument instead.', + DeprecationWarning, stacklevel=2) + self._context = ssl_opts.pop('ssl_context', None) + self._server_hostname = None + if not self._server_side: + self._server_hostname = ssl_opts.pop('server_hostname', host) + if self._context: + self._custom_context = True + if ssl_opts: + raise ValueError( + 'Incompatible arguments: ssl_context and %s' + % ' '.join(ssl_opts.keys())) + if not self._has_ssl_context: + raise ValueError( + 'ssl_context is not available for this version of Python') + else: + self._custom_context = False + ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION) + self._init_context(ssl_version) + self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED) + self.ca_certs = ssl_opts.pop('ca_certs', None) + self.keyfile = ssl_opts.pop('keyfile', None) + self.certfile = ssl_opts.pop('certfile', None) + self.ciphers = ssl_opts.pop('ciphers', None) + + if ssl_opts: + raise ValueError( + 'Unknown keyword arguments: ', ' '.join(ssl_opts.keys())) + + if self._should_verify: + if not self.ca_certs: + raise ValueError( + 'ca_certs is needed when cert_reqs is not ssl.CERT_NONE') + if not os.access(self.ca_certs, os.R_OK): + raise IOError('Certificate Authority ca_certs file "%s" ' + 'is not readable, cannot validate SSL ' + 'certificates.' % (self.ca_certs)) + + @property + def certfile(self): + return self._certfile + + @certfile.setter + def certfile(self, certfile): + if self._server_side and not certfile: + raise ValueError('certfile is needed for server-side') + if certfile and not os.access(certfile, os.R_OK): + raise IOError('No such certfile found: %s' % (certfile)) + self._certfile = certfile + + def _wrap_socket(self, sock): + if self._has_ssl_context: + if not self._custom_context: + self.ssl_context.verify_mode = self.cert_reqs + if self.certfile: + self.ssl_context.load_cert_chain(self.certfile, + self.keyfile) + if self.ciphers: + self.ssl_context.set_ciphers(self.ciphers) + if self.ca_certs: + self.ssl_context.load_verify_locations(self.ca_certs) + return self.ssl_context.wrap_socket( + sock, server_side=self._server_side, + server_hostname=self._server_hostname) + else: + ssl_opts = { + 'ssl_version': self._ssl_version, + 'server_side': self._server_side, + 'ca_certs': self.ca_certs, + 'keyfile': self.keyfile, + 'certfile': self.certfile, + 'cert_reqs': self.cert_reqs, + } + if self.ciphers: + if self._has_ciphers: + ssl_opts['ciphers'] = self.ciphers + else: + logger.warning( + 'ciphers is specified but ignored due to old Python version') + return ssl.wrap_socket(sock, **ssl_opts) + + +class TSSLSocket(TSocket.TSocket, TSSLBase): + """ + SSL implementation of TSocket + + This class creates outbound sockets wrapped using the + python standard ssl module for encrypted connections. + """ + + # New signature + # def __init__(self, host='localhost', port=9090, unix_socket=None, + # **ssl_args): + # Deprecated signature + # def __init__(self, host='localhost', port=9090, validate=True, + # ca_certs=None, keyfile=None, certfile=None, + # unix_socket=None, ciphers=None): + def __init__(self, host='localhost', port=9090, *args, **kwargs): + """Positional arguments: ``host``, ``port``, ``unix_socket`` + + Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, + ``ssl_version``, ``ca_certs``, + ``ciphers`` (Python 2.7.0 or later), + ``server_hostname`` (Python 2.7.9 or later) + Passed to ssl.wrap_socket. See ssl.wrap_socket documentation. + + Alternative keyword arguments: (Python 2.7.9 or later) + ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket + ``server_hostname``: Passed to SSLContext.wrap_socket + + Common keyword argument: + ``validate_callback`` (cert, hostname) -> None: + Called after SSL handshake. Can raise when hostname does not + match the cert. + ``socket_keepalive`` enable TCP keepalive, default off. + """ + self.is_valid = False + self.peercert = None + + if args: + if len(args) > 6: + raise TypeError('Too many positional argument') + if not self._unix_socket_arg(host, port, args, kwargs): + self._deprecated_arg(args, kwargs, 0, 'validate') + self._deprecated_arg(args, kwargs, 1, 'ca_certs') + self._deprecated_arg(args, kwargs, 2, 'keyfile') + self._deprecated_arg(args, kwargs, 3, 'certfile') + self._deprecated_arg(args, kwargs, 4, 'unix_socket') + self._deprecated_arg(args, kwargs, 5, 'ciphers') + + validate = kwargs.pop('validate', None) + if validate is not None: + cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE' + warnings.warn( + 'validate is deprecated. please use cert_reqs=ssl.%s instead' + % cert_reqs_name, + DeprecationWarning, stacklevel=2) + if 'cert_reqs' in kwargs: + raise TypeError('Cannot specify both validate and cert_reqs') + kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE + + unix_socket = kwargs.pop('unix_socket', None) + socket_keepalive = kwargs.pop('socket_keepalive', False) + self._validate_callback = kwargs.pop('validate_callback', _match_hostname) + TSSLBase.__init__(self, False, host, kwargs) + TSocket.TSocket.__init__(self, host, port, unix_socket, + socket_keepalive=socket_keepalive) + + def close(self): try: - self.handle.connect(ip_port) - except socket.error as e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error as e: - if self._unix_socket: - message = 'Could not connect to secure socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, - message=message) - if self.validate: - self._validate_cert() + self.handle.settimeout(0.001) + self.handle = self.handle.unwrap() + except (ssl.SSLError, socket.error, OSError): + # could not complete shutdown in a reasonable amount of time. bail. + pass + TSocket.TSocket.close(self) - def _validate_cert(self): - """internal method to validate the peer's SSL certificate, and to check the - commonName of the certificate to ensure it matches the hostname we - used to make this connection. Does not support subjectAltName records - in certificates. + @property + def validate(self): + warnings.warn('validate is deprecated. please use cert_reqs instead', + DeprecationWarning, stacklevel=2) + return self.cert_reqs != ssl.CERT_NONE - raises TTransportException if the certificate fails validation. + @validate.setter + def validate(self, value): + warnings.warn('validate is deprecated. please use cert_reqs instead', + DeprecationWarning, stacklevel=2) + self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE + + def _do_open(self, family, socktype): + plain_sock = socket.socket(family, socktype) + try: + return self._wrap_socket(plain_sock) + except Exception as ex: + plain_sock.close() + msg = 'failed to initialize SSL' + logger.exception(msg) + raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex) + + def open(self): + super(TSSLSocket, self).open() + if self._should_verify: + self.peercert = self.handle.getpeercert() + try: + self._validate_callback(self.peercert, self._server_hostname) + self.is_valid = True + except TTransportException: + raise + except Exception as ex: + raise TTransportException(message=str(ex), inner=ex) + + +class TSSLServerSocket(TSocket.TServerSocket, TSSLBase): + """SSL implementation of TServerSocket + + This uses the ssl module's wrap_socket() method to provide SSL + negotiated encryption. """ - cert = self.handle.getpeercert() - self.peercert = cert - if 'subject' not in cert: - raise TTransportException( - type=TTransportException.NOT_OPEN, - message='No SSL certificate found from %s:%s' % (self.host, self.port)) - fields = cert['subject'] - for field in fields: - # ensure structure we get back is what we expect - if not isinstance(field, tuple): - continue - cert_pair = field[0] - if len(cert_pair) < 2: - continue - cert_key, cert_value = cert_pair[0:2] - if cert_key != 'commonName': - continue - certhost = cert_value - if certhost == self.host: - # success, cert commonName matches desired hostname - self.is_valid = True - return - else: - raise TTransportException( - type=TTransportException.UNKNOWN, - message='Hostname we connected to "%s" doesn\'t match certificate ' - 'provided commonName "%s"' % (self.host, certhost)) - raise TTransportException( - type=TTransportException.UNKNOWN, - message='Could not validate SSL certificate from ' - 'host "%s". Cert=%s' % (self.host, cert)) + # New signature + # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): + # Deprecated signature + # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): + def __init__(self, host=None, port=9090, *args, **kwargs): + """Positional arguments: ``host``, ``port``, ``unix_socket`` -class TSSLServerSocket(TSocket.TServerSocket): - """SSL implementation of TServerSocket + Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, + ``ca_certs``, ``ciphers`` (Python 2.7.0 or later) + See ssl.wrap_socket documentation. - This uses the ssl module's wrap_socket() method to provide SSL - negotiated encryption. - """ - SSL_VERSION = ssl.PROTOCOL_TLSv1 + Alternative keyword arguments: (Python 2.7.9 or later) + ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket + ``server_hostname``: Passed to SSLContext.wrap_socket - def __init__(self, - host=None, - port=9090, - certfile='cert.pem', - unix_socket=None): - """Initialize a TSSLServerSocket + Common keyword argument: + ``validate_callback`` (cert, hostname) -> None: + Called after SSL handshake. Can raise when hostname does not + match the cert. + """ + if args: + if len(args) > 3: + raise TypeError('Too many positional argument') + if not self._unix_socket_arg(host, port, args, kwargs): + self._deprecated_arg(args, kwargs, 0, 'certfile') + self._deprecated_arg(args, kwargs, 1, 'unix_socket') + self._deprecated_arg(args, kwargs, 2, 'ciphers') - @param certfile: filename of the server certificate, defaults to cert.pem - @type certfile: str - @param host: The hostname or IP to bind the listen socket to, - i.e. 'localhost' for only allowing local network connections. - Pass None to bind to all interfaces. - @type host: str - @param port: The port to listen on for inbound connections. - @type port: int - """ - self.setCertfile(certfile) - TSocket.TServerSocket.__init__(self, host, port) + if 'ssl_context' not in kwargs: + # Preserve existing behaviors for default values + if 'cert_reqs' not in kwargs: + kwargs['cert_reqs'] = ssl.CERT_NONE + if'certfile' not in kwargs: + kwargs['certfile'] = 'cert.pem' - def setCertfile(self, certfile): - """Set or change the server certificate file used to wrap new connections. + unix_socket = kwargs.pop('unix_socket', None) + self._validate_callback = \ + kwargs.pop('validate_callback', _match_hostname) + TSSLBase.__init__(self, True, None, kwargs) + TSocket.TServerSocket.__init__(self, host, port, unix_socket) + if self._should_verify and not _match_has_ipaddress: + raise ValueError('Need ipaddress and backports.ssl_match_hostname ' + 'module to verify client certificate') - @param certfile: The filename of the server certificate, - i.e. '/etc/certs/server.pem' - @type certfile: str + def setCertfile(self, certfile): + """Set or change the server certificate file used to wrap new + connections. - Raises an IOError exception if the certfile is not present or unreadable. - """ - if not os.access(certfile, os.R_OK): - raise IOError('No such certfile found: %s' % (certfile)) - self.certfile = certfile + @param certfile: The filename of the server certificate, + i.e. '/etc/certs/server.pem' + @type certfile: str - def accept(self): - plain_client, addr = self.handle.accept() - try: - client = ssl.wrap_socket(plain_client, certfile=self.certfile, - server_side=True, ssl_version=self.SSL_VERSION) - except ssl.SSLError as ssl_exc: - # failed handshake/ssl wrap, close socket to client - plain_client.close() - # raise ssl_exc - # We can't raise the exception, because it kills most TServer derived - # serve() methods. - # Instead, return None, and let the TServer instance deal with it in - # other exception handling. (but TSimpleServer dies anyway) - return None - result = TSocket.TSocket() - result.setHandle(client) - return result + Raises an IOError exception if the certfile is not present or unreadable. + """ + warnings.warn( + 'setCertfile is deprecated. please use certfile property instead.', + DeprecationWarning, stacklevel=2) + self.certfile = certfile + + def accept(self): + plain_client, addr = self.handle.accept() + try: + client = self._wrap_socket(plain_client) + except (ssl.SSLError, socket.error, OSError): + logger.exception('Error while accepting from %s', addr) + # failed handshake/ssl wrap, close socket to client + plain_client.close() + # raise + # We can't raise the exception, because it kills most TServer derived + # serve() methods. + # Instead, return None, and let the TServer instance deal with it in + # other exception handling. (but TSimpleServer dies anyway) + return None + + if self._should_verify: + client.peercert = client.getpeercert() + try: + self._validate_callback(client.peercert, addr[0]) + client.is_valid = True + except Exception: + logger.warn('Failed to validate client certificate address: %s', + addr[0], exc_info=True) + client.close() + plain_client.close() + return None + + result = TSocket.TSocket() + result.handle = client + return result diff --git a/thrift/transport/TSocket.py b/thrift/transport/TSocket.py index 82ce568..3c7a3ca 100644 --- a/thrift/transport/TSocket.py +++ b/thrift/transport/TSocket.py @@ -18,159 +18,222 @@ # import errno +import logging import os import socket import sys -from .TTransport import * +from .TTransport import TTransportBase, TTransportException, TServerTransportBase + +logger = logging.getLogger(__name__) class TSocketBase(TTransportBase): - def _resolveAddr(self): - if self._unix_socket is not None: - return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, - self._unix_socket)] - else: - return socket.getaddrinfo(self.host, - self.port, - socket.AF_UNSPEC, - socket.SOCK_STREAM, - 0, - socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, + self._unix_socket)] + else: + return socket.getaddrinfo(self.host, + self.port, + self._socket_family, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE) - def close(self): - if self.handle: - self.handle.close() - self.handle = None + def close(self): + if self.handle: + self.handle.close() + self.handle = None class TSocket(TSocketBase): - """Socket implementation of TTransport base.""" + """Socket implementation of TTransport base.""" - def __init__(self, host='localhost', port=9090, unix_socket=None): - """Initialize a TSocket + def __init__(self, host='localhost', port=9090, unix_socket=None, + socket_family=socket.AF_UNSPEC, + socket_keepalive=False): + """Initialize a TSocket - @param host(str) The host to connect to. - @param port(int) The (TCP) port to connect to. - @param unix_socket(str) The filename of a unix socket to connect to. - (host and port will be ignored.) - """ - self.host = host - self.port = port - self.handle = None - self._unix_socket = unix_socket - self._timeout = None + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + @param socket_family(int) The socket family to use with this socket. + @param socket_keepalive(bool) enable TCP keepalive, default off. + """ + self.host = host + self.port = port + self.handle = None + self._unix_socket = unix_socket + self._timeout = None + self._socket_family = socket_family + self._socket_keepalive = socket_keepalive - def setHandle(self, h): - self.handle = h + def setHandle(self, h): + self.handle = h - def isOpen(self): - return self.handle is not None + def isOpen(self): + if self.handle is None: + return False - def setTimeout(self, ms): - if ms is None: - self._timeout = None - else: - self._timeout = ms / 1000.0 - - if self.handle is not None: - self.handle.settimeout(self._timeout) - - def open(self): - try: - res0 = self._resolveAddr() - for res in res0: - self.handle = socket.socket(res[0], res[1]) - self.handle.settimeout(self._timeout) + # this lets us cheaply see if the other end of the socket is still + # connected. if disconnected, we'll get EOF back (expressed as zero + # bytes of data) otherwise we'll get one byte or an error indicating + # we'd have to block for data. + # + # note that we're not doing this with socket.MSG_DONTWAIT because 1) + # it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us + # when timeout is non-zero. + original_timeout = self.handle.gettimeout() try: - self.handle.connect(res[4]) + self.handle.settimeout(0) + try: + peeked_bytes = self.handle.recv(1, socket.MSG_PEEK) + except (socket.error, OSError) as exc: # on modern python this is just BlockingIOError + if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN): + return True + return False + finally: + self.handle.settimeout(original_timeout) + + # the length will be zero if we got EOF (indicating connection closed) + return len(peeked_bytes) == 1 + + def setTimeout(self, ms): + if ms is None: + self._timeout = None + else: + self._timeout = ms / 1000.0 + + if self.handle is not None: + self.handle.settimeout(self._timeout) + + def _do_open(self, family, socktype): + return socket.socket(family, socktype) + + @property + def _address(self): + return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port) + + def open(self): + if self.handle: + raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open") + try: + addrs = self._resolveAddr() + except socket.gaierror as gai: + msg = 'failed to resolve sockaddr for ' + str(self._address) + logger.exception(msg) + raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai) + for family, socktype, _, _, sockaddr in addrs: + handle = self._do_open(family, socktype) + + # TCP_KEEPALIVE + if self._socket_keepalive: + handle.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1) + + handle.settimeout(self._timeout) + try: + handle.connect(sockaddr) + self.handle = handle + return + except socket.error: + handle.close() + logger.info('Could not connect to %s', sockaddr, exc_info=True) + msg = 'Could not connect to any of %s' % list(map(lambda a: a[4], + addrs)) + logger.error(msg) + raise TTransportException(type=TTransportException.NOT_OPEN, message=msg) + + def read(self, sz): + try: + buff = self.handle.recv(sz) except socket.error as e: - if res is not res0[-1]: - continue - else: - raise e - break - except socket.error as e: - if self._unix_socket: - message = 'Could not connect to socket %s' % self._unix_socket - else: - message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, - message=message) + if (e.args[0] == errno.ECONNRESET and + (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): + # freebsd and Mach don't follow POSIX semantic of recv + # and fail with ECONNRESET if peer performed shutdown. + # See corresponding comment and code in TSocket::read() + # in lib/cpp/src/transport/TSocket.cpp. + self.close() + # Trigger the check to raise the END_OF_FILE exception below. + buff = '' + elif e.args[0] == errno.ETIMEDOUT: + raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e) + else: + raise TTransportException(message="unexpected exception", inner=e) + if len(buff) == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') + return buff - def read(self, sz): - try: - buff = self.handle.recv(sz) - except socket.error as e: - if (e.args[0] == errno.ECONNRESET and - (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): - # freebsd and Mach don't follow POSIX semantic of recv - # and fail with ECONNRESET if peer performed shutdown. - # See corresponding comment and code in TSocket::read() - # in lib/cpp/src/transport/TSocket.cpp. - self.close() - # Trigger the check to raise the END_OF_FILE exception below. - buff = '' - else: - raise - if len(buff) == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, - message='TSocket read 0 bytes') - return buff + def write(self, buff): + if not self.handle: + raise TTransportException(type=TTransportException.NOT_OPEN, + message='Transport not open') + sent = 0 + have = len(buff) + while sent < have: + try: + plus = self.handle.send(buff) + if plus == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket sent 0 bytes') + sent += plus + buff = buff[plus:] + except socket.error as e: + raise TTransportException(message="unexpected exception", inner=e) - def write(self, buff): - if not self.handle: - raise TTransportException(type=TTransportException.NOT_OPEN, - message='Transport not open') - sent = 0 - have = len(buff) - while sent < have: - plus = self.handle.send(buff) - if plus == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, - message='TSocket sent 0 bytes') - sent += plus - buff = buff[plus:] - - def flush(self): - pass + def flush(self): + pass class TServerSocket(TSocketBase, TServerTransportBase): - """Socket implementation of TServerTransport base.""" + """Socket implementation of TServerTransport base.""" - def __init__(self, host=None, port=9090, unix_socket=None): - self.host = host - self.port = port - self._unix_socket = unix_socket - self.handle = None + def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): + self.host = host + self.port = port + self._unix_socket = unix_socket + self._socket_family = socket_family + self.handle = None + self._backlog = 128 - def listen(self): - res0 = self._resolveAddr() - for res in res0: - if res[0] is socket.AF_INET6 or res is res0[-1]: - break + def setBacklog(self, backlog=None): + if not self.handle: + self._backlog = backlog + else: + # We cann't update backlog when it is already listening, since the + # handle has been created. + logger.warn('You have to set backlog before listen.') - # We need remove the old unix socket if the file exists and - # nobody is listening on it. - if self._unix_socket: - tmp = socket.socket(res[0], res[1]) - try: - tmp.connect(res[4]) - except socket.error as err: - eno, message = err.args - if eno == errno.ECONNREFUSED: - os.unlink(res[4]) + def listen(self): + res0 = self._resolveAddr() + socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family + for res in res0: + if res[0] is socket_family or res is res0[-1]: + break - self.handle = socket.socket(res[0], res[1]) - self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if hasattr(self.handle, 'settimeout'): - self.handle.settimeout(None) - self.handle.bind(res[4]) - self.handle.listen(128) + # We need remove the old unix socket if the file exists and + # nobody is listening on it. + if self._unix_socket: + tmp = socket.socket(res[0], res[1]) + try: + tmp.connect(res[4]) + except socket.error as err: + eno, message = err.args + if eno == errno.ECONNREFUSED: + os.unlink(res[4]) - def accept(self): - client, addr = self.handle.accept() - result = TSocket() - result.setHandle(client) - return result + self.handle = socket.socket(res[0], res[1]) + self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(self.handle, 'settimeout'): + self.handle.settimeout(None) + self.handle.bind(res[4]) + self.handle.listen(self._backlog) + + def accept(self): + client, addr = self.handle.accept() + result = TSocket() + result.setHandle(client) + return result diff --git a/thrift/transport/TTransport.py b/thrift/transport/TTransport.py index dcedd3d..9dbe95d 100644 --- a/thrift/transport/TTransport.py +++ b/thrift/transport/TTransport.py @@ -17,317 +17,440 @@ # under the License. # -from six import BytesIO from struct import pack, unpack from thrift.Thrift import TException +from ..compat import BufferIO class TTransportException(TException): - """Custom Transport Exception class""" + """Custom Transport Exception class""" - UNKNOWN = 0 - NOT_OPEN = 1 - ALREADY_OPEN = 2 - TIMED_OUT = 3 - END_OF_FILE = 4 + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + NEGATIVE_SIZE = 5 + SIZE_LIMIT = 6 + INVALID_CLIENT_TYPE = 7 - def __init__(self, type=UNKNOWN, message=None): - TException.__init__(self, message) - self.type = type + def __init__(self, type=UNKNOWN, message=None, inner=None): + TException.__init__(self, message) + self.type = type + self.inner = inner -class TTransportBase: - """Base class for Thrift transport layer.""" +class TTransportBase(object): + """Base class for Thrift transport layer.""" - def isOpen(self): - pass + def isOpen(self): + pass - def open(self): - pass + def open(self): + pass - def close(self): - pass + def close(self): + pass - def read(self, sz): - pass + def read(self, sz): + pass - def readAll(self, sz): - buff = b'' - have = 0 - while (have < sz): - chunk = self.read(sz - have) - have += len(chunk) - buff += chunk + def readAll(self, sz): + buff = b'' + have = 0 + while (have < sz): + chunk = self.read(sz - have) + chunkLen = len(chunk) + have += chunkLen + buff += chunk - if len(chunk) == 0: - raise EOFError() + if chunkLen == 0: + raise EOFError() - return buff + return buff - def write(self, buf): - pass + def write(self, buf): + pass - def flush(self): - pass + def flush(self): + pass # This class should be thought of as an interface. -class CReadableTransport: - """base class for transports that are readable from C""" +class CReadableTransport(object): + """base class for transports that are readable from C""" - # TODO(dreiss): Think about changing this interface to allow us to use - # a (Python, not c) StringIO instead, because it allows - # you to write after reading. + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. - # NOTE: This is a classic class, so properties will NOT work - # correctly for setting. - @property - def cstringio_buf(self): - """A cStringIO buffer that contains the current chunk we are reading.""" - pass + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass - def cstringio_refill(self, partialread, reqlen): - """Refills cstringio_buf. + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. - Returns the currently used buffer (which can but need not be the same as - the old cstringio_buf). partialread is what the C code has read from the - buffer, and should be inserted into the buffer before any more reads. The - return value must be a new, not borrowed reference. Something along the - lines of self._buf should be fine. + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. - If reqlen bytes can't be read, throw EOFError. - """ - pass + If reqlen bytes can't be read, throw EOFError. + """ + pass -class TServerTransportBase: - """Base class for Thrift server transports.""" +class TServerTransportBase(object): + """Base class for Thrift server transports.""" - def listen(self): - pass + def listen(self): + pass - def accept(self): - pass + def accept(self): + pass - def close(self): - pass + def close(self): + pass -class TTransportFactoryBase: - """Base class for a Transport Factory""" +class TTransportFactoryBase(object): + """Base class for a Transport Factory""" - def getTransport(self, trans): - return trans + def getTransport(self, trans): + return trans -class TBufferedTransportFactory: - """Factory transport that builds buffered transports""" +class TBufferedTransportFactory(object): + """Factory transport that builds buffered transports""" - def getTransport(self, trans): - buffered = TBufferedTransport(trans) - return buffered + def getTransport(self, trans): + buffered = TBufferedTransport(trans) + return buffered class TBufferedTransport(TTransportBase, CReadableTransport): - """Class that wraps another transport and buffers its I/O. + """Class that wraps another transport and buffers its I/O. - The implementation uses a (configurable) fixed-size read buffer - but buffers all writes until a flush is performed. - """ - DEFAULT_BUFFER = 4096 + The implementation uses a (configurable) fixed-size read buffer + but buffers all writes until a flush is performed. + """ + DEFAULT_BUFFER = 4096 - def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): - self.__trans = trans - self.__wbuf = BytesIO() - self.__rbuf = BytesIO("") - self.__rbuf_size = rbuf_size + def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): + self.__trans = trans + self.__wbuf = BufferIO() + # Pass string argument to initialize read buffer as cStringIO.InputType + self.__rbuf = BufferIO(b'') + self.__rbuf_size = rbuf_size - def isOpen(self): - return self.__trans.isOpen() + def isOpen(self): + return self.__trans.isOpen() - def open(self): - return self.__trans.open() + def open(self): + return self.__trans.open() - def close(self): - return self.__trans.close() + def close(self): + return self.__trans.close() - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size))) + return self.__rbuf.read(sz) - self.__rbuf = BytesIO(self.__trans.read(max(sz, self.__rbuf_size))) - return self.__rbuf.read(sz) + def write(self, buf): + try: + self.__wbuf.write(buf) + except Exception as e: + # on exception reset wbuf so it doesn't contain a partial function call + self.__wbuf = BufferIO() + raise e - def write(self, buf): - self.__wbuf.write(buf) + def flush(self): + out = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = BufferIO() + self.__trans.write(out) + self.__trans.flush() - def flush(self): - out = self.__wbuf.getvalue() - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = BytesIO() - self.__trans.write(out) - self.__trans.flush() + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf + def cstringio_refill(self, partialread, reqlen): + retstring = partialread + if reqlen < self.__rbuf_size: + # try to make a read of as much as we can. + retstring += self.__trans.read(self.__rbuf_size) - def cstringio_refill(self, partialread, reqlen): - retstring = partialread - if reqlen < self.__rbuf_size: - # try to make a read of as much as we can. - retstring += self.__trans.read(self.__rbuf_size) + # but make sure we do read reqlen bytes. + if len(retstring) < reqlen: + retstring += self.__trans.readAll(reqlen - len(retstring)) - # but make sure we do read reqlen bytes. - if len(retstring) < reqlen: - retstring += self.__trans.readAll(reqlen - len(retstring)) - - self.__rbuf = BytesIO(retstring) - return self.__rbuf + self.__rbuf = BufferIO(retstring) + return self.__rbuf class TMemoryBuffer(TTransportBase, CReadableTransport): - """Wraps a cStringIO object as a TTransport. + """Wraps a cBytesIO object as a TTransport. - NOTE: Unlike the C++ version of this class, you cannot write to it - then immediately read from it. If you want to read from a - TMemoryBuffer, you must either pass a string to the constructor. - TODO(dreiss): Make this work like the C++ version. - """ + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ - def __init__(self, value=None): - """value -- a value to read from for stringio + def __init__(self, value=None, offset=0): + """value -- a value to read from for stringio - If value is set, this will be a transport for reading, - otherwise, it is for writing""" - if value is not None: - self._buffer = BytesIO(value) - else: - self._buffer = BytesIO() + If value is set, this will be a transport for reading, + otherwise, it is for writing""" + if value is not None: + self._buffer = BufferIO(value) + else: + self._buffer = BufferIO() + if offset: + self._buffer.seek(offset) - def isOpen(self): - return not self._buffer.closed + def isOpen(self): + return not self._buffer.closed - def open(self): - pass + def open(self): + pass - def close(self): - self._buffer.close() + def close(self): + self._buffer.close() - def read(self, sz): - return self._buffer.read(sz) + def read(self, sz): + return self._buffer.read(sz) - def write(self, buf): - try: - self._buffer.write(buf) - except TypeError: - self._buffer.write(buf.encode('cp437')) + def write(self, buf): + self._buffer.write(buf) - def flush(self): - pass + def flush(self): + pass - def getvalue(self): - return self._buffer.getvalue() + def getvalue(self): + return self._buffer.getvalue() - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self._buffer + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer - def cstringio_refill(self, partialread, reqlen): - # only one shot at reading... - raise EOFError() + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFError() -class TFramedTransportFactory: - """Factory transport that builds framed transports""" +class TFramedTransportFactory(object): + """Factory transport that builds framed transports""" - def getTransport(self, trans): - framed = TFramedTransport(trans) - return framed + def getTransport(self, trans): + framed = TFramedTransport(trans) + return framed class TFramedTransport(TTransportBase, CReadableTransport): - """Class that wraps another transport and frames its I/O when writing.""" + """Class that wraps another transport and frames its I/O when writing.""" - def __init__(self, trans,): - self.__trans = trans - self.__rbuf = BytesIO() - self.__wbuf = BytesIO() + def __init__(self, trans,): + self.__trans = trans + self.__rbuf = BufferIO(b'') + self.__wbuf = BufferIO() - def isOpen(self): - return self.__trans.isOpen() + def isOpen(self): + return self.__trans.isOpen() - def open(self): - return self.__trans.open() + def open(self): + return self.__trans.open() - def close(self): - return self.__trans.close() + def close(self): + return self.__trans.close() - def read(self, sz): - ret = self.__rbuf.read(sz) - if len(ret) != 0: - return ret + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret - self.readFrame() - return self.__rbuf.read(sz) + self.readFrame() + return self.__rbuf.read(sz) - def readFrame(self): - buff = self.__trans.readAll(4) - sz, = unpack('!i', buff) - self.__rbuf = BytesIO(self.__trans.readAll(sz)) + def readFrame(self): + buff = self.__trans.readAll(4) + sz, = unpack('!i', buff) + self.__rbuf = BufferIO(self.__trans.readAll(sz)) - def write(self, buf): - self.__wbuf.write(buf) + def write(self, buf): + self.__wbuf.write(buf) - def flush(self): - wout = self.__wbuf.getvalue() - wsz = len(wout) - # reset wbuf before write/flush to preserve state on underlying failure - self.__wbuf = BytesIO() - # N.B.: Doing this string concatenation is WAY cheaper than making - # two separate calls to the underlying socket object. Socket writes in - # Python turn out to be REALLY expensive, but it seems to do a pretty - # good job of managing string buffer operations without excessive copies - buf = pack("!i", wsz) + wout - self.__trans.write(buf) - self.__trans.flush() + def flush(self): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = BufferIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = pack("!i", wsz) + wout + self.__trans.write(buf) + self.__trans.flush() - # Implement the CReadableTransport interface. - @property - def cstringio_buf(self): - return self.__rbuf + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf - def cstringio_refill(self, prefix, reqlen): - # self.__rbuf will already be empty here because fastbinary doesn't - # ask for a refill until the previous buffer is empty. Therefore, - # we can start reading new frames immediately. - while len(prefix) < reqlen: - self.readFrame() - prefix += self.__rbuf.getvalue() - self.__rbuf = BytesIO(prefix) - return self.__rbuf + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self.readFrame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BufferIO(prefix) + return self.__rbuf class TFileObjectTransport(TTransportBase): - """Wraps a file-like object to make it work as a Thrift transport.""" + """Wraps a file-like object to make it work as a Thrift transport.""" - def __init__(self, fileobj): - self.fileobj = fileobj + def __init__(self, fileobj): + self.fileobj = fileobj - def isOpen(self): - return True + def isOpen(self): + return True - def close(self): - self.fileobj.close() + def close(self): + self.fileobj.close() - def read(self, sz): - return self.fileobj.read(sz) + def read(self, sz): + return self.fileobj.read(sz) - def write(self, buf): - self.fileobj.write(buf) + def write(self, buf): + self.fileobj.write(buf) - def flush(self): - self.fileobj.flush() + def flush(self): + self.fileobj.flush() + + +class TSaslClientTransport(TTransportBase, CReadableTransport): + """ + SASL transport + """ + + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + def __init__(self, transport, host, service, mechanism='GSSAPI', + **sasl_kwargs): + """ + transport: an underlying transport to use, typically just a TSocket + host: the name of the server, from a SASL perspective + service: the name of the server's service, from a SASL perspective + mechanism: the name of the preferred mechanism to use + + All other kwargs will be passed to the puresasl.client.SASLClient + constructor. + """ + + from puresasl.client import SASLClient + + self.transport = transport + self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) + + self.__wbuf = BufferIO() + self.__rbuf = BufferIO(b'') + + def open(self): + if not self.transport.isOpen(): + self.transport.open() + + self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii')) + self.send_sasl_msg(self.OK, self.sasl.process()) + + while True: + status, challenge = self.recv_sasl_msg() + if status == self.OK: + self.send_sasl_msg(self.OK, self.sasl.process(challenge)) + elif status == self.COMPLETE: + if not self.sasl.complete: + raise TTransportException( + TTransportException.NOT_OPEN, + "The server erroneously indicated " + "that SASL negotiation was complete") + else: + break + else: + raise TTransportException( + TTransportException.NOT_OPEN, + "Bad SASL negotiation status: %d (%s)" + % (status, challenge)) + + def send_sasl_msg(self, status, body): + header = pack(">BI", status, len(body)) + self.transport.write(header + body) + self.transport.flush() + + def recv_sasl_msg(self): + header = self.transport.readAll(5) + status, length = unpack(">BI", header) + if length > 0: + payload = self.transport.readAll(length) + else: + payload = "" + return status, payload + + def write(self, data): + self.__wbuf.write(data) + + def flush(self): + data = self.__wbuf.getvalue() + encoded = self.sasl.wrap(data) + self.transport.write(pack("!i", len(encoded)) + encoded) + self.transport.flush() + self.__wbuf = BufferIO() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self._read_frame() + return self.__rbuf.read(sz) + + def _read_frame(self): + header = self.transport.readAll(4) + length, = unpack('!i', header) + encoded = self.transport.readAll(length) + self.__rbuf = BufferIO(self.sasl.unwrap(encoded)) + + def close(self): + self.sasl.dispose() + self.transport.close() + + # based on TFramedTransport + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self._read_frame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BufferIO(prefix) + return self.__rbuf diff --git a/thrift/transport/TTwisted.py b/thrift/transport/TTwisted.py index ffe5494..a27f0ad 100644 --- a/thrift/transport/TTwisted.py +++ b/thrift/transport/TTwisted.py @@ -17,14 +17,15 @@ # under the License. # -from io import StringIO +from io import BytesIO +import struct -from zope.interface import implements, Interface, Attribute -from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ +from zope.interface import implementer, Interface, Attribute +from twisted.internet.protocol import ServerFactory, ClientFactory, \ connectionDone from twisted.internet import defer +from twisted.internet.threads import deferToThread from twisted.protocols import basic -from twisted.python import log from twisted.web import server, resource, http from thrift.transport import TTransport @@ -33,15 +34,15 @@ from thrift.transport import TTransport class TMessageSenderTransport(TTransport.TTransportBase): def __init__(self): - self.__wbuf = StringIO() + self.__wbuf = BytesIO() def write(self, buf): self.__wbuf.write(buf) def flush(self): msg = self.__wbuf.getvalue() - self.__wbuf = StringIO() - self.sendMessage(msg) + self.__wbuf = BytesIO() + return self.sendMessage(msg) def sendMessage(self, message): raise NotImplementedError @@ -54,7 +55,7 @@ class TCallbackTransport(TMessageSenderTransport): self.func = func def sendMessage(self, message): - self.func(message) + return self.func(message) class ThriftClientProtocol(basic.Int32StringReceiver): @@ -81,11 +82,18 @@ class ThriftClientProtocol(basic.Int32StringReceiver): self.started.callback(self.client) def connectionLost(self, reason=connectionDone): - for k, v in self.client._reqs.items(): + # the called errbacks can add items to our client's _reqs, + # so we need to use a tmp, and iterate until no more requests + # are added during errbacks + if self.client: tex = TTransport.TTransportException( type=TTransport.TTransportException.END_OF_FILE, - message='Connection closed') - v.errback(tex) + message='Connection closed (%s)' % reason) + while self.client._reqs: + _, v = self.client._reqs.popitem() + v.errback(tex) + del self.client._reqs + self.client = None def stringReceived(self, frame): tr = TTransport.TMemoryBuffer(frame) @@ -101,6 +109,108 @@ class ThriftClientProtocol(basic.Int32StringReceiver): method(iprot, mtype, rseqid) +class ThriftSASLClientProtocol(ThriftClientProtocol): + + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + MAX_LENGTH = 2 ** 31 - 1 + + def __init__(self, client_class, iprot_factory, oprot_factory=None, + host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): + """ + host: the name of the server, from a SASL perspective + service: the name of the server's service, from a SASL perspective + mechanism: the name of the preferred mechanism to use + + All other kwargs will be passed to the puresasl.client.SASLClient + constructor. + """ + + from puresasl.client import SASLClient + self.SASLCLient = SASLClient + + ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory) + + self._sasl_negotiation_deferred = None + self._sasl_negotiation_status = None + self.client = None + + if host is not None: + self.createSASLClient(host, service, mechanism, **sasl_kwargs) + + def createSASLClient(self, host, service, mechanism, **kwargs): + self.sasl = self.SASLClient(host, service, mechanism, **kwargs) + + def dispatch(self, msg): + encoded = self.sasl.wrap(msg) + len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded)) + ThriftClientProtocol.dispatch(self, len_and_encoded) + + @defer.inlineCallbacks + def connectionMade(self): + self._sendSASLMessage(self.START, self.sasl.mechanism) + initial_message = yield deferToThread(self.sasl.process) + self._sendSASLMessage(self.OK, initial_message) + + while True: + status, challenge = yield self._receiveSASLMessage() + if status == self.OK: + response = yield deferToThread(self.sasl.process, challenge) + self._sendSASLMessage(self.OK, response) + elif status == self.COMPLETE: + if not self.sasl.complete: + msg = "The server erroneously indicated that SASL " \ + "negotiation was complete" + raise TTransport.TTransportException(msg, message=msg) + else: + break + else: + msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge) + raise TTransport.TTransportException(msg, message=msg) + + self._sasl_negotiation_deferred = None + ThriftClientProtocol.connectionMade(self) + + def _sendSASLMessage(self, status, body): + if body is None: + body = "" + header = struct.pack(">BI", status, len(body)) + self.transport.write(header + body) + + def _receiveSASLMessage(self): + self._sasl_negotiation_deferred = defer.Deferred() + self._sasl_negotiation_status = None + return self._sasl_negotiation_deferred + + def connectionLost(self, reason=connectionDone): + if self.client: + ThriftClientProtocol.connectionLost(self, reason) + + def dataReceived(self, data): + if self._sasl_negotiation_deferred: + # we got a sasl challenge in the format (status, length, challenge) + # save the status, let IntNStringReceiver piece the challenge data together + self._sasl_negotiation_status, = struct.unpack("B", data[0]) + ThriftClientProtocol.dataReceived(self, data[1:]) + else: + # normal frame, let IntNStringReceiver piece it together + ThriftClientProtocol.dataReceived(self, data) + + def stringReceived(self, frame): + if self._sasl_negotiation_deferred: + # the frame is just a SASL challenge + response = (self._sasl_negotiation_status, frame) + self._sasl_negotiation_deferred.callback(response) + else: + # there's a second 4 byte length prefix inside the frame + decoded_frame = self.sasl.unwrap(frame[4:]) + ThriftClientProtocol.stringReceived(self, decoded_frame) + + class ThriftServerProtocol(basic.Int32StringReceiver): MAX_LENGTH = 2 ** 31 - 1 @@ -126,7 +236,7 @@ class ThriftServerProtocol(basic.Int32StringReceiver): d = self.factory.processor.process(iprot, oprot) d.addCallbacks(self.processOk, self.processError, - callbackArgs=(tmo,)) + callbackArgs=(tmo,)) class IThriftServerFactory(Interface): @@ -147,10 +257,9 @@ class IThriftClientFactory(Interface): oprot_factory = Attribute("Output protocol factory") +@implementer(IThriftServerFactory) class ThriftServerFactory(ServerFactory): - implements(IThriftServerFactory) - protocol = ThriftServerProtocol def __init__(self, processor, iprot_factory, oprot_factory=None): @@ -162,10 +271,9 @@ class ThriftServerFactory(ServerFactory): self.oprot_factory = oprot_factory +@implementer(IThriftClientFactory) class ThriftClientFactory(ClientFactory): - implements(IThriftClientFactory) - protocol = ThriftClientProtocol def __init__(self, client_class, iprot_factory, oprot_factory=None): @@ -178,7 +286,7 @@ class ThriftClientFactory(ClientFactory): def buildProtocol(self, addr): p = self.protocol(self.client_class, self.iprot_factory, - self.oprot_factory) + self.oprot_factory) p.factory = self return p @@ -188,7 +296,7 @@ class ThriftResource(resource.Resource): allowedMethods = ('POST',) def __init__(self, processor, inputProtocolFactory, - outputProtocolFactory=None): + outputProtocolFactory=None): resource.Resource.__init__(self) self.inputProtocolFactory = inputProtocolFactory if outputProtocolFactory is None: diff --git a/thrift/transport/TZlibTransport.py b/thrift/transport/TZlibTransport.py index a21dc80..e848579 100644 --- a/thrift/transport/TZlibTransport.py +++ b/thrift/transport/TZlibTransport.py @@ -22,227 +22,227 @@ class, using the python standard library zlib module to implement data compression. """ - +from __future__ import division import zlib -from io import StringIO from .TTransport import TTransportBase, CReadableTransport +from ..compat import BufferIO class TZlibTransportFactory(object): - """Factory transport that builds zlib compressed transports. + """Factory transport that builds zlib compressed transports. - This factory caches the last single client/transport that it was passed - and returns the same TZlibTransport object that was created. + This factory caches the last single client/transport that it was passed + and returns the same TZlibTransport object that was created. - This caching means the TServer class will get the _same_ transport - object for both input and output transports from this factory. - (For non-threaded scenarios only, since the cache only holds one object) + This caching means the TServer class will get the _same_ transport + object for both input and output transports from this factory. + (For non-threaded scenarios only, since the cache only holds one object) - The purpose of this caching is to allocate only one TZlibTransport where - only one is really needed (since it must have separate read/write buffers), - and makes the statistics from getCompSavings() and getCompRatio() - easier to understand. - """ - # class scoped cache of last transport given and zlibtransport returned - _last_trans = None - _last_z = None - - def getTransport(self, trans, compresslevel=9): - """Wrap a transport, trans, with the TZlibTransport - compressed transport class, returning a new - transport to the caller. - - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Defaults to 9. - @type compresslevel: int - - This method returns a TZlibTransport which wraps the - passed C{trans} TTransport derived instance. + The purpose of this caching is to allocate only one TZlibTransport where + only one is really needed (since it must have separate read/write buffers), + and makes the statistics from getCompSavings() and getCompRatio() + easier to understand. """ - if trans == self._last_trans: - return self._last_z - ztrans = TZlibTransport(trans, compresslevel) - self._last_trans = trans - self._last_z = ztrans - return ztrans + # class scoped cache of last transport given and zlibtransport returned + _last_trans = None + _last_z = None + + def getTransport(self, trans, compresslevel=9): + """Wrap a transport, trans, with the TZlibTransport + compressed transport class, returning a new + transport to the caller. + + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Defaults to 9. + @type compresslevel: int + + This method returns a TZlibTransport which wraps the + passed C{trans} TTransport derived instance. + """ + if trans == self._last_trans: + return self._last_z + ztrans = TZlibTransport(trans, compresslevel) + self._last_trans = trans + self._last_z = ztrans + return ztrans class TZlibTransport(TTransportBase, CReadableTransport): - """Class that wraps a transport with zlib, compressing writes - and decompresses reads, using the python standard - library zlib module. - """ - # Read buffer size for the python fastbinary C extension, - # the TBinaryProtocolAccelerated class. - DEFAULT_BUFFSIZE = 4096 - - def __init__(self, trans, compresslevel=9): - """Create a new TZlibTransport, wrapping C{trans}, another - TTransport derived object. - - @param trans: A thrift transport object, i.e. a TSocket() object. - @type trans: TTransport - @param compresslevel: The zlib compression level, ranging - from 0 (no compression) to 9 (best compression). Default is 9. - @type compresslevel: int + """Class that wraps a transport with zlib, compressing writes + and decompresses reads, using the python standard + library zlib module. """ - self.__trans = trans - self.compresslevel = compresslevel - self.__rbuf = StringIO() - self.__wbuf = StringIO() - self._init_zlib() - self._init_stats() + # Read buffer size for the python fastbinary C extension, + # the TBinaryProtocolAccelerated class. + DEFAULT_BUFFSIZE = 4096 - def _reinit_buffers(self): - """Internal method to initialize/reset the internal StringIO objects - for read and write buffers. - """ - self.__rbuf = StringIO() - self.__wbuf = StringIO() + def __init__(self, trans, compresslevel=9): + """Create a new TZlibTransport, wrapping C{trans}, another + TTransport derived object. - def _init_stats(self): - """Internal method to reset the internal statistics counters - for compression ratios and bandwidth savings. - """ - self.bytes_in = 0 - self.bytes_out = 0 - self.bytes_in_comp = 0 - self.bytes_out_comp = 0 + @param trans: A thrift transport object, i.e. a TSocket() object. + @type trans: TTransport + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Default is 9. + @type compresslevel: int + """ + self.__trans = trans + self.compresslevel = compresslevel + self.__rbuf = BufferIO() + self.__wbuf = BufferIO() + self._init_zlib() + self._init_stats() - def _init_zlib(self): - """Internal method for setting up the zlib compression and - decompression objects. - """ - self._zcomp_read = zlib.decompressobj() - self._zcomp_write = zlib.compressobj(self.compresslevel) + def _reinit_buffers(self): + """Internal method to initialize/reset the internal StringIO objects + for read and write buffers. + """ + self.__rbuf = BufferIO() + self.__wbuf = BufferIO() - def getCompRatio(self): - """Get the current measured compression ratios (in,out) from - this transport. + def _init_stats(self): + """Internal method to reset the internal statistics counters + for compression ratios and bandwidth savings. + """ + self.bytes_in = 0 + self.bytes_out = 0 + self.bytes_in_comp = 0 + self.bytes_out_comp = 0 - Returns a tuple of: - (inbound_compression_ratio, outbound_compression_ratio) + def _init_zlib(self): + """Internal method for setting up the zlib compression and + decompression objects. + """ + self._zcomp_read = zlib.decompressobj() + self._zcomp_write = zlib.compressobj(self.compresslevel) - The compression ratios are computed as: - compressed / uncompressed + def getCompRatio(self): + """Get the current measured compression ratios (in,out) from + this transport. - E.g., data that compresses by 10x will have a ratio of: 0.10 - and data that compresses to half of ts original size will - have a ratio of 0.5 + Returns a tuple of: + (inbound_compression_ratio, outbound_compression_ratio) - None is returned if no bytes have yet been processed in - a particular direction. - """ - r_percent, w_percent = (None, None) - if self.bytes_in > 0: - r_percent = self.bytes_in_comp / self.bytes_in - if self.bytes_out > 0: - w_percent = self.bytes_out_comp / self.bytes_out - return (r_percent, w_percent) + The compression ratios are computed as: + compressed / uncompressed - def getCompSavings(self): - """Get the current count of saved bytes due to data - compression. + E.g., data that compresses by 10x will have a ratio of: 0.10 + and data that compresses to half of ts original size will + have a ratio of 0.5 - Returns a tuple of: - (inbound_saved_bytes, outbound_saved_bytes) + None is returned if no bytes have yet been processed in + a particular direction. + """ + r_percent, w_percent = (None, None) + if self.bytes_in > 0: + r_percent = self.bytes_in_comp / self.bytes_in + if self.bytes_out > 0: + w_percent = self.bytes_out_comp / self.bytes_out + return (r_percent, w_percent) - Note: if compression is actually expanding your - data (only likely with very tiny thrift objects), then - the values returned will be negative. - """ - r_saved = self.bytes_in - self.bytes_in_comp - w_saved = self.bytes_out - self.bytes_out_comp - return (r_saved, w_saved) + def getCompSavings(self): + """Get the current count of saved bytes due to data + compression. - def isOpen(self): - """Return the underlying transport's open status""" - return self.__trans.isOpen() + Returns a tuple of: + (inbound_saved_bytes, outbound_saved_bytes) - def open(self): - """Open the underlying transport""" - self._init_stats() - return self.__trans.open() + Note: if compression is actually expanding your + data (only likely with very tiny thrift objects), then + the values returned will be negative. + """ + r_saved = self.bytes_in - self.bytes_in_comp + w_saved = self.bytes_out - self.bytes_out_comp + return (r_saved, w_saved) - def listen(self): - """Invoke the underlying transport's listen() method""" - self.__trans.listen() + def isOpen(self): + """Return the underlying transport's open status""" + return self.__trans.isOpen() - def accept(self): - """Accept connections on the underlying transport""" - return self.__trans.accept() + def open(self): + """Open the underlying transport""" + self._init_stats() + return self.__trans.open() - def close(self): - """Close the underlying transport,""" - self._reinit_buffers() - self._init_zlib() - return self.__trans.close() + def listen(self): + """Invoke the underlying transport's listen() method""" + self.__trans.listen() - def read(self, sz): - """Read up to sz bytes from the decompressed bytes buffer, and - read from the underlying transport if the decompression - buffer is empty. - """ - ret = self.__rbuf.read(sz) - if len(ret) > 0: - return ret - # keep reading from transport until something comes back - while True: - if self.readComp(sz): - break - ret = self.__rbuf.read(sz) - return ret + def accept(self): + """Accept connections on the underlying transport""" + return self.__trans.accept() - def readComp(self, sz): - """Read compressed data from the underlying transport, then - decompress it and append it to the internal StringIO read buffer - """ - zbuf = self.__trans.read(sz) - zbuf = self._zcomp_read.unconsumed_tail + zbuf - buf = self._zcomp_read.decompress(zbuf) - self.bytes_in += len(zbuf) - self.bytes_in_comp += len(buf) - old = self.__rbuf.read() - self.__rbuf = StringIO(old + buf) - if len(old) + len(buf) == 0: - return False - return True + def close(self): + """Close the underlying transport,""" + self._reinit_buffers() + self._init_zlib() + return self.__trans.close() - def write(self, buf): - """Write some bytes, putting them into the internal write - buffer for eventual compression. - """ - self.__wbuf.write(buf) + def read(self, sz): + """Read up to sz bytes from the decompressed bytes buffer, and + read from the underlying transport if the decompression + buffer is empty. + """ + ret = self.__rbuf.read(sz) + if len(ret) > 0: + return ret + # keep reading from transport until something comes back + while True: + if self.readComp(sz): + break + ret = self.__rbuf.read(sz) + return ret - def flush(self): - """Flush any queued up data in the write buffer and ensure the - compression buffer is flushed out to the underlying transport - """ - wout = self.__wbuf.getvalue() - if len(wout) > 0: - zbuf = self._zcomp_write.compress(wout) - self.bytes_out += len(wout) - self.bytes_out_comp += len(zbuf) - else: - zbuf = '' - ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) - self.bytes_out_comp += len(ztail) - if (len(zbuf) + len(ztail)) > 0: - self.__wbuf = StringIO() - self.__trans.write(zbuf + ztail) - self.__trans.flush() + def readComp(self, sz): + """Read compressed data from the underlying transport, then + decompress it and append it to the internal StringIO read buffer + """ + zbuf = self.__trans.read(sz) + zbuf = self._zcomp_read.unconsumed_tail + zbuf + buf = self._zcomp_read.decompress(zbuf) + self.bytes_in += len(zbuf) + self.bytes_in_comp += len(buf) + old = self.__rbuf.read() + self.__rbuf = BufferIO(old + buf) + if len(old) + len(buf) == 0: + return False + return True - @property - def cstringio_buf(self): - """Implement the CReadableTransport interface""" - return self.__rbuf + def write(self, buf): + """Write some bytes, putting them into the internal write + buffer for eventual compression. + """ + self.__wbuf.write(buf) - def cstringio_refill(self, partialread, reqlen): - """Implement the CReadableTransport interface for refill""" - retstring = partialread - if reqlen < self.DEFAULT_BUFFSIZE: - retstring += self.read(self.DEFAULT_BUFFSIZE) - while len(retstring) < reqlen: - retstring += self.read(reqlen - len(retstring)) - self.__rbuf = StringIO(retstring) - return self.__rbuf + def flush(self): + """Flush any queued up data in the write buffer and ensure the + compression buffer is flushed out to the underlying transport + """ + wout = self.__wbuf.getvalue() + if len(wout) > 0: + zbuf = self._zcomp_write.compress(wout) + self.bytes_out += len(wout) + self.bytes_out_comp += len(zbuf) + else: + zbuf = '' + ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) + self.bytes_out_comp += len(ztail) + if (len(zbuf) + len(ztail)) > 0: + self.__wbuf = BufferIO() + self.__trans.write(zbuf + ztail) + self.__trans.flush() + + @property + def cstringio_buf(self): + """Implement the CReadableTransport interface""" + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + """Implement the CReadableTransport interface for refill""" + retstring = partialread + if reqlen < self.DEFAULT_BUFFSIZE: + retstring += self.read(self.DEFAULT_BUFFSIZE) + while len(retstring) < reqlen: + retstring += self.read(reqlen - len(retstring)) + self.__rbuf = BufferIO(retstring) + return self.__rbuf diff --git a/thrift/transport/sslcompat.py b/thrift/transport/sslcompat.py new file mode 100644 index 0000000..ab00cb2 --- /dev/null +++ b/thrift/transport/sslcompat.py @@ -0,0 +1,100 @@ +# +# licensed to the apache software foundation (asf) under one +# or more contributor license agreements. see the notice file +# distributed with this work for additional information +# regarding copyright ownership. the asf licenses this file +# to you under the apache license, version 2.0 (the +# "license"); you may not use this file except in compliance +# with the license. you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, +# software distributed under the license is distributed on an +# "as is" basis, without warranties or conditions of any +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import sys + +from thrift.transport.TTransport import TTransportException + +logger = logging.getLogger(__name__) + + +def legacy_validate_callback(cert, hostname): + """legacy method to validate the peer's SSL certificate, and to check + the commonName of the certificate to ensure it matches the hostname we + used to make this connection. Does not support subjectAltName records + in certificates. + + raises TTransportException if the certificate fails validation. + """ + if 'subject' not in cert: + raise TTransportException( + TTransportException.NOT_OPEN, + 'No SSL certificate found from %s' % hostname) + fields = cert['subject'] + for field in fields: + # ensure structure we get back is what we expect + if not isinstance(field, tuple): + continue + cert_pair = field[0] + if len(cert_pair) < 2: + continue + cert_key, cert_value = cert_pair[0:2] + if cert_key != 'commonName': + continue + certhost = cert_value + # this check should be performed by some sort of Access Manager + if certhost == hostname: + # success, cert commonName matches desired hostname + return + else: + raise TTransportException( + TTransportException.UNKNOWN, + 'Hostname we connected to "%s" doesn\'t match certificate ' + 'provided commonName "%s"' % (hostname, certhost)) + raise TTransportException( + TTransportException.UNKNOWN, + 'Could not validate SSL certificate from host "%s". Cert=%s' + % (hostname, cert)) + + +def _optional_dependencies(): + try: + import ipaddress # noqa + logger.debug('ipaddress module is available') + ipaddr = True + except ImportError: + logger.warn('ipaddress module is unavailable') + ipaddr = False + + if sys.hexversion < 0x030500F0: + try: + from backports.ssl_match_hostname import match_hostname, __version__ as ver + ver = list(map(int, ver.split('.'))) + logger.debug('backports.ssl_match_hostname module is available') + match = match_hostname + if ver[0] * 10 + ver[1] >= 35: + return ipaddr, match + else: + logger.warn('backports.ssl_match_hostname module is too old') + ipaddr = False + except ImportError: + logger.warn('backports.ssl_match_hostname is unavailable') + ipaddr = False + try: + from ssl import match_hostname + logger.debug('ssl.match_hostname is available') + match = match_hostname + except ImportError: + logger.warn('using legacy validation callback') + match = legacy_validate_callback + return ipaddr, match + + +_match_has_ipaddress, _match_hostname = _optional_dependencies()