mirror of
https://github.com/Unidata/python-awips.git
synced 2025-02-23 22:57:56 -05:00
- Brought over all new thrift files... had to untar and unzip the thrift package in awips2-rpm - then go into /lib/py/ and run `python setup.py build` - then copy all of the files that get put in the subdirectory in /build - Replaced DataAccessLayer.py with the current one from our v18.1.11 of python-awips
329 lines
11 KiB
Python
329 lines
11 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
|
|
from io import BytesIO
|
|
import struct
|
|
|
|
from zope.interface import implementer, Interface, Attribute
|
|
from twisted.internet.protocol import ServerFactory, ClientFactory, \
|
|
connectionDone
|
|
from twisted.internet import defer
|
|
from twisted.internet.threads import deferToThread
|
|
from twisted.protocols import basic
|
|
from twisted.web import server, resource, http
|
|
|
|
from thrift.transport import TTransport
|
|
|
|
|
|
class TMessageSenderTransport(TTransport.TTransportBase):
|
|
|
|
def __init__(self):
|
|
self.__wbuf = BytesIO()
|
|
|
|
def write(self, buf):
|
|
self.__wbuf.write(buf)
|
|
|
|
def flush(self):
|
|
msg = self.__wbuf.getvalue()
|
|
self.__wbuf = BytesIO()
|
|
return self.sendMessage(msg)
|
|
|
|
def sendMessage(self, message):
|
|
raise NotImplementedError
|
|
|
|
|
|
class TCallbackTransport(TMessageSenderTransport):
|
|
|
|
def __init__(self, func):
|
|
TMessageSenderTransport.__init__(self)
|
|
self.func = func
|
|
|
|
def sendMessage(self, message):
|
|
return self.func(message)
|
|
|
|
|
|
class ThriftClientProtocol(basic.Int32StringReceiver):
|
|
|
|
MAX_LENGTH = 2 ** 31 - 1
|
|
|
|
def __init__(self, client_class, iprot_factory, oprot_factory=None):
|
|
self._client_class = client_class
|
|
self._iprot_factory = iprot_factory
|
|
if oprot_factory is None:
|
|
self._oprot_factory = iprot_factory
|
|
else:
|
|
self._oprot_factory = oprot_factory
|
|
|
|
self.recv_map = {}
|
|
self.started = defer.Deferred()
|
|
|
|
def dispatch(self, msg):
|
|
self.sendString(msg)
|
|
|
|
def connectionMade(self):
|
|
tmo = TCallbackTransport(self.dispatch)
|
|
self.client = self._client_class(tmo, self._oprot_factory)
|
|
self.started.callback(self.client)
|
|
|
|
def connectionLost(self, reason=connectionDone):
|
|
# the called errbacks can add items to our client's _reqs,
|
|
# so we need to use a tmp, and iterate until no more requests
|
|
# are added during errbacks
|
|
if self.client:
|
|
tex = TTransport.TTransportException(
|
|
type=TTransport.TTransportException.END_OF_FILE,
|
|
message='Connection closed (%s)' % reason)
|
|
while self.client._reqs:
|
|
_, v = self.client._reqs.popitem()
|
|
v.errback(tex)
|
|
del self.client._reqs
|
|
self.client = None
|
|
|
|
def stringReceived(self, frame):
|
|
tr = TTransport.TMemoryBuffer(frame)
|
|
iprot = self._iprot_factory.getProtocol(tr)
|
|
(fname, mtype, rseqid) = iprot.readMessageBegin()
|
|
|
|
try:
|
|
method = self.recv_map[fname]
|
|
except KeyError:
|
|
method = getattr(self.client, 'recv_' + fname)
|
|
self.recv_map[fname] = method
|
|
|
|
method(iprot, mtype, rseqid)
|
|
|
|
|
|
class ThriftSASLClientProtocol(ThriftClientProtocol):
|
|
|
|
START = 1
|
|
OK = 2
|
|
BAD = 3
|
|
ERROR = 4
|
|
COMPLETE = 5
|
|
|
|
MAX_LENGTH = 2 ** 31 - 1
|
|
|
|
def __init__(self, client_class, iprot_factory, oprot_factory=None,
|
|
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
|
|
"""
|
|
host: the name of the server, from a SASL perspective
|
|
service: the name of the server's service, from a SASL perspective
|
|
mechanism: the name of the preferred mechanism to use
|
|
|
|
All other kwargs will be passed to the puresasl.client.SASLClient
|
|
constructor.
|
|
"""
|
|
|
|
from puresasl.client import SASLClient
|
|
self.SASLCLient = SASLClient
|
|
|
|
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
|
|
|
|
self._sasl_negotiation_deferred = None
|
|
self._sasl_negotiation_status = None
|
|
self.client = None
|
|
|
|
if host is not None:
|
|
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
|
|
|
|
def createSASLClient(self, host, service, mechanism, **kwargs):
|
|
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
|
|
|
|
def dispatch(self, msg):
|
|
encoded = self.sasl.wrap(msg)
|
|
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
|
|
ThriftClientProtocol.dispatch(self, len_and_encoded)
|
|
|
|
@defer.inlineCallbacks
|
|
def connectionMade(self):
|
|
self._sendSASLMessage(self.START, self.sasl.mechanism)
|
|
initial_message = yield deferToThread(self.sasl.process)
|
|
self._sendSASLMessage(self.OK, initial_message)
|
|
|
|
while True:
|
|
status, challenge = yield self._receiveSASLMessage()
|
|
if status == self.OK:
|
|
response = yield deferToThread(self.sasl.process, challenge)
|
|
self._sendSASLMessage(self.OK, response)
|
|
elif status == self.COMPLETE:
|
|
if not self.sasl.complete:
|
|
msg = "The server erroneously indicated that SASL " \
|
|
"negotiation was complete"
|
|
raise TTransport.TTransportException(msg, message=msg)
|
|
else:
|
|
break
|
|
else:
|
|
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
|
|
raise TTransport.TTransportException(msg, message=msg)
|
|
|
|
self._sasl_negotiation_deferred = None
|
|
ThriftClientProtocol.connectionMade(self)
|
|
|
|
def _sendSASLMessage(self, status, body):
|
|
if body is None:
|
|
body = ""
|
|
header = struct.pack(">BI", status, len(body))
|
|
self.transport.write(header + body)
|
|
|
|
def _receiveSASLMessage(self):
|
|
self._sasl_negotiation_deferred = defer.Deferred()
|
|
self._sasl_negotiation_status = None
|
|
return self._sasl_negotiation_deferred
|
|
|
|
def connectionLost(self, reason=connectionDone):
|
|
if self.client:
|
|
ThriftClientProtocol.connectionLost(self, reason)
|
|
|
|
def dataReceived(self, data):
|
|
if self._sasl_negotiation_deferred:
|
|
# we got a sasl challenge in the format (status, length, challenge)
|
|
# save the status, let IntNStringReceiver piece the challenge data together
|
|
self._sasl_negotiation_status, = struct.unpack("B", data[0])
|
|
ThriftClientProtocol.dataReceived(self, data[1:])
|
|
else:
|
|
# normal frame, let IntNStringReceiver piece it together
|
|
ThriftClientProtocol.dataReceived(self, data)
|
|
|
|
def stringReceived(self, frame):
|
|
if self._sasl_negotiation_deferred:
|
|
# the frame is just a SASL challenge
|
|
response = (self._sasl_negotiation_status, frame)
|
|
self._sasl_negotiation_deferred.callback(response)
|
|
else:
|
|
# there's a second 4 byte length prefix inside the frame
|
|
decoded_frame = self.sasl.unwrap(frame[4:])
|
|
ThriftClientProtocol.stringReceived(self, decoded_frame)
|
|
|
|
|
|
class ThriftServerProtocol(basic.Int32StringReceiver):
|
|
|
|
MAX_LENGTH = 2 ** 31 - 1
|
|
|
|
def dispatch(self, msg):
|
|
self.sendString(msg)
|
|
|
|
def processError(self, error):
|
|
self.transport.loseConnection()
|
|
|
|
def processOk(self, _, tmo):
|
|
msg = tmo.getvalue()
|
|
|
|
if len(msg) > 0:
|
|
self.dispatch(msg)
|
|
|
|
def stringReceived(self, frame):
|
|
tmi = TTransport.TMemoryBuffer(frame)
|
|
tmo = TTransport.TMemoryBuffer()
|
|
|
|
iprot = self.factory.iprot_factory.getProtocol(tmi)
|
|
oprot = self.factory.oprot_factory.getProtocol(tmo)
|
|
|
|
d = self.factory.processor.process(iprot, oprot)
|
|
d.addCallbacks(self.processOk, self.processError,
|
|
callbackArgs=(tmo,))
|
|
|
|
|
|
class IThriftServerFactory(Interface):
|
|
|
|
processor = Attribute("Thrift processor")
|
|
|
|
iprot_factory = Attribute("Input protocol factory")
|
|
|
|
oprot_factory = Attribute("Output protocol factory")
|
|
|
|
|
|
class IThriftClientFactory(Interface):
|
|
|
|
client_class = Attribute("Thrift client class")
|
|
|
|
iprot_factory = Attribute("Input protocol factory")
|
|
|
|
oprot_factory = Attribute("Output protocol factory")
|
|
|
|
|
|
@implementer(IThriftServerFactory)
|
|
class ThriftServerFactory(ServerFactory):
|
|
|
|
protocol = ThriftServerProtocol
|
|
|
|
def __init__(self, processor, iprot_factory, oprot_factory=None):
|
|
self.processor = processor
|
|
self.iprot_factory = iprot_factory
|
|
if oprot_factory is None:
|
|
self.oprot_factory = iprot_factory
|
|
else:
|
|
self.oprot_factory = oprot_factory
|
|
|
|
|
|
@implementer(IThriftClientFactory)
|
|
class ThriftClientFactory(ClientFactory):
|
|
|
|
protocol = ThriftClientProtocol
|
|
|
|
def __init__(self, client_class, iprot_factory, oprot_factory=None):
|
|
self.client_class = client_class
|
|
self.iprot_factory = iprot_factory
|
|
if oprot_factory is None:
|
|
self.oprot_factory = iprot_factory
|
|
else:
|
|
self.oprot_factory = oprot_factory
|
|
|
|
def buildProtocol(self, addr):
|
|
p = self.protocol(self.client_class, self.iprot_factory,
|
|
self.oprot_factory)
|
|
p.factory = self
|
|
return p
|
|
|
|
|
|
class ThriftResource(resource.Resource):
|
|
|
|
allowedMethods = ('POST',)
|
|
|
|
def __init__(self, processor, inputProtocolFactory,
|
|
outputProtocolFactory=None):
|
|
resource.Resource.__init__(self)
|
|
self.inputProtocolFactory = inputProtocolFactory
|
|
if outputProtocolFactory is None:
|
|
self.outputProtocolFactory = inputProtocolFactory
|
|
else:
|
|
self.outputProtocolFactory = outputProtocolFactory
|
|
self.processor = processor
|
|
|
|
def getChild(self, path, request):
|
|
return self
|
|
|
|
def _cbProcess(self, _, request, tmo):
|
|
msg = tmo.getvalue()
|
|
request.setResponseCode(http.OK)
|
|
request.setHeader("content-type", "application/x-thrift")
|
|
request.write(msg)
|
|
request.finish()
|
|
|
|
def render_POST(self, request):
|
|
request.content.seek(0, 0)
|
|
data = request.content.read()
|
|
tmi = TTransport.TMemoryBuffer(data)
|
|
tmo = TTransport.TMemoryBuffer()
|
|
|
|
iprot = self.inputProtocolFactory.getProtocol(tmi)
|
|
oprot = self.outputProtocolFactory.getProtocol(tmo)
|
|
|
|
d = self.processor.process(iprot, oprot)
|
|
d.addCallback(self._cbProcess, request, tmo)
|
|
return server.NOT_DONE_YET
|