1187 lines
39 KiB
Python
1187 lines
39 KiB
Python
# Copyright 2016 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# Copyright 2002, Google LLC.
|
|
|
|
|
|
from __future__ import absolute_import
|
|
import array
|
|
import itertools
|
|
import re
|
|
import struct
|
|
import six
|
|
import six.moves.http_client
|
|
|
|
try:
|
|
# NOTE(user): Using non-google-style import to workaround a zipimport_tinypar
|
|
# issue for zip files embedded in par files. See http://b/13811096
|
|
import googlecloudsdk.appengine.proto.proto1 as proto1
|
|
except ImportError:
|
|
# Protect in case of missing deps / strange env (GAE?) / etc.
|
|
class ProtocolBufferDecodeError(Exception): pass
|
|
class ProtocolBufferEncodeError(Exception): pass
|
|
class ProtocolBufferReturnError(Exception): pass
|
|
else:
|
|
ProtocolBufferDecodeError = proto1.ProtocolBufferDecodeError
|
|
ProtocolBufferEncodeError = proto1.ProtocolBufferEncodeError
|
|
ProtocolBufferReturnError = proto1.ProtocolBufferReturnError
|
|
|
|
__all__ = ['ProtocolMessage', 'Encoder', 'Decoder',
|
|
'ExtendableProtocolMessage',
|
|
'ProtocolBufferDecodeError',
|
|
'ProtocolBufferEncodeError',
|
|
'ProtocolBufferReturnError']
|
|
|
|
URL_RE = re.compile('^(https?)://([^/]+)(/.*)$')
|
|
|
|
|
|
class ProtocolMessage:
|
|
"""
|
|
The parent class of all protocol buffers.
|
|
NOTE: the methods that unconditionally raise NotImplementedError are
|
|
reimplemented by the subclasses of this class.
|
|
Subclasses are automatically generated by tools/protocol_converter.
|
|
Encoding methods can raise ProtocolBufferEncodeError if a value for an
|
|
integer or long field is too large, or if any required field is not set.
|
|
Decoding methods can raise ProtocolBufferDecodeError if they couldn't
|
|
decode correctly, or the decoded message doesn't have all required fields.
|
|
"""
|
|
|
|
#####################################
|
|
# methods you should use #
|
|
#####################################
|
|
|
|
def __init__(self, contents=None):
|
|
"""Construct a new protocol buffer, with optional starting contents
|
|
in binary protocol buffer format."""
|
|
raise NotImplementedError
|
|
|
|
def Clear(self):
|
|
"""Erases all fields of protocol buffer (& resets to defaults
|
|
if fields have defaults)."""
|
|
raise NotImplementedError
|
|
|
|
def IsInitialized(self, debug_strs=None):
|
|
"""returns true iff all required fields have been set."""
|
|
raise NotImplementedError
|
|
|
|
def Encode(self):
|
|
"""Returns a string representing the protocol buffer object."""
|
|
try:
|
|
return self._CEncode()
|
|
except (NotImplementedError, AttributeError):
|
|
e = Encoder()
|
|
self.Output(e)
|
|
return e.buffer().tostring()
|
|
|
|
def SerializeToString(self):
|
|
"""Same as Encode(), but has same name as proto2's serialize function."""
|
|
return self.Encode()
|
|
|
|
def SerializePartialToString(self):
|
|
"""Returns a string representing the protocol buffer object.
|
|
Same as SerializeToString() but does not enforce required fields are set.
|
|
"""
|
|
try:
|
|
return self._CEncodePartial()
|
|
except (NotImplementedError, AttributeError):
|
|
e = Encoder()
|
|
self.OutputPartial(e)
|
|
return e.buffer().tostring()
|
|
|
|
def _CEncode(self):
|
|
"""Call into C++ encode code.
|
|
|
|
Generated protocol buffer classes will override this method to
|
|
provide C++-based serialization. If a subclass does not
|
|
implement this method, Encode() will fall back to
|
|
using pure-Python encoding.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _CEncodePartial(self):
|
|
"""Same as _CEncode, except does not encode missing required fields."""
|
|
raise NotImplementedError
|
|
|
|
def ParseFromString(self, s):
|
|
"""Reads data from the string 's'.
|
|
Raises a ProtocolBufferDecodeError if, after successfully reading
|
|
in the contents of 's', this protocol message is still not initialized."""
|
|
self.Clear()
|
|
self.MergeFromString(s)
|
|
|
|
def ParsePartialFromString(self, s):
|
|
"""Reads data from the string 's'.
|
|
Does not enforce required fields are set."""
|
|
self.Clear()
|
|
self.MergePartialFromString(s)
|
|
|
|
def MergeFromString(self, s):
|
|
"""Adds in data from the string 's'.
|
|
Raises a ProtocolBufferDecodeError if, after successfully merging
|
|
in the contents of 's', this protocol message is still not initialized."""
|
|
self.MergePartialFromString(s)
|
|
dbg = []
|
|
if not self.IsInitialized(dbg):
|
|
raise ProtocolBufferDecodeError('\n\t'.join(dbg))
|
|
|
|
def MergePartialFromString(self, s):
|
|
"""Merges in data from the string 's'.
|
|
Does not enforce required fields are set."""
|
|
try:
|
|
self._CMergeFromString(s)
|
|
except (NotImplementedError, AttributeError):
|
|
# If we can't call into C++ to deserialize the string, use
|
|
# the (much slower) pure-Python implementation.
|
|
a = array.array('B')
|
|
a.fromstring(s)
|
|
d = Decoder(a, 0, len(a))
|
|
self.TryMerge(d)
|
|
|
|
def _CMergeFromString(self, s):
|
|
"""Call into C++ parsing code to merge from a string.
|
|
|
|
Does *not* check IsInitialized() before returning.
|
|
|
|
Generated protocol buffer classes will override this method to
|
|
provide C++-based deserialization. If a subclass does not
|
|
implement this method, MergeFromString() will fall back to
|
|
using pure-Python parsing.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __getstate__(self):
|
|
"""Return the pickled representation of the data inside protocol buffer,
|
|
which is the same as its binary-encoded representation (as a string)."""
|
|
return self.Encode()
|
|
|
|
def __setstate__(self, contents_):
|
|
"""Restore the pickled representation of the data inside protocol buffer.
|
|
Note that the mechanism underlying pickle.load() does not call __init__."""
|
|
self.__init__(contents=contents_)
|
|
|
|
def sendCommand(self, server, url, response, follow_redirects=1,
|
|
secure=0, keyfile=None, certfile=None):
|
|
"""posts the protocol buffer to the desired url on the server
|
|
and puts the return data into the protocol buffer 'response'
|
|
|
|
NOTE: The underlying socket raises the 'error' exception
|
|
for all I/O related errors (can't connect, etc.).
|
|
|
|
If 'response' is None, the server's PB response will be ignored.
|
|
|
|
The optional 'follow_redirects' argument indicates the number
|
|
of HTTP redirects that are followed before giving up and raising an
|
|
exception. The default is 1.
|
|
|
|
If 'secure' is true, HTTPS will be used instead of HTTP. Also,
|
|
'keyfile' and 'certfile' may be set for client authentication.
|
|
"""
|
|
data = self.Encode()
|
|
if secure:
|
|
if keyfile and certfile:
|
|
conn = six.moves.http_client.HTTPSConnection(server, key_file=keyfile,
|
|
cert_file=certfile)
|
|
else:
|
|
conn = six.moves.http_client.HTTPSConnection(server)
|
|
else:
|
|
conn = six.moves.http_client.HTTPConnection(server)
|
|
conn.putrequest("POST", url)
|
|
conn.putheader("Content-Length", "%d" %len(data))
|
|
conn.endheaders()
|
|
conn.send(data)
|
|
resp = conn.getresponse()
|
|
if follow_redirects > 0 and resp.status == 302:
|
|
m = URL_RE.match(resp.getheader('Location'))
|
|
if m:
|
|
protocol, server, url = m.groups()
|
|
return self.sendCommand(server, url, response,
|
|
follow_redirects=follow_redirects - 1,
|
|
secure=(protocol == 'https'),
|
|
keyfile=keyfile,
|
|
certfile=certfile)
|
|
if resp.status != 200:
|
|
raise ProtocolBufferReturnError(resp.status)
|
|
if response is not None:
|
|
response.ParseFromString(resp.read())
|
|
return response
|
|
|
|
def sendSecureCommand(self, server, keyfile, certfile, url, response,
|
|
follow_redirects=1):
|
|
"""posts the protocol buffer via https to the desired url on the server,
|
|
using the specified key and certificate files, and puts the return
|
|
data int othe protocol buffer 'response'.
|
|
|
|
See caveats in sendCommand.
|
|
|
|
You need an SSL-aware build of the Python2 interpreter to use this command.
|
|
(Python1 is not supported). An SSL build of python2.2 is in
|
|
/home/build/buildtools/python-ssl-2.2 . An SSL build of python is
|
|
standard on all prod machines.
|
|
|
|
keyfile: Contains our private RSA key
|
|
certfile: Contains SSL certificate for remote host
|
|
Specify None for keyfile/certfile if you don't want to do client auth.
|
|
"""
|
|
return self.sendCommand(server, url, response,
|
|
follow_redirects=follow_redirects,
|
|
secure=1, keyfile=keyfile, certfile=certfile)
|
|
|
|
def __str__(self, prefix="", printElemNumber=0):
|
|
"""Returns nicely formatted contents of this protocol buffer."""
|
|
raise NotImplementedError
|
|
|
|
def ToASCII(self):
|
|
"""Returns the protocol buffer as a human-readable string."""
|
|
return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII)
|
|
|
|
def ToShortASCII(self):
|
|
"""Returns the protocol buffer as an ASCII string.
|
|
The output is short, leaving out newlines and some other niceties.
|
|
Defers to the C++ ProtocolPrinter class in SYMBOLIC_SHORT mode.
|
|
"""
|
|
return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII)
|
|
|
|
# Note that these must be consistent with the ProtocolPrinter::Level C++
|
|
# enum.
|
|
_NUMERIC_ASCII = 0
|
|
_SYMBOLIC_SHORT_ASCII = 1
|
|
_SYMBOLIC_FULL_ASCII = 2
|
|
|
|
def _CToASCII(self, output_format):
|
|
"""Calls into C++ ASCII-generating code.
|
|
|
|
Generated protocol buffer classes will override this method to provide
|
|
C++-based ASCII output.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def ParseASCII(self, ascii_string):
|
|
"""Parses a string generated by ToASCII() or by the C++ DebugString()
|
|
method, initializing this protocol buffer with its contents. This method
|
|
raises a ValueError if it encounters an unknown field.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def ParseASCIIIgnoreUnknown(self, ascii_string):
|
|
"""Parses a string generated by ToASCII() or by the C++ DebugString()
|
|
method, initializing this protocol buffer with its contents. Ignores
|
|
unknown fields.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def Equals(self, other):
|
|
"""Returns whether or not this protocol buffer is equivalent to another.
|
|
|
|
This assumes that self and other are of the same type.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __eq__(self, other):
|
|
"""Implementation of operator ==."""
|
|
# If self and other are of different types we return NotImplemented, which
|
|
# tells the Python interpreter to try some other methods of measuring
|
|
# equality before finally performing an identity comparison. This allows
|
|
# other classes to implement custom __eq__ or __ne__ methods.
|
|
# See http://docs.sympy.org/_sources/python-comparisons.txt
|
|
if other.__class__ is self.__class__:
|
|
return self.Equals(other)
|
|
return NotImplemented
|
|
|
|
def __ne__(self, other):
|
|
"""Implementation of operator !=."""
|
|
# We repeat code for __ne__ instead of returning "not (self == other)"
|
|
# so that we can return NotImplemented when comparing against an object of
|
|
# a different type.
|
|
# See http://bugs.python.org/msg76374 for an example of when __ne__ might
|
|
# return something other than the Boolean opposite of __eq__.
|
|
if other.__class__ is self.__class__:
|
|
return not self.Equals(other)
|
|
return NotImplemented
|
|
|
|
#####################################
|
|
# methods power-users might want #
|
|
#####################################
|
|
|
|
def Output(self, e):
|
|
"""write self to the encoder 'e'."""
|
|
dbg = []
|
|
if not self.IsInitialized(dbg):
|
|
raise ProtocolBufferEncodeError('\n\t'.join(dbg))
|
|
self.OutputUnchecked(e)
|
|
return
|
|
|
|
def OutputUnchecked(self, e):
|
|
"""write self to the encoder 'e', don't check for initialization."""
|
|
raise NotImplementedError
|
|
|
|
def OutputPartial(self, e):
|
|
"""write self to the encoder 'e', don't check for initialization and
|
|
don't assume required fields exist."""
|
|
raise NotImplementedError
|
|
|
|
def Parse(self, d):
|
|
"""reads data from the Decoder 'd'."""
|
|
self.Clear()
|
|
self.Merge(d)
|
|
return
|
|
|
|
def Merge(self, d):
|
|
"""merges data from the Decoder 'd'."""
|
|
self.TryMerge(d)
|
|
dbg = []
|
|
if not self.IsInitialized(dbg):
|
|
raise ProtocolBufferDecodeError('\n\t'.join(dbg))
|
|
return
|
|
|
|
def TryMerge(self, d):
|
|
"""merges data from the Decoder 'd'."""
|
|
raise NotImplementedError
|
|
|
|
def CopyFrom(self, pb):
|
|
"""copy data from another protocol buffer"""
|
|
if (pb == self): return
|
|
self.Clear()
|
|
self.MergeFrom(pb)
|
|
|
|
def MergeFrom(self, pb):
|
|
"""merge data from another protocol buffer"""
|
|
raise NotImplementedError
|
|
|
|
#####################################
|
|
# helper methods for subclasses #
|
|
#####################################
|
|
|
|
def lengthVarInt32(self, n):
|
|
return self.lengthVarInt64(n)
|
|
|
|
def lengthVarInt64(self, n):
|
|
if n < 0:
|
|
return 10 # ceil(64/7)
|
|
result = 0
|
|
while 1:
|
|
result += 1
|
|
n >>= 7
|
|
if n == 0:
|
|
break
|
|
return result
|
|
|
|
def lengthString(self, n):
|
|
return self.lengthVarInt32(n) + n
|
|
|
|
def DebugFormat(self, value):
|
|
return "%s" % value
|
|
def DebugFormatInt32(self, value):
|
|
if (value <= -2000000000 or value >= 2000000000):
|
|
return self.DebugFormatFixed32(value)
|
|
return "%d" % value
|
|
def DebugFormatInt64(self, value):
|
|
if (value <= -20000000000000 or value >= 20000000000000):
|
|
return self.DebugFormatFixed64(value)
|
|
return "%d" % value
|
|
def DebugFormatString(self, value):
|
|
# For now we only escape the bare minimum to insure interoperability
|
|
# and redability. In the future we may want to mimick the c++ behavior
|
|
# more closely, but this will make the code a lot more messy.
|
|
def escape(c):
|
|
o = ord(c)
|
|
if o == 10: return r"\n" # optional escape
|
|
if o == 39: return r"\'" # optional escape
|
|
|
|
if o == 34: return r'\"' # necessary escape
|
|
if o == 92: return r"\\" # necessary escape
|
|
|
|
if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes
|
|
return c
|
|
return '"' + "".join(escape(c) for c in value) + '"'
|
|
def DebugFormatFloat(self, value):
|
|
return "%ff" % value
|
|
def DebugFormatFixed32(self, value):
|
|
if (value < 0): value += (1<<32)
|
|
return "0x%x" % value
|
|
def DebugFormatFixed64(self, value):
|
|
if (value < 0): value += (1<<64)
|
|
return "0x%x" % value
|
|
def DebugFormatBool(self, value):
|
|
if value:
|
|
return "true"
|
|
else:
|
|
return "false"
|
|
|
|
# types of fields, must match Proto::Type and net/proto/protocoltype.proto
|
|
TYPE_DOUBLE = 1
|
|
TYPE_FLOAT = 2
|
|
TYPE_INT64 = 3
|
|
TYPE_UINT64 = 4
|
|
TYPE_INT32 = 5
|
|
TYPE_FIXED64 = 6
|
|
TYPE_FIXED32 = 7
|
|
TYPE_BOOL = 8
|
|
TYPE_STRING = 9
|
|
TYPE_GROUP = 10
|
|
TYPE_FOREIGN = 11
|
|
|
|
# debug string for extensions
|
|
_TYPE_TO_DEBUG_STRING = {
|
|
TYPE_INT32: ProtocolMessage.DebugFormatInt32,
|
|
TYPE_INT64: ProtocolMessage.DebugFormatInt64,
|
|
TYPE_UINT64: ProtocolMessage.DebugFormatInt64,
|
|
TYPE_FLOAT: ProtocolMessage.DebugFormatFloat,
|
|
TYPE_STRING: ProtocolMessage.DebugFormatString,
|
|
TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32,
|
|
TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64,
|
|
TYPE_BOOL: ProtocolMessage.DebugFormatBool }
|
|
|
|
# users of protocol buffers usually won't need to concern themselves
|
|
# with either Encoders or Decoders.
|
|
class Encoder:
|
|
|
|
# types of data
|
|
NUMERIC = 0
|
|
DOUBLE = 1
|
|
STRING = 2
|
|
STARTGROUP = 3
|
|
ENDGROUP = 4
|
|
FLOAT = 5
|
|
MAX_TYPE = 6
|
|
|
|
def __init__(self):
|
|
self.buf = array.array('B')
|
|
return
|
|
|
|
def buffer(self):
|
|
return self.buf
|
|
|
|
def put8(self, v):
|
|
if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError("u8 too big")
|
|
self.buf.append(v & 255)
|
|
return
|
|
|
|
def put16(self, v):
|
|
if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError("u16 too big")
|
|
self.buf.append((v >> 0) & 255)
|
|
self.buf.append((v >> 8) & 255)
|
|
return
|
|
|
|
def put32(self, v):
|
|
if v < 0 or v >= (1<<32): raise ProtocolBufferEncodeError("u32 too big")
|
|
self.buf.append((v >> 0) & 255)
|
|
self.buf.append((v >> 8) & 255)
|
|
self.buf.append((v >> 16) & 255)
|
|
self.buf.append((v >> 24) & 255)
|
|
return
|
|
|
|
def put64(self, v):
|
|
if v < 0 or v >= (1<<64): raise ProtocolBufferEncodeError("u64 too big")
|
|
self.buf.append((v >> 0) & 255)
|
|
self.buf.append((v >> 8) & 255)
|
|
self.buf.append((v >> 16) & 255)
|
|
self.buf.append((v >> 24) & 255)
|
|
self.buf.append((v >> 32) & 255)
|
|
self.buf.append((v >> 40) & 255)
|
|
self.buf.append((v >> 48) & 255)
|
|
self.buf.append((v >> 56) & 255)
|
|
return
|
|
|
|
def putVarInt32(self, v):
|
|
# Profiling has shown this code to be very performance critical
|
|
# so we duplicate code, go for early exits when possible, etc.
|
|
# VarInt32 gets more unrolling because VarInt32s are far and away
|
|
# the most common element in protobufs (field tags and string
|
|
# lengths), so they get more attention. They're also more
|
|
# likely to fit in one byte (string lengths again), so we
|
|
# check and bail out early if possible.
|
|
|
|
buf_append = self.buf.append # cache attribute lookup
|
|
if v & 127 == v:
|
|
buf_append(v)
|
|
return
|
|
if v >= 0x80000000 or v < -0x80000000: # python2.4 doesn't fold constants
|
|
raise ProtocolBufferEncodeError("int32 too big")
|
|
if v < 0:
|
|
v += 0x10000000000000000
|
|
while True:
|
|
bits = v & 127
|
|
v >>= 7
|
|
if v:
|
|
bits |= 128
|
|
buf_append(bits)
|
|
if not v:
|
|
break
|
|
return
|
|
|
|
def putVarInt64(self, v):
|
|
buf_append = self.buf.append
|
|
if v >= 0x8000000000000000 or v < -0x8000000000000000:
|
|
raise ProtocolBufferEncodeError("int64 too big")
|
|
if v < 0:
|
|
v += 0x10000000000000000
|
|
while True:
|
|
bits = v & 127
|
|
v >>= 7
|
|
if v:
|
|
bits |= 128
|
|
buf_append(bits)
|
|
if not v:
|
|
break
|
|
return
|
|
|
|
def putVarUint64(self, v):
|
|
buf_append = self.buf.append
|
|
if v < 0 or v >= 0x10000000000000000:
|
|
raise ProtocolBufferEncodeError("uint64 too big")
|
|
while True:
|
|
bits = v & 127
|
|
v >>= 7
|
|
if v:
|
|
bits |= 128
|
|
buf_append(bits)
|
|
if not v:
|
|
break
|
|
return
|
|
|
|
def putFloat(self, v):
|
|
a = array.array('B')
|
|
a.fromstring(struct.pack("<f", v))
|
|
self.buf.extend(a)
|
|
return
|
|
|
|
def putDouble(self, v):
|
|
a = array.array('B')
|
|
a.fromstring(struct.pack("<d", v))
|
|
self.buf.extend(a)
|
|
return
|
|
|
|
def putBoolean(self, v):
|
|
if v:
|
|
self.buf.append(1)
|
|
else:
|
|
self.buf.append(0)
|
|
return
|
|
|
|
def putPrefixedString(self, v):
|
|
# This change prevents corrupted encoding an YouTube, where
|
|
# our default encoding is utf-8 and unicode strings may occasionally be
|
|
# passed into ProtocolBuffers.
|
|
v = str(v)
|
|
self.putVarInt32(len(v))
|
|
self.buf.fromstring(v)
|
|
return
|
|
|
|
def putRawString(self, v):
|
|
self.buf.fromstring(v)
|
|
|
|
_TYPE_TO_METHOD = {
|
|
TYPE_DOUBLE: putDouble,
|
|
TYPE_FLOAT: putFloat,
|
|
TYPE_FIXED64: put64,
|
|
TYPE_FIXED32: put32,
|
|
TYPE_INT32: putVarInt32,
|
|
TYPE_INT64: putVarInt64,
|
|
TYPE_UINT64: putVarUint64,
|
|
TYPE_BOOL: putBoolean,
|
|
TYPE_STRING: putPrefixedString }
|
|
|
|
_TYPE_TO_BYTE_SIZE = {
|
|
TYPE_DOUBLE: 8,
|
|
TYPE_FLOAT: 4,
|
|
TYPE_FIXED64: 8,
|
|
TYPE_FIXED32: 4,
|
|
TYPE_BOOL: 1 }
|
|
|
|
class Decoder:
|
|
def __init__(self, buf, idx, limit):
|
|
self.buf = buf
|
|
self.idx = idx
|
|
self.limit = limit
|
|
return
|
|
|
|
def avail(self):
|
|
return self.limit - self.idx
|
|
|
|
def buffer(self):
|
|
return self.buf
|
|
|
|
def pos(self):
|
|
return self.idx
|
|
|
|
def skip(self, n):
|
|
if self.idx + n > self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
self.idx += n
|
|
return
|
|
|
|
def skipData(self, tag):
|
|
t = tag & 7 # tag format type
|
|
if t == Encoder.NUMERIC:
|
|
self.getVarInt64()
|
|
elif t == Encoder.DOUBLE:
|
|
self.skip(8)
|
|
elif t == Encoder.STRING:
|
|
n = self.getVarInt32()
|
|
self.skip(n)
|
|
elif t == Encoder.STARTGROUP:
|
|
while 1:
|
|
t = self.getVarInt32()
|
|
if (t & 7) == Encoder.ENDGROUP:
|
|
break
|
|
else:
|
|
self.skipData(t)
|
|
if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP):
|
|
raise ProtocolBufferDecodeError("corrupted")
|
|
elif t == Encoder.ENDGROUP:
|
|
raise ProtocolBufferDecodeError("corrupted")
|
|
elif t == Encoder.FLOAT:
|
|
self.skip(4)
|
|
else:
|
|
raise ProtocolBufferDecodeError("corrupted")
|
|
|
|
# these are all unsigned gets
|
|
def get8(self):
|
|
if self.idx >= self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
c = self.buf[self.idx]
|
|
self.idx += 1
|
|
return c
|
|
|
|
def get16(self):
|
|
if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
c = self.buf[self.idx]
|
|
d = self.buf[self.idx + 1]
|
|
self.idx += 2
|
|
return (d << 8) | c
|
|
|
|
def get32(self):
|
|
if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
c = self.buf[self.idx]
|
|
d = self.buf[self.idx + 1]
|
|
e = self.buf[self.idx + 2]
|
|
f = int(self.buf[self.idx + 3])
|
|
self.idx += 4
|
|
return (f << 24) | (e << 16) | (d << 8) | c
|
|
|
|
def get64(self):
|
|
if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
c = self.buf[self.idx]
|
|
d = self.buf[self.idx + 1]
|
|
e = self.buf[self.idx + 2]
|
|
f = int(self.buf[self.idx + 3])
|
|
g = int(self.buf[self.idx + 4])
|
|
h = int(self.buf[self.idx + 5])
|
|
i = int(self.buf[self.idx + 6])
|
|
j = int(self.buf[self.idx + 7])
|
|
self.idx += 8
|
|
return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24)
|
|
| (e << 16) | (d << 8) | c)
|
|
|
|
def getVarInt32(self):
|
|
# getVarInt32 gets different treatment than other integer getter
|
|
# functions due to the much larger number of varInt32s and also
|
|
# varInt32s that fit in one byte. See the comment at putVarInt32.
|
|
b = self.get8()
|
|
if not (b & 128):
|
|
return b
|
|
|
|
result = int(0)
|
|
shift = 0
|
|
|
|
while 1:
|
|
result |= (int(b & 127) << shift)
|
|
shift += 7
|
|
if not (b & 128):
|
|
if result >= 0x10000000000000000: # (1L << 64):
|
|
raise ProtocolBufferDecodeError("corrupted")
|
|
break
|
|
if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
|
|
b = self.get8()
|
|
|
|
if result >= 0x8000000000000000: # (1L << 63)
|
|
result -= 0x10000000000000000 # (1L << 64)
|
|
if result >= 0x80000000 or result < -0x80000000: # (1L << 31)
|
|
raise ProtocolBufferDecodeError("corrupted")
|
|
return result
|
|
|
|
def getVarInt64(self):
|
|
result = self.getVarUint64()
|
|
if result >= (1 << 63):
|
|
result -= (1 << 64)
|
|
return result
|
|
|
|
def getVarUint64(self):
|
|
result = int(0)
|
|
shift = 0
|
|
while 1:
|
|
if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
|
|
b = self.get8()
|
|
result |= (int(b & 127) << shift)
|
|
shift += 7
|
|
if not (b & 128):
|
|
if result >= (1 << 64): raise ProtocolBufferDecodeError("corrupted")
|
|
return result
|
|
return result # make pychecker happy
|
|
|
|
def getFloat(self):
|
|
if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
a = self.buf[self.idx:self.idx+4]
|
|
self.idx += 4
|
|
return struct.unpack("<f", a)[0]
|
|
|
|
def getDouble(self):
|
|
if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
|
|
a = self.buf[self.idx:self.idx+8]
|
|
self.idx += 8
|
|
return struct.unpack("<d", a)[0]
|
|
|
|
def getBoolean(self):
|
|
b = self.get8()
|
|
if b != 0 and b != 1: raise ProtocolBufferDecodeError("corrupted")
|
|
return b
|
|
|
|
def getPrefixedString(self):
|
|
length = self.getVarInt32()
|
|
if self.idx + length > self.limit:
|
|
raise ProtocolBufferDecodeError("truncated")
|
|
r = self.buf[self.idx : self.idx + length]
|
|
self.idx += length
|
|
return r.tostring()
|
|
|
|
def getRawString(self):
|
|
r = self.buf[self.idx:self.limit]
|
|
self.idx = self.limit
|
|
return r.tostring()
|
|
|
|
_TYPE_TO_METHOD = {
|
|
TYPE_DOUBLE: getDouble,
|
|
TYPE_FLOAT: getFloat,
|
|
TYPE_FIXED64: get64,
|
|
TYPE_FIXED32: get32,
|
|
TYPE_INT32: getVarInt32,
|
|
TYPE_INT64: getVarInt64,
|
|
TYPE_UINT64: getVarUint64,
|
|
TYPE_BOOL: getBoolean,
|
|
TYPE_STRING: getPrefixedString }
|
|
|
|
#####################################
|
|
# extensions #
|
|
#####################################
|
|
|
|
class ExtensionIdentifier(object):
|
|
__slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated',
|
|
'default', 'containing_cls', 'composite_cls', 'message_name')
|
|
def __init__(self, full_name, number, field_type, wire_tag, is_repeated,
|
|
default):
|
|
self.full_name = full_name
|
|
self.number = number
|
|
self.field_type = field_type
|
|
self.wire_tag = wire_tag
|
|
self.is_repeated = is_repeated
|
|
self.default = default
|
|
|
|
class ExtendableProtocolMessage(ProtocolMessage):
|
|
def HasExtension(self, extension):
|
|
"""Checks if the message contains a certain non-repeated extension."""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
return extension in self._extension_fields
|
|
|
|
def ClearExtension(self, extension):
|
|
"""Clears the value of extension, so that HasExtension() returns false or
|
|
ExtensionSize() returns 0."""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if extension in self._extension_fields:
|
|
del self._extension_fields[extension]
|
|
|
|
def GetExtension(self, extension, index=None):
|
|
"""Gets the extension value for a certain extension.
|
|
|
|
Args:
|
|
extension: The ExtensionIdentifier for the extension.
|
|
index: The index of element to get in a repeated field. Only needed if
|
|
the extension is repeated.
|
|
|
|
Returns:
|
|
The value of the extension if exists, otherwise the default value of the
|
|
extension will be returned.
|
|
"""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if extension in self._extension_fields:
|
|
result = self._extension_fields[extension]
|
|
else:
|
|
if extension.is_repeated:
|
|
result = []
|
|
elif extension.composite_cls:
|
|
result = extension.composite_cls()
|
|
else:
|
|
result = extension.default
|
|
if extension.is_repeated:
|
|
result = result[index]
|
|
return result
|
|
|
|
def SetExtension(self, extension, *args):
|
|
"""Sets the extension value for a certain scalar type extension.
|
|
|
|
Arg varies according to extension type:
|
|
- Singular:
|
|
message.SetExtension(extension, value)
|
|
- Repeated:
|
|
message.SetExtension(extension, index, value)
|
|
where
|
|
extension: The ExtensionIdentifier for the extension.
|
|
index: The index of element to set in a repeated field. Only needed if
|
|
the extension is repeated.
|
|
value: The value to set.
|
|
|
|
Raises:
|
|
TypeError if a message type extension is given.
|
|
"""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if extension.composite_cls:
|
|
raise TypeError(
|
|
'Cannot assign to extension "%s" because it is a composite type.' %
|
|
extension.full_name)
|
|
if extension.is_repeated:
|
|
try:
|
|
index, value = args
|
|
except ValueError:
|
|
raise TypeError(
|
|
"SetExtension(extension, index, value) for repeated extension "
|
|
"takes exactly 4 arguments: (%d given)" % (len(args) + 2))
|
|
self._extension_fields[extension][index] = value
|
|
else:
|
|
try:
|
|
(value,) = args
|
|
except ValueError:
|
|
raise TypeError(
|
|
"SetExtension(extension, value) for singular extension "
|
|
"takes exactly 3 arguments: (%d given)" % (len(args) + 2))
|
|
self._extension_fields[extension] = value
|
|
|
|
def MutableExtension(self, extension, index=None):
|
|
"""Gets a mutable reference of a message type extension.
|
|
|
|
For repeated extension, index must be specified, and only one element will
|
|
be returned. For optional extension, if the extension does not exist, a new
|
|
message will be created and set in parent message.
|
|
|
|
Args:
|
|
extension: The ExtensionIdentifier for the extension.
|
|
index: The index of element to mutate in a repeated field. Only needed if
|
|
the extension is repeated.
|
|
|
|
Returns:
|
|
The mutable message reference.
|
|
|
|
Raises:
|
|
TypeError if non-message type extension is given.
|
|
"""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if extension.composite_cls is None:
|
|
raise TypeError(
|
|
'MutableExtension() cannot be applied to "%s", because it is not a '
|
|
'composite type.' % extension.full_name)
|
|
if extension.is_repeated:
|
|
if index is None:
|
|
raise TypeError(
|
|
'MutableExtension(extension, index) for repeated extension '
|
|
'takes exactly 2 arguments: (1 given)')
|
|
return self.GetExtension(extension, index)
|
|
if extension in self._extension_fields:
|
|
return self._extension_fields[extension]
|
|
else:
|
|
result = extension.composite_cls()
|
|
self._extension_fields[extension] = result
|
|
return result
|
|
|
|
def ExtensionList(self, extension):
|
|
"""Returns a mutable list of extensions.
|
|
|
|
Raises:
|
|
TypeError if the extension is not repeated.
|
|
"""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if not extension.is_repeated:
|
|
raise TypeError(
|
|
'ExtensionList() cannot be applied to "%s", because it is not a '
|
|
'repeated extension.' % extension.full_name)
|
|
if extension in self._extension_fields:
|
|
return self._extension_fields[extension]
|
|
result = []
|
|
self._extension_fields[extension] = result
|
|
return result
|
|
|
|
def ExtensionSize(self, extension):
|
|
"""Returns the size of a repeated extension.
|
|
|
|
Raises:
|
|
TypeError if the extension is not repeated.
|
|
"""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if not extension.is_repeated:
|
|
raise TypeError(
|
|
'ExtensionSize() cannot be applied to "%s", because it is not a '
|
|
'repeated extension.' % extension.full_name)
|
|
if extension in self._extension_fields:
|
|
return len(self._extension_fields[extension])
|
|
return 0
|
|
|
|
def AddExtension(self, extension, value=None):
|
|
"""Appends a new element into a repeated extension.
|
|
|
|
Arg varies according to the extension field type:
|
|
- Scalar/String:
|
|
message.AddExtension(extension, value)
|
|
- Message:
|
|
mutable_message = AddExtension(extension)
|
|
|
|
Args:
|
|
extension: The ExtensionIdentifier for the extension.
|
|
value: The value of the extension if the extension is scalar/string type.
|
|
The value must NOT be set for message type extensions; set values on
|
|
the returned message object instead.
|
|
|
|
Returns:
|
|
A mutable new message if it's a message type extension, or None otherwise.
|
|
|
|
Raises:
|
|
TypeError if the extension is not repeated, or value is given for message
|
|
type extensions.
|
|
"""
|
|
self._VerifyExtensionIdentifier(extension)
|
|
if not extension.is_repeated:
|
|
raise TypeError(
|
|
'AddExtension() cannot be applied to "%s", because it is not a '
|
|
'repeated extension.' % extension.full_name)
|
|
if extension in self._extension_fields:
|
|
field = self._extension_fields[extension]
|
|
else:
|
|
field = []
|
|
self._extension_fields[extension] = field
|
|
# Composite field
|
|
if extension.composite_cls:
|
|
if value is not None:
|
|
raise TypeError(
|
|
'value must not be set in AddExtension() for "%s", because it is '
|
|
'a message type extension. Set values on the returned message '
|
|
'instead.' % extension.full_name)
|
|
msg = extension.composite_cls()
|
|
field.append(msg)
|
|
return msg
|
|
# Scalar and string field
|
|
field.append(value)
|
|
|
|
def _VerifyExtensionIdentifier(self, extension):
|
|
if extension.containing_cls != self.__class__:
|
|
raise TypeError("Containing type of %s is %s, but not %s."
|
|
% (extension.full_name,
|
|
extension.containing_cls.__name__,
|
|
self.__class__.__name__))
|
|
|
|
def _MergeExtensionFields(self, x):
|
|
for ext, val in x._extension_fields.items():
|
|
if ext.is_repeated:
|
|
for single_val in val:
|
|
if ext.composite_cls is None:
|
|
self.AddExtension(ext, single_val)
|
|
else:
|
|
self.AddExtension(ext).MergeFrom(single_val)
|
|
else:
|
|
if ext.composite_cls is None:
|
|
self.SetExtension(ext, val)
|
|
else:
|
|
self.MutableExtension(ext).MergeFrom(val)
|
|
|
|
def _ListExtensions(self):
|
|
return sorted(
|
|
(ext for ext in self._extension_fields
|
|
if (not ext.is_repeated) or self.ExtensionSize(ext) > 0),
|
|
key=lambda item: item.number)
|
|
|
|
def _ExtensionEquals(self, x):
|
|
extensions = self._ListExtensions()
|
|
if extensions != x._ListExtensions():
|
|
return False
|
|
for ext in extensions:
|
|
if ext.is_repeated:
|
|
if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False
|
|
for e1, e2 in zip(self.ExtensionList(ext),
|
|
x.ExtensionList(ext)):
|
|
if e1 != e2: return False
|
|
else:
|
|
if self.GetExtension(ext) != x.GetExtension(ext): return False
|
|
return True
|
|
|
|
def _OutputExtensionFields(self, out, partial, extensions, start_index,
|
|
end_field_number):
|
|
"""Serialize a range of extensions.
|
|
|
|
To generate canonical output when encoding, we interleave fields and
|
|
extensions to preserve tag order.
|
|
|
|
Generated code will prepare a list of ExtensionIdentifier sorted in field
|
|
number order and call this method to serialize a specific range of
|
|
extensions. The range is specified by the two arguments, start_index and
|
|
end_field_number.
|
|
|
|
The method will serialize all extensions[i] with i >= start_index and
|
|
extensions[i].number < end_field_number. Since extensions argument is sorted
|
|
by field_number, this is a contiguous range; the first index j not included
|
|
in that range is returned. The return value can be used as the start_index
|
|
in the next call to serialize the next range of extensions.
|
|
|
|
Args:
|
|
extensions: A list of ExtensionIdentifier sorted in field number order.
|
|
start_index: The start index in the extensions list.
|
|
end_field_number: The end field number of the extension range.
|
|
|
|
Returns:
|
|
The first index that is not in the range. Or the size of extensions if all
|
|
the extensions are within the range.
|
|
"""
|
|
def OutputSingleField(ext, value):
|
|
out.putVarInt32(ext.wire_tag)
|
|
if ext.field_type == TYPE_GROUP:
|
|
if partial:
|
|
value.OutputPartial(out)
|
|
else:
|
|
value.OutputUnchecked(out)
|
|
out.putVarInt32(ext.wire_tag + 1) # End the group
|
|
elif ext.field_type == TYPE_FOREIGN:
|
|
if partial:
|
|
out.putVarInt32(value.ByteSizePartial())
|
|
value.OutputPartial(out)
|
|
else:
|
|
out.putVarInt32(value.ByteSize())
|
|
value.OutputUnchecked(out)
|
|
else:
|
|
Encoder._TYPE_TO_METHOD[ext.field_type](out, value)
|
|
|
|
for ext_index, ext in enumerate(
|
|
itertools.islice(extensions, start_index, None), start=start_index):
|
|
if ext.number >= end_field_number:
|
|
# exceeding extension range end.
|
|
return ext_index
|
|
if ext.is_repeated:
|
|
for field in self._extension_fields[ext]:
|
|
OutputSingleField(ext, field)
|
|
else:
|
|
OutputSingleField(ext, self._extension_fields[ext])
|
|
return len(extensions)
|
|
|
|
def _ParseOneExtensionField(self, wire_tag, d):
|
|
number = wire_tag >> 3
|
|
if number in self._extensions_by_field_number:
|
|
ext = self._extensions_by_field_number[number]
|
|
if wire_tag != ext.wire_tag:
|
|
# wire_tag doesn't match; discard as unknown field.
|
|
return
|
|
if ext.field_type == TYPE_FOREIGN:
|
|
length = d.getVarInt32()
|
|
tmp = Decoder(d.buffer(), d.pos(), d.pos() + length)
|
|
if ext.is_repeated:
|
|
self.AddExtension(ext).TryMerge(tmp)
|
|
else:
|
|
self.MutableExtension(ext).TryMerge(tmp)
|
|
d.skip(length)
|
|
elif ext.field_type == TYPE_GROUP:
|
|
if ext.is_repeated:
|
|
self.AddExtension(ext).TryMerge(d)
|
|
else:
|
|
self.MutableExtension(ext).TryMerge(d)
|
|
else:
|
|
value = Decoder._TYPE_TO_METHOD[ext.field_type](d)
|
|
if ext.is_repeated:
|
|
self.AddExtension(ext, value)
|
|
else:
|
|
self.SetExtension(ext, value)
|
|
else:
|
|
# discard unknown extensions.
|
|
d.skipData(wire_tag)
|
|
|
|
def _ExtensionByteSize(self, partial):
|
|
size = 0
|
|
for extension, value in six.iteritems(self._extension_fields):
|
|
ftype = extension.field_type
|
|
tag_size = self.lengthVarInt64(extension.wire_tag)
|
|
if ftype == TYPE_GROUP:
|
|
tag_size *= 2 # end tag
|
|
if extension.is_repeated:
|
|
size += tag_size * len(value)
|
|
for single_value in value:
|
|
size += self._FieldByteSize(ftype, single_value, partial)
|
|
else:
|
|
size += tag_size + self._FieldByteSize(ftype, value, partial)
|
|
return size
|
|
|
|
def _FieldByteSize(self, ftype, value, partial):
|
|
size = 0
|
|
if ftype == TYPE_STRING:
|
|
size = self.lengthString(len(value))
|
|
elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP:
|
|
if partial:
|
|
size = self.lengthString(value.ByteSizePartial())
|
|
else:
|
|
size = self.lengthString(value.ByteSize())
|
|
elif ftype == TYPE_INT64 or \
|
|
ftype == TYPE_UINT64 or \
|
|
ftype == TYPE_INT32:
|
|
size = self.lengthVarInt64(value)
|
|
else:
|
|
if ftype in Encoder._TYPE_TO_BYTE_SIZE:
|
|
size = Encoder._TYPE_TO_BYTE_SIZE[ftype]
|
|
else:
|
|
raise AssertionError(
|
|
'Extension type %d is not recognized.' % ftype)
|
|
return size
|
|
|
|
def _ExtensionDebugString(self, prefix, printElemNumber):
|
|
res = ''
|
|
extensions = self._ListExtensions()
|
|
for extension in extensions:
|
|
value = self._extension_fields[extension]
|
|
if extension.is_repeated:
|
|
cnt = 0
|
|
for e in value:
|
|
elm=""
|
|
if printElemNumber: elm = "(%d)" % cnt
|
|
if extension.composite_cls is not None:
|
|
res += prefix + "[%s%s] {\n" % \
|
|
(extension.full_name, elm)
|
|
res += e.__str__(prefix + " ", printElemNumber)
|
|
res += prefix + "}\n"
|
|
else:
|
|
if extension.composite_cls is not None:
|
|
res += prefix + "[%s] {\n" % extension.full_name
|
|
res += value.__str__(
|
|
prefix + " ", printElemNumber)
|
|
res += prefix + "}\n"
|
|
else:
|
|
if extension.field_type in _TYPE_TO_DEBUG_STRING:
|
|
text_value = _TYPE_TO_DEBUG_STRING[
|
|
extension.field_type](self, value)
|
|
else:
|
|
text_value = self.DebugFormat(value)
|
|
res += prefix + "[%s]: %s\n" % (extension.full_name, text_value)
|
|
return res
|
|
|
|
@staticmethod
|
|
def _RegisterExtension(cls, extension, composite_cls=None):
|
|
extension.containing_cls = cls
|
|
extension.composite_cls = composite_cls
|
|
if composite_cls is not None:
|
|
extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME
|
|
actual_handle = cls._extensions_by_field_number.setdefault(
|
|
extension.number, extension)
|
|
if actual_handle is not extension:
|
|
raise AssertionError(
|
|
'Extensions "%s" and "%s" both try to extend message type "%s" with '
|
|
'field number %d.' %
|
|
(extension.full_name, actual_handle.full_name,
|
|
cls.__name__, extension.number))
|