mirror of
https://github.com/Unidata/python-awips.git
synced 2025-02-23 14:57:56 -05:00
Final set of changes for v20 python code:
- Brought over all new thrift files... had to untar and unzip the thrift package in awips2-rpm - then go into /lib/py/ and run `python setup.py build` - then copy all of the files that get put in the subdirectory in /build - Replaced DataAccessLayer.py with the current one from our v18.1.11 of python-awips
This commit is contained in:
parent
018afefee6
commit
3c7bd9f0de
29 changed files with 4972 additions and 2175 deletions
|
@ -1,59 +1,33 @@
|
||||||
# #
|
|
||||||
# This software was developed and / or modified by Raytheon Company,
|
|
||||||
# pursuant to Contract DG133W-05-CQ-1067 with the US Government.
|
|
||||||
#
|
|
||||||
# U.S. EXPORT CONTROLLED TECHNICAL DATA
|
|
||||||
# This software product contains export-restricted data whose
|
|
||||||
# export/transfer/disclosure is restricted by U.S. law. Dissemination
|
|
||||||
# to non-U.S. persons whether in the United States or abroad requires
|
|
||||||
# an export license or other authorization.
|
|
||||||
#
|
|
||||||
# Contractor Name: Raytheon Company
|
|
||||||
# Contractor Address: 6825 Pine Street, Suite 340
|
|
||||||
# Mail Stop B8
|
|
||||||
# Omaha, NE 68106
|
|
||||||
# 402.291.0100
|
|
||||||
#
|
|
||||||
# See the AWIPS II Master Rights File ("Master Rights File.pdf") for
|
|
||||||
# further licensing information.
|
|
||||||
# #
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Published interface for awips.dataaccess package
|
# Published interface for awips.dataaccess package
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# SOFTWARE HISTORY
|
# SOFTWARE HISTORY
|
||||||
#
|
#
|
||||||
# Date Ticket# Engineer Description
|
# Date Ticket# Engineer Description
|
||||||
# ------------ ---------- ----------- --------------------------
|
# ------------ ------- ---------- -------------------------
|
||||||
# 12/10/12 njensen Initial Creation.
|
# 12/10/12 njensen Initial Creation.
|
||||||
# Feb 14, 2013 1614 bsteffen refactor data access framework
|
# Feb 14, 2013 1614 bsteffen refactor data access framework to use single request.
|
||||||
# to use single request.
|
# 04/10/13 1871 mnash move getLatLonCoords to JGridData and add default args
|
||||||
# 04/10/13 1871 mnash move getLatLonCoords to JGridData and add default args
|
# 05/29/13 2023 dgilling Hook up ThriftClientRouter.
|
||||||
# 05/29/13 2023 dgilling Hook up ThriftClientRouter.
|
# 03/03/14 2673 bsteffen Add ability to query only ref times.
|
||||||
# 03/03/14 2673 bsteffen Add ability to query only ref times.
|
# 07/22/14 3185 njensen Added optional/default args to newDataRequest
|
||||||
# 07/22/14 3185 njensen Added optional/default args to newDataRequest
|
# 07/30/14 3185 njensen Renamed valid identifiers to optional
|
||||||
# 07/30/14 3185 njensen Renamed valid identifiers to optional
|
# Apr 26, 2015 4259 njensen Updated for new JEP API
|
||||||
# Apr 26, 2015 4259 njensen Updated for new JEP API
|
# Apr 13, 2016 5379 tgurney Add getIdentifierValues(), getRequiredIdentifiers(),
|
||||||
# Apr 13, 2016 5379 tgurney Add getIdentifierValues()
|
# and getOptionalIdentifiers()
|
||||||
# Jun 01, 2016 5587 tgurney Add new signatures for
|
# Oct 07, 2016 ---- mjames@ucar Added getForecastRun
|
||||||
# getRequiredIdentifiers() and
|
# Oct 18, 2016 5916 bsteffen Add setLazyLoadGridLatLon
|
||||||
# getOptionalIdentifiers()
|
# Oct 11, 2018 ---- mjames@ucar Added getMetarObs() getSynopticObs()
|
||||||
# Oct 18, 2016 5916 bsteffen Add setLazyLoadGridLatLon
|
|
||||||
#
|
#
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
THRIFT_HOST = "edex"
|
THRIFT_HOST = "edex"
|
||||||
|
|
||||||
USING_NATIVE_THRIFT = False
|
USING_NATIVE_THRIFT = False
|
||||||
|
|
||||||
|
|
||||||
if 'jep' in sys.modules:
|
if 'jep' in sys.modules:
|
||||||
# intentionally do not catch if this fails to import, we want it to
|
# intentionally do not catch if this fails to import, we want it to
|
||||||
# be obvious that something is configured wrong when running from within
|
# be obvious that something is configured wrong when running from within
|
||||||
|
@ -66,6 +40,147 @@ else:
|
||||||
USING_NATIVE_THRIFT = True
|
USING_NATIVE_THRIFT = True
|
||||||
|
|
||||||
|
|
||||||
|
def getRadarProductIDs(availableParms):
|
||||||
|
"""
|
||||||
|
Get only the numeric idetifiers for NEXRAD3 products.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
availableParms: Full list of radar parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of filtered parameters
|
||||||
|
"""
|
||||||
|
productIDs = []
|
||||||
|
for p in list(availableParms):
|
||||||
|
try:
|
||||||
|
if isinstance(int(p), int):
|
||||||
|
productIDs.append(str(p))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return productIDs
|
||||||
|
|
||||||
|
|
||||||
|
def getRadarProductNames(availableParms):
|
||||||
|
"""
|
||||||
|
Get only the named idetifiers for NEXRAD3 products.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
availableParms: Full list of radar parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of filtered parameters
|
||||||
|
"""
|
||||||
|
productNames = []
|
||||||
|
for p in list(availableParms):
|
||||||
|
if len(p) > 3:
|
||||||
|
productNames.append(p)
|
||||||
|
|
||||||
|
return productNames
|
||||||
|
|
||||||
|
|
||||||
|
def getMetarObs(response):
|
||||||
|
"""
|
||||||
|
Processes a DataAccessLayer "obs" response into a dictionary,
|
||||||
|
with special consideration for multi-value parameters
|
||||||
|
"presWeather", "skyCover", and "skyLayerBase".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: DAL getGeometry() list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of METAR obs
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
single_val_params = ["timeObs", "stationName", "longitude", "latitude",
|
||||||
|
"temperature", "dewpoint", "windDir",
|
||||||
|
"windSpeed", "seaLevelPress"]
|
||||||
|
multi_val_params = ["presWeather", "skyCover", "skyLayerBase"]
|
||||||
|
params = single_val_params + multi_val_params
|
||||||
|
station_names, pres_weather, sky_cov, sky_layer_base = [], [], [], []
|
||||||
|
obs = dict({params: [] for params in params})
|
||||||
|
for ob in response:
|
||||||
|
avail_params = ob.getParameters()
|
||||||
|
if "presWeather" in avail_params:
|
||||||
|
pres_weather.append(ob.getString("presWeather"))
|
||||||
|
elif "skyCover" in avail_params and "skyLayerBase" in avail_params:
|
||||||
|
sky_cov.append(ob.getString("skyCover"))
|
||||||
|
sky_layer_base.append(ob.getNumber("skyLayerBase"))
|
||||||
|
else:
|
||||||
|
# If we already have a record for this stationName, skip
|
||||||
|
if ob.getString('stationName') not in station_names:
|
||||||
|
station_names.append(ob.getString('stationName'))
|
||||||
|
for param in single_val_params:
|
||||||
|
if param in avail_params:
|
||||||
|
if param == 'timeObs':
|
||||||
|
obs[param].append(datetime.fromtimestamp(ob.getNumber(param) / 1000.0))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
obs[param].append(ob.getNumber(param))
|
||||||
|
except TypeError:
|
||||||
|
obs[param].append(ob.getString(param))
|
||||||
|
else:
|
||||||
|
obs[param].append(None)
|
||||||
|
|
||||||
|
obs['presWeather'].append(pres_weather)
|
||||||
|
obs['skyCover'].append(sky_cov)
|
||||||
|
obs['skyLayerBase'].append(sky_layer_base)
|
||||||
|
pres_weather = []
|
||||||
|
sky_cov = []
|
||||||
|
sky_layer_base = []
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
def getSynopticObs(response):
|
||||||
|
"""
|
||||||
|
Processes a DataAccessLayer "sfcobs" response into a dictionary
|
||||||
|
of available parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: DAL getGeometry() list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of synop obs
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
station_names = []
|
||||||
|
params = response[0].getParameters()
|
||||||
|
sfcobs = dict({params: [] for params in params})
|
||||||
|
for sfcob in response:
|
||||||
|
# If we already have a record for this stationId, skip
|
||||||
|
if sfcob.getString('stationId') not in station_names:
|
||||||
|
station_names.append(sfcob.getString('stationId'))
|
||||||
|
for param in params:
|
||||||
|
if param == 'timeObs':
|
||||||
|
sfcobs[param].append(datetime.fromtimestamp(sfcob.getNumber(param) / 1000.0))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
sfcobs[param].append(sfcob.getNumber(param))
|
||||||
|
except TypeError:
|
||||||
|
sfcobs[param].append(sfcob.getString(param))
|
||||||
|
|
||||||
|
return sfcobs
|
||||||
|
|
||||||
|
|
||||||
|
def getForecastRun(cycle, times):
|
||||||
|
"""
|
||||||
|
Get the latest forecast run (list of objects) from all
|
||||||
|
all cycles and times returned from DataAccessLayer "grid"
|
||||||
|
response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cycle: Forecast cycle reference time
|
||||||
|
times: All available times/cycles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataTime array for a single forecast run
|
||||||
|
"""
|
||||||
|
fcstRun = []
|
||||||
|
for t in times:
|
||||||
|
if str(t)[:19] == str(cycle):
|
||||||
|
fcstRun.append(t)
|
||||||
|
return fcstRun
|
||||||
|
|
||||||
|
|
||||||
def getAvailableTimes(request, refTimeOnly=False):
|
def getAvailableTimes(request, refTimeOnly=False):
|
||||||
"""
|
"""
|
||||||
|
@ -74,7 +189,7 @@ def getAvailableTimes(request, refTimeOnly=False):
|
||||||
Args:
|
Args:
|
||||||
request: the IDataRequest to get data for
|
request: the IDataRequest to get data for
|
||||||
refTimeOnly: optional, use True if only unique refTimes should be
|
refTimeOnly: optional, use True if only unique refTimes should be
|
||||||
returned (without a forecastHr)
|
returned (without a forecastHr)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of DataTimes
|
a list of DataTimes
|
||||||
|
@ -91,7 +206,7 @@ def getGridData(request, times=[]):
|
||||||
Args:
|
Args:
|
||||||
request: the IDataRequest to get data for
|
request: the IDataRequest to get data for
|
||||||
times: a list of DataTimes, a TimeRange, or None if the data is time
|
times: a list of DataTimes, a TimeRange, or None if the data is time
|
||||||
agnostic
|
agnostic
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of IGridData
|
a list of IGridData
|
||||||
|
@ -108,10 +223,10 @@ def getGeometryData(request, times=[]):
|
||||||
Args:
|
Args:
|
||||||
request: the IDataRequest to get data for
|
request: the IDataRequest to get data for
|
||||||
times: a list of DataTimes, a TimeRange, or None if the data is time
|
times: a list of DataTimes, a TimeRange, or None if the data is time
|
||||||
agnostic
|
agnostic
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of IGeometryData
|
a list of IGeometryData
|
||||||
"""
|
"""
|
||||||
return router.getGeometryData(request, times)
|
return router.getGeometryData(request, times)
|
||||||
|
|
||||||
|
@ -204,8 +319,9 @@ def getIdentifierValues(request, identifierKey):
|
||||||
"""
|
"""
|
||||||
return router.getIdentifierValues(request, identifierKey)
|
return router.getIdentifierValues(request, identifierKey)
|
||||||
|
|
||||||
|
|
||||||
def newDataRequest(datatype=None, **kwargs):
|
def newDataRequest(datatype=None, **kwargs):
|
||||||
""""
|
"""
|
||||||
Creates a new instance of IDataRequest suitable for the runtime environment.
|
Creates a new instance of IDataRequest suitable for the runtime environment.
|
||||||
All args are optional and exist solely for convenience.
|
All args are optional and exist solely for convenience.
|
||||||
|
|
||||||
|
@ -215,13 +331,14 @@ def newDataRequest(datatype=None, **kwargs):
|
||||||
levels: a list of levels to set on the request
|
levels: a list of levels to set on the request
|
||||||
locationNames: a list of locationNames to set on the request
|
locationNames: a list of locationNames to set on the request
|
||||||
envelope: an envelope to limit the request
|
envelope: an envelope to limit the request
|
||||||
**kwargs: any leftover kwargs will be set as identifiers
|
kwargs: any leftover kwargs will be set as identifiers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a new IDataRequest
|
a new IDataRequest
|
||||||
"""
|
"""
|
||||||
return router.newDataRequest(datatype, **kwargs)
|
return router.newDataRequest(datatype, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def getSupportedDatatypes():
|
def getSupportedDatatypes():
|
||||||
"""
|
"""
|
||||||
Gets the datatypes that are supported by the framework
|
Gets the datatypes that are supported by the framework
|
||||||
|
@ -239,7 +356,7 @@ def changeEDEXHost(newHostName):
|
||||||
method will throw a TypeError.
|
method will throw a TypeError.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
newHostHame: the EDEX host to connect to
|
newHostName: the EDEX host to connect to
|
||||||
"""
|
"""
|
||||||
if USING_NATIVE_THRIFT:
|
if USING_NATIVE_THRIFT:
|
||||||
global THRIFT_HOST
|
global THRIFT_HOST
|
||||||
|
@ -249,6 +366,7 @@ def changeEDEXHost(newHostName):
|
||||||
else:
|
else:
|
||||||
raise TypeError("Cannot call changeEDEXHost when using JepRouter.")
|
raise TypeError("Cannot call changeEDEXHost when using JepRouter.")
|
||||||
|
|
||||||
|
|
||||||
def setLazyLoadGridLatLon(lazyLoadGridLatLon):
|
def setLazyLoadGridLatLon(lazyLoadGridLatLon):
|
||||||
"""
|
"""
|
||||||
Provide a hint to the Data Access Framework indicating whether to load the
|
Provide a hint to the Data Access Framework indicating whether to load the
|
||||||
|
@ -261,7 +379,7 @@ def setLazyLoadGridLatLon(lazyLoadGridLatLon):
|
||||||
set to False if it is guaranteed that all lat/lon information is needed and
|
set to False if it is guaranteed that all lat/lon information is needed and
|
||||||
it would be better to get any performance overhead for generating the
|
it would be better to get any performance overhead for generating the
|
||||||
lat/lon data out of the way during the initial request.
|
lat/lon data out of the way during the initial request.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lazyLoadGridLatLon: Boolean value indicating whether to lazy load.
|
lazyLoadGridLatLon: Boolean value indicating whether to lazy load.
|
||||||
|
|
82
thrift/TMultiplexedProcessor.py
Normal file
82
thrift/TMultiplexedProcessor.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from thrift.Thrift import TProcessor, TMessageType
|
||||||
|
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
|
||||||
|
from thrift.protocol.TProtocol import TProtocolException
|
||||||
|
|
||||||
|
|
||||||
|
class TMultiplexedProcessor(TProcessor):
|
||||||
|
def __init__(self):
|
||||||
|
self.defaultProcessor = None
|
||||||
|
self.services = {}
|
||||||
|
|
||||||
|
def registerDefault(self, processor):
|
||||||
|
"""
|
||||||
|
If a non-multiplexed processor connects to the server and wants to
|
||||||
|
communicate, use the given processor to handle it. This mechanism
|
||||||
|
allows servers to upgrade from non-multiplexed to multiplexed in a
|
||||||
|
backwards-compatible way and still handle old clients.
|
||||||
|
"""
|
||||||
|
self.defaultProcessor = processor
|
||||||
|
|
||||||
|
def registerProcessor(self, serviceName, processor):
|
||||||
|
self.services[serviceName] = processor
|
||||||
|
|
||||||
|
def on_message_begin(self, func):
|
||||||
|
for key in self.services.keys():
|
||||||
|
self.services[key].on_message_begin(func)
|
||||||
|
|
||||||
|
def process(self, iprot, oprot):
|
||||||
|
(name, type, seqid) = iprot.readMessageBegin()
|
||||||
|
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.NOT_IMPLEMENTED,
|
||||||
|
"TMultiplexedProtocol only supports CALL & ONEWAY")
|
||||||
|
|
||||||
|
index = name.find(TMultiplexedProtocol.SEPARATOR)
|
||||||
|
if index < 0:
|
||||||
|
if self.defaultProcessor:
|
||||||
|
return self.defaultProcessor.process(
|
||||||
|
StoredMessageProtocol(iprot, (name, type, seqid)), oprot)
|
||||||
|
else:
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.NOT_IMPLEMENTED,
|
||||||
|
"Service name not found in message name: " + name + ". " +
|
||||||
|
"Did you forget to use TMultiplexedProtocol in your client?")
|
||||||
|
|
||||||
|
serviceName = name[0:index]
|
||||||
|
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
|
||||||
|
if serviceName not in self.services:
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.NOT_IMPLEMENTED,
|
||||||
|
"Service name not found: " + serviceName + ". " +
|
||||||
|
"Did you forget to call registerProcessor()?")
|
||||||
|
|
||||||
|
standardMessage = (call, type, seqid)
|
||||||
|
return self.services[serviceName].process(
|
||||||
|
StoredMessageProtocol(iprot, standardMessage), oprot)
|
||||||
|
|
||||||
|
|
||||||
|
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
|
||||||
|
def __init__(self, protocol, messageBegin):
|
||||||
|
self.messageBegin = messageBegin
|
||||||
|
|
||||||
|
def readMessageBegin(self):
|
||||||
|
return self.messageBegin
|
83
thrift/TRecursive.py
Normal file
83
thrift/TRecursive.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from thrift.Thrift import TType
|
||||||
|
|
||||||
|
TYPE_IDX = 1
|
||||||
|
SPEC_ARGS_IDX = 3
|
||||||
|
SPEC_ARGS_CLASS_REF_IDX = 0
|
||||||
|
SPEC_ARGS_THRIFT_SPEC_IDX = 1
|
||||||
|
|
||||||
|
|
||||||
|
def fix_spec(all_structs):
|
||||||
|
"""Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
|
||||||
|
for struc in all_structs:
|
||||||
|
spec = struc.thrift_spec
|
||||||
|
for thrift_spec in spec:
|
||||||
|
if thrift_spec is None:
|
||||||
|
continue
|
||||||
|
elif thrift_spec[TYPE_IDX] == TType.STRUCT:
|
||||||
|
other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec
|
||||||
|
thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other
|
||||||
|
elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET):
|
||||||
|
_fix_list_or_set(thrift_spec[SPEC_ARGS_IDX])
|
||||||
|
elif thrift_spec[TYPE_IDX] == TType.MAP:
|
||||||
|
_fix_map(thrift_spec[SPEC_ARGS_IDX])
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_list_or_set(element_type):
|
||||||
|
# For a list or set, the thrift_spec entry looks like,
|
||||||
|
# (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1
|
||||||
|
# so ``element_type`` will be,
|
||||||
|
# (TType.STRUCT, [RecList, None], False)
|
||||||
|
if element_type[0] == TType.STRUCT:
|
||||||
|
element_type[1][1] = element_type[1][0].thrift_spec
|
||||||
|
elif element_type[0] in (TType.LIST, TType.SET):
|
||||||
|
_fix_list_or_set(element_type[1])
|
||||||
|
elif element_type[0] == TType.MAP:
|
||||||
|
_fix_map(element_type[1])
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_map(element_type):
|
||||||
|
# For a map of key -> value type, ``element_type`` will be,
|
||||||
|
# (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
|
||||||
|
# which is just a normal struct definition.
|
||||||
|
#
|
||||||
|
# For a map of key -> list / set, ``element_type`` will be,
|
||||||
|
# (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False)
|
||||||
|
# and we need to process the 3rd element as a list.
|
||||||
|
#
|
||||||
|
# For a map of key -> map, ``element_type`` will be,
|
||||||
|
# (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT,
|
||||||
|
# [RecMapMap, None], False), False)
|
||||||
|
# and need to process 3rd element as a map.
|
||||||
|
|
||||||
|
# Is the map key a struct?
|
||||||
|
if element_type[0] == TType.STRUCT:
|
||||||
|
element_type[1][1] = element_type[1][0].thrift_spec
|
||||||
|
elif element_type[0] in (TType.LIST, TType.SET):
|
||||||
|
_fix_list_or_set(element_type[1])
|
||||||
|
elif element_type[0] == TType.MAP:
|
||||||
|
_fix_map(element_type[1])
|
||||||
|
|
||||||
|
# Is the map value a struct?
|
||||||
|
if element_type[2] == TType.STRUCT:
|
||||||
|
element_type[3][1] = element_type[3][0].thrift_spec
|
||||||
|
elif element_type[2] in (TType.LIST, TType.SET):
|
||||||
|
_fix_list_or_set(element_type[3])
|
||||||
|
elif element_type[2] == TType.MAP:
|
||||||
|
_fix_map(element_type[3])
|
|
@ -19,17 +19,18 @@
|
||||||
|
|
||||||
from os import path
|
from os import path
|
||||||
from SCons.Builder import Builder
|
from SCons.Builder import Builder
|
||||||
|
from six.moves import map
|
||||||
|
|
||||||
|
|
||||||
def scons_env(env, add=''):
|
def scons_env(env, add=''):
|
||||||
opath = path.dirname(path.abspath('$TARGET'))
|
opath = path.dirname(path.abspath('$TARGET'))
|
||||||
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
|
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
|
||||||
cppbuild = Builder(action=lstr)
|
cppbuild = Builder(action=lstr)
|
||||||
env.Append(BUILDERS={'ThriftCpp': cppbuild})
|
env.Append(BUILDERS={'ThriftCpp': cppbuild})
|
||||||
|
|
||||||
|
|
||||||
def gen_cpp(env, dir, file):
|
def gen_cpp(env, dir, file):
|
||||||
scons_env(env)
|
scons_env(env)
|
||||||
suffixes = ['_types.h', '_types.cpp']
|
suffixes = ['_types.h', '_types.cpp']
|
||||||
targets = ['gen-cpp/' + file + s for s in suffixes]
|
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
|
||||||
return env.ThriftCpp(targets, dir + file + '.thrift')
|
return env.ThriftCpp(targets, dir + file + '.thrift')
|
||||||
|
|
188
thrift/TTornado.py
Normal file
188
thrift/TTornado.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
from collections import deque
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from tornado import gen, iostream, ioloop, tcpserver, concurrent
|
||||||
|
|
||||||
|
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _Lock(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._waiters = deque()
|
||||||
|
|
||||||
|
def acquired(self):
|
||||||
|
return len(self._waiters) > 0
|
||||||
|
|
||||||
|
@gen.coroutine
|
||||||
|
def acquire(self):
|
||||||
|
blocker = self._waiters[-1] if self.acquired() else None
|
||||||
|
future = concurrent.Future()
|
||||||
|
self._waiters.append(future)
|
||||||
|
if blocker:
|
||||||
|
yield blocker
|
||||||
|
|
||||||
|
raise gen.Return(self._lock_context())
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
assert self.acquired(), 'Lock not aquired'
|
||||||
|
future = self._waiters.popleft()
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _lock_context(self):
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
|
||||||
|
class TTornadoStreamTransport(TTransportBase):
|
||||||
|
"""a framed, buffered transport over a Tornado stream"""
|
||||||
|
def __init__(self, host, port, stream=None, io_loop=None):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.io_loop = io_loop or ioloop.IOLoop.current()
|
||||||
|
self.__wbuf = BytesIO()
|
||||||
|
self._read_lock = _Lock()
|
||||||
|
|
||||||
|
# servers provide a ready-to-go stream
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
|
def with_timeout(self, timeout, future):
|
||||||
|
return gen.with_timeout(timeout, future, self.io_loop)
|
||||||
|
|
||||||
|
@gen.coroutine
|
||||||
|
def open(self, timeout=None):
|
||||||
|
logger.debug('socket connecting')
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||||
|
self.stream = iostream.IOStream(sock)
|
||||||
|
|
||||||
|
try:
|
||||||
|
connect = self.stream.connect((self.host, self.port))
|
||||||
|
if timeout is not None:
|
||||||
|
yield self.with_timeout(timeout, connect)
|
||||||
|
else:
|
||||||
|
yield connect
|
||||||
|
except (socket.error, IOError, ioloop.TimeoutError) as e:
|
||||||
|
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
|
||||||
|
raise TTransportException(
|
||||||
|
type=TTransportException.NOT_OPEN,
|
||||||
|
message=message)
|
||||||
|
|
||||||
|
raise gen.Return(self)
|
||||||
|
|
||||||
|
def set_close_callback(self, callback):
|
||||||
|
"""
|
||||||
|
Should be called only after open() returns
|
||||||
|
"""
|
||||||
|
self.stream.set_close_callback(callback)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# don't raise if we intend to close
|
||||||
|
self.stream.set_close_callback(None)
|
||||||
|
self.stream.close()
|
||||||
|
|
||||||
|
def read(self, _):
|
||||||
|
# The generated code for Tornado shouldn't do individual reads -- only
|
||||||
|
# frames at a time
|
||||||
|
assert False, "you're doing it wrong"
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def io_exception_context(self):
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except (socket.error, IOError) as e:
|
||||||
|
raise TTransportException(
|
||||||
|
type=TTransportException.END_OF_FILE,
|
||||||
|
message=str(e))
|
||||||
|
except iostream.StreamBufferFullError as e:
|
||||||
|
raise TTransportException(
|
||||||
|
type=TTransportException.UNKNOWN,
|
||||||
|
message=str(e))
|
||||||
|
|
||||||
|
@gen.coroutine
|
||||||
|
def readFrame(self):
|
||||||
|
# IOStream processes reads one at a time
|
||||||
|
with (yield self._read_lock.acquire()):
|
||||||
|
with self.io_exception_context():
|
||||||
|
frame_header = yield self.stream.read_bytes(4)
|
||||||
|
if len(frame_header) == 0:
|
||||||
|
raise iostream.StreamClosedError('Read zero bytes from stream')
|
||||||
|
frame_length, = struct.unpack('!i', frame_header)
|
||||||
|
frame = yield self.stream.read_bytes(frame_length)
|
||||||
|
raise gen.Return(frame)
|
||||||
|
|
||||||
|
def write(self, buf):
|
||||||
|
self.__wbuf.write(buf)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
frame = self.__wbuf.getvalue()
|
||||||
|
# reset wbuf before write/flush to preserve state on underlying failure
|
||||||
|
frame_length = struct.pack('!i', len(frame))
|
||||||
|
self.__wbuf = BytesIO()
|
||||||
|
with self.io_exception_context():
|
||||||
|
return self.stream.write(frame_length + frame)
|
||||||
|
|
||||||
|
|
||||||
|
class TTornadoServer(tcpserver.TCPServer):
|
||||||
|
def __init__(self, processor, iprot_factory, oprot_factory=None,
|
||||||
|
*args, **kwargs):
|
||||||
|
super(TTornadoServer, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._processor = processor
|
||||||
|
self._iprot_factory = iprot_factory
|
||||||
|
self._oprot_factory = (oprot_factory if oprot_factory is not None
|
||||||
|
else iprot_factory)
|
||||||
|
|
||||||
|
@gen.coroutine
|
||||||
|
def handle_stream(self, stream, address):
|
||||||
|
host, port = address[:2]
|
||||||
|
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
|
||||||
|
io_loop=self.io_loop)
|
||||||
|
oprot = self._oprot_factory.getProtocol(trans)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not trans.stream.closed():
|
||||||
|
try:
|
||||||
|
frame = yield trans.readFrame()
|
||||||
|
except TTransportException as e:
|
||||||
|
if e.type == TTransportException.END_OF_FILE:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
tr = TMemoryBuffer(frame)
|
||||||
|
iprot = self._iprot_factory.getProtocol(tr)
|
||||||
|
yield self._processor.process(iprot, oprot)
|
||||||
|
except Exception:
|
||||||
|
logger.exception('thrift exception in handle_stream')
|
||||||
|
trans.close()
|
||||||
|
|
||||||
|
logger.info('client disconnected %s:%d', host, port)
|
276
thrift/Thrift.py
276
thrift/Thrift.py
|
@ -17,141 +17,177 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import sys
|
|
||||||
|
class TType(object):
|
||||||
|
STOP = 0
|
||||||
|
VOID = 1
|
||||||
|
BOOL = 2
|
||||||
|
BYTE = 3
|
||||||
|
I08 = 3
|
||||||
|
DOUBLE = 4
|
||||||
|
I16 = 6
|
||||||
|
I32 = 8
|
||||||
|
I64 = 10
|
||||||
|
STRING = 11
|
||||||
|
UTF7 = 11
|
||||||
|
STRUCT = 12
|
||||||
|
MAP = 13
|
||||||
|
SET = 14
|
||||||
|
LIST = 15
|
||||||
|
UTF8 = 16
|
||||||
|
UTF16 = 17
|
||||||
|
|
||||||
|
_VALUES_TO_NAMES = (
|
||||||
|
'STOP',
|
||||||
|
'VOID',
|
||||||
|
'BOOL',
|
||||||
|
'BYTE',
|
||||||
|
'DOUBLE',
|
||||||
|
None,
|
||||||
|
'I16',
|
||||||
|
None,
|
||||||
|
'I32',
|
||||||
|
None,
|
||||||
|
'I64',
|
||||||
|
'STRING',
|
||||||
|
'STRUCT',
|
||||||
|
'MAP',
|
||||||
|
'SET',
|
||||||
|
'LIST',
|
||||||
|
'UTF8',
|
||||||
|
'UTF16',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TType:
|
class TMessageType(object):
|
||||||
STOP = 0
|
CALL = 1
|
||||||
VOID = 1
|
REPLY = 2
|
||||||
BOOL = 2
|
EXCEPTION = 3
|
||||||
BYTE = 3
|
ONEWAY = 4
|
||||||
I08 = 3
|
|
||||||
DOUBLE = 4
|
|
||||||
I16 = 6
|
|
||||||
I32 = 8
|
|
||||||
I64 = 10
|
|
||||||
STRING = 11
|
|
||||||
UTF7 = 11
|
|
||||||
STRUCT = 12
|
|
||||||
MAP = 13
|
|
||||||
SET = 14
|
|
||||||
LIST = 15
|
|
||||||
UTF8 = 16
|
|
||||||
UTF16 = 17
|
|
||||||
|
|
||||||
_VALUES_TO_NAMES = ('STOP',
|
|
||||||
'VOID',
|
|
||||||
'BOOL',
|
|
||||||
'BYTE',
|
|
||||||
'DOUBLE',
|
|
||||||
None,
|
|
||||||
'I16',
|
|
||||||
None,
|
|
||||||
'I32',
|
|
||||||
None,
|
|
||||||
'I64',
|
|
||||||
'STRING',
|
|
||||||
'STRUCT',
|
|
||||||
'MAP',
|
|
||||||
'SET',
|
|
||||||
'LIST',
|
|
||||||
'UTF8',
|
|
||||||
'UTF16')
|
|
||||||
|
|
||||||
|
|
||||||
class TMessageType:
|
class TProcessor(object):
|
||||||
CALL = 1
|
"""Base class for processor, which works on two streams."""
|
||||||
REPLY = 2
|
|
||||||
EXCEPTION = 3
|
|
||||||
ONEWAY = 4
|
|
||||||
|
|
||||||
|
def process(self, iprot, oprot):
|
||||||
|
"""
|
||||||
|
Process a request. The normal behvaior is to have the
|
||||||
|
processor invoke the correct handler and then it is the
|
||||||
|
server's responsibility to write the response to oprot.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class TProcessor:
|
def on_message_begin(self, func):
|
||||||
"""Base class for procsessor, which works on two streams."""
|
"""
|
||||||
|
Install a callback that receives (name, type, seqid)
|
||||||
def process(iprot, oprot):
|
after the message header is read.
|
||||||
pass
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TException(Exception):
|
class TException(Exception):
|
||||||
"""Base class for all thrift exceptions."""
|
"""Base class for all thrift exceptions."""
|
||||||
|
|
||||||
# BaseException.message is deprecated in Python v[2.6,3.0)
|
def __init__(self, message=None):
|
||||||
if (2, 6, 0) <= sys.version_info < (3, 0):
|
Exception.__init__(self, message)
|
||||||
def _get_message(self):
|
super(TException, self).__setattr__("message", message)
|
||||||
return self._message
|
|
||||||
|
|
||||||
def _set_message(self, message):
|
|
||||||
self._message = message
|
|
||||||
message = property(_get_message, _set_message)
|
|
||||||
|
|
||||||
def __init__(self, message=None):
|
|
||||||
Exception.__init__(self, message)
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
|
|
||||||
class TApplicationException(TException):
|
class TApplicationException(TException):
|
||||||
"""Application level thrift exceptions."""
|
"""Application level thrift exceptions."""
|
||||||
|
|
||||||
UNKNOWN = 0
|
UNKNOWN = 0
|
||||||
UNKNOWN_METHOD = 1
|
UNKNOWN_METHOD = 1
|
||||||
INVALID_MESSAGE_TYPE = 2
|
INVALID_MESSAGE_TYPE = 2
|
||||||
WRONG_METHOD_NAME = 3
|
WRONG_METHOD_NAME = 3
|
||||||
BAD_SEQUENCE_ID = 4
|
BAD_SEQUENCE_ID = 4
|
||||||
MISSING_RESULT = 5
|
MISSING_RESULT = 5
|
||||||
INTERNAL_ERROR = 6
|
INTERNAL_ERROR = 6
|
||||||
PROTOCOL_ERROR = 7
|
PROTOCOL_ERROR = 7
|
||||||
|
INVALID_TRANSFORM = 8
|
||||||
|
INVALID_PROTOCOL = 9
|
||||||
|
UNSUPPORTED_CLIENT_TYPE = 10
|
||||||
|
|
||||||
def __init__(self, type=UNKNOWN, message=None):
|
def __init__(self, type=UNKNOWN, message=None):
|
||||||
TException.__init__(self, message)
|
TException.__init__(self, message)
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.message:
|
if self.message:
|
||||||
return self.message
|
return self.message
|
||||||
elif self.type == self.UNKNOWN_METHOD:
|
elif self.type == self.UNKNOWN_METHOD:
|
||||||
return 'Unknown method'
|
return 'Unknown method'
|
||||||
elif self.type == self.INVALID_MESSAGE_TYPE:
|
elif self.type == self.INVALID_MESSAGE_TYPE:
|
||||||
return 'Invalid message type'
|
return 'Invalid message type'
|
||||||
elif self.type == self.WRONG_METHOD_NAME:
|
elif self.type == self.WRONG_METHOD_NAME:
|
||||||
return 'Wrong method name'
|
return 'Wrong method name'
|
||||||
elif self.type == self.BAD_SEQUENCE_ID:
|
elif self.type == self.BAD_SEQUENCE_ID:
|
||||||
return 'Bad sequence ID'
|
return 'Bad sequence ID'
|
||||||
elif self.type == self.MISSING_RESULT:
|
elif self.type == self.MISSING_RESULT:
|
||||||
return 'Missing result'
|
return 'Missing result'
|
||||||
else:
|
elif self.type == self.INTERNAL_ERROR:
|
||||||
return 'Default (unknown) TApplicationException'
|
return 'Internal error'
|
||||||
|
elif self.type == self.PROTOCOL_ERROR:
|
||||||
def read(self, iprot):
|
return 'Protocol error'
|
||||||
iprot.readStructBegin()
|
elif self.type == self.INVALID_TRANSFORM:
|
||||||
while True:
|
return 'Invalid transform'
|
||||||
(fname, ftype, fid) = iprot.readFieldBegin()
|
elif self.type == self.INVALID_PROTOCOL:
|
||||||
if ftype == TType.STOP:
|
return 'Invalid protocol'
|
||||||
break
|
elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
|
||||||
if fid == 1:
|
return 'Unsupported client type'
|
||||||
if ftype == TType.STRING:
|
|
||||||
self.message = iprot.readString()
|
|
||||||
else:
|
else:
|
||||||
iprot.skip(ftype)
|
return 'Default (unknown) TApplicationException'
|
||||||
elif fid == 2:
|
|
||||||
if ftype == TType.I32:
|
|
||||||
self.type = iprot.readI32()
|
|
||||||
else:
|
|
||||||
iprot.skip(ftype)
|
|
||||||
else:
|
|
||||||
iprot.skip(ftype)
|
|
||||||
iprot.readFieldEnd()
|
|
||||||
iprot.readStructEnd()
|
|
||||||
|
|
||||||
def write(self, oprot):
|
def read(self, iprot):
|
||||||
oprot.writeStructBegin('TApplicationException')
|
iprot.readStructBegin()
|
||||||
if self.message is not None:
|
while True:
|
||||||
oprot.writeFieldBegin('message', TType.STRING, 1)
|
(fname, ftype, fid) = iprot.readFieldBegin()
|
||||||
oprot.writeString(self.message)
|
if ftype == TType.STOP:
|
||||||
oprot.writeFieldEnd()
|
break
|
||||||
if self.type is not None:
|
if fid == 1:
|
||||||
oprot.writeFieldBegin('type', TType.I32, 2)
|
if ftype == TType.STRING:
|
||||||
oprot.writeI32(self.type)
|
self.message = iprot.readString()
|
||||||
oprot.writeFieldEnd()
|
else:
|
||||||
oprot.writeFieldStop()
|
iprot.skip(ftype)
|
||||||
oprot.writeStructEnd()
|
elif fid == 2:
|
||||||
|
if ftype == TType.I32:
|
||||||
|
self.type = iprot.readI32()
|
||||||
|
else:
|
||||||
|
iprot.skip(ftype)
|
||||||
|
else:
|
||||||
|
iprot.skip(ftype)
|
||||||
|
iprot.readFieldEnd()
|
||||||
|
iprot.readStructEnd()
|
||||||
|
|
||||||
|
def write(self, oprot):
|
||||||
|
oprot.writeStructBegin('TApplicationException')
|
||||||
|
if self.message is not None:
|
||||||
|
oprot.writeFieldBegin('message', TType.STRING, 1)
|
||||||
|
oprot.writeString(self.message)
|
||||||
|
oprot.writeFieldEnd()
|
||||||
|
if self.type is not None:
|
||||||
|
oprot.writeFieldBegin('type', TType.I32, 2)
|
||||||
|
oprot.writeI32(self.type)
|
||||||
|
oprot.writeFieldEnd()
|
||||||
|
oprot.writeFieldStop()
|
||||||
|
oprot.writeStructEnd()
|
||||||
|
|
||||||
|
|
||||||
|
class TFrozenDict(dict):
|
||||||
|
"""A dictionary that is "frozen" like a frozenset"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TFrozenDict, self).__init__(*args, **kwargs)
|
||||||
|
# Sort the items so they will be in a consistent order.
|
||||||
|
# XOR in the hash of the class so we don't collide with
|
||||||
|
# the hash of a list of tuples.
|
||||||
|
self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
|
||||||
|
|
||||||
|
def __setitem__(self, *args):
|
||||||
|
raise TypeError("Can't modify frozen TFreezableDict")
|
||||||
|
|
||||||
|
def __delitem__(self, *args):
|
||||||
|
raise TypeError("Can't modify frozen TFreezableDict")
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return self.__hashval
|
||||||
|
|
46
thrift/compat.py
Normal file
46
thrift/compat.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
|
||||||
|
from cStringIO import StringIO as BufferIO
|
||||||
|
|
||||||
|
def binary_to_str(bin_val):
|
||||||
|
return bin_val
|
||||||
|
|
||||||
|
def str_to_binary(str_val):
|
||||||
|
return str_val
|
||||||
|
|
||||||
|
def byte_index(bytes_val, i):
|
||||||
|
return ord(bytes_val[i])
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
from io import BytesIO as BufferIO # noqa
|
||||||
|
|
||||||
|
def binary_to_str(bin_val):
|
||||||
|
return bin_val.decode('utf8')
|
||||||
|
|
||||||
|
def str_to_binary(str_val):
|
||||||
|
return bytes(str_val, 'utf8')
|
||||||
|
|
||||||
|
def byte_index(bytes_val, i):
|
||||||
|
return bytes_val[i]
|
|
@ -17,65 +17,70 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from thrift.Thrift import *
|
|
||||||
from thrift.protocol import TBinaryProtocol
|
|
||||||
from thrift.transport import TTransport
|
from thrift.transport import TTransport
|
||||||
|
|
||||||
try:
|
|
||||||
from thrift.protocol import fastbinary
|
|
||||||
except:
|
|
||||||
fastbinary = None
|
|
||||||
|
|
||||||
|
|
||||||
class TBase(object):
|
class TBase(object):
|
||||||
__slots__ = []
|
__slots__ = ()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
L = ['%s=%r' % (key, getattr(self, key))
|
L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
|
||||||
for key in self.__slots__]
|
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
||||||
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return False
|
return False
|
||||||
for attr in self.__slots__:
|
for attr in self.__slots__:
|
||||||
my_val = getattr(self, attr)
|
my_val = getattr(self, attr)
|
||||||
other_val = getattr(other, attr)
|
other_val = getattr(other, attr)
|
||||||
if my_val != other_val:
|
if my_val != other_val:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not (self == other)
|
return not (self == other)
|
||||||
|
|
||||||
def read(self, iprot):
|
def read(self, iprot):
|
||||||
if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
|
if (iprot._fast_decode is not None and
|
||||||
isinstance(iprot.trans, TTransport.CReadableTransport) and
|
isinstance(iprot.trans, TTransport.CReadableTransport) and
|
||||||
self.thrift_spec is not None and
|
self.thrift_spec is not None):
|
||||||
fastbinary is not None):
|
iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
|
||||||
fastbinary.decode_binary(self,
|
else:
|
||||||
iprot.trans,
|
iprot.readStruct(self, self.thrift_spec)
|
||||||
(self.__class__, self.thrift_spec))
|
|
||||||
return
|
|
||||||
iprot.readStruct(self, self.thrift_spec)
|
|
||||||
|
|
||||||
def write(self, oprot):
|
def write(self, oprot):
|
||||||
if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
|
if (oprot._fast_encode is not None and self.thrift_spec is not None):
|
||||||
self.thrift_spec is not None and
|
oprot.trans.write(
|
||||||
fastbinary is not None):
|
oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
|
||||||
oprot.trans.write(
|
else:
|
||||||
fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
|
oprot.writeStruct(self, self.thrift_spec)
|
||||||
return
|
|
||||||
oprot.writeStruct(self, self.thrift_spec)
|
|
||||||
|
|
||||||
|
|
||||||
class TExceptionBase(Exception):
|
class TExceptionBase(TBase, Exception):
|
||||||
# old style class so python2.4 can raise exceptions derived from this
|
pass
|
||||||
# This can't inherit from TBase because of that limitation.
|
|
||||||
__slots__ = []
|
|
||||||
|
|
||||||
__repr__ = TBase.__repr__.__func__
|
|
||||||
__eq__ = TBase.__eq__.__func__
|
class TFrozenBase(TBase):
|
||||||
__ne__ = TBase.__ne__.__func__
|
def __setitem__(self, *args):
|
||||||
read = TBase.read.__func__
|
raise TypeError("Can't modify frozen struct")
|
||||||
write = TBase.write.__func__
|
|
||||||
|
def __delitem__(self, *args):
|
||||||
|
raise TypeError("Can't modify frozen struct")
|
||||||
|
|
||||||
|
def __hash__(self, *args):
|
||||||
|
return hash(self.__class__) ^ hash(self.__slots__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, iprot):
|
||||||
|
if (iprot._fast_decode is not None and
|
||||||
|
isinstance(iprot.trans, TTransport.CReadableTransport) and
|
||||||
|
cls.thrift_spec is not None):
|
||||||
|
self = cls()
|
||||||
|
return iprot._fast_decode(None, iprot,
|
||||||
|
[self.__class__, self.thrift_spec])
|
||||||
|
else:
|
||||||
|
return iprot.readStruct(cls, cls.thrift_spec, True)
|
||||||
|
|
||||||
|
|
||||||
|
class TFrozenExceptionBase(TFrozenBase, TExceptionBase):
|
||||||
|
pass
|
||||||
|
|
|
@ -17,248 +17,285 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from .TProtocol import *
|
from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
|
|
||||||
|
|
||||||
class TBinaryProtocol(TProtocolBase):
|
class TBinaryProtocol(TProtocolBase):
|
||||||
"""Binary implementation of the Thrift protocol driver."""
|
"""Binary implementation of the Thrift protocol driver."""
|
||||||
|
|
||||||
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
|
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
|
||||||
# positive, converting this into a long. If we hardcode the int value
|
# positive, converting this into a long. If we hardcode the int value
|
||||||
# instead it'll stay in 32 bit-land.
|
# instead it'll stay in 32 bit-land.
|
||||||
|
|
||||||
# VERSION_MASK = 0xffff0000
|
# VERSION_MASK = 0xffff0000
|
||||||
VERSION_MASK = -65536
|
VERSION_MASK = -65536
|
||||||
|
|
||||||
# VERSION_1 = 0x80010000
|
# VERSION_1 = 0x80010000
|
||||||
VERSION_1 = -2147418112
|
VERSION_1 = -2147418112
|
||||||
|
|
||||||
TYPE_MASK = 0x000000ff
|
TYPE_MASK = 0x000000ff
|
||||||
|
|
||||||
def __init__(self, trans, strictRead=False, strictWrite=True):
|
def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
|
||||||
TProtocolBase.__init__(self, trans)
|
TProtocolBase.__init__(self, trans)
|
||||||
self.strictRead = strictRead
|
self.strictRead = strictRead
|
||||||
self.strictWrite = strictWrite
|
self.strictWrite = strictWrite
|
||||||
|
self.string_length_limit = kwargs.get('string_length_limit', None)
|
||||||
|
self.container_length_limit = kwargs.get('container_length_limit', None)
|
||||||
|
|
||||||
def writeMessageBegin(self, name, type, seqid):
|
def _check_string_length(self, length):
|
||||||
if self.strictWrite:
|
self._check_length(self.string_length_limit, length)
|
||||||
self.writeI32(TBinaryProtocol.VERSION_1 | type)
|
|
||||||
self.writeString(name)
|
|
||||||
self.writeI32(seqid)
|
|
||||||
else:
|
|
||||||
self.writeString(name)
|
|
||||||
self.writeByte(type)
|
|
||||||
self.writeI32(seqid)
|
|
||||||
|
|
||||||
def writeMessageEnd(self):
|
def _check_container_length(self, length):
|
||||||
pass
|
self._check_length(self.container_length_limit, length)
|
||||||
|
|
||||||
def writeStructBegin(self, name):
|
def writeMessageBegin(self, name, type, seqid):
|
||||||
pass
|
if self.strictWrite:
|
||||||
|
self.writeI32(TBinaryProtocol.VERSION_1 | type)
|
||||||
|
self.writeString(name)
|
||||||
|
self.writeI32(seqid)
|
||||||
|
else:
|
||||||
|
self.writeString(name)
|
||||||
|
self.writeByte(type)
|
||||||
|
self.writeI32(seqid)
|
||||||
|
|
||||||
def writeStructEnd(self):
|
def writeMessageEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeFieldBegin(self, name, type, id):
|
def writeStructBegin(self, name):
|
||||||
self.writeByte(type)
|
pass
|
||||||
self.writeI16(id)
|
|
||||||
|
|
||||||
def writeFieldEnd(self):
|
def writeStructEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeFieldStop(self):
|
def writeFieldBegin(self, name, type, id):
|
||||||
self.writeByte(TType.STOP)
|
self.writeByte(type)
|
||||||
|
self.writeI16(id)
|
||||||
|
|
||||||
def writeMapBegin(self, ktype, vtype, size):
|
def writeFieldEnd(self):
|
||||||
self.writeByte(ktype)
|
pass
|
||||||
self.writeByte(vtype)
|
|
||||||
self.writeI32(size)
|
|
||||||
|
|
||||||
def writeMapEnd(self):
|
def writeFieldStop(self):
|
||||||
pass
|
self.writeByte(TType.STOP)
|
||||||
|
|
||||||
def writeListBegin(self, etype, size):
|
def writeMapBegin(self, ktype, vtype, size):
|
||||||
self.writeByte(etype)
|
self.writeByte(ktype)
|
||||||
self.writeI32(size)
|
self.writeByte(vtype)
|
||||||
|
self.writeI32(size)
|
||||||
|
|
||||||
def writeListEnd(self):
|
def writeMapEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeSetBegin(self, etype, size):
|
def writeListBegin(self, etype, size):
|
||||||
self.writeByte(etype)
|
self.writeByte(etype)
|
||||||
self.writeI32(size)
|
self.writeI32(size)
|
||||||
|
|
||||||
def writeSetEnd(self):
|
def writeListEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeBool(self, bool):
|
def writeSetBegin(self, etype, size):
|
||||||
if bool:
|
self.writeByte(etype)
|
||||||
self.writeByte(1)
|
self.writeI32(size)
|
||||||
else:
|
|
||||||
self.writeByte(0)
|
|
||||||
|
|
||||||
def writeByte(self, byte):
|
def writeSetEnd(self):
|
||||||
buff = pack("!b", byte)
|
pass
|
||||||
self.trans.write(buff)
|
|
||||||
|
|
||||||
def writeI16(self, i16):
|
def writeBool(self, bool):
|
||||||
buff = pack("!h", i16)
|
if bool:
|
||||||
self.trans.write(buff)
|
self.writeByte(1)
|
||||||
|
else:
|
||||||
|
self.writeByte(0)
|
||||||
|
|
||||||
def writeI32(self, i32):
|
def writeByte(self, byte):
|
||||||
buff = pack("!i", i32)
|
buff = pack("!b", byte)
|
||||||
self.trans.write(buff)
|
self.trans.write(buff)
|
||||||
|
|
||||||
def writeI64(self, i64):
|
def writeI16(self, i16):
|
||||||
buff = pack("!q", i64)
|
buff = pack("!h", i16)
|
||||||
self.trans.write(buff)
|
self.trans.write(buff)
|
||||||
|
|
||||||
def writeDouble(self, dub):
|
def writeI32(self, i32):
|
||||||
buff = pack("!d", dub)
|
buff = pack("!i", i32)
|
||||||
self.trans.write(buff)
|
self.trans.write(buff)
|
||||||
|
|
||||||
def writeString(self, str):
|
def writeI64(self, i64):
|
||||||
self.writeI32(len(str))
|
buff = pack("!q", i64)
|
||||||
self.trans.write(str)
|
self.trans.write(buff)
|
||||||
|
|
||||||
def readMessageBegin(self):
|
def writeDouble(self, dub):
|
||||||
sz = self.readI32()
|
buff = pack("!d", dub)
|
||||||
if sz < 0:
|
self.trans.write(buff)
|
||||||
version = sz & TBinaryProtocol.VERSION_MASK
|
|
||||||
if version != TBinaryProtocol.VERSION_1:
|
|
||||||
raise TProtocolException(
|
|
||||||
type=TProtocolException.BAD_VERSION,
|
|
||||||
message='Bad version in readMessageBegin: %d' % (sz))
|
|
||||||
type = sz & TBinaryProtocol.TYPE_MASK
|
|
||||||
name = self.readString()
|
|
||||||
seqid = self.readI32()
|
|
||||||
else:
|
|
||||||
if self.strictRead:
|
|
||||||
raise TProtocolException(type=TProtocolException.BAD_VERSION,
|
|
||||||
message='No protocol version header')
|
|
||||||
name = self.trans.readAll(sz)
|
|
||||||
type = self.readByte()
|
|
||||||
seqid = self.readI32()
|
|
||||||
return (name, type, seqid)
|
|
||||||
|
|
||||||
def readMessageEnd(self):
|
def writeBinary(self, str):
|
||||||
pass
|
self.writeI32(len(str))
|
||||||
|
self.trans.write(str)
|
||||||
|
|
||||||
def readStructBegin(self):
|
def readMessageBegin(self):
|
||||||
pass
|
sz = self.readI32()
|
||||||
|
if sz < 0:
|
||||||
|
version = sz & TBinaryProtocol.VERSION_MASK
|
||||||
|
if version != TBinaryProtocol.VERSION_1:
|
||||||
|
raise TProtocolException(
|
||||||
|
type=TProtocolException.BAD_VERSION,
|
||||||
|
message='Bad version in readMessageBegin: %d' % (sz))
|
||||||
|
type = sz & TBinaryProtocol.TYPE_MASK
|
||||||
|
name = self.readString()
|
||||||
|
seqid = self.readI32()
|
||||||
|
else:
|
||||||
|
if self.strictRead:
|
||||||
|
raise TProtocolException(type=TProtocolException.BAD_VERSION,
|
||||||
|
message='No protocol version header')
|
||||||
|
name = self.trans.readAll(sz)
|
||||||
|
type = self.readByte()
|
||||||
|
seqid = self.readI32()
|
||||||
|
return (name, type, seqid)
|
||||||
|
|
||||||
def readStructEnd(self):
|
def readMessageEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readFieldBegin(self):
|
def readStructBegin(self):
|
||||||
type = self.readByte()
|
pass
|
||||||
if type == TType.STOP:
|
|
||||||
return (None, type, 0)
|
|
||||||
id = self.readI16()
|
|
||||||
return (None, type, id)
|
|
||||||
|
|
||||||
def readFieldEnd(self):
|
def readStructEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readMapBegin(self):
|
def readFieldBegin(self):
|
||||||
ktype = self.readByte()
|
type = self.readByte()
|
||||||
vtype = self.readByte()
|
if type == TType.STOP:
|
||||||
size = self.readI32()
|
return (None, type, 0)
|
||||||
return (ktype, vtype, size)
|
id = self.readI16()
|
||||||
|
return (None, type, id)
|
||||||
|
|
||||||
def readMapEnd(self):
|
def readFieldEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readListBegin(self):
|
def readMapBegin(self):
|
||||||
etype = self.readByte()
|
ktype = self.readByte()
|
||||||
size = self.readI32()
|
vtype = self.readByte()
|
||||||
return (etype, size)
|
size = self.readI32()
|
||||||
|
self._check_container_length(size)
|
||||||
|
return (ktype, vtype, size)
|
||||||
|
|
||||||
def readListEnd(self):
|
def readMapEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readSetBegin(self):
|
def readListBegin(self):
|
||||||
etype = self.readByte()
|
etype = self.readByte()
|
||||||
size = self.readI32()
|
size = self.readI32()
|
||||||
return (etype, size)
|
self._check_container_length(size)
|
||||||
|
return (etype, size)
|
||||||
|
|
||||||
def readSetEnd(self):
|
def readListEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readBool(self):
|
def readSetBegin(self):
|
||||||
byte = self.readByte()
|
etype = self.readByte()
|
||||||
if byte == 0:
|
size = self.readI32()
|
||||||
return False
|
self._check_container_length(size)
|
||||||
return True
|
return (etype, size)
|
||||||
|
|
||||||
def readByte(self):
|
def readSetEnd(self):
|
||||||
buff = self.trans.readAll(1)
|
pass
|
||||||
val, = unpack('!b', buff)
|
|
||||||
return val
|
|
||||||
|
|
||||||
def readI16(self):
|
def readBool(self):
|
||||||
buff = self.trans.readAll(2)
|
byte = self.readByte()
|
||||||
val, = unpack('!h', buff)
|
if byte == 0:
|
||||||
return val
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def readI32(self):
|
def readByte(self):
|
||||||
buff = self.trans.readAll(4)
|
buff = self.trans.readAll(1)
|
||||||
try:
|
val, = unpack('!b', buff)
|
||||||
val, = unpack('!i', buff)
|
return val
|
||||||
except TypeError:
|
|
||||||
#str does not support the buffer interface
|
|
||||||
val, = unpack('!i', buff)
|
|
||||||
return val
|
|
||||||
|
|
||||||
def readI64(self):
|
def readI16(self):
|
||||||
buff = self.trans.readAll(8)
|
buff = self.trans.readAll(2)
|
||||||
val, = unpack('!q', buff)
|
val, = unpack('!h', buff)
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def readDouble(self):
|
def readI32(self):
|
||||||
buff = self.trans.readAll(8)
|
buff = self.trans.readAll(4)
|
||||||
val, = unpack('!d', buff)
|
val, = unpack('!i', buff)
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def readString(self):
|
def readI64(self):
|
||||||
len = self.readI32()
|
buff = self.trans.readAll(8)
|
||||||
str = self.trans.readAll(len)
|
val, = unpack('!q', buff)
|
||||||
return str
|
return val
|
||||||
|
|
||||||
|
def readDouble(self):
|
||||||
|
buff = self.trans.readAll(8)
|
||||||
|
val, = unpack('!d', buff)
|
||||||
|
return val
|
||||||
|
|
||||||
|
def readBinary(self):
|
||||||
|
size = self.readI32()
|
||||||
|
self._check_string_length(size)
|
||||||
|
s = self.trans.readAll(size)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class TBinaryProtocolFactory:
|
class TBinaryProtocolFactory(TProtocolFactory):
|
||||||
def __init__(self, strictRead=False, strictWrite=True):
|
def __init__(self, strictRead=False, strictWrite=True, **kwargs):
|
||||||
self.strictRead = strictRead
|
self.strictRead = strictRead
|
||||||
self.strictWrite = strictWrite
|
self.strictWrite = strictWrite
|
||||||
|
self.string_length_limit = kwargs.get('string_length_limit', None)
|
||||||
|
self.container_length_limit = kwargs.get('container_length_limit', None)
|
||||||
|
|
||||||
def getProtocol(self, trans):
|
def getProtocol(self, trans):
|
||||||
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
|
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
|
||||||
return prot
|
string_length_limit=self.string_length_limit,
|
||||||
|
container_length_limit=self.container_length_limit)
|
||||||
|
return prot
|
||||||
|
|
||||||
|
|
||||||
class TBinaryProtocolAccelerated(TBinaryProtocol):
|
class TBinaryProtocolAccelerated(TBinaryProtocol):
|
||||||
"""C-Accelerated version of TBinaryProtocol.
|
"""C-Accelerated version of TBinaryProtocol.
|
||||||
|
|
||||||
This class does not override any of TBinaryProtocol's methods,
|
This class does not override any of TBinaryProtocol's methods,
|
||||||
but the generated code recognizes it directly and will call into
|
but the generated code recognizes it directly and will call into
|
||||||
our C module to do the encoding, bypassing this object entirely.
|
our C module to do the encoding, bypassing this object entirely.
|
||||||
We inherit from TBinaryProtocol so that the normal TBinaryProtocol
|
We inherit from TBinaryProtocol so that the normal TBinaryProtocol
|
||||||
encoding can happen if the fastbinary module doesn't work for some
|
encoding can happen if the fastbinary module doesn't work for some
|
||||||
reason. (TODO(dreiss): Make this happen sanely in more cases.)
|
reason. (TODO(dreiss): Make this happen sanely in more cases.)
|
||||||
|
To disable this behavior, pass fallback=False constructor argument.
|
||||||
|
|
||||||
In order to take advantage of the C module, just use
|
In order to take advantage of the C module, just use
|
||||||
TBinaryProtocolAccelerated instead of TBinaryProtocol.
|
TBinaryProtocolAccelerated instead of TBinaryProtocol.
|
||||||
|
|
||||||
NOTE: This code was contributed by an external developer.
|
NOTE: This code was contributed by an external developer.
|
||||||
The internal Thrift team has reviewed and tested it,
|
The internal Thrift team has reviewed and tested it,
|
||||||
but we cannot guarantee that it is production-ready.
|
but we cannot guarantee that it is production-ready.
|
||||||
Please feel free to report bugs and/or success stories
|
Please feel free to report bugs and/or success stories
|
||||||
to the public mailing list.
|
to the public mailing list.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
fallback = kwargs.pop('fallback', True)
|
||||||
|
super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
from thrift.protocol import fastbinary
|
||||||
|
except ImportError:
|
||||||
|
if not fallback:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
self._fast_decode = fastbinary.decode_binary
|
||||||
|
self._fast_encode = fastbinary.encode_binary
|
||||||
|
|
||||||
|
|
||||||
class TBinaryProtocolAcceleratedFactory:
|
class TBinaryProtocolAcceleratedFactory(TProtocolFactory):
|
||||||
def getProtocol(self, trans):
|
def __init__(self,
|
||||||
return TBinaryProtocolAccelerated(trans)
|
string_length_limit=None,
|
||||||
|
container_length_limit=None,
|
||||||
|
fallback=True):
|
||||||
|
self.string_length_limit = string_length_limit
|
||||||
|
self.container_length_limit = container_length_limit
|
||||||
|
self._fallback = fallback
|
||||||
|
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
return TBinaryProtocolAccelerated(
|
||||||
|
trans,
|
||||||
|
string_length_limit=self.string_length_limit,
|
||||||
|
container_length_limit=self.container_length_limit,
|
||||||
|
fallback=self._fallback)
|
||||||
|
|
|
@ -17,9 +17,11 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from .TProtocol import *
|
from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
|
|
||||||
|
from ..compat import binary_to_str, str_to_binary
|
||||||
|
|
||||||
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
|
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
|
||||||
|
|
||||||
CLEAR = 0
|
CLEAR = 0
|
||||||
|
@ -34,370 +36,452 @@ BOOL_READ = 8
|
||||||
|
|
||||||
|
|
||||||
def make_helper(v_from, container):
|
def make_helper(v_from, container):
|
||||||
def helper(func):
|
def helper(func):
|
||||||
def nested(self, *args, **kwargs):
|
def nested(self, *args, **kwargs):
|
||||||
assert self.state in (v_from, container), (self.state, v_from, container)
|
assert self.state in (v_from, container), (self.state, v_from, container)
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
return nested
|
return nested
|
||||||
return helper
|
return helper
|
||||||
|
|
||||||
|
|
||||||
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
|
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
|
||||||
reader = make_helper(VALUE_READ, CONTAINER_READ)
|
reader = make_helper(VALUE_READ, CONTAINER_READ)
|
||||||
|
|
||||||
|
|
||||||
def makeZigZag(n, bits):
|
def makeZigZag(n, bits):
|
||||||
return (n << 1) ^ (n >> (bits - 1))
|
checkIntegerLimits(n, bits)
|
||||||
|
return (n << 1) ^ (n >> (bits - 1))
|
||||||
|
|
||||||
|
|
||||||
def fromZigZag(n):
|
def fromZigZag(n):
|
||||||
return (n >> 1) ^ -(n & 1)
|
return (n >> 1) ^ -(n & 1)
|
||||||
|
|
||||||
|
|
||||||
def writeVarint(trans, n):
|
def writeVarint(trans, n):
|
||||||
out = []
|
assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!"
|
||||||
while True:
|
out = bytearray()
|
||||||
if n & ~0x7f == 0:
|
while True:
|
||||||
out.append(n)
|
if n & ~0x7f == 0:
|
||||||
break
|
out.append(n)
|
||||||
else:
|
break
|
||||||
out.append((n & 0xff) | 0x80)
|
else:
|
||||||
n = n >> 7
|
out.append((n & 0xff) | 0x80)
|
||||||
trans.write(''.join(map(chr, out)))
|
n = n >> 7
|
||||||
|
trans.write(bytes(out))
|
||||||
|
|
||||||
|
|
||||||
def readVarint(trans):
|
def readVarint(trans):
|
||||||
result = 0
|
result = 0
|
||||||
shift = 0
|
shift = 0
|
||||||
while True:
|
while True:
|
||||||
x = trans.readAll(1)
|
x = trans.readAll(1)
|
||||||
byte = ord(x)
|
byte = ord(x)
|
||||||
result |= (byte & 0x7f) << shift
|
result |= (byte & 0x7f) << shift
|
||||||
if byte >> 7 == 0:
|
if byte >> 7 == 0:
|
||||||
return result
|
return result
|
||||||
shift += 7
|
shift += 7
|
||||||
|
|
||||||
|
|
||||||
class CompactType:
|
class CompactType(object):
|
||||||
STOP = 0x00
|
STOP = 0x00
|
||||||
TRUE = 0x01
|
TRUE = 0x01
|
||||||
FALSE = 0x02
|
FALSE = 0x02
|
||||||
BYTE = 0x03
|
BYTE = 0x03
|
||||||
I16 = 0x04
|
I16 = 0x04
|
||||||
I32 = 0x05
|
I32 = 0x05
|
||||||
I64 = 0x06
|
I64 = 0x06
|
||||||
DOUBLE = 0x07
|
DOUBLE = 0x07
|
||||||
BINARY = 0x08
|
BINARY = 0x08
|
||||||
LIST = 0x09
|
LIST = 0x09
|
||||||
SET = 0x0A
|
SET = 0x0A
|
||||||
MAP = 0x0B
|
MAP = 0x0B
|
||||||
STRUCT = 0x0C
|
STRUCT = 0x0C
|
||||||
|
|
||||||
CTYPES = {TType.STOP: CompactType.STOP,
|
|
||||||
TType.BOOL: CompactType.TRUE, # used for collection
|
CTYPES = {
|
||||||
TType.BYTE: CompactType.BYTE,
|
TType.STOP: CompactType.STOP,
|
||||||
TType.I16: CompactType.I16,
|
TType.BOOL: CompactType.TRUE, # used for collection
|
||||||
TType.I32: CompactType.I32,
|
TType.BYTE: CompactType.BYTE,
|
||||||
TType.I64: CompactType.I64,
|
TType.I16: CompactType.I16,
|
||||||
TType.DOUBLE: CompactType.DOUBLE,
|
TType.I32: CompactType.I32,
|
||||||
TType.STRING: CompactType.BINARY,
|
TType.I64: CompactType.I64,
|
||||||
TType.STRUCT: CompactType.STRUCT,
|
TType.DOUBLE: CompactType.DOUBLE,
|
||||||
TType.LIST: CompactType.LIST,
|
TType.STRING: CompactType.BINARY,
|
||||||
TType.SET: CompactType.SET,
|
TType.STRUCT: CompactType.STRUCT,
|
||||||
TType.MAP: CompactType.MAP
|
TType.LIST: CompactType.LIST,
|
||||||
}
|
TType.SET: CompactType.SET,
|
||||||
|
TType.MAP: CompactType.MAP,
|
||||||
|
}
|
||||||
|
|
||||||
TTYPES = {}
|
TTYPES = {}
|
||||||
for k, v in list(CTYPES.items()):
|
for k, v in CTYPES.items():
|
||||||
TTYPES[v] = k
|
TTYPES[v] = k
|
||||||
TTYPES[CompactType.FALSE] = TType.BOOL
|
TTYPES[CompactType.FALSE] = TType.BOOL
|
||||||
del k
|
del k
|
||||||
del v
|
del v
|
||||||
|
|
||||||
|
|
||||||
class TCompactProtocol(TProtocolBase):
|
class TCompactProtocol(TProtocolBase):
|
||||||
"""Compact implementation of the Thrift protocol driver."""
|
"""Compact implementation of the Thrift protocol driver."""
|
||||||
|
|
||||||
PROTOCOL_ID = 0x82
|
PROTOCOL_ID = 0x82
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
VERSION_MASK = 0x1f
|
VERSION_MASK = 0x1f
|
||||||
TYPE_MASK = 0xe0
|
TYPE_MASK = 0xe0
|
||||||
TYPE_SHIFT_AMOUNT = 5
|
TYPE_BITS = 0x07
|
||||||
|
TYPE_SHIFT_AMOUNT = 5
|
||||||
|
|
||||||
def __init__(self, trans):
|
def __init__(self, trans,
|
||||||
TProtocolBase.__init__(self, trans)
|
string_length_limit=None,
|
||||||
self.state = CLEAR
|
container_length_limit=None):
|
||||||
self.__last_fid = 0
|
TProtocolBase.__init__(self, trans)
|
||||||
self.__bool_fid = None
|
self.state = CLEAR
|
||||||
self.__bool_value = None
|
self.__last_fid = 0
|
||||||
self.__structs = []
|
self.__bool_fid = None
|
||||||
self.__containers = []
|
self.__bool_value = None
|
||||||
|
self.__structs = []
|
||||||
|
self.__containers = []
|
||||||
|
self.string_length_limit = string_length_limit
|
||||||
|
self.container_length_limit = container_length_limit
|
||||||
|
|
||||||
def __writeVarint(self, n):
|
def _check_string_length(self, length):
|
||||||
writeVarint(self.trans, n)
|
self._check_length(self.string_length_limit, length)
|
||||||
|
|
||||||
def writeMessageBegin(self, name, type, seqid):
|
def _check_container_length(self, length):
|
||||||
assert self.state == CLEAR
|
self._check_length(self.container_length_limit, length)
|
||||||
self.__writeUByte(self.PROTOCOL_ID)
|
|
||||||
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
|
|
||||||
self.__writeVarint(seqid)
|
|
||||||
self.__writeString(name)
|
|
||||||
self.state = VALUE_WRITE
|
|
||||||
|
|
||||||
def writeMessageEnd(self):
|
def __writeVarint(self, n):
|
||||||
assert self.state == VALUE_WRITE
|
writeVarint(self.trans, n)
|
||||||
self.state = CLEAR
|
|
||||||
|
|
||||||
def writeStructBegin(self, name):
|
def writeMessageBegin(self, name, type, seqid):
|
||||||
assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
|
assert self.state == CLEAR
|
||||||
self.__structs.append((self.state, self.__last_fid))
|
self.__writeUByte(self.PROTOCOL_ID)
|
||||||
self.state = FIELD_WRITE
|
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
|
||||||
self.__last_fid = 0
|
# The sequence id is a signed 32-bit integer but the compact protocol
|
||||||
|
# writes this out as a "var int" which is always positive, and attempting
|
||||||
|
# to write a negative number results in an infinite loop, so we may
|
||||||
|
# need to do some conversion here...
|
||||||
|
tseqid = seqid
|
||||||
|
if tseqid < 0:
|
||||||
|
tseqid = 2147483648 + (2147483648 + tseqid)
|
||||||
|
self.__writeVarint(tseqid)
|
||||||
|
self.__writeBinary(str_to_binary(name))
|
||||||
|
self.state = VALUE_WRITE
|
||||||
|
|
||||||
def writeStructEnd(self):
|
def writeMessageEnd(self):
|
||||||
assert self.state == FIELD_WRITE
|
assert self.state == VALUE_WRITE
|
||||||
self.state, self.__last_fid = self.__structs.pop()
|
self.state = CLEAR
|
||||||
|
|
||||||
def writeFieldStop(self):
|
def writeStructBegin(self, name):
|
||||||
self.__writeByte(0)
|
assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
|
||||||
|
self.__structs.append((self.state, self.__last_fid))
|
||||||
|
self.state = FIELD_WRITE
|
||||||
|
self.__last_fid = 0
|
||||||
|
|
||||||
def __writeFieldHeader(self, type, fid):
|
def writeStructEnd(self):
|
||||||
delta = fid - self.__last_fid
|
assert self.state == FIELD_WRITE
|
||||||
if 0 < delta <= 15:
|
self.state, self.__last_fid = self.__structs.pop()
|
||||||
self.__writeUByte(delta << 4 | type)
|
|
||||||
else:
|
|
||||||
self.__writeByte(type)
|
|
||||||
self.__writeI16(fid)
|
|
||||||
self.__last_fid = fid
|
|
||||||
|
|
||||||
def writeFieldBegin(self, name, type, fid):
|
def writeFieldStop(self):
|
||||||
assert self.state == FIELD_WRITE, self.state
|
self.__writeByte(0)
|
||||||
if type == TType.BOOL:
|
|
||||||
self.state = BOOL_WRITE
|
|
||||||
self.__bool_fid = fid
|
|
||||||
else:
|
|
||||||
self.state = VALUE_WRITE
|
|
||||||
self.__writeFieldHeader(CTYPES[type], fid)
|
|
||||||
|
|
||||||
def writeFieldEnd(self):
|
def __writeFieldHeader(self, type, fid):
|
||||||
assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
|
delta = fid - self.__last_fid
|
||||||
self.state = FIELD_WRITE
|
if 0 < delta <= 15:
|
||||||
|
self.__writeUByte(delta << 4 | type)
|
||||||
|
else:
|
||||||
|
self.__writeByte(type)
|
||||||
|
self.__writeI16(fid)
|
||||||
|
self.__last_fid = fid
|
||||||
|
|
||||||
def __writeUByte(self, byte):
|
def writeFieldBegin(self, name, type, fid):
|
||||||
self.trans.write(pack('!B', byte))
|
assert self.state == FIELD_WRITE, self.state
|
||||||
|
if type == TType.BOOL:
|
||||||
|
self.state = BOOL_WRITE
|
||||||
|
self.__bool_fid = fid
|
||||||
|
else:
|
||||||
|
self.state = VALUE_WRITE
|
||||||
|
self.__writeFieldHeader(CTYPES[type], fid)
|
||||||
|
|
||||||
def __writeByte(self, byte):
|
def writeFieldEnd(self):
|
||||||
self.trans.write(pack('!b', byte))
|
assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
|
||||||
|
self.state = FIELD_WRITE
|
||||||
|
|
||||||
def __writeI16(self, i16):
|
def __writeUByte(self, byte):
|
||||||
self.__writeVarint(makeZigZag(i16, 16))
|
self.trans.write(pack('!B', byte))
|
||||||
|
|
||||||
def __writeSize(self, i32):
|
def __writeByte(self, byte):
|
||||||
self.__writeVarint(i32)
|
self.trans.write(pack('!b', byte))
|
||||||
|
|
||||||
def writeCollectionBegin(self, etype, size):
|
def __writeI16(self, i16):
|
||||||
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
|
self.__writeVarint(makeZigZag(i16, 16))
|
||||||
if size <= 14:
|
|
||||||
self.__writeUByte(size << 4 | CTYPES[etype])
|
|
||||||
else:
|
|
||||||
self.__writeUByte(0xf0 | CTYPES[etype])
|
|
||||||
self.__writeSize(size)
|
|
||||||
self.__containers.append(self.state)
|
|
||||||
self.state = CONTAINER_WRITE
|
|
||||||
writeSetBegin = writeCollectionBegin
|
|
||||||
writeListBegin = writeCollectionBegin
|
|
||||||
|
|
||||||
def writeMapBegin(self, ktype, vtype, size):
|
def __writeSize(self, i32):
|
||||||
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
|
self.__writeVarint(i32)
|
||||||
if size == 0:
|
|
||||||
self.__writeByte(0)
|
|
||||||
else:
|
|
||||||
self.__writeSize(size)
|
|
||||||
self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
|
|
||||||
self.__containers.append(self.state)
|
|
||||||
self.state = CONTAINER_WRITE
|
|
||||||
|
|
||||||
def writeCollectionEnd(self):
|
def writeCollectionBegin(self, etype, size):
|
||||||
assert self.state == CONTAINER_WRITE, self.state
|
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
|
||||||
self.state = self.__containers.pop()
|
if size <= 14:
|
||||||
writeMapEnd = writeCollectionEnd
|
self.__writeUByte(size << 4 | CTYPES[etype])
|
||||||
writeSetEnd = writeCollectionEnd
|
else:
|
||||||
writeListEnd = writeCollectionEnd
|
self.__writeUByte(0xf0 | CTYPES[etype])
|
||||||
|
self.__writeSize(size)
|
||||||
|
self.__containers.append(self.state)
|
||||||
|
self.state = CONTAINER_WRITE
|
||||||
|
writeSetBegin = writeCollectionBegin
|
||||||
|
writeListBegin = writeCollectionBegin
|
||||||
|
|
||||||
def writeBool(self, bool):
|
def writeMapBegin(self, ktype, vtype, size):
|
||||||
if self.state == BOOL_WRITE:
|
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
|
||||||
if bool:
|
if size == 0:
|
||||||
ctype = CompactType.TRUE
|
self.__writeByte(0)
|
||||||
else:
|
else:
|
||||||
ctype = CompactType.FALSE
|
self.__writeSize(size)
|
||||||
self.__writeFieldHeader(ctype, self.__bool_fid)
|
self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
|
||||||
elif self.state == CONTAINER_WRITE:
|
self.__containers.append(self.state)
|
||||||
if bool:
|
self.state = CONTAINER_WRITE
|
||||||
self.__writeByte(CompactType.TRUE)
|
|
||||||
else:
|
|
||||||
self.__writeByte(CompactType.FALSE)
|
|
||||||
else:
|
|
||||||
raise AssertionError("Invalid state in compact protocol")
|
|
||||||
|
|
||||||
writeByte = writer(__writeByte)
|
def writeCollectionEnd(self):
|
||||||
writeI16 = writer(__writeI16)
|
assert self.state == CONTAINER_WRITE, self.state
|
||||||
|
self.state = self.__containers.pop()
|
||||||
|
writeMapEnd = writeCollectionEnd
|
||||||
|
writeSetEnd = writeCollectionEnd
|
||||||
|
writeListEnd = writeCollectionEnd
|
||||||
|
|
||||||
@writer
|
def writeBool(self, bool):
|
||||||
def writeI32(self, i32):
|
if self.state == BOOL_WRITE:
|
||||||
self.__writeVarint(makeZigZag(i32, 32))
|
if bool:
|
||||||
|
ctype = CompactType.TRUE
|
||||||
|
else:
|
||||||
|
ctype = CompactType.FALSE
|
||||||
|
self.__writeFieldHeader(ctype, self.__bool_fid)
|
||||||
|
elif self.state == CONTAINER_WRITE:
|
||||||
|
if bool:
|
||||||
|
self.__writeByte(CompactType.TRUE)
|
||||||
|
else:
|
||||||
|
self.__writeByte(CompactType.FALSE)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Invalid state in compact protocol")
|
||||||
|
|
||||||
@writer
|
writeByte = writer(__writeByte)
|
||||||
def writeI64(self, i64):
|
writeI16 = writer(__writeI16)
|
||||||
self.__writeVarint(makeZigZag(i64, 64))
|
|
||||||
|
|
||||||
@writer
|
@writer
|
||||||
def writeDouble(self, dub):
|
def writeI32(self, i32):
|
||||||
self.trans.write(pack('!d', dub))
|
self.__writeVarint(makeZigZag(i32, 32))
|
||||||
|
|
||||||
def __writeString(self, s):
|
@writer
|
||||||
self.__writeSize(len(s))
|
def writeI64(self, i64):
|
||||||
self.trans.write(s)
|
self.__writeVarint(makeZigZag(i64, 64))
|
||||||
writeString = writer(__writeString)
|
|
||||||
|
|
||||||
def readFieldBegin(self):
|
@writer
|
||||||
assert self.state == FIELD_READ, self.state
|
def writeDouble(self, dub):
|
||||||
type = self.__readUByte()
|
self.trans.write(pack('<d', dub))
|
||||||
if type & 0x0f == TType.STOP:
|
|
||||||
return (None, 0, 0)
|
|
||||||
delta = type >> 4
|
|
||||||
if delta == 0:
|
|
||||||
fid = self.__readI16()
|
|
||||||
else:
|
|
||||||
fid = self.__last_fid + delta
|
|
||||||
self.__last_fid = fid
|
|
||||||
type = type & 0x0f
|
|
||||||
if type == CompactType.TRUE:
|
|
||||||
self.state = BOOL_READ
|
|
||||||
self.__bool_value = True
|
|
||||||
elif type == CompactType.FALSE:
|
|
||||||
self.state = BOOL_READ
|
|
||||||
self.__bool_value = False
|
|
||||||
else:
|
|
||||||
self.state = VALUE_READ
|
|
||||||
return (None, self.__getTType(type), fid)
|
|
||||||
|
|
||||||
def readFieldEnd(self):
|
def __writeBinary(self, s):
|
||||||
assert self.state in (VALUE_READ, BOOL_READ), self.state
|
self.__writeSize(len(s))
|
||||||
self.state = FIELD_READ
|
self.trans.write(s)
|
||||||
|
writeBinary = writer(__writeBinary)
|
||||||
|
|
||||||
def __readUByte(self):
|
def readFieldBegin(self):
|
||||||
result, = unpack('!B', self.trans.readAll(1))
|
assert self.state == FIELD_READ, self.state
|
||||||
return result
|
type = self.__readUByte()
|
||||||
|
if type & 0x0f == TType.STOP:
|
||||||
|
return (None, 0, 0)
|
||||||
|
delta = type >> 4
|
||||||
|
if delta == 0:
|
||||||
|
fid = self.__readI16()
|
||||||
|
else:
|
||||||
|
fid = self.__last_fid + delta
|
||||||
|
self.__last_fid = fid
|
||||||
|
type = type & 0x0f
|
||||||
|
if type == CompactType.TRUE:
|
||||||
|
self.state = BOOL_READ
|
||||||
|
self.__bool_value = True
|
||||||
|
elif type == CompactType.FALSE:
|
||||||
|
self.state = BOOL_READ
|
||||||
|
self.__bool_value = False
|
||||||
|
else:
|
||||||
|
self.state = VALUE_READ
|
||||||
|
return (None, self.__getTType(type), fid)
|
||||||
|
|
||||||
def __readByte(self):
|
def readFieldEnd(self):
|
||||||
result, = unpack('!b', self.trans.readAll(1))
|
assert self.state in (VALUE_READ, BOOL_READ), self.state
|
||||||
return result
|
self.state = FIELD_READ
|
||||||
|
|
||||||
def __readVarint(self):
|
def __readUByte(self):
|
||||||
return readVarint(self.trans)
|
result, = unpack('!B', self.trans.readAll(1))
|
||||||
|
return result
|
||||||
|
|
||||||
def __readZigZag(self):
|
def __readByte(self):
|
||||||
return fromZigZag(self.__readVarint())
|
result, = unpack('!b', self.trans.readAll(1))
|
||||||
|
return result
|
||||||
|
|
||||||
def __readSize(self):
|
def __readVarint(self):
|
||||||
result = self.__readVarint()
|
return readVarint(self.trans)
|
||||||
if result < 0:
|
|
||||||
raise TException("Length < 0")
|
|
||||||
return result
|
|
||||||
|
|
||||||
def readMessageBegin(self):
|
def __readZigZag(self):
|
||||||
assert self.state == CLEAR
|
return fromZigZag(self.__readVarint())
|
||||||
proto_id = self.__readUByte()
|
|
||||||
if proto_id != self.PROTOCOL_ID:
|
|
||||||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
|
||||||
'Bad protocol id in the message: %d' % proto_id)
|
|
||||||
ver_type = self.__readUByte()
|
|
||||||
type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT
|
|
||||||
version = ver_type & self.VERSION_MASK
|
|
||||||
if version != self.VERSION:
|
|
||||||
raise TProtocolException(TProtocolException.BAD_VERSION,
|
|
||||||
'Bad version: %d (expect %d)' % (version, self.VERSION))
|
|
||||||
seqid = self.__readVarint()
|
|
||||||
name = self.__readString()
|
|
||||||
return (name, type, seqid)
|
|
||||||
|
|
||||||
def readMessageEnd(self):
|
def __readSize(self):
|
||||||
assert self.state == CLEAR
|
result = self.__readVarint()
|
||||||
assert len(self.__structs) == 0
|
if result < 0:
|
||||||
|
raise TProtocolException("Length < 0")
|
||||||
|
return result
|
||||||
|
|
||||||
def readStructBegin(self):
|
def readMessageBegin(self):
|
||||||
assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
|
assert self.state == CLEAR
|
||||||
self.__structs.append((self.state, self.__last_fid))
|
proto_id = self.__readUByte()
|
||||||
self.state = FIELD_READ
|
if proto_id != self.PROTOCOL_ID:
|
||||||
self.__last_fid = 0
|
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||||
|
'Bad protocol id in the message: %d' % proto_id)
|
||||||
|
ver_type = self.__readUByte()
|
||||||
|
type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
|
||||||
|
version = ver_type & self.VERSION_MASK
|
||||||
|
if version != self.VERSION:
|
||||||
|
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||||
|
'Bad version: %d (expect %d)' % (version, self.VERSION))
|
||||||
|
seqid = self.__readVarint()
|
||||||
|
# the sequence is a compact "var int" which is treaded as unsigned,
|
||||||
|
# however the sequence is actually signed...
|
||||||
|
if seqid > 2147483647:
|
||||||
|
seqid = -2147483648 - (2147483648 - seqid)
|
||||||
|
name = binary_to_str(self.__readBinary())
|
||||||
|
return (name, type, seqid)
|
||||||
|
|
||||||
def readStructEnd(self):
|
def readMessageEnd(self):
|
||||||
assert self.state == FIELD_READ
|
assert self.state == CLEAR
|
||||||
self.state, self.__last_fid = self.__structs.pop()
|
assert len(self.__structs) == 0
|
||||||
|
|
||||||
def readCollectionBegin(self):
|
def readStructBegin(self):
|
||||||
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
|
||||||
size_type = self.__readUByte()
|
self.__structs.append((self.state, self.__last_fid))
|
||||||
size = size_type >> 4
|
self.state = FIELD_READ
|
||||||
type = self.__getTType(size_type)
|
self.__last_fid = 0
|
||||||
if size == 15:
|
|
||||||
size = self.__readSize()
|
|
||||||
self.__containers.append(self.state)
|
|
||||||
self.state = CONTAINER_READ
|
|
||||||
return type, size
|
|
||||||
readSetBegin = readCollectionBegin
|
|
||||||
readListBegin = readCollectionBegin
|
|
||||||
|
|
||||||
def readMapBegin(self):
|
def readStructEnd(self):
|
||||||
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
assert self.state == FIELD_READ
|
||||||
size = self.__readSize()
|
self.state, self.__last_fid = self.__structs.pop()
|
||||||
types = 0
|
|
||||||
if size > 0:
|
|
||||||
types = self.__readUByte()
|
|
||||||
vtype = self.__getTType(types)
|
|
||||||
ktype = self.__getTType(types >> 4)
|
|
||||||
self.__containers.append(self.state)
|
|
||||||
self.state = CONTAINER_READ
|
|
||||||
return (ktype, vtype, size)
|
|
||||||
|
|
||||||
def readCollectionEnd(self):
|
def readCollectionBegin(self):
|
||||||
assert self.state == CONTAINER_READ, self.state
|
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
||||||
self.state = self.__containers.pop()
|
size_type = self.__readUByte()
|
||||||
readSetEnd = readCollectionEnd
|
size = size_type >> 4
|
||||||
readListEnd = readCollectionEnd
|
type = self.__getTType(size_type)
|
||||||
readMapEnd = readCollectionEnd
|
if size == 15:
|
||||||
|
size = self.__readSize()
|
||||||
|
self._check_container_length(size)
|
||||||
|
self.__containers.append(self.state)
|
||||||
|
self.state = CONTAINER_READ
|
||||||
|
return type, size
|
||||||
|
readSetBegin = readCollectionBegin
|
||||||
|
readListBegin = readCollectionBegin
|
||||||
|
|
||||||
def readBool(self):
|
def readMapBegin(self):
|
||||||
if self.state == BOOL_READ:
|
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
|
||||||
return self.__bool_value == CompactType.TRUE
|
size = self.__readSize()
|
||||||
elif self.state == CONTAINER_READ:
|
self._check_container_length(size)
|
||||||
return self.__readByte() == CompactType.TRUE
|
types = 0
|
||||||
else:
|
if size > 0:
|
||||||
raise AssertionError("Invalid state in compact protocol: %d" %
|
types = self.__readUByte()
|
||||||
self.state)
|
vtype = self.__getTType(types)
|
||||||
|
ktype = self.__getTType(types >> 4)
|
||||||
|
self.__containers.append(self.state)
|
||||||
|
self.state = CONTAINER_READ
|
||||||
|
return (ktype, vtype, size)
|
||||||
|
|
||||||
readByte = reader(__readByte)
|
def readCollectionEnd(self):
|
||||||
__readI16 = __readZigZag
|
assert self.state == CONTAINER_READ, self.state
|
||||||
readI16 = reader(__readZigZag)
|
self.state = self.__containers.pop()
|
||||||
readI32 = reader(__readZigZag)
|
readSetEnd = readCollectionEnd
|
||||||
readI64 = reader(__readZigZag)
|
readListEnd = readCollectionEnd
|
||||||
|
readMapEnd = readCollectionEnd
|
||||||
|
|
||||||
@reader
|
def readBool(self):
|
||||||
def readDouble(self):
|
if self.state == BOOL_READ:
|
||||||
buff = self.trans.readAll(8)
|
return self.__bool_value == CompactType.TRUE
|
||||||
val, = unpack('!d', buff)
|
elif self.state == CONTAINER_READ:
|
||||||
return val
|
return self.__readByte() == CompactType.TRUE
|
||||||
|
else:
|
||||||
|
raise AssertionError("Invalid state in compact protocol: %d" %
|
||||||
|
self.state)
|
||||||
|
|
||||||
def __readString(self):
|
readByte = reader(__readByte)
|
||||||
len = self.__readSize()
|
__readI16 = __readZigZag
|
||||||
return self.trans.readAll(len)
|
readI16 = reader(__readZigZag)
|
||||||
readString = reader(__readString)
|
readI32 = reader(__readZigZag)
|
||||||
|
readI64 = reader(__readZigZag)
|
||||||
|
|
||||||
def __getTType(self, byte):
|
@reader
|
||||||
return TTYPES[byte & 0x0f]
|
def readDouble(self):
|
||||||
|
buff = self.trans.readAll(8)
|
||||||
|
val, = unpack('<d', buff)
|
||||||
|
return val
|
||||||
|
|
||||||
|
def __readBinary(self):
|
||||||
|
size = self.__readSize()
|
||||||
|
self._check_string_length(size)
|
||||||
|
return self.trans.readAll(size)
|
||||||
|
readBinary = reader(__readBinary)
|
||||||
|
|
||||||
|
def __getTType(self, byte):
|
||||||
|
return TTYPES[byte & 0x0f]
|
||||||
|
|
||||||
|
|
||||||
class TCompactProtocolFactory:
|
class TCompactProtocolFactory(TProtocolFactory):
|
||||||
def __init__(self):
|
def __init__(self,
|
||||||
|
string_length_limit=None,
|
||||||
|
container_length_limit=None):
|
||||||
|
self.string_length_limit = string_length_limit
|
||||||
|
self.container_length_limit = container_length_limit
|
||||||
|
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
return TCompactProtocol(trans,
|
||||||
|
self.string_length_limit,
|
||||||
|
self.container_length_limit)
|
||||||
|
|
||||||
|
|
||||||
|
class TCompactProtocolAccelerated(TCompactProtocol):
|
||||||
|
"""C-Accelerated version of TCompactProtocol.
|
||||||
|
|
||||||
|
This class does not override any of TCompactProtocol's methods,
|
||||||
|
but the generated code recognizes it directly and will call into
|
||||||
|
our C module to do the encoding, bypassing this object entirely.
|
||||||
|
We inherit from TCompactProtocol so that the normal TCompactProtocol
|
||||||
|
encoding can happen if the fastbinary module doesn't work for some
|
||||||
|
reason.
|
||||||
|
To disable this behavior, pass fallback=False constructor argument.
|
||||||
|
|
||||||
|
In order to take advantage of the C module, just use
|
||||||
|
TCompactProtocolAccelerated instead of TCompactProtocol.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def getProtocol(self, trans):
|
def __init__(self, *args, **kwargs):
|
||||||
return TCompactProtocol(trans)
|
fallback = kwargs.pop('fallback', True)
|
||||||
|
super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
|
||||||
|
try:
|
||||||
|
from thrift.protocol import fastbinary
|
||||||
|
except ImportError:
|
||||||
|
if not fallback:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
self._fast_decode = fastbinary.decode_compact
|
||||||
|
self._fast_encode = fastbinary.encode_compact
|
||||||
|
|
||||||
|
|
||||||
|
class TCompactProtocolAcceleratedFactory(TProtocolFactory):
|
||||||
|
def __init__(self,
|
||||||
|
string_length_limit=None,
|
||||||
|
container_length_limit=None,
|
||||||
|
fallback=True):
|
||||||
|
self.string_length_limit = string_length_limit
|
||||||
|
self.container_length_limit = container_length_limit
|
||||||
|
self._fallback = fallback
|
||||||
|
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
return TCompactProtocolAccelerated(
|
||||||
|
trans,
|
||||||
|
string_length_limit=self.string_length_limit,
|
||||||
|
container_length_limit=self.container_length_limit,
|
||||||
|
fallback=self._fallback)
|
||||||
|
|
232
thrift/protocol/THeaderProtocol.py
Normal file
232
thrift/protocol/THeaderProtocol.py
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
|
||||||
|
from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
|
||||||
|
from thrift.protocol.TProtocol import TProtocolBase, TProtocolException, TProtocolFactory
|
||||||
|
from thrift.Thrift import TApplicationException, TMessageType
|
||||||
|
from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
|
||||||
|
|
||||||
|
|
||||||
|
PROTOCOLS_BY_ID = {
|
||||||
|
THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
|
||||||
|
THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class THeaderProtocol(TProtocolBase):
|
||||||
|
"""A framed protocol with headers and payload transforms.
|
||||||
|
|
||||||
|
THeaderProtocol frames other Thrift protocols and adds support for optional
|
||||||
|
out-of-band headers. The currently supported subprotocols are
|
||||||
|
TBinaryProtocol and TCompactProtocol. When used as a client, the
|
||||||
|
subprotocol to frame can be chosen with the `default_protocol` parameter to
|
||||||
|
the constructor.
|
||||||
|
|
||||||
|
It's also possible to apply transforms to the encoded message payload. The
|
||||||
|
only transform currently supported is to gzip.
|
||||||
|
|
||||||
|
When used in a server, THeaderProtocol can accept messages from
|
||||||
|
non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
|
||||||
|
includes framed and unframed transports and both TBinaryProtocol and
|
||||||
|
TCompactProtocol. The server will respond in the appropriate dialect for
|
||||||
|
the connected client. HTTP clients are not currently supported.
|
||||||
|
|
||||||
|
THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
|
||||||
|
or TProcessPoolServer.
|
||||||
|
|
||||||
|
See doc/specs/HeaderFormat.md for details of the wire format.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY):
|
||||||
|
# much of the actual work for THeaderProtocol happens down in
|
||||||
|
# THeaderTransport since we need to do low-level shenanigans to detect
|
||||||
|
# if the client is sending us headers or one of the headerless formats
|
||||||
|
# we support. this wraps the real transport with the one that does all
|
||||||
|
# the magic.
|
||||||
|
if not isinstance(transport, THeaderTransport):
|
||||||
|
transport = THeaderTransport(transport, allowed_client_types, default_protocol)
|
||||||
|
super(THeaderProtocol, self).__init__(transport)
|
||||||
|
self._set_protocol()
|
||||||
|
|
||||||
|
def get_headers(self):
|
||||||
|
return self.trans.get_headers()
|
||||||
|
|
||||||
|
def set_header(self, key, value):
|
||||||
|
self.trans.set_header(key, value)
|
||||||
|
|
||||||
|
def clear_headers(self):
|
||||||
|
self.trans.clear_headers()
|
||||||
|
|
||||||
|
def add_transform(self, transform_id):
|
||||||
|
self.trans.add_transform(transform_id)
|
||||||
|
|
||||||
|
def writeMessageBegin(self, name, ttype, seqid):
|
||||||
|
self.trans.sequence_id = seqid
|
||||||
|
return self._protocol.writeMessageBegin(name, ttype, seqid)
|
||||||
|
|
||||||
|
def writeMessageEnd(self):
|
||||||
|
return self._protocol.writeMessageEnd()
|
||||||
|
|
||||||
|
def writeStructBegin(self, name):
|
||||||
|
return self._protocol.writeStructBegin(name)
|
||||||
|
|
||||||
|
def writeStructEnd(self):
|
||||||
|
return self._protocol.writeStructEnd()
|
||||||
|
|
||||||
|
def writeFieldBegin(self, name, ttype, fid):
|
||||||
|
return self._protocol.writeFieldBegin(name, ttype, fid)
|
||||||
|
|
||||||
|
def writeFieldEnd(self):
|
||||||
|
return self._protocol.writeFieldEnd()
|
||||||
|
|
||||||
|
def writeFieldStop(self):
|
||||||
|
return self._protocol.writeFieldStop()
|
||||||
|
|
||||||
|
def writeMapBegin(self, ktype, vtype, size):
|
||||||
|
return self._protocol.writeMapBegin(ktype, vtype, size)
|
||||||
|
|
||||||
|
def writeMapEnd(self):
|
||||||
|
return self._protocol.writeMapEnd()
|
||||||
|
|
||||||
|
def writeListBegin(self, etype, size):
|
||||||
|
return self._protocol.writeListBegin(etype, size)
|
||||||
|
|
||||||
|
def writeListEnd(self):
|
||||||
|
return self._protocol.writeListEnd()
|
||||||
|
|
||||||
|
def writeSetBegin(self, etype, size):
|
||||||
|
return self._protocol.writeSetBegin(etype, size)
|
||||||
|
|
||||||
|
def writeSetEnd(self):
|
||||||
|
return self._protocol.writeSetEnd()
|
||||||
|
|
||||||
|
def writeBool(self, bool_val):
|
||||||
|
return self._protocol.writeBool(bool_val)
|
||||||
|
|
||||||
|
def writeByte(self, byte):
|
||||||
|
return self._protocol.writeByte(byte)
|
||||||
|
|
||||||
|
def writeI16(self, i16):
|
||||||
|
return self._protocol.writeI16(i16)
|
||||||
|
|
||||||
|
def writeI32(self, i32):
|
||||||
|
return self._protocol.writeI32(i32)
|
||||||
|
|
||||||
|
def writeI64(self, i64):
|
||||||
|
return self._protocol.writeI64(i64)
|
||||||
|
|
||||||
|
def writeDouble(self, dub):
|
||||||
|
return self._protocol.writeDouble(dub)
|
||||||
|
|
||||||
|
def writeBinary(self, str_val):
|
||||||
|
return self._protocol.writeBinary(str_val)
|
||||||
|
|
||||||
|
def _set_protocol(self):
|
||||||
|
try:
|
||||||
|
protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
|
||||||
|
except KeyError:
|
||||||
|
raise TApplicationException(
|
||||||
|
TProtocolException.INVALID_PROTOCOL,
|
||||||
|
"Unknown protocol requested.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._protocol = protocol_cls(self.trans)
|
||||||
|
self._fast_encode = self._protocol._fast_encode
|
||||||
|
self._fast_decode = self._protocol._fast_decode
|
||||||
|
|
||||||
|
def readMessageBegin(self):
|
||||||
|
try:
|
||||||
|
self.trans.readFrame(0)
|
||||||
|
self._set_protocol()
|
||||||
|
except TApplicationException as exc:
|
||||||
|
self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
|
||||||
|
exc.write(self._protocol)
|
||||||
|
self._protocol.writeMessageEnd()
|
||||||
|
self.trans.flush()
|
||||||
|
|
||||||
|
return self._protocol.readMessageBegin()
|
||||||
|
|
||||||
|
def readMessageEnd(self):
|
||||||
|
return self._protocol.readMessageEnd()
|
||||||
|
|
||||||
|
def readStructBegin(self):
|
||||||
|
return self._protocol.readStructBegin()
|
||||||
|
|
||||||
|
def readStructEnd(self):
|
||||||
|
return self._protocol.readStructEnd()
|
||||||
|
|
||||||
|
def readFieldBegin(self):
|
||||||
|
return self._protocol.readFieldBegin()
|
||||||
|
|
||||||
|
def readFieldEnd(self):
|
||||||
|
return self._protocol.readFieldEnd()
|
||||||
|
|
||||||
|
def readMapBegin(self):
|
||||||
|
return self._protocol.readMapBegin()
|
||||||
|
|
||||||
|
def readMapEnd(self):
|
||||||
|
return self._protocol.readMapEnd()
|
||||||
|
|
||||||
|
def readListBegin(self):
|
||||||
|
return self._protocol.readListBegin()
|
||||||
|
|
||||||
|
def readListEnd(self):
|
||||||
|
return self._protocol.readListEnd()
|
||||||
|
|
||||||
|
def readSetBegin(self):
|
||||||
|
return self._protocol.readSetBegin()
|
||||||
|
|
||||||
|
def readSetEnd(self):
|
||||||
|
return self._protocol.readSetEnd()
|
||||||
|
|
||||||
|
def readBool(self):
|
||||||
|
return self._protocol.readBool()
|
||||||
|
|
||||||
|
def readByte(self):
|
||||||
|
return self._protocol.readByte()
|
||||||
|
|
||||||
|
def readI16(self):
|
||||||
|
return self._protocol.readI16()
|
||||||
|
|
||||||
|
def readI32(self):
|
||||||
|
return self._protocol.readI32()
|
||||||
|
|
||||||
|
def readI64(self):
|
||||||
|
return self._protocol.readI64()
|
||||||
|
|
||||||
|
def readDouble(self):
|
||||||
|
return self._protocol.readDouble()
|
||||||
|
|
||||||
|
def readBinary(self):
|
||||||
|
return self._protocol.readBinary()
|
||||||
|
|
||||||
|
|
||||||
|
class THeaderProtocolFactory(TProtocolFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
allowed_client_types=(THeaderClientType.HEADERS,),
|
||||||
|
default_protocol=THeaderSubprotocolID.BINARY,
|
||||||
|
):
|
||||||
|
self.allowed_client_types = allowed_client_types
|
||||||
|
self.default_protocol = default_protocol
|
||||||
|
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
return THeaderProtocol(trans, self.allowed_client_types, self.default_protocol)
|
677
thrift/protocol/TJSONProtocol.py
Normal file
677
thrift/protocol/TJSONProtocol.py
Normal file
|
@ -0,0 +1,677 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from .TProtocol import (TType, TProtocolBase, TProtocolException,
|
||||||
|
TProtocolFactory, checkIntegerLimits)
|
||||||
|
import base64
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from ..compat import str_to_binary
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['TJSONProtocol',
|
||||||
|
'TJSONProtocolFactory',
|
||||||
|
'TSimpleJSONProtocol',
|
||||||
|
'TSimpleJSONProtocolFactory']
|
||||||
|
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
COMMA = b','
|
||||||
|
COLON = b':'
|
||||||
|
LBRACE = b'{'
|
||||||
|
RBRACE = b'}'
|
||||||
|
LBRACKET = b'['
|
||||||
|
RBRACKET = b']'
|
||||||
|
QUOTE = b'"'
|
||||||
|
BACKSLASH = b'\\'
|
||||||
|
ZERO = b'0'
|
||||||
|
|
||||||
|
ESCSEQ0 = ord('\\')
|
||||||
|
ESCSEQ1 = ord('u')
|
||||||
|
ESCAPE_CHAR_VALS = {
|
||||||
|
'"': '\\"',
|
||||||
|
'\\': '\\\\',
|
||||||
|
'\b': '\\b',
|
||||||
|
'\f': '\\f',
|
||||||
|
'\n': '\\n',
|
||||||
|
'\r': '\\r',
|
||||||
|
'\t': '\\t',
|
||||||
|
# '/': '\\/',
|
||||||
|
}
|
||||||
|
ESCAPE_CHARS = {
|
||||||
|
b'"': '"',
|
||||||
|
b'\\': '\\',
|
||||||
|
b'b': '\b',
|
||||||
|
b'f': '\f',
|
||||||
|
b'n': '\n',
|
||||||
|
b'r': '\r',
|
||||||
|
b't': '\t',
|
||||||
|
b'/': '/',
|
||||||
|
}
|
||||||
|
NUMERIC_CHAR = b'+-.0123456789Ee'
|
||||||
|
|
||||||
|
CTYPES = {
|
||||||
|
TType.BOOL: 'tf',
|
||||||
|
TType.BYTE: 'i8',
|
||||||
|
TType.I16: 'i16',
|
||||||
|
TType.I32: 'i32',
|
||||||
|
TType.I64: 'i64',
|
||||||
|
TType.DOUBLE: 'dbl',
|
||||||
|
TType.STRING: 'str',
|
||||||
|
TType.STRUCT: 'rec',
|
||||||
|
TType.LIST: 'lst',
|
||||||
|
TType.SET: 'set',
|
||||||
|
TType.MAP: 'map',
|
||||||
|
}
|
||||||
|
|
||||||
|
JTYPES = {}
|
||||||
|
for key in CTYPES.keys():
|
||||||
|
JTYPES[CTYPES[key]] = key
|
||||||
|
|
||||||
|
|
||||||
|
class JSONBaseContext(object):
|
||||||
|
|
||||||
|
def __init__(self, protocol):
|
||||||
|
self.protocol = protocol
|
||||||
|
self.first = True
|
||||||
|
|
||||||
|
def doIO(self, function):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def escapeNum(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
class JSONListContext(JSONBaseContext):
|
||||||
|
|
||||||
|
def doIO(self, function):
|
||||||
|
if self.first is True:
|
||||||
|
self.first = False
|
||||||
|
else:
|
||||||
|
function(COMMA)
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.doIO(self.protocol.trans.write)
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
self.doIO(self.protocol.readJSONSyntaxChar)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONPairContext(JSONBaseContext):
|
||||||
|
|
||||||
|
def __init__(self, protocol):
|
||||||
|
super(JSONPairContext, self).__init__(protocol)
|
||||||
|
self.colon = True
|
||||||
|
|
||||||
|
def doIO(self, function):
|
||||||
|
if self.first:
|
||||||
|
self.first = False
|
||||||
|
self.colon = True
|
||||||
|
else:
|
||||||
|
function(COLON if self.colon else COMMA)
|
||||||
|
self.colon = not self.colon
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.doIO(self.protocol.trans.write)
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
self.doIO(self.protocol.readJSONSyntaxChar)
|
||||||
|
|
||||||
|
def escapeNum(self):
|
||||||
|
return self.colon
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '%s, colon=%s' % (self.__class__.__name__, self.colon)
|
||||||
|
|
||||||
|
|
||||||
|
class LookaheadReader():
|
||||||
|
hasData = False
|
||||||
|
data = ''
|
||||||
|
|
||||||
|
def __init__(self, protocol):
|
||||||
|
self.protocol = protocol
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
if self.hasData is True:
|
||||||
|
self.hasData = False
|
||||||
|
else:
|
||||||
|
self.data = self.protocol.trans.read(1)
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
def peek(self):
|
||||||
|
if self.hasData is False:
|
||||||
|
self.data = self.protocol.trans.read(1)
|
||||||
|
self.hasData = True
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
|
class TJSONProtocolBase(TProtocolBase):
|
||||||
|
|
||||||
|
def __init__(self, trans):
|
||||||
|
TProtocolBase.__init__(self, trans)
|
||||||
|
self.resetWriteContext()
|
||||||
|
self.resetReadContext()
|
||||||
|
|
||||||
|
# We don't have length limit implementation for JSON protocols
|
||||||
|
@property
|
||||||
|
def string_length_limit(senf):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def container_length_limit(senf):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resetWriteContext(self):
|
||||||
|
self.context = JSONBaseContext(self)
|
||||||
|
self.contextStack = [self.context]
|
||||||
|
|
||||||
|
def resetReadContext(self):
|
||||||
|
self.resetWriteContext()
|
||||||
|
self.reader = LookaheadReader(self)
|
||||||
|
|
||||||
|
def pushContext(self, ctx):
|
||||||
|
self.contextStack.append(ctx)
|
||||||
|
self.context = ctx
|
||||||
|
|
||||||
|
def popContext(self):
|
||||||
|
self.contextStack.pop()
|
||||||
|
if self.contextStack:
|
||||||
|
self.context = self.contextStack[-1]
|
||||||
|
else:
|
||||||
|
self.context = JSONBaseContext(self)
|
||||||
|
|
||||||
|
def writeJSONString(self, string):
|
||||||
|
self.context.write()
|
||||||
|
json_str = ['"']
|
||||||
|
for s in string:
|
||||||
|
escaped = ESCAPE_CHAR_VALS.get(s, s)
|
||||||
|
json_str.append(escaped)
|
||||||
|
json_str.append('"')
|
||||||
|
self.trans.write(str_to_binary(''.join(json_str)))
|
||||||
|
|
||||||
|
def writeJSONNumber(self, number, formatter='{0}'):
|
||||||
|
self.context.write()
|
||||||
|
jsNumber = str(formatter.format(number)).encode('ascii')
|
||||||
|
if self.context.escapeNum():
|
||||||
|
self.trans.write(QUOTE)
|
||||||
|
self.trans.write(jsNumber)
|
||||||
|
self.trans.write(QUOTE)
|
||||||
|
else:
|
||||||
|
self.trans.write(jsNumber)
|
||||||
|
|
||||||
|
def writeJSONBase64(self, binary):
|
||||||
|
self.context.write()
|
||||||
|
self.trans.write(QUOTE)
|
||||||
|
self.trans.write(base64.b64encode(binary))
|
||||||
|
self.trans.write(QUOTE)
|
||||||
|
|
||||||
|
def writeJSONObjectStart(self):
|
||||||
|
self.context.write()
|
||||||
|
self.trans.write(LBRACE)
|
||||||
|
self.pushContext(JSONPairContext(self))
|
||||||
|
|
||||||
|
def writeJSONObjectEnd(self):
|
||||||
|
self.popContext()
|
||||||
|
self.trans.write(RBRACE)
|
||||||
|
|
||||||
|
def writeJSONArrayStart(self):
|
||||||
|
self.context.write()
|
||||||
|
self.trans.write(LBRACKET)
|
||||||
|
self.pushContext(JSONListContext(self))
|
||||||
|
|
||||||
|
def writeJSONArrayEnd(self):
|
||||||
|
self.popContext()
|
||||||
|
self.trans.write(RBRACKET)
|
||||||
|
|
||||||
|
def readJSONSyntaxChar(self, character):
|
||||||
|
current = self.reader.read()
|
||||||
|
if character != current:
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"Unexpected character: %s" % current)
|
||||||
|
|
||||||
|
def _isHighSurrogate(self, codeunit):
|
||||||
|
return codeunit >= 0xd800 and codeunit <= 0xdbff
|
||||||
|
|
||||||
|
def _isLowSurrogate(self, codeunit):
|
||||||
|
return codeunit >= 0xdc00 and codeunit <= 0xdfff
|
||||||
|
|
||||||
|
def _toChar(self, high, low=None):
|
||||||
|
if not low:
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
return ("\\u%04x" % high).decode('unicode-escape') \
|
||||||
|
.encode('utf-8')
|
||||||
|
else:
|
||||||
|
return chr(high)
|
||||||
|
else:
|
||||||
|
codepoint = (1 << 16) + ((high & 0x3ff) << 10)
|
||||||
|
codepoint += low & 0x3ff
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
s = "\\U%08x" % codepoint
|
||||||
|
return s.decode('unicode-escape').encode('utf-8')
|
||||||
|
else:
|
||||||
|
return chr(codepoint)
|
||||||
|
|
||||||
|
def readJSONString(self, skipContext):
|
||||||
|
highSurrogate = None
|
||||||
|
string = []
|
||||||
|
if skipContext is False:
|
||||||
|
self.context.read()
|
||||||
|
self.readJSONSyntaxChar(QUOTE)
|
||||||
|
while True:
|
||||||
|
character = self.reader.read()
|
||||||
|
if character == QUOTE:
|
||||||
|
break
|
||||||
|
if ord(character) == ESCSEQ0:
|
||||||
|
character = self.reader.read()
|
||||||
|
if ord(character) == ESCSEQ1:
|
||||||
|
character = self.trans.read(4).decode('ascii')
|
||||||
|
codeunit = int(character, 16)
|
||||||
|
if self._isHighSurrogate(codeunit):
|
||||||
|
if highSurrogate:
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.INVALID_DATA,
|
||||||
|
"Expected low surrogate char")
|
||||||
|
highSurrogate = codeunit
|
||||||
|
continue
|
||||||
|
elif self._isLowSurrogate(codeunit):
|
||||||
|
if not highSurrogate:
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.INVALID_DATA,
|
||||||
|
"Expected high surrogate char")
|
||||||
|
character = self._toChar(highSurrogate, codeunit)
|
||||||
|
highSurrogate = None
|
||||||
|
else:
|
||||||
|
character = self._toChar(codeunit)
|
||||||
|
else:
|
||||||
|
if character not in ESCAPE_CHARS:
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.INVALID_DATA,
|
||||||
|
"Expected control char")
|
||||||
|
character = ESCAPE_CHARS[character]
|
||||||
|
elif character in ESCAPE_CHAR_VALS:
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"Unescaped control char")
|
||||||
|
elif sys.version_info[0] > 2:
|
||||||
|
utf8_bytes = bytearray([ord(character)])
|
||||||
|
while ord(self.reader.peek()) >= 0x80:
|
||||||
|
utf8_bytes.append(ord(self.reader.read()))
|
||||||
|
character = utf8_bytes.decode('utf8')
|
||||||
|
string.append(character)
|
||||||
|
|
||||||
|
if highSurrogate:
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"Expected low surrogate char")
|
||||||
|
return ''.join(string)
|
||||||
|
|
||||||
|
def isJSONNumeric(self, character):
|
||||||
|
return (True if NUMERIC_CHAR.find(character) != - 1 else False)
|
||||||
|
|
||||||
|
def readJSONQuotes(self):
|
||||||
|
if (self.context.escapeNum()):
|
||||||
|
self.readJSONSyntaxChar(QUOTE)
|
||||||
|
|
||||||
|
def readJSONNumericChars(self):
|
||||||
|
numeric = []
|
||||||
|
while True:
|
||||||
|
character = self.reader.peek()
|
||||||
|
if self.isJSONNumeric(character) is False:
|
||||||
|
break
|
||||||
|
numeric.append(self.reader.read())
|
||||||
|
return b''.join(numeric).decode('ascii')
|
||||||
|
|
||||||
|
def readJSONInteger(self):
|
||||||
|
self.context.read()
|
||||||
|
self.readJSONQuotes()
|
||||||
|
numeric = self.readJSONNumericChars()
|
||||||
|
self.readJSONQuotes()
|
||||||
|
try:
|
||||||
|
return int(numeric)
|
||||||
|
except ValueError:
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"Bad data encounted in numeric data")
|
||||||
|
|
||||||
|
def readJSONDouble(self):
|
||||||
|
self.context.read()
|
||||||
|
if self.reader.peek() == QUOTE:
|
||||||
|
string = self.readJSONString(True)
|
||||||
|
try:
|
||||||
|
double = float(string)
|
||||||
|
if (self.context.escapeNum is False and
|
||||||
|
not math.isinf(double) and
|
||||||
|
not math.isnan(double)):
|
||||||
|
raise TProtocolException(
|
||||||
|
TProtocolException.INVALID_DATA,
|
||||||
|
"Numeric data unexpectedly quoted")
|
||||||
|
return double
|
||||||
|
except ValueError:
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"Bad data encounted in numeric data")
|
||||||
|
else:
|
||||||
|
if self.context.escapeNum() is True:
|
||||||
|
self.readJSONSyntaxChar(QUOTE)
|
||||||
|
try:
|
||||||
|
return float(self.readJSONNumericChars())
|
||||||
|
except ValueError:
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"Bad data encounted in numeric data")
|
||||||
|
|
||||||
|
def readJSONBase64(self):
|
||||||
|
string = self.readJSONString(False)
|
||||||
|
size = len(string)
|
||||||
|
m = size % 4
|
||||||
|
# Force padding since b64encode method does not allow it
|
||||||
|
if m != 0:
|
||||||
|
for i in range(4 - m):
|
||||||
|
string += '='
|
||||||
|
return base64.b64decode(string)
|
||||||
|
|
||||||
|
def readJSONObjectStart(self):
|
||||||
|
self.context.read()
|
||||||
|
self.readJSONSyntaxChar(LBRACE)
|
||||||
|
self.pushContext(JSONPairContext(self))
|
||||||
|
|
||||||
|
def readJSONObjectEnd(self):
|
||||||
|
self.readJSONSyntaxChar(RBRACE)
|
||||||
|
self.popContext()
|
||||||
|
|
||||||
|
def readJSONArrayStart(self):
|
||||||
|
self.context.read()
|
||||||
|
self.readJSONSyntaxChar(LBRACKET)
|
||||||
|
self.pushContext(JSONListContext(self))
|
||||||
|
|
||||||
|
def readJSONArrayEnd(self):
|
||||||
|
self.readJSONSyntaxChar(RBRACKET)
|
||||||
|
self.popContext()
|
||||||
|
|
||||||
|
|
||||||
|
class TJSONProtocol(TJSONProtocolBase):
|
||||||
|
|
||||||
|
def readMessageBegin(self):
|
||||||
|
self.resetReadContext()
|
||||||
|
self.readJSONArrayStart()
|
||||||
|
if self.readJSONInteger() != VERSION:
|
||||||
|
raise TProtocolException(TProtocolException.BAD_VERSION,
|
||||||
|
"Message contained bad version.")
|
||||||
|
name = self.readJSONString(False)
|
||||||
|
typen = self.readJSONInteger()
|
||||||
|
seqid = self.readJSONInteger()
|
||||||
|
return (name, typen, seqid)
|
||||||
|
|
||||||
|
def readMessageEnd(self):
|
||||||
|
self.readJSONArrayEnd()
|
||||||
|
|
||||||
|
def readStructBegin(self):
|
||||||
|
self.readJSONObjectStart()
|
||||||
|
|
||||||
|
def readStructEnd(self):
|
||||||
|
self.readJSONObjectEnd()
|
||||||
|
|
||||||
|
def readFieldBegin(self):
|
||||||
|
character = self.reader.peek()
|
||||||
|
ttype = 0
|
||||||
|
id = 0
|
||||||
|
if character == RBRACE:
|
||||||
|
ttype = TType.STOP
|
||||||
|
else:
|
||||||
|
id = self.readJSONInteger()
|
||||||
|
self.readJSONObjectStart()
|
||||||
|
ttype = JTYPES[self.readJSONString(False)]
|
||||||
|
return (None, ttype, id)
|
||||||
|
|
||||||
|
def readFieldEnd(self):
|
||||||
|
self.readJSONObjectEnd()
|
||||||
|
|
||||||
|
def readMapBegin(self):
|
||||||
|
self.readJSONArrayStart()
|
||||||
|
keyType = JTYPES[self.readJSONString(False)]
|
||||||
|
valueType = JTYPES[self.readJSONString(False)]
|
||||||
|
size = self.readJSONInteger()
|
||||||
|
self.readJSONObjectStart()
|
||||||
|
return (keyType, valueType, size)
|
||||||
|
|
||||||
|
def readMapEnd(self):
|
||||||
|
self.readJSONObjectEnd()
|
||||||
|
self.readJSONArrayEnd()
|
||||||
|
|
||||||
|
def readCollectionBegin(self):
|
||||||
|
self.readJSONArrayStart()
|
||||||
|
elemType = JTYPES[self.readJSONString(False)]
|
||||||
|
size = self.readJSONInteger()
|
||||||
|
return (elemType, size)
|
||||||
|
readListBegin = readCollectionBegin
|
||||||
|
readSetBegin = readCollectionBegin
|
||||||
|
|
||||||
|
def readCollectionEnd(self):
|
||||||
|
self.readJSONArrayEnd()
|
||||||
|
readSetEnd = readCollectionEnd
|
||||||
|
readListEnd = readCollectionEnd
|
||||||
|
|
||||||
|
def readBool(self):
|
||||||
|
return (False if self.readJSONInteger() == 0 else True)
|
||||||
|
|
||||||
|
def readNumber(self):
|
||||||
|
return self.readJSONInteger()
|
||||||
|
readByte = readNumber
|
||||||
|
readI16 = readNumber
|
||||||
|
readI32 = readNumber
|
||||||
|
readI64 = readNumber
|
||||||
|
|
||||||
|
def readDouble(self):
|
||||||
|
return self.readJSONDouble()
|
||||||
|
|
||||||
|
def readString(self):
|
||||||
|
return self.readJSONString(False)
|
||||||
|
|
||||||
|
def readBinary(self):
|
||||||
|
return self.readJSONBase64()
|
||||||
|
|
||||||
|
def writeMessageBegin(self, name, request_type, seqid):
|
||||||
|
self.resetWriteContext()
|
||||||
|
self.writeJSONArrayStart()
|
||||||
|
self.writeJSONNumber(VERSION)
|
||||||
|
self.writeJSONString(name)
|
||||||
|
self.writeJSONNumber(request_type)
|
||||||
|
self.writeJSONNumber(seqid)
|
||||||
|
|
||||||
|
def writeMessageEnd(self):
|
||||||
|
self.writeJSONArrayEnd()
|
||||||
|
|
||||||
|
def writeStructBegin(self, name):
|
||||||
|
self.writeJSONObjectStart()
|
||||||
|
|
||||||
|
def writeStructEnd(self):
|
||||||
|
self.writeJSONObjectEnd()
|
||||||
|
|
||||||
|
def writeFieldBegin(self, name, ttype, id):
|
||||||
|
self.writeJSONNumber(id)
|
||||||
|
self.writeJSONObjectStart()
|
||||||
|
self.writeJSONString(CTYPES[ttype])
|
||||||
|
|
||||||
|
def writeFieldEnd(self):
|
||||||
|
self.writeJSONObjectEnd()
|
||||||
|
|
||||||
|
def writeFieldStop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def writeMapBegin(self, ktype, vtype, size):
|
||||||
|
self.writeJSONArrayStart()
|
||||||
|
self.writeJSONString(CTYPES[ktype])
|
||||||
|
self.writeJSONString(CTYPES[vtype])
|
||||||
|
self.writeJSONNumber(size)
|
||||||
|
self.writeJSONObjectStart()
|
||||||
|
|
||||||
|
def writeMapEnd(self):
|
||||||
|
self.writeJSONObjectEnd()
|
||||||
|
self.writeJSONArrayEnd()
|
||||||
|
|
||||||
|
def writeListBegin(self, etype, size):
|
||||||
|
self.writeJSONArrayStart()
|
||||||
|
self.writeJSONString(CTYPES[etype])
|
||||||
|
self.writeJSONNumber(size)
|
||||||
|
|
||||||
|
def writeListEnd(self):
|
||||||
|
self.writeJSONArrayEnd()
|
||||||
|
|
||||||
|
def writeSetBegin(self, etype, size):
|
||||||
|
self.writeJSONArrayStart()
|
||||||
|
self.writeJSONString(CTYPES[etype])
|
||||||
|
self.writeJSONNumber(size)
|
||||||
|
|
||||||
|
def writeSetEnd(self):
|
||||||
|
self.writeJSONArrayEnd()
|
||||||
|
|
||||||
|
def writeBool(self, boolean):
|
||||||
|
self.writeJSONNumber(1 if boolean is True else 0)
|
||||||
|
|
||||||
|
def writeByte(self, byte):
|
||||||
|
checkIntegerLimits(byte, 8)
|
||||||
|
self.writeJSONNumber(byte)
|
||||||
|
|
||||||
|
def writeI16(self, i16):
|
||||||
|
checkIntegerLimits(i16, 16)
|
||||||
|
self.writeJSONNumber(i16)
|
||||||
|
|
||||||
|
def writeI32(self, i32):
|
||||||
|
checkIntegerLimits(i32, 32)
|
||||||
|
self.writeJSONNumber(i32)
|
||||||
|
|
||||||
|
def writeI64(self, i64):
|
||||||
|
checkIntegerLimits(i64, 64)
|
||||||
|
self.writeJSONNumber(i64)
|
||||||
|
|
||||||
|
def writeDouble(self, dbl):
|
||||||
|
# 17 significant digits should be just enough for any double precision
|
||||||
|
# value.
|
||||||
|
self.writeJSONNumber(dbl, '{0:.17g}')
|
||||||
|
|
||||||
|
def writeString(self, string):
|
||||||
|
self.writeJSONString(string)
|
||||||
|
|
||||||
|
def writeBinary(self, binary):
|
||||||
|
self.writeJSONBase64(binary)
|
||||||
|
|
||||||
|
|
||||||
|
class TJSONProtocolFactory(TProtocolFactory):
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
return TJSONProtocol(trans)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def string_length_limit(senf):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def container_length_limit(senf):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TSimpleJSONProtocol(TJSONProtocolBase):
|
||||||
|
"""Simple, readable, write-only JSON protocol.
|
||||||
|
|
||||||
|
Useful for interacting with scripting languages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def readMessageBegin(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def readMessageEnd(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def readStructBegin(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def readStructEnd(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def writeMessageBegin(self, name, request_type, seqid):
|
||||||
|
self.resetWriteContext()
|
||||||
|
|
||||||
|
def writeMessageEnd(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def writeStructBegin(self, name):
|
||||||
|
self.writeJSONObjectStart()
|
||||||
|
|
||||||
|
def writeStructEnd(self):
|
||||||
|
self.writeJSONObjectEnd()
|
||||||
|
|
||||||
|
def writeFieldBegin(self, name, ttype, fid):
|
||||||
|
self.writeJSONString(name)
|
||||||
|
|
||||||
|
def writeFieldEnd(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def writeMapBegin(self, ktype, vtype, size):
|
||||||
|
self.writeJSONObjectStart()
|
||||||
|
|
||||||
|
def writeMapEnd(self):
|
||||||
|
self.writeJSONObjectEnd()
|
||||||
|
|
||||||
|
def _writeCollectionBegin(self, etype, size):
|
||||||
|
self.writeJSONArrayStart()
|
||||||
|
|
||||||
|
def _writeCollectionEnd(self):
|
||||||
|
self.writeJSONArrayEnd()
|
||||||
|
writeListBegin = _writeCollectionBegin
|
||||||
|
writeListEnd = _writeCollectionEnd
|
||||||
|
writeSetBegin = _writeCollectionBegin
|
||||||
|
writeSetEnd = _writeCollectionEnd
|
||||||
|
|
||||||
|
def writeByte(self, byte):
|
||||||
|
checkIntegerLimits(byte, 8)
|
||||||
|
self.writeJSONNumber(byte)
|
||||||
|
|
||||||
|
def writeI16(self, i16):
|
||||||
|
checkIntegerLimits(i16, 16)
|
||||||
|
self.writeJSONNumber(i16)
|
||||||
|
|
||||||
|
def writeI32(self, i32):
|
||||||
|
checkIntegerLimits(i32, 32)
|
||||||
|
self.writeJSONNumber(i32)
|
||||||
|
|
||||||
|
def writeI64(self, i64):
|
||||||
|
checkIntegerLimits(i64, 64)
|
||||||
|
self.writeJSONNumber(i64)
|
||||||
|
|
||||||
|
def writeBool(self, boolean):
|
||||||
|
self.writeJSONNumber(1 if boolean is True else 0)
|
||||||
|
|
||||||
|
def writeDouble(self, dbl):
|
||||||
|
self.writeJSONNumber(dbl)
|
||||||
|
|
||||||
|
def writeString(self, string):
|
||||||
|
self.writeJSONString(string)
|
||||||
|
|
||||||
|
def writeBinary(self, binary):
|
||||||
|
self.writeJSONBase64(binary)
|
||||||
|
|
||||||
|
|
||||||
|
class TSimpleJSONProtocolFactory(TProtocolFactory):
|
||||||
|
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
return TSimpleJSONProtocol(trans)
|
39
thrift/protocol/TMultiplexedProtocol.py
Normal file
39
thrift/protocol/TMultiplexedProtocol.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from thrift.Thrift import TMessageType
|
||||||
|
from thrift.protocol import TProtocolDecorator
|
||||||
|
|
||||||
|
SEPARATOR = ":"
|
||||||
|
|
||||||
|
|
||||||
|
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
|
||||||
|
def __init__(self, protocol, serviceName):
|
||||||
|
self.serviceName = serviceName
|
||||||
|
|
||||||
|
def writeMessageBegin(self, name, type, seqid):
|
||||||
|
if (type == TMessageType.CALL or
|
||||||
|
type == TMessageType.ONEWAY):
|
||||||
|
super(TMultiplexedProtocol, self).writeMessageBegin(
|
||||||
|
self.serviceName + SEPARATOR + name,
|
||||||
|
type,
|
||||||
|
seqid
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid)
|
|
@ -17,390 +17,412 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from thrift.Thrift import *
|
from thrift.Thrift import TException, TType, TFrozenDict
|
||||||
|
from thrift.transport.TTransport import TTransportException
|
||||||
|
from ..compat import binary_to_str, str_to_binary
|
||||||
|
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
from itertools import islice
|
||||||
|
from six.moves import zip
|
||||||
|
|
||||||
|
|
||||||
class TProtocolException(TException):
|
class TProtocolException(TException):
|
||||||
"""Custom Protocol Exception class"""
|
"""Custom Protocol Exception class"""
|
||||||
|
|
||||||
UNKNOWN = 0
|
UNKNOWN = 0
|
||||||
INVALID_DATA = 1
|
INVALID_DATA = 1
|
||||||
NEGATIVE_SIZE = 2
|
NEGATIVE_SIZE = 2
|
||||||
SIZE_LIMIT = 3
|
SIZE_LIMIT = 3
|
||||||
BAD_VERSION = 4
|
BAD_VERSION = 4
|
||||||
|
NOT_IMPLEMENTED = 5
|
||||||
|
DEPTH_LIMIT = 6
|
||||||
|
INVALID_PROTOCOL = 7
|
||||||
|
|
||||||
def __init__(self, type=UNKNOWN, message=None):
|
def __init__(self, type=UNKNOWN, message=None):
|
||||||
TException.__init__(self, message)
|
TException.__init__(self, message)
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
|
|
||||||
class TProtocolBase:
|
class TProtocolBase(object):
|
||||||
"""Base class for Thrift protocol driver."""
|
"""Base class for Thrift protocol driver."""
|
||||||
|
|
||||||
def __init__(self, trans):
|
def __init__(self, trans):
|
||||||
self.trans = trans
|
self.trans = trans
|
||||||
|
self._fast_decode = None
|
||||||
|
self._fast_encode = None
|
||||||
|
|
||||||
def writeMessageBegin(self, name, type, seqid):
|
@staticmethod
|
||||||
pass
|
def _check_length(limit, length):
|
||||||
|
if length < 0:
|
||||||
|
raise TTransportException(TTransportException.NEGATIVE_SIZE,
|
||||||
|
'Negative length: %d' % length)
|
||||||
|
if limit is not None and length > limit:
|
||||||
|
raise TTransportException(TTransportException.SIZE_LIMIT,
|
||||||
|
'Length exceeded max allowed: %d' % limit)
|
||||||
|
|
||||||
def writeMessageEnd(self):
|
def writeMessageBegin(self, name, ttype, seqid):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeStructBegin(self, name):
|
def writeMessageEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeStructEnd(self):
|
def writeStructBegin(self, name):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeFieldBegin(self, name, type, id):
|
def writeStructEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeFieldEnd(self):
|
def writeFieldBegin(self, name, ttype, fid):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeFieldStop(self):
|
def writeFieldEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeMapBegin(self, ktype, vtype, size):
|
def writeFieldStop(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeMapEnd(self):
|
def writeMapBegin(self, ktype, vtype, size):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeListBegin(self, etype, size):
|
def writeMapEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeListEnd(self):
|
def writeListBegin(self, etype, size):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeSetBegin(self, etype, size):
|
def writeListEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeSetEnd(self):
|
def writeSetBegin(self, etype, size):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeBool(self, bool):
|
def writeSetEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeByte(self, byte):
|
def writeBool(self, bool_val):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeI16(self, i16):
|
def writeByte(self, byte):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeI32(self, i32):
|
def writeI16(self, i16):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeI64(self, i64):
|
def writeI32(self, i32):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeDouble(self, dub):
|
def writeI64(self, i64):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeString(self, str):
|
def writeDouble(self, dub):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readMessageBegin(self):
|
def writeString(self, str_val):
|
||||||
pass
|
self.writeBinary(str_to_binary(str_val))
|
||||||
|
|
||||||
def readMessageEnd(self):
|
def writeBinary(self, str_val):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readStructBegin(self):
|
def writeUtf8(self, str_val):
|
||||||
pass
|
self.writeString(str_val.encode('utf8'))
|
||||||
|
|
||||||
def readStructEnd(self):
|
def readMessageBegin(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readFieldBegin(self):
|
def readMessageEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readFieldEnd(self):
|
def readStructBegin(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readMapBegin(self):
|
def readStructEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readMapEnd(self):
|
def readFieldBegin(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readListBegin(self):
|
def readFieldEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readListEnd(self):
|
def readMapBegin(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readSetBegin(self):
|
def readMapEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readSetEnd(self):
|
def readListBegin(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readBool(self):
|
def readListEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readByte(self):
|
def readSetBegin(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readI16(self):
|
def readSetEnd(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readI32(self):
|
def readBool(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readI64(self):
|
def readByte(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readDouble(self):
|
def readI16(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readString(self):
|
def readI32(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def skip(self, type):
|
def readI64(self):
|
||||||
if type == TType.STOP:
|
pass
|
||||||
return
|
|
||||||
elif type == TType.BOOL:
|
|
||||||
self.readBool()
|
|
||||||
elif type == TType.BYTE:
|
|
||||||
self.readByte()
|
|
||||||
elif type == TType.I16:
|
|
||||||
self.readI16()
|
|
||||||
elif type == TType.I32:
|
|
||||||
self.readI32()
|
|
||||||
elif type == TType.I64:
|
|
||||||
self.readI64()
|
|
||||||
elif type == TType.DOUBLE:
|
|
||||||
self.readDouble()
|
|
||||||
elif type == TType.STRING:
|
|
||||||
self.readString()
|
|
||||||
elif type == TType.STRUCT:
|
|
||||||
name = self.readStructBegin()
|
|
||||||
while True:
|
|
||||||
(name, type, id) = self.readFieldBegin()
|
|
||||||
if type == TType.STOP:
|
|
||||||
break
|
|
||||||
self.skip(type)
|
|
||||||
self.readFieldEnd()
|
|
||||||
self.readStructEnd()
|
|
||||||
elif type == TType.MAP:
|
|
||||||
(ktype, vtype, size) = self.readMapBegin()
|
|
||||||
for i in range(size):
|
|
||||||
self.skip(ktype)
|
|
||||||
self.skip(vtype)
|
|
||||||
self.readMapEnd()
|
|
||||||
elif type == TType.SET:
|
|
||||||
(etype, size) = self.readSetBegin()
|
|
||||||
for i in range(size):
|
|
||||||
self.skip(etype)
|
|
||||||
self.readSetEnd()
|
|
||||||
elif type == TType.LIST:
|
|
||||||
(etype, size) = self.readListBegin()
|
|
||||||
for i in range(size):
|
|
||||||
self.skip(etype)
|
|
||||||
self.readListEnd()
|
|
||||||
|
|
||||||
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
|
def readDouble(self):
|
||||||
_TTYPE_HANDLERS = (
|
pass
|
||||||
(None, None, False), # 0 TType.STOP
|
|
||||||
(None, None, False), # 1 TType.VOID # TODO: handle void?
|
|
||||||
('readBool', 'writeBool', False), # 2 TType.BOOL
|
|
||||||
('readByte', 'writeByte', False), # 3 TType.BYTE and I08
|
|
||||||
('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
|
|
||||||
(None, None, False), # 5 undefined
|
|
||||||
('readI16', 'writeI16', False), # 6 TType.I16
|
|
||||||
(None, None, False), # 7 undefined
|
|
||||||
('readI32', 'writeI32', False), # 8 TType.I32
|
|
||||||
(None, None, False), # 9 undefined
|
|
||||||
('readI64', 'writeI64', False), # 10 TType.I64
|
|
||||||
('readString', 'writeString', False), # 11 TType.STRING and UTF7
|
|
||||||
('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
|
|
||||||
('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
|
|
||||||
('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
|
|
||||||
('readContainerList', 'writeContainerList', True), # 15 TType.LIST
|
|
||||||
(None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
|
|
||||||
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
|
|
||||||
)
|
|
||||||
|
|
||||||
def readFieldByTType(self, ttype, spec):
|
def readString(self):
|
||||||
try:
|
return binary_to_str(self.readBinary())
|
||||||
(r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
|
|
||||||
except IndexError:
|
|
||||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
|
||||||
message='Invalid field type %d' % (ttype))
|
|
||||||
if r_handler is None:
|
|
||||||
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
|
||||||
message='Invalid field type %d' % (ttype))
|
|
||||||
reader = getattr(self, r_handler)
|
|
||||||
if not is_container:
|
|
||||||
return reader()
|
|
||||||
return reader(spec)
|
|
||||||
|
|
||||||
def readContainerList(self, spec):
|
def readBinary(self):
|
||||||
results = []
|
pass
|
||||||
ttype, tspec = spec[0], spec[1]
|
|
||||||
r_handler = self._TTYPE_HANDLERS[ttype][0]
|
|
||||||
reader = getattr(self, r_handler)
|
|
||||||
(list_type, list_len) = self.readListBegin()
|
|
||||||
if tspec is None:
|
|
||||||
# list values are simple types
|
|
||||||
for idx in range(list_len):
|
|
||||||
results.append(reader())
|
|
||||||
else:
|
|
||||||
# this is like an inlined readFieldByTType
|
|
||||||
container_reader = self._TTYPE_HANDLERS[list_type][0]
|
|
||||||
val_reader = getattr(self, container_reader)
|
|
||||||
for idx in range(list_len):
|
|
||||||
val = val_reader(tspec)
|
|
||||||
results.append(val)
|
|
||||||
self.readListEnd()
|
|
||||||
return results
|
|
||||||
|
|
||||||
def readContainerSet(self, spec):
|
def readUtf8(self):
|
||||||
results = set()
|
return self.readString().decode('utf8')
|
||||||
ttype, tspec = spec[0], spec[1]
|
|
||||||
r_handler = self._TTYPE_HANDLERS[ttype][0]
|
|
||||||
reader = getattr(self, r_handler)
|
|
||||||
(set_type, set_len) = self.readSetBegin()
|
|
||||||
if tspec is None:
|
|
||||||
# set members are simple types
|
|
||||||
for idx in range(set_len):
|
|
||||||
results.add(reader())
|
|
||||||
else:
|
|
||||||
container_reader = self._TTYPE_HANDLERS[set_type][0]
|
|
||||||
val_reader = getattr(self, container_reader)
|
|
||||||
for idx in range(set_len):
|
|
||||||
results.add(val_reader(tspec))
|
|
||||||
self.readSetEnd()
|
|
||||||
return results
|
|
||||||
|
|
||||||
def readContainerStruct(self, spec):
|
def skip(self, ttype):
|
||||||
(obj_class, obj_spec) = spec
|
if ttype == TType.BOOL:
|
||||||
obj = obj_class()
|
self.readBool()
|
||||||
obj.read(self)
|
elif ttype == TType.BYTE:
|
||||||
return obj
|
self.readByte()
|
||||||
|
elif ttype == TType.I16:
|
||||||
def readContainerMap(self, spec):
|
self.readI16()
|
||||||
results = dict()
|
elif ttype == TType.I32:
|
||||||
key_ttype, key_spec = spec[0], spec[1]
|
self.readI32()
|
||||||
val_ttype, val_spec = spec[2], spec[3]
|
elif ttype == TType.I64:
|
||||||
(map_ktype, map_vtype, map_len) = self.readMapBegin()
|
self.readI64()
|
||||||
# TODO: compare types we just decoded with thrift_spec and
|
elif ttype == TType.DOUBLE:
|
||||||
# abort/skip if types disagree
|
self.readDouble()
|
||||||
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
|
elif ttype == TType.STRING:
|
||||||
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
|
self.readString()
|
||||||
# list values are simple types
|
elif ttype == TType.STRUCT:
|
||||||
for idx in range(map_len):
|
name = self.readStructBegin()
|
||||||
if key_spec is None:
|
while True:
|
||||||
k_val = key_reader()
|
(name, ttype, id) = self.readFieldBegin()
|
||||||
else:
|
if ttype == TType.STOP:
|
||||||
k_val = self.readFieldByTType(key_ttype, key_spec)
|
break
|
||||||
if val_spec is None:
|
self.skip(ttype)
|
||||||
v_val = val_reader()
|
self.readFieldEnd()
|
||||||
else:
|
self.readStructEnd()
|
||||||
v_val = self.readFieldByTType(val_ttype, val_spec)
|
elif ttype == TType.MAP:
|
||||||
# this raises a TypeError with unhashable keys types
|
(ktype, vtype, size) = self.readMapBegin()
|
||||||
# i.e. this fails: d=dict(); d[[0,1]] = 2
|
for i in range(size):
|
||||||
results[k_val] = v_val
|
self.skip(ktype)
|
||||||
self.readMapEnd()
|
self.skip(vtype)
|
||||||
return results
|
self.readMapEnd()
|
||||||
|
elif ttype == TType.SET:
|
||||||
def readStruct(self, obj, thrift_spec):
|
(etype, size) = self.readSetBegin()
|
||||||
self.readStructBegin()
|
for i in range(size):
|
||||||
while True:
|
self.skip(etype)
|
||||||
(fname, ftype, fid) = self.readFieldBegin()
|
self.readSetEnd()
|
||||||
if ftype == TType.STOP:
|
elif ttype == TType.LIST:
|
||||||
break
|
(etype, size) = self.readListBegin()
|
||||||
try:
|
for i in range(size):
|
||||||
field = thrift_spec[fid]
|
self.skip(etype)
|
||||||
except IndexError:
|
self.readListEnd()
|
||||||
self.skip(ftype)
|
|
||||||
else:
|
|
||||||
if field is not None and ftype == field[1]:
|
|
||||||
fname = field[2]
|
|
||||||
fspec = field[3]
|
|
||||||
val = self.readFieldByTType(ftype, fspec)
|
|
||||||
setattr(obj, fname, val)
|
|
||||||
else:
|
else:
|
||||||
self.skip(ftype)
|
raise TProtocolException(
|
||||||
self.readFieldEnd()
|
TProtocolException.INVALID_DATA,
|
||||||
self.readStructEnd()
|
"invalid TType")
|
||||||
|
|
||||||
def writeContainerStruct(self, val, spec):
|
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
|
||||||
val.write(self)
|
_TTYPE_HANDLERS = (
|
||||||
|
(None, None, False), # 0 TType.STOP
|
||||||
|
(None, None, False), # 1 TType.VOID # TODO: handle void?
|
||||||
|
('readBool', 'writeBool', False), # 2 TType.BOOL
|
||||||
|
('readByte', 'writeByte', False), # 3 TType.BYTE and I08
|
||||||
|
('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
|
||||||
|
(None, None, False), # 5 undefined
|
||||||
|
('readI16', 'writeI16', False), # 6 TType.I16
|
||||||
|
(None, None, False), # 7 undefined
|
||||||
|
('readI32', 'writeI32', False), # 8 TType.I32
|
||||||
|
(None, None, False), # 9 undefined
|
||||||
|
('readI64', 'writeI64', False), # 10 TType.I64
|
||||||
|
('readString', 'writeString', False), # 11 TType.STRING and UTF7
|
||||||
|
('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
|
||||||
|
('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
|
||||||
|
('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
|
||||||
|
('readContainerList', 'writeContainerList', True), # 15 TType.LIST
|
||||||
|
(None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
|
||||||
|
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
|
||||||
|
)
|
||||||
|
|
||||||
def writeContainerList(self, val, spec):
|
def _ttype_handlers(self, ttype, spec):
|
||||||
self.writeListBegin(spec[0], len(val))
|
if spec == 'BINARY':
|
||||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
|
if ttype != TType.STRING:
|
||||||
e_writer = getattr(self, w_handler)
|
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||||
if not is_container:
|
message='Invalid binary field type %d' % ttype)
|
||||||
for elem in val:
|
return ('readBinary', 'writeBinary', False)
|
||||||
e_writer(elem)
|
if sys.version_info[0] == 2 and spec == 'UTF8':
|
||||||
else:
|
if ttype != TType.STRING:
|
||||||
for elem in val:
|
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||||
e_writer(elem, spec[1])
|
message='Invalid string field type %d' % ttype)
|
||||||
self.writeListEnd()
|
return ('readUtf8', 'writeUtf8', False)
|
||||||
|
return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
|
||||||
|
|
||||||
def writeContainerSet(self, val, spec):
|
def _read_by_ttype(self, ttype, spec, espec):
|
||||||
self.writeSetBegin(spec[0], len(val))
|
reader_name, _, is_container = self._ttype_handlers(ttype, espec)
|
||||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
|
if reader_name is None:
|
||||||
e_writer = getattr(self, w_handler)
|
raise TProtocolException(type=TProtocolException.INVALID_DATA,
|
||||||
if not is_container:
|
message='Invalid type %d' % (ttype))
|
||||||
for elem in val:
|
reader_func = getattr(self, reader_name)
|
||||||
e_writer(elem)
|
read = (lambda: reader_func(espec)) if is_container else reader_func
|
||||||
else:
|
while True:
|
||||||
for elem in val:
|
yield read()
|
||||||
e_writer(elem, spec[1])
|
|
||||||
self.writeSetEnd()
|
|
||||||
|
|
||||||
def writeContainerMap(self, val, spec):
|
def readFieldByTType(self, ttype, spec):
|
||||||
k_type = spec[0]
|
return next(self._read_by_ttype(ttype, spec, spec))
|
||||||
v_type = spec[2]
|
|
||||||
ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
|
|
||||||
ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
|
|
||||||
k_writer = getattr(self, ktype_name)
|
|
||||||
v_writer = getattr(self, vtype_name)
|
|
||||||
self.writeMapBegin(k_type, v_type, len(val))
|
|
||||||
for m_key, m_val in val.items():
|
|
||||||
if not k_is_container:
|
|
||||||
k_writer(m_key)
|
|
||||||
else:
|
|
||||||
k_writer(m_key, spec[1])
|
|
||||||
if not v_is_container:
|
|
||||||
v_writer(m_val)
|
|
||||||
else:
|
|
||||||
v_writer(m_val, spec[3])
|
|
||||||
self.writeMapEnd()
|
|
||||||
|
|
||||||
def writeStruct(self, obj, thrift_spec):
|
def readContainerList(self, spec):
|
||||||
self.writeStructBegin(obj.__class__.__name__)
|
ttype, tspec, is_immutable = spec
|
||||||
for field in thrift_spec:
|
(list_type, list_len) = self.readListBegin()
|
||||||
if field is None:
|
# TODO: compare types we just decoded with thrift_spec
|
||||||
continue
|
elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
|
||||||
fname = field[2]
|
results = (tuple if is_immutable else list)(elems)
|
||||||
val = getattr(obj, fname)
|
self.readListEnd()
|
||||||
if val is None:
|
return results
|
||||||
# skip writing out unset fields
|
|
||||||
continue
|
|
||||||
fid = field[0]
|
|
||||||
ftype = field[1]
|
|
||||||
fspec = field[3]
|
|
||||||
# get the writer method for this value
|
|
||||||
self.writeFieldBegin(fname, ftype, fid)
|
|
||||||
self.writeFieldByTType(ftype, val, fspec)
|
|
||||||
self.writeFieldEnd()
|
|
||||||
self.writeFieldStop()
|
|
||||||
self.writeStructEnd()
|
|
||||||
|
|
||||||
def writeFieldByTType(self, ttype, val, spec):
|
def readContainerSet(self, spec):
|
||||||
r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
|
ttype, tspec, is_immutable = spec
|
||||||
writer = getattr(self, w_handler)
|
(set_type, set_len) = self.readSetBegin()
|
||||||
if is_container:
|
# TODO: compare types we just decoded with thrift_spec
|
||||||
writer(val, spec)
|
elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
|
||||||
else:
|
results = (frozenset if is_immutable else set)(elems)
|
||||||
writer(val)
|
self.readSetEnd()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def readContainerStruct(self, spec):
|
||||||
|
(obj_class, obj_spec) = spec
|
||||||
|
|
||||||
|
# If obj_class.read is a classmethod (e.g. in frozen structs),
|
||||||
|
# call it as such.
|
||||||
|
if getattr(obj_class.read, '__self__', None) is obj_class:
|
||||||
|
obj = obj_class.read(self)
|
||||||
|
else:
|
||||||
|
obj = obj_class()
|
||||||
|
obj.read(self)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def readContainerMap(self, spec):
|
||||||
|
ktype, kspec, vtype, vspec, is_immutable = spec
|
||||||
|
(map_ktype, map_vtype, map_len) = self.readMapBegin()
|
||||||
|
# TODO: compare types we just decoded with thrift_spec and
|
||||||
|
# abort/skip if types disagree
|
||||||
|
keys = self._read_by_ttype(ktype, spec, kspec)
|
||||||
|
vals = self._read_by_ttype(vtype, spec, vspec)
|
||||||
|
keyvals = islice(zip(keys, vals), map_len)
|
||||||
|
results = (TFrozenDict if is_immutable else dict)(keyvals)
|
||||||
|
self.readMapEnd()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def readStruct(self, obj, thrift_spec, is_immutable=False):
|
||||||
|
if is_immutable:
|
||||||
|
fields = {}
|
||||||
|
self.readStructBegin()
|
||||||
|
while True:
|
||||||
|
(fname, ftype, fid) = self.readFieldBegin()
|
||||||
|
if ftype == TType.STOP:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
field = thrift_spec[fid]
|
||||||
|
except IndexError:
|
||||||
|
self.skip(ftype)
|
||||||
|
else:
|
||||||
|
if field is not None and ftype == field[1]:
|
||||||
|
fname = field[2]
|
||||||
|
fspec = field[3]
|
||||||
|
val = self.readFieldByTType(ftype, fspec)
|
||||||
|
if is_immutable:
|
||||||
|
fields[fname] = val
|
||||||
|
else:
|
||||||
|
setattr(obj, fname, val)
|
||||||
|
else:
|
||||||
|
self.skip(ftype)
|
||||||
|
self.readFieldEnd()
|
||||||
|
self.readStructEnd()
|
||||||
|
if is_immutable:
|
||||||
|
return obj(**fields)
|
||||||
|
|
||||||
|
def writeContainerStruct(self, val, spec):
|
||||||
|
val.write(self)
|
||||||
|
|
||||||
|
def writeContainerList(self, val, spec):
|
||||||
|
ttype, tspec, _ = spec
|
||||||
|
self.writeListBegin(ttype, len(val))
|
||||||
|
for _ in self._write_by_ttype(ttype, val, spec, tspec):
|
||||||
|
pass
|
||||||
|
self.writeListEnd()
|
||||||
|
|
||||||
|
def writeContainerSet(self, val, spec):
|
||||||
|
ttype, tspec, _ = spec
|
||||||
|
self.writeSetBegin(ttype, len(val))
|
||||||
|
for _ in self._write_by_ttype(ttype, val, spec, tspec):
|
||||||
|
pass
|
||||||
|
self.writeSetEnd()
|
||||||
|
|
||||||
|
def writeContainerMap(self, val, spec):
|
||||||
|
ktype, kspec, vtype, vspec, _ = spec
|
||||||
|
self.writeMapBegin(ktype, vtype, len(val))
|
||||||
|
for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
|
||||||
|
self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
|
||||||
|
pass
|
||||||
|
self.writeMapEnd()
|
||||||
|
|
||||||
|
def writeStruct(self, obj, thrift_spec):
|
||||||
|
self.writeStructBegin(obj.__class__.__name__)
|
||||||
|
for field in thrift_spec:
|
||||||
|
if field is None:
|
||||||
|
continue
|
||||||
|
fname = field[2]
|
||||||
|
val = getattr(obj, fname)
|
||||||
|
if val is None:
|
||||||
|
# skip writing out unset fields
|
||||||
|
continue
|
||||||
|
fid = field[0]
|
||||||
|
ftype = field[1]
|
||||||
|
fspec = field[3]
|
||||||
|
self.writeFieldBegin(fname, ftype, fid)
|
||||||
|
self.writeFieldByTType(ftype, val, fspec)
|
||||||
|
self.writeFieldEnd()
|
||||||
|
self.writeFieldStop()
|
||||||
|
self.writeStructEnd()
|
||||||
|
|
||||||
|
def _write_by_ttype(self, ttype, vals, spec, espec):
|
||||||
|
_, writer_name, is_container = self._ttype_handlers(ttype, espec)
|
||||||
|
writer_func = getattr(self, writer_name)
|
||||||
|
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
|
||||||
|
for v in vals:
|
||||||
|
yield write(v)
|
||||||
|
|
||||||
|
def writeFieldByTType(self, ttype, val, spec):
|
||||||
|
next(self._write_by_ttype(ttype, [val], spec, spec))
|
||||||
|
|
||||||
|
|
||||||
class TProtocolFactory:
|
def checkIntegerLimits(i, bits):
|
||||||
def getProtocol(self, trans):
|
if bits == 8 and (i < -128 or i > 127):
|
||||||
pass
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"i8 requires -128 <= number <= 127")
|
||||||
|
elif bits == 16 and (i < -32768 or i > 32767):
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"i16 requires -32768 <= number <= 32767")
|
||||||
|
elif bits == 32 and (i < -2147483648 or i > 2147483647):
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"i32 requires -2147483648 <= number <= 2147483647")
|
||||||
|
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
|
||||||
|
raise TProtocolException(TProtocolException.INVALID_DATA,
|
||||||
|
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
|
||||||
|
|
||||||
|
|
||||||
|
class TProtocolFactory(object):
|
||||||
|
def getProtocol(self, trans):
|
||||||
|
pass
|
||||||
|
|
26
thrift/protocol/TProtocolDecorator.py
Normal file
26
thrift/protocol/TProtocolDecorator.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
class TProtocolDecorator(object):
|
||||||
|
def __new__(cls, protocol, *args, **kwargs):
|
||||||
|
decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]),
|
||||||
|
(cls, protocol.__class__),
|
||||||
|
protocol.__dict__)
|
||||||
|
return object.__new__(decorated_cls)
|
|
@ -17,4 +17,5 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']
|
__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
|
||||||
|
'TJSONProtocol', 'TProtocol', 'TProtocolDecorator']
|
||||||
|
|
BIN
thrift/protocol/fastbinary.cpython-38-darwin.so
Executable file
BIN
thrift/protocol/fastbinary.cpython-38-darwin.so
Executable file
Binary file not shown.
|
@ -17,71 +17,115 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import http.server
|
import ssl
|
||||||
|
|
||||||
|
from six.moves import BaseHTTPServer
|
||||||
|
|
||||||
|
from thrift.Thrift import TMessageType
|
||||||
from thrift.server import TServer
|
from thrift.server import TServer
|
||||||
from thrift.transport import TTransport
|
from thrift.transport import TTransport
|
||||||
|
|
||||||
|
|
||||||
class ResponseException(Exception):
|
class ResponseException(Exception):
|
||||||
"""Allows handlers to override the HTTP response
|
"""Allows handlers to override the HTTP response
|
||||||
|
|
||||||
Normally, THttpServer always sends a 200 response. If a handler wants
|
Normally, THttpServer always sends a 200 response. If a handler wants
|
||||||
to override this behavior (e.g., to simulate a misconfigured or
|
to override this behavior (e.g., to simulate a misconfigured or
|
||||||
overloaded web server during testing), it can raise a ResponseException.
|
overloaded web server during testing), it can raise a ResponseException.
|
||||||
The function passed to the constructor will be called with the
|
The function passed to the constructor will be called with the
|
||||||
RequestHandler as its only argument.
|
RequestHandler as its only argument. Note that this is irrelevant
|
||||||
"""
|
for ONEWAY requests, as the HTTP response must be sent before the
|
||||||
def __init__(self, handler):
|
RPC is processed.
|
||||||
self.handler = handler
|
"""
|
||||||
|
def __init__(self, handler):
|
||||||
|
self.handler = handler
|
||||||
|
|
||||||
|
|
||||||
class THttpServer(TServer.TServer):
|
class THttpServer(TServer.TServer):
|
||||||
"""A simple HTTP-based Thrift server
|
"""A simple HTTP-based Thrift server
|
||||||
|
|
||||||
This class is not very performant, but it is useful (for example) for
|
This class is not very performant, but it is useful (for example) for
|
||||||
acting as a mock version of an Apache-based PHP Thrift endpoint.
|
acting as a mock version of an Apache-based PHP Thrift endpoint.
|
||||||
"""
|
Also important to note the HTTP implementation pretty much violates the
|
||||||
def __init__(self,
|
transport/protocol/processor/server layering, by performing the transport
|
||||||
processor,
|
functions here. This means things like oneway handling are oddly exposed.
|
||||||
server_address,
|
|
||||||
inputProtocolFactory,
|
|
||||||
outputProtocolFactory=None,
|
|
||||||
server_class=http.server.HTTPServer):
|
|
||||||
"""Set up protocol factories and HTTP server.
|
|
||||||
|
|
||||||
See BaseHTTPServer for server_address.
|
|
||||||
See TServer for protocol factories.
|
|
||||||
"""
|
"""
|
||||||
if outputProtocolFactory is None:
|
def __init__(self,
|
||||||
outputProtocolFactory = inputProtocolFactory
|
processor,
|
||||||
|
server_address,
|
||||||
|
inputProtocolFactory,
|
||||||
|
outputProtocolFactory=None,
|
||||||
|
server_class=BaseHTTPServer.HTTPServer,
|
||||||
|
**kwargs):
|
||||||
|
"""Set up protocol factories and HTTP (or HTTPS) server.
|
||||||
|
|
||||||
TServer.TServer.__init__(self, processor, None, None, None,
|
See BaseHTTPServer for server_address.
|
||||||
inputProtocolFactory, outputProtocolFactory)
|
See TServer for protocol factories.
|
||||||
|
|
||||||
thttpserver = self
|
To make a secure server, provide the named arguments:
|
||||||
|
* cafile - to validate clients [optional]
|
||||||
|
* cert_file - the server cert
|
||||||
|
* key_file - the server's key
|
||||||
|
"""
|
||||||
|
if outputProtocolFactory is None:
|
||||||
|
outputProtocolFactory = inputProtocolFactory
|
||||||
|
|
||||||
class RequestHander(http.server.BaseHTTPRequestHandler):
|
TServer.TServer.__init__(self, processor, None, None, None,
|
||||||
def do_POST(self):
|
inputProtocolFactory, outputProtocolFactory)
|
||||||
# Don't care about the request path.
|
|
||||||
itrans = TTransport.TFileObjectTransport(self.rfile)
|
|
||||||
otrans = TTransport.TFileObjectTransport(self.wfile)
|
|
||||||
itrans = TTransport.TBufferedTransport(
|
|
||||||
itrans, int(self.headers['Content-Length']))
|
|
||||||
otrans = TTransport.TMemoryBuffer()
|
|
||||||
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
|
|
||||||
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
|
|
||||||
try:
|
|
||||||
thttpserver.processor.process(iprot, oprot)
|
|
||||||
except ResponseException as exn:
|
|
||||||
exn.handler(self)
|
|
||||||
else:
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_header("content-type", "application/x-thrift")
|
|
||||||
self.end_headers()
|
|
||||||
self.wfile.write(otrans.getvalue())
|
|
||||||
|
|
||||||
self.httpd = server_class(server_address, RequestHander)
|
thttpserver = self
|
||||||
|
self._replied = None
|
||||||
|
|
||||||
def serve(self):
|
class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||||
self.httpd.serve_forever()
|
def do_POST(self):
|
||||||
|
# Don't care about the request path.
|
||||||
|
thttpserver._replied = False
|
||||||
|
iftrans = TTransport.TFileObjectTransport(self.rfile)
|
||||||
|
itrans = TTransport.TBufferedTransport(
|
||||||
|
iftrans, int(self.headers['Content-Length']))
|
||||||
|
otrans = TTransport.TMemoryBuffer()
|
||||||
|
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
|
||||||
|
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
|
||||||
|
try:
|
||||||
|
thttpserver.processor.on_message_begin(self.on_begin)
|
||||||
|
thttpserver.processor.process(iprot, oprot)
|
||||||
|
except ResponseException as exn:
|
||||||
|
exn.handler(self)
|
||||||
|
else:
|
||||||
|
if not thttpserver._replied:
|
||||||
|
# If the request was ONEWAY we would have replied already
|
||||||
|
data = otrans.getvalue()
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-Length", len(data))
|
||||||
|
self.send_header("Content-Type", "application/x-thrift")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(data)
|
||||||
|
|
||||||
|
def on_begin(self, name, type, seqid):
|
||||||
|
"""
|
||||||
|
Inspect the message header.
|
||||||
|
|
||||||
|
This allows us to post an immediate transport response
|
||||||
|
if the request is a ONEWAY message type.
|
||||||
|
"""
|
||||||
|
if type == TMessageType.ONEWAY:
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-Type", "application/x-thrift")
|
||||||
|
self.end_headers()
|
||||||
|
thttpserver._replied = True
|
||||||
|
|
||||||
|
self.httpd = server_class(server_address, RequestHander)
|
||||||
|
|
||||||
|
if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
|
||||||
|
context = ssl.create_default_context(cafile=kwargs.get('cafile'))
|
||||||
|
context.check_hostname = False
|
||||||
|
context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
|
||||||
|
context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
|
||||||
|
self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
|
||||||
|
|
||||||
|
def serve(self):
|
||||||
|
self.httpd.serve_forever()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.httpd.socket.close()
|
||||||
|
# self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!
|
||||||
|
|
|
@ -24,18 +24,23 @@ only from the main thread.
|
||||||
The thread poool should be sized for concurrent tasks, not
|
The thread poool should be sized for concurrent tasks, not
|
||||||
maximum connections
|
maximum connections
|
||||||
"""
|
"""
|
||||||
import threading
|
|
||||||
import socket
|
|
||||||
import queue
|
|
||||||
import select
|
|
||||||
import struct
|
|
||||||
import logging
|
import logging
|
||||||
|
import select
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from six.moves import queue
|
||||||
|
|
||||||
from thrift.transport import TTransport
|
from thrift.transport import TTransport
|
||||||
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
|
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
|
||||||
|
|
||||||
__all__ = ['TNonblockingServer']
|
__all__ = ['TNonblockingServer']
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Worker(threading.Thread):
|
class Worker(threading.Thread):
|
||||||
"""Worker is a small helper to process incoming connection."""
|
"""Worker is a small helper to process incoming connection."""
|
||||||
|
@ -54,8 +59,9 @@ class Worker(threading.Thread):
|
||||||
processor.process(iprot, oprot)
|
processor.process(iprot, oprot)
|
||||||
callback(True, otrans.getvalue())
|
callback(True, otrans.getvalue())
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Exception while processing request")
|
logger.exception("Exception while processing request", exc_info=True)
|
||||||
callback(False, '')
|
callback(False, b'')
|
||||||
|
|
||||||
|
|
||||||
WAIT_LEN = 0
|
WAIT_LEN = 0
|
||||||
WAIT_MESSAGE = 1
|
WAIT_MESSAGE = 1
|
||||||
|
@ -81,11 +87,24 @@ def socket_exception(func):
|
||||||
try:
|
try:
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
except socket.error:
|
except socket.error:
|
||||||
|
logger.debug('ignoring socket exception', exc_info=True)
|
||||||
self.close()
|
self.close()
|
||||||
return read
|
return read
|
||||||
|
|
||||||
|
|
||||||
class Connection:
|
class Message(object):
|
||||||
|
def __init__(self, offset, len_, header):
|
||||||
|
self.offset = offset
|
||||||
|
self.len = len_
|
||||||
|
self.buffer = None
|
||||||
|
self.is_header = header
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end(self):
|
||||||
|
return self.offset + self.len
|
||||||
|
|
||||||
|
|
||||||
|
class Connection(object):
|
||||||
"""Basic class is represented connection.
|
"""Basic class is represented connection.
|
||||||
|
|
||||||
It can be in state:
|
It can be in state:
|
||||||
|
@ -102,68 +121,60 @@ class Connection:
|
||||||
self.socket.setblocking(False)
|
self.socket.setblocking(False)
|
||||||
self.status = WAIT_LEN
|
self.status = WAIT_LEN
|
||||||
self.len = 0
|
self.len = 0
|
||||||
self.message = ''
|
self.received = deque()
|
||||||
|
self._reading = Message(0, 4, True)
|
||||||
|
self._rbuf = b''
|
||||||
|
self._wbuf = b''
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.wake_up = wake_up
|
self.wake_up = wake_up
|
||||||
|
self.remaining = False
|
||||||
def _read_len(self):
|
|
||||||
"""Reads length of request.
|
|
||||||
|
|
||||||
It's a safer alternative to self.socket.recv(4)
|
|
||||||
"""
|
|
||||||
read = self.socket.recv(4 - len(self.message))
|
|
||||||
if len(read) == 0:
|
|
||||||
# if we read 0 bytes and self.message is empty, then
|
|
||||||
# the client closed the connection
|
|
||||||
if len(self.message) != 0:
|
|
||||||
logging.error("can't read frame size from socket")
|
|
||||||
self.close()
|
|
||||||
return
|
|
||||||
self.message += read
|
|
||||||
if len(self.message) == 4:
|
|
||||||
self.len, = struct.unpack('!i', self.message)
|
|
||||||
if self.len < 0:
|
|
||||||
logging.error("negative frame size, it seems client "
|
|
||||||
"doesn't use FramedTransport")
|
|
||||||
self.close()
|
|
||||||
elif self.len == 0:
|
|
||||||
logging.error("empty frame, it's really strange")
|
|
||||||
self.close()
|
|
||||||
else:
|
|
||||||
self.message = ''
|
|
||||||
self.status = WAIT_MESSAGE
|
|
||||||
|
|
||||||
@socket_exception
|
@socket_exception
|
||||||
def read(self):
|
def read(self):
|
||||||
"""Reads data from stream and switch state."""
|
"""Reads data from stream and switch state."""
|
||||||
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
|
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
|
||||||
if self.status == WAIT_LEN:
|
assert not self.received
|
||||||
self._read_len()
|
buf_size = 8192
|
||||||
# go back to the main loop here for simplicity instead of
|
first = True
|
||||||
# falling through, even though there is a good chance that
|
done = False
|
||||||
# the message is already available
|
while not done:
|
||||||
elif self.status == WAIT_MESSAGE:
|
read = self.socket.recv(buf_size)
|
||||||
read = self.socket.recv(self.len - len(self.message))
|
rlen = len(read)
|
||||||
if len(read) == 0:
|
done = rlen < buf_size
|
||||||
logging.error("can't read frame from socket (get %d of "
|
self._rbuf += read
|
||||||
"%d bytes)" % (len(self.message), self.len))
|
if first and rlen == 0:
|
||||||
|
if self.status != WAIT_LEN or self._rbuf:
|
||||||
|
logger.error('could not read frame from socket')
|
||||||
|
else:
|
||||||
|
logger.debug('read zero length. client might have disconnected')
|
||||||
self.close()
|
self.close()
|
||||||
return
|
while len(self._rbuf) >= self._reading.end:
|
||||||
self.message += read
|
if self._reading.is_header:
|
||||||
if len(self.message) == self.len:
|
mlen, = struct.unpack('!i', self._rbuf[:4])
|
||||||
|
self._reading = Message(self._reading.end, mlen, False)
|
||||||
|
self.status = WAIT_MESSAGE
|
||||||
|
else:
|
||||||
|
self._reading.buffer = self._rbuf
|
||||||
|
self.received.append(self._reading)
|
||||||
|
self._rbuf = self._rbuf[self._reading.end:]
|
||||||
|
self._reading = Message(0, 4, True)
|
||||||
|
first = False
|
||||||
|
if self.received:
|
||||||
self.status = WAIT_PROCESS
|
self.status = WAIT_PROCESS
|
||||||
|
break
|
||||||
|
self.remaining = not done
|
||||||
|
|
||||||
@socket_exception
|
@socket_exception
|
||||||
def write(self):
|
def write(self):
|
||||||
"""Writes data from socket and switch state."""
|
"""Writes data from socket and switch state."""
|
||||||
assert self.status == SEND_ANSWER
|
assert self.status == SEND_ANSWER
|
||||||
sent = self.socket.send(self.message)
|
sent = self.socket.send(self._wbuf)
|
||||||
if sent == len(self.message):
|
if sent == len(self._wbuf):
|
||||||
self.status = WAIT_LEN
|
self.status = WAIT_LEN
|
||||||
self.message = ''
|
self._wbuf = b''
|
||||||
self.len = 0
|
self.len = 0
|
||||||
else:
|
else:
|
||||||
self.message = self.message[sent:]
|
self._wbuf = self._wbuf[sent:]
|
||||||
|
|
||||||
@locked
|
@locked
|
||||||
def ready(self, all_ok, message):
|
def ready(self, all_ok, message):
|
||||||
|
@ -183,13 +194,13 @@ class Connection:
|
||||||
self.close()
|
self.close()
|
||||||
self.wake_up()
|
self.wake_up()
|
||||||
return
|
return
|
||||||
self.len = ''
|
self.len = 0
|
||||||
if len(message) == 0:
|
if len(message) == 0:
|
||||||
# it was a oneway request, do not write answer
|
# it was a oneway request, do not write answer
|
||||||
self.message = ''
|
self._wbuf = b''
|
||||||
self.status = WAIT_LEN
|
self.status = WAIT_LEN
|
||||||
else:
|
else:
|
||||||
self.message = struct.pack('!i', len(message)) + message
|
self._wbuf = struct.pack('!i', len(message)) + message
|
||||||
self.status = SEND_ANSWER
|
self.status = SEND_ANSWER
|
||||||
self.wake_up()
|
self.wake_up()
|
||||||
|
|
||||||
|
@ -219,7 +230,7 @@ class Connection:
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|
||||||
|
|
||||||
class TNonblockingServer:
|
class TNonblockingServer(object):
|
||||||
"""Non-blocking server."""
|
"""Non-blocking server."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -259,7 +270,7 @@ class TNonblockingServer:
|
||||||
def wake_up(self):
|
def wake_up(self):
|
||||||
"""Wake up main thread.
|
"""Wake up main thread.
|
||||||
|
|
||||||
The server usualy waits in select call in we should terminate one.
|
The server usually waits in select call in we should terminate one.
|
||||||
The simplest way is using socketpair.
|
The simplest way is using socketpair.
|
||||||
|
|
||||||
Select always wait to read from the first socket of socketpair.
|
Select always wait to read from the first socket of socketpair.
|
||||||
|
@ -267,7 +278,7 @@ class TNonblockingServer:
|
||||||
In this case, we can just write anything to the second socket from
|
In this case, we can just write anything to the second socket from
|
||||||
socketpair.
|
socketpair.
|
||||||
"""
|
"""
|
||||||
self._write.send('1')
|
self._write.send(b'1')
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the server.
|
"""Stop the server.
|
||||||
|
@ -288,14 +299,20 @@ class TNonblockingServer:
|
||||||
"""Does select on open connections."""
|
"""Does select on open connections."""
|
||||||
readable = [self.socket.handle.fileno(), self._read.fileno()]
|
readable = [self.socket.handle.fileno(), self._read.fileno()]
|
||||||
writable = []
|
writable = []
|
||||||
|
remaining = []
|
||||||
for i, connection in list(self.clients.items()):
|
for i, connection in list(self.clients.items()):
|
||||||
if connection.is_readable():
|
if connection.is_readable():
|
||||||
readable.append(connection.fileno())
|
readable.append(connection.fileno())
|
||||||
|
if connection.remaining or connection.received:
|
||||||
|
remaining.append(connection.fileno())
|
||||||
if connection.is_writeable():
|
if connection.is_writeable():
|
||||||
writable.append(connection.fileno())
|
writable.append(connection.fileno())
|
||||||
if connection.is_closed():
|
if connection.is_closed():
|
||||||
del self.clients[i]
|
del self.clients[i]
|
||||||
return select.select(readable, writable, readable)
|
if remaining:
|
||||||
|
return remaining, [], [], False
|
||||||
|
else:
|
||||||
|
return select.select(readable, writable, readable) + (True,)
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
"""Handle requests.
|
"""Handle requests.
|
||||||
|
@ -303,20 +320,27 @@ class TNonblockingServer:
|
||||||
WARNING! You must call prepare() BEFORE calling handle()
|
WARNING! You must call prepare() BEFORE calling handle()
|
||||||
"""
|
"""
|
||||||
assert self.prepared, "You have to call prepare before handle"
|
assert self.prepared, "You have to call prepare before handle"
|
||||||
rset, wset, xset = self._select()
|
rset, wset, xset, selected = self._select()
|
||||||
for readable in rset:
|
for readable in rset:
|
||||||
if readable == self._read.fileno():
|
if readable == self._read.fileno():
|
||||||
# don't care i just need to clean readable flag
|
# don't care i just need to clean readable flag
|
||||||
self._read.recv(1024)
|
self._read.recv(1024)
|
||||||
elif readable == self.socket.handle.fileno():
|
elif readable == self.socket.handle.fileno():
|
||||||
client = self.socket.accept().handle
|
try:
|
||||||
self.clients[client.fileno()] = Connection(client,
|
client = self.socket.accept()
|
||||||
self.wake_up)
|
if client:
|
||||||
|
self.clients[client.handle.fileno()] = Connection(client.handle,
|
||||||
|
self.wake_up)
|
||||||
|
except socket.error:
|
||||||
|
logger.debug('error while accepting', exc_info=True)
|
||||||
else:
|
else:
|
||||||
connection = self.clients[readable]
|
connection = self.clients[readable]
|
||||||
connection.read()
|
if selected:
|
||||||
if connection.status == WAIT_PROCESS:
|
connection.read()
|
||||||
itransport = TTransport.TMemoryBuffer(connection.message)
|
if connection.received:
|
||||||
|
connection.status = WAIT_PROCESS
|
||||||
|
msg = connection.received.popleft()
|
||||||
|
itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset)
|
||||||
otransport = TTransport.TMemoryBuffer()
|
otransport = TTransport.TMemoryBuffer()
|
||||||
iprot = self.in_protocol.getProtocol(itransport)
|
iprot = self.in_protocol.getProtocol(itransport)
|
||||||
oprot = self.out_protocol.getProtocol(otransport)
|
oprot = self.out_protocol.getProtocol(otransport)
|
||||||
|
|
|
@ -19,11 +19,13 @@
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Process, Value, Condition, reduction
|
|
||||||
|
from multiprocessing import Process, Value, Condition
|
||||||
|
|
||||||
from .TServer import TServer
|
from .TServer import TServer
|
||||||
from thrift.transport.TTransport import TTransportException
|
from thrift.transport.TTransport import TTransportException
|
||||||
import collections.abc
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TProcessPoolServer(TServer):
|
class TProcessPoolServer(TServer):
|
||||||
|
@ -41,7 +43,7 @@ class TProcessPoolServer(TServer):
|
||||||
self.postForkCallback = None
|
self.postForkCallback = None
|
||||||
|
|
||||||
def setPostForkCallback(self, callback):
|
def setPostForkCallback(self, callback):
|
||||||
if not isinstance(callback, collections.abc.Callable):
|
if not callable(callback):
|
||||||
raise TypeError("This is not a callback!")
|
raise TypeError("This is not a callback!")
|
||||||
self.postForkCallback = callback
|
self.postForkCallback = callback
|
||||||
|
|
||||||
|
@ -57,11 +59,13 @@ class TProcessPoolServer(TServer):
|
||||||
while self.isRunning.value:
|
while self.isRunning.value:
|
||||||
try:
|
try:
|
||||||
client = self.serverTransport.accept()
|
client = self.serverTransport.accept()
|
||||||
|
if not client:
|
||||||
|
continue
|
||||||
self.serveClient(client)
|
self.serveClient(client)
|
||||||
except (KeyboardInterrupt, SystemExit):
|
except (KeyboardInterrupt, SystemExit):
|
||||||
return 0
|
return 0
|
||||||
except Exception as x:
|
except Exception as x:
|
||||||
logging.exception(x)
|
logger.exception(x)
|
||||||
|
|
||||||
def serveClient(self, client):
|
def serveClient(self, client):
|
||||||
"""Process input/output from a client for as long as possible"""
|
"""Process input/output from a client for as long as possible"""
|
||||||
|
@ -73,10 +77,10 @@ class TProcessPoolServer(TServer):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
self.processor.process(iprot, oprot)
|
self.processor.process(iprot, oprot)
|
||||||
except TTransportException as tx:
|
except TTransportException:
|
||||||
pass
|
pass
|
||||||
except Exception as x:
|
except Exception as x:
|
||||||
logging.exception(x)
|
logger.exception(x)
|
||||||
|
|
||||||
itrans.close()
|
itrans.close()
|
||||||
otrans.close()
|
otrans.close()
|
||||||
|
@ -97,7 +101,7 @@ class TProcessPoolServer(TServer):
|
||||||
w.start()
|
w.start()
|
||||||
self.workers.append(w)
|
self.workers.append(w)
|
||||||
except Exception as x:
|
except Exception as x:
|
||||||
logging.exception(x)
|
logger.exception(x)
|
||||||
|
|
||||||
# wait until the condition is set by stop()
|
# wait until the condition is set by stop()
|
||||||
while True:
|
while True:
|
||||||
|
@ -108,7 +112,7 @@ class TProcessPoolServer(TServer):
|
||||||
except (SystemExit, KeyboardInterrupt):
|
except (SystemExit, KeyboardInterrupt):
|
||||||
break
|
break
|
||||||
except Exception as x:
|
except Exception as x:
|
||||||
logging.exception(x)
|
logger.exception(x)
|
||||||
|
|
||||||
self.isRunning.value = False
|
self.isRunning.value = False
|
||||||
|
|
||||||
|
|
|
@ -17,253 +17,307 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import queue
|
from six.moves import queue
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
|
||||||
|
|
||||||
from thrift.Thrift import TProcessor
|
|
||||||
from thrift.protocol import TBinaryProtocol
|
from thrift.protocol import TBinaryProtocol
|
||||||
|
from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
|
||||||
from thrift.transport import TTransport
|
from thrift.transport import TTransport
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TServer:
|
|
||||||
"""Base interface for a server, which must have a serve() method.
|
|
||||||
|
|
||||||
Three constructors for all servers:
|
class TServer(object):
|
||||||
1) (processor, serverTransport)
|
"""Base interface for a server, which must have a serve() method.
|
||||||
2) (processor, serverTransport, transportFactory, protocolFactory)
|
|
||||||
3) (processor, serverTransport,
|
|
||||||
inputTransportFactory, outputTransportFactory,
|
|
||||||
inputProtocolFactory, outputProtocolFactory)
|
|
||||||
"""
|
|
||||||
def __init__(self, *args):
|
|
||||||
if (len(args) == 2):
|
|
||||||
self.__initArgs__(args[0], args[1],
|
|
||||||
TTransport.TTransportFactoryBase(),
|
|
||||||
TTransport.TTransportFactoryBase(),
|
|
||||||
TBinaryProtocol.TBinaryProtocolFactory(),
|
|
||||||
TBinaryProtocol.TBinaryProtocolFactory())
|
|
||||||
elif (len(args) == 4):
|
|
||||||
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
|
|
||||||
elif (len(args) == 6):
|
|
||||||
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
|
|
||||||
|
|
||||||
def __initArgs__(self, processor, serverTransport,
|
Three constructors for all servers:
|
||||||
inputTransportFactory, outputTransportFactory,
|
1) (processor, serverTransport)
|
||||||
inputProtocolFactory, outputProtocolFactory):
|
2) (processor, serverTransport, transportFactory, protocolFactory)
|
||||||
self.processor = processor
|
3) (processor, serverTransport,
|
||||||
self.serverTransport = serverTransport
|
inputTransportFactory, outputTransportFactory,
|
||||||
self.inputTransportFactory = inputTransportFactory
|
inputProtocolFactory, outputProtocolFactory)
|
||||||
self.outputTransportFactory = outputTransportFactory
|
"""
|
||||||
self.inputProtocolFactory = inputProtocolFactory
|
def __init__(self, *args):
|
||||||
self.outputProtocolFactory = outputProtocolFactory
|
if (len(args) == 2):
|
||||||
|
self.__initArgs__(args[0], args[1],
|
||||||
|
TTransport.TTransportFactoryBase(),
|
||||||
|
TTransport.TTransportFactoryBase(),
|
||||||
|
TBinaryProtocol.TBinaryProtocolFactory(),
|
||||||
|
TBinaryProtocol.TBinaryProtocolFactory())
|
||||||
|
elif (len(args) == 4):
|
||||||
|
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
|
||||||
|
elif (len(args) == 6):
|
||||||
|
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
|
||||||
|
|
||||||
def serve(self):
|
def __initArgs__(self, processor, serverTransport,
|
||||||
pass
|
inputTransportFactory, outputTransportFactory,
|
||||||
|
inputProtocolFactory, outputProtocolFactory):
|
||||||
|
self.processor = processor
|
||||||
|
self.serverTransport = serverTransport
|
||||||
|
self.inputTransportFactory = inputTransportFactory
|
||||||
|
self.outputTransportFactory = outputTransportFactory
|
||||||
|
self.inputProtocolFactory = inputProtocolFactory
|
||||||
|
self.outputProtocolFactory = outputProtocolFactory
|
||||||
|
|
||||||
|
input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
|
||||||
|
output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
|
||||||
|
if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
|
||||||
|
raise ValueError("THeaderProtocol servers require that both the input and "
|
||||||
|
"output protocols are THeaderProtocol.")
|
||||||
|
|
||||||
|
def serve(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TSimpleServer(TServer):
|
class TSimpleServer(TServer):
|
||||||
"""Simple single-threaded server that just pumps around one transport."""
|
"""Simple single-threaded server that just pumps around one transport."""
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
TServer.__init__(self, *args)
|
TServer.__init__(self, *args)
|
||||||
|
|
||||||
def serve(self):
|
def serve(self):
|
||||||
self.serverTransport.listen()
|
self.serverTransport.listen()
|
||||||
while True:
|
|
||||||
client = self.serverTransport.accept()
|
|
||||||
itrans = self.inputTransportFactory.getTransport(client)
|
|
||||||
otrans = self.outputTransportFactory.getTransport(client)
|
|
||||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
|
||||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
|
||||||
try:
|
|
||||||
while True:
|
while True:
|
||||||
self.processor.process(iprot, oprot)
|
client = self.serverTransport.accept()
|
||||||
except TTransport.TTransportException as tx:
|
if not client:
|
||||||
pass
|
continue
|
||||||
except Exception as x:
|
|
||||||
logging.exception(x)
|
|
||||||
|
|
||||||
itrans.close()
|
itrans = self.inputTransportFactory.getTransport(client)
|
||||||
otrans.close()
|
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||||
|
|
||||||
|
# for THeaderProtocol, we must use the same protocol instance for
|
||||||
|
# input and output so that the response is in the same dialect that
|
||||||
|
# the server detected the request was in.
|
||||||
|
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
|
||||||
|
otrans = None
|
||||||
|
oprot = iprot
|
||||||
|
else:
|
||||||
|
otrans = self.outputTransportFactory.getTransport(client)
|
||||||
|
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.processor.process(iprot, oprot)
|
||||||
|
except TTransport.TTransportException:
|
||||||
|
pass
|
||||||
|
except Exception as x:
|
||||||
|
logger.exception(x)
|
||||||
|
|
||||||
|
itrans.close()
|
||||||
|
if otrans:
|
||||||
|
otrans.close()
|
||||||
|
|
||||||
|
|
||||||
class TThreadedServer(TServer):
|
class TThreadedServer(TServer):
|
||||||
"""Threaded server that spawns a new thread per each connection."""
|
"""Threaded server that spawns a new thread per each connection."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
TServer.__init__(self, *args)
|
TServer.__init__(self, *args)
|
||||||
self.daemon = kwargs.get("daemon", False)
|
self.daemon = kwargs.get("daemon", False)
|
||||||
|
|
||||||
def serve(self):
|
def serve(self):
|
||||||
self.serverTransport.listen()
|
self.serverTransport.listen()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
client = self.serverTransport.accept()
|
client = self.serverTransport.accept()
|
||||||
t = threading.Thread(target=self.handle, args=(client,))
|
if not client:
|
||||||
t.setDaemon(self.daemon)
|
continue
|
||||||
t.start()
|
t = threading.Thread(target=self.handle, args=(client,))
|
||||||
except KeyboardInterrupt:
|
t.setDaemon(self.daemon)
|
||||||
raise
|
t.start()
|
||||||
except Exception as x:
|
except KeyboardInterrupt:
|
||||||
logging.exception(x)
|
raise
|
||||||
|
except Exception as x:
|
||||||
|
logger.exception(x)
|
||||||
|
|
||||||
def handle(self, client):
|
def handle(self, client):
|
||||||
itrans = self.inputTransportFactory.getTransport(client)
|
itrans = self.inputTransportFactory.getTransport(client)
|
||||||
otrans = self.outputTransportFactory.getTransport(client)
|
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
|
||||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
self.processor.process(iprot, oprot)
|
|
||||||
except TTransport.TTransportException as tx:
|
|
||||||
pass
|
|
||||||
except Exception as x:
|
|
||||||
logging.exception(x)
|
|
||||||
|
|
||||||
itrans.close()
|
# for THeaderProtocol, we must use the same protocol instance for input
|
||||||
otrans.close()
|
# and output so that the response is in the same dialect that the
|
||||||
|
# server detected the request was in.
|
||||||
|
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
|
||||||
|
otrans = None
|
||||||
|
oprot = iprot
|
||||||
|
else:
|
||||||
|
otrans = self.outputTransportFactory.getTransport(client)
|
||||||
|
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.processor.process(iprot, oprot)
|
||||||
|
except TTransport.TTransportException:
|
||||||
|
pass
|
||||||
|
except Exception as x:
|
||||||
|
logger.exception(x)
|
||||||
|
|
||||||
|
itrans.close()
|
||||||
|
if otrans:
|
||||||
|
otrans.close()
|
||||||
|
|
||||||
|
|
||||||
class TThreadPoolServer(TServer):
|
class TThreadPoolServer(TServer):
|
||||||
"""Server with a fixed size pool of threads which service requests."""
|
"""Server with a fixed size pool of threads which service requests."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
TServer.__init__(self, *args)
|
TServer.__init__(self, *args)
|
||||||
self.clients = queue.Queue()
|
self.clients = queue.Queue()
|
||||||
self.threads = 10
|
self.threads = 10
|
||||||
self.daemon = kwargs.get("daemon", False)
|
self.daemon = kwargs.get("daemon", False)
|
||||||
|
|
||||||
def setNumThreads(self, num):
|
def setNumThreads(self, num):
|
||||||
"""Set the number of worker threads that should be created"""
|
"""Set the number of worker threads that should be created"""
|
||||||
self.threads = num
|
self.threads = num
|
||||||
|
|
||||||
def serveThread(self):
|
def serveThread(self):
|
||||||
"""Loop around getting clients from the shared queue and process them."""
|
"""Loop around getting clients from the shared queue and process them."""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
client = self.clients.get()
|
client = self.clients.get()
|
||||||
self.serveClient(client)
|
self.serveClient(client)
|
||||||
except Exception as x:
|
except Exception as x:
|
||||||
logging.exception(x)
|
logger.exception(x)
|
||||||
|
|
||||||
def serveClient(self, client):
|
def serveClient(self, client):
|
||||||
"""Process input/output from a client for as long as possible"""
|
"""Process input/output from a client for as long as possible"""
|
||||||
itrans = self.inputTransportFactory.getTransport(client)
|
itrans = self.inputTransportFactory.getTransport(client)
|
||||||
otrans = self.outputTransportFactory.getTransport(client)
|
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
|
||||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
self.processor.process(iprot, oprot)
|
|
||||||
except TTransport.TTransportException as tx:
|
|
||||||
pass
|
|
||||||
except Exception as x:
|
|
||||||
logging.exception(x)
|
|
||||||
|
|
||||||
itrans.close()
|
# for THeaderProtocol, we must use the same protocol instance for input
|
||||||
otrans.close()
|
# and output so that the response is in the same dialect that the
|
||||||
|
# server detected the request was in.
|
||||||
|
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
|
||||||
|
otrans = None
|
||||||
|
oprot = iprot
|
||||||
|
else:
|
||||||
|
otrans = self.outputTransportFactory.getTransport(client)
|
||||||
|
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||||
|
|
||||||
def serve(self):
|
try:
|
||||||
"""Start a fixed number of worker threads and put client into a queue"""
|
while True:
|
||||||
for i in range(self.threads):
|
self.processor.process(iprot, oprot)
|
||||||
try:
|
except TTransport.TTransportException:
|
||||||
t = threading.Thread(target=self.serveThread)
|
pass
|
||||||
t.setDaemon(self.daemon)
|
except Exception as x:
|
||||||
t.start()
|
logger.exception(x)
|
||||||
except Exception as x:
|
|
||||||
logging.exception(x)
|
|
||||||
|
|
||||||
# Pump the socket for clients
|
itrans.close()
|
||||||
self.serverTransport.listen()
|
if otrans:
|
||||||
while True:
|
otrans.close()
|
||||||
try:
|
|
||||||
client = self.serverTransport.accept()
|
def serve(self):
|
||||||
self.clients.put(client)
|
"""Start a fixed number of worker threads and put client into a queue"""
|
||||||
except Exception as x:
|
for i in range(self.threads):
|
||||||
logging.exception(x)
|
try:
|
||||||
|
t = threading.Thread(target=self.serveThread)
|
||||||
|
t.setDaemon(self.daemon)
|
||||||
|
t.start()
|
||||||
|
except Exception as x:
|
||||||
|
logger.exception(x)
|
||||||
|
|
||||||
|
# Pump the socket for clients
|
||||||
|
self.serverTransport.listen()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
client = self.serverTransport.accept()
|
||||||
|
if not client:
|
||||||
|
continue
|
||||||
|
self.clients.put(client)
|
||||||
|
except Exception as x:
|
||||||
|
logger.exception(x)
|
||||||
|
|
||||||
|
|
||||||
class TForkingServer(TServer):
|
class TForkingServer(TServer):
|
||||||
"""A Thrift server that forks a new process for each request
|
"""A Thrift server that forks a new process for each request
|
||||||
|
|
||||||
This is more scalable than the threaded server as it does not cause
|
This is more scalable than the threaded server as it does not cause
|
||||||
GIL contention.
|
GIL contention.
|
||||||
|
|
||||||
Note that this has different semantics from the threading server.
|
Note that this has different semantics from the threading server.
|
||||||
Specifically, updates to shared variables will no longer be shared.
|
Specifically, updates to shared variables will no longer be shared.
|
||||||
It will also not work on windows.
|
It will also not work on windows.
|
||||||
|
|
||||||
This code is heavily inspired by SocketServer.ForkingMixIn in the
|
This code is heavily inspired by SocketServer.ForkingMixIn in the
|
||||||
Python stdlib.
|
Python stdlib.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
TServer.__init__(self, *args)
|
TServer.__init__(self, *args)
|
||||||
self.children = []
|
self.children = []
|
||||||
|
|
||||||
def serve(self):
|
def serve(self):
|
||||||
def try_close(file):
|
def try_close(file):
|
||||||
try:
|
|
||||||
file.close()
|
|
||||||
except IOError as e:
|
|
||||||
logging.warning(e, exc_info=True)
|
|
||||||
|
|
||||||
self.serverTransport.listen()
|
|
||||||
while True:
|
|
||||||
client = self.serverTransport.accept()
|
|
||||||
try:
|
|
||||||
pid = os.fork()
|
|
||||||
|
|
||||||
if pid: # parent
|
|
||||||
# add before collect, otherwise you race w/ waitpid
|
|
||||||
self.children.append(pid)
|
|
||||||
self.collect_children()
|
|
||||||
|
|
||||||
# Parent must close socket or the connection may not get
|
|
||||||
# closed promptly
|
|
||||||
itrans = self.inputTransportFactory.getTransport(client)
|
|
||||||
otrans = self.outputTransportFactory.getTransport(client)
|
|
||||||
try_close(itrans)
|
|
||||||
try_close(otrans)
|
|
||||||
else:
|
|
||||||
itrans = self.inputTransportFactory.getTransport(client)
|
|
||||||
otrans = self.outputTransportFactory.getTransport(client)
|
|
||||||
|
|
||||||
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
|
||||||
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
|
||||||
|
|
||||||
ecode = 0
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
while True:
|
file.close()
|
||||||
self.processor.process(iprot, oprot)
|
except IOError as e:
|
||||||
except TTransport.TTransportException as tx:
|
logger.warning(e, exc_info=True)
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
ecode = 1
|
|
||||||
finally:
|
|
||||||
try_close(itrans)
|
|
||||||
try_close(otrans)
|
|
||||||
|
|
||||||
os._exit(ecode)
|
self.serverTransport.listen()
|
||||||
|
while True:
|
||||||
|
client = self.serverTransport.accept()
|
||||||
|
if not client:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
pid = os.fork()
|
||||||
|
|
||||||
except TTransport.TTransportException as tx:
|
if pid: # parent
|
||||||
pass
|
# add before collect, otherwise you race w/ waitpid
|
||||||
except Exception as x:
|
self.children.append(pid)
|
||||||
logging.exception(x)
|
self.collect_children()
|
||||||
|
|
||||||
def collect_children(self):
|
# Parent must close socket or the connection may not get
|
||||||
while self.children:
|
# closed promptly
|
||||||
try:
|
itrans = self.inputTransportFactory.getTransport(client)
|
||||||
pid, status = os.waitpid(0, os.WNOHANG)
|
otrans = self.outputTransportFactory.getTransport(client)
|
||||||
except os.error:
|
try_close(itrans)
|
||||||
pid = None
|
try_close(otrans)
|
||||||
|
else:
|
||||||
|
itrans = self.inputTransportFactory.getTransport(client)
|
||||||
|
iprot = self.inputProtocolFactory.getProtocol(itrans)
|
||||||
|
|
||||||
if pid:
|
# for THeaderProtocol, we must use the same protocol
|
||||||
self.children.remove(pid)
|
# instance for input and output so that the response is in
|
||||||
else:
|
# the same dialect that the server detected the request was
|
||||||
break
|
# in.
|
||||||
|
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
|
||||||
|
otrans = None
|
||||||
|
oprot = iprot
|
||||||
|
else:
|
||||||
|
otrans = self.outputTransportFactory.getTransport(client)
|
||||||
|
oprot = self.outputProtocolFactory.getProtocol(otrans)
|
||||||
|
|
||||||
|
ecode = 0
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.processor.process(iprot, oprot)
|
||||||
|
except TTransport.TTransportException:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
ecode = 1
|
||||||
|
finally:
|
||||||
|
try_close(itrans)
|
||||||
|
if otrans:
|
||||||
|
try_close(otrans)
|
||||||
|
|
||||||
|
os._exit(ecode)
|
||||||
|
|
||||||
|
except TTransport.TTransportException:
|
||||||
|
pass
|
||||||
|
except Exception as x:
|
||||||
|
logger.exception(x)
|
||||||
|
|
||||||
|
def collect_children(self):
|
||||||
|
while self.children:
|
||||||
|
try:
|
||||||
|
pid, status = os.waitpid(0, os.WNOHANG)
|
||||||
|
except os.error:
|
||||||
|
pid = None
|
||||||
|
|
||||||
|
if pid:
|
||||||
|
self.children.remove(pid)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
352
thrift/transport/THeaderTransport.py
Normal file
352
thrift/transport/THeaderTransport.py
Normal file
|
@ -0,0 +1,352 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import struct
|
||||||
|
import zlib
|
||||||
|
|
||||||
|
from thrift.compat import BufferIO, byte_index
|
||||||
|
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
|
||||||
|
from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
|
||||||
|
from thrift.Thrift import TApplicationException
|
||||||
|
from thrift.transport.TTransport import (
|
||||||
|
CReadableTransport,
|
||||||
|
TMemoryBuffer,
|
||||||
|
TTransportBase,
|
||||||
|
TTransportException,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
U16 = struct.Struct("!H")
|
||||||
|
I32 = struct.Struct("!i")
|
||||||
|
HEADER_MAGIC = 0x0FFF
|
||||||
|
HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
|
||||||
|
|
||||||
|
|
||||||
|
class THeaderClientType(object):
|
||||||
|
HEADERS = 0x00
|
||||||
|
|
||||||
|
FRAMED_BINARY = 0x01
|
||||||
|
UNFRAMED_BINARY = 0x02
|
||||||
|
|
||||||
|
FRAMED_COMPACT = 0x03
|
||||||
|
UNFRAMED_COMPACT = 0x04
|
||||||
|
|
||||||
|
|
||||||
|
class THeaderSubprotocolID(object):
|
||||||
|
BINARY = 0x00
|
||||||
|
COMPACT = 0x02
|
||||||
|
|
||||||
|
|
||||||
|
class TInfoHeaderType(object):
|
||||||
|
KEY_VALUE = 0x01
|
||||||
|
|
||||||
|
|
||||||
|
class THeaderTransformID(object):
|
||||||
|
ZLIB = 0x01
|
||||||
|
|
||||||
|
|
||||||
|
READ_TRANSFORMS_BY_ID = {
|
||||||
|
THeaderTransformID.ZLIB: zlib.decompress,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
WRITE_TRANSFORMS_BY_ID = {
|
||||||
|
THeaderTransformID.ZLIB: zlib.compress,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _readString(trans):
|
||||||
|
size = readVarint(trans)
|
||||||
|
if size < 0:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.NEGATIVE_SIZE,
|
||||||
|
"Negative length"
|
||||||
|
)
|
||||||
|
return trans.read(size)
|
||||||
|
|
||||||
|
|
||||||
|
def _writeString(trans, value):
|
||||||
|
writeVarint(trans, len(value))
|
||||||
|
trans.write(value)
|
||||||
|
|
||||||
|
|
||||||
|
class THeaderTransport(TTransportBase, CReadableTransport):
|
||||||
|
def __init__(self, transport, allowed_client_types, default_protocol=THeaderSubprotocolID.BINARY):
|
||||||
|
self._transport = transport
|
||||||
|
self._client_type = THeaderClientType.HEADERS
|
||||||
|
self._allowed_client_types = allowed_client_types
|
||||||
|
|
||||||
|
self._read_buffer = BufferIO(b"")
|
||||||
|
self._read_headers = {}
|
||||||
|
|
||||||
|
self._write_buffer = BufferIO()
|
||||||
|
self._write_headers = {}
|
||||||
|
self._write_transforms = []
|
||||||
|
|
||||||
|
self.flags = 0
|
||||||
|
self.sequence_id = 0
|
||||||
|
self._protocol_id = default_protocol
|
||||||
|
self._max_frame_size = HARD_MAX_FRAME_SIZE
|
||||||
|
|
||||||
|
def isOpen(self):
|
||||||
|
return self._transport.isOpen()
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
return self._transport.open()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
return self._transport.close()
|
||||||
|
|
||||||
|
def get_headers(self):
|
||||||
|
return self._read_headers
|
||||||
|
|
||||||
|
def set_header(self, key, value):
|
||||||
|
if not isinstance(key, bytes):
|
||||||
|
raise ValueError("header names must be bytes")
|
||||||
|
if not isinstance(value, bytes):
|
||||||
|
raise ValueError("header values must be bytes")
|
||||||
|
self._write_headers[key] = value
|
||||||
|
|
||||||
|
def clear_headers(self):
|
||||||
|
self._write_headers.clear()
|
||||||
|
|
||||||
|
def add_transform(self, transform_id):
|
||||||
|
if transform_id not in WRITE_TRANSFORMS_BY_ID:
|
||||||
|
raise ValueError("unknown transform")
|
||||||
|
self._write_transforms.append(transform_id)
|
||||||
|
|
||||||
|
def set_max_frame_size(self, size):
|
||||||
|
if not 0 < size < HARD_MAX_FRAME_SIZE:
|
||||||
|
raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
|
||||||
|
self._max_frame_size = size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol_id(self):
|
||||||
|
if self._client_type == THeaderClientType.HEADERS:
|
||||||
|
return self._protocol_id
|
||||||
|
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
|
||||||
|
return THeaderSubprotocolID.BINARY
|
||||||
|
elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
|
||||||
|
return THeaderSubprotocolID.COMPACT
|
||||||
|
else:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.INVALID_CLIENT_TYPE,
|
||||||
|
"Protocol ID not know for client type %d" % self._client_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def read(self, sz):
|
||||||
|
# if there are bytes left in the buffer, produce those first.
|
||||||
|
bytes_read = self._read_buffer.read(sz)
|
||||||
|
bytes_left_to_read = sz - len(bytes_read)
|
||||||
|
if bytes_left_to_read == 0:
|
||||||
|
return bytes_read
|
||||||
|
|
||||||
|
# if we've determined this is an unframed client, just pass the read
|
||||||
|
# through to the underlying transport until we're reset again at the
|
||||||
|
# beginning of the next message.
|
||||||
|
if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
|
||||||
|
return bytes_read + self._transport.read(bytes_left_to_read)
|
||||||
|
|
||||||
|
# we're empty and (maybe) framed. fill the buffers with the next frame.
|
||||||
|
self.readFrame(bytes_left_to_read)
|
||||||
|
return bytes_read + self._read_buffer.read(bytes_left_to_read)
|
||||||
|
|
||||||
|
def _set_client_type(self, client_type):
|
||||||
|
if client_type not in self._allowed_client_types:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.INVALID_CLIENT_TYPE,
|
||||||
|
"Client type %d not allowed by server." % client_type,
|
||||||
|
)
|
||||||
|
self._client_type = client_type
|
||||||
|
|
||||||
|
def readFrame(self, req_sz):
|
||||||
|
# the first word could either be the length field of a framed message
|
||||||
|
# or the first bytes of an unframed message.
|
||||||
|
first_word = self._transport.readAll(I32.size)
|
||||||
|
frame_size, = I32.unpack(first_word)
|
||||||
|
is_unframed = False
|
||||||
|
if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
|
||||||
|
self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
|
||||||
|
is_unframed = True
|
||||||
|
elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
|
||||||
|
byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
|
||||||
|
self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
|
||||||
|
is_unframed = True
|
||||||
|
|
||||||
|
if is_unframed:
|
||||||
|
bytes_left_to_read = req_sz - I32.size
|
||||||
|
if bytes_left_to_read > 0:
|
||||||
|
rest = self._transport.read(bytes_left_to_read)
|
||||||
|
else:
|
||||||
|
rest = b""
|
||||||
|
self._read_buffer = BufferIO(first_word + rest)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ok, we're still here so we're framed.
|
||||||
|
if frame_size > self._max_frame_size:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.SIZE_LIMIT,
|
||||||
|
"Frame was too large.",
|
||||||
|
)
|
||||||
|
read_buffer = BufferIO(self._transport.readAll(frame_size))
|
||||||
|
|
||||||
|
# the next word is either going to be the version field of a
|
||||||
|
# binary/compact protocol message or the magic value + flags of a
|
||||||
|
# header protocol message.
|
||||||
|
second_word = read_buffer.read(I32.size)
|
||||||
|
version, = I32.unpack(second_word)
|
||||||
|
read_buffer.seek(0)
|
||||||
|
if version >> 16 == HEADER_MAGIC:
|
||||||
|
self._set_client_type(THeaderClientType.HEADERS)
|
||||||
|
self._read_buffer = self._parse_header_format(read_buffer)
|
||||||
|
elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
|
||||||
|
self._set_client_type(THeaderClientType.FRAMED_BINARY)
|
||||||
|
self._read_buffer = read_buffer
|
||||||
|
elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
|
||||||
|
byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
|
||||||
|
self._set_client_type(THeaderClientType.FRAMED_COMPACT)
|
||||||
|
self._read_buffer = read_buffer
|
||||||
|
else:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.INVALID_CLIENT_TYPE,
|
||||||
|
"Could not detect client transport type.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_header_format(self, buffer):
|
||||||
|
# make BufferIO look like TTransport for varint helpers
|
||||||
|
buffer_transport = TMemoryBuffer()
|
||||||
|
buffer_transport._buffer = buffer
|
||||||
|
|
||||||
|
buffer.read(2) # discard the magic bytes
|
||||||
|
self.flags, = U16.unpack(buffer.read(U16.size))
|
||||||
|
self.sequence_id, = I32.unpack(buffer.read(I32.size))
|
||||||
|
|
||||||
|
header_length = U16.unpack(buffer.read(U16.size))[0] * 4
|
||||||
|
end_of_headers = buffer.tell() + header_length
|
||||||
|
if end_of_headers > len(buffer.getvalue()):
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.SIZE_LIMIT,
|
||||||
|
"Header size is larger than whole frame.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._protocol_id = readVarint(buffer_transport)
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
transform_count = readVarint(buffer_transport)
|
||||||
|
for _ in range(transform_count):
|
||||||
|
transform_id = readVarint(buffer_transport)
|
||||||
|
if transform_id not in READ_TRANSFORMS_BY_ID:
|
||||||
|
raise TApplicationException(
|
||||||
|
TApplicationException.INVALID_TRANSFORM,
|
||||||
|
"Unknown transform: %d" % transform_id,
|
||||||
|
)
|
||||||
|
transforms.append(transform_id)
|
||||||
|
transforms.reverse()
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
while buffer.tell() < end_of_headers:
|
||||||
|
header_type = readVarint(buffer_transport)
|
||||||
|
if header_type == TInfoHeaderType.KEY_VALUE:
|
||||||
|
count = readVarint(buffer_transport)
|
||||||
|
for _ in range(count):
|
||||||
|
key = _readString(buffer_transport)
|
||||||
|
value = _readString(buffer_transport)
|
||||||
|
headers[key] = value
|
||||||
|
else:
|
||||||
|
break # ignore unknown headers
|
||||||
|
self._read_headers = headers
|
||||||
|
|
||||||
|
# skip padding / anything we didn't understand
|
||||||
|
buffer.seek(end_of_headers)
|
||||||
|
|
||||||
|
payload = buffer.read()
|
||||||
|
for transform_id in transforms:
|
||||||
|
transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
|
||||||
|
payload = transform_fn(payload)
|
||||||
|
return BufferIO(payload)
|
||||||
|
|
||||||
|
def write(self, buf):
|
||||||
|
self._write_buffer.write(buf)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
payload = self._write_buffer.getvalue()
|
||||||
|
self._write_buffer = BufferIO()
|
||||||
|
|
||||||
|
buffer = BufferIO()
|
||||||
|
if self._client_type == THeaderClientType.HEADERS:
|
||||||
|
for transform_id in self._write_transforms:
|
||||||
|
transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
|
||||||
|
payload = transform_fn(payload)
|
||||||
|
|
||||||
|
headers = BufferIO()
|
||||||
|
writeVarint(headers, self._protocol_id)
|
||||||
|
writeVarint(headers, len(self._write_transforms))
|
||||||
|
for transform_id in self._write_transforms:
|
||||||
|
writeVarint(headers, transform_id)
|
||||||
|
if self._write_headers:
|
||||||
|
writeVarint(headers, TInfoHeaderType.KEY_VALUE)
|
||||||
|
writeVarint(headers, len(self._write_headers))
|
||||||
|
for key, value in self._write_headers.items():
|
||||||
|
_writeString(headers, key)
|
||||||
|
_writeString(headers, value)
|
||||||
|
self._write_headers = {}
|
||||||
|
padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
|
||||||
|
headers.write(b"\x00" * padding_needed)
|
||||||
|
header_bytes = headers.getvalue()
|
||||||
|
|
||||||
|
buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
|
||||||
|
buffer.write(U16.pack(HEADER_MAGIC))
|
||||||
|
buffer.write(U16.pack(self.flags))
|
||||||
|
buffer.write(I32.pack(self.sequence_id))
|
||||||
|
buffer.write(U16.pack(len(header_bytes) // 4))
|
||||||
|
buffer.write(header_bytes)
|
||||||
|
buffer.write(payload)
|
||||||
|
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
|
||||||
|
buffer.write(I32.pack(len(payload)))
|
||||||
|
buffer.write(payload)
|
||||||
|
elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
|
||||||
|
buffer.write(payload)
|
||||||
|
else:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.INVALID_CLIENT_TYPE,
|
||||||
|
"Unknown client type.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# the frame length field doesn't count towards the frame payload size
|
||||||
|
frame_bytes = buffer.getvalue()
|
||||||
|
frame_payload_size = len(frame_bytes) - 4
|
||||||
|
if frame_payload_size > self._max_frame_size:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.SIZE_LIMIT,
|
||||||
|
"Attempting to send frame that is too large.",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._transport.write(frame_bytes)
|
||||||
|
self._transport.flush()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cstringio_buf(self):
|
||||||
|
return self._read_buffer
|
||||||
|
|
||||||
|
def cstringio_refill(self, partialread, reqlen):
|
||||||
|
result = bytearray(partialread)
|
||||||
|
while len(result) < reqlen:
|
||||||
|
result += self.read(reqlen - len(result))
|
||||||
|
self._read_buffer = BufferIO(result)
|
||||||
|
return self._read_buffer
|
|
@ -17,133 +17,175 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
import http.client
|
from io import BytesIO
|
||||||
import os
|
import os
|
||||||
import socket
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import urllib.request, urllib.parse, urllib.error
|
|
||||||
import urllib.parse
|
|
||||||
import warnings
|
import warnings
|
||||||
|
import base64
|
||||||
|
|
||||||
from io import StringIO
|
from six.moves import urllib
|
||||||
|
from six.moves import http_client
|
||||||
|
|
||||||
from .TTransport import *
|
from .TTransport import TTransportBase
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
class THttpClient(TTransportBase):
|
class THttpClient(TTransportBase):
|
||||||
"""Http implementation of TTransport base."""
|
"""Http implementation of TTransport base."""
|
||||||
|
|
||||||
def __init__(self, uri_or_host, port=None, path=None):
|
def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None):
|
||||||
"""THttpClient supports two different types constructor parameters.
|
"""THttpClient supports two different types of construction:
|
||||||
|
|
||||||
THttpClient(host, port, path) - deprecated
|
THttpClient(host, port, path) - deprecated
|
||||||
THttpClient(uri)
|
THttpClient(uri, [port=<n>, path=<s>, cafile=<filename>, cert_file=<filename>, key_file=<filename>, ssl_context=<context>])
|
||||||
|
|
||||||
Only the second supports https.
|
Only the second supports https. To properly authenticate against the server,
|
||||||
"""
|
provide the client's identity by specifying cert_file and key_file. To properly
|
||||||
if port is not None:
|
authenticate the server, specify either cafile or ssl_context with a CA defined.
|
||||||
warnings.warn(
|
NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile.
|
||||||
"Please use the THttpClient('http://host:port/path') syntax",
|
"""
|
||||||
DeprecationWarning,
|
if port is not None:
|
||||||
stacklevel=2)
|
warnings.warn(
|
||||||
self.host = uri_or_host
|
"Please use the THttpClient('http{s}://host:port/path') constructor",
|
||||||
self.port = port
|
DeprecationWarning,
|
||||||
assert path
|
stacklevel=2)
|
||||||
self.path = path
|
self.host = uri_or_host
|
||||||
self.scheme = 'http'
|
self.port = port
|
||||||
else:
|
assert path
|
||||||
parsed = urllib.parse.urlparse(uri_or_host)
|
self.path = path
|
||||||
self.scheme = parsed.scheme
|
self.scheme = 'http'
|
||||||
assert self.scheme in ('http', 'https')
|
else:
|
||||||
if self.scheme == 'http':
|
parsed = urllib.parse.urlparse(uri_or_host)
|
||||||
self.port = parsed.port or http.client.HTTP_PORT
|
self.scheme = parsed.scheme
|
||||||
elif self.scheme == 'https':
|
assert self.scheme in ('http', 'https')
|
||||||
self.port = parsed.port or http.client.HTTPS_PORT
|
if self.scheme == 'http':
|
||||||
self.host = parsed.hostname
|
self.port = parsed.port or http_client.HTTP_PORT
|
||||||
self.path = parsed.path
|
elif self.scheme == 'https':
|
||||||
if parsed.query:
|
self.port = parsed.port or http_client.HTTPS_PORT
|
||||||
self.path += '?%s' % parsed.query
|
self.certfile = cert_file
|
||||||
self.__wbuf = StringIO()
|
self.keyfile = key_file
|
||||||
self.__http = None
|
self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context
|
||||||
self.__timeout = None
|
self.host = parsed.hostname
|
||||||
self.__custom_headers = None
|
self.path = parsed.path
|
||||||
|
if parsed.query:
|
||||||
|
self.path += '?%s' % parsed.query
|
||||||
|
try:
|
||||||
|
proxy = urllib.request.getproxies()[self.scheme]
|
||||||
|
except KeyError:
|
||||||
|
proxy = None
|
||||||
|
else:
|
||||||
|
if urllib.request.proxy_bypass(self.host):
|
||||||
|
proxy = None
|
||||||
|
if proxy:
|
||||||
|
parsed = urllib.parse.urlparse(proxy)
|
||||||
|
self.realhost = self.host
|
||||||
|
self.realport = self.port
|
||||||
|
self.host = parsed.hostname
|
||||||
|
self.port = parsed.port
|
||||||
|
self.proxy_auth = self.basic_proxy_auth_header(parsed)
|
||||||
|
else:
|
||||||
|
self.realhost = self.realport = self.proxy_auth = None
|
||||||
|
self.__wbuf = BytesIO()
|
||||||
|
self.__http = None
|
||||||
|
self.__http_response = None
|
||||||
|
self.__timeout = None
|
||||||
|
self.__custom_headers = None
|
||||||
|
|
||||||
def open(self):
|
@staticmethod
|
||||||
if self.scheme == 'http':
|
def basic_proxy_auth_header(proxy):
|
||||||
self.__http = http.client.HTTP(self.host, self.port)
|
if proxy is None or not proxy.username:
|
||||||
else:
|
return None
|
||||||
self.__http = http.client.HTTPS(self.host, self.port)
|
ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
|
||||||
|
urllib.parse.unquote(proxy.password))
|
||||||
|
cr = base64.b64encode(ap).strip()
|
||||||
|
return "Basic " + cr
|
||||||
|
|
||||||
def close(self):
|
def using_proxy(self):
|
||||||
self.__http.close()
|
return self.realhost is not None
|
||||||
self.__http = None
|
|
||||||
|
|
||||||
def isOpen(self):
|
def open(self):
|
||||||
return self.__http is not None
|
if self.scheme == 'http':
|
||||||
|
self.__http = http_client.HTTPConnection(self.host, self.port,
|
||||||
|
timeout=self.__timeout)
|
||||||
|
elif self.scheme == 'https':
|
||||||
|
self.__http = http_client.HTTPSConnection(self.host, self.port,
|
||||||
|
key_file=self.keyfile,
|
||||||
|
cert_file=self.certfile,
|
||||||
|
timeout=self.__timeout,
|
||||||
|
context=self.context)
|
||||||
|
if self.using_proxy():
|
||||||
|
self.__http.set_tunnel(self.realhost, self.realport,
|
||||||
|
{"Proxy-Authorization": self.proxy_auth})
|
||||||
|
|
||||||
def setTimeout(self, ms):
|
def close(self):
|
||||||
if not hasattr(socket, 'getdefaulttimeout'):
|
self.__http.close()
|
||||||
raise NotImplementedError
|
self.__http = None
|
||||||
|
self.__http_response = None
|
||||||
|
|
||||||
if ms is None:
|
def isOpen(self):
|
||||||
self.__timeout = None
|
return self.__http is not None
|
||||||
else:
|
|
||||||
self.__timeout = ms / 1000.0
|
|
||||||
|
|
||||||
def setCustomHeaders(self, headers):
|
def setTimeout(self, ms):
|
||||||
self.__custom_headers = headers
|
if ms is None:
|
||||||
|
self.__timeout = None
|
||||||
|
else:
|
||||||
|
self.__timeout = ms / 1000.0
|
||||||
|
|
||||||
def read(self, sz):
|
def setCustomHeaders(self, headers):
|
||||||
return self.__http.file.read(sz)
|
self.__custom_headers = headers
|
||||||
|
|
||||||
def write(self, buf):
|
def read(self, sz):
|
||||||
self.__wbuf.write(buf)
|
return self.__http_response.read(sz)
|
||||||
|
|
||||||
def __withTimeout(f):
|
def write(self, buf):
|
||||||
def _f(*args, **kwargs):
|
self.__wbuf.write(buf)
|
||||||
orig_timeout = socket.getdefaulttimeout()
|
|
||||||
socket.setdefaulttimeout(args[0].__timeout)
|
|
||||||
result = f(*args, **kwargs)
|
|
||||||
socket.setdefaulttimeout(orig_timeout)
|
|
||||||
return result
|
|
||||||
return _f
|
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
if self.isOpen():
|
if self.isOpen():
|
||||||
self.close()
|
self.close()
|
||||||
self.open()
|
self.open()
|
||||||
|
|
||||||
# Pull data out of buffer
|
# Pull data out of buffer
|
||||||
data = self.__wbuf.getvalue()
|
data = self.__wbuf.getvalue()
|
||||||
self.__wbuf = StringIO()
|
self.__wbuf = BytesIO()
|
||||||
|
|
||||||
# HTTP request
|
# HTTP request
|
||||||
self.__http.putrequest('POST', self.path)
|
if self.using_proxy() and self.scheme == "http":
|
||||||
|
# need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel)
|
||||||
|
self.__http.putrequest('POST', "http://%s:%s%s" %
|
||||||
|
(self.realhost, self.realport, self.path))
|
||||||
|
else:
|
||||||
|
self.__http.putrequest('POST', self.path)
|
||||||
|
|
||||||
# Write headers
|
# Write headers
|
||||||
self.__http.putheader('Host', self.host)
|
self.__http.putheader('Content-Type', 'application/x-thrift')
|
||||||
self.__http.putheader('Content-Type', 'application/x-thrift')
|
self.__http.putheader('Content-Length', str(len(data)))
|
||||||
self.__http.putheader('Content-Length', str(len(data)))
|
if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None:
|
||||||
|
self.__http.putheader("Proxy-Authorization", self.proxy_auth)
|
||||||
|
|
||||||
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
|
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
|
||||||
user_agent = 'Python/THttpClient'
|
user_agent = 'Python/THttpClient'
|
||||||
script = os.path.basename(sys.argv[0])
|
script = os.path.basename(sys.argv[0])
|
||||||
if script:
|
if script:
|
||||||
user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
|
user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
|
||||||
self.__http.putheader('User-Agent', user_agent)
|
self.__http.putheader('User-Agent', user_agent)
|
||||||
|
|
||||||
if self.__custom_headers:
|
if self.__custom_headers:
|
||||||
for key, val in self.__custom_headers.items():
|
for key, val in six.iteritems(self.__custom_headers):
|
||||||
self.__http.putheader(key, val)
|
self.__http.putheader(key, val)
|
||||||
|
|
||||||
self.__http.endheaders()
|
self.__http.endheaders()
|
||||||
|
|
||||||
# Write payload
|
# Write payload
|
||||||
self.__http.send(data)
|
self.__http.send(data)
|
||||||
|
|
||||||
# Get reply to flush the request
|
# Get reply to flush the request
|
||||||
self.code, self.message, self.headers = self.__http.getreply()
|
self.__http_response = self.__http.getresponse()
|
||||||
|
self.code = self.__http_response.status
|
||||||
|
self.message = self.__http_response.reason
|
||||||
|
self.headers = self.__http_response.msg
|
||||||
|
|
||||||
# Decorate if we know how to timeout
|
# Saves the cookie sent by the server response
|
||||||
if hasattr(socket, 'getdefaulttimeout'):
|
if 'Set-Cookie' in self.headers:
|
||||||
flush = __withTimeout(flush)
|
self.__http.putheader('Cookie', self.headers['Set-Cookie'])
|
||||||
|
|
|
@ -17,186 +17,392 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from .sslcompat import _match_hostname, _match_has_ipaddress
|
||||||
from thrift.transport import TSocket
|
from thrift.transport import TSocket
|
||||||
from thrift.transport.TTransport import TTransportException
|
from thrift.transport.TTransport import TTransportException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
'default', category=DeprecationWarning, module=__name__)
|
||||||
|
|
||||||
class TSSLSocket(TSocket.TSocket):
|
|
||||||
"""
|
|
||||||
SSL implementation of client-side TSocket
|
|
||||||
|
|
||||||
This class creates outbound sockets wrapped using the
|
class TSSLBase(object):
|
||||||
python standard ssl module for encrypted connections.
|
# SSLContext is not available for Python < 2.7.9
|
||||||
|
_has_ssl_context = sys.hexversion >= 0x020709F0
|
||||||
|
|
||||||
The protocol used is set using the class variable
|
# ciphers argument is not available for Python < 2.7.0
|
||||||
SSL_VERSION, which must be one of ssl.PROTOCOL_* and
|
_has_ciphers = sys.hexversion >= 0x020700F0
|
||||||
defaults to ssl.PROTOCOL_TLSv1 for greatest security.
|
|
||||||
"""
|
|
||||||
SSL_VERSION = ssl.PROTOCOL_TLSv1
|
|
||||||
|
|
||||||
def __init__(self,
|
# For python >= 2.7.9, use latest TLS that both client and server
|
||||||
host='localhost',
|
# supports.
|
||||||
port=9090,
|
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
|
||||||
validate=True,
|
# For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
|
||||||
ca_certs=None,
|
# unavailable.
|
||||||
unix_socket=None):
|
_default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
|
||||||
"""Create SSL TSocket
|
ssl.PROTOCOL_TLSv1
|
||||||
|
|
||||||
@param validate: Set to False to disable SSL certificate validation
|
def _init_context(self, ssl_version):
|
||||||
@type validate: bool
|
if self._has_ssl_context:
|
||||||
@param ca_certs: Filename to the Certificate Authority pem file, possibly a
|
self._context = ssl.SSLContext(ssl_version)
|
||||||
file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to
|
if self._context.protocol == ssl.PROTOCOL_SSLv23:
|
||||||
the ssl_wrap function as the 'ca_certs' parameter.
|
self._context.options |= ssl.OP_NO_SSLv2
|
||||||
@type ca_certs: str
|
self._context.options |= ssl.OP_NO_SSLv3
|
||||||
|
else:
|
||||||
|
self._context = None
|
||||||
|
self._ssl_version = ssl_version
|
||||||
|
|
||||||
Raises an IOError exception if validate is True and the ca_certs file is
|
@property
|
||||||
None, not present or unreadable.
|
def _should_verify(self):
|
||||||
|
if self._has_ssl_context:
|
||||||
|
return self._context.verify_mode != ssl.CERT_NONE
|
||||||
|
else:
|
||||||
|
return self.cert_reqs != ssl.CERT_NONE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ssl_version(self):
|
||||||
|
if self._has_ssl_context:
|
||||||
|
return self.ssl_context.protocol
|
||||||
|
else:
|
||||||
|
return self._ssl_version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ssl_context(self):
|
||||||
|
return self._context
|
||||||
|
|
||||||
|
SSL_VERSION = _default_protocol
|
||||||
"""
|
"""
|
||||||
self.validate = validate
|
Default SSL version.
|
||||||
self.is_valid = False
|
For backwards compatibility, it can be modified.
|
||||||
self.peercert = None
|
Use __init__ keyword argument "ssl_version" instead.
|
||||||
if not validate:
|
"""
|
||||||
self.cert_reqs = ssl.CERT_NONE
|
|
||||||
else:
|
|
||||||
self.cert_reqs = ssl.CERT_REQUIRED
|
|
||||||
self.ca_certs = ca_certs
|
|
||||||
if validate:
|
|
||||||
if ca_certs is None or not os.access(ca_certs, os.R_OK):
|
|
||||||
raise IOError('Certificate Authority ca_certs file "%s" '
|
|
||||||
'is not readable, cannot validate SSL '
|
|
||||||
'certificates.' % (ca_certs))
|
|
||||||
TSocket.TSocket.__init__(self, host, port, unix_socket)
|
|
||||||
|
|
||||||
def open(self):
|
def _deprecated_arg(self, args, kwargs, pos, key):
|
||||||
try:
|
if len(args) <= pos:
|
||||||
res0 = self._resolveAddr()
|
return
|
||||||
for res in res0:
|
real_pos = pos + 3
|
||||||
sock_family, sock_type = res[0:2]
|
warnings.warn(
|
||||||
ip_port = res[4]
|
'%dth positional argument is deprecated.'
|
||||||
plain_sock = socket.socket(sock_family, sock_type)
|
'please use keyword argument instead.'
|
||||||
self.handle = ssl.wrap_socket(plain_sock,
|
% real_pos, DeprecationWarning, stacklevel=3)
|
||||||
ssl_version=self.SSL_VERSION,
|
|
||||||
do_handshake_on_connect=True,
|
if key in kwargs:
|
||||||
ca_certs=self.ca_certs,
|
raise TypeError(
|
||||||
cert_reqs=self.cert_reqs)
|
'Duplicate argument: %dth argument and %s keyword argument.'
|
||||||
self.handle.settimeout(self._timeout)
|
% (real_pos, key))
|
||||||
|
kwargs[key] = args[pos]
|
||||||
|
|
||||||
|
def _unix_socket_arg(self, host, port, args, kwargs):
|
||||||
|
key = 'unix_socket'
|
||||||
|
if host is None and port is None and len(args) == 1 and key not in kwargs:
|
||||||
|
kwargs[key] = args[0]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
if key == 'SSL_VERSION':
|
||||||
|
warnings.warn(
|
||||||
|
'SSL_VERSION is deprecated.'
|
||||||
|
'please use ssl_version attribute instead.',
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
return self.ssl_version
|
||||||
|
|
||||||
|
def __init__(self, server_side, host, ssl_opts):
|
||||||
|
self._server_side = server_side
|
||||||
|
if TSSLBase.SSL_VERSION != self._default_protocol:
|
||||||
|
warnings.warn(
|
||||||
|
'SSL_VERSION is deprecated.'
|
||||||
|
'please use ssl_version keyword argument instead.',
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
self._context = ssl_opts.pop('ssl_context', None)
|
||||||
|
self._server_hostname = None
|
||||||
|
if not self._server_side:
|
||||||
|
self._server_hostname = ssl_opts.pop('server_hostname', host)
|
||||||
|
if self._context:
|
||||||
|
self._custom_context = True
|
||||||
|
if ssl_opts:
|
||||||
|
raise ValueError(
|
||||||
|
'Incompatible arguments: ssl_context and %s'
|
||||||
|
% ' '.join(ssl_opts.keys()))
|
||||||
|
if not self._has_ssl_context:
|
||||||
|
raise ValueError(
|
||||||
|
'ssl_context is not available for this version of Python')
|
||||||
|
else:
|
||||||
|
self._custom_context = False
|
||||||
|
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
|
||||||
|
self._init_context(ssl_version)
|
||||||
|
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
|
||||||
|
self.ca_certs = ssl_opts.pop('ca_certs', None)
|
||||||
|
self.keyfile = ssl_opts.pop('keyfile', None)
|
||||||
|
self.certfile = ssl_opts.pop('certfile', None)
|
||||||
|
self.ciphers = ssl_opts.pop('ciphers', None)
|
||||||
|
|
||||||
|
if ssl_opts:
|
||||||
|
raise ValueError(
|
||||||
|
'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
|
||||||
|
|
||||||
|
if self._should_verify:
|
||||||
|
if not self.ca_certs:
|
||||||
|
raise ValueError(
|
||||||
|
'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
|
||||||
|
if not os.access(self.ca_certs, os.R_OK):
|
||||||
|
raise IOError('Certificate Authority ca_certs file "%s" '
|
||||||
|
'is not readable, cannot validate SSL '
|
||||||
|
'certificates.' % (self.ca_certs))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def certfile(self):
|
||||||
|
return self._certfile
|
||||||
|
|
||||||
|
@certfile.setter
|
||||||
|
def certfile(self, certfile):
|
||||||
|
if self._server_side and not certfile:
|
||||||
|
raise ValueError('certfile is needed for server-side')
|
||||||
|
if certfile and not os.access(certfile, os.R_OK):
|
||||||
|
raise IOError('No such certfile found: %s' % (certfile))
|
||||||
|
self._certfile = certfile
|
||||||
|
|
||||||
|
def _wrap_socket(self, sock):
|
||||||
|
if self._has_ssl_context:
|
||||||
|
if not self._custom_context:
|
||||||
|
self.ssl_context.verify_mode = self.cert_reqs
|
||||||
|
if self.certfile:
|
||||||
|
self.ssl_context.load_cert_chain(self.certfile,
|
||||||
|
self.keyfile)
|
||||||
|
if self.ciphers:
|
||||||
|
self.ssl_context.set_ciphers(self.ciphers)
|
||||||
|
if self.ca_certs:
|
||||||
|
self.ssl_context.load_verify_locations(self.ca_certs)
|
||||||
|
return self.ssl_context.wrap_socket(
|
||||||
|
sock, server_side=self._server_side,
|
||||||
|
server_hostname=self._server_hostname)
|
||||||
|
else:
|
||||||
|
ssl_opts = {
|
||||||
|
'ssl_version': self._ssl_version,
|
||||||
|
'server_side': self._server_side,
|
||||||
|
'ca_certs': self.ca_certs,
|
||||||
|
'keyfile': self.keyfile,
|
||||||
|
'certfile': self.certfile,
|
||||||
|
'cert_reqs': self.cert_reqs,
|
||||||
|
}
|
||||||
|
if self.ciphers:
|
||||||
|
if self._has_ciphers:
|
||||||
|
ssl_opts['ciphers'] = self.ciphers
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'ciphers is specified but ignored due to old Python version')
|
||||||
|
return ssl.wrap_socket(sock, **ssl_opts)
|
||||||
|
|
||||||
|
|
||||||
|
class TSSLSocket(TSocket.TSocket, TSSLBase):
|
||||||
|
"""
|
||||||
|
SSL implementation of TSocket
|
||||||
|
|
||||||
|
This class creates outbound sockets wrapped using the
|
||||||
|
python standard ssl module for encrypted connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# New signature
|
||||||
|
# def __init__(self, host='localhost', port=9090, unix_socket=None,
|
||||||
|
# **ssl_args):
|
||||||
|
# Deprecated signature
|
||||||
|
# def __init__(self, host='localhost', port=9090, validate=True,
|
||||||
|
# ca_certs=None, keyfile=None, certfile=None,
|
||||||
|
# unix_socket=None, ciphers=None):
|
||||||
|
def __init__(self, host='localhost', port=9090, *args, **kwargs):
|
||||||
|
"""Positional arguments: ``host``, ``port``, ``unix_socket``
|
||||||
|
|
||||||
|
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
|
||||||
|
``ssl_version``, ``ca_certs``,
|
||||||
|
``ciphers`` (Python 2.7.0 or later),
|
||||||
|
``server_hostname`` (Python 2.7.9 or later)
|
||||||
|
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
|
||||||
|
|
||||||
|
Alternative keyword arguments: (Python 2.7.9 or later)
|
||||||
|
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
|
||||||
|
``server_hostname``: Passed to SSLContext.wrap_socket
|
||||||
|
|
||||||
|
Common keyword argument:
|
||||||
|
``validate_callback`` (cert, hostname) -> None:
|
||||||
|
Called after SSL handshake. Can raise when hostname does not
|
||||||
|
match the cert.
|
||||||
|
``socket_keepalive`` enable TCP keepalive, default off.
|
||||||
|
"""
|
||||||
|
self.is_valid = False
|
||||||
|
self.peercert = None
|
||||||
|
|
||||||
|
if args:
|
||||||
|
if len(args) > 6:
|
||||||
|
raise TypeError('Too many positional argument')
|
||||||
|
if not self._unix_socket_arg(host, port, args, kwargs):
|
||||||
|
self._deprecated_arg(args, kwargs, 0, 'validate')
|
||||||
|
self._deprecated_arg(args, kwargs, 1, 'ca_certs')
|
||||||
|
self._deprecated_arg(args, kwargs, 2, 'keyfile')
|
||||||
|
self._deprecated_arg(args, kwargs, 3, 'certfile')
|
||||||
|
self._deprecated_arg(args, kwargs, 4, 'unix_socket')
|
||||||
|
self._deprecated_arg(args, kwargs, 5, 'ciphers')
|
||||||
|
|
||||||
|
validate = kwargs.pop('validate', None)
|
||||||
|
if validate is not None:
|
||||||
|
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
|
||||||
|
warnings.warn(
|
||||||
|
'validate is deprecated. please use cert_reqs=ssl.%s instead'
|
||||||
|
% cert_reqs_name,
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
if 'cert_reqs' in kwargs:
|
||||||
|
raise TypeError('Cannot specify both validate and cert_reqs')
|
||||||
|
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
|
||||||
|
|
||||||
|
unix_socket = kwargs.pop('unix_socket', None)
|
||||||
|
socket_keepalive = kwargs.pop('socket_keepalive', False)
|
||||||
|
self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
|
||||||
|
TSSLBase.__init__(self, False, host, kwargs)
|
||||||
|
TSocket.TSocket.__init__(self, host, port, unix_socket,
|
||||||
|
socket_keepalive=socket_keepalive)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
try:
|
try:
|
||||||
self.handle.connect(ip_port)
|
self.handle.settimeout(0.001)
|
||||||
except socket.error as e:
|
self.handle = self.handle.unwrap()
|
||||||
if res is not res0[-1]:
|
except (ssl.SSLError, socket.error, OSError):
|
||||||
continue
|
# could not complete shutdown in a reasonable amount of time. bail.
|
||||||
else:
|
pass
|
||||||
raise e
|
TSocket.TSocket.close(self)
|
||||||
break
|
|
||||||
except socket.error as e:
|
|
||||||
if self._unix_socket:
|
|
||||||
message = 'Could not connect to secure socket %s' % self._unix_socket
|
|
||||||
else:
|
|
||||||
message = 'Could not connect to %s:%d' % (self.host, self.port)
|
|
||||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
|
||||||
message=message)
|
|
||||||
if self.validate:
|
|
||||||
self._validate_cert()
|
|
||||||
|
|
||||||
def _validate_cert(self):
|
@property
|
||||||
"""internal method to validate the peer's SSL certificate, and to check the
|
def validate(self):
|
||||||
commonName of the certificate to ensure it matches the hostname we
|
warnings.warn('validate is deprecated. please use cert_reqs instead',
|
||||||
used to make this connection. Does not support subjectAltName records
|
DeprecationWarning, stacklevel=2)
|
||||||
in certificates.
|
return self.cert_reqs != ssl.CERT_NONE
|
||||||
|
|
||||||
raises TTransportException if the certificate fails validation.
|
@validate.setter
|
||||||
|
def validate(self, value):
|
||||||
|
warnings.warn('validate is deprecated. please use cert_reqs instead',
|
||||||
|
DeprecationWarning, stacklevel=2)
|
||||||
|
self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
|
||||||
|
|
||||||
|
def _do_open(self, family, socktype):
|
||||||
|
plain_sock = socket.socket(family, socktype)
|
||||||
|
try:
|
||||||
|
return self._wrap_socket(plain_sock)
|
||||||
|
except Exception as ex:
|
||||||
|
plain_sock.close()
|
||||||
|
msg = 'failed to initialize SSL'
|
||||||
|
logger.exception(msg)
|
||||||
|
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex)
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
super(TSSLSocket, self).open()
|
||||||
|
if self._should_verify:
|
||||||
|
self.peercert = self.handle.getpeercert()
|
||||||
|
try:
|
||||||
|
self._validate_callback(self.peercert, self._server_hostname)
|
||||||
|
self.is_valid = True
|
||||||
|
except TTransportException:
|
||||||
|
raise
|
||||||
|
except Exception as ex:
|
||||||
|
raise TTransportException(message=str(ex), inner=ex)
|
||||||
|
|
||||||
|
|
||||||
|
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
|
||||||
|
"""SSL implementation of TServerSocket
|
||||||
|
|
||||||
|
This uses the ssl module's wrap_socket() method to provide SSL
|
||||||
|
negotiated encryption.
|
||||||
"""
|
"""
|
||||||
cert = self.handle.getpeercert()
|
|
||||||
self.peercert = cert
|
|
||||||
if 'subject' not in cert:
|
|
||||||
raise TTransportException(
|
|
||||||
type=TTransportException.NOT_OPEN,
|
|
||||||
message='No SSL certificate found from %s:%s' % (self.host, self.port))
|
|
||||||
fields = cert['subject']
|
|
||||||
for field in fields:
|
|
||||||
# ensure structure we get back is what we expect
|
|
||||||
if not isinstance(field, tuple):
|
|
||||||
continue
|
|
||||||
cert_pair = field[0]
|
|
||||||
if len(cert_pair) < 2:
|
|
||||||
continue
|
|
||||||
cert_key, cert_value = cert_pair[0:2]
|
|
||||||
if cert_key != 'commonName':
|
|
||||||
continue
|
|
||||||
certhost = cert_value
|
|
||||||
if certhost == self.host:
|
|
||||||
# success, cert commonName matches desired hostname
|
|
||||||
self.is_valid = True
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise TTransportException(
|
|
||||||
type=TTransportException.UNKNOWN,
|
|
||||||
message='Hostname we connected to "%s" doesn\'t match certificate '
|
|
||||||
'provided commonName "%s"' % (self.host, certhost))
|
|
||||||
raise TTransportException(
|
|
||||||
type=TTransportException.UNKNOWN,
|
|
||||||
message='Could not validate SSL certificate from '
|
|
||||||
'host "%s". Cert=%s' % (self.host, cert))
|
|
||||||
|
|
||||||
|
# New signature
|
||||||
|
# def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
|
||||||
|
# Deprecated signature
|
||||||
|
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
|
||||||
|
def __init__(self, host=None, port=9090, *args, **kwargs):
|
||||||
|
"""Positional arguments: ``host``, ``port``, ``unix_socket``
|
||||||
|
|
||||||
class TSSLServerSocket(TSocket.TServerSocket):
|
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
|
||||||
"""SSL implementation of TServerSocket
|
``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
|
||||||
|
See ssl.wrap_socket documentation.
|
||||||
|
|
||||||
This uses the ssl module's wrap_socket() method to provide SSL
|
Alternative keyword arguments: (Python 2.7.9 or later)
|
||||||
negotiated encryption.
|
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
|
||||||
"""
|
``server_hostname``: Passed to SSLContext.wrap_socket
|
||||||
SSL_VERSION = ssl.PROTOCOL_TLSv1
|
|
||||||
|
|
||||||
def __init__(self,
|
Common keyword argument:
|
||||||
host=None,
|
``validate_callback`` (cert, hostname) -> None:
|
||||||
port=9090,
|
Called after SSL handshake. Can raise when hostname does not
|
||||||
certfile='cert.pem',
|
match the cert.
|
||||||
unix_socket=None):
|
"""
|
||||||
"""Initialize a TSSLServerSocket
|
if args:
|
||||||
|
if len(args) > 3:
|
||||||
|
raise TypeError('Too many positional argument')
|
||||||
|
if not self._unix_socket_arg(host, port, args, kwargs):
|
||||||
|
self._deprecated_arg(args, kwargs, 0, 'certfile')
|
||||||
|
self._deprecated_arg(args, kwargs, 1, 'unix_socket')
|
||||||
|
self._deprecated_arg(args, kwargs, 2, 'ciphers')
|
||||||
|
|
||||||
@param certfile: filename of the server certificate, defaults to cert.pem
|
if 'ssl_context' not in kwargs:
|
||||||
@type certfile: str
|
# Preserve existing behaviors for default values
|
||||||
@param host: The hostname or IP to bind the listen socket to,
|
if 'cert_reqs' not in kwargs:
|
||||||
i.e. 'localhost' for only allowing local network connections.
|
kwargs['cert_reqs'] = ssl.CERT_NONE
|
||||||
Pass None to bind to all interfaces.
|
if'certfile' not in kwargs:
|
||||||
@type host: str
|
kwargs['certfile'] = 'cert.pem'
|
||||||
@param port: The port to listen on for inbound connections.
|
|
||||||
@type port: int
|
|
||||||
"""
|
|
||||||
self.setCertfile(certfile)
|
|
||||||
TSocket.TServerSocket.__init__(self, host, port)
|
|
||||||
|
|
||||||
def setCertfile(self, certfile):
|
unix_socket = kwargs.pop('unix_socket', None)
|
||||||
"""Set or change the server certificate file used to wrap new connections.
|
self._validate_callback = \
|
||||||
|
kwargs.pop('validate_callback', _match_hostname)
|
||||||
|
TSSLBase.__init__(self, True, None, kwargs)
|
||||||
|
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
|
||||||
|
if self._should_verify and not _match_has_ipaddress:
|
||||||
|
raise ValueError('Need ipaddress and backports.ssl_match_hostname '
|
||||||
|
'module to verify client certificate')
|
||||||
|
|
||||||
@param certfile: The filename of the server certificate,
|
def setCertfile(self, certfile):
|
||||||
i.e. '/etc/certs/server.pem'
|
"""Set or change the server certificate file used to wrap new
|
||||||
@type certfile: str
|
connections.
|
||||||
|
|
||||||
Raises an IOError exception if the certfile is not present or unreadable.
|
@param certfile: The filename of the server certificate,
|
||||||
"""
|
i.e. '/etc/certs/server.pem'
|
||||||
if not os.access(certfile, os.R_OK):
|
@type certfile: str
|
||||||
raise IOError('No such certfile found: %s' % (certfile))
|
|
||||||
self.certfile = certfile
|
|
||||||
|
|
||||||
def accept(self):
|
Raises an IOError exception if the certfile is not present or unreadable.
|
||||||
plain_client, addr = self.handle.accept()
|
"""
|
||||||
try:
|
warnings.warn(
|
||||||
client = ssl.wrap_socket(plain_client, certfile=self.certfile,
|
'setCertfile is deprecated. please use certfile property instead.',
|
||||||
server_side=True, ssl_version=self.SSL_VERSION)
|
DeprecationWarning, stacklevel=2)
|
||||||
except ssl.SSLError as ssl_exc:
|
self.certfile = certfile
|
||||||
# failed handshake/ssl wrap, close socket to client
|
|
||||||
plain_client.close()
|
def accept(self):
|
||||||
# raise ssl_exc
|
plain_client, addr = self.handle.accept()
|
||||||
# We can't raise the exception, because it kills most TServer derived
|
try:
|
||||||
# serve() methods.
|
client = self._wrap_socket(plain_client)
|
||||||
# Instead, return None, and let the TServer instance deal with it in
|
except (ssl.SSLError, socket.error, OSError):
|
||||||
# other exception handling. (but TSimpleServer dies anyway)
|
logger.exception('Error while accepting from %s', addr)
|
||||||
return None
|
# failed handshake/ssl wrap, close socket to client
|
||||||
result = TSocket.TSocket()
|
plain_client.close()
|
||||||
result.setHandle(client)
|
# raise
|
||||||
return result
|
# We can't raise the exception, because it kills most TServer derived
|
||||||
|
# serve() methods.
|
||||||
|
# Instead, return None, and let the TServer instance deal with it in
|
||||||
|
# other exception handling. (but TSimpleServer dies anyway)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._should_verify:
|
||||||
|
client.peercert = client.getpeercert()
|
||||||
|
try:
|
||||||
|
self._validate_callback(client.peercert, addr[0])
|
||||||
|
client.is_valid = True
|
||||||
|
except Exception:
|
||||||
|
logger.warn('Failed to validate client certificate address: %s',
|
||||||
|
addr[0], exc_info=True)
|
||||||
|
client.close()
|
||||||
|
plain_client.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = TSocket.TSocket()
|
||||||
|
result.handle = client
|
||||||
|
return result
|
||||||
|
|
|
@ -18,159 +18,222 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from .TTransport import *
|
from .TTransport import TTransportBase, TTransportException, TServerTransportBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TSocketBase(TTransportBase):
|
class TSocketBase(TTransportBase):
|
||||||
def _resolveAddr(self):
|
def _resolveAddr(self):
|
||||||
if self._unix_socket is not None:
|
if self._unix_socket is not None:
|
||||||
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
|
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
|
||||||
self._unix_socket)]
|
self._unix_socket)]
|
||||||
else:
|
else:
|
||||||
return socket.getaddrinfo(self.host,
|
return socket.getaddrinfo(self.host,
|
||||||
self.port,
|
self.port,
|
||||||
socket.AF_UNSPEC,
|
self._socket_family,
|
||||||
socket.SOCK_STREAM,
|
socket.SOCK_STREAM,
|
||||||
0,
|
0,
|
||||||
socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
|
socket.AI_PASSIVE)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.handle:
|
if self.handle:
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
|
|
||||||
class TSocket(TSocketBase):
|
class TSocket(TSocketBase):
|
||||||
"""Socket implementation of TTransport base."""
|
"""Socket implementation of TTransport base."""
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=9090, unix_socket=None):
|
def __init__(self, host='localhost', port=9090, unix_socket=None,
|
||||||
"""Initialize a TSocket
|
socket_family=socket.AF_UNSPEC,
|
||||||
|
socket_keepalive=False):
|
||||||
|
"""Initialize a TSocket
|
||||||
|
|
||||||
@param host(str) The host to connect to.
|
@param host(str) The host to connect to.
|
||||||
@param port(int) The (TCP) port to connect to.
|
@param port(int) The (TCP) port to connect to.
|
||||||
@param unix_socket(str) The filename of a unix socket to connect to.
|
@param unix_socket(str) The filename of a unix socket to connect to.
|
||||||
(host and port will be ignored.)
|
(host and port will be ignored.)
|
||||||
"""
|
@param socket_family(int) The socket family to use with this socket.
|
||||||
self.host = host
|
@param socket_keepalive(bool) enable TCP keepalive, default off.
|
||||||
self.port = port
|
"""
|
||||||
self.handle = None
|
self.host = host
|
||||||
self._unix_socket = unix_socket
|
self.port = port
|
||||||
self._timeout = None
|
self.handle = None
|
||||||
|
self._unix_socket = unix_socket
|
||||||
|
self._timeout = None
|
||||||
|
self._socket_family = socket_family
|
||||||
|
self._socket_keepalive = socket_keepalive
|
||||||
|
|
||||||
def setHandle(self, h):
|
def setHandle(self, h):
|
||||||
self.handle = h
|
self.handle = h
|
||||||
|
|
||||||
def isOpen(self):
|
def isOpen(self):
|
||||||
return self.handle is not None
|
if self.handle is None:
|
||||||
|
return False
|
||||||
|
|
||||||
def setTimeout(self, ms):
|
# this lets us cheaply see if the other end of the socket is still
|
||||||
if ms is None:
|
# connected. if disconnected, we'll get EOF back (expressed as zero
|
||||||
self._timeout = None
|
# bytes of data) otherwise we'll get one byte or an error indicating
|
||||||
else:
|
# we'd have to block for data.
|
||||||
self._timeout = ms / 1000.0
|
#
|
||||||
|
# note that we're not doing this with socket.MSG_DONTWAIT because 1)
|
||||||
if self.handle is not None:
|
# it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us
|
||||||
self.handle.settimeout(self._timeout)
|
# when timeout is non-zero.
|
||||||
|
original_timeout = self.handle.gettimeout()
|
||||||
def open(self):
|
|
||||||
try:
|
|
||||||
res0 = self._resolveAddr()
|
|
||||||
for res in res0:
|
|
||||||
self.handle = socket.socket(res[0], res[1])
|
|
||||||
self.handle.settimeout(self._timeout)
|
|
||||||
try:
|
try:
|
||||||
self.handle.connect(res[4])
|
self.handle.settimeout(0)
|
||||||
|
try:
|
||||||
|
peeked_bytes = self.handle.recv(1, socket.MSG_PEEK)
|
||||||
|
except (socket.error, OSError) as exc: # on modern python this is just BlockingIOError
|
||||||
|
if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self.handle.settimeout(original_timeout)
|
||||||
|
|
||||||
|
# the length will be zero if we got EOF (indicating connection closed)
|
||||||
|
return len(peeked_bytes) == 1
|
||||||
|
|
||||||
|
def setTimeout(self, ms):
|
||||||
|
if ms is None:
|
||||||
|
self._timeout = None
|
||||||
|
else:
|
||||||
|
self._timeout = ms / 1000.0
|
||||||
|
|
||||||
|
if self.handle is not None:
|
||||||
|
self.handle.settimeout(self._timeout)
|
||||||
|
|
||||||
|
def _do_open(self, family, socktype):
|
||||||
|
return socket.socket(family, socktype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _address(self):
|
||||||
|
return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
if self.handle:
|
||||||
|
raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open")
|
||||||
|
try:
|
||||||
|
addrs = self._resolveAddr()
|
||||||
|
except socket.gaierror as gai:
|
||||||
|
msg = 'failed to resolve sockaddr for ' + str(self._address)
|
||||||
|
logger.exception(msg)
|
||||||
|
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai)
|
||||||
|
for family, socktype, _, _, sockaddr in addrs:
|
||||||
|
handle = self._do_open(family, socktype)
|
||||||
|
|
||||||
|
# TCP_KEEPALIVE
|
||||||
|
if self._socket_keepalive:
|
||||||
|
handle.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1)
|
||||||
|
|
||||||
|
handle.settimeout(self._timeout)
|
||||||
|
try:
|
||||||
|
handle.connect(sockaddr)
|
||||||
|
self.handle = handle
|
||||||
|
return
|
||||||
|
except socket.error:
|
||||||
|
handle.close()
|
||||||
|
logger.info('Could not connect to %s', sockaddr, exc_info=True)
|
||||||
|
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
|
||||||
|
addrs))
|
||||||
|
logger.error(msg)
|
||||||
|
raise TTransportException(type=TTransportException.NOT_OPEN, message=msg)
|
||||||
|
|
||||||
|
def read(self, sz):
|
||||||
|
try:
|
||||||
|
buff = self.handle.recv(sz)
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
if res is not res0[-1]:
|
if (e.args[0] == errno.ECONNRESET and
|
||||||
continue
|
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
|
||||||
else:
|
# freebsd and Mach don't follow POSIX semantic of recv
|
||||||
raise e
|
# and fail with ECONNRESET if peer performed shutdown.
|
||||||
break
|
# See corresponding comment and code in TSocket::read()
|
||||||
except socket.error as e:
|
# in lib/cpp/src/transport/TSocket.cpp.
|
||||||
if self._unix_socket:
|
self.close()
|
||||||
message = 'Could not connect to socket %s' % self._unix_socket
|
# Trigger the check to raise the END_OF_FILE exception below.
|
||||||
else:
|
buff = ''
|
||||||
message = 'Could not connect to %s:%d' % (self.host, self.port)
|
elif e.args[0] == errno.ETIMEDOUT:
|
||||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e)
|
||||||
message=message)
|
else:
|
||||||
|
raise TTransportException(message="unexpected exception", inner=e)
|
||||||
|
if len(buff) == 0:
|
||||||
|
raise TTransportException(type=TTransportException.END_OF_FILE,
|
||||||
|
message='TSocket read 0 bytes')
|
||||||
|
return buff
|
||||||
|
|
||||||
def read(self, sz):
|
def write(self, buff):
|
||||||
try:
|
if not self.handle:
|
||||||
buff = self.handle.recv(sz)
|
raise TTransportException(type=TTransportException.NOT_OPEN,
|
||||||
except socket.error as e:
|
message='Transport not open')
|
||||||
if (e.args[0] == errno.ECONNRESET and
|
sent = 0
|
||||||
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
|
have = len(buff)
|
||||||
# freebsd and Mach don't follow POSIX semantic of recv
|
while sent < have:
|
||||||
# and fail with ECONNRESET if peer performed shutdown.
|
try:
|
||||||
# See corresponding comment and code in TSocket::read()
|
plus = self.handle.send(buff)
|
||||||
# in lib/cpp/src/transport/TSocket.cpp.
|
if plus == 0:
|
||||||
self.close()
|
raise TTransportException(type=TTransportException.END_OF_FILE,
|
||||||
# Trigger the check to raise the END_OF_FILE exception below.
|
message='TSocket sent 0 bytes')
|
||||||
buff = ''
|
sent += plus
|
||||||
else:
|
buff = buff[plus:]
|
||||||
raise
|
except socket.error as e:
|
||||||
if len(buff) == 0:
|
raise TTransportException(message="unexpected exception", inner=e)
|
||||||
raise TTransportException(type=TTransportException.END_OF_FILE,
|
|
||||||
message='TSocket read 0 bytes')
|
|
||||||
return buff
|
|
||||||
|
|
||||||
def write(self, buff):
|
def flush(self):
|
||||||
if not self.handle:
|
pass
|
||||||
raise TTransportException(type=TTransportException.NOT_OPEN,
|
|
||||||
message='Transport not open')
|
|
||||||
sent = 0
|
|
||||||
have = len(buff)
|
|
||||||
while sent < have:
|
|
||||||
plus = self.handle.send(buff)
|
|
||||||
if plus == 0:
|
|
||||||
raise TTransportException(type=TTransportException.END_OF_FILE,
|
|
||||||
message='TSocket sent 0 bytes')
|
|
||||||
sent += plus
|
|
||||||
buff = buff[plus:]
|
|
||||||
|
|
||||||
def flush(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TServerSocket(TSocketBase, TServerTransportBase):
|
class TServerSocket(TSocketBase, TServerTransportBase):
|
||||||
"""Socket implementation of TServerTransport base."""
|
"""Socket implementation of TServerTransport base."""
|
||||||
|
|
||||||
def __init__(self, host=None, port=9090, unix_socket=None):
|
def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self._unix_socket = unix_socket
|
self._unix_socket = unix_socket
|
||||||
self.handle = None
|
self._socket_family = socket_family
|
||||||
|
self.handle = None
|
||||||
|
self._backlog = 128
|
||||||
|
|
||||||
def listen(self):
|
def setBacklog(self, backlog=None):
|
||||||
res0 = self._resolveAddr()
|
if not self.handle:
|
||||||
for res in res0:
|
self._backlog = backlog
|
||||||
if res[0] is socket.AF_INET6 or res is res0[-1]:
|
else:
|
||||||
break
|
# We cann't update backlog when it is already listening, since the
|
||||||
|
# handle has been created.
|
||||||
|
logger.warn('You have to set backlog before listen.')
|
||||||
|
|
||||||
# We need remove the old unix socket if the file exists and
|
def listen(self):
|
||||||
# nobody is listening on it.
|
res0 = self._resolveAddr()
|
||||||
if self._unix_socket:
|
socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
|
||||||
tmp = socket.socket(res[0], res[1])
|
for res in res0:
|
||||||
try:
|
if res[0] is socket_family or res is res0[-1]:
|
||||||
tmp.connect(res[4])
|
break
|
||||||
except socket.error as err:
|
|
||||||
eno, message = err.args
|
|
||||||
if eno == errno.ECONNREFUSED:
|
|
||||||
os.unlink(res[4])
|
|
||||||
|
|
||||||
self.handle = socket.socket(res[0], res[1])
|
# We need remove the old unix socket if the file exists and
|
||||||
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
# nobody is listening on it.
|
||||||
if hasattr(self.handle, 'settimeout'):
|
if self._unix_socket:
|
||||||
self.handle.settimeout(None)
|
tmp = socket.socket(res[0], res[1])
|
||||||
self.handle.bind(res[4])
|
try:
|
||||||
self.handle.listen(128)
|
tmp.connect(res[4])
|
||||||
|
except socket.error as err:
|
||||||
|
eno, message = err.args
|
||||||
|
if eno == errno.ECONNREFUSED:
|
||||||
|
os.unlink(res[4])
|
||||||
|
|
||||||
def accept(self):
|
self.handle = socket.socket(res[0], res[1])
|
||||||
client, addr = self.handle.accept()
|
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
result = TSocket()
|
if hasattr(self.handle, 'settimeout'):
|
||||||
result.setHandle(client)
|
self.handle.settimeout(None)
|
||||||
return result
|
self.handle.bind(res[4])
|
||||||
|
self.handle.listen(self._backlog)
|
||||||
|
|
||||||
|
def accept(self):
|
||||||
|
client, addr = self.handle.accept()
|
||||||
|
result = TSocket()
|
||||||
|
result.setHandle(client)
|
||||||
|
return result
|
||||||
|
|
|
@ -17,317 +17,440 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from six import BytesIO
|
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
from thrift.Thrift import TException
|
from thrift.Thrift import TException
|
||||||
|
from ..compat import BufferIO
|
||||||
|
|
||||||
|
|
||||||
class TTransportException(TException):
|
class TTransportException(TException):
|
||||||
"""Custom Transport Exception class"""
|
"""Custom Transport Exception class"""
|
||||||
|
|
||||||
UNKNOWN = 0
|
UNKNOWN = 0
|
||||||
NOT_OPEN = 1
|
NOT_OPEN = 1
|
||||||
ALREADY_OPEN = 2
|
ALREADY_OPEN = 2
|
||||||
TIMED_OUT = 3
|
TIMED_OUT = 3
|
||||||
END_OF_FILE = 4
|
END_OF_FILE = 4
|
||||||
|
NEGATIVE_SIZE = 5
|
||||||
|
SIZE_LIMIT = 6
|
||||||
|
INVALID_CLIENT_TYPE = 7
|
||||||
|
|
||||||
def __init__(self, type=UNKNOWN, message=None):
|
def __init__(self, type=UNKNOWN, message=None, inner=None):
|
||||||
TException.__init__(self, message)
|
TException.__init__(self, message)
|
||||||
self.type = type
|
self.type = type
|
||||||
|
self.inner = inner
|
||||||
|
|
||||||
|
|
||||||
class TTransportBase:
|
class TTransportBase(object):
|
||||||
"""Base class for Thrift transport layer."""
|
"""Base class for Thrift transport layer."""
|
||||||
|
|
||||||
def isOpen(self):
|
def isOpen(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def read(self, sz):
|
def read(self, sz):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def readAll(self, sz):
|
def readAll(self, sz):
|
||||||
buff = b''
|
buff = b''
|
||||||
have = 0
|
have = 0
|
||||||
while (have < sz):
|
while (have < sz):
|
||||||
chunk = self.read(sz - have)
|
chunk = self.read(sz - have)
|
||||||
have += len(chunk)
|
chunkLen = len(chunk)
|
||||||
buff += chunk
|
have += chunkLen
|
||||||
|
buff += chunk
|
||||||
|
|
||||||
if len(chunk) == 0:
|
if chunkLen == 0:
|
||||||
raise EOFError()
|
raise EOFError()
|
||||||
|
|
||||||
return buff
|
return buff
|
||||||
|
|
||||||
def write(self, buf):
|
def write(self, buf):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# This class should be thought of as an interface.
|
# This class should be thought of as an interface.
|
||||||
class CReadableTransport:
|
class CReadableTransport(object):
|
||||||
"""base class for transports that are readable from C"""
|
"""base class for transports that are readable from C"""
|
||||||
|
|
||||||
# TODO(dreiss): Think about changing this interface to allow us to use
|
# TODO(dreiss): Think about changing this interface to allow us to use
|
||||||
# a (Python, not c) StringIO instead, because it allows
|
# a (Python, not c) StringIO instead, because it allows
|
||||||
# you to write after reading.
|
# you to write after reading.
|
||||||
|
|
||||||
# NOTE: This is a classic class, so properties will NOT work
|
# NOTE: This is a classic class, so properties will NOT work
|
||||||
# correctly for setting.
|
# correctly for setting.
|
||||||
@property
|
@property
|
||||||
def cstringio_buf(self):
|
def cstringio_buf(self):
|
||||||
"""A cStringIO buffer that contains the current chunk we are reading."""
|
"""A cStringIO buffer that contains the current chunk we are reading."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cstringio_refill(self, partialread, reqlen):
|
def cstringio_refill(self, partialread, reqlen):
|
||||||
"""Refills cstringio_buf.
|
"""Refills cstringio_buf.
|
||||||
|
|
||||||
Returns the currently used buffer (which can but need not be the same as
|
Returns the currently used buffer (which can but need not be the same as
|
||||||
the old cstringio_buf). partialread is what the C code has read from the
|
the old cstringio_buf). partialread is what the C code has read from the
|
||||||
buffer, and should be inserted into the buffer before any more reads. The
|
buffer, and should be inserted into the buffer before any more reads. The
|
||||||
return value must be a new, not borrowed reference. Something along the
|
return value must be a new, not borrowed reference. Something along the
|
||||||
lines of self._buf should be fine.
|
lines of self._buf should be fine.
|
||||||
|
|
||||||
If reqlen bytes can't be read, throw EOFError.
|
If reqlen bytes can't be read, throw EOFError.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TServerTransportBase:
|
class TServerTransportBase(object):
|
||||||
"""Base class for Thrift server transports."""
|
"""Base class for Thrift server transports."""
|
||||||
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def accept(self):
|
def accept(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTransportFactoryBase:
|
class TTransportFactoryBase(object):
|
||||||
"""Base class for a Transport Factory"""
|
"""Base class for a Transport Factory"""
|
||||||
|
|
||||||
def getTransport(self, trans):
|
def getTransport(self, trans):
|
||||||
return trans
|
return trans
|
||||||
|
|
||||||
|
|
||||||
class TBufferedTransportFactory:
|
class TBufferedTransportFactory(object):
|
||||||
"""Factory transport that builds buffered transports"""
|
"""Factory transport that builds buffered transports"""
|
||||||
|
|
||||||
def getTransport(self, trans):
|
def getTransport(self, trans):
|
||||||
buffered = TBufferedTransport(trans)
|
buffered = TBufferedTransport(trans)
|
||||||
return buffered
|
return buffered
|
||||||
|
|
||||||
|
|
||||||
class TBufferedTransport(TTransportBase, CReadableTransport):
|
class TBufferedTransport(TTransportBase, CReadableTransport):
|
||||||
"""Class that wraps another transport and buffers its I/O.
|
"""Class that wraps another transport and buffers its I/O.
|
||||||
|
|
||||||
The implementation uses a (configurable) fixed-size read buffer
|
The implementation uses a (configurable) fixed-size read buffer
|
||||||
but buffers all writes until a flush is performed.
|
but buffers all writes until a flush is performed.
|
||||||
"""
|
"""
|
||||||
DEFAULT_BUFFER = 4096
|
DEFAULT_BUFFER = 4096
|
||||||
|
|
||||||
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
|
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
|
||||||
self.__trans = trans
|
self.__trans = trans
|
||||||
self.__wbuf = BytesIO()
|
self.__wbuf = BufferIO()
|
||||||
self.__rbuf = BytesIO("")
|
# Pass string argument to initialize read buffer as cStringIO.InputType
|
||||||
self.__rbuf_size = rbuf_size
|
self.__rbuf = BufferIO(b'')
|
||||||
|
self.__rbuf_size = rbuf_size
|
||||||
|
|
||||||
def isOpen(self):
|
def isOpen(self):
|
||||||
return self.__trans.isOpen()
|
return self.__trans.isOpen()
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
return self.__trans.open()
|
return self.__trans.open()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
return self.__trans.close()
|
return self.__trans.close()
|
||||||
|
|
||||||
def read(self, sz):
|
def read(self, sz):
|
||||||
ret = self.__rbuf.read(sz)
|
ret = self.__rbuf.read(sz)
|
||||||
if len(ret) != 0:
|
if len(ret) != 0:
|
||||||
return ret
|
return ret
|
||||||
|
self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
|
||||||
|
return self.__rbuf.read(sz)
|
||||||
|
|
||||||
self.__rbuf = BytesIO(self.__trans.read(max(sz, self.__rbuf_size)))
|
def write(self, buf):
|
||||||
return self.__rbuf.read(sz)
|
try:
|
||||||
|
self.__wbuf.write(buf)
|
||||||
|
except Exception as e:
|
||||||
|
# on exception reset wbuf so it doesn't contain a partial function call
|
||||||
|
self.__wbuf = BufferIO()
|
||||||
|
raise e
|
||||||
|
|
||||||
def write(self, buf):
|
def flush(self):
|
||||||
self.__wbuf.write(buf)
|
out = self.__wbuf.getvalue()
|
||||||
|
# reset wbuf before write/flush to preserve state on underlying failure
|
||||||
|
self.__wbuf = BufferIO()
|
||||||
|
self.__trans.write(out)
|
||||||
|
self.__trans.flush()
|
||||||
|
|
||||||
def flush(self):
|
# Implement the CReadableTransport interface.
|
||||||
out = self.__wbuf.getvalue()
|
@property
|
||||||
# reset wbuf before write/flush to preserve state on underlying failure
|
def cstringio_buf(self):
|
||||||
self.__wbuf = BytesIO()
|
return self.__rbuf
|
||||||
self.__trans.write(out)
|
|
||||||
self.__trans.flush()
|
|
||||||
|
|
||||||
# Implement the CReadableTransport interface.
|
def cstringio_refill(self, partialread, reqlen):
|
||||||
@property
|
retstring = partialread
|
||||||
def cstringio_buf(self):
|
if reqlen < self.__rbuf_size:
|
||||||
return self.__rbuf
|
# try to make a read of as much as we can.
|
||||||
|
retstring += self.__trans.read(self.__rbuf_size)
|
||||||
|
|
||||||
def cstringio_refill(self, partialread, reqlen):
|
# but make sure we do read reqlen bytes.
|
||||||
retstring = partialread
|
if len(retstring) < reqlen:
|
||||||
if reqlen < self.__rbuf_size:
|
retstring += self.__trans.readAll(reqlen - len(retstring))
|
||||||
# try to make a read of as much as we can.
|
|
||||||
retstring += self.__trans.read(self.__rbuf_size)
|
|
||||||
|
|
||||||
# but make sure we do read reqlen bytes.
|
self.__rbuf = BufferIO(retstring)
|
||||||
if len(retstring) < reqlen:
|
return self.__rbuf
|
||||||
retstring += self.__trans.readAll(reqlen - len(retstring))
|
|
||||||
|
|
||||||
self.__rbuf = BytesIO(retstring)
|
|
||||||
return self.__rbuf
|
|
||||||
|
|
||||||
|
|
||||||
class TMemoryBuffer(TTransportBase, CReadableTransport):
|
class TMemoryBuffer(TTransportBase, CReadableTransport):
|
||||||
"""Wraps a cStringIO object as a TTransport.
|
"""Wraps a cBytesIO object as a TTransport.
|
||||||
|
|
||||||
NOTE: Unlike the C++ version of this class, you cannot write to it
|
NOTE: Unlike the C++ version of this class, you cannot write to it
|
||||||
then immediately read from it. If you want to read from a
|
then immediately read from it. If you want to read from a
|
||||||
TMemoryBuffer, you must either pass a string to the constructor.
|
TMemoryBuffer, you must either pass a string to the constructor.
|
||||||
TODO(dreiss): Make this work like the C++ version.
|
TODO(dreiss): Make this work like the C++ version.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, value=None):
|
def __init__(self, value=None, offset=0):
|
||||||
"""value -- a value to read from for stringio
|
"""value -- a value to read from for stringio
|
||||||
|
|
||||||
If value is set, this will be a transport for reading,
|
If value is set, this will be a transport for reading,
|
||||||
otherwise, it is for writing"""
|
otherwise, it is for writing"""
|
||||||
if value is not None:
|
if value is not None:
|
||||||
self._buffer = BytesIO(value)
|
self._buffer = BufferIO(value)
|
||||||
else:
|
else:
|
||||||
self._buffer = BytesIO()
|
self._buffer = BufferIO()
|
||||||
|
if offset:
|
||||||
|
self._buffer.seek(offset)
|
||||||
|
|
||||||
def isOpen(self):
|
def isOpen(self):
|
||||||
return not self._buffer.closed
|
return not self._buffer.closed
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._buffer.close()
|
self._buffer.close()
|
||||||
|
|
||||||
def read(self, sz):
|
def read(self, sz):
|
||||||
return self._buffer.read(sz)
|
return self._buffer.read(sz)
|
||||||
|
|
||||||
def write(self, buf):
|
def write(self, buf):
|
||||||
try:
|
self._buffer.write(buf)
|
||||||
self._buffer.write(buf)
|
|
||||||
except TypeError:
|
|
||||||
self._buffer.write(buf.encode('cp437'))
|
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def getvalue(self):
|
def getvalue(self):
|
||||||
return self._buffer.getvalue()
|
return self._buffer.getvalue()
|
||||||
|
|
||||||
# Implement the CReadableTransport interface.
|
# Implement the CReadableTransport interface.
|
||||||
@property
|
@property
|
||||||
def cstringio_buf(self):
|
def cstringio_buf(self):
|
||||||
return self._buffer
|
return self._buffer
|
||||||
|
|
||||||
def cstringio_refill(self, partialread, reqlen):
|
def cstringio_refill(self, partialread, reqlen):
|
||||||
# only one shot at reading...
|
# only one shot at reading...
|
||||||
raise EOFError()
|
raise EOFError()
|
||||||
|
|
||||||
|
|
||||||
class TFramedTransportFactory:
|
class TFramedTransportFactory(object):
|
||||||
"""Factory transport that builds framed transports"""
|
"""Factory transport that builds framed transports"""
|
||||||
|
|
||||||
def getTransport(self, trans):
|
def getTransport(self, trans):
|
||||||
framed = TFramedTransport(trans)
|
framed = TFramedTransport(trans)
|
||||||
return framed
|
return framed
|
||||||
|
|
||||||
|
|
||||||
class TFramedTransport(TTransportBase, CReadableTransport):
|
class TFramedTransport(TTransportBase, CReadableTransport):
|
||||||
"""Class that wraps another transport and frames its I/O when writing."""
|
"""Class that wraps another transport and frames its I/O when writing."""
|
||||||
|
|
||||||
def __init__(self, trans,):
|
def __init__(self, trans,):
|
||||||
self.__trans = trans
|
self.__trans = trans
|
||||||
self.__rbuf = BytesIO()
|
self.__rbuf = BufferIO(b'')
|
||||||
self.__wbuf = BytesIO()
|
self.__wbuf = BufferIO()
|
||||||
|
|
||||||
def isOpen(self):
|
def isOpen(self):
|
||||||
return self.__trans.isOpen()
|
return self.__trans.isOpen()
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
return self.__trans.open()
|
return self.__trans.open()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
return self.__trans.close()
|
return self.__trans.close()
|
||||||
|
|
||||||
def read(self, sz):
|
def read(self, sz):
|
||||||
ret = self.__rbuf.read(sz)
|
ret = self.__rbuf.read(sz)
|
||||||
if len(ret) != 0:
|
if len(ret) != 0:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
self.readFrame()
|
self.readFrame()
|
||||||
return self.__rbuf.read(sz)
|
return self.__rbuf.read(sz)
|
||||||
|
|
||||||
def readFrame(self):
|
def readFrame(self):
|
||||||
buff = self.__trans.readAll(4)
|
buff = self.__trans.readAll(4)
|
||||||
sz, = unpack('!i', buff)
|
sz, = unpack('!i', buff)
|
||||||
self.__rbuf = BytesIO(self.__trans.readAll(sz))
|
self.__rbuf = BufferIO(self.__trans.readAll(sz))
|
||||||
|
|
||||||
def write(self, buf):
|
def write(self, buf):
|
||||||
self.__wbuf.write(buf)
|
self.__wbuf.write(buf)
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
wout = self.__wbuf.getvalue()
|
wout = self.__wbuf.getvalue()
|
||||||
wsz = len(wout)
|
wsz = len(wout)
|
||||||
# reset wbuf before write/flush to preserve state on underlying failure
|
# reset wbuf before write/flush to preserve state on underlying failure
|
||||||
self.__wbuf = BytesIO()
|
self.__wbuf = BufferIO()
|
||||||
# N.B.: Doing this string concatenation is WAY cheaper than making
|
# N.B.: Doing this string concatenation is WAY cheaper than making
|
||||||
# two separate calls to the underlying socket object. Socket writes in
|
# two separate calls to the underlying socket object. Socket writes in
|
||||||
# Python turn out to be REALLY expensive, but it seems to do a pretty
|
# Python turn out to be REALLY expensive, but it seems to do a pretty
|
||||||
# good job of managing string buffer operations without excessive copies
|
# good job of managing string buffer operations without excessive copies
|
||||||
buf = pack("!i", wsz) + wout
|
buf = pack("!i", wsz) + wout
|
||||||
self.__trans.write(buf)
|
self.__trans.write(buf)
|
||||||
self.__trans.flush()
|
self.__trans.flush()
|
||||||
|
|
||||||
# Implement the CReadableTransport interface.
|
# Implement the CReadableTransport interface.
|
||||||
@property
|
@property
|
||||||
def cstringio_buf(self):
|
def cstringio_buf(self):
|
||||||
return self.__rbuf
|
return self.__rbuf
|
||||||
|
|
||||||
def cstringio_refill(self, prefix, reqlen):
|
def cstringio_refill(self, prefix, reqlen):
|
||||||
# self.__rbuf will already be empty here because fastbinary doesn't
|
# self.__rbuf will already be empty here because fastbinary doesn't
|
||||||
# ask for a refill until the previous buffer is empty. Therefore,
|
# ask for a refill until the previous buffer is empty. Therefore,
|
||||||
# we can start reading new frames immediately.
|
# we can start reading new frames immediately.
|
||||||
while len(prefix) < reqlen:
|
while len(prefix) < reqlen:
|
||||||
self.readFrame()
|
self.readFrame()
|
||||||
prefix += self.__rbuf.getvalue()
|
prefix += self.__rbuf.getvalue()
|
||||||
self.__rbuf = BytesIO(prefix)
|
self.__rbuf = BufferIO(prefix)
|
||||||
return self.__rbuf
|
return self.__rbuf
|
||||||
|
|
||||||
|
|
||||||
class TFileObjectTransport(TTransportBase):
|
class TFileObjectTransport(TTransportBase):
|
||||||
"""Wraps a file-like object to make it work as a Thrift transport."""
|
"""Wraps a file-like object to make it work as a Thrift transport."""
|
||||||
|
|
||||||
def __init__(self, fileobj):
|
def __init__(self, fileobj):
|
||||||
self.fileobj = fileobj
|
self.fileobj = fileobj
|
||||||
|
|
||||||
def isOpen(self):
|
def isOpen(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.fileobj.close()
|
self.fileobj.close()
|
||||||
|
|
||||||
def read(self, sz):
|
def read(self, sz):
|
||||||
return self.fileobj.read(sz)
|
return self.fileobj.read(sz)
|
||||||
|
|
||||||
def write(self, buf):
|
def write(self, buf):
|
||||||
self.fileobj.write(buf)
|
self.fileobj.write(buf)
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
self.fileobj.flush()
|
self.fileobj.flush()
|
||||||
|
|
||||||
|
|
||||||
|
class TSaslClientTransport(TTransportBase, CReadableTransport):
|
||||||
|
"""
|
||||||
|
SASL transport
|
||||||
|
"""
|
||||||
|
|
||||||
|
START = 1
|
||||||
|
OK = 2
|
||||||
|
BAD = 3
|
||||||
|
ERROR = 4
|
||||||
|
COMPLETE = 5
|
||||||
|
|
||||||
|
def __init__(self, transport, host, service, mechanism='GSSAPI',
|
||||||
|
**sasl_kwargs):
|
||||||
|
"""
|
||||||
|
transport: an underlying transport to use, typically just a TSocket
|
||||||
|
host: the name of the server, from a SASL perspective
|
||||||
|
service: the name of the server's service, from a SASL perspective
|
||||||
|
mechanism: the name of the preferred mechanism to use
|
||||||
|
|
||||||
|
All other kwargs will be passed to the puresasl.client.SASLClient
|
||||||
|
constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from puresasl.client import SASLClient
|
||||||
|
|
||||||
|
self.transport = transport
|
||||||
|
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
|
||||||
|
|
||||||
|
self.__wbuf = BufferIO()
|
||||||
|
self.__rbuf = BufferIO(b'')
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
if not self.transport.isOpen():
|
||||||
|
self.transport.open()
|
||||||
|
|
||||||
|
self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii'))
|
||||||
|
self.send_sasl_msg(self.OK, self.sasl.process())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
status, challenge = self.recv_sasl_msg()
|
||||||
|
if status == self.OK:
|
||||||
|
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
|
||||||
|
elif status == self.COMPLETE:
|
||||||
|
if not self.sasl.complete:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.NOT_OPEN,
|
||||||
|
"The server erroneously indicated "
|
||||||
|
"that SASL negotiation was complete")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.NOT_OPEN,
|
||||||
|
"Bad SASL negotiation status: %d (%s)"
|
||||||
|
% (status, challenge))
|
||||||
|
|
||||||
|
def send_sasl_msg(self, status, body):
|
||||||
|
header = pack(">BI", status, len(body))
|
||||||
|
self.transport.write(header + body)
|
||||||
|
self.transport.flush()
|
||||||
|
|
||||||
|
def recv_sasl_msg(self):
|
||||||
|
header = self.transport.readAll(5)
|
||||||
|
status, length = unpack(">BI", header)
|
||||||
|
if length > 0:
|
||||||
|
payload = self.transport.readAll(length)
|
||||||
|
else:
|
||||||
|
payload = ""
|
||||||
|
return status, payload
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
self.__wbuf.write(data)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
data = self.__wbuf.getvalue()
|
||||||
|
encoded = self.sasl.wrap(data)
|
||||||
|
self.transport.write(pack("!i", len(encoded)) + encoded)
|
||||||
|
self.transport.flush()
|
||||||
|
self.__wbuf = BufferIO()
|
||||||
|
|
||||||
|
def read(self, sz):
|
||||||
|
ret = self.__rbuf.read(sz)
|
||||||
|
if len(ret) != 0:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
self._read_frame()
|
||||||
|
return self.__rbuf.read(sz)
|
||||||
|
|
||||||
|
def _read_frame(self):
|
||||||
|
header = self.transport.readAll(4)
|
||||||
|
length, = unpack('!i', header)
|
||||||
|
encoded = self.transport.readAll(length)
|
||||||
|
self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.sasl.dispose()
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
# based on TFramedTransport
|
||||||
|
@property
|
||||||
|
def cstringio_buf(self):
|
||||||
|
return self.__rbuf
|
||||||
|
|
||||||
|
def cstringio_refill(self, prefix, reqlen):
|
||||||
|
# self.__rbuf will already be empty here because fastbinary doesn't
|
||||||
|
# ask for a refill until the previous buffer is empty. Therefore,
|
||||||
|
# we can start reading new frames immediately.
|
||||||
|
while len(prefix) < reqlen:
|
||||||
|
self._read_frame()
|
||||||
|
prefix += self.__rbuf.getvalue()
|
||||||
|
self.__rbuf = BufferIO(prefix)
|
||||||
|
return self.__rbuf
|
||||||
|
|
|
@ -17,14 +17,15 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from io import StringIO
|
from io import BytesIO
|
||||||
|
import struct
|
||||||
|
|
||||||
from zope.interface import implements, Interface, Attribute
|
from zope.interface import implementer, Interface, Attribute
|
||||||
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
|
from twisted.internet.protocol import ServerFactory, ClientFactory, \
|
||||||
connectionDone
|
connectionDone
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.threads import deferToThread
|
||||||
from twisted.protocols import basic
|
from twisted.protocols import basic
|
||||||
from twisted.python import log
|
|
||||||
from twisted.web import server, resource, http
|
from twisted.web import server, resource, http
|
||||||
|
|
||||||
from thrift.transport import TTransport
|
from thrift.transport import TTransport
|
||||||
|
@ -33,15 +34,15 @@ from thrift.transport import TTransport
|
||||||
class TMessageSenderTransport(TTransport.TTransportBase):
|
class TMessageSenderTransport(TTransport.TTransportBase):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__wbuf = StringIO()
|
self.__wbuf = BytesIO()
|
||||||
|
|
||||||
def write(self, buf):
|
def write(self, buf):
|
||||||
self.__wbuf.write(buf)
|
self.__wbuf.write(buf)
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
msg = self.__wbuf.getvalue()
|
msg = self.__wbuf.getvalue()
|
||||||
self.__wbuf = StringIO()
|
self.__wbuf = BytesIO()
|
||||||
self.sendMessage(msg)
|
return self.sendMessage(msg)
|
||||||
|
|
||||||
def sendMessage(self, message):
|
def sendMessage(self, message):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -54,7 +55,7 @@ class TCallbackTransport(TMessageSenderTransport):
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def sendMessage(self, message):
|
def sendMessage(self, message):
|
||||||
self.func(message)
|
return self.func(message)
|
||||||
|
|
||||||
|
|
||||||
class ThriftClientProtocol(basic.Int32StringReceiver):
|
class ThriftClientProtocol(basic.Int32StringReceiver):
|
||||||
|
@ -81,11 +82,18 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
|
||||||
self.started.callback(self.client)
|
self.started.callback(self.client)
|
||||||
|
|
||||||
def connectionLost(self, reason=connectionDone):
|
def connectionLost(self, reason=connectionDone):
|
||||||
for k, v in self.client._reqs.items():
|
# the called errbacks can add items to our client's _reqs,
|
||||||
|
# so we need to use a tmp, and iterate until no more requests
|
||||||
|
# are added during errbacks
|
||||||
|
if self.client:
|
||||||
tex = TTransport.TTransportException(
|
tex = TTransport.TTransportException(
|
||||||
type=TTransport.TTransportException.END_OF_FILE,
|
type=TTransport.TTransportException.END_OF_FILE,
|
||||||
message='Connection closed')
|
message='Connection closed (%s)' % reason)
|
||||||
v.errback(tex)
|
while self.client._reqs:
|
||||||
|
_, v = self.client._reqs.popitem()
|
||||||
|
v.errback(tex)
|
||||||
|
del self.client._reqs
|
||||||
|
self.client = None
|
||||||
|
|
||||||
def stringReceived(self, frame):
|
def stringReceived(self, frame):
|
||||||
tr = TTransport.TMemoryBuffer(frame)
|
tr = TTransport.TMemoryBuffer(frame)
|
||||||
|
@ -101,6 +109,108 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
|
||||||
method(iprot, mtype, rseqid)
|
method(iprot, mtype, rseqid)
|
||||||
|
|
||||||
|
|
||||||
|
class ThriftSASLClientProtocol(ThriftClientProtocol):
|
||||||
|
|
||||||
|
START = 1
|
||||||
|
OK = 2
|
||||||
|
BAD = 3
|
||||||
|
ERROR = 4
|
||||||
|
COMPLETE = 5
|
||||||
|
|
||||||
|
MAX_LENGTH = 2 ** 31 - 1
|
||||||
|
|
||||||
|
def __init__(self, client_class, iprot_factory, oprot_factory=None,
|
||||||
|
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
|
||||||
|
"""
|
||||||
|
host: the name of the server, from a SASL perspective
|
||||||
|
service: the name of the server's service, from a SASL perspective
|
||||||
|
mechanism: the name of the preferred mechanism to use
|
||||||
|
|
||||||
|
All other kwargs will be passed to the puresasl.client.SASLClient
|
||||||
|
constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from puresasl.client import SASLClient
|
||||||
|
self.SASLCLient = SASLClient
|
||||||
|
|
||||||
|
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
|
||||||
|
|
||||||
|
self._sasl_negotiation_deferred = None
|
||||||
|
self._sasl_negotiation_status = None
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
if host is not None:
|
||||||
|
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
|
||||||
|
|
||||||
|
def createSASLClient(self, host, service, mechanism, **kwargs):
|
||||||
|
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
|
||||||
|
|
||||||
|
def dispatch(self, msg):
|
||||||
|
encoded = self.sasl.wrap(msg)
|
||||||
|
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
|
||||||
|
ThriftClientProtocol.dispatch(self, len_and_encoded)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def connectionMade(self):
|
||||||
|
self._sendSASLMessage(self.START, self.sasl.mechanism)
|
||||||
|
initial_message = yield deferToThread(self.sasl.process)
|
||||||
|
self._sendSASLMessage(self.OK, initial_message)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
status, challenge = yield self._receiveSASLMessage()
|
||||||
|
if status == self.OK:
|
||||||
|
response = yield deferToThread(self.sasl.process, challenge)
|
||||||
|
self._sendSASLMessage(self.OK, response)
|
||||||
|
elif status == self.COMPLETE:
|
||||||
|
if not self.sasl.complete:
|
||||||
|
msg = "The server erroneously indicated that SASL " \
|
||||||
|
"negotiation was complete"
|
||||||
|
raise TTransport.TTransportException(msg, message=msg)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
|
||||||
|
raise TTransport.TTransportException(msg, message=msg)
|
||||||
|
|
||||||
|
self._sasl_negotiation_deferred = None
|
||||||
|
ThriftClientProtocol.connectionMade(self)
|
||||||
|
|
||||||
|
def _sendSASLMessage(self, status, body):
|
||||||
|
if body is None:
|
||||||
|
body = ""
|
||||||
|
header = struct.pack(">BI", status, len(body))
|
||||||
|
self.transport.write(header + body)
|
||||||
|
|
||||||
|
def _receiveSASLMessage(self):
|
||||||
|
self._sasl_negotiation_deferred = defer.Deferred()
|
||||||
|
self._sasl_negotiation_status = None
|
||||||
|
return self._sasl_negotiation_deferred
|
||||||
|
|
||||||
|
def connectionLost(self, reason=connectionDone):
|
||||||
|
if self.client:
|
||||||
|
ThriftClientProtocol.connectionLost(self, reason)
|
||||||
|
|
||||||
|
def dataReceived(self, data):
|
||||||
|
if self._sasl_negotiation_deferred:
|
||||||
|
# we got a sasl challenge in the format (status, length, challenge)
|
||||||
|
# save the status, let IntNStringReceiver piece the challenge data together
|
||||||
|
self._sasl_negotiation_status, = struct.unpack("B", data[0])
|
||||||
|
ThriftClientProtocol.dataReceived(self, data[1:])
|
||||||
|
else:
|
||||||
|
# normal frame, let IntNStringReceiver piece it together
|
||||||
|
ThriftClientProtocol.dataReceived(self, data)
|
||||||
|
|
||||||
|
def stringReceived(self, frame):
|
||||||
|
if self._sasl_negotiation_deferred:
|
||||||
|
# the frame is just a SASL challenge
|
||||||
|
response = (self._sasl_negotiation_status, frame)
|
||||||
|
self._sasl_negotiation_deferred.callback(response)
|
||||||
|
else:
|
||||||
|
# there's a second 4 byte length prefix inside the frame
|
||||||
|
decoded_frame = self.sasl.unwrap(frame[4:])
|
||||||
|
ThriftClientProtocol.stringReceived(self, decoded_frame)
|
||||||
|
|
||||||
|
|
||||||
class ThriftServerProtocol(basic.Int32StringReceiver):
|
class ThriftServerProtocol(basic.Int32StringReceiver):
|
||||||
|
|
||||||
MAX_LENGTH = 2 ** 31 - 1
|
MAX_LENGTH = 2 ** 31 - 1
|
||||||
|
@ -126,7 +236,7 @@ class ThriftServerProtocol(basic.Int32StringReceiver):
|
||||||
|
|
||||||
d = self.factory.processor.process(iprot, oprot)
|
d = self.factory.processor.process(iprot, oprot)
|
||||||
d.addCallbacks(self.processOk, self.processError,
|
d.addCallbacks(self.processOk, self.processError,
|
||||||
callbackArgs=(tmo,))
|
callbackArgs=(tmo,))
|
||||||
|
|
||||||
|
|
||||||
class IThriftServerFactory(Interface):
|
class IThriftServerFactory(Interface):
|
||||||
|
@ -147,10 +257,9 @@ class IThriftClientFactory(Interface):
|
||||||
oprot_factory = Attribute("Output protocol factory")
|
oprot_factory = Attribute("Output protocol factory")
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IThriftServerFactory)
|
||||||
class ThriftServerFactory(ServerFactory):
|
class ThriftServerFactory(ServerFactory):
|
||||||
|
|
||||||
implements(IThriftServerFactory)
|
|
||||||
|
|
||||||
protocol = ThriftServerProtocol
|
protocol = ThriftServerProtocol
|
||||||
|
|
||||||
def __init__(self, processor, iprot_factory, oprot_factory=None):
|
def __init__(self, processor, iprot_factory, oprot_factory=None):
|
||||||
|
@ -162,10 +271,9 @@ class ThriftServerFactory(ServerFactory):
|
||||||
self.oprot_factory = oprot_factory
|
self.oprot_factory = oprot_factory
|
||||||
|
|
||||||
|
|
||||||
|
@implementer(IThriftClientFactory)
|
||||||
class ThriftClientFactory(ClientFactory):
|
class ThriftClientFactory(ClientFactory):
|
||||||
|
|
||||||
implements(IThriftClientFactory)
|
|
||||||
|
|
||||||
protocol = ThriftClientProtocol
|
protocol = ThriftClientProtocol
|
||||||
|
|
||||||
def __init__(self, client_class, iprot_factory, oprot_factory=None):
|
def __init__(self, client_class, iprot_factory, oprot_factory=None):
|
||||||
|
@ -178,7 +286,7 @@ class ThriftClientFactory(ClientFactory):
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr):
|
||||||
p = self.protocol(self.client_class, self.iprot_factory,
|
p = self.protocol(self.client_class, self.iprot_factory,
|
||||||
self.oprot_factory)
|
self.oprot_factory)
|
||||||
p.factory = self
|
p.factory = self
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
@ -188,7 +296,7 @@ class ThriftResource(resource.Resource):
|
||||||
allowedMethods = ('POST',)
|
allowedMethods = ('POST',)
|
||||||
|
|
||||||
def __init__(self, processor, inputProtocolFactory,
|
def __init__(self, processor, inputProtocolFactory,
|
||||||
outputProtocolFactory=None):
|
outputProtocolFactory=None):
|
||||||
resource.Resource.__init__(self)
|
resource.Resource.__init__(self)
|
||||||
self.inputProtocolFactory = inputProtocolFactory
|
self.inputProtocolFactory = inputProtocolFactory
|
||||||
if outputProtocolFactory is None:
|
if outputProtocolFactory is None:
|
||||||
|
|
|
@ -22,227 +22,227 @@ class, using the python standard library zlib module to implement
|
||||||
data compression.
|
data compression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import division
|
||||||
import zlib
|
import zlib
|
||||||
from io import StringIO
|
|
||||||
from .TTransport import TTransportBase, CReadableTransport
|
from .TTransport import TTransportBase, CReadableTransport
|
||||||
|
from ..compat import BufferIO
|
||||||
|
|
||||||
|
|
||||||
class TZlibTransportFactory(object):
|
class TZlibTransportFactory(object):
|
||||||
"""Factory transport that builds zlib compressed transports.
|
"""Factory transport that builds zlib compressed transports.
|
||||||
|
|
||||||
This factory caches the last single client/transport that it was passed
|
This factory caches the last single client/transport that it was passed
|
||||||
and returns the same TZlibTransport object that was created.
|
and returns the same TZlibTransport object that was created.
|
||||||
|
|
||||||
This caching means the TServer class will get the _same_ transport
|
This caching means the TServer class will get the _same_ transport
|
||||||
object for both input and output transports from this factory.
|
object for both input and output transports from this factory.
|
||||||
(For non-threaded scenarios only, since the cache only holds one object)
|
(For non-threaded scenarios only, since the cache only holds one object)
|
||||||
|
|
||||||
The purpose of this caching is to allocate only one TZlibTransport where
|
The purpose of this caching is to allocate only one TZlibTransport where
|
||||||
only one is really needed (since it must have separate read/write buffers),
|
only one is really needed (since it must have separate read/write buffers),
|
||||||
and makes the statistics from getCompSavings() and getCompRatio()
|
and makes the statistics from getCompSavings() and getCompRatio()
|
||||||
easier to understand.
|
easier to understand.
|
||||||
"""
|
|
||||||
# class scoped cache of last transport given and zlibtransport returned
|
|
||||||
_last_trans = None
|
|
||||||
_last_z = None
|
|
||||||
|
|
||||||
def getTransport(self, trans, compresslevel=9):
|
|
||||||
"""Wrap a transport, trans, with the TZlibTransport
|
|
||||||
compressed transport class, returning a new
|
|
||||||
transport to the caller.
|
|
||||||
|
|
||||||
@param compresslevel: The zlib compression level, ranging
|
|
||||||
from 0 (no compression) to 9 (best compression). Defaults to 9.
|
|
||||||
@type compresslevel: int
|
|
||||||
|
|
||||||
This method returns a TZlibTransport which wraps the
|
|
||||||
passed C{trans} TTransport derived instance.
|
|
||||||
"""
|
"""
|
||||||
if trans == self._last_trans:
|
# class scoped cache of last transport given and zlibtransport returned
|
||||||
return self._last_z
|
_last_trans = None
|
||||||
ztrans = TZlibTransport(trans, compresslevel)
|
_last_z = None
|
||||||
self._last_trans = trans
|
|
||||||
self._last_z = ztrans
|
def getTransport(self, trans, compresslevel=9):
|
||||||
return ztrans
|
"""Wrap a transport, trans, with the TZlibTransport
|
||||||
|
compressed transport class, returning a new
|
||||||
|
transport to the caller.
|
||||||
|
|
||||||
|
@param compresslevel: The zlib compression level, ranging
|
||||||
|
from 0 (no compression) to 9 (best compression). Defaults to 9.
|
||||||
|
@type compresslevel: int
|
||||||
|
|
||||||
|
This method returns a TZlibTransport which wraps the
|
||||||
|
passed C{trans} TTransport derived instance.
|
||||||
|
"""
|
||||||
|
if trans == self._last_trans:
|
||||||
|
return self._last_z
|
||||||
|
ztrans = TZlibTransport(trans, compresslevel)
|
||||||
|
self._last_trans = trans
|
||||||
|
self._last_z = ztrans
|
||||||
|
return ztrans
|
||||||
|
|
||||||
|
|
||||||
class TZlibTransport(TTransportBase, CReadableTransport):
|
class TZlibTransport(TTransportBase, CReadableTransport):
|
||||||
"""Class that wraps a transport with zlib, compressing writes
|
"""Class that wraps a transport with zlib, compressing writes
|
||||||
and decompresses reads, using the python standard
|
and decompresses reads, using the python standard
|
||||||
library zlib module.
|
library zlib module.
|
||||||
"""
|
|
||||||
# Read buffer size for the python fastbinary C extension,
|
|
||||||
# the TBinaryProtocolAccelerated class.
|
|
||||||
DEFAULT_BUFFSIZE = 4096
|
|
||||||
|
|
||||||
def __init__(self, trans, compresslevel=9):
|
|
||||||
"""Create a new TZlibTransport, wrapping C{trans}, another
|
|
||||||
TTransport derived object.
|
|
||||||
|
|
||||||
@param trans: A thrift transport object, i.e. a TSocket() object.
|
|
||||||
@type trans: TTransport
|
|
||||||
@param compresslevel: The zlib compression level, ranging
|
|
||||||
from 0 (no compression) to 9 (best compression). Default is 9.
|
|
||||||
@type compresslevel: int
|
|
||||||
"""
|
"""
|
||||||
self.__trans = trans
|
# Read buffer size for the python fastbinary C extension,
|
||||||
self.compresslevel = compresslevel
|
# the TBinaryProtocolAccelerated class.
|
||||||
self.__rbuf = StringIO()
|
DEFAULT_BUFFSIZE = 4096
|
||||||
self.__wbuf = StringIO()
|
|
||||||
self._init_zlib()
|
|
||||||
self._init_stats()
|
|
||||||
|
|
||||||
def _reinit_buffers(self):
|
def __init__(self, trans, compresslevel=9):
|
||||||
"""Internal method to initialize/reset the internal StringIO objects
|
"""Create a new TZlibTransport, wrapping C{trans}, another
|
||||||
for read and write buffers.
|
TTransport derived object.
|
||||||
"""
|
|
||||||
self.__rbuf = StringIO()
|
|
||||||
self.__wbuf = StringIO()
|
|
||||||
|
|
||||||
def _init_stats(self):
|
@param trans: A thrift transport object, i.e. a TSocket() object.
|
||||||
"""Internal method to reset the internal statistics counters
|
@type trans: TTransport
|
||||||
for compression ratios and bandwidth savings.
|
@param compresslevel: The zlib compression level, ranging
|
||||||
"""
|
from 0 (no compression) to 9 (best compression). Default is 9.
|
||||||
self.bytes_in = 0
|
@type compresslevel: int
|
||||||
self.bytes_out = 0
|
"""
|
||||||
self.bytes_in_comp = 0
|
self.__trans = trans
|
||||||
self.bytes_out_comp = 0
|
self.compresslevel = compresslevel
|
||||||
|
self.__rbuf = BufferIO()
|
||||||
|
self.__wbuf = BufferIO()
|
||||||
|
self._init_zlib()
|
||||||
|
self._init_stats()
|
||||||
|
|
||||||
def _init_zlib(self):
|
def _reinit_buffers(self):
|
||||||
"""Internal method for setting up the zlib compression and
|
"""Internal method to initialize/reset the internal StringIO objects
|
||||||
decompression objects.
|
for read and write buffers.
|
||||||
"""
|
"""
|
||||||
self._zcomp_read = zlib.decompressobj()
|
self.__rbuf = BufferIO()
|
||||||
self._zcomp_write = zlib.compressobj(self.compresslevel)
|
self.__wbuf = BufferIO()
|
||||||
|
|
||||||
def getCompRatio(self):
|
def _init_stats(self):
|
||||||
"""Get the current measured compression ratios (in,out) from
|
"""Internal method to reset the internal statistics counters
|
||||||
this transport.
|
for compression ratios and bandwidth savings.
|
||||||
|
"""
|
||||||
|
self.bytes_in = 0
|
||||||
|
self.bytes_out = 0
|
||||||
|
self.bytes_in_comp = 0
|
||||||
|
self.bytes_out_comp = 0
|
||||||
|
|
||||||
Returns a tuple of:
|
def _init_zlib(self):
|
||||||
(inbound_compression_ratio, outbound_compression_ratio)
|
"""Internal method for setting up the zlib compression and
|
||||||
|
decompression objects.
|
||||||
|
"""
|
||||||
|
self._zcomp_read = zlib.decompressobj()
|
||||||
|
self._zcomp_write = zlib.compressobj(self.compresslevel)
|
||||||
|
|
||||||
The compression ratios are computed as:
|
def getCompRatio(self):
|
||||||
compressed / uncompressed
|
"""Get the current measured compression ratios (in,out) from
|
||||||
|
this transport.
|
||||||
|
|
||||||
E.g., data that compresses by 10x will have a ratio of: 0.10
|
Returns a tuple of:
|
||||||
and data that compresses to half of ts original size will
|
(inbound_compression_ratio, outbound_compression_ratio)
|
||||||
have a ratio of 0.5
|
|
||||||
|
|
||||||
None is returned if no bytes have yet been processed in
|
The compression ratios are computed as:
|
||||||
a particular direction.
|
compressed / uncompressed
|
||||||
"""
|
|
||||||
r_percent, w_percent = (None, None)
|
|
||||||
if self.bytes_in > 0:
|
|
||||||
r_percent = self.bytes_in_comp / self.bytes_in
|
|
||||||
if self.bytes_out > 0:
|
|
||||||
w_percent = self.bytes_out_comp / self.bytes_out
|
|
||||||
return (r_percent, w_percent)
|
|
||||||
|
|
||||||
def getCompSavings(self):
|
E.g., data that compresses by 10x will have a ratio of: 0.10
|
||||||
"""Get the current count of saved bytes due to data
|
and data that compresses to half of ts original size will
|
||||||
compression.
|
have a ratio of 0.5
|
||||||
|
|
||||||
Returns a tuple of:
|
None is returned if no bytes have yet been processed in
|
||||||
(inbound_saved_bytes, outbound_saved_bytes)
|
a particular direction.
|
||||||
|
"""
|
||||||
|
r_percent, w_percent = (None, None)
|
||||||
|
if self.bytes_in > 0:
|
||||||
|
r_percent = self.bytes_in_comp / self.bytes_in
|
||||||
|
if self.bytes_out > 0:
|
||||||
|
w_percent = self.bytes_out_comp / self.bytes_out
|
||||||
|
return (r_percent, w_percent)
|
||||||
|
|
||||||
Note: if compression is actually expanding your
|
def getCompSavings(self):
|
||||||
data (only likely with very tiny thrift objects), then
|
"""Get the current count of saved bytes due to data
|
||||||
the values returned will be negative.
|
compression.
|
||||||
"""
|
|
||||||
r_saved = self.bytes_in - self.bytes_in_comp
|
|
||||||
w_saved = self.bytes_out - self.bytes_out_comp
|
|
||||||
return (r_saved, w_saved)
|
|
||||||
|
|
||||||
def isOpen(self):
|
Returns a tuple of:
|
||||||
"""Return the underlying transport's open status"""
|
(inbound_saved_bytes, outbound_saved_bytes)
|
||||||
return self.__trans.isOpen()
|
|
||||||
|
|
||||||
def open(self):
|
Note: if compression is actually expanding your
|
||||||
"""Open the underlying transport"""
|
data (only likely with very tiny thrift objects), then
|
||||||
self._init_stats()
|
the values returned will be negative.
|
||||||
return self.__trans.open()
|
"""
|
||||||
|
r_saved = self.bytes_in - self.bytes_in_comp
|
||||||
|
w_saved = self.bytes_out - self.bytes_out_comp
|
||||||
|
return (r_saved, w_saved)
|
||||||
|
|
||||||
def listen(self):
|
def isOpen(self):
|
||||||
"""Invoke the underlying transport's listen() method"""
|
"""Return the underlying transport's open status"""
|
||||||
self.__trans.listen()
|
return self.__trans.isOpen()
|
||||||
|
|
||||||
def accept(self):
|
def open(self):
|
||||||
"""Accept connections on the underlying transport"""
|
"""Open the underlying transport"""
|
||||||
return self.__trans.accept()
|
self._init_stats()
|
||||||
|
return self.__trans.open()
|
||||||
|
|
||||||
def close(self):
|
def listen(self):
|
||||||
"""Close the underlying transport,"""
|
"""Invoke the underlying transport's listen() method"""
|
||||||
self._reinit_buffers()
|
self.__trans.listen()
|
||||||
self._init_zlib()
|
|
||||||
return self.__trans.close()
|
|
||||||
|
|
||||||
def read(self, sz):
|
def accept(self):
|
||||||
"""Read up to sz bytes from the decompressed bytes buffer, and
|
"""Accept connections on the underlying transport"""
|
||||||
read from the underlying transport if the decompression
|
return self.__trans.accept()
|
||||||
buffer is empty.
|
|
||||||
"""
|
|
||||||
ret = self.__rbuf.read(sz)
|
|
||||||
if len(ret) > 0:
|
|
||||||
return ret
|
|
||||||
# keep reading from transport until something comes back
|
|
||||||
while True:
|
|
||||||
if self.readComp(sz):
|
|
||||||
break
|
|
||||||
ret = self.__rbuf.read(sz)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def readComp(self, sz):
|
def close(self):
|
||||||
"""Read compressed data from the underlying transport, then
|
"""Close the underlying transport,"""
|
||||||
decompress it and append it to the internal StringIO read buffer
|
self._reinit_buffers()
|
||||||
"""
|
self._init_zlib()
|
||||||
zbuf = self.__trans.read(sz)
|
return self.__trans.close()
|
||||||
zbuf = self._zcomp_read.unconsumed_tail + zbuf
|
|
||||||
buf = self._zcomp_read.decompress(zbuf)
|
|
||||||
self.bytes_in += len(zbuf)
|
|
||||||
self.bytes_in_comp += len(buf)
|
|
||||||
old = self.__rbuf.read()
|
|
||||||
self.__rbuf = StringIO(old + buf)
|
|
||||||
if len(old) + len(buf) == 0:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def write(self, buf):
|
def read(self, sz):
|
||||||
"""Write some bytes, putting them into the internal write
|
"""Read up to sz bytes from the decompressed bytes buffer, and
|
||||||
buffer for eventual compression.
|
read from the underlying transport if the decompression
|
||||||
"""
|
buffer is empty.
|
||||||
self.__wbuf.write(buf)
|
"""
|
||||||
|
ret = self.__rbuf.read(sz)
|
||||||
|
if len(ret) > 0:
|
||||||
|
return ret
|
||||||
|
# keep reading from transport until something comes back
|
||||||
|
while True:
|
||||||
|
if self.readComp(sz):
|
||||||
|
break
|
||||||
|
ret = self.__rbuf.read(sz)
|
||||||
|
return ret
|
||||||
|
|
||||||
def flush(self):
|
def readComp(self, sz):
|
||||||
"""Flush any queued up data in the write buffer and ensure the
|
"""Read compressed data from the underlying transport, then
|
||||||
compression buffer is flushed out to the underlying transport
|
decompress it and append it to the internal StringIO read buffer
|
||||||
"""
|
"""
|
||||||
wout = self.__wbuf.getvalue()
|
zbuf = self.__trans.read(sz)
|
||||||
if len(wout) > 0:
|
zbuf = self._zcomp_read.unconsumed_tail + zbuf
|
||||||
zbuf = self._zcomp_write.compress(wout)
|
buf = self._zcomp_read.decompress(zbuf)
|
||||||
self.bytes_out += len(wout)
|
self.bytes_in += len(zbuf)
|
||||||
self.bytes_out_comp += len(zbuf)
|
self.bytes_in_comp += len(buf)
|
||||||
else:
|
old = self.__rbuf.read()
|
||||||
zbuf = ''
|
self.__rbuf = BufferIO(old + buf)
|
||||||
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
|
if len(old) + len(buf) == 0:
|
||||||
self.bytes_out_comp += len(ztail)
|
return False
|
||||||
if (len(zbuf) + len(ztail)) > 0:
|
return True
|
||||||
self.__wbuf = StringIO()
|
|
||||||
self.__trans.write(zbuf + ztail)
|
|
||||||
self.__trans.flush()
|
|
||||||
|
|
||||||
@property
|
def write(self, buf):
|
||||||
def cstringio_buf(self):
|
"""Write some bytes, putting them into the internal write
|
||||||
"""Implement the CReadableTransport interface"""
|
buffer for eventual compression.
|
||||||
return self.__rbuf
|
"""
|
||||||
|
self.__wbuf.write(buf)
|
||||||
|
|
||||||
def cstringio_refill(self, partialread, reqlen):
|
def flush(self):
|
||||||
"""Implement the CReadableTransport interface for refill"""
|
"""Flush any queued up data in the write buffer and ensure the
|
||||||
retstring = partialread
|
compression buffer is flushed out to the underlying transport
|
||||||
if reqlen < self.DEFAULT_BUFFSIZE:
|
"""
|
||||||
retstring += self.read(self.DEFAULT_BUFFSIZE)
|
wout = self.__wbuf.getvalue()
|
||||||
while len(retstring) < reqlen:
|
if len(wout) > 0:
|
||||||
retstring += self.read(reqlen - len(retstring))
|
zbuf = self._zcomp_write.compress(wout)
|
||||||
self.__rbuf = StringIO(retstring)
|
self.bytes_out += len(wout)
|
||||||
return self.__rbuf
|
self.bytes_out_comp += len(zbuf)
|
||||||
|
else:
|
||||||
|
zbuf = ''
|
||||||
|
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
|
||||||
|
self.bytes_out_comp += len(ztail)
|
||||||
|
if (len(zbuf) + len(ztail)) > 0:
|
||||||
|
self.__wbuf = BufferIO()
|
||||||
|
self.__trans.write(zbuf + ztail)
|
||||||
|
self.__trans.flush()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cstringio_buf(self):
|
||||||
|
"""Implement the CReadableTransport interface"""
|
||||||
|
return self.__rbuf
|
||||||
|
|
||||||
|
def cstringio_refill(self, partialread, reqlen):
|
||||||
|
"""Implement the CReadableTransport interface for refill"""
|
||||||
|
retstring = partialread
|
||||||
|
if reqlen < self.DEFAULT_BUFFSIZE:
|
||||||
|
retstring += self.read(self.DEFAULT_BUFFSIZE)
|
||||||
|
while len(retstring) < reqlen:
|
||||||
|
retstring += self.read(reqlen - len(retstring))
|
||||||
|
self.__rbuf = BufferIO(retstring)
|
||||||
|
return self.__rbuf
|
||||||
|
|
100
thrift/transport/sslcompat.py
Normal file
100
thrift/transport/sslcompat.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
#
|
||||||
|
# licensed to the apache software foundation (asf) under one
|
||||||
|
# or more contributor license agreements. see the notice file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. the asf licenses this file
|
||||||
|
# to you under the apache license, version 2.0 (the
|
||||||
|
# "license"); you may not use this file except in compliance
|
||||||
|
# with the license. you may obtain a copy of the license at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/license-2.0
|
||||||
|
#
|
||||||
|
# unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the license is distributed on an
|
||||||
|
# "as is" basis, without warranties or conditions of any
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from thrift.transport.TTransport import TTransportException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_validate_callback(cert, hostname):
|
||||||
|
"""legacy method to validate the peer's SSL certificate, and to check
|
||||||
|
the commonName of the certificate to ensure it matches the hostname we
|
||||||
|
used to make this connection. Does not support subjectAltName records
|
||||||
|
in certificates.
|
||||||
|
|
||||||
|
raises TTransportException if the certificate fails validation.
|
||||||
|
"""
|
||||||
|
if 'subject' not in cert:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.NOT_OPEN,
|
||||||
|
'No SSL certificate found from %s' % hostname)
|
||||||
|
fields = cert['subject']
|
||||||
|
for field in fields:
|
||||||
|
# ensure structure we get back is what we expect
|
||||||
|
if not isinstance(field, tuple):
|
||||||
|
continue
|
||||||
|
cert_pair = field[0]
|
||||||
|
if len(cert_pair) < 2:
|
||||||
|
continue
|
||||||
|
cert_key, cert_value = cert_pair[0:2]
|
||||||
|
if cert_key != 'commonName':
|
||||||
|
continue
|
||||||
|
certhost = cert_value
|
||||||
|
# this check should be performed by some sort of Access Manager
|
||||||
|
if certhost == hostname:
|
||||||
|
# success, cert commonName matches desired hostname
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.UNKNOWN,
|
||||||
|
'Hostname we connected to "%s" doesn\'t match certificate '
|
||||||
|
'provided commonName "%s"' % (hostname, certhost))
|
||||||
|
raise TTransportException(
|
||||||
|
TTransportException.UNKNOWN,
|
||||||
|
'Could not validate SSL certificate from host "%s". Cert=%s'
|
||||||
|
% (hostname, cert))
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_dependencies():
|
||||||
|
try:
|
||||||
|
import ipaddress # noqa
|
||||||
|
logger.debug('ipaddress module is available')
|
||||||
|
ipaddr = True
|
||||||
|
except ImportError:
|
||||||
|
logger.warn('ipaddress module is unavailable')
|
||||||
|
ipaddr = False
|
||||||
|
|
||||||
|
if sys.hexversion < 0x030500F0:
|
||||||
|
try:
|
||||||
|
from backports.ssl_match_hostname import match_hostname, __version__ as ver
|
||||||
|
ver = list(map(int, ver.split('.')))
|
||||||
|
logger.debug('backports.ssl_match_hostname module is available')
|
||||||
|
match = match_hostname
|
||||||
|
if ver[0] * 10 + ver[1] >= 35:
|
||||||
|
return ipaddr, match
|
||||||
|
else:
|
||||||
|
logger.warn('backports.ssl_match_hostname module is too old')
|
||||||
|
ipaddr = False
|
||||||
|
except ImportError:
|
||||||
|
logger.warn('backports.ssl_match_hostname is unavailable')
|
||||||
|
ipaddr = False
|
||||||
|
try:
|
||||||
|
from ssl import match_hostname
|
||||||
|
logger.debug('ssl.match_hostname is available')
|
||||||
|
match = match_hostname
|
||||||
|
except ImportError:
|
||||||
|
logger.warn('using legacy validation callback')
|
||||||
|
match = legacy_validate_callback
|
||||||
|
return ipaddr, match
|
||||||
|
|
||||||
|
|
||||||
|
_match_has_ipaddress, _match_hostname = _optional_dependencies()
|
Loading…
Add table
Reference in a new issue