# Copyright 2016 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2002, Google LLC. from __future__ import absolute_import import array import itertools import re import struct import six import six.moves.http_client try: # NOTE(user): Using non-google-style import to workaround a zipimport_tinypar # issue for zip files embedded in par files. See http://b/13811096 import googlecloudsdk.appengine.proto.proto1 as proto1 except ImportError: # Protect in case of missing deps / strange env (GAE?) / etc. class ProtocolBufferDecodeError(Exception): pass class ProtocolBufferEncodeError(Exception): pass class ProtocolBufferReturnError(Exception): pass else: ProtocolBufferDecodeError = proto1.ProtocolBufferDecodeError ProtocolBufferEncodeError = proto1.ProtocolBufferEncodeError ProtocolBufferReturnError = proto1.ProtocolBufferReturnError __all__ = ['ProtocolMessage', 'Encoder', 'Decoder', 'ExtendableProtocolMessage', 'ProtocolBufferDecodeError', 'ProtocolBufferEncodeError', 'ProtocolBufferReturnError'] URL_RE = re.compile('^(https?)://([^/]+)(/.*)$') class ProtocolMessage: """ The parent class of all protocol buffers. NOTE: the methods that unconditionally raise NotImplementedError are reimplemented by the subclasses of this class. Subclasses are automatically generated by tools/protocol_converter. Encoding methods can raise ProtocolBufferEncodeError if a value for an integer or long field is too large, or if any required field is not set. Decoding methods can raise ProtocolBufferDecodeError if they couldn't decode correctly, or the decoded message doesn't have all required fields. """ ##################################### # methods you should use # ##################################### def __init__(self, contents=None): """Construct a new protocol buffer, with optional starting contents in binary protocol buffer format.""" raise NotImplementedError def Clear(self): """Erases all fields of protocol buffer (& resets to defaults if fields have defaults).""" raise NotImplementedError def IsInitialized(self, debug_strs=None): """returns true iff all required fields have been set.""" raise NotImplementedError def Encode(self): """Returns a string representing the protocol buffer object.""" try: return self._CEncode() except (NotImplementedError, AttributeError): e = Encoder() self.Output(e) return e.buffer().tostring() def SerializeToString(self): """Same as Encode(), but has same name as proto2's serialize function.""" return self.Encode() def SerializePartialToString(self): """Returns a string representing the protocol buffer object. Same as SerializeToString() but does not enforce required fields are set. """ try: return self._CEncodePartial() except (NotImplementedError, AttributeError): e = Encoder() self.OutputPartial(e) return e.buffer().tostring() def _CEncode(self): """Call into C++ encode code. Generated protocol buffer classes will override this method to provide C++-based serialization. If a subclass does not implement this method, Encode() will fall back to using pure-Python encoding. """ raise NotImplementedError def _CEncodePartial(self): """Same as _CEncode, except does not encode missing required fields.""" raise NotImplementedError def ParseFromString(self, s): """Reads data from the string 's'. Raises a ProtocolBufferDecodeError if, after successfully reading in the contents of 's', this protocol message is still not initialized.""" self.Clear() self.MergeFromString(s) def ParsePartialFromString(self, s): """Reads data from the string 's'. Does not enforce required fields are set.""" self.Clear() self.MergePartialFromString(s) def MergeFromString(self, s): """Adds in data from the string 's'. Raises a ProtocolBufferDecodeError if, after successfully merging in the contents of 's', this protocol message is still not initialized.""" self.MergePartialFromString(s) dbg = [] if not self.IsInitialized(dbg): raise ProtocolBufferDecodeError('\n\t'.join(dbg)) def MergePartialFromString(self, s): """Merges in data from the string 's'. Does not enforce required fields are set.""" try: self._CMergeFromString(s) except (NotImplementedError, AttributeError): # If we can't call into C++ to deserialize the string, use # the (much slower) pure-Python implementation. a = array.array('B') a.fromstring(s) d = Decoder(a, 0, len(a)) self.TryMerge(d) def _CMergeFromString(self, s): """Call into C++ parsing code to merge from a string. Does *not* check IsInitialized() before returning. Generated protocol buffer classes will override this method to provide C++-based deserialization. If a subclass does not implement this method, MergeFromString() will fall back to using pure-Python parsing. """ raise NotImplementedError def __getstate__(self): """Return the pickled representation of the data inside protocol buffer, which is the same as its binary-encoded representation (as a string).""" return self.Encode() def __setstate__(self, contents_): """Restore the pickled representation of the data inside protocol buffer. Note that the mechanism underlying pickle.load() does not call __init__.""" self.__init__(contents=contents_) def sendCommand(self, server, url, response, follow_redirects=1, secure=0, keyfile=None, certfile=None): """posts the protocol buffer to the desired url on the server and puts the return data into the protocol buffer 'response' NOTE: The underlying socket raises the 'error' exception for all I/O related errors (can't connect, etc.). If 'response' is None, the server's PB response will be ignored. The optional 'follow_redirects' argument indicates the number of HTTP redirects that are followed before giving up and raising an exception. The default is 1. If 'secure' is true, HTTPS will be used instead of HTTP. Also, 'keyfile' and 'certfile' may be set for client authentication. """ data = self.Encode() if secure: if keyfile and certfile: conn = six.moves.http_client.HTTPSConnection(server, key_file=keyfile, cert_file=certfile) else: conn = six.moves.http_client.HTTPSConnection(server) else: conn = six.moves.http_client.HTTPConnection(server) conn.putrequest("POST", url) conn.putheader("Content-Length", "%d" %len(data)) conn.endheaders() conn.send(data) resp = conn.getresponse() if follow_redirects > 0 and resp.status == 302: m = URL_RE.match(resp.getheader('Location')) if m: protocol, server, url = m.groups() return self.sendCommand(server, url, response, follow_redirects=follow_redirects - 1, secure=(protocol == 'https'), keyfile=keyfile, certfile=certfile) if resp.status != 200: raise ProtocolBufferReturnError(resp.status) if response is not None: response.ParseFromString(resp.read()) return response def sendSecureCommand(self, server, keyfile, certfile, url, response, follow_redirects=1): """posts the protocol buffer via https to the desired url on the server, using the specified key and certificate files, and puts the return data int othe protocol buffer 'response'. See caveats in sendCommand. You need an SSL-aware build of the Python2 interpreter to use this command. (Python1 is not supported). An SSL build of python2.2 is in /home/build/buildtools/python-ssl-2.2 . An SSL build of python is standard on all prod machines. keyfile: Contains our private RSA key certfile: Contains SSL certificate for remote host Specify None for keyfile/certfile if you don't want to do client auth. """ return self.sendCommand(server, url, response, follow_redirects=follow_redirects, secure=1, keyfile=keyfile, certfile=certfile) def __str__(self, prefix="", printElemNumber=0): """Returns nicely formatted contents of this protocol buffer.""" raise NotImplementedError def ToASCII(self): """Returns the protocol buffer as a human-readable string.""" return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII) def ToShortASCII(self): """Returns the protocol buffer as an ASCII string. The output is short, leaving out newlines and some other niceties. Defers to the C++ ProtocolPrinter class in SYMBOLIC_SHORT mode. """ return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII) # Note that these must be consistent with the ProtocolPrinter::Level C++ # enum. _NUMERIC_ASCII = 0 _SYMBOLIC_SHORT_ASCII = 1 _SYMBOLIC_FULL_ASCII = 2 def _CToASCII(self, output_format): """Calls into C++ ASCII-generating code. Generated protocol buffer classes will override this method to provide C++-based ASCII output. """ raise NotImplementedError def ParseASCII(self, ascii_string): """Parses a string generated by ToASCII() or by the C++ DebugString() method, initializing this protocol buffer with its contents. This method raises a ValueError if it encounters an unknown field. """ raise NotImplementedError def ParseASCIIIgnoreUnknown(self, ascii_string): """Parses a string generated by ToASCII() or by the C++ DebugString() method, initializing this protocol buffer with its contents. Ignores unknown fields. """ raise NotImplementedError def Equals(self, other): """Returns whether or not this protocol buffer is equivalent to another. This assumes that self and other are of the same type. """ raise NotImplementedError def __eq__(self, other): """Implementation of operator ==.""" # If self and other are of different types we return NotImplemented, which # tells the Python interpreter to try some other methods of measuring # equality before finally performing an identity comparison. This allows # other classes to implement custom __eq__ or __ne__ methods. # See http://docs.sympy.org/_sources/python-comparisons.txt if other.__class__ is self.__class__: return self.Equals(other) return NotImplemented def __ne__(self, other): """Implementation of operator !=.""" # We repeat code for __ne__ instead of returning "not (self == other)" # so that we can return NotImplemented when comparing against an object of # a different type. # See http://bugs.python.org/msg76374 for an example of when __ne__ might # return something other than the Boolean opposite of __eq__. if other.__class__ is self.__class__: return not self.Equals(other) return NotImplemented ##################################### # methods power-users might want # ##################################### def Output(self, e): """write self to the encoder 'e'.""" dbg = [] if not self.IsInitialized(dbg): raise ProtocolBufferEncodeError('\n\t'.join(dbg)) self.OutputUnchecked(e) return def OutputUnchecked(self, e): """write self to the encoder 'e', don't check for initialization.""" raise NotImplementedError def OutputPartial(self, e): """write self to the encoder 'e', don't check for initialization and don't assume required fields exist.""" raise NotImplementedError def Parse(self, d): """reads data from the Decoder 'd'.""" self.Clear() self.Merge(d) return def Merge(self, d): """merges data from the Decoder 'd'.""" self.TryMerge(d) dbg = [] if not self.IsInitialized(dbg): raise ProtocolBufferDecodeError('\n\t'.join(dbg)) return def TryMerge(self, d): """merges data from the Decoder 'd'.""" raise NotImplementedError def CopyFrom(self, pb): """copy data from another protocol buffer""" if (pb == self): return self.Clear() self.MergeFrom(pb) def MergeFrom(self, pb): """merge data from another protocol buffer""" raise NotImplementedError ##################################### # helper methods for subclasses # ##################################### def lengthVarInt32(self, n): return self.lengthVarInt64(n) def lengthVarInt64(self, n): if n < 0: return 10 # ceil(64/7) result = 0 while 1: result += 1 n >>= 7 if n == 0: break return result def lengthString(self, n): return self.lengthVarInt32(n) + n def DebugFormat(self, value): return "%s" % value def DebugFormatInt32(self, value): if (value <= -2000000000 or value >= 2000000000): return self.DebugFormatFixed32(value) return "%d" % value def DebugFormatInt64(self, value): if (value <= -20000000000000 or value >= 20000000000000): return self.DebugFormatFixed64(value) return "%d" % value def DebugFormatString(self, value): # For now we only escape the bare minimum to insure interoperability # and redability. In the future we may want to mimick the c++ behavior # more closely, but this will make the code a lot more messy. def escape(c): o = ord(c) if o == 10: return r"\n" # optional escape if o == 39: return r"\'" # optional escape if o == 34: return r'\"' # necessary escape if o == 92: return r"\\" # necessary escape if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes return c return '"' + "".join(escape(c) for c in value) + '"' def DebugFormatFloat(self, value): return "%ff" % value def DebugFormatFixed32(self, value): if (value < 0): value += (1<<32) return "0x%x" % value def DebugFormatFixed64(self, value): if (value < 0): value += (1<<64) return "0x%x" % value def DebugFormatBool(self, value): if value: return "true" else: return "false" # types of fields, must match Proto::Type and net/proto/protocoltype.proto TYPE_DOUBLE = 1 TYPE_FLOAT = 2 TYPE_INT64 = 3 TYPE_UINT64 = 4 TYPE_INT32 = 5 TYPE_FIXED64 = 6 TYPE_FIXED32 = 7 TYPE_BOOL = 8 TYPE_STRING = 9 TYPE_GROUP = 10 TYPE_FOREIGN = 11 # debug string for extensions _TYPE_TO_DEBUG_STRING = { TYPE_INT32: ProtocolMessage.DebugFormatInt32, TYPE_INT64: ProtocolMessage.DebugFormatInt64, TYPE_UINT64: ProtocolMessage.DebugFormatInt64, TYPE_FLOAT: ProtocolMessage.DebugFormatFloat, TYPE_STRING: ProtocolMessage.DebugFormatString, TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32, TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64, TYPE_BOOL: ProtocolMessage.DebugFormatBool } # users of protocol buffers usually won't need to concern themselves # with either Encoders or Decoders. class Encoder: # types of data NUMERIC = 0 DOUBLE = 1 STRING = 2 STARTGROUP = 3 ENDGROUP = 4 FLOAT = 5 MAX_TYPE = 6 def __init__(self): self.buf = array.array('B') return def buffer(self): return self.buf def put8(self, v): if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError("u8 too big") self.buf.append(v & 255) return def put16(self, v): if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError("u16 too big") self.buf.append((v >> 0) & 255) self.buf.append((v >> 8) & 255) return def put32(self, v): if v < 0 or v >= (1<<32): raise ProtocolBufferEncodeError("u32 too big") self.buf.append((v >> 0) & 255) self.buf.append((v >> 8) & 255) self.buf.append((v >> 16) & 255) self.buf.append((v >> 24) & 255) return def put64(self, v): if v < 0 or v >= (1<<64): raise ProtocolBufferEncodeError("u64 too big") self.buf.append((v >> 0) & 255) self.buf.append((v >> 8) & 255) self.buf.append((v >> 16) & 255) self.buf.append((v >> 24) & 255) self.buf.append((v >> 32) & 255) self.buf.append((v >> 40) & 255) self.buf.append((v >> 48) & 255) self.buf.append((v >> 56) & 255) return def putVarInt32(self, v): # Profiling has shown this code to be very performance critical # so we duplicate code, go for early exits when possible, etc. # VarInt32 gets more unrolling because VarInt32s are far and away # the most common element in protobufs (field tags and string # lengths), so they get more attention. They're also more # likely to fit in one byte (string lengths again), so we # check and bail out early if possible. buf_append = self.buf.append # cache attribute lookup if v & 127 == v: buf_append(v) return if v >= 0x80000000 or v < -0x80000000: # python2.4 doesn't fold constants raise ProtocolBufferEncodeError("int32 too big") if v < 0: v += 0x10000000000000000 while True: bits = v & 127 v >>= 7 if v: bits |= 128 buf_append(bits) if not v: break return def putVarInt64(self, v): buf_append = self.buf.append if v >= 0x8000000000000000 or v < -0x8000000000000000: raise ProtocolBufferEncodeError("int64 too big") if v < 0: v += 0x10000000000000000 while True: bits = v & 127 v >>= 7 if v: bits |= 128 buf_append(bits) if not v: break return def putVarUint64(self, v): buf_append = self.buf.append if v < 0 or v >= 0x10000000000000000: raise ProtocolBufferEncodeError("uint64 too big") while True: bits = v & 127 v >>= 7 if v: bits |= 128 buf_append(bits) if not v: break return def putFloat(self, v): a = array.array('B') a.fromstring(struct.pack(" self.limit: raise ProtocolBufferDecodeError("truncated") self.idx += n return def skipData(self, tag): t = tag & 7 # tag format type if t == Encoder.NUMERIC: self.getVarInt64() elif t == Encoder.DOUBLE: self.skip(8) elif t == Encoder.STRING: n = self.getVarInt32() self.skip(n) elif t == Encoder.STARTGROUP: while 1: t = self.getVarInt32() if (t & 7) == Encoder.ENDGROUP: break else: self.skipData(t) if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP): raise ProtocolBufferDecodeError("corrupted") elif t == Encoder.ENDGROUP: raise ProtocolBufferDecodeError("corrupted") elif t == Encoder.FLOAT: self.skip(4) else: raise ProtocolBufferDecodeError("corrupted") # these are all unsigned gets def get8(self): if self.idx >= self.limit: raise ProtocolBufferDecodeError("truncated") c = self.buf[self.idx] self.idx += 1 return c def get16(self): if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError("truncated") c = self.buf[self.idx] d = self.buf[self.idx + 1] self.idx += 2 return (d << 8) | c def get32(self): if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated") c = self.buf[self.idx] d = self.buf[self.idx + 1] e = self.buf[self.idx + 2] f = int(self.buf[self.idx + 3]) self.idx += 4 return (f << 24) | (e << 16) | (d << 8) | c def get64(self): if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated") c = self.buf[self.idx] d = self.buf[self.idx + 1] e = self.buf[self.idx + 2] f = int(self.buf[self.idx + 3]) g = int(self.buf[self.idx + 4]) h = int(self.buf[self.idx + 5]) i = int(self.buf[self.idx + 6]) j = int(self.buf[self.idx + 7]) self.idx += 8 return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24) | (e << 16) | (d << 8) | c) def getVarInt32(self): # getVarInt32 gets different treatment than other integer getter # functions due to the much larger number of varInt32s and also # varInt32s that fit in one byte. See the comment at putVarInt32. b = self.get8() if not (b & 128): return b result = int(0) shift = 0 while 1: result |= (int(b & 127) << shift) shift += 7 if not (b & 128): if result >= 0x10000000000000000: # (1L << 64): raise ProtocolBufferDecodeError("corrupted") break if shift >= 64: raise ProtocolBufferDecodeError("corrupted") b = self.get8() if result >= 0x8000000000000000: # (1L << 63) result -= 0x10000000000000000 # (1L << 64) if result >= 0x80000000 or result < -0x80000000: # (1L << 31) raise ProtocolBufferDecodeError("corrupted") return result def getVarInt64(self): result = self.getVarUint64() if result >= (1 << 63): result -= (1 << 64) return result def getVarUint64(self): result = int(0) shift = 0 while 1: if shift >= 64: raise ProtocolBufferDecodeError("corrupted") b = self.get8() result |= (int(b & 127) << shift) shift += 7 if not (b & 128): if result >= (1 << 64): raise ProtocolBufferDecodeError("corrupted") return result return result # make pychecker happy def getFloat(self): if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated") a = self.buf[self.idx:self.idx+4] self.idx += 4 return struct.unpack(" self.limit: raise ProtocolBufferDecodeError("truncated") a = self.buf[self.idx:self.idx+8] self.idx += 8 return struct.unpack(" self.limit: raise ProtocolBufferDecodeError("truncated") r = self.buf[self.idx : self.idx + length] self.idx += length return r.tostring() def getRawString(self): r = self.buf[self.idx:self.limit] self.idx = self.limit return r.tostring() _TYPE_TO_METHOD = { TYPE_DOUBLE: getDouble, TYPE_FLOAT: getFloat, TYPE_FIXED64: get64, TYPE_FIXED32: get32, TYPE_INT32: getVarInt32, TYPE_INT64: getVarInt64, TYPE_UINT64: getVarUint64, TYPE_BOOL: getBoolean, TYPE_STRING: getPrefixedString } ##################################### # extensions # ##################################### class ExtensionIdentifier(object): __slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated', 'default', 'containing_cls', 'composite_cls', 'message_name') def __init__(self, full_name, number, field_type, wire_tag, is_repeated, default): self.full_name = full_name self.number = number self.field_type = field_type self.wire_tag = wire_tag self.is_repeated = is_repeated self.default = default class ExtendableProtocolMessage(ProtocolMessage): def HasExtension(self, extension): """Checks if the message contains a certain non-repeated extension.""" self._VerifyExtensionIdentifier(extension) return extension in self._extension_fields def ClearExtension(self, extension): """Clears the value of extension, so that HasExtension() returns false or ExtensionSize() returns 0.""" self._VerifyExtensionIdentifier(extension) if extension in self._extension_fields: del self._extension_fields[extension] def GetExtension(self, extension, index=None): """Gets the extension value for a certain extension. Args: extension: The ExtensionIdentifier for the extension. index: The index of element to get in a repeated field. Only needed if the extension is repeated. Returns: The value of the extension if exists, otherwise the default value of the extension will be returned. """ self._VerifyExtensionIdentifier(extension) if extension in self._extension_fields: result = self._extension_fields[extension] else: if extension.is_repeated: result = [] elif extension.composite_cls: result = extension.composite_cls() else: result = extension.default if extension.is_repeated: result = result[index] return result def SetExtension(self, extension, *args): """Sets the extension value for a certain scalar type extension. Arg varies according to extension type: - Singular: message.SetExtension(extension, value) - Repeated: message.SetExtension(extension, index, value) where extension: The ExtensionIdentifier for the extension. index: The index of element to set in a repeated field. Only needed if the extension is repeated. value: The value to set. Raises: TypeError if a message type extension is given. """ self._VerifyExtensionIdentifier(extension) if extension.composite_cls: raise TypeError( 'Cannot assign to extension "%s" because it is a composite type.' % extension.full_name) if extension.is_repeated: try: index, value = args except ValueError: raise TypeError( "SetExtension(extension, index, value) for repeated extension " "takes exactly 4 arguments: (%d given)" % (len(args) + 2)) self._extension_fields[extension][index] = value else: try: (value,) = args except ValueError: raise TypeError( "SetExtension(extension, value) for singular extension " "takes exactly 3 arguments: (%d given)" % (len(args) + 2)) self._extension_fields[extension] = value def MutableExtension(self, extension, index=None): """Gets a mutable reference of a message type extension. For repeated extension, index must be specified, and only one element will be returned. For optional extension, if the extension does not exist, a new message will be created and set in parent message. Args: extension: The ExtensionIdentifier for the extension. index: The index of element to mutate in a repeated field. Only needed if the extension is repeated. Returns: The mutable message reference. Raises: TypeError if non-message type extension is given. """ self._VerifyExtensionIdentifier(extension) if extension.composite_cls is None: raise TypeError( 'MutableExtension() cannot be applied to "%s", because it is not a ' 'composite type.' % extension.full_name) if extension.is_repeated: if index is None: raise TypeError( 'MutableExtension(extension, index) for repeated extension ' 'takes exactly 2 arguments: (1 given)') return self.GetExtension(extension, index) if extension in self._extension_fields: return self._extension_fields[extension] else: result = extension.composite_cls() self._extension_fields[extension] = result return result def ExtensionList(self, extension): """Returns a mutable list of extensions. Raises: TypeError if the extension is not repeated. """ self._VerifyExtensionIdentifier(extension) if not extension.is_repeated: raise TypeError( 'ExtensionList() cannot be applied to "%s", because it is not a ' 'repeated extension.' % extension.full_name) if extension in self._extension_fields: return self._extension_fields[extension] result = [] self._extension_fields[extension] = result return result def ExtensionSize(self, extension): """Returns the size of a repeated extension. Raises: TypeError if the extension is not repeated. """ self._VerifyExtensionIdentifier(extension) if not extension.is_repeated: raise TypeError( 'ExtensionSize() cannot be applied to "%s", because it is not a ' 'repeated extension.' % extension.full_name) if extension in self._extension_fields: return len(self._extension_fields[extension]) return 0 def AddExtension(self, extension, value=None): """Appends a new element into a repeated extension. Arg varies according to the extension field type: - Scalar/String: message.AddExtension(extension, value) - Message: mutable_message = AddExtension(extension) Args: extension: The ExtensionIdentifier for the extension. value: The value of the extension if the extension is scalar/string type. The value must NOT be set for message type extensions; set values on the returned message object instead. Returns: A mutable new message if it's a message type extension, or None otherwise. Raises: TypeError if the extension is not repeated, or value is given for message type extensions. """ self._VerifyExtensionIdentifier(extension) if not extension.is_repeated: raise TypeError( 'AddExtension() cannot be applied to "%s", because it is not a ' 'repeated extension.' % extension.full_name) if extension in self._extension_fields: field = self._extension_fields[extension] else: field = [] self._extension_fields[extension] = field # Composite field if extension.composite_cls: if value is not None: raise TypeError( 'value must not be set in AddExtension() for "%s", because it is ' 'a message type extension. Set values on the returned message ' 'instead.' % extension.full_name) msg = extension.composite_cls() field.append(msg) return msg # Scalar and string field field.append(value) def _VerifyExtensionIdentifier(self, extension): if extension.containing_cls != self.__class__: raise TypeError("Containing type of %s is %s, but not %s." % (extension.full_name, extension.containing_cls.__name__, self.__class__.__name__)) def _MergeExtensionFields(self, x): for ext, val in x._extension_fields.items(): if ext.is_repeated: for single_val in val: if ext.composite_cls is None: self.AddExtension(ext, single_val) else: self.AddExtension(ext).MergeFrom(single_val) else: if ext.composite_cls is None: self.SetExtension(ext, val) else: self.MutableExtension(ext).MergeFrom(val) def _ListExtensions(self): return sorted( (ext for ext in self._extension_fields if (not ext.is_repeated) or self.ExtensionSize(ext) > 0), key=lambda item: item.number) def _ExtensionEquals(self, x): extensions = self._ListExtensions() if extensions != x._ListExtensions(): return False for ext in extensions: if ext.is_repeated: if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False for e1, e2 in zip(self.ExtensionList(ext), x.ExtensionList(ext)): if e1 != e2: return False else: if self.GetExtension(ext) != x.GetExtension(ext): return False return True def _OutputExtensionFields(self, out, partial, extensions, start_index, end_field_number): """Serialize a range of extensions. To generate canonical output when encoding, we interleave fields and extensions to preserve tag order. Generated code will prepare a list of ExtensionIdentifier sorted in field number order and call this method to serialize a specific range of extensions. The range is specified by the two arguments, start_index and end_field_number. The method will serialize all extensions[i] with i >= start_index and extensions[i].number < end_field_number. Since extensions argument is sorted by field_number, this is a contiguous range; the first index j not included in that range is returned. The return value can be used as the start_index in the next call to serialize the next range of extensions. Args: extensions: A list of ExtensionIdentifier sorted in field number order. start_index: The start index in the extensions list. end_field_number: The end field number of the extension range. Returns: The first index that is not in the range. Or the size of extensions if all the extensions are within the range. """ def OutputSingleField(ext, value): out.putVarInt32(ext.wire_tag) if ext.field_type == TYPE_GROUP: if partial: value.OutputPartial(out) else: value.OutputUnchecked(out) out.putVarInt32(ext.wire_tag + 1) # End the group elif ext.field_type == TYPE_FOREIGN: if partial: out.putVarInt32(value.ByteSizePartial()) value.OutputPartial(out) else: out.putVarInt32(value.ByteSize()) value.OutputUnchecked(out) else: Encoder._TYPE_TO_METHOD[ext.field_type](out, value) for ext_index, ext in enumerate( itertools.islice(extensions, start_index, None), start=start_index): if ext.number >= end_field_number: # exceeding extension range end. return ext_index if ext.is_repeated: for field in self._extension_fields[ext]: OutputSingleField(ext, field) else: OutputSingleField(ext, self._extension_fields[ext]) return len(extensions) def _ParseOneExtensionField(self, wire_tag, d): number = wire_tag >> 3 if number in self._extensions_by_field_number: ext = self._extensions_by_field_number[number] if wire_tag != ext.wire_tag: # wire_tag doesn't match; discard as unknown field. return if ext.field_type == TYPE_FOREIGN: length = d.getVarInt32() tmp = Decoder(d.buffer(), d.pos(), d.pos() + length) if ext.is_repeated: self.AddExtension(ext).TryMerge(tmp) else: self.MutableExtension(ext).TryMerge(tmp) d.skip(length) elif ext.field_type == TYPE_GROUP: if ext.is_repeated: self.AddExtension(ext).TryMerge(d) else: self.MutableExtension(ext).TryMerge(d) else: value = Decoder._TYPE_TO_METHOD[ext.field_type](d) if ext.is_repeated: self.AddExtension(ext, value) else: self.SetExtension(ext, value) else: # discard unknown extensions. d.skipData(wire_tag) def _ExtensionByteSize(self, partial): size = 0 for extension, value in six.iteritems(self._extension_fields): ftype = extension.field_type tag_size = self.lengthVarInt64(extension.wire_tag) if ftype == TYPE_GROUP: tag_size *= 2 # end tag if extension.is_repeated: size += tag_size * len(value) for single_value in value: size += self._FieldByteSize(ftype, single_value, partial) else: size += tag_size + self._FieldByteSize(ftype, value, partial) return size def _FieldByteSize(self, ftype, value, partial): size = 0 if ftype == TYPE_STRING: size = self.lengthString(len(value)) elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP: if partial: size = self.lengthString(value.ByteSizePartial()) else: size = self.lengthString(value.ByteSize()) elif ftype == TYPE_INT64 or \ ftype == TYPE_UINT64 or \ ftype == TYPE_INT32: size = self.lengthVarInt64(value) else: if ftype in Encoder._TYPE_TO_BYTE_SIZE: size = Encoder._TYPE_TO_BYTE_SIZE[ftype] else: raise AssertionError( 'Extension type %d is not recognized.' % ftype) return size def _ExtensionDebugString(self, prefix, printElemNumber): res = '' extensions = self._ListExtensions() for extension in extensions: value = self._extension_fields[extension] if extension.is_repeated: cnt = 0 for e in value: elm="" if printElemNumber: elm = "(%d)" % cnt if extension.composite_cls is not None: res += prefix + "[%s%s] {\n" % \ (extension.full_name, elm) res += e.__str__(prefix + " ", printElemNumber) res += prefix + "}\n" else: if extension.composite_cls is not None: res += prefix + "[%s] {\n" % extension.full_name res += value.__str__( prefix + " ", printElemNumber) res += prefix + "}\n" else: if extension.field_type in _TYPE_TO_DEBUG_STRING: text_value = _TYPE_TO_DEBUG_STRING[ extension.field_type](self, value) else: text_value = self.DebugFormat(value) res += prefix + "[%s]: %s\n" % (extension.full_name, text_value) return res @staticmethod def _RegisterExtension(cls, extension, composite_cls=None): extension.containing_cls = cls extension.composite_cls = composite_cls if composite_cls is not None: extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME actual_handle = cls._extensions_by_field_number.setdefault( extension.number, extension) if actual_handle is not extension: raise AssertionError( 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 'field number %d.' % (extension.full_name, actual_handle.full_name, cls.__name__, extension.number))