feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared __init__.py for apitools."""
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View File

@@ -0,0 +1,19 @@
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared __init__.py for apitools."""
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View File

@@ -0,0 +1,615 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Services descriptor definitions.
Contains message definitions and functions for converting
service classes into transmittable message format.
Describing an Enum instance, Enum class, Field class or Message class will
generate an appropriate descriptor object that describes that class.
This message can itself be used to transmit information to clients wishing
to know the description of an enum value, enum, field or message without
needing to download the source code. This format is also compatible with
other, non-Python languages.
The descriptors are modeled to be binary compatible with
https://github.com/google/protobuf
NOTE: The names of types and fields are not always the same between these
descriptors and the ones defined in descriptor.proto. This was done in order
to make source code files that use these descriptors easier to read. For
example, it is not necessary to prefix TYPE to all the values in
FieldDescriptor.Variant as is done in descriptor.proto
FieldDescriptorProto.Type.
Example:
class Pixel(messages.Message):
x = messages.IntegerField(1, required=True)
y = messages.IntegerField(2, required=True)
color = messages.BytesField(3)
# Describe Pixel class using message descriptor.
fields = []
field = FieldDescriptor()
field.name = 'x'
field.number = 1
field.label = FieldDescriptor.Label.REQUIRED
field.variant = FieldDescriptor.Variant.INT64
fields.append(field)
field = FieldDescriptor()
field.name = 'y'
field.number = 2
field.label = FieldDescriptor.Label.REQUIRED
field.variant = FieldDescriptor.Variant.INT64
fields.append(field)
field = FieldDescriptor()
field.name = 'color'
field.number = 3
field.label = FieldDescriptor.Label.OPTIONAL
field.variant = FieldDescriptor.Variant.BYTES
fields.append(field)
message = MessageDescriptor()
message.name = 'Pixel'
message.fields = fields
# Describing is the equivalent of building the above message.
message == describe_message(Pixel)
Public Classes:
EnumValueDescriptor: Describes Enum values.
EnumDescriptor: Describes Enum classes.
FieldDescriptor: Describes field instances.
FileDescriptor: Describes a single 'file' unit.
FileSet: Describes a collection of file descriptors.
MessageDescriptor: Describes Message classes.
Public Functions:
describe_enum_value: Describe an individual enum-value.
describe_enum: Describe an Enum class.
describe_field: Describe a Field definition.
describe_file: Describe a 'file' unit from a Python module or object.
describe_file_set: Describe a file set from a list of modules or objects.
describe_message: Describe a Message definition.
"""
import codecs
import types
import six
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
__all__ = [
'EnumDescriptor',
'EnumValueDescriptor',
'FieldDescriptor',
'MessageDescriptor',
'FileDescriptor',
'FileSet',
'DescriptorLibrary',
'describe_enum',
'describe_enum_value',
'describe_field',
'describe_message',
'describe_file',
'describe_file_set',
'describe',
'import_descriptor_loader',
]
# NOTE: MessageField is missing because message fields cannot have
# a default value at this time.
# TODO(user): Support default message values.
#
# Map to functions that convert default values of fields of a given type
# to a string. The function must return a value that is compatible with
# FieldDescriptor.default_value and therefore a unicode string.
_DEFAULT_TO_STRING_MAP = {
messages.IntegerField: six.text_type,
messages.FloatField: six.text_type,
messages.BooleanField: lambda value: value and u'true' or u'false',
messages.BytesField: lambda value: codecs.escape_encode(value)[0],
messages.StringField: lambda value: value,
messages.EnumField: lambda value: six.text_type(value.number),
}
_DEFAULT_FROM_STRING_MAP = {
messages.IntegerField: int,
messages.FloatField: float,
messages.BooleanField: lambda value: value == u'true',
messages.BytesField: lambda value: codecs.escape_decode(value)[0],
messages.StringField: lambda value: value,
messages.EnumField: int,
}
class EnumValueDescriptor(messages.Message):
"""Enum value descriptor.
Fields:
name: Name of enumeration value.
number: Number of enumeration value.
"""
# TODO(user): Why are these listed as optional in descriptor.proto.
# Harmonize?
name = messages.StringField(1, required=True)
number = messages.IntegerField(2,
required=True,
variant=messages.Variant.INT32)
class EnumDescriptor(messages.Message):
"""Enum class descriptor.
Fields:
name: Name of Enum without any qualification.
values: Values defined by Enum class.
"""
name = messages.StringField(1)
values = messages.MessageField(EnumValueDescriptor, 2, repeated=True)
class FieldDescriptor(messages.Message):
"""Field definition descriptor.
Enums:
Variant: Wire format hint sub-types for field.
Label: Values for optional, required and repeated fields.
Fields:
name: Name of field.
number: Number of field.
variant: Variant of field.
type_name: Type name for message and enum fields.
default_value: String representation of default value.
"""
Variant = messages.Variant # pylint:disable=invalid-name
class Label(messages.Enum):
"""Field label."""
OPTIONAL = 1
REQUIRED = 2
REPEATED = 3
name = messages.StringField(1, required=True)
number = messages.IntegerField(3,
required=True,
variant=messages.Variant.INT32)
label = messages.EnumField(Label, 4, default=Label.OPTIONAL)
variant = messages.EnumField(Variant, 5)
type_name = messages.StringField(6)
# For numeric types, contains the original text representation of
# the value.
# For booleans, "true" or "false".
# For strings, contains the default text contents (not escaped in any
# way).
# For bytes, contains the C escaped value. All bytes < 128 are that are
# traditionally considered unprintable are also escaped.
default_value = messages.StringField(7)
class MessageDescriptor(messages.Message):
"""Message definition descriptor.
Fields:
name: Name of Message without any qualification.
fields: Fields defined for message.
message_types: Nested Message classes defined on message.
enum_types: Nested Enum classes defined on message.
"""
name = messages.StringField(1)
fields = messages.MessageField(FieldDescriptor, 2, repeated=True)
message_types = messages.MessageField(
'apitools.base.protorpclite.descriptor.MessageDescriptor', 3,
repeated=True)
enum_types = messages.MessageField(EnumDescriptor, 4, repeated=True)
class FileDescriptor(messages.Message):
"""Description of file containing protobuf definitions.
Fields:
package: Fully qualified name of package that definitions belong to.
message_types: Message definitions contained in file.
enum_types: Enum definitions contained in file.
"""
package = messages.StringField(2)
# TODO(user): Add dependency field
message_types = messages.MessageField(MessageDescriptor, 4, repeated=True)
enum_types = messages.MessageField(EnumDescriptor, 5, repeated=True)
class FileSet(messages.Message):
"""A collection of FileDescriptors.
Fields:
files: Files in file-set.
"""
files = messages.MessageField(FileDescriptor, 1, repeated=True)
def describe_enum_value(enum_value):
"""Build descriptor for Enum instance.
Args:
enum_value: Enum value to provide descriptor for.
Returns:
Initialized EnumValueDescriptor instance describing the Enum instance.
"""
enum_value_descriptor = EnumValueDescriptor()
enum_value_descriptor.name = six.text_type(enum_value.name)
enum_value_descriptor.number = enum_value.number
return enum_value_descriptor
def describe_enum(enum_definition):
"""Build descriptor for Enum class.
Args:
enum_definition: Enum class to provide descriptor for.
Returns:
Initialized EnumDescriptor instance describing the Enum class.
"""
enum_descriptor = EnumDescriptor()
enum_descriptor.name = enum_definition.definition_name().split('.')[-1]
values = []
for number in sorted(enum_definition.numbers()):
value = enum_definition.lookup_by_number(number)
values.append(describe_enum_value(value))
if values:
enum_descriptor.values = values
return enum_descriptor
def describe_field(field_definition):
"""Build descriptor for Field instance.
Args:
field_definition: Field instance to provide descriptor for.
Returns:
Initialized FieldDescriptor instance describing the Field instance.
"""
field_descriptor = FieldDescriptor()
field_descriptor.name = field_definition.name
field_descriptor.number = field_definition.number
field_descriptor.variant = field_definition.variant
if isinstance(field_definition, messages.EnumField):
field_descriptor.type_name = field_definition.type.definition_name()
if isinstance(field_definition, messages.MessageField):
field_descriptor.type_name = (
field_definition.message_type.definition_name())
if field_definition.default is not None:
field_descriptor.default_value = _DEFAULT_TO_STRING_MAP[
type(field_definition)](field_definition.default)
# Set label.
if field_definition.repeated:
field_descriptor.label = FieldDescriptor.Label.REPEATED
elif field_definition.required:
field_descriptor.label = FieldDescriptor.Label.REQUIRED
else:
field_descriptor.label = FieldDescriptor.Label.OPTIONAL
return field_descriptor
def describe_message(message_definition):
"""Build descriptor for Message class.
Args:
message_definition: Message class to provide descriptor for.
Returns:
Initialized MessageDescriptor instance describing the Message class.
"""
message_descriptor = MessageDescriptor()
message_descriptor.name = message_definition.definition_name().split(
'.')[-1]
fields = sorted(message_definition.all_fields(),
key=lambda v: v.number)
if fields:
message_descriptor.fields = [describe_field(field) for field in fields]
try:
nested_messages = message_definition.__messages__
except AttributeError:
pass
else:
message_descriptors = []
for name in nested_messages:
value = getattr(message_definition, name)
message_descriptors.append(describe_message(value))
message_descriptor.message_types = message_descriptors
try:
nested_enums = message_definition.__enums__
except AttributeError:
pass
else:
enum_descriptors = []
for name in nested_enums:
value = getattr(message_definition, name)
enum_descriptors.append(describe_enum(value))
message_descriptor.enum_types = enum_descriptors
return message_descriptor
def describe_file(module):
"""Build a file from a specified Python module.
Args:
module: Python module to describe.
Returns:
Initialized FileDescriptor instance describing the module.
"""
descriptor = FileDescriptor()
descriptor.package = util.get_package_for_module(module)
if not descriptor.package:
descriptor.package = None
message_descriptors = []
enum_descriptors = []
# Need to iterate over all top level attributes of the module looking for
# message and enum definitions. Each definition must be itself described.
for name in sorted(dir(module)):
value = getattr(module, name)
if isinstance(value, type):
if issubclass(value, messages.Message):
message_descriptors.append(describe_message(value))
elif issubclass(value, messages.Enum):
enum_descriptors.append(describe_enum(value))
if message_descriptors:
descriptor.message_types = message_descriptors
if enum_descriptors:
descriptor.enum_types = enum_descriptors
return descriptor
def describe_file_set(modules):
"""Build a file set from a specified Python modules.
Args:
modules: Iterable of Python module to describe.
Returns:
Initialized FileSet instance describing the modules.
"""
descriptor = FileSet()
file_descriptors = []
for module in modules:
file_descriptors.append(describe_file(module))
if file_descriptors:
descriptor.files = file_descriptors
return descriptor
def describe(value):
"""Describe any value as a descriptor.
Helper function for describing any object with an appropriate descriptor
object.
Args:
value: Value to describe as a descriptor.
Returns:
Descriptor message class if object is describable as a descriptor, else
None.
"""
if isinstance(value, types.ModuleType):
return describe_file(value)
elif isinstance(value, messages.Field):
return describe_field(value)
elif isinstance(value, messages.Enum):
return describe_enum_value(value)
elif isinstance(value, type):
if issubclass(value, messages.Message):
return describe_message(value)
elif issubclass(value, messages.Enum):
return describe_enum(value)
return None
@util.positional(1)
def import_descriptor_loader(definition_name, importer=__import__):
"""Find objects by importing modules as needed.
A definition loader is a function that resolves a definition name to a
descriptor.
The import finder resolves definitions to their names by importing modules
when necessary.
Args:
definition_name: Name of definition to find.
importer: Import function used for importing new modules.
Returns:
Appropriate descriptor for any describable type located by name.
Raises:
DefinitionNotFoundError when a name does not refer to either a definition
or a module.
"""
# Attempt to import descriptor as a module.
if definition_name.startswith('.'):
definition_name = definition_name[1:]
if not definition_name.startswith('.'):
leaf = definition_name.split('.')[-1]
if definition_name:
try:
module = importer(definition_name, '', '', [leaf])
except ImportError:
pass
else:
return describe(module)
try:
# Attempt to use messages.find_definition to find item.
return describe(messages.find_definition(definition_name,
importer=__import__))
except messages.DefinitionNotFoundError as err:
# There are things that find_definition will not find, but if
# the parent is loaded, its children can be searched for a
# match.
split_name = definition_name.rsplit('.', 1)
if len(split_name) > 1:
parent, child = split_name
try:
parent_definition = import_descriptor_loader(
parent, importer=importer)
except messages.DefinitionNotFoundError:
# Fall through to original error.
pass
else:
# Check the parent definition for a matching descriptor.
if isinstance(parent_definition, EnumDescriptor):
search_list = parent_definition.values or []
elif isinstance(parent_definition, MessageDescriptor):
search_list = parent_definition.fields or []
else:
search_list = []
for definition in search_list:
if definition.name == child:
return definition
# Still didn't find. Reraise original exception.
raise err
class DescriptorLibrary(object):
"""A descriptor library is an object that contains known definitions.
A descriptor library contains a cache of descriptor objects mapped by
definition name. It contains all types of descriptors except for
file sets.
When a definition name is requested that the library does not know about
it can be provided with a descriptor loader which attempt to resolve the
missing descriptor.
"""
@util.positional(1)
def __init__(self,
descriptors=None,
descriptor_loader=import_descriptor_loader):
"""Constructor.
Args:
descriptors: A dictionary or dictionary-like object that can be used
to store and cache descriptors by definition name.
definition_loader: A function used for resolving missing descriptors.
The function takes a definition name as its parameter and returns
an appropriate descriptor. It may raise DefinitionNotFoundError.
"""
self.__descriptor_loader = descriptor_loader
self.__descriptors = descriptors or {}
def lookup_descriptor(self, definition_name):
"""Lookup descriptor by name.
Get descriptor from library by name. If descriptor is not found will
attempt to find via descriptor loader if provided.
Args:
definition_name: Definition name to find.
Returns:
Descriptor that describes definition name.
Raises:
DefinitionNotFoundError if not descriptor exists for definition name.
"""
try:
return self.__descriptors[definition_name]
except KeyError:
pass
if self.__descriptor_loader:
definition = self.__descriptor_loader(definition_name)
self.__descriptors[definition_name] = definition
return definition
else:
raise messages.DefinitionNotFoundError(
'Could not find definition for %s' % definition_name)
def lookup_package(self, definition_name):
"""Determines the package name for any definition.
Determine the package that any definition name belongs to. May
check parent for package name and will resolve missing
descriptors if provided descriptor loader.
Args:
definition_name: Definition name to find package for.
"""
while True:
descriptor = self.lookup_descriptor(definition_name)
if isinstance(descriptor, FileDescriptor):
return descriptor.package
else:
index = definition_name.rfind('.')
if index < 0:
return None
definition_name = definition_name[:index]

View File

@@ -0,0 +1,119 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Simple protocol message types.
Includes new message and field types that are outside what is defined by the
protocol buffers standard.
"""
import datetime
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
__all__ = [
'DateTimeField',
'DateTimeMessage',
'VoidMessage',
]
class VoidMessage(messages.Message):
"""Empty message."""
class DateTimeMessage(messages.Message):
"""Message to store/transmit a DateTime.
Fields:
milliseconds: Milliseconds since Jan 1st 1970 local time.
time_zone_offset: Optional time zone offset, in minutes from UTC.
"""
milliseconds = messages.IntegerField(1, required=True)
time_zone_offset = messages.IntegerField(2)
class DateTimeField(messages.MessageField):
"""Field definition for datetime values.
Stores a python datetime object as a field. If time zone information is
included in the datetime object, it will be included in
the encoded data when this is encoded/decoded.
"""
type = datetime.datetime
message_type = DateTimeMessage
@util.positional(3)
def __init__(self,
number,
**kwargs):
super(DateTimeField, self).__init__(self.message_type,
number,
**kwargs)
def value_from_message(self, message):
"""Convert DateTimeMessage to a datetime.
Args:
A DateTimeMessage instance.
Returns:
A datetime instance.
"""
message = super(DateTimeField, self).value_from_message(message)
if message.time_zone_offset is None:
return datetime.datetime.fromtimestamp(
message.milliseconds / 1000.0, tz=datetime.timezone.utc).replace(tzinfo=None)
# Need to subtract the time zone offset, because when we call
# datetime.fromtimestamp, it will add the time zone offset to the
# value we pass.
milliseconds = (message.milliseconds -
60000 * message.time_zone_offset)
timezone = util.TimeZoneOffset(message.time_zone_offset)
return datetime.datetime.fromtimestamp(milliseconds / 1000.0,
tz=timezone)
def value_to_message(self, value):
value = super(DateTimeField, self).value_to_message(value)
# First, determine the delta from the epoch, so we can fill in
# DateTimeMessage's milliseconds field.
if value.tzinfo is None:
time_zone_offset = 0
local_epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
else:
time_zone_offset = util.total_seconds(
value.tzinfo.utcoffset(value))
# Determine Jan 1, 1970 local time.
local_epoch = datetime.datetime.fromtimestamp(-time_zone_offset,
tz=value.tzinfo)
delta = value - local_epoch
# Create and fill in the DateTimeMessage, including time zone if
# one was specified.
message = DateTimeMessage()
message.milliseconds = int(util.total_seconds(delta) * 1000)
if value.tzinfo is not None:
utc_offset = value.tzinfo.utcoffset(value)
if utc_offset is not None:
message.time_zone_offset = int(
util.total_seconds(value.tzinfo.utcoffset(value)) / 60)
return message

View File

@@ -0,0 +1,400 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""JSON support for message types.
Public classes:
MessageJSONEncoder: JSON encoder for message objects.
Public functions:
encode_message: Encodes a message in to a JSON string.
decode_message: Merge from a JSON string in to a message.
"""
import base64
import binascii
import logging
import six
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
__all__ = [
'ALTERNATIVE_CONTENT_TYPES',
'CONTENT_TYPE',
'MessageJSONEncoder',
'encode_message',
'decode_message',
'ProtoJson',
]
def _load_json_module():
"""Try to load a valid json module.
There are more than one json modules that might be installed. They are
mostly compatible with one another but some versions may be different.
This function attempts to load various json modules in a preferred order.
It does a basic check to guess if a loaded version of json is compatible.
Returns:
Compatible json module.
Raises:
ImportError if there are no json modules or the loaded json module is
not compatible with ProtoRPC.
"""
first_import_error = None
for module_name in ['json',
'simplejson']:
try:
module = __import__(module_name, {}, {}, 'json')
if not hasattr(module, 'JSONEncoder'):
message = (
'json library "%s" is not compatible with ProtoRPC' %
module_name)
logging.warning(message)
raise ImportError(message)
else:
return module
except ImportError as err:
if not first_import_error:
first_import_error = err
logging.error('Must use valid json library (json or simplejson)')
raise first_import_error # pylint:disable=raising-bad-type
json = _load_json_module()
# TODO: Rename this to MessageJsonEncoder.
class MessageJSONEncoder(json.JSONEncoder):
"""Message JSON encoder class.
Extension of JSONEncoder that can build JSON from a message object.
"""
def __init__(self, protojson_protocol=None, **kwargs):
"""Constructor.
Args:
protojson_protocol: ProtoJson instance.
"""
super(MessageJSONEncoder, self).__init__(**kwargs)
self.__protojson_protocol = (
protojson_protocol or ProtoJson.get_default())
def default(self, value):
"""Return dictionary instance from a message object.
Args:
value: Value to get dictionary for. If not encodable, will
call superclasses default method.
"""
if isinstance(value, messages.Enum):
return str(value)
if six.PY3 and isinstance(value, bytes):
return value.decode('utf8')
if isinstance(value, messages.Message):
result = {}
for field in value.all_fields():
item = value.get_assigned_value(field.name)
if item not in (None, [], ()):
result[field.name] = (
self.__protojson_protocol.encode_field(field, item))
# Handle unrecognized fields, so they're included when a message is
# decoded then encoded.
for unknown_key in value.all_unrecognized_fields():
unrecognized_field, _ = value.get_unrecognized_field_info(
unknown_key)
# Unknown fields are not encoded as they should have been
# processed before we get to here.
result[unknown_key] = unrecognized_field
return result
return super(MessageJSONEncoder, self).default(value)
class ProtoJson(object):
"""ProtoRPC JSON implementation class.
Implementation of JSON based protocol used for serializing and
deserializing message objects. Instances of remote.ProtocolConfig
constructor or used with remote.Protocols.add_protocol. See the
remote.py module for more details.
"""
CONTENT_TYPE = 'application/json'
ALTERNATIVE_CONTENT_TYPES = [
'application/x-javascript',
'text/javascript',
'text/x-javascript',
'text/x-json',
'text/json',
]
def encode_field(self, field, value):
"""Encode a python field value to a JSON value.
Args:
field: A ProtoRPC field instance.
value: A python value supported by field.
Returns:
A JSON serializable value appropriate for field.
"""
if isinstance(field, messages.BytesField):
if field.repeated:
value = [base64.b64encode(byte) for byte in value]
else:
value = base64.b64encode(value)
elif isinstance(field, message_types.DateTimeField):
# DateTimeField stores its data as a RFC 3339 compliant string.
if field.repeated:
value = [i.isoformat() for i in value]
else:
value = value.isoformat()
return value
def encode_message(self, message):
"""Encode Message instance to JSON string.
Args:
Message instance to encode in to JSON string.
Returns:
String encoding of Message instance in protocol JSON format.
Raises:
messages.ValidationError if message is not initialized.
"""
message.check_initialized()
return json.dumps(message, cls=MessageJSONEncoder,
protojson_protocol=self)
def decode_message(self, message_type, encoded_message):
"""Merge JSON structure to Message instance.
Args:
message_type: Message to decode data to.
encoded_message: JSON encoded version of message.
Returns:
Decoded instance of message_type.
Raises:
ValueError: If encoded_message is not valid JSON.
messages.ValidationError if merged message is not initialized.
"""
encoded_message = six.ensure_str(encoded_message)
if not encoded_message.strip():
return message_type()
dictionary = json.loads(encoded_message)
message = self.__decode_dictionary(message_type, dictionary)
message.check_initialized()
return message
def __find_variant(self, value):
"""Find the messages.Variant type that describes this value.
Args:
value: The value whose variant type is being determined.
Returns:
The messages.Variant value that best describes value's type,
or None if it's a type we don't know how to handle.
"""
if isinstance(value, bool):
return messages.Variant.BOOL
elif isinstance(value, six.integer_types):
return messages.Variant.INT64
elif isinstance(value, float):
return messages.Variant.DOUBLE
elif isinstance(value, six.string_types):
return messages.Variant.STRING
elif isinstance(value, (list, tuple)):
# Find the most specific variant that covers all elements.
variant_priority = [None,
messages.Variant.INT64,
messages.Variant.DOUBLE,
messages.Variant.STRING]
chosen_priority = 0
for v in value:
variant = self.__find_variant(v)
try:
priority = variant_priority.index(variant)
except IndexError:
priority = -1
if priority > chosen_priority:
chosen_priority = priority
return variant_priority[chosen_priority]
# Unrecognized type.
return None
def __decode_dictionary(self, message_type, dictionary):
"""Merge dictionary in to message.
Args:
message: Message to merge dictionary in to.
dictionary: Dictionary to extract information from. Dictionary
is as parsed from JSON. Nested objects will also be dictionaries.
"""
message = message_type()
for key, value in six.iteritems(dictionary):
if value is None:
try:
message.reset(key)
except AttributeError:
pass # This is an unrecognized field, skip it.
continue
try:
field = message.field_by_name(key)
except KeyError:
# Save unknown values.
variant = self.__find_variant(value)
if variant:
message.set_unrecognized_field(key, value, variant)
continue
is_enum_field = isinstance(field, messages.EnumField)
is_unrecognized_field = False
if field.repeated:
# This should be unnecessary? Or in fact become an error.
if not isinstance(value, list):
value = [value]
valid_value = []
for item in value:
try:
v = self.decode_field(field, item)
if is_enum_field and v is None:
continue
except messages.DecodeError:
if not is_enum_field:
raise
is_unrecognized_field = True
continue
valid_value.append(v)
setattr(message, field.name, valid_value)
if is_unrecognized_field:
variant = self.__find_variant(value)
if variant:
message.set_unrecognized_field(key, value, variant)
continue
# This is just for consistency with the old behavior.
if value == []:
continue
try:
setattr(message, field.name, self.decode_field(field, value))
except messages.DecodeError:
# Save unknown enum values.
if not is_enum_field:
raise
variant = self.__find_variant(value)
if variant:
message.set_unrecognized_field(key, value, variant)
return message
def decode_field(self, field, value):
"""Decode a JSON value to a python value.
Args:
field: A ProtoRPC field instance.
value: A serialized JSON value.
Return:
A Python value compatible with field.
"""
if isinstance(field, messages.EnumField):
try:
return field.type(value)
except TypeError:
raise messages.DecodeError(
'Invalid enum value "%s"' % (value or ''))
elif isinstance(field, messages.BytesField):
try:
return base64.b64decode(value)
except (binascii.Error, TypeError) as err:
raise messages.DecodeError('Base64 decoding error: %s' % err)
elif isinstance(field, message_types.DateTimeField):
try:
return util.decode_datetime(value, truncate_time=True)
except ValueError as err:
raise messages.DecodeError(err)
elif (isinstance(field, messages.MessageField) and
issubclass(field.type, messages.Message)):
return self.__decode_dictionary(field.type, value)
elif (isinstance(field, messages.FloatField) and
isinstance(value, (six.integer_types, six.string_types))):
try:
return float(value)
except: # pylint:disable=bare-except
pass
elif (isinstance(field, messages.IntegerField) and
isinstance(value, six.string_types)):
try:
return int(value)
except: # pylint:disable=bare-except
pass
return value
@staticmethod
def get_default():
"""Get default instanceof ProtoJson."""
try:
return ProtoJson.__default
except AttributeError:
ProtoJson.__default = ProtoJson()
return ProtoJson.__default
@staticmethod
def set_default(protocol):
"""Set the default instance of ProtoJson.
Args:
protocol: A ProtoJson instance.
"""
if not isinstance(protocol, ProtoJson):
raise TypeError('Expected protocol of type ProtoJson')
ProtoJson.__default = protocol
CONTENT_TYPE = ProtoJson.CONTENT_TYPE
ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES
encode_message = ProtoJson.get_default().encode_message
decode_message = ProtoJson.get_default().decode_message

View File

@@ -0,0 +1,667 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Test utilities for message testing.
Includes module interface test to ensure that public parts of module are
correctly declared in __all__.
Includes message types that correspond to those defined in
services_test.proto.
Includes additional test utilities to make sure encoding/decoding libraries
conform.
"""
import cgi
import datetime
import inspect
import os
import re
import socket
import types
import unittest
import six
from six.moves import range # pylint: disable=redefined-builtin
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
# Unicode of the word "Russian" in cyrillic.
RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439'
# All characters binary value interspersed with nulls.
BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256))
class TestCase(unittest.TestCase):
def assertRaisesWithRegexpMatch(self,
exception,
regexp,
function,
*params,
**kwargs):
"""Check that exception is raised and text matches regular expression.
Args:
exception: Exception type that is expected.
regexp: String regular expression that is expected in error message.
function: Callable to test.
params: Parameters to forward to function.
kwargs: Keyword arguments to forward to function.
"""
try:
function(*params, **kwargs)
self.fail('Expected exception %s was not raised' %
exception.__name__)
except exception as err:
match = bool(re.match(regexp, str(err)))
self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp,
err))
def assertHeaderSame(self, header1, header2):
"""Check that two HTTP headers are the same.
Args:
header1: Header value string 1.
header2: header value string 2.
"""
value1, params1 = cgi.parse_header(header1)
value2, params2 = cgi.parse_header(header2)
self.assertEqual(value1, value2)
self.assertEqual(params1, params2)
def assertIterEqual(self, iter1, iter2):
"""Check two iterators or iterables are equal independent of order.
Similar to Python 2.7 assertItemsEqual. Named differently in order to
avoid potential conflict.
Args:
iter1: An iterator or iterable.
iter2: An iterator or iterable.
"""
list1 = list(iter1)
list2 = list(iter2)
unmatched1 = list()
while list1:
item1 = list1[0]
del list1[0]
for index in range(len(list2)):
if item1 == list2[index]:
del list2[index]
break
else:
unmatched1.append(item1)
error_message = []
for item in unmatched1:
error_message.append(
' Item from iter1 not found in iter2: %r' % item)
for item in list2:
error_message.append(
' Item from iter2 not found in iter1: %r' % item)
if error_message:
self.fail('Collections not equivalent:\n' +
'\n'.join(error_message))
class ModuleInterfaceTest(object):
"""Test to ensure module interface is carefully constructed.
A module interface is the set of public objects listed in the
module __all__ attribute. Modules that that are considered public
should have this interface carefully declared. At all times, the
__all__ attribute should have objects intended to be publically
used and all other objects in the module should be considered
unused.
Protected attributes (those beginning with '_') and other imported
modules should not be part of this set of variables. An exception
is for variables that begin and end with '__' which are implicitly
part of the interface (eg. __name__, __file__, __all__ itself,
etc.).
Modules that are imported in to the tested modules are an
exception and may be left out of the __all__ definition. The test
is done by checking the value of what would otherwise be a public
name and not allowing it to be exported if it is an instance of a
module. Modules that are explicitly exported are for the time
being not permitted.
To use this test class a module should define a new class that
inherits first from ModuleInterfaceTest and then from
test_util.TestCase. No other tests should be added to this test
case, making the order of inheritance less important, but if setUp
for some reason is overidden, it is important that
ModuleInterfaceTest is first in the list so that its setUp method
is invoked.
Multiple inheritance is required so that ModuleInterfaceTest is
not itself a test, and is not itself executed as one.
The test class is expected to have the following class attributes
defined:
MODULE: A reference to the module that is being validated for interface
correctness.
Example:
Module definition (hello.py):
import sys
__all__ = ['hello']
def _get_outputter():
return sys.stdout
def hello():
_get_outputter().write('Hello\n')
Test definition:
import unittest
from protorpc import test_util
import hello
class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
test_util.TestCase):
MODULE = hello
class HelloTest(test_util.TestCase):
... Test 'hello' module ...
if __name__ == '__main__':
unittest.main()
"""
def setUp(self):
"""Set up makes sure that MODULE and IMPORTED_MODULES is defined.
This is a basic configuration test for the test itself so does not
get it's own test case.
"""
if not hasattr(self, 'MODULE'):
self.fail(
"You must define 'MODULE' on ModuleInterfaceTest sub-class "
"%s." % type(self).__name__)
def testAllExist(self):
"""Test that all attributes defined in __all__ exist."""
missing_attributes = []
for attribute in self.MODULE.__all__:
if not hasattr(self.MODULE, attribute):
missing_attributes.append(attribute)
if missing_attributes:
self.fail('%s of __all__ are not defined in module.' %
missing_attributes)
def testAllExported(self):
"""Test that all public attributes not imported are in __all__."""
missing_attributes = []
for attribute in dir(self.MODULE):
if not attribute.startswith('_'):
if (attribute not in self.MODULE.__all__ and
not isinstance(getattr(self.MODULE, attribute),
types.ModuleType) and
attribute != 'with_statement'):
missing_attributes.append(attribute)
if missing_attributes:
self.fail('%s are not modules and not defined in __all__.' %
missing_attributes)
def testNoExportedProtectedVariables(self):
"""Test that there are no protected variables listed in __all__."""
protected_variables = []
for attribute in self.MODULE.__all__:
if attribute.startswith('_'):
protected_variables.append(attribute)
if protected_variables:
self.fail('%s are protected variables and may not be exported.' %
protected_variables)
def testNoExportedModules(self):
"""Test that no modules exist in __all__."""
exported_modules = []
for attribute in self.MODULE.__all__:
try:
value = getattr(self.MODULE, attribute)
except AttributeError:
# This is a different error case tested for in testAllExist.
pass
else:
if isinstance(value, types.ModuleType):
exported_modules.append(attribute)
if exported_modules:
self.fail('%s are modules and may not be exported.' %
exported_modules)
class NestedMessage(messages.Message):
"""Simple message that gets nested in another message."""
a_value = messages.StringField(1, required=True)
class HasNestedMessage(messages.Message):
"""Message that has another message nested in it."""
nested = messages.MessageField(NestedMessage, 1)
repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True)
class HasDefault(messages.Message):
"""Has a default value."""
a_value = messages.StringField(1, default=u'a default')
class OptionalMessage(messages.Message):
"""Contains all message types."""
class SimpleEnum(messages.Enum):
"""Simple enumeration type."""
VAL1 = 1
VAL2 = 2
double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE)
float_value = messages.FloatField(2, variant=messages.Variant.FLOAT)
int64_value = messages.IntegerField(3, variant=messages.Variant.INT64)
uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64)
int32_value = messages.IntegerField(5, variant=messages.Variant.INT32)
bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL)
string_value = messages.StringField(7, variant=messages.Variant.STRING)
bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES)
enum_value = messages.EnumField(SimpleEnum, 10)
class RepeatedMessage(messages.Message):
"""Contains all message types as repeated fields."""
class SimpleEnum(messages.Enum):
"""Simple enumeration type."""
VAL1 = 1
VAL2 = 2
double_value = messages.FloatField(1,
variant=messages.Variant.DOUBLE,
repeated=True)
float_value = messages.FloatField(2,
variant=messages.Variant.FLOAT,
repeated=True)
int64_value = messages.IntegerField(3,
variant=messages.Variant.INT64,
repeated=True)
uint64_value = messages.IntegerField(4,
variant=messages.Variant.UINT64,
repeated=True)
int32_value = messages.IntegerField(5,
variant=messages.Variant.INT32,
repeated=True)
bool_value = messages.BooleanField(6,
variant=messages.Variant.BOOL,
repeated=True)
string_value = messages.StringField(7,
variant=messages.Variant.STRING,
repeated=True)
bytes_value = messages.BytesField(8,
variant=messages.Variant.BYTES,
repeated=True)
enum_value = messages.EnumField(SimpleEnum,
10,
repeated=True)
class HasOptionalNestedMessage(messages.Message):
nested = messages.MessageField(OptionalMessage, 1)
repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True)
# pylint:disable=anomalous-unicode-escape-in-string
class ProtoConformanceTestBase(object):
"""Protocol conformance test base class.
Each supported protocol should implement two methods that support encoding
and decoding of Message objects in that format:
encode_message(message) - Serialize to encoding.
encode_message(message, encoded_message) - Deserialize from encoding.
Tests for the modules where these functions are implemented should extend
this class in order to support basic behavioral expectations. This ensures
that protocols correctly encode and decode message transparently to the
caller.
In order to support these test, the base class should also extend
the TestCase class and implement the following class attributes
which define the encoded version of certain protocol buffers:
encoded_partial:
<OptionalMessage
double_value: 1.23
int64_value: -100000000000
string_value: u"a string"
enum_value: OptionalMessage.SimpleEnum.VAL2
>
encoded_full:
<OptionalMessage
double_value: 1.23
float_value: -2.5
int64_value: -100000000000
uint64_value: 102020202020
int32_value: 1020
bool_value: true
string_value: u"a string\u044f"
bytes_value: b"a bytes\xff\xfe"
enum_value: OptionalMessage.SimpleEnum.VAL2
>
encoded_repeated:
<RepeatedMessage
double_value: [1.23, 2.3]
float_value: [-2.5, 0.5]
int64_value: [-100000000000, 20]
uint64_value: [102020202020, 10]
int32_value: [1020, 718]
bool_value: [true, false]
string_value: [u"a string\u044f", u"another string"]
bytes_value: [b"a bytes\xff\xfe", b"another bytes"]
enum_value: [OptionalMessage.SimpleEnum.VAL2,
OptionalMessage.SimpleEnum.VAL 1]
>
encoded_nested:
<HasNestedMessage
nested: <NestedMessage
a_value: "a string"
>
>
encoded_repeated_nested:
<HasNestedMessage
repeated_nested: [
<NestedMessage a_value: "a string">,
<NestedMessage a_value: "another string">
]
>
unexpected_tag_message:
An encoded message that has an undefined tag or number in the stream.
encoded_default_assigned:
<HasDefault
a_value: "a default"
>
encoded_nested_empty:
<HasOptionalNestedMessage
nested: <OptionalMessage>
>
encoded_invalid_enum:
<OptionalMessage
enum_value: (invalid value for serialization type)
>
encoded_invalid_repeated_enum:
<RepeatedMessage
enum_value: (invalid value for serialization type)
>
"""
encoded_empty_message = ''
def testEncodeInvalidMessage(self):
message = NestedMessage()
self.assertRaises(messages.ValidationError,
self.PROTOLIB.encode_message, message)
def CompareEncoded(self, expected_encoded, actual_encoded):
"""Compare two encoded protocol values.
Can be overridden by sub-classes to special case comparison.
For example, to eliminate white space from output that is not
relevant to encoding.
Args:
expected_encoded: Expected string encoded value.
actual_encoded: Actual string encoded value.
"""
self.assertEqual(expected_encoded, actual_encoded)
def EncodeDecode(self, encoded, expected_message):
message = self.PROTOLIB.decode_message(type(expected_message), encoded)
self.assertEqual(expected_message, message)
self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message))
def testEmptyMessage(self):
self.EncodeDecode(self.encoded_empty_message, OptionalMessage())
def testPartial(self):
"""Test message with a few values set."""
message = OptionalMessage()
message.double_value = 1.23
message.int64_value = -100000000000
message.int32_value = 1020
message.string_value = u'a string'
message.enum_value = OptionalMessage.SimpleEnum.VAL2
self.EncodeDecode(self.encoded_partial, message)
def testFull(self):
"""Test all types."""
message = OptionalMessage()
message.double_value = 1.23
message.float_value = -2.5
message.int64_value = -100000000000
message.uint64_value = 102020202020
message.int32_value = 1020
message.bool_value = True
message.string_value = u'a string\u044f'
message.bytes_value = b'a bytes\xff\xfe'
message.enum_value = OptionalMessage.SimpleEnum.VAL2
self.EncodeDecode(self.encoded_full, message)
def testRepeated(self):
"""Test repeated fields."""
message = RepeatedMessage()
message.double_value = [1.23, 2.3]
message.float_value = [-2.5, 0.5]
message.int64_value = [-100000000000, 20]
message.uint64_value = [102020202020, 10]
message.int32_value = [1020, 718]
message.bool_value = [True, False]
message.string_value = [u'a string\u044f', u'another string']
message.bytes_value = [b'a bytes\xff\xfe', b'another bytes']
message.enum_value = [RepeatedMessage.SimpleEnum.VAL2,
RepeatedMessage.SimpleEnum.VAL1]
self.EncodeDecode(self.encoded_repeated, message)
def testNested(self):
"""Test nested messages."""
nested_message = NestedMessage()
nested_message.a_value = u'a string'
message = HasNestedMessage()
message.nested = nested_message
self.EncodeDecode(self.encoded_nested, message)
def testRepeatedNested(self):
"""Test repeated nested messages."""
nested_message1 = NestedMessage()
nested_message1.a_value = u'a string'
nested_message2 = NestedMessage()
nested_message2.a_value = u'another string'
message = HasNestedMessage()
message.repeated_nested = [nested_message1, nested_message2]
self.EncodeDecode(self.encoded_repeated_nested, message)
def testStringTypes(self):
"""Test that encoding str on StringField works."""
message = OptionalMessage()
message.string_value = 'Latin'
self.EncodeDecode(self.encoded_string_types, message)
def testEncodeUninitialized(self):
"""Test that cannot encode uninitialized message."""
required = NestedMessage()
self.assertRaisesWithRegexpMatch(messages.ValidationError,
"Message NestedMessage is missing "
"required field a_value",
self.PROTOLIB.encode_message,
required)
def testUnexpectedField(self):
"""Test decoding and encoding unexpected fields."""
loaded_message = self.PROTOLIB.decode_message(
OptionalMessage, self.unexpected_tag_message)
# Message should be equal to an empty message, since unknown
# values aren't included in equality.
self.assertEqual(OptionalMessage(), loaded_message)
# Verify that the encoded message matches the source, including the
# unknown value.
self.assertEqual(self.unexpected_tag_message,
self.PROTOLIB.encode_message(loaded_message))
def testDoNotSendDefault(self):
"""Test that default is not sent when nothing is assigned."""
self.EncodeDecode(self.encoded_empty_message, HasDefault())
def testSendDefaultExplicitlyAssigned(self):
"""Test that default is sent when explcitly assigned."""
message = HasDefault()
message.a_value = HasDefault.a_value.default
self.EncodeDecode(self.encoded_default_assigned, message)
def testEncodingNestedEmptyMessage(self):
"""Test encoding a nested empty message."""
message = HasOptionalNestedMessage()
message.nested = OptionalMessage()
self.EncodeDecode(self.encoded_nested_empty, message)
def testEncodingRepeatedNestedEmptyMessage(self):
"""Test encoding a nested empty message."""
message = HasOptionalNestedMessage()
message.repeated_nested = [OptionalMessage(), OptionalMessage()]
self.EncodeDecode(self.encoded_repeated_nested_empty, message)
def testContentType(self):
self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str))
def testDecodeInvalidEnumType(self):
# Since protos need to be able to add new enums, a message should be
# successfully decoded even if the enum value is invalid. Encoding the
# decoded message should result in equivalence with the original
# encoded message containing an invalid enum.
decoded = self.PROTOLIB.decode_message(OptionalMessage,
self.encoded_invalid_enum)
message = OptionalMessage()
self.assertEqual(message, decoded)
encoded = self.PROTOLIB.encode_message(decoded)
self.assertEqual(self.encoded_invalid_enum, encoded)
def testDecodeInvalidRepeatedEnumType(self):
# Since protos need to be able to add new enums, a message should be
# successfully decoded even if the enum value is invalid. Encoding the
# decoded message should result in equivalence with the original
# encoded message containing an invalid enum.
decoded = self.PROTOLIB.decode_message(RepeatedMessage,
self.encoded_invalid_repeated_enum)
message = RepeatedMessage()
message.enum_value = [RepeatedMessage.SimpleEnum.VAL1]
self.assertEqual(message, decoded)
encoded = self.PROTOLIB.encode_message(decoded)
self.assertEqual(self.encoded_invalid_repeated_enum, encoded)
def testDateTimeNoTimeZone(self):
"""Test that DateTimeFields are encoded/decoded correctly."""
class MyMessage(messages.Message):
value = message_types.DateTimeField(1)
value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000)
message = MyMessage(value=value)
decoded = self.PROTOLIB.decode_message(
MyMessage, self.PROTOLIB.encode_message(message))
self.assertEqual(decoded.value, value)
def testDateTimeWithTimeZone(self):
"""Test DateTimeFields with time zones."""
class MyMessage(messages.Message):
value = message_types.DateTimeField(1)
value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000,
util.TimeZoneOffset(8 * 60))
message = MyMessage(value=value)
decoded = self.PROTOLIB.decode_message(
MyMessage, self.PROTOLIB.encode_message(message))
self.assertEqual(decoded.value, value)
def pick_unused_port():
"""Find an unused port to use in tests.
Derived from Damon Kohlers example:
http://code.activestate.com/recipes/531822-pick-unused-port
"""
temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
temp.bind(('localhost', 0))
port = temp.getsockname()[1]
finally:
temp.close()
return port
def get_module_name(module_attribute):
"""Get the module name.
Args:
module_attribute: An attribute of the module.
Returns:
The fully qualified module name or simple module name where
'module_attribute' is defined if the module name is "__main__".
"""
if module_attribute.__module__ == '__main__':
module_file = inspect.getfile(module_attribute)
default = os.path.basename(module_file).split('.')[0]
return default
return module_attribute.__module__

View File

@@ -0,0 +1,314 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common utility library."""
from __future__ import with_statement
import datetime
import functools
import inspect
import logging
import os
import re
import sys
import six
__all__ = [
'Error',
'decode_datetime',
'get_package_for_module',
'positional',
'TimeZoneOffset',
'total_seconds',
]
class Error(Exception):
"""Base class for protorpc exceptions."""
_TIME_ZONE_RE_STRING = r"""
# Examples:
# +01:00
# -05:30
# Z12:00
((?P<z>Z) | (?P<sign>[-+])
(?P<hours>\d\d) :
(?P<minutes>\d\d))$
"""
_TIME_ZONE_RE = re.compile(_TIME_ZONE_RE_STRING, re.IGNORECASE | re.VERBOSE)
def positional(max_positional_args):
"""A decorator that declares only the first N arguments may be positional.
This decorator makes it easy to support Python 3 style keyword-only
parameters. For example, in Python 3 it is possible to write:
def fn(pos1, *, kwonly1=None, kwonly1=None):
...
All named parameters after * must be a keyword:
fn(10, 'kw1', 'kw2') # Raises exception.
fn(10, kwonly1='kw1') # Ok.
Example:
To define a function like above, do:
@positional(1)
def fn(pos1, kwonly1=None, kwonly2=None):
...
If no default value is provided to a keyword argument, it
becomes a required keyword argument:
@positional(0)
def fn(required_kw):
...
This must be called with the keyword parameter:
fn() # Raises exception.
fn(10) # Raises exception.
fn(required_kw=10) # Ok.
When defining instance or class methods always remember to account for
'self' and 'cls':
class MyClass(object):
@positional(2)
def my_method(self, pos1, kwonly1=None):
...
@classmethod
@positional(2)
def my_method(cls, pos1, kwonly1=None):
...
One can omit the argument to 'positional' altogether, and then no
arguments with default values may be passed positionally. This
would be equivalent to placing a '*' before the first argument
with a default value in Python 3. If there are no arguments with
default values, and no argument is given to 'positional', an error
is raised.
@positional
def fn(arg1, arg2, required_kw1=None, required_kw2=0):
...
fn(1, 3, 5) # Raises exception.
fn(1, 3) # Ok.
fn(1, 3, required_kw1=5) # Ok.
Args:
max_positional_arguments: Maximum number of positional arguments. All
parameters after the this index must be keyword only.
Returns:
A decorator that prevents using arguments after max_positional_args from
being used as positional parameters.
Raises:
TypeError if a keyword-only argument is provided as a positional
parameter.
ValueError if no maximum number of arguments is provided and the function
has no arguments with default values.
"""
def positional_decorator(wrapped):
"""Creates a function wraper to enforce number of arguments."""
@functools.wraps(wrapped)
def positional_wrapper(*args, **kwargs):
if len(args) > max_positional_args:
plural_s = ''
if max_positional_args != 1:
plural_s = 's'
raise TypeError('%s() takes at most %d positional argument%s '
'(%d given)' % (wrapped.__name__,
max_positional_args,
plural_s, len(args)))
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, six.integer_types):
return positional_decorator
else:
args, _, _, defaults, *_ = inspect.getfullargspec(max_positional_args)
if defaults is None:
raise ValueError(
'Functions with no keyword arguments must specify '
'max_positional_args')
return positional(len(args) - len(defaults))(max_positional_args)
@positional(1)
def get_package_for_module(module):
"""Get package name for a module.
Helper calculates the package name of a module.
Args:
module: Module to get name for. If module is a string, try to find
module in sys.modules.
Returns:
If module contains 'package' attribute, uses that as package name.
Else, if module is not the '__main__' module, the module __name__.
Else, the base name of the module file name. Else None.
"""
if isinstance(module, six.string_types):
try:
module = sys.modules[module]
except KeyError:
return None
try:
return six.text_type(module.package)
except AttributeError:
if module.__name__ == '__main__':
try:
file_name = module.__file__
except AttributeError:
pass
else:
base_name = os.path.basename(file_name)
split_name = os.path.splitext(base_name)
if len(split_name) == 1:
return six.text_type(base_name)
return u'.'.join(split_name[:-1])
return six.text_type(module.__name__)
def total_seconds(offset):
"""Backport of offset.total_seconds() from python 2.7+."""
seconds = offset.days * 24 * 60 * 60 + offset.seconds
microseconds = seconds * 10**6 + offset.microseconds
return microseconds / (10**6 * 1.0)
class TimeZoneOffset(datetime.tzinfo):
"""Time zone information as encoded/decoded for DateTimeFields."""
def __init__(self, offset):
"""Initialize a time zone offset.
Args:
offset: Integer or timedelta time zone offset, in minutes from UTC.
This can be negative.
"""
super(TimeZoneOffset, self).__init__()
if isinstance(offset, datetime.timedelta):
offset = total_seconds(offset) / 60
self.__offset = offset
def utcoffset(self, _):
"""Get the a timedelta with the time zone's offset from UTC.
Returns:
The time zone offset from UTC, as a timedelta.
"""
return datetime.timedelta(minutes=self.__offset)
def dst(self, _):
"""Get the daylight savings time offset.
The formats that ProtoRPC uses to encode/decode time zone
information don't contain any information about daylight
savings time. So this always returns a timedelta of 0.
Returns:
A timedelta of 0.
"""
return datetime.timedelta(0)
def decode_datetime(encoded_datetime, truncate_time=False):
"""Decode a DateTimeField parameter from a string to a python datetime.
Args:
encoded_datetime: A string in RFC 3339 format.
truncate_time: If true, truncate time string with precision higher than
microsecs.
Returns:
A datetime object with the date and time specified in encoded_datetime.
Raises:
ValueError: If the string is not in a recognized format.
"""
# Check if the string includes a time zone offset. Break out the
# part that doesn't include time zone info. Convert to uppercase
# because all our comparisons should be case-insensitive.
time_zone_match = _TIME_ZONE_RE.search(encoded_datetime)
if time_zone_match:
time_string = encoded_datetime[:time_zone_match.start(1)].upper()
else:
time_string = encoded_datetime.upper()
if '.' in time_string:
format_string = '%Y-%m-%dT%H:%M:%S.%f'
else:
format_string = '%Y-%m-%dT%H:%M:%S'
try:
decoded_datetime = datetime.datetime.strptime(time_string,
format_string)
except ValueError:
if truncate_time and '.' in time_string:
datetime_string, decimal_secs = time_string.split('.')
if len(decimal_secs) > 6:
# datetime can handle only microsecs precision.
truncated_time_string = '{}.{}'.format(
datetime_string, decimal_secs[:6])
decoded_datetime = datetime.datetime.strptime(
truncated_time_string,
format_string)
logging.warning(
'Truncating the datetime string from %s to %s',
time_string, truncated_time_string)
else:
raise
else:
raise
if not time_zone_match:
return decoded_datetime
# Time zone info was included in the parameter. Add a tzinfo
# object to the datetime. Datetimes can't be changed after they're
# created, so we'll need to create a new one.
if time_zone_match.group('z'):
offset_minutes = 0
else:
sign = time_zone_match.group('sign')
hours, minutes = [int(value) for value in
time_zone_match.group('hours', 'minutes')]
offset_minutes = hours * 60 + minutes
if sign == '-':
offset_minutes *= -1
return datetime.datetime(decoded_datetime.year,
decoded_datetime.month,
decoded_datetime.day,
decoded_datetime.hour,
decoded_datetime.minute,
decoded_datetime.second,
decoded_datetime.microsecond,
TimeZoneOffset(offset_minutes))

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python
#
# Copyright 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file replaces the version that imports everything to the top level.
Different from the normal dependency generation process, this file just replaces
a single file (apitools/base/py/__init__.py) from Google internal apitools
during the depenency generation process. This allows us to make a breaking
change to apitools without complicating the process of updating apitools, as
changes are frequently pulled in from upstream.
"""

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Top-level imports for apitools base files."""
# pylint:disable=wildcard-import
# pylint:disable=redefined-builtin
from apitools.base.py.base_api import *
from apitools.base.py.batch import *
from apitools.base.py.credentials_lib import *
from apitools.base.py.encoding import *
from apitools.base.py.exceptions import *
from apitools.base.py.extra_types import *
from apitools.base.py.http_wrapper import *
from apitools.base.py.list_pager import *
from apitools.base.py.transfer import *
from apitools.base.py.util import *
try:
# pylint:disable=no-name-in-module
from apitools.base.py.internal import *
except ImportError:
pass

View File

@@ -0,0 +1,753 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for api services."""
import base64
import contextlib
import datetime
import logging
import pprint
import six
from six.moves import http_client
from six.moves import urllib
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.py import encoding
from apitools.base.py import exceptions
from apitools.base.py import http_wrapper
from apitools.base.py import util
__all__ = [
'ApiMethodInfo',
'ApiUploadInfo',
'BaseApiClient',
'BaseApiService',
'NormalizeApiEndpoint',
]
# TODO(user): Remove this once we quiet the spurious logging in
# oauth2client (or drop oauth2client).
logging.getLogger('oauth2client.util').setLevel(logging.ERROR)
_MAX_URL_LENGTH = 2048
class ApiUploadInfo(messages.Message):
"""Media upload information for a method.
Fields:
accept: (repeated) MIME Media Ranges for acceptable media uploads
to this method.
max_size: (integer) Maximum size of a media upload, such as 3MB
or 1TB (converted to an integer).
resumable_path: Path to use for resumable uploads.
resumable_multipart: (boolean) Whether or not the resumable endpoint
supports multipart uploads.
simple_path: Path to use for simple uploads.
simple_multipart: (boolean) Whether or not the simple endpoint
supports multipart uploads.
"""
accept = messages.StringField(1, repeated=True)
max_size = messages.IntegerField(2)
resumable_path = messages.StringField(3)
resumable_multipart = messages.BooleanField(4)
simple_path = messages.StringField(5)
simple_multipart = messages.BooleanField(6)
class ApiMethodInfo(messages.Message):
"""Configuration info for an API method.
All fields are strings unless noted otherwise.
Fields:
relative_path: Relative path for this method.
flat_path: Expanded version (if any) of relative_path.
method_id: ID for this method.
http_method: HTTP verb to use for this method.
path_params: (repeated) path parameters for this method.
query_params: (repeated) query parameters for this method.
ordered_params: (repeated) ordered list of parameters for
this method.
description: description of this method.
request_type_name: name of the request type.
response_type_name: name of the response type.
request_field: if not null, the field to pass as the body
of this POST request. may also be the REQUEST_IS_BODY
value below to indicate the whole message is the body.
upload_config: (ApiUploadInfo) Information about the upload
configuration supported by this method.
supports_download: (boolean) If True, this method supports
downloading the request via the `alt=media` query
parameter.
api_version_param: API version system parameter for this
method.
"""
relative_path = messages.StringField(1)
flat_path = messages.StringField(2)
method_id = messages.StringField(3)
http_method = messages.StringField(4)
path_params = messages.StringField(5, repeated=True)
query_params = messages.StringField(6, repeated=True)
ordered_params = messages.StringField(7, repeated=True)
description = messages.StringField(8)
request_type_name = messages.StringField(9)
response_type_name = messages.StringField(10)
request_field = messages.StringField(11, default='')
upload_config = messages.MessageField(ApiUploadInfo, 12)
supports_download = messages.BooleanField(13, default=False)
api_version_param = messages.StringField(14)
REQUEST_IS_BODY = '<request>'
def _LoadClass(name, messages_module):
if name.startswith('message_types.'):
_, _, classname = name.partition('.')
return getattr(message_types, classname)
elif '.' not in name:
return getattr(messages_module, name)
else:
raise exceptions.GeneratedClientError('Unknown class %s' % name)
def _RequireClassAttrs(obj, attrs):
for attr in attrs:
attr_name = attr.upper()
if not hasattr(obj, '%s' % attr_name) or not getattr(obj, attr_name):
msg = 'No %s specified for object of class %s.' % (
attr_name, type(obj).__name__)
raise exceptions.GeneratedClientError(msg)
def NormalizeApiEndpoint(api_endpoint):
if not api_endpoint.endswith('/'):
api_endpoint += '/'
return api_endpoint
def _urljoin(base, url): # pylint: disable=invalid-name
"""Custom urljoin replacement supporting : before / in url."""
# In general, it's unsafe to simply join base and url. However, for
# the case of discovery documents, we know:
# * base will never contain params, query, or fragment
# * url will never contain a scheme or net_loc.
# In general, this means we can safely join on /; we just need to
# ensure we end up with precisely one / joining base and url. The
# exception here is the case of media uploads, where url will be an
# absolute url.
if url.startswith('http://') or url.startswith('https://'):
return urllib.parse.urljoin(base, url)
new_base = base if base.endswith('/') else base + '/'
new_url = url[1:] if url.startswith('/') else url
return new_base + new_url
class _UrlBuilder(object):
"""Convenient container for url data."""
def __init__(self, base_url, relative_path=None, query_params=None):
components = urllib.parse.urlsplit(_urljoin(
base_url, relative_path or ''))
if components.fragment:
raise exceptions.ConfigurationValueError(
'Unexpected url fragment: %s' % components.fragment)
self.query_params = urllib.parse.parse_qs(components.query or '')
if query_params is not None:
self.query_params.update(query_params)
self.__scheme = components.scheme
self.__netloc = components.netloc
self.relative_path = components.path or ''
@classmethod
def FromUrl(cls, url):
urlparts = urllib.parse.urlsplit(url)
query_params = urllib.parse.parse_qs(urlparts.query)
base_url = urllib.parse.urlunsplit((
urlparts.scheme, urlparts.netloc, '', None, None))
relative_path = urlparts.path or ''
return cls(
base_url, relative_path=relative_path, query_params=query_params)
@property
def base_url(self):
return urllib.parse.urlunsplit(
(self.__scheme, self.__netloc, '', '', ''))
@base_url.setter
def base_url(self, value):
components = urllib.parse.urlsplit(value)
if components.path or components.query or components.fragment:
raise exceptions.ConfigurationValueError(
'Invalid base url: %s' % value)
self.__scheme = components.scheme
self.__netloc = components.netloc
@property
def query(self):
# TODO(user): In the case that some of the query params are
# non-ASCII, we may silently fail to encode correctly. We should
# figure out who is responsible for owning the object -> str
# conversion.
return urllib.parse.urlencode(self.query_params, True)
@property
def url(self):
if '{' in self.relative_path or '}' in self.relative_path:
raise exceptions.ConfigurationValueError(
'Cannot create url with relative path %s' % self.relative_path)
return urllib.parse.urlunsplit((
self.__scheme, self.__netloc, self.relative_path, self.query, ''))
def _SkipGetCredentials():
"""Hook for skipping credentials. For internal use."""
return False
class BaseApiClient(object):
"""Base class for client libraries."""
MESSAGES_MODULE = None
_API_KEY = ''
_CLIENT_ID = ''
_CLIENT_SECRET = ''
_PACKAGE = ''
_SCOPES = []
_USER_AGENT = ''
def __init__(self, url, credentials=None, get_credentials=True, http=None,
model=None, log_request=False, log_response=False,
num_retries=5, max_retry_wait=60, credentials_args=None,
default_global_params=None, additional_http_headers=None,
check_response_func=None, retry_func=None,
response_encoding=None):
_RequireClassAttrs(self, ('_package', '_scopes', 'messages_module'))
if default_global_params is not None:
util.Typecheck(default_global_params, self.params_type)
self.__default_global_params = default_global_params
self.log_request = log_request
self.log_response = log_response
self.__num_retries = 5
self.__max_retry_wait = 60
# We let the @property machinery below do our validation.
self.num_retries = num_retries
self.max_retry_wait = max_retry_wait
self._credentials = credentials
get_credentials = get_credentials and not _SkipGetCredentials()
if get_credentials and not credentials:
credentials_args = credentials_args or {}
self._SetCredentials(**credentials_args)
self._url = NormalizeApiEndpoint(url)
self._http = http or http_wrapper.GetHttp()
# Note that "no credentials" is totally possible.
if self._credentials is not None:
self._http = self._credentials.authorize(self._http)
# TODO(user): Remove this field when we switch to proto2.
self.__include_fields = None
self.additional_http_headers = additional_http_headers or {}
self.check_response_func = check_response_func
self.retry_func = retry_func
self.response_encoding = response_encoding
# Since we can't change the init arguments without regenerating clients,
# offer this hook to affect FinalizeTransferUrl behavior.
self.overwrite_transfer_urls_with_client_base = False
# TODO(user): Finish deprecating these fields.
_ = model
self.__response_type_model = 'proto'
def _SetCredentials(self, **kwds):
"""Fetch credentials, and set them for this client.
Note that we can't simply return credentials, since creating them
may involve side-effecting self.
Args:
**kwds: Additional keyword arguments are passed on to GetCredentials.
Returns:
None. Sets self._credentials.
"""
args = {
'api_key': self._API_KEY,
'client': self,
'client_id': self._CLIENT_ID,
'client_secret': self._CLIENT_SECRET,
'package_name': self._PACKAGE,
'scopes': self._SCOPES,
'user_agent': self._USER_AGENT,
}
args.update(kwds)
# credentials_lib can be expensive to import so do it only if needed.
from apitools.base.py import credentials_lib
# TODO(user): It's a bit dangerous to pass this
# still-half-initialized self into this method, but we might need
# to set attributes on it associated with our credentials.
# Consider another way around this (maybe a callback?) and whether
# or not it's worth it.
self._credentials = credentials_lib.GetCredentials(**args)
@classmethod
def ClientInfo(cls):
return {
'client_id': cls._CLIENT_ID,
'client_secret': cls._CLIENT_SECRET,
'scope': ' '.join(sorted(util.NormalizeScopes(cls._SCOPES))),
'user_agent': cls._USER_AGENT,
}
@property
def base_model_class(self):
return None
@property
def http(self):
return self._http
@property
def url(self):
return self._url
@classmethod
def GetScopes(cls):
return cls._SCOPES
@property
def params_type(self):
return _LoadClass('StandardQueryParameters', self.MESSAGES_MODULE)
@property
def user_agent(self):
return self._USER_AGENT
@property
def _default_global_params(self):
if self.__default_global_params is None:
# pylint: disable=not-callable
self.__default_global_params = self.params_type()
return self.__default_global_params
def AddGlobalParam(self, name, value):
params = self._default_global_params
setattr(params, name, value)
@property
def global_params(self):
return encoding.CopyProtoMessage(self._default_global_params)
@contextlib.contextmanager
def IncludeFields(self, include_fields):
self.__include_fields = include_fields
yield
self.__include_fields = None
@property
def response_type_model(self):
return self.__response_type_model
@contextlib.contextmanager
def JsonResponseModel(self):
"""In this context, return raw JSON instead of proto."""
old_model = self.response_type_model
self.__response_type_model = 'json'
yield
self.__response_type_model = old_model
@property
def num_retries(self):
return self.__num_retries
@num_retries.setter
def num_retries(self, value):
util.Typecheck(value, six.integer_types)
if value < 0:
raise exceptions.InvalidDataError(
'Cannot have negative value for num_retries')
self.__num_retries = value
@property
def max_retry_wait(self):
return self.__max_retry_wait
@max_retry_wait.setter
def max_retry_wait(self, value):
util.Typecheck(value, six.integer_types)
if value <= 0:
raise exceptions.InvalidDataError(
'max_retry_wait must be a postiive integer')
self.__max_retry_wait = value
@contextlib.contextmanager
def WithRetries(self, num_retries):
old_num_retries = self.num_retries
self.num_retries = num_retries
yield
self.num_retries = old_num_retries
def ProcessRequest(self, method_config, request):
"""Hook for pre-processing of requests."""
if self.log_request:
logging.info(
'Calling method %s with %s: %s', method_config.method_id,
method_config.request_type_name, request)
return request
def ProcessHttpRequest(self, http_request):
"""Hook for pre-processing of http requests."""
http_request.headers.update(self.additional_http_headers)
if self.log_request:
logging.info('Making http %s to %s',
http_request.http_method, http_request.url)
logging.info('Headers: %s', pprint.pformat(http_request.headers))
if http_request.body:
# TODO(user): Make this safe to print in the case of
# non-printable body characters.
logging.info('Body:\n%s',
http_request.loggable_body or http_request.body)
else:
logging.info('Body: (none)')
return http_request
def ProcessResponse(self, method_config, response):
if self.log_response:
logging.info('Response of type %s: %s',
method_config.response_type_name, response)
return response
# TODO(user): Decide where these two functions should live.
def SerializeMessage(self, message):
return encoding.MessageToJson(
message, include_fields=self.__include_fields)
def DeserializeMessage(self, response_type, data):
"""Deserialize the given data as method_config.response_type."""
try:
message = encoding.JsonToMessage(response_type, data)
except (exceptions.InvalidDataFromServerError,
messages.ValidationError, ValueError) as e:
raise exceptions.InvalidDataFromServerError(
'Error decoding response "%s" as type %s: %s' % (
data, response_type.__name__, e))
return message
def FinalizeTransferUrl(self, url):
"""Modify the url for a given transfer, based on auth and version."""
url_builder = _UrlBuilder.FromUrl(url)
if getattr(self.global_params, 'key', None):
url_builder.query_params['key'] = self.global_params.key
if self.overwrite_transfer_urls_with_client_base:
client_url_builder = _UrlBuilder.FromUrl(self._url)
url_builder.base_url = client_url_builder.base_url
return url_builder.url
class BaseApiService(object):
"""Base class for generated API services."""
def __init__(self, client):
self.__client = client
self._method_configs = {}
self._upload_configs = {}
@property
def _client(self):
return self.__client
@property
def client(self):
return self.__client
def GetMethodConfig(self, method):
"""Returns service cached method config for given method."""
method_config = self._method_configs.get(method)
if method_config:
return method_config
func = getattr(self, method, None)
if func is None:
raise KeyError(method)
method_config = getattr(func, 'method_config', None)
if method_config is None:
raise KeyError(method)
self._method_configs[method] = config = method_config()
return config
@classmethod
def GetMethodsList(cls):
return [f.__name__ for f in six.itervalues(cls.__dict__)
if getattr(f, 'method_config', None)]
def GetUploadConfig(self, method):
return self._upload_configs.get(method)
def GetRequestType(self, method):
method_config = self.GetMethodConfig(method)
return getattr(self.client.MESSAGES_MODULE,
method_config.request_type_name)
def GetResponseType(self, method):
method_config = self.GetMethodConfig(method)
return getattr(self.client.MESSAGES_MODULE,
method_config.response_type_name)
def __CombineGlobalParams(self, global_params, default_params):
"""Combine the given params with the defaults."""
util.Typecheck(global_params, (type(None), self.__client.params_type))
result = self.__client.params_type()
global_params = global_params or self.__client.params_type()
for field in result.all_fields():
value = global_params.get_assigned_value(field.name)
if value is None:
value = default_params.get_assigned_value(field.name)
if value not in (None, [], ()):
setattr(result, field.name, value)
return result
def __EncodePrettyPrint(self, query_info):
# The prettyPrint flag needs custom encoding: it should be encoded
# as 0 if False, and ignored otherwise (True is the default).
if not query_info.pop('prettyPrint', True):
query_info['prettyPrint'] = 0
# The One Platform equivalent of prettyPrint is pp, which also needs
# custom encoding.
if not query_info.pop('pp', True):
query_info['pp'] = 0
return query_info
def __FinalUrlValue(self, value, field):
"""Encode value for the URL, using field to skip encoding for bytes."""
if isinstance(field, messages.BytesField) and value is not None:
return base64.urlsafe_b64encode(value)
elif isinstance(value, six.text_type):
return value.encode('utf8')
elif isinstance(value, six.binary_type):
return value.decode('utf8')
elif isinstance(value, datetime.datetime):
return value.isoformat()
return value
def __ConstructQueryParams(self, query_params, request, global_params):
"""Construct a dictionary of query parameters for this request."""
# First, handle the global params.
global_params = self.__CombineGlobalParams(
global_params, self.__client.global_params)
global_param_names = util.MapParamNames(
[x.name for x in self.__client.params_type.all_fields()],
self.__client.params_type)
global_params_type = type(global_params)
query_info = dict(
(param,
self.__FinalUrlValue(getattr(global_params, param),
getattr(global_params_type, param)))
for param in global_param_names)
# Next, add the query params.
query_param_names = util.MapParamNames(query_params, type(request))
request_type = type(request)
query_info.update(
(param,
self.__FinalUrlValue(getattr(request, param, None),
getattr(request_type, param)))
for param in query_param_names)
query_info = dict((k, v) for k, v in query_info.items()
if v is not None)
query_info = self.__EncodePrettyPrint(query_info)
query_info = util.MapRequestParams(query_info, type(request))
return query_info
def __ConstructRelativePath(self, method_config, request,
relative_path=None):
"""Determine the relative path for request."""
python_param_names = util.MapParamNames(
method_config.path_params, type(request))
params = dict([(param, getattr(request, param, None))
for param in python_param_names])
params = util.MapRequestParams(params, type(request))
return util.ExpandRelativePath(method_config, params,
relative_path=relative_path)
def __FinalizeRequest(self, http_request, url_builder):
"""Make any final general adjustments to the request."""
if (http_request.http_method == 'GET' and
len(http_request.url) > _MAX_URL_LENGTH):
http_request.http_method = 'POST'
http_request.headers['x-http-method-override'] = 'GET'
http_request.headers[
'content-type'] = 'application/x-www-form-urlencoded'
http_request.body = url_builder.query
url_builder.query_params = {}
http_request.url = url_builder.url
def __ProcessHttpResponse(self, method_config, http_response, request):
"""Process the given http response."""
if http_response.status_code not in (http_client.OK,
http_client.CREATED,
http_client.NO_CONTENT):
raise exceptions.HttpError.FromResponse(
http_response, method_config=method_config, request=request)
if http_response.status_code == http_client.NO_CONTENT:
# TODO(user): Find out why _replace doesn't seem to work
# here.
http_response = http_wrapper.Response(
info=http_response.info, content='{}',
request_url=http_response.request_url)
content = http_response.content
if self._client.response_encoding and isinstance(content, bytes):
content = content.decode(self._client.response_encoding)
if self.__client.response_type_model == 'json':
return content
response_type = _LoadClass(method_config.response_type_name,
self.__client.MESSAGES_MODULE)
return self.__client.DeserializeMessage(response_type, content)
def __SetBaseHeaders(self, http_request, client):
"""Fill in the basic headers on http_request."""
# TODO(user): Make the default a little better here, and
# include the apitools version.
user_agent = client.user_agent or 'apitools-client/1.0'
http_request.headers['user-agent'] = user_agent
http_request.headers['accept'] = 'application/json'
http_request.headers['accept-encoding'] = 'gzip, deflate'
def __SetBaseSystemParams(self, http_request, method_config):
"""Fill in the system parameters to always set for the method."""
if method_config.api_version_param:
http_request.headers['X-Goog-Api-Version'] = (
method_config.api_version_param)
def __SetBody(self, http_request, method_config, request, upload):
"""Fill in the body on http_request."""
if not method_config.request_field:
return
request_type = _LoadClass(
method_config.request_type_name, self.__client.MESSAGES_MODULE)
if method_config.request_field == REQUEST_IS_BODY:
body_value = request
body_type = request_type
else:
body_value = getattr(request, method_config.request_field)
body_field = request_type.field_by_name(
method_config.request_field)
util.Typecheck(body_field, messages.MessageField)
body_type = body_field.type
# If there was no body provided, we use an empty message of the
# appropriate type.
body_value = body_value or body_type()
if upload and not body_value:
# We're going to fill in the body later.
return
util.Typecheck(body_value, body_type)
http_request.headers['content-type'] = 'application/json'
http_request.body = self.__client.SerializeMessage(body_value)
def PrepareHttpRequest(self, method_config, request, global_params=None,
upload=None, upload_config=None, download=None):
"""Prepares an HTTP request to be sent."""
request_type = _LoadClass(
method_config.request_type_name, self.__client.MESSAGES_MODULE)
util.Typecheck(request, request_type)
request = self.__client.ProcessRequest(method_config, request)
http_request = http_wrapper.Request(
http_method=method_config.http_method)
self.__SetBaseHeaders(http_request, self.__client)
self.__SetBaseSystemParams(http_request, method_config)
self.__SetBody(http_request, method_config, request, upload)
url_builder = _UrlBuilder(
self.__client.url, relative_path=method_config.relative_path)
url_builder.query_params = self.__ConstructQueryParams(
method_config.query_params, request, global_params)
# It's important that upload and download go before we fill in the
# relative path, so that they can replace it.
if upload is not None:
upload.ConfigureRequest(upload_config, http_request, url_builder)
if download is not None:
download.ConfigureRequest(http_request, url_builder)
url_builder.relative_path = self.__ConstructRelativePath(
method_config, request, relative_path=url_builder.relative_path)
self.__FinalizeRequest(http_request, url_builder)
return self.__client.ProcessHttpRequest(http_request)
def _RunMethod(self, method_config, request, global_params=None,
upload=None, upload_config=None, download=None):
"""Call this method with request."""
if upload is not None and download is not None:
# TODO(user): This just involves refactoring the logic
# below into callbacks that we can pass around; in particular,
# the order should be that the upload gets the initial request,
# and then passes its reply to a download if one exists, and
# then that goes to ProcessResponse and is returned.
raise exceptions.NotYetImplementedError(
'Cannot yet use both upload and download at once')
http_request = self.PrepareHttpRequest(
method_config, request, global_params, upload, upload_config,
download)
# TODO(user): Make num_retries customizable on Transfer
# objects, and pass in self.__client.num_retries when initializing
# an upload or download.
if download is not None:
download.InitializeDownload(http_request, client=self.client)
return
http_response = None
if upload is not None:
http_response = upload.InitializeUpload(
http_request, client=self.client)
if http_response is None:
http = self.__client.http
if upload and upload.bytes_http:
http = upload.bytes_http
opts = {
'retries': self.__client.num_retries,
'max_retry_wait': self.__client.max_retry_wait,
}
if self.__client.check_response_func:
opts['check_response_func'] = self.__client.check_response_func
if self.__client.retry_func:
opts['retry_func'] = self.__client.retry_func
http_response = http_wrapper.MakeRequest(
http, http_request, **opts)
return self.ProcessHttpResponse(method_config, http_response, request)
def ProcessHttpResponse(self, method_config, http_response, request=None):
"""Convert an HTTP response to the expected message type."""
return self.__client.ProcessResponse(
method_config,
self.__ProcessHttpResponse(method_config, http_response, request))

View File

@@ -0,0 +1,506 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for handling batch HTTP requests for apitools."""
import collections
import email.generator as generator
import email.mime.multipart as mime_multipart
import email.mime.nonmultipart as mime_nonmultipart
import email.parser as email_parser
import itertools
import time
import uuid
import six
from six.moves import http_client
from six.moves import urllib_parse
from six.moves import range # pylint: disable=redefined-builtin
from apitools.base.py import exceptions
from apitools.base.py import http_wrapper
__all__ = [
'BatchApiRequest',
]
class RequestResponseAndHandler(collections.namedtuple(
'RequestResponseAndHandler', ['request', 'response', 'handler'])):
"""Container for data related to completing an HTTP request.
This contains an HTTP request, its response, and a callback for handling
the response from the server.
Attributes:
request: An http_wrapper.Request object representing the HTTP request.
response: The http_wrapper.Response object returned from the server.
handler: A callback function accepting two arguments, response
and exception. Response is an http_wrapper.Response object, and
exception is an apiclient.errors.HttpError object if an error
occurred, or otherwise None.
"""
class BatchApiRequest(object):
"""Batches multiple api requests into a single request."""
class ApiCall(object):
"""Holds request and response information for each request.
ApiCalls are ultimately exposed to the client once the HTTP
batch request has been completed.
Attributes:
http_request: A client-supplied http_wrapper.Request to be
submitted to the server.
response: A http_wrapper.Response object given by the server as a
response to the user request, or None if an error occurred.
exception: An apiclient.errors.HttpError object if an error
occurred, or None.
"""
def __init__(self, request, retryable_codes, service, method_config):
"""Initialize an individual API request.
Args:
request: An http_wrapper.Request object.
retryable_codes: A list of integer HTTP codes that can
be retried.
service: A service inheriting from base_api.BaseApiService.
method_config: Method config for the desired API request.
"""
self.__retryable_codes = list(
set(retryable_codes + [http_client.UNAUTHORIZED]))
self.__http_response = None
self.__service = service
self.__method_config = method_config
self.http_request = request
# TODO(user): Add some validation to these fields.
self.__response = None
self.__exception = None
@property
def is_error(self):
return self.exception is not None
@property
def response(self):
return self.__response
@property
def exception(self):
return self.__exception
@property
def authorization_failed(self):
return (self.__http_response and (
self.__http_response.status_code == http_client.UNAUTHORIZED))
@property
def terminal_state(self):
if self.__http_response is None:
return False
response_code = self.__http_response.status_code
return response_code not in self.__retryable_codes
def HandleResponse(self, http_response, exception):
"""Handles incoming http response to the request in http_request.
This is intended to be used as a callback function for
BatchHttpRequest.Add.
Args:
http_response: Deserialized http_wrapper.Response object.
exception: apiclient.errors.HttpError object if an error
occurred.
"""
self.__http_response = http_response
self.__exception = exception
if self.terminal_state and not self.__exception:
self.__response = self.__service.ProcessHttpResponse(
self.__method_config, self.__http_response)
def __init__(self, batch_url=None, retryable_codes=None,
response_encoding=None):
"""Initialize a batch API request object.
Args:
batch_url: Base URL for batch API calls.
retryable_codes: A list of integer HTTP codes that can be retried.
response_encoding: The encoding type of response content.
"""
self.api_requests = []
self.retryable_codes = retryable_codes or []
self.batch_url = batch_url or 'https://www.googleapis.com/batch'
self.response_encoding = response_encoding
def Add(self, service, method, request, global_params=None):
"""Add a request to the batch.
Args:
service: A class inheriting base_api.BaseApiService.
method: A string indicated desired method from the service. See
the example in the class docstring.
request: An input message appropriate for the specified
service.method.
global_params: Optional additional parameters to pass into
method.PrepareHttpRequest.
Returns:
None
"""
# Retrieve the configs for the desired method and service.
method_config = service.GetMethodConfig(method)
upload_config = service.GetUploadConfig(method)
# Prepare the HTTP Request.
http_request = service.PrepareHttpRequest(
method_config, request, global_params=global_params,
upload_config=upload_config)
# Create the request and add it to our master list.
api_request = self.ApiCall(
http_request, self.retryable_codes, service, method_config)
self.api_requests.append(api_request)
def Execute(self, http, sleep_between_polls=5, max_retries=5,
max_batch_size=None, batch_request_callback=None):
"""Execute all of the requests in the batch.
Args:
http: httplib2.Http object for use in the request.
sleep_between_polls: Integer number of seconds to sleep between
polls.
max_retries: Max retries. Any requests that have not succeeded by
this number of retries simply report the last response or
exception, whatever it happened to be.
max_batch_size: int, if specified requests will be split in batches
of given size.
batch_request_callback: function of (http_response, exception) passed
to BatchHttpRequest which will be run on any given results.
Returns:
List of ApiCalls.
"""
requests = [request for request in self.api_requests
if not request.terminal_state]
batch_size = max_batch_size or len(requests)
for attempt in range(max_retries):
if attempt:
time.sleep(sleep_between_polls)
for i in range(0, len(requests), batch_size):
# Create a batch_http_request object and populate it with
# incomplete requests.
batch_http_request = BatchHttpRequest(
batch_url=self.batch_url,
callback=batch_request_callback,
response_encoding=self.response_encoding
)
for request in itertools.islice(requests,
i, i + batch_size):
batch_http_request.Add(
request.http_request, request.HandleResponse)
batch_http_request.Execute(http)
if hasattr(http.request, 'credentials'):
if any(request.authorization_failed
for request in itertools.islice(requests,
i, i + batch_size)):
http.request.credentials.refresh(http)
# Collect retryable requests.
requests = [request for request in self.api_requests if not
request.terminal_state]
if not requests:
break
return self.api_requests
class BatchHttpRequest(object):
"""Batches multiple http_wrapper.Request objects into a single request."""
def __init__(self, batch_url, callback=None, response_encoding=None):
"""Constructor for a BatchHttpRequest.
Args:
batch_url: URL to send batch requests to.
callback: A callback to be called for each response, of the
form callback(response, exception). The first parameter is
the deserialized Response object. The second is an
apiclient.errors.HttpError exception object if an HTTP error
occurred while processing the request, or None if no error
occurred.
response_encoding: The encoding type of response content.
"""
# Endpoint to which these requests are sent.
self.__batch_url = batch_url
# Global callback to be called for each individual response in the
# batch.
self.__callback = callback
# Response content will be decoded if this is provided.
self.__response_encoding = response_encoding
# List of requests, responses and handlers.
self.__request_response_handlers = {}
# The last auto generated id.
self.__last_auto_id = itertools.count()
# Unique ID on which to base the Content-ID headers.
self.__base_id = uuid.uuid4()
def _ConvertIdToHeader(self, request_id):
"""Convert an id to a Content-ID header value.
Args:
request_id: String identifier for a individual request.
Returns:
A Content-ID header with the id_ encoded into it. A UUID is
prepended to the value because Content-ID headers are
supposed to be universally unique.
"""
return '<%s+%s>' % (self.__base_id, urllib_parse.quote(request_id))
@staticmethod
def _ConvertHeaderToId(header):
"""Convert a Content-ID header value to an id.
Presumes the Content-ID header conforms to the format that
_ConvertIdToHeader() returns.
Args:
header: A string indicating the Content-ID header value.
Returns:
The extracted id value.
Raises:
BatchError if the header is not in the expected format.
"""
if not (header.startswith('<') or header.endswith('>')):
raise exceptions.BatchError(
'Invalid value for Content-ID: %s' % header)
if '+' not in header:
raise exceptions.BatchError(
'Invalid value for Content-ID: %s' % header)
_, request_id = header[1:-1].rsplit('+', 1)
return urllib_parse.unquote(request_id)
def _SerializeRequest(self, request):
"""Convert a http_wrapper.Request object into a string.
Args:
request: A http_wrapper.Request to serialize.
Returns:
The request as a string in application/http format.
"""
# Construct status line
parsed = urllib_parse.urlsplit(request.url)
request_line = urllib_parse.urlunsplit(
('', '', parsed.path, parsed.query, ''))
if not isinstance(request_line, six.text_type):
request_line = request_line.decode('utf-8')
status_line = u' '.join((
request.http_method,
request_line,
u'HTTP/1.1\n'
))
major, minor = request.headers.get(
'content-type', 'application/json').split('/')
msg = mime_nonmultipart.MIMENonMultipart(major, minor)
# MIMENonMultipart adds its own Content-Type header.
# Keep all of the other headers in `request.headers`.
for key, value in request.headers.items():
if key == 'content-type':
continue
msg[key] = value
msg['Host'] = parsed.netloc
msg.set_unixfrom(None)
if request.body is not None:
msg.set_payload(request.body)
# Serialize the mime message.
str_io = six.StringIO()
# maxheaderlen=0 means don't line wrap headers.
gen = generator.Generator(str_io, maxheaderlen=0)
gen.flatten(msg, unixfrom=False)
body = str_io.getvalue()
return status_line + body
def _DeserializeResponse(self, payload):
"""Convert string into Response and content.
Args:
payload: Header and body string to be deserialized.
Returns:
A Response object
"""
# Strip off the status line.
status_line, payload = payload.split('\n', 1)
_, status, _ = status_line.split(' ', 2)
# Parse the rest of the response.
parser = email_parser.Parser()
msg = parser.parsestr(payload)
# Get the headers.
info = dict(msg)
info['status'] = status
# Create Response from the parsed headers.
content = msg.get_payload()
return http_wrapper.Response(info, content, self.__batch_url)
def _NewId(self):
"""Create a new id.
Auto incrementing number that avoids conflicts with ids already used.
Returns:
A new unique id string.
"""
return str(next(self.__last_auto_id))
def Add(self, request, callback=None):
"""Add a new request.
Args:
request: A http_wrapper.Request to add to the batch.
callback: A callback to be called for this response, of the
form callback(response, exception). The first parameter is the
deserialized response object. The second is an
apiclient.errors.HttpError exception object if an HTTP error
occurred while processing the request, or None if no errors
occurred.
Returns:
None
"""
handler = RequestResponseAndHandler(request, None, callback)
self.__request_response_handlers[self._NewId()] = handler
def _Execute(self, http):
"""Serialize batch request, send to server, process response.
Args:
http: A httplib2.Http object to be used to make the request with.
Raises:
httplib2.HttpLib2Error if a transport error has occured.
apiclient.errors.BatchError if the response is the wrong format.
"""
message = mime_multipart.MIMEMultipart('mixed')
# Message should not write out its own headers.
setattr(message, '_write_headers', lambda self: None)
# Add all the individual requests.
for key in self.__request_response_handlers:
msg = mime_nonmultipart.MIMENonMultipart('application', 'http')
msg['Content-Transfer-Encoding'] = 'binary'
msg['Content-ID'] = self._ConvertIdToHeader(key)
body = self._SerializeRequest(
self.__request_response_handlers[key].request)
msg.set_payload(body)
message.attach(msg)
request = http_wrapper.Request(self.__batch_url, 'POST')
request.body = message.as_string()
request.headers['content-type'] = (
'multipart/mixed; boundary="%s"') % message.get_boundary()
response = http_wrapper.MakeRequest(http, request)
if response.status_code >= 300:
raise exceptions.HttpError.FromResponse(response)
# Prepend with a content-type header so Parser can handle it.
header = 'content-type: %s\r\n\r\n' % response.info['content-type']
content = response.content
if isinstance(content, bytes) and self.__response_encoding:
content = response.content.decode(self.__response_encoding)
parser = email_parser.Parser()
mime_response = parser.parsestr(header + content)
if not mime_response.is_multipart():
raise exceptions.BatchError(
'Response not in multipart/mixed format.')
for part in mime_response.get_payload():
request_id = self._ConvertHeaderToId(part['Content-ID'])
response = self._DeserializeResponse(part.get_payload())
# Disable protected access because namedtuple._replace(...)
# is not actually meant to be protected.
# pylint: disable=protected-access
self.__request_response_handlers[request_id] = (
self.__request_response_handlers[request_id]._replace(
response=response))
def Execute(self, http):
"""Execute all the requests as a single batched HTTP request.
Args:
http: A httplib2.Http object to be used with the request.
Returns:
None
Raises:
BatchError if the response is the wrong format.
"""
self._Execute(http)
for key in self.__request_response_handlers:
response = self.__request_response_handlers[key].response
callback = self.__request_response_handlers[key].handler
exception = None
if response.status_code >= 300:
exception = exceptions.HttpError.FromResponse(response)
if callback is not None:
callback(response, exception)
if self.__callback is not None:
self.__callback(response, exception)

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Small helper class to provide a small slice of a stream.
This class reads ahead to detect if we are at the end of the stream.
"""
from apitools.base.py import exceptions
# TODO(user): Consider replacing this with a StringIO.
class BufferedStream(object):
"""Buffers a stream, reading ahead to determine if we're at the end."""
def __init__(self, stream, start, size):
self.__stream = stream
self.__start_pos = start
self.__buffer_pos = 0
self.__buffered_data = self.__stream.read(size)
self.__stream_at_end = len(self.__buffered_data) < size
self.__end_pos = self.__start_pos + len(self.__buffered_data)
def __str__(self):
return ('Buffered stream %s from position %s-%s with %s '
'bytes remaining' % (self.__stream, self.__start_pos,
self.__end_pos, self._bytes_remaining))
def __len__(self):
return len(self.__buffered_data)
@property
def stream_exhausted(self):
return self.__stream_at_end
@property
def stream_end_position(self):
return self.__end_pos
@property
def _bytes_remaining(self):
return len(self.__buffered_data) - self.__buffer_pos
def read(self, size=None): # pylint: disable=invalid-name
"""Reads from the buffer."""
if size is None or size < 0:
raise exceptions.NotYetImplementedError(
'Illegal read of size %s requested on BufferedStream. '
'Wrapped stream %s is at position %s-%s, '
'%s bytes remaining.' %
(size, self.__stream, self.__start_pos, self.__end_pos,
self._bytes_remaining))
data = ''
if self._bytes_remaining:
size = min(size, self._bytes_remaining)
data = self.__buffered_data[
self.__buffer_pos:self.__buffer_pos + size]
self.__buffer_pos += size
return data

View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python
#
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compression support for apitools."""
from collections import deque
from apitools.base.py import gzip
__all__ = [
'CompressStream',
]
# pylint: disable=invalid-name
# Note: Apitools only uses the default chunksize when compressing.
def CompressStream(in_stream, length=None, compresslevel=2,
chunksize=16777216):
"""Compresses an input stream into a file-like buffer.
This reads from the input stream until either we've stored at least length
compressed bytes, or the input stream has been exhausted.
This supports streams of unknown size.
Args:
in_stream: The input stream to read from.
length: The target number of compressed bytes to buffer in the output
stream. If length is none, the input stream will be compressed
until it's exhausted.
The actual length of the output buffer can vary from the target.
If the input stream is exhaused, the output buffer may be smaller
than expected. If the data is incompressible, the maximum length
can be exceeded by can be calculated to be:
chunksize + 5 * (floor((chunksize - 1) / 16383) + 1) + 17
This accounts for additional header data gzip adds. For the default
16MiB chunksize, this results in the max size of the output buffer
being:
length + 16Mib + 5142 bytes
compresslevel: Optional, defaults to 2. The desired compression level.
chunksize: Optional, defaults to 16MiB. The chunk size used when
reading data from the input stream to write into the output
buffer.
Returns:
A file-like output buffer of compressed bytes, the number of bytes read
from the input stream, and a flag denoting if the input stream was
exhausted.
"""
in_read = 0
in_exhausted = False
out_stream = StreamingBuffer()
with gzip.GzipFile(mode='wb',
fileobj=out_stream,
compresslevel=compresslevel) as compress_stream:
# Read until we've written at least length bytes to the output stream.
while not length or out_stream.length < length:
data = in_stream.read(chunksize)
data_length = len(data)
compress_stream.write(data)
in_read += data_length
# If we read less than requested, the stream is exhausted.
if data_length < chunksize:
in_exhausted = True
break
return out_stream, in_read, in_exhausted
class StreamingBuffer(object):
"""Provides a file-like object that writes to a temporary buffer.
When data is read from the buffer, it is permanently removed. This is
useful when there are memory constraints preventing the entire buffer from
being stored in memory.
"""
def __init__(self):
# The buffer of byte arrays.
self.__buf = deque()
# The number of bytes in __buf.
self.__size = 0
def __len__(self):
return self.__size
def __nonzero__(self):
# For 32-bit python2.x, len() cannot exceed a 32-bit number; avoid
# accidental len() calls from httplib in the form of "if this_object:".
return bool(self.__size)
@property
def length(self):
# For 32-bit python2.x, len() cannot exceed a 32-bit number.
return self.__size
def write(self, data):
# Gzip can write many 0 byte chunks for highly compressible data.
# Prevent them from being added internally.
if data is not None and data:
self.__buf.append(data)
self.__size += len(data)
def read(self, size=None):
"""Read at most size bytes from this buffer.
Bytes read from this buffer are consumed and are permanently removed.
Args:
size: If provided, read no more than size bytes from the buffer.
Otherwise, this reads the entire buffer.
Returns:
The bytes read from this buffer.
"""
if size is None:
size = self.__size
ret_list = []
while size > 0 and self.__buf:
data = self.__buf.popleft()
size -= len(data)
ret_list.append(data)
if size < 0:
ret_list[-1], remainder = ret_list[-1][:size], ret_list[-1][size:]
self.__buf.appendleft(remainder)
ret = b''.join(ret_list)
self.__size -= len(ret)
return ret

View File

@@ -0,0 +1,776 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common credentials classes and constructors."""
from __future__ import print_function
import argparse
import contextlib
import datetime
import json
import os
import threading
import warnings
import httplib2
import oauth2client
import oauth2client.client
from oauth2client import service_account
from oauth2client import tools # for gflags declarations
import six
from six.moves import http_client
from six.moves import urllib
from apitools.base.py import exceptions
from apitools.base.py import util
# App Engine does not support ctypes which are required for the
# monotonic time used in fasteners. Conversely, App Engine does
# not support colocated concurrent processes, so process locks
# are not needed.
try:
import fasteners
_FASTENERS_AVAILABLE = True
except ImportError as import_error:
server_env = os.environ.get('SERVER_SOFTWARE', '')
if not (server_env.startswith('Development') or
server_env.startswith('Google App Engine')):
raise import_error
_FASTENERS_AVAILABLE = False
# Note: we try the oauth2client imports two ways, to accomodate layout
# changes in oauth2client 2.0+. We can remove these once we no longer
# support oauth2client < 2.0.
#
# pylint: disable=wrong-import-order,ungrouped-imports
try:
from oauth2client.contrib import gce
except ImportError:
from oauth2client import gce
try:
from oauth2client.contrib import multiprocess_file_storage
_NEW_FILESTORE = True
except ImportError:
_NEW_FILESTORE = False
try:
from oauth2client.contrib import multistore_file
except ImportError:
from oauth2client import multistore_file
try:
import gflags
FLAGS = gflags.FLAGS
except ImportError:
FLAGS = None
__all__ = [
'CredentialsFromFile',
'GaeAssertionCredentials',
'GceAssertionCredentials',
'GetCredentials',
'GetUserinfo',
'ServiceAccountCredentialsFromFile',
]
# Lock when accessing the cache file to avoid resource contention.
cache_file_lock = threading.Lock()
def SetCredentialsCacheFileLock(lock):
global cache_file_lock # pylint: disable=global-statement
cache_file_lock = lock
# List of additional methods we use when attempting to construct
# credentials. Users can register their own methods here, which we try
# before the defaults.
_CREDENTIALS_METHODS = []
def _RegisterCredentialsMethod(method, position=None):
"""Register a new method for fetching credentials.
This new method should be a function with signature:
client_info, **kwds -> Credentials or None
This method can be used as a decorator, unless position needs to
be supplied.
Note that method must *always* accept arbitrary keyword arguments.
Args:
method: New credential-fetching method.
position: (default: None) Where in the list of methods to
add this; if None, we append. In all but rare cases,
this should be either 0 or None.
Returns:
method, for use as a decorator.
"""
if position is None:
position = len(_CREDENTIALS_METHODS)
else:
position = min(position, len(_CREDENTIALS_METHODS))
_CREDENTIALS_METHODS.insert(position, method)
return method
def GetCredentials(package_name, scopes, client_id, client_secret, user_agent,
credentials_filename=None,
api_key=None, # pylint: disable=unused-argument
client=None, # pylint: disable=unused-argument
oauth2client_args=None,
**kwds):
"""Attempt to get credentials, using an oauth dance as the last resort."""
scopes = util.NormalizeScopes(scopes)
client_info = {
'client_id': client_id,
'client_secret': client_secret,
'scope': ' '.join(sorted(scopes)),
'user_agent': user_agent or '%s-generated/0.1' % package_name,
}
for method in _CREDENTIALS_METHODS:
credentials = method(client_info, **kwds)
if credentials is not None:
return credentials
credentials_filename = credentials_filename or os.path.expanduser(
'~/.apitools.token')
credentials = CredentialsFromFile(credentials_filename, client_info,
oauth2client_args=oauth2client_args)
if credentials is not None:
return credentials
raise exceptions.CredentialsError('Could not create valid credentials')
def ServiceAccountCredentialsFromFile(filename, scopes, user_agent=None):
"""Use the credentials in filename to create a token for scopes."""
filename = os.path.expanduser(filename)
# We have two options, based on our version of oauth2client.
if oauth2client.__version__ > '1.5.2':
# oauth2client >= 2.0.0
credentials = (
service_account.ServiceAccountCredentials.from_json_keyfile_name(
filename, scopes=scopes))
if credentials is not None:
if user_agent is not None:
credentials.user_agent = user_agent
return credentials
else:
# oauth2client < 2.0.0
with open(filename) as keyfile:
service_account_info = json.load(keyfile)
account_type = service_account_info.get('type')
if account_type != oauth2client.client.SERVICE_ACCOUNT:
raise exceptions.CredentialsError(
'Invalid service account credentials: %s' % (filename,))
# pylint: disable=protected-access
credentials = service_account._ServiceAccountCredentials(
service_account_id=service_account_info['client_id'],
service_account_email=service_account_info['client_email'],
private_key_id=service_account_info['private_key_id'],
private_key_pkcs8_text=service_account_info['private_key'],
scopes=scopes, user_agent=user_agent)
# pylint: enable=protected-access
return credentials
def ServiceAccountCredentialsFromP12File(
service_account_name, private_key_filename, scopes, user_agent):
"""Create a new credential from the named .p12 keyfile."""
private_key_filename = os.path.expanduser(private_key_filename)
scopes = util.NormalizeScopes(scopes)
if oauth2client.__version__ > '1.5.2':
# oauth2client >= 2.0.0
credentials = (
service_account.ServiceAccountCredentials.from_p12_keyfile(
service_account_name, private_key_filename, scopes=scopes))
if credentials is not None:
credentials.user_agent = user_agent
return credentials
else:
# oauth2client < 2.0.0
with open(private_key_filename, 'rb') as key_file:
return oauth2client.client.SignedJwtAssertionCredentials(
service_account_name, key_file.read(), scopes,
user_agent=user_agent)
def _GceMetadataRequest(relative_url, use_metadata_ip=False):
"""Request the given url from the GCE metadata service."""
if use_metadata_ip:
base_url = os.environ.get('GCE_METADATA_IP', '169.254.169.254')
else:
base_url = os.environ.get(
'GCE_METADATA_ROOT', 'metadata.google.internal')
url = 'http://' + base_url + '/computeMetadata/v1/' + relative_url
# Extra header requirement can be found here:
# https://developers.google.com/compute/docs/metadata
headers = {'Metadata-Flavor': 'Google'}
request = urllib.request.Request(url, headers=headers)
opener = urllib.request.build_opener(urllib.request.ProxyHandler({}))
try:
response = opener.open(request)
except urllib.error.URLError as e:
raise exceptions.CommunicationError(
'Could not reach metadata service: %s' % e.reason)
return response
class GceAssertionCredentials(gce.AppAssertionCredentials):
"""Assertion credentials for GCE instances."""
def __init__(self, scopes=None, service_account_name='default', **kwds):
"""Initializes the credentials instance.
Args:
scopes: The scopes to get. If None, whatever scopes that are
available to the instance are used.
service_account_name: The service account to retrieve the scopes
from.
**kwds: Additional keyword args.
"""
# If there is a connectivity issue with the metadata server,
# detection calls may fail even if we've already successfully
# identified these scopes in the same execution. However, the
# available scopes don't change once an instance is created,
# so there is no reason to perform more than one query.
self.__service_account_name = six.ensure_text(
service_account_name,
encoding='utf-8',)
cached_scopes = None
cache_filename = kwds.get('cache_filename')
if cache_filename:
cached_scopes = self._CheckCacheFileForMatch(
cache_filename, scopes)
scopes = cached_scopes or self._ScopesFromMetadataServer(scopes)
if cache_filename and not cached_scopes:
self._WriteCacheFile(cache_filename, scopes)
# We check the scopes above, but don't need them again after
# this point. Newer versions of oauth2client let us drop them
# here, but since we support older versions as well, we just
# catch and squelch the warning.
with warnings.catch_warnings():
warnings.simplefilter('ignore')
super(GceAssertionCredentials, self).__init__(scope=scopes, **kwds)
@classmethod
def Get(cls, *args, **kwds):
try:
return cls(*args, **kwds)
except exceptions.Error:
return None
def _CheckCacheFileForMatch(self, cache_filename, scopes):
"""Checks the cache file to see if it matches the given credentials.
Args:
cache_filename: Cache filename to check.
scopes: Scopes for the desired credentials.
Returns:
List of scopes (if cache matches) or None.
"""
creds = { # Credentials metadata dict.
'scopes': sorted(list(scopes)) if scopes else None,
'svc_acct_name': self.__service_account_name,
}
cache_file = _MultiProcessCacheFile(cache_filename)
try:
cached_creds_str = cache_file.LockedRead()
if not cached_creds_str:
return None
cached_creds = json.loads(cached_creds_str)
if creds['svc_acct_name'] == cached_creds['svc_acct_name']:
if creds['scopes'] in (None, cached_creds['scopes']):
return cached_creds['scopes']
except KeyboardInterrupt:
raise
except: # pylint: disable=bare-except
# Treat exceptions as a cache miss.
pass
def _WriteCacheFile(self, cache_filename, scopes):
"""Writes the credential metadata to the cache file.
This does not save the credentials themselves (CredentialStore class
optionally handles that after this class is initialized).
Args:
cache_filename: Cache filename to check.
scopes: Scopes for the desired credentials.
"""
# Credentials metadata dict.
scopes = sorted([six.ensure_text(scope) for scope in scopes])
creds = {'scopes': scopes,
'svc_acct_name': self.__service_account_name}
creds_str = json.dumps(creds)
cache_file = _MultiProcessCacheFile(cache_filename)
try:
cache_file.LockedWrite(creds_str)
except KeyboardInterrupt:
raise
except: # pylint: disable=bare-except
# Treat exceptions as a cache miss.
pass
def _ScopesFromMetadataServer(self, scopes):
"""Returns instance scopes based on GCE metadata server."""
if not util.DetectGce():
raise exceptions.ResourceUnavailableError(
'GCE credentials requested outside a GCE instance')
if not self.GetServiceAccount(self.__service_account_name):
raise exceptions.ResourceUnavailableError(
'GCE credentials requested but service account '
'%s does not exist.' % self.__service_account_name)
if scopes:
scope_ls = util.NormalizeScopes(scopes)
instance_scopes = self.GetInstanceScopes()
if scope_ls > instance_scopes:
raise exceptions.CredentialsError(
'Instance did not have access to scopes %s' % (
sorted(list(scope_ls - instance_scopes)),))
else:
scopes = self.GetInstanceScopes()
return scopes
def GetServiceAccount(self, account):
relative_url = 'instance/service-accounts'
response = _GceMetadataRequest(relative_url)
response_lines = [six.ensure_str(line).rstrip(u'/\n\r')
for line in response.readlines()]
return account in response_lines
def GetInstanceScopes(self):
relative_url = 'instance/service-accounts/{0}/scopes'.format(
self.__service_account_name)
response = _GceMetadataRequest(relative_url)
return util.NormalizeScopes(six.ensure_str(scope).strip()
for scope in response.readlines())
# pylint: disable=arguments-differ
def _refresh(self, do_request):
"""Refresh self.access_token.
This function replaces AppAssertionCredentials._refresh, which
does not use the credential store and is therefore poorly
suited for multi-threaded scenarios.
Args:
do_request: A function matching httplib2.Http.request's signature.
"""
# pylint: disable=protected-access
oauth2client.client.OAuth2Credentials._refresh(self, do_request)
# pylint: enable=protected-access
def _do_refresh_request(self, unused_http_request):
"""Refresh self.access_token by querying the metadata server.
If self.store is initialized, store acquired credentials there.
"""
relative_url = 'instance/service-accounts/{0}/token'.format(
self.__service_account_name)
try:
response = _GceMetadataRequest(relative_url)
except exceptions.CommunicationError:
self.invalid = True
if self.store:
self.store.locked_put(self)
raise
content = six.ensure_str(response.read())
try:
credential_info = json.loads(content)
except ValueError:
raise exceptions.CredentialsError(
'Could not parse response as JSON: %s' % content)
self.access_token = credential_info['access_token']
if 'expires_in' in credential_info:
expires_in = int(credential_info['expires_in'])
self.token_expiry = (
datetime.timedelta(seconds=expires_in) +
datetime.datetime.now(tz=datetime.timezone.utc).replace(tzinfo=None))
else:
self.token_expiry = None
self.invalid = False
if self.store:
self.store.locked_put(self)
def to_json(self):
# OAuth2Client made gce.AppAssertionCredentials unserializable as of
# v3.0, but we need those credentials to be serializable for use with
# this library, so we use AppAssertionCredentials' parent's to_json
# method.
# pylint: disable=bad-super-call
return super(gce.AppAssertionCredentials, self).to_json()
@classmethod
def from_json(cls, json_data):
data = json.loads(json_data)
kwargs = {}
if 'cache_filename' in data.get('kwargs', []):
kwargs['cache_filename'] = data['kwargs']['cache_filename']
# Newer versions of GceAssertionCredentials don't have a "scope"
# attribute.
scope_list = None
if 'scope' in data:
scope_list = [data['scope']]
credentials = GceAssertionCredentials(scopes=scope_list, **kwargs)
if 'access_token' in data:
credentials.access_token = data['access_token']
if 'token_expiry' in data:
credentials.token_expiry = datetime.datetime.strptime(
data['token_expiry'], oauth2client.client.EXPIRY_FORMAT)
if 'invalid' in data:
credentials.invalid = data['invalid']
return credentials
@property
def serialization_data(self):
raise NotImplementedError(
'Cannot serialize credentials for GCE service accounts.')
# TODO(user): Currently, we can't even *load*
# `oauth2client.appengine` without being on appengine, because of how
# it handles imports. Fix that by splitting that module into
# GAE-specific and GAE-independent bits, and guarding imports.
class GaeAssertionCredentials(oauth2client.client.AssertionCredentials):
"""Assertion credentials for Google App Engine apps."""
def __init__(self, scopes, **kwds):
if not util.DetectGae():
raise exceptions.ResourceUnavailableError(
'GCE credentials requested outside a GCE instance')
self._scopes = list(util.NormalizeScopes(scopes))
super(GaeAssertionCredentials, self).__init__(None, **kwds)
@classmethod
def Get(cls, *args, **kwds):
try:
return cls(*args, **kwds)
except exceptions.Error:
return None
@classmethod
def from_json(cls, json_data):
data = json.loads(json_data)
return GaeAssertionCredentials(data['_scopes'])
def _refresh(self, _):
"""Refresh self.access_token.
Args:
_: (ignored) A function matching httplib2.Http.request's signature.
"""
# pylint: disable=import-error
from google.appengine.api import app_identity
try:
token, _ = app_identity.get_access_token(self._scopes)
except app_identity.Error as e:
raise exceptions.CredentialsError(str(e))
self.access_token = token
def sign_blob(self, blob):
"""Cryptographically sign a blob (of bytes).
This method is provided to support a common interface, but
the actual key used for a Google Compute Engine service account
is not available, so it can't be used to sign content.
Args:
blob: bytes, Message to be signed.
Raises:
NotImplementedError, always.
"""
raise NotImplementedError(
'Compute Engine service accounts cannot sign blobs')
def _GetRunFlowFlags(args=None):
"""Retrieves command line flags based on gflags module."""
# There's one rare situation where gsutil will not have argparse
# available, but doesn't need anything depending on argparse anyway,
# since they're bringing their own credentials. So we just allow this
# to fail with an ImportError in those cases.
#
parser = argparse.ArgumentParser(parents=[tools.argparser])
# Get command line argparse flags.
flags, _ = parser.parse_known_args(args=args)
# Allow `gflags` and `argparse` to be used side-by-side.
if hasattr(FLAGS, 'auth_host_name'):
flags.auth_host_name = FLAGS.auth_host_name
if hasattr(FLAGS, 'auth_host_port'):
flags.auth_host_port = FLAGS.auth_host_port
if hasattr(FLAGS, 'auth_local_webserver'):
flags.noauth_local_webserver = (not FLAGS.auth_local_webserver)
return flags
# TODO(user): Switch this from taking a path to taking a stream.
def CredentialsFromFile(path, client_info, oauth2client_args=None):
"""Read credentials from a file."""
user_agent = client_info['user_agent']
scope_key = client_info['scope']
if not isinstance(scope_key, six.string_types):
scope_key = ':'.join(scope_key)
storage_key = client_info['client_id'] + user_agent + scope_key
if _NEW_FILESTORE:
credential_store = multiprocess_file_storage.MultiprocessFileStorage(
path, storage_key)
else:
credential_store = multistore_file.get_credential_storage_custom_string_key( # noqa
path, storage_key)
if hasattr(FLAGS, 'auth_local_webserver'):
FLAGS.auth_local_webserver = False
credentials = credential_store.get()
if credentials is None or credentials.invalid:
print('Generating new OAuth credentials ...')
for _ in range(20):
# If authorization fails, we want to retry, rather than let this
# cascade up and get caught elsewhere. If users want out of the
# retry loop, they can ^C.
try:
flow = oauth2client.client.OAuth2WebServerFlow(**client_info)
flags = _GetRunFlowFlags(args=oauth2client_args)
credentials = tools.run_flow(flow, credential_store, flags)
break
except (oauth2client.client.FlowExchangeError, SystemExit) as e:
# Here SystemExit is "no credential at all", and the
# FlowExchangeError is "invalid" -- usually because
# you reused a token.
print('Invalid authorization: %s' % (e,))
except httplib2.HttpLib2Error as e:
print('Communication error: %s' % (e,))
raise exceptions.CredentialsError(
'Communication error creating credentials: %s' % e)
return credentials
class _MultiProcessCacheFile(object):
"""Simple multithreading and multiprocessing safe cache file.
Notes on behavior:
* the fasteners.InterProcessLock object cannot reliably prevent threads
from double-acquiring a lock. A threading lock is used in addition to
the InterProcessLock. The threading lock is always acquired first and
released last.
* The interprocess lock will not deadlock. If a process can not acquire
the interprocess lock within `_lock_timeout` the call will return as
a cache miss or unsuccessful cache write.
* App Engine environments cannot be process locked because (1) the runtime
does not provide monotonic time and (2) different processes may or may
not share the same machine. Because of this, process locks are disabled
and locking is only guaranteed to protect against multithreaded access.
"""
_lock_timeout = 1
_encoding = 'utf-8'
_thread_lock = threading.Lock()
def __init__(self, filename):
self._file = None
self._filename = filename
if _FASTENERS_AVAILABLE:
self._process_lock_getter = self._ProcessLockAcquired
self._process_lock = fasteners.InterProcessLock(
'{0}.lock'.format(filename))
else:
self._process_lock_getter = self._DummyLockAcquired
self._process_lock = None
@contextlib.contextmanager
def _ProcessLockAcquired(self):
"""Context manager for process locks with timeout."""
try:
is_locked = self._process_lock.acquire(timeout=self._lock_timeout)
yield is_locked
finally:
if is_locked:
self._process_lock.release()
@contextlib.contextmanager
def _DummyLockAcquired(self):
"""Lock context manager for environments without process locks."""
yield True
def LockedRead(self):
"""Acquire an interprocess lock and dump cache contents.
This method safely acquires the locks then reads a string
from the cache file. If the file does not exist and cannot
be created, it will return None. If the locks cannot be
acquired, this will also return None.
Returns:
cache data - string if present, None on failure.
"""
file_contents = None
with self._thread_lock:
if not self._EnsureFileExists():
return None
with self._process_lock_getter() as acquired_plock:
if not acquired_plock:
return None
with open(self._filename, 'rb') as f:
file_contents = f.read().decode(encoding=self._encoding)
return file_contents
def LockedWrite(self, cache_data):
"""Acquire an interprocess lock and write a string.
This method safely acquires the locks then writes a string
to the cache file. If the string is written successfully
the function will return True, if the write fails for any
reason it will return False.
Args:
cache_data: string or bytes to write.
Returns:
bool: success
"""
if isinstance(cache_data, six.text_type):
cache_data = cache_data.encode(encoding=self._encoding)
with self._thread_lock:
if not self._EnsureFileExists():
return False
with self._process_lock_getter() as acquired_plock:
if not acquired_plock:
return False
with open(self._filename, 'wb') as f:
f.write(cache_data)
return True
def _EnsureFileExists(self):
"""Touches a file; returns False on error, True on success."""
if not os.path.exists(self._filename):
old_umask = os.umask(0o177)
try:
open(self._filename, 'a+b').close()
except OSError:
return False
finally:
os.umask(old_umask)
return True
# TODO(user): Push this into oauth2client.
def GetUserinfo(credentials, http=None): # pylint: disable=invalid-name
"""Get the userinfo associated with the given credentials.
This is dependent on the token having either the userinfo.email or
userinfo.profile scope for the given token.
Args:
credentials: (oauth2client.client.Credentials) incoming credentials
http: (httplib2.Http, optional) http instance to use
Returns:
The email address for this token, or None if the required scopes
aren't available.
"""
http = http or httplib2.Http()
url = _GetUserinfoUrl(credentials)
# We ignore communication woes here (i.e. SSL errors, socket
# timeout), as handling these should be done in a common location.
response, content = http.request(url)
if response.status == http_client.BAD_REQUEST:
credentials.refresh(http)
url = _GetUserinfoUrl(credentials)
response, content = http.request(url)
return json.loads(content or '{}') # Save ourselves from an empty reply.
def _GetUserinfoUrl(credentials):
url_root = 'https://oauth2.googleapis.com/tokeninfo'
query_args = {'access_token': credentials.access_token}
return '?'.join((url_root, urllib.parse.urlencode(query_args)))
@_RegisterCredentialsMethod
def _GetServiceAccountCredentials(
client_info, service_account_name=None, service_account_keyfile=None,
service_account_json_keyfile=None, **unused_kwds):
"""Returns ServiceAccountCredentials from give file."""
scopes = client_info['scope'].split()
user_agent = client_info['user_agent']
# Use the .json credentials, if provided.
if service_account_json_keyfile:
return ServiceAccountCredentialsFromFile(
service_account_json_keyfile, scopes, user_agent=user_agent)
# Fall back to .p12 if there's no .json credentials.
if ((service_account_name and not service_account_keyfile) or
(service_account_keyfile and not service_account_name)):
raise exceptions.CredentialsError(
'Service account name or keyfile provided without the other')
if service_account_name is not None:
return ServiceAccountCredentialsFromP12File(
service_account_name, service_account_keyfile, scopes, user_agent)
@_RegisterCredentialsMethod
def _GetGaeServiceAccount(client_info, **unused_kwds):
scopes = client_info['scope'].split(' ')
return GaeAssertionCredentials.Get(scopes=scopes)
@_RegisterCredentialsMethod
def _GetGceServiceAccount(client_info, **unused_kwds):
scopes = client_info['scope'].split(' ')
return GceAssertionCredentials.Get(scopes=scopes)
@_RegisterCredentialsMethod
def _GetApplicationDefaultCredentials(
client_info, skip_application_default_credentials=False,
**unused_kwds):
"""Returns ADC with right scopes."""
scopes = client_info['scope'].split()
if skip_application_default_credentials:
return None
gc = oauth2client.client.GoogleCredentials
with cache_file_lock:
try:
# pylint: disable=protected-access
# We've already done our own check for GAE/GCE
# credentials, we don't want to pay for checking again.
credentials = gc._implicit_credentials_from_files()
except oauth2client.client.ApplicationDefaultCredentialsError:
return None
# If we got back a non-service account credential, we need to use
# a heuristic to decide whether or not the application default
# credential will work for us. We assume that if we're requesting
# cloud-platform, our scopes are a subset of cloud scopes, and the
# ADC will work.
cp = 'https://www.googleapis.com/auth/cloud-platform'
if credentials is None:
return None
if not isinstance(credentials, gc) or cp in scopes:
return credentials.create_scoped(scopes)
return None

View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common code for converting proto to other formats, such as JSON."""
# pylint:disable=wildcard-import
from apitools.base.py.encoding_helper import *
import apitools.base.py.extra_types # pylint:disable=unused-import
# pylint:disable=undefined-all-variable
__all__ = [
'CopyProtoMessage',
'JsonToMessage',
'MessageToJson',
'DictToMessage',
'MessageToDict',
'PyValueToMessage',
'MessageToPyValue',
'MessageToRepr',
'GetCustomJsonFieldMapping',
'AddCustomJsonFieldMapping',
'GetCustomJsonEnumMapping',
'AddCustomJsonEnumMapping',
]

View File

@@ -0,0 +1,808 @@
#!/usr/bin/env python
#
# Copyright 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common code for converting proto to other formats, such as JSON."""
import base64
import collections
import datetime
import json
import six
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import protojson
from apitools.base.py import exceptions
_Codec = collections.namedtuple('_Codec', ['encoder', 'decoder'])
CodecResult = collections.namedtuple('CodecResult', ['value', 'complete'])
class EdgeType(object):
"""The type of transition made by an edge."""
SCALAR = 1
REPEATED = 2
MAP = 3
class ProtoEdge(collections.namedtuple('ProtoEdge',
['type_', 'field', 'index'])):
"""A description of a one-level transition from a message to a value.
Protobuf messages can be arbitrarily nested as fields can be defined with
any "message" type. This nesting property means that there are often many
levels of proto messages within a single message instance. This class can
unambiguously describe a single step from a message to some nested value.
Properties:
type_: EdgeType, The type of transition represented by this edge.
field: str, The name of the message-typed field.
index: Any, Additional data needed to make the transition. The semantics
of the "index" property change based on the value of "type_":
SCALAR: ignored.
REPEATED: a numeric index into "field"'s list.
MAP: a key into "field"'s mapping.
"""
__slots__ = ()
def __str__(self):
if self.type_ == EdgeType.SCALAR:
return self.field
else:
return '{}[{}]'.format(self.field, self.index)
# TODO(user): Make these non-global.
_UNRECOGNIZED_FIELD_MAPPINGS = {}
_CUSTOM_MESSAGE_CODECS = {}
_CUSTOM_FIELD_CODECS = {}
_FIELD_TYPE_CODECS = {}
def MapUnrecognizedFields(field_name):
"""Register field_name as a container for unrecognized fields."""
def Register(cls):
_UNRECOGNIZED_FIELD_MAPPINGS[cls] = field_name
return cls
return Register
def RegisterCustomMessageCodec(encoder, decoder):
"""Register a custom encoder/decoder for this message class."""
def Register(cls):
_CUSTOM_MESSAGE_CODECS[cls] = _Codec(encoder=encoder, decoder=decoder)
return cls
return Register
def RegisterCustomFieldCodec(encoder, decoder):
"""Register a custom encoder/decoder for this field."""
def Register(field):
_CUSTOM_FIELD_CODECS[field] = _Codec(encoder=encoder, decoder=decoder)
return field
return Register
def RegisterFieldTypeCodec(encoder, decoder):
"""Register a custom encoder/decoder for all fields of this type."""
def Register(field_type):
_FIELD_TYPE_CODECS[field_type] = _Codec(
encoder=encoder, decoder=decoder)
return field_type
return Register
def CopyProtoMessage(message):
"""Make a deep copy of a message."""
return JsonToMessage(type(message), MessageToJson(message))
def MessageToJson(message, include_fields=None):
"""Convert the given message to JSON."""
result = _ProtoJsonApiTools.Get().encode_message(message)
return _IncludeFields(result, message, include_fields)
def JsonToMessage(message_type, message):
"""Convert the given JSON to a message of type message_type."""
return _ProtoJsonApiTools.Get().decode_message(message_type, message)
# TODO(user): Do this directly, instead of via JSON.
def DictToMessage(d, message_type):
"""Convert the given dictionary to a message of type message_type."""
return JsonToMessage(message_type, json.dumps(d))
def MessageToDict(message):
"""Convert the given message to a dictionary."""
return json.loads(MessageToJson(message))
def DictToAdditionalPropertyMessage(properties, additional_property_type,
sort_items=False):
"""Convert the given dictionary to an AdditionalProperty message."""
items = properties.items()
if sort_items:
items = sorted(items)
map_ = []
for key, value in items:
map_.append(additional_property_type.AdditionalProperty(
key=key, value=value))
return additional_property_type(additionalProperties=map_)
def PyValueToMessage(message_type, value):
"""Convert the given python value to a message of type message_type."""
return JsonToMessage(message_type, json.dumps(value))
def MessageToPyValue(message):
"""Convert the given message to a python value."""
return json.loads(MessageToJson(message))
def MessageToRepr(msg, multiline=False, **kwargs):
"""Return a repr-style string for a protorpc message.
protorpc.Message.__repr__ does not return anything that could be considered
python code. Adding this function lets us print a protorpc message in such
a way that it could be pasted into code later, and used to compare against
other things.
Args:
msg: protorpc.Message, the message to be repr'd.
multiline: bool, True if the returned string should have each field
assignment on its own line.
**kwargs: {str:str}, Additional flags for how to format the string.
Known **kwargs:
shortstrings: bool, True if all string values should be
truncated at 100 characters, since when mocking the contents
typically don't matter except for IDs, and IDs are usually
less than 100 characters.
no_modules: bool, True if the long module name should not be printed with
each type.
Returns:
str, A string of valid python (assuming the right imports have been made)
that recreates the message passed into this function.
"""
# TODO(user): craigcitro suggests a pretty-printer from apitools/gen.
indent = kwargs.get('indent', 0)
def IndentKwargs(kwargs):
kwargs = dict(kwargs)
kwargs['indent'] = kwargs.get('indent', 0) + 4
return kwargs
if isinstance(msg, list):
s = '['
for item in msg:
if multiline:
s += '\n' + ' ' * (indent + 4)
s += MessageToRepr(
item, multiline=multiline, **IndentKwargs(kwargs)) + ','
if multiline:
s += '\n' + ' ' * indent
s += ']'
return s
if isinstance(msg, messages.Message):
s = type(msg).__name__ + '('
if not kwargs.get('no_modules'):
s = msg.__module__ + '.' + s
names = sorted([field.name for field in msg.all_fields()])
for name in names:
field = msg.field_by_name(name)
if multiline:
s += '\n' + ' ' * (indent + 4)
value = getattr(msg, field.name)
s += field.name + '=' + MessageToRepr(
value, multiline=multiline, **IndentKwargs(kwargs)) + ','
if multiline:
s += '\n' + ' ' * indent
s += ')'
return s
if isinstance(msg, six.string_types):
if kwargs.get('shortstrings') and len(msg) > 100:
msg = msg[:100]
if isinstance(msg, datetime.datetime):
class SpecialTZInfo(datetime.tzinfo):
def __init__(self, offset):
super(SpecialTZInfo, self).__init__()
self.offset = offset
def __repr__(self):
s = 'TimeZoneOffset(' + repr(self.offset) + ')'
if not kwargs.get('no_modules'):
s = 'apitools.base.protorpclite.util.' + s
return s
msg = datetime.datetime(
msg.year, msg.month, msg.day, msg.hour, msg.minute, msg.second,
msg.microsecond, SpecialTZInfo(msg.tzinfo.utcoffset(0)))
return repr(msg)
def _GetField(message, field_path):
for field in field_path:
if field not in dir(message):
raise KeyError('no field "%s"' % field)
message = getattr(message, field)
return message
def _SetField(dictblob, field_path, value):
for field in field_path[:-1]:
dictblob = dictblob.setdefault(field, {})
dictblob[field_path[-1]] = value
def _IncludeFields(encoded_message, message, include_fields):
"""Add the requested fields to the encoded message."""
if include_fields is None:
return encoded_message
result = json.loads(encoded_message)
for field_name in include_fields:
try:
value = _GetField(message, field_name.split('.'))
nullvalue = None
if isinstance(value, list):
nullvalue = []
except KeyError:
raise exceptions.InvalidDataError(
'No field named %s in message of type %s' % (
field_name, type(message)))
_SetField(result, field_name.split('.'), nullvalue)
return json.dumps(result)
def _GetFieldCodecs(field, attr):
result = [
getattr(_CUSTOM_FIELD_CODECS.get(field), attr, None),
getattr(_FIELD_TYPE_CODECS.get(type(field)), attr, None),
]
return [x for x in result if x is not None]
class _ProtoJsonApiTools(protojson.ProtoJson):
"""JSON encoder used by apitools clients."""
_INSTANCE = None
@classmethod
def Get(cls):
if cls._INSTANCE is None:
cls._INSTANCE = cls()
return cls._INSTANCE
def decode_message(self, message_type, encoded_message):
if message_type in _CUSTOM_MESSAGE_CODECS:
return _CUSTOM_MESSAGE_CODECS[
message_type].decoder(encoded_message)
result = _DecodeCustomFieldNames(message_type, encoded_message)
result = super(_ProtoJsonApiTools, self).decode_message(
message_type, result)
result = _ProcessUnknownEnums(result, encoded_message)
result = _ProcessUnknownMessages(result, encoded_message)
return _DecodeUnknownFields(result, encoded_message)
def decode_field(self, field, value):
"""Decode the given JSON value.
Args:
field: a messages.Field for the field we're decoding.
value: a python value we'd like to decode.
Returns:
A value suitable for assignment to field.
"""
for decoder in _GetFieldCodecs(field, 'decoder'):
result = decoder(field, value)
value = result.value
if result.complete:
return value
if isinstance(field, messages.MessageField):
field_value = self.decode_message(
field.message_type, json.dumps(value))
elif isinstance(field, messages.EnumField):
value = GetCustomJsonEnumMapping(
field.type, json_name=value) or value
try:
field_value = super(
_ProtoJsonApiTools, self).decode_field(field, value)
except messages.DecodeError:
if not isinstance(value, six.string_types):
raise
field_value = None
else:
field_value = super(
_ProtoJsonApiTools, self).decode_field(field, value)
return field_value
def encode_message(self, message):
if isinstance(message, messages.FieldList):
return '[%s]' % (', '.join(self.encode_message(x)
for x in message))
# pylint: disable=unidiomatic-typecheck
if type(message) in _CUSTOM_MESSAGE_CODECS:
return _CUSTOM_MESSAGE_CODECS[type(message)].encoder(message)
message = _EncodeUnknownFields(message)
result = super(_ProtoJsonApiTools, self).encode_message(message)
result = _EncodeCustomFieldNames(message, result)
return json.dumps(json.loads(result), sort_keys=True)
def encode_field(self, field, value):
"""Encode the given value as JSON.
Args:
field: a messages.Field for the field we're encoding.
value: a value for field.
Returns:
A python value suitable for json.dumps.
"""
for encoder in _GetFieldCodecs(field, 'encoder'):
result = encoder(field, value)
value = result.value
if result.complete:
return value
if isinstance(field, messages.EnumField):
if field.repeated:
remapped_value = [GetCustomJsonEnumMapping(
field.type, python_name=e.name) or e.name for e in value]
else:
remapped_value = GetCustomJsonEnumMapping(
field.type, python_name=value.name)
if remapped_value:
return remapped_value
if (isinstance(field, messages.MessageField) and
not isinstance(field, message_types.DateTimeField)):
value = json.loads(self.encode_message(value))
return super(_ProtoJsonApiTools, self).encode_field(field, value)
# TODO(user): Fold this and _IncludeFields in as codecs.
def _DecodeUnknownFields(message, encoded_message):
"""Rewrite unknown fields in message into message.destination."""
destination = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if destination is None:
return message
pair_field = message.field_by_name(destination)
if not isinstance(pair_field, messages.MessageField):
raise exceptions.InvalidDataFromServerError(
'Unrecognized fields must be mapped to a compound '
'message type.')
pair_type = pair_field.message_type
# TODO(user): Add more error checking around the pair
# type being exactly what we suspect (field names, etc).
if isinstance(pair_type.value, messages.MessageField):
new_values = _DecodeUnknownMessages(
message, json.loads(encoded_message), pair_type)
else:
new_values = _DecodeUnrecognizedFields(message, pair_type)
setattr(message, destination, new_values)
# We could probably get away with not setting this, but
# why not clear it?
setattr(message, '_Message__unrecognized_fields', {})
return message
def _DecodeUnknownMessages(message, encoded_message, pair_type):
"""Process unknown fields in encoded_message of a message type."""
field_type = pair_type.value.type
new_values = []
all_field_names = [x.name for x in message.all_fields()]
for name, value_dict in six.iteritems(encoded_message):
if name in all_field_names:
continue
value = PyValueToMessage(field_type, value_dict)
if pair_type.value.repeated:
value = _AsMessageList(value)
new_pair = pair_type(key=name, value=value)
new_values.append(new_pair)
return new_values
def _DecodeUnrecognizedFields(message, pair_type):
"""Process unrecognized fields in message."""
new_values = []
codec = _ProtoJsonApiTools.Get()
for unknown_field in message.all_unrecognized_fields():
# TODO(user): Consider validating the variant if
# the assignment below doesn't take care of it. It may
# also be necessary to check it in the case that the
# type has multiple encodings.
value, _ = message.get_unrecognized_field_info(unknown_field)
value_type = pair_type.field_by_name('value')
if isinstance(value_type, messages.MessageField):
decoded_value = DictToMessage(value, pair_type.value.message_type)
else:
decoded_value = codec.decode_field(
pair_type.value, value)
try:
new_pair_key = str(unknown_field)
except UnicodeEncodeError:
new_pair_key = protojson.ProtoJson().decode_field(
pair_type.key, unknown_field)
new_pair = pair_type(key=new_pair_key, value=decoded_value)
new_values.append(new_pair)
return new_values
def _CopyProtoMessageVanillaProtoJson(message):
codec = protojson.ProtoJson()
return codec.decode_message(type(message), codec.encode_message(message))
def _EncodeUnknownFields(message):
"""Remap unknown fields in message out of message.source."""
source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if source is None:
return message
# CopyProtoMessage uses _ProtoJsonApiTools, which uses this message. Use
# the vanilla protojson-based copy function to avoid infinite recursion.
result = _CopyProtoMessageVanillaProtoJson(message)
pairs_field = message.field_by_name(source)
if not isinstance(pairs_field, messages.MessageField):
raise exceptions.InvalidUserInputError(
'Invalid pairs field %s' % pairs_field)
pairs_type = pairs_field.message_type
value_field = pairs_type.field_by_name('value')
value_variant = value_field.variant
pairs = getattr(message, source)
codec = _ProtoJsonApiTools.Get()
for pair in pairs:
encoded_value = codec.encode_field(value_field, pair.value)
result.set_unrecognized_field(pair.key, encoded_value, value_variant)
setattr(result, source, [])
return result
def _SafeEncodeBytes(field, value):
"""Encode the bytes in value as urlsafe base64."""
try:
if field.repeated:
result = [base64.urlsafe_b64encode(byte) for byte in value]
else:
result = base64.urlsafe_b64encode(value)
complete = True
except TypeError:
result = value
complete = False
return CodecResult(value=result, complete=complete)
def _SafeDecodeBytes(unused_field, value):
"""Decode the urlsafe base64 value into bytes."""
try:
result = base64.urlsafe_b64decode(str(value))
complete = True
except TypeError:
result = value
complete = False
return CodecResult(value=result, complete=complete)
def _ProcessUnknownEnums(message, encoded_message):
"""Add unknown enum values from encoded_message as unknown fields.
ProtoRPC diverges from the usual protocol buffer behavior here and
doesn't allow unknown fields. Throwing on unknown fields makes it
impossible to let servers add new enum values and stay compatible
with older clients, which isn't reasonable for us. We simply store
unrecognized enum values as unknown fields, and all is well.
Args:
message: Proto message we've decoded thus far.
encoded_message: JSON string we're decoding.
Returns:
message, with any unknown enums stored as unrecognized fields.
"""
if not encoded_message:
return message
decoded_message = json.loads(six.ensure_str(encoded_message))
for field in message.all_fields():
if (isinstance(field, messages.EnumField) and
field.name in decoded_message):
value = message.get_assigned_value(field.name)
if ((field.repeated and len(value) != len(decoded_message[field.name])) or
value is None):
message.set_unrecognized_field(
field.name, decoded_message[field.name], messages.Variant.ENUM)
return message
def _ProcessUnknownMessages(message, encoded_message):
"""Store any remaining unknown fields as strings.
ProtoRPC currently ignores unknown values for which no type can be
determined (and logs a "No variant found" message). For the purposes
of reserializing, this is quite harmful (since it throws away
information). Here we simply add those as unknown fields of type
string (so that they can easily be reserialized).
Args:
message: Proto message we've decoded thus far.
encoded_message: JSON string we're decoding.
Returns:
message, with any remaining unrecognized fields saved.
"""
if not encoded_message:
return message
decoded_message = json.loads(six.ensure_str(encoded_message))
message_fields = [x.name for x in message.all_fields()] + list(
message.all_unrecognized_fields())
missing_fields = [x for x in decoded_message.keys()
if x not in message_fields]
for field_name in missing_fields:
message.set_unrecognized_field(field_name, decoded_message[field_name],
messages.Variant.STRING)
return message
RegisterFieldTypeCodec(_SafeEncodeBytes, _SafeDecodeBytes)(messages.BytesField)
# Note that these could share a dictionary, since they're keyed by
# distinct types, but it's not really worth it.
_JSON_ENUM_MAPPINGS = {}
_JSON_FIELD_MAPPINGS = {}
def AddCustomJsonEnumMapping(enum_type, python_name, json_name,
package=None): # pylint: disable=unused-argument
"""Add a custom wire encoding for a given enum value.
This is primarily used in generated code, to handle enum values
which happen to be Python keywords.
Args:
enum_type: (messages.Enum) An enum type
python_name: (basestring) Python name for this value.
json_name: (basestring) JSON name to be used on the wire.
package: (NoneType, optional) No effect, exists for legacy compatibility.
"""
if not issubclass(enum_type, messages.Enum):
raise exceptions.TypecheckError(
'Cannot set JSON enum mapping for non-enum "%s"' % enum_type)
if python_name not in enum_type.names():
raise exceptions.InvalidDataError(
'Enum value %s not a value for type %s' % (python_name, enum_type))
field_mappings = _JSON_ENUM_MAPPINGS.setdefault(enum_type, {})
_CheckForExistingMappings('enum', enum_type, python_name, json_name)
field_mappings[python_name] = json_name
def AddCustomJsonFieldMapping(message_type, python_name, json_name,
package=None): # pylint: disable=unused-argument
"""Add a custom wire encoding for a given message field.
This is primarily used in generated code, to handle enum values
which happen to be Python keywords.
Args:
message_type: (messages.Message) A message type
python_name: (basestring) Python name for this value.
json_name: (basestring) JSON name to be used on the wire.
package: (NoneType, optional) No effect, exists for legacy compatibility.
"""
if not issubclass(message_type, messages.Message):
raise exceptions.TypecheckError(
'Cannot set JSON field mapping for '
'non-message "%s"' % message_type)
try:
_ = message_type.field_by_name(python_name)
except KeyError:
raise exceptions.InvalidDataError(
'Field %s not recognized for type %s' % (
python_name, message_type))
field_mappings = _JSON_FIELD_MAPPINGS.setdefault(message_type, {})
_CheckForExistingMappings('field', message_type, python_name, json_name)
field_mappings[python_name] = json_name
def GetCustomJsonEnumMapping(enum_type, python_name=None, json_name=None):
"""Return the appropriate remapping for the given enum, or None."""
return _FetchRemapping(enum_type, 'enum',
python_name=python_name, json_name=json_name,
mappings=_JSON_ENUM_MAPPINGS)
def GetCustomJsonFieldMapping(message_type, python_name=None, json_name=None):
"""Return the appropriate remapping for the given field, or None."""
return _FetchRemapping(message_type, 'field',
python_name=python_name, json_name=json_name,
mappings=_JSON_FIELD_MAPPINGS)
def _FetchRemapping(type_name, mapping_type, python_name=None, json_name=None,
mappings=None):
"""Common code for fetching a key or value from a remapping dict."""
if python_name and json_name:
raise exceptions.InvalidDataError(
'Cannot specify both python_name and json_name '
'for %s remapping' % mapping_type)
if not (python_name or json_name):
raise exceptions.InvalidDataError(
'Must specify either python_name or json_name for %s remapping' % (
mapping_type,))
field_remappings = mappings.get(type_name, {})
if field_remappings:
if python_name:
return field_remappings.get(python_name)
elif json_name:
if json_name in list(field_remappings.values()):
return [k for k in field_remappings
if field_remappings[k] == json_name][0]
return None
def _CheckForExistingMappings(mapping_type, message_type,
python_name, json_name):
"""Validate that no mappings exist for the given values."""
if mapping_type == 'field':
getter = GetCustomJsonFieldMapping
elif mapping_type == 'enum':
getter = GetCustomJsonEnumMapping
remapping = getter(message_type, python_name=python_name)
if remapping is not None and remapping != json_name:
raise exceptions.InvalidDataError(
'Cannot add mapping for %s "%s", already mapped to "%s"' % (
mapping_type, python_name, remapping))
remapping = getter(message_type, json_name=json_name)
if remapping is not None and remapping != python_name:
raise exceptions.InvalidDataError(
'Cannot add mapping for %s "%s", already mapped to "%s"' % (
mapping_type, json_name, remapping))
def _EncodeCustomFieldNames(message, encoded_value):
field_remappings = list(_JSON_FIELD_MAPPINGS.get(type(message), {})
.items())
if field_remappings:
decoded_value = json.loads(encoded_value)
for python_name, json_name in field_remappings:
if python_name in encoded_value:
decoded_value[json_name] = decoded_value.pop(python_name)
encoded_value = json.dumps(decoded_value)
return encoded_value
def _DecodeCustomFieldNames(message_type, encoded_message):
field_remappings = _JSON_FIELD_MAPPINGS.get(message_type, {})
if field_remappings:
decoded_message = json.loads(encoded_message)
for python_name, json_name in list(field_remappings.items()):
if json_name in decoded_message:
decoded_message[python_name] = decoded_message.pop(json_name)
encoded_message = json.dumps(decoded_message)
return encoded_message
def _AsMessageList(msg):
"""Convert the provided list-as-JsonValue to a list."""
# This really needs to live in extra_types, but extra_types needs
# to import this file to be able to register codecs.
# TODO(user): Split out a codecs module and fix this ugly
# import.
from apitools.base.py import extra_types
def _IsRepeatedJsonValue(msg):
"""Return True if msg is a repeated value as a JsonValue."""
if isinstance(msg, extra_types.JsonArray):
return True
if isinstance(msg, extra_types.JsonValue) and msg.array_value:
return True
return False
if not _IsRepeatedJsonValue(msg):
raise ValueError('invalid argument to _AsMessageList')
if isinstance(msg, extra_types.JsonValue):
msg = msg.array_value
if isinstance(msg, extra_types.JsonArray):
msg = msg.entries
return msg
def _IsMap(message, field):
"""Returns whether the "field" is actually a map-type."""
value = message.get_assigned_value(field.name)
if not isinstance(value, messages.Message):
return False
try:
additional_properties = value.field_by_name('additionalProperties')
except KeyError:
return False
else:
return additional_properties.repeated
def _MapItems(message, field):
"""Yields the (key, value) pair of the map values."""
assert _IsMap(message, field)
map_message = message.get_assigned_value(field.name)
additional_properties = map_message.get_assigned_value(
'additionalProperties')
for kv_pair in additional_properties:
yield kv_pair.key, kv_pair.value
def UnrecognizedFieldIter(message, _edges=()): # pylint: disable=invalid-name
"""Yields the locations of unrecognized fields within "message".
If a sub-message is found to have unrecognized fields, that sub-message
will not be searched any further. We prune the search of the sub-message
because we assume it is malformed and further checks will not yield
productive errors.
Args:
message: The Message instance to search.
_edges: Internal arg for passing state.
Yields:
(edges_to_message, field_names):
edges_to_message: List[ProtoEdge], The edges (relative to "message")
describing the path to the sub-message where the unrecognized
fields were found.
field_names: List[Str], The names of the field(s) that were
unrecognized in the sub-message.
"""
if not isinstance(message, messages.Message):
# This is a primitive leaf, no errors found down this path.
return
field_names = message.all_unrecognized_fields()
if field_names:
# This message is malformed. Stop recursing and report it.
yield _edges, field_names
return
# Recurse through all fields in the current message.
for field in message.all_fields():
value = message.get_assigned_value(field.name)
if field.repeated:
for i, item in enumerate(value):
repeated_edge = ProtoEdge(EdgeType.REPEATED, field.name, i)
iter_ = UnrecognizedFieldIter(item, _edges + (repeated_edge,))
for (e, y) in iter_:
yield e, y
elif _IsMap(message, field):
for key, item in _MapItems(message, field):
map_edge = ProtoEdge(EdgeType.MAP, field.name, key)
iter_ = UnrecognizedFieldIter(item, _edges + (map_edge,))
for (e, y) in iter_:
yield e, y
else:
scalar_edge = ProtoEdge(EdgeType.SCALAR, field.name, None)
iter_ = UnrecognizedFieldIter(value, _edges + (scalar_edge,))
for (e, y) in iter_:
yield e, y

View File

@@ -0,0 +1,207 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exceptions for generated client libraries."""
class Error(Exception):
"""Base class for all exceptions."""
class TypecheckError(Error, TypeError):
"""An object of an incorrect type is provided."""
class NotFoundError(Error):
"""A specified resource could not be found."""
class UserError(Error):
"""Base class for errors related to user input."""
class InvalidDataError(Error):
"""Base class for any invalid data error."""
class CommunicationError(Error):
"""Any communication error talking to an API server."""
class HttpError(CommunicationError):
"""Error making a request. Soon to be HttpError."""
def __init__(self, response, content, url,
method_config=None, request=None):
error_message = HttpError._build_message(response, content, url)
super(HttpError, self).__init__(error_message)
self.response = response
self.content = content
self.url = url
self.method_config = method_config
self.request = request
def __str__(self):
return HttpError._build_message(self.response, self.content, self.url)
@staticmethod
def _build_message(response, content, url):
if isinstance(content, bytes):
content = content.decode('ascii', 'replace')
return 'HttpError accessing <%s>: response: <%s>, content <%s>' % (
url, response, content)
@property
def status_code(self):
# TODO(user): Turn this into something better than a
# KeyError if there is no status.
return int(self.response['status'])
@classmethod
def FromResponse(cls, http_response, **kwargs):
try:
status_code = int(http_response.info.get('status'))
error_cls = _HTTP_ERRORS.get(status_code, cls)
except ValueError:
error_cls = cls
return error_cls(http_response.info, http_response.content,
http_response.request_url, **kwargs)
class HttpBadRequestError(HttpError):
"""HTTP 400 Bad Request."""
class HttpUnauthorizedError(HttpError):
"""HTTP 401 Unauthorized."""
class HttpForbiddenError(HttpError):
"""HTTP 403 Forbidden."""
class HttpNotFoundError(HttpError):
"""HTTP 404 Not Found."""
class HttpConflictError(HttpError):
"""HTTP 409 Conflict."""
_HTTP_ERRORS = {
400: HttpBadRequestError,
401: HttpUnauthorizedError,
403: HttpForbiddenError,
404: HttpNotFoundError,
409: HttpConflictError,
}
class InvalidUserInputError(InvalidDataError):
"""User-provided input is invalid."""
class InvalidDataFromServerError(InvalidDataError, CommunicationError):
"""Data received from the server is malformed."""
class BatchError(Error):
"""Error generated while constructing a batch request."""
class ConfigurationError(Error):
"""Base class for configuration errors."""
class GeneratedClientError(Error):
"""The generated client configuration is invalid."""
class ConfigurationValueError(UserError):
"""Some part of the user-specified client configuration is invalid."""
class ResourceUnavailableError(Error):
"""User requested an unavailable resource."""
class CredentialsError(Error):
"""Errors related to invalid credentials."""
class TransferError(CommunicationError):
"""Errors related to transfers."""
class TransferRetryError(TransferError):
"""Retryable errors related to transfers."""
class TransferInvalidError(TransferError):
"""The given transfer is invalid."""
class RequestError(CommunicationError):
"""The request was not successful."""
class RetryAfterError(HttpError):
"""The response contained a retry-after header."""
def __init__(self, response, content, url, retry_after, **kwargs):
super(RetryAfterError, self).__init__(response, content, url, **kwargs)
self.retry_after = int(retry_after)
@classmethod
def FromResponse(cls, http_response, **kwargs):
return cls(http_response.info, http_response.content,
http_response.request_url, http_response.retry_after,
**kwargs)
class BadStatusCodeError(HttpError):
"""The request completed but returned a bad status code."""
class NotYetImplementedError(GeneratedClientError):
"""This functionality is not yet implemented."""
class StreamExhausted(Error):
"""Attempted to read more bytes from a stream than were available."""

View File

@@ -0,0 +1,312 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extra types understood by apitools."""
import datetime
import json
import numbers
import six
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import protojson
from apitools.base.py import encoding_helper as encoding
from apitools.base.py import exceptions
from apitools.base.py import util
if six.PY3:
from collections.abc import Iterable
else:
from collections import Iterable
__all__ = [
'DateField',
'DateTimeMessage',
'JsonArray',
'JsonObject',
'JsonValue',
'JsonProtoEncoder',
'JsonProtoDecoder',
]
# pylint:disable=invalid-name
DateTimeMessage = message_types.DateTimeMessage
# pylint:enable=invalid-name
# We insert our own metaclass here to avoid letting ProtoRPC
# register this as the default field type for strings.
# * since ProtoRPC does this via metaclasses, we don't have any
# choice but to use one ourselves
# * since a subclass's metaclass must inherit from its superclass's
# metaclass, we're forced to have this hard-to-read inheritance.
#
# pylint: disable=protected-access
class _FieldMeta(messages._FieldMeta):
def __init__(cls, name, bases, dct): # pylint: disable=no-self-argument
# pylint: disable=super-init-not-called,non-parent-init-called
type.__init__(cls, name, bases, dct)
# pylint: enable=protected-access
class DateField(six.with_metaclass(_FieldMeta, messages.Field)):
"""Field definition for Date values."""
VARIANTS = frozenset([messages.Variant.STRING])
DEFAULT_VARIANT = messages.Variant.STRING
type = datetime.date
def _ValidateJsonValue(json_value):
entries = [(f, json_value.get_assigned_value(f.name))
for f in json_value.all_fields()]
assigned_entries = [(f, value)
for f, value in entries if value is not None]
if len(assigned_entries) != 1:
raise exceptions.InvalidDataError(
'Malformed JsonValue: %s' % json_value)
def _JsonValueToPythonValue(json_value):
"""Convert the given JsonValue to a json string."""
util.Typecheck(json_value, JsonValue)
_ValidateJsonValue(json_value)
if json_value.is_null:
return None
entries = [(f, json_value.get_assigned_value(f.name))
for f in json_value.all_fields()]
assigned_entries = [(f, value)
for f, value in entries if value is not None]
field, value = assigned_entries[0]
if not isinstance(field, messages.MessageField):
return value
elif field.message_type is JsonObject:
return _JsonObjectToPythonValue(value)
elif field.message_type is JsonArray:
return _JsonArrayToPythonValue(value)
def _JsonObjectToPythonValue(json_value):
util.Typecheck(json_value, JsonObject)
return dict([(prop.key, _JsonValueToPythonValue(prop.value)) for prop
in json_value.properties])
def _JsonArrayToPythonValue(json_value):
util.Typecheck(json_value, JsonArray)
return [_JsonValueToPythonValue(e) for e in json_value.entries]
_MAXINT64 = 2 << 63 - 1
_MININT64 = -(2 << 63)
def _PythonValueToJsonValue(py_value):
"""Convert the given python value to a JsonValue."""
if py_value is None:
return JsonValue(is_null=True)
if isinstance(py_value, bool):
return JsonValue(boolean_value=py_value)
if isinstance(py_value, six.string_types):
return JsonValue(string_value=py_value)
if isinstance(py_value, numbers.Number):
if isinstance(py_value, six.integer_types):
if _MININT64 < py_value < _MAXINT64:
return JsonValue(integer_value=py_value)
return JsonValue(double_value=float(py_value))
if isinstance(py_value, dict):
return JsonValue(object_value=_PythonValueToJsonObject(py_value))
if isinstance(py_value, Iterable):
return JsonValue(array_value=_PythonValueToJsonArray(py_value))
raise exceptions.InvalidDataError(
'Cannot convert "%s" to JsonValue' % py_value)
def _PythonValueToJsonObject(py_value):
util.Typecheck(py_value, dict)
return JsonObject(
properties=[
JsonObject.Property(key=key, value=_PythonValueToJsonValue(value))
for key, value in py_value.items()])
def _PythonValueToJsonArray(py_value):
return JsonArray(entries=list(map(_PythonValueToJsonValue, py_value)))
class JsonValue(messages.Message):
"""Any valid JSON value."""
# Is this JSON object `null`?
is_null = messages.BooleanField(1, default=False)
# Exactly one of the following is provided if is_null is False; none
# should be provided if is_null is True.
boolean_value = messages.BooleanField(2)
string_value = messages.StringField(3)
# We keep two numeric fields to keep int64 round-trips exact.
double_value = messages.FloatField(4, variant=messages.Variant.DOUBLE)
integer_value = messages.IntegerField(5, variant=messages.Variant.INT64)
# Compound types
object_value = messages.MessageField('JsonObject', 6)
array_value = messages.MessageField('JsonArray', 7)
class JsonObject(messages.Message):
"""A JSON object value.
Messages:
Property: A property of a JsonObject.
Fields:
properties: A list of properties of a JsonObject.
"""
class Property(messages.Message):
"""A property of a JSON object.
Fields:
key: Name of the property.
value: A JsonValue attribute.
"""
key = messages.StringField(1)
value = messages.MessageField(JsonValue, 2)
properties = messages.MessageField(Property, 1, repeated=True)
class JsonArray(messages.Message):
"""A JSON array value."""
entries = messages.MessageField(JsonValue, 1, repeated=True)
_JSON_PROTO_TO_PYTHON_MAP = {
JsonArray: _JsonArrayToPythonValue,
JsonObject: _JsonObjectToPythonValue,
JsonValue: _JsonValueToPythonValue,
}
_JSON_PROTO_TYPES = tuple(_JSON_PROTO_TO_PYTHON_MAP.keys())
def _JsonProtoToPythonValue(json_proto):
util.Typecheck(json_proto, _JSON_PROTO_TYPES)
return _JSON_PROTO_TO_PYTHON_MAP[type(json_proto)](json_proto)
def _PythonValueToJsonProto(py_value):
if isinstance(py_value, dict):
return _PythonValueToJsonObject(py_value)
if (isinstance(py_value, Iterable) and
not isinstance(py_value, six.string_types)):
return _PythonValueToJsonArray(py_value)
return _PythonValueToJsonValue(py_value)
def _JsonProtoToJson(json_proto, unused_encoder=None):
return json.dumps(_JsonProtoToPythonValue(json_proto))
def _JsonToJsonProto(json_data, unused_decoder=None):
return _PythonValueToJsonProto(json.loads(json_data))
def _JsonToJsonValue(json_data, unused_decoder=None):
result = _PythonValueToJsonProto(json.loads(json_data))
if isinstance(result, JsonValue):
return result
elif isinstance(result, JsonObject):
return JsonValue(object_value=result)
elif isinstance(result, JsonArray):
return JsonValue(array_value=result)
else:
raise exceptions.InvalidDataError(
'Malformed JsonValue: %s' % json_data)
# pylint:disable=invalid-name
JsonProtoEncoder = _JsonProtoToJson
JsonProtoDecoder = _JsonToJsonProto
# pylint:enable=invalid-name
encoding.RegisterCustomMessageCodec(
encoder=JsonProtoEncoder, decoder=_JsonToJsonValue)(JsonValue)
encoding.RegisterCustomMessageCodec(
encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonObject)
encoding.RegisterCustomMessageCodec(
encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonArray)
def _EncodeDateTimeField(field, value):
result = protojson.ProtoJson().encode_field(field, value)
return encoding.CodecResult(value=result, complete=True)
def _DecodeDateTimeField(unused_field, value):
result = protojson.ProtoJson().decode_field(
message_types.DateTimeField(1), value)
return encoding.CodecResult(value=result, complete=True)
encoding.RegisterFieldTypeCodec(_EncodeDateTimeField, _DecodeDateTimeField)(
message_types.DateTimeField)
def _EncodeInt64Field(field, value):
"""Handle the special case of int64 as a string."""
capabilities = [
messages.Variant.INT64,
messages.Variant.UINT64,
]
if field.variant not in capabilities:
return encoding.CodecResult(value=value, complete=False)
if field.repeated:
result = [str(x) for x in value]
else:
result = str(value)
return encoding.CodecResult(value=result, complete=True)
def _DecodeInt64Field(unused_field, value):
# Don't need to do anything special, they're decoded just fine
return encoding.CodecResult(value=value, complete=False)
encoding.RegisterFieldTypeCodec(_EncodeInt64Field, _DecodeInt64Field)(
messages.IntegerField)
def _EncodeDateField(field, value):
"""Encoder for datetime.date objects."""
if field.repeated:
result = [d.isoformat() for d in value]
else:
result = value.isoformat()
return encoding.CodecResult(value=result, complete=True)
def _DecodeDateField(unused_field, value):
date = datetime.datetime.strptime(value, '%Y-%m-%d').date()
return encoding.CodecResult(value=date, complete=True)
encoding.RegisterFieldTypeCodec(_EncodeDateField, _DecodeDateField)(DateField)

View File

@@ -0,0 +1,617 @@
# Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
# 2011, 2012, 2013, 2014, 2015, 2016, 2017 Python Software Foundation; All
# Rights Reserved
#
# This is a backport from python 3.4 into python 2.7. Text and exclusive mode
# support are removed as they're unsupported in 2.7. This backport patches a
# streaming bug that exists in python 2.7.
"""Functions that read and write gzipped files.
The user of the file doesn't have to worry about the compression,
but random access is not allowed."""
# based on Andrew Kuchling's minigzip.py distributed with the zlib module
import six
from six.moves import builtins
from six.moves import range
import struct
import sys
import time
import os
import zlib
import io
__all__ = ["GzipFile", "open", "compress", "decompress"]
FTEXT, FHCRC, FEXTRA, FNAME, FCOMMENT = 1, 2, 4, 8, 16
READ, WRITE = 1, 2
def open(filename, mode="rb", compresslevel=9):
"""Shorthand for GzipFile(filename, mode, compresslevel).
The filename argument is required; mode defaults to 'rb'
and compresslevel defaults to 9.
"""
return GzipFile(filename, mode, compresslevel)
def write32u(output, value):
# The L format writes the bit pattern correctly whether signed
# or unsigned.
output.write(struct.pack("<L", value))
class _PaddedFile(object):
"""Minimal read-only file object that prepends a string to the contents
of an actual file. Shouldn't be used outside of gzip.py, as it lacks
essential functionality."""
def __init__(self, f, prepend=b''):
self._buffer = prepend
self._length = len(prepend)
self.file = f
self._read = 0
def read(self, size):
if self._read is None:
return self.file.read(size)
if self._read + size <= self._length:
read = self._read
self._read += size
return self._buffer[read:self._read]
else:
read = self._read
self._read = None
return self._buffer[read:] + \
self.file.read(size - self._length + read)
def prepend(self, prepend=b'', readprevious=False):
if self._read is None:
self._buffer = prepend
elif readprevious and len(prepend) <= self._read:
self._read -= len(prepend)
return
else:
self._buffer = self._buffer[self._read:] + prepend
self._length = len(self._buffer)
self._read = 0
def unused(self):
if self._read is None:
return b''
return self._buffer[self._read:]
def seek(self, offset, whence=0):
# This is only ever called with offset=whence=0
if whence == 1 and self._read is not None:
if 0 <= offset + self._read <= self._length:
self._read += offset
return
else:
offset += self._length - self._read
self._read = None
self._buffer = None
return self.file.seek(offset, whence)
def __getattr__(self, name):
return getattr(self.file, name)
class GzipFile(io.BufferedIOBase):
"""The GzipFile class simulates most of the methods of a file object with
the exception of the readinto() and truncate() methods.
This class only supports opening files in binary mode. If you need to open
a compressed file in text mode, use the gzip.open() function.
"""
myfileobj = None
max_read_chunk = 10 * 1024 * 1024 # 10Mb
def __init__(self, filename=None, mode=None,
compresslevel=9, fileobj=None, mtime=None):
"""Constructor for the GzipFile class.
At least one of fileobj and filename must be given a
non-trivial value.
The new class instance is based on fileobj, which can be a regular
file, an io.BytesIO object, or any other object which simulates a file.
It defaults to None, in which case filename is opened to provide
a file object.
When fileobj is not None, the filename argument is only used to be
included in the gzip file header, which may includes the original
filename of the uncompressed file. It defaults to the filename of
fileobj, if discernible; otherwise, it defaults to the empty string,
and in this case the original filename is not included in the header.
The mode argument can be any of 'r', 'rb', 'a', 'ab', 'w', or 'wb',
depending on whether the file will be read or written. The default
is the mode of fileobj if discernible; otherwise, the default is 'rb'.
A mode of 'r' is equivalent to one of 'rb', and similarly for 'w' and
'wb', and 'a' and 'ab'.
The compresslevel argument is an integer from 0 to 9 controlling the
level of compression; 1 is fastest and produces the least compression,
and 9 is slowest and produces the most compression. 0 is no compression
at all. The default is 9.
The mtime argument is an optional numeric timestamp to be written
to the stream when compressing. All gzip compressed streams
are required to contain a timestamp. If omitted or None, the
current time is used. This module ignores the timestamp when
decompressing; however, some programs, such as gunzip, make use
of it. The format of the timestamp is the same as that of the
return value of time.time() and of the st_mtime member of the
object returned by os.stat().
"""
if mode and ('t' in mode or 'U' in mode):
raise ValueError("Invalid mode: {!r}".format(mode))
if mode and 'b' not in mode:
mode += 'b'
if fileobj is None:
fileobj = self.myfileobj = builtins.open(filename, mode or 'rb')
if filename is None:
filename = getattr(fileobj, 'name', '')
if not isinstance(filename, six.string_types):
filename = ''
if mode is None:
mode = getattr(fileobj, 'mode', 'rb')
if mode.startswith('r'):
self.mode = READ
# Set flag indicating start of a new member
self._new_member = True
# Buffer data read from gzip file. extrastart is offset in
# stream where buffer starts. extrasize is number of
# bytes remaining in buffer from current stream position.
self.extrabuf = b""
self.extrasize = 0
self.extrastart = 0
self.name = filename
# Starts small, scales exponentially
self.min_readsize = 100
fileobj = _PaddedFile(fileobj)
elif mode.startswith(('w', 'a')):
self.mode = WRITE
self._init_write(filename)
self.compress = zlib.compressobj(compresslevel,
zlib.DEFLATED,
-zlib.MAX_WBITS,
zlib.DEF_MEM_LEVEL,
0)
else:
raise ValueError("Invalid mode: {!r}".format(mode))
self.fileobj = fileobj
self.offset = 0
self.mtime = mtime
if self.mode == WRITE:
self._write_gzip_header()
@property
def filename(self):
import warnings
warnings.warn("use the name attribute", DeprecationWarning, 2)
if self.mode == WRITE and self.name[-3:] != ".gz":
return self.name + ".gz"
return self.name
def __repr__(self):
fileobj = self.fileobj
if isinstance(fileobj, _PaddedFile):
fileobj = fileobj.file
s = repr(fileobj)
return '<gzip ' + s[1:-1] + ' ' + hex(id(self)) + '>'
def _check_closed(self):
"""Raises a ValueError if the underlying file object has been closed.
"""
if self.closed:
raise ValueError('I/O operation on closed file.')
def _init_write(self, filename):
self.name = filename
self.crc = zlib.crc32(b"") & 0xffffffff
self.size = 0
self.writebuf = []
self.bufsize = 0
def _write_gzip_header(self):
self.fileobj.write(b'\037\213') # magic header
self.fileobj.write(b'\010') # compression method
try:
# RFC 1952 requires the FNAME field to be Latin-1. Do not
# include filenames that cannot be represented that way.
fname = os.path.basename(self.name)
if not isinstance(fname, six.binary_type):
fname = fname.encode('latin-1')
if fname.endswith(b'.gz'):
fname = fname[:-3]
except UnicodeEncodeError:
fname = b''
flags = 0
if fname:
flags = FNAME
self.fileobj.write(six.unichr(flags).encode('latin-1'))
mtime = self.mtime
if mtime is None:
mtime = time.time()
write32u(self.fileobj, int(mtime))
self.fileobj.write(b'\002')
self.fileobj.write(b'\377')
if fname:
self.fileobj.write(fname + b'\000')
def _init_read(self):
self.crc = zlib.crc32(b"") & 0xffffffff
self.size = 0
def _read_exact(self, n):
data = self.fileobj.read(n)
while len(data) < n:
b = self.fileobj.read(n - len(data))
if not b:
raise EOFError("Compressed file ended before the "
"end-of-stream marker was reached")
data += b
return data
def _read_gzip_header(self):
magic = self.fileobj.read(2)
if magic == b'':
return False
if magic != b'\037\213':
raise OSError('Not a gzipped file')
method, flag, self.mtime = struct.unpack("<BBIxx", self._read_exact(8))
if method != 8:
raise OSError('Unknown compression method')
if flag & FEXTRA:
# Read & discard the extra field, if present
extra_len, = struct.unpack("<H", self._read_exact(2))
self._read_exact(extra_len)
if flag & FNAME:
# Read and discard a null-terminated string containing the filename
while True:
s = self.fileobj.read(1)
if not s or s == b'\000':
break
if flag & FCOMMENT:
# Read and discard a null-terminated string containing a comment
while True:
s = self.fileobj.read(1)
if not s or s == b'\000':
break
if flag & FHCRC:
self._read_exact(2) # Read & discard the 16-bit header CRC
unused = self.fileobj.unused()
if unused:
uncompress = self.decompress.decompress(unused)
self._add_read_data(uncompress)
return True
def write(self, data):
self._check_closed()
if self.mode != WRITE:
import errno
raise OSError(errno.EBADF, "write() on read-only GzipFile object")
if self.fileobj is None:
raise ValueError("write() on closed GzipFile object")
# Convert data type if called by io.BufferedWriter.
if isinstance(data, memoryview):
data = data.tobytes()
if len(data) > 0:
self.fileobj.write(self.compress.compress(data))
self.size += len(data)
self.crc = zlib.crc32(data, self.crc) & 0xffffffff
self.offset += len(data)
return len(data)
def read(self, size=-1):
self._check_closed()
if self.mode != READ:
import errno
raise OSError(errno.EBADF, "read() on write-only GzipFile object")
if self.extrasize <= 0 and self.fileobj is None:
return b''
readsize = 1024
if size < 0: # get the whole thing
while self._read(readsize):
readsize = min(self.max_read_chunk, readsize * 2)
size = self.extrasize
else: # just get some more of it
while size > self.extrasize:
if not self._read(readsize):
if size > self.extrasize:
size = self.extrasize
break
readsize = min(self.max_read_chunk, readsize * 2)
offset = self.offset - self.extrastart
chunk = self.extrabuf[offset: offset + size]
self.extrasize = self.extrasize - size
self.offset += size
return chunk
def read1(self, size=-1):
self._check_closed()
if self.mode != READ:
import errno
raise OSError(errno.EBADF, "read1() on write-only GzipFile object")
if self.extrasize <= 0 and self.fileobj is None:
return b''
# For certain input data, a single call to _read() may not return
# any data. In this case, retry until we get some data or reach EOF.
while self.extrasize <= 0 and self._read():
pass
if size < 0 or size > self.extrasize:
size = self.extrasize
offset = self.offset - self.extrastart
chunk = self.extrabuf[offset: offset + size]
self.extrasize -= size
self.offset += size
return chunk
def peek(self, n):
if self.mode != READ:
import errno
raise OSError(errno.EBADF, "peek() on write-only GzipFile object")
# Do not return ridiculously small buffers, for one common idiom
# is to call peek(1) and expect more bytes in return.
if n < 100:
n = 100
if self.extrasize == 0:
if self.fileobj is None:
return b''
# Ensure that we don't return b"" if we haven't reached EOF.
# 1024 is the same buffering heuristic used in read()
while self.extrasize == 0 and self._read(max(n, 1024)):
pass
offset = self.offset - self.extrastart
remaining = self.extrasize
assert remaining == len(self.extrabuf) - offset
return self.extrabuf[offset:offset + n]
def _unread(self, buf):
self.extrasize = len(buf) + self.extrasize
self.offset -= len(buf)
def _read(self, size=1024):
if self.fileobj is None:
return False
if self._new_member:
# If the _new_member flag is set, we have to
# jump to the next member, if there is one.
self._init_read()
if not self._read_gzip_header():
return False
self.decompress = zlib.decompressobj(-zlib.MAX_WBITS)
self._new_member = False
# Read a chunk of data from the file
buf = self.fileobj.read(size)
# If the EOF has been reached, flush the decompression object
# and mark this object as finished.
if buf == b"":
uncompress = self.decompress.flush()
# Prepend the already read bytes to the fileobj to they can be
# seen by _read_eof()
self.fileobj.prepend(self.decompress.unused_data, True)
self._read_eof()
self._add_read_data(uncompress)
return False
uncompress = self.decompress.decompress(buf)
self._add_read_data(uncompress)
if self.decompress.unused_data != b"":
# Ending case: we've come to the end of a member in the file,
# so seek back to the start of the unused data, finish up
# this member, and read a new gzip header.
# Prepend the already read bytes to the fileobj to they can be
# seen by _read_eof() and _read_gzip_header()
self.fileobj.prepend(self.decompress.unused_data, True)
# Check the CRC and file size, and set the flag so we read
# a new member on the next call
self._read_eof()
self._new_member = True
return True
def _add_read_data(self, data):
self.crc = zlib.crc32(data, self.crc) & 0xffffffff
offset = self.offset - self.extrastart
self.extrabuf = self.extrabuf[offset:] + data
self.extrasize = self.extrasize + len(data)
self.extrastart = self.offset
self.size = self.size + len(data)
def _read_eof(self):
# We've read to the end of the file
# We check the that the computed CRC and size of the
# uncompressed data matches the stored values. Note that the size
# stored is the true file size mod 2**32.
crc32, isize = struct.unpack("<II", self._read_exact(8))
if crc32 != self.crc:
raise OSError("CRC check failed %s != %s" % (hex(crc32),
hex(self.crc)))
elif isize != (self.size & 0xffffffff):
raise OSError("Incorrect length of data produced")
# Gzip files can be padded with zeroes and still have archives.
# Consume all zero bytes and set the file position to the first
# non-zero byte. See http://www.gzip.org/#faq8
c = b"\x00"
while c == b"\x00":
c = self.fileobj.read(1)
if c:
self.fileobj.prepend(c, True)
@property
def closed(self):
return self.fileobj is None
def close(self):
fileobj = self.fileobj
if fileobj is None:
return
self.fileobj = None
try:
if self.mode == WRITE:
fileobj.write(self.compress.flush())
write32u(fileobj, self.crc)
# self.size may exceed 2GB, or even 4GB
write32u(fileobj, self.size & 0xffffffff)
finally:
myfileobj = self.myfileobj
if myfileobj:
self.myfileobj = None
myfileobj.close()
def flush(self, zlib_mode=zlib.Z_SYNC_FLUSH):
self._check_closed()
if self.mode == WRITE:
# Ensure the compressor's buffer is flushed
self.fileobj.write(self.compress.flush(zlib_mode))
self.fileobj.flush()
def fileno(self):
"""Invoke the underlying file object's fileno() method.
This will raise AttributeError if the underlying file object
doesn't support fileno().
"""
return self.fileobj.fileno()
def rewind(self):
'''Return the uncompressed stream file position indicator to the
beginning of the file'''
if self.mode != READ:
raise OSError("Can't rewind in write mode")
self.fileobj.seek(0)
self._new_member = True
self.extrabuf = b""
self.extrasize = 0
self.extrastart = 0
self.offset = 0
def readable(self):
return self.mode == READ
def writable(self):
return self.mode == WRITE
def seekable(self):
return True
def seek(self, offset, whence=0):
if whence:
if whence == 1:
offset = self.offset + offset
else:
raise ValueError('Seek from end not supported')
if self.mode == WRITE:
if offset < self.offset:
raise OSError('Negative seek in write mode')
count = offset - self.offset
chunk = bytes(1024)
for i in range(count // 1024):
self.write(chunk)
self.write(bytes(count % 1024))
elif self.mode == READ:
if offset < self.offset:
# for negative seek, rewind and do positive seek
self.rewind()
count = offset - self.offset
for i in range(count // 1024):
self.read(1024)
self.read(count % 1024)
return self.offset
def readline(self, size=-1):
if size < 0:
# Shortcut common case - newline found in buffer.
offset = self.offset - self.extrastart
i = self.extrabuf.find(b'\n', offset) + 1
if i > 0:
self.extrasize -= i - offset
self.offset += i - offset
return self.extrabuf[offset: i]
size = sys.maxsize
readsize = self.min_readsize
else:
readsize = size
bufs = []
while size != 0:
c = self.read(readsize)
i = c.find(b'\n')
# We set i=size to break out of the loop under two
# conditions: 1) there's no newline, and the chunk is
# larger than size, or 2) there is a newline, but the
# resulting line would be longer than 'size'.
if (size <= i) or (i == -1 and len(c) > size):
i = size - 1
if i >= 0 or c == b'':
bufs.append(c[:i + 1]) # Add portion of last chunk
self._unread(c[i + 1:]) # Push back rest of chunk
break
# Append chunk to list, decrease 'size',
bufs.append(c)
size = size - len(c)
readsize = min(size, readsize * 2)
if readsize > self.min_readsize:
self.min_readsize = min(readsize, self.min_readsize * 2, 512)
return b''.join(bufs) # Return resulting line
def compress(data, compresslevel=9):
"""Compress data in one shot and return the compressed string.
Optional argument is the compression level, in range of 0-9.
"""
buf = io.BytesIO()
with GzipFile(fileobj=buf, mode='wb', compresslevel=compresslevel) as f:
f.write(data)
return buf.getvalue()
def decompress(data):
"""Decompress a gzip compressed string in one shot.
Return the decompressed string.
"""
with GzipFile(fileobj=io.BytesIO(data)) as f:
return f.read()

View File

@@ -0,0 +1,422 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HTTP wrapper for apitools.
This library wraps the underlying http library we use, which is
currently httplib2.
"""
import collections
import contextlib
import logging
import socket
import time
import httplib2
import six
from six.moves import http_client
from six.moves.urllib import parse
from apitools.base.py import exceptions
from apitools.base.py import util
# pylint: disable=ungrouped-imports
try:
from oauth2client.client import HttpAccessTokenRefreshError as TokenRefreshError # noqa
except ImportError:
from oauth2client.client import AccessTokenRefreshError as TokenRefreshError # noqa
__all__ = [
'CheckResponse',
'GetHttp',
'HandleExceptionsAndRebuildHttpConnections',
'MakeRequest',
'RebuildHttpConnections',
'Request',
'Response',
'RethrowExceptionHandler',
]
# 308 and 429 don't have names in httplib.
RESUME_INCOMPLETE = 308
TOO_MANY_REQUESTS = 429
_REDIRECT_STATUS_CODES = (
http_client.MOVED_PERMANENTLY,
http_client.FOUND,
http_client.SEE_OTHER,
http_client.TEMPORARY_REDIRECT,
RESUME_INCOMPLETE,
)
# http: An httplib2.Http instance.
# http_request: A http_wrapper.Request.
# exc: Exception being raised.
# num_retries: Number of retries consumed; used for exponential backoff.
ExceptionRetryArgs = collections.namedtuple(
'ExceptionRetryArgs', ['http', 'http_request', 'exc', 'num_retries',
'max_retry_wait', 'total_wait_sec'])
@contextlib.contextmanager
def _Httplib2Debuglevel(http_request, level, http=None):
"""Temporarily change the value of httplib2.debuglevel, if necessary.
If http_request has a `loggable_body` distinct from `body`, then we
need to prevent httplib2 from logging the full body. This sets
httplib2.debuglevel for the duration of the `with` block; however,
that alone won't change the value of existing HTTP connections. If
an httplib2.Http object is provided, we'll also change the level on
any cached connections attached to it.
Args:
http_request: a Request we're logging.
level: (int) the debuglevel for logging.
http: (optional) an httplib2.Http whose connections we should
set the debuglevel on.
Yields:
None.
"""
if http_request.loggable_body is None:
yield
return
old_level = httplib2.debuglevel
http_levels = {}
httplib2.debuglevel = level
if http is not None:
for connection_key, connection in http.connections.items():
# httplib2 stores two kinds of values in this dict, connection
# classes and instances. Since the connection types are all
# old-style classes, we can't easily distinguish by connection
# type -- so instead we use the key pattern.
if ':' not in connection_key:
continue
http_levels[connection_key] = connection.debuglevel
connection.set_debuglevel(level)
yield
httplib2.debuglevel = old_level
if http is not None:
for connection_key, old_level in http_levels.items():
if connection_key in http.connections:
http.connections[connection_key].set_debuglevel(old_level)
class Request(object):
"""Class encapsulating the data for an HTTP request."""
def __init__(self, url='', http_method='GET', headers=None, body=''):
self.url = url
self.http_method = http_method
self.headers = headers or {}
self.__body = None
self.__loggable_body = None
self.body = body
@property
def loggable_body(self):
return self.__loggable_body
@loggable_body.setter
def loggable_body(self, value):
if self.body is None:
raise exceptions.RequestError(
'Cannot set loggable body on request with no body')
self.__loggable_body = value
@property
def body(self):
return self.__body
@body.setter
def body(self, value):
"""Sets the request body; handles logging and length measurement."""
self.__body = value
if value is not None:
# Avoid calling len() which cannot exceed 4GiB in 32-bit python.
body_length = getattr(
self.__body, 'length', None) or len(self.__body)
self.headers['content-length'] = str(body_length)
else:
self.headers.pop('content-length', None)
# This line ensures we don't try to print large requests.
if not isinstance(value, (type(None), six.string_types)):
self.loggable_body = '<media body>'
# Note: currently the order of fields here is important, since we want
# to be able to pass in the result from httplib2.request.
class Response(collections.namedtuple(
'HttpResponse', ['info', 'content', 'request_url'])):
"""Class encapsulating data for an HTTP response."""
__slots__ = ()
def __len__(self):
return self.length
@property
def length(self):
"""Return the length of this response.
We expose this as an attribute since using len() directly can fail
for responses larger than sys.maxint.
Returns:
Response length (as int or long)
"""
def ProcessContentRange(content_range):
_, _, range_spec = content_range.partition(' ')
byte_range, _, _ = range_spec.partition('/')
start, _, end = byte_range.partition('-')
return int(end) - int(start) + 1
if '-content-encoding' in self.info and 'content-range' in self.info:
# httplib2 rewrites content-length in the case of a compressed
# transfer; we can't trust the content-length header in that
# case, but we *can* trust content-range, if it's present.
return ProcessContentRange(self.info['content-range'])
elif 'content-length' in self.info:
return int(self.info.get('content-length'))
elif 'content-range' in self.info:
return ProcessContentRange(self.info['content-range'])
return len(self.content)
@property
def status_code(self):
return int(self.info['status'])
@property
def retry_after(self):
if 'retry-after' in self.info:
return int(self.info['retry-after'])
@property
def is_redirect(self):
return (self.status_code in _REDIRECT_STATUS_CODES and
'location' in self.info)
def CheckResponse(response):
if response is None:
# Caller shouldn't call us if the response is None, but handle anyway.
raise exceptions.RequestError(
'Request to url %s did not return a response.' %
response.request_url)
elif (response.status_code >= 500 or
response.status_code == TOO_MANY_REQUESTS):
raise exceptions.BadStatusCodeError.FromResponse(response)
elif response.retry_after:
raise exceptions.RetryAfterError.FromResponse(response)
def RebuildHttpConnections(http):
"""Rebuilds all http connections in the httplib2.Http instance.
httplib2 overloads the map in http.connections to contain two different
types of values:
{ scheme string: connection class } and
{ scheme + authority string : actual http connection }
Here we remove all of the entries for actual connections so that on the
next request httplib2 will rebuild them from the connection types.
Args:
http: An httplib2.Http instance.
"""
if getattr(http, 'connections', None):
for conn_key in list(http.connections.keys()):
if ':' in conn_key:
del http.connections[conn_key]
def RethrowExceptionHandler(*unused_args):
# pylint: disable=misplaced-bare-raise
raise
def HandleExceptionsAndRebuildHttpConnections(retry_args):
"""Exception handler for http failures.
This catches known failures and rebuilds the underlying HTTP connections.
Args:
retry_args: An ExceptionRetryArgs tuple.
"""
# If the server indicates how long to wait, use that value. Otherwise,
# calculate the wait time on our own.
retry_after = None
# Transport failures
if isinstance(retry_args.exc, (http_client.BadStatusLine,
http_client.IncompleteRead,
http_client.ResponseNotReady)):
logging.debug('Caught HTTP error %s, retrying: %s',
type(retry_args.exc).__name__, retry_args.exc)
elif isinstance(retry_args.exc, socket.error):
logging.debug('Caught socket error, retrying: %s', retry_args.exc)
elif isinstance(retry_args.exc, socket.gaierror):
logging.debug(
'Caught socket address error, retrying: %s', retry_args.exc)
elif isinstance(retry_args.exc, socket.timeout):
logging.debug(
'Caught socket timeout error, retrying: %s', retry_args.exc)
elif isinstance(retry_args.exc, httplib2.ServerNotFoundError):
logging.debug(
'Caught server not found error, retrying: %s', retry_args.exc)
elif isinstance(retry_args.exc, ValueError):
# oauth2client tries to JSON-decode the response, which can result
# in a ValueError if the response was invalid. Until that is fixed in
# oauth2client, need to handle it here.
logging.debug('Response content was invalid (%s), retrying',
retry_args.exc)
elif (isinstance(retry_args.exc, TokenRefreshError) and
hasattr(retry_args.exc, 'status') and
(retry_args.exc.status == TOO_MANY_REQUESTS or
retry_args.exc.status >= 500)):
logging.debug(
'Caught transient credential refresh error (%s), retrying',
retry_args.exc)
elif isinstance(retry_args.exc, exceptions.RequestError):
logging.debug('Request returned no response, retrying')
# API-level failures
elif isinstance(retry_args.exc, exceptions.BadStatusCodeError):
logging.debug('Response returned status %s, retrying',
retry_args.exc.status_code)
elif isinstance(retry_args.exc, exceptions.RetryAfterError):
logging.debug('Response returned a retry-after header, retrying')
retry_after = retry_args.exc.retry_after
else:
raise retry_args.exc
RebuildHttpConnections(retry_args.http)
logging.debug('Retrying request to url %s after exception %s',
retry_args.http_request.url, retry_args.exc)
time.sleep(
retry_after or util.CalculateWaitForRetry(
retry_args.num_retries, max_wait=retry_args.max_retry_wait))
def MakeRequest(http, http_request, retries=7, max_retry_wait=60,
redirections=5,
retry_func=HandleExceptionsAndRebuildHttpConnections,
check_response_func=CheckResponse):
"""Send http_request via the given http, performing error/retry handling.
Args:
http: An httplib2.Http instance, or a http multiplexer that delegates to
an underlying http, for example, HTTPMultiplexer.
http_request: A Request to send.
retries: (int, default 7) Number of retries to attempt on retryable
replies (such as 429 or 5XX).
max_retry_wait: (int, default 60) Maximum number of seconds to wait
when retrying.
redirections: (int, default 5) Number of redirects to follow.
retry_func: Function to handle retries on exceptions. Argument is an
ExceptionRetryArgs tuple.
check_response_func: Function to validate the HTTP response.
Arguments are (Response, response content, url).
Raises:
InvalidDataFromServerError: if there is no response after retries.
Returns:
A Response object.
"""
retry = 0
first_req_time = time.time()
# Provide compatibility for breaking change in httplib2 0.16.0+:
# https://github.com/googleapis/google-api-python-client/issues/803
if hasattr(http, 'redirect_codes'):
http.redirect_codes = set(http.redirect_codes) - {308}
while True:
try:
return _MakeRequestNoRetry(
http, http_request, redirections=redirections,
check_response_func=check_response_func)
# retry_func will consume the exception types it handles and raise.
# pylint: disable=broad-except
except Exception as e:
retry += 1
if retry >= retries:
raise
else:
total_wait_sec = time.time() - first_req_time
retry_func(ExceptionRetryArgs(http, http_request, e, retry,
max_retry_wait, total_wait_sec))
def _MakeRequestNoRetry(http, http_request, redirections=5,
check_response_func=CheckResponse):
"""Send http_request via the given http.
This wrapper exists to handle translation between the plain httplib2
request/response types and the Request and Response types above.
Args:
http: An httplib2.Http instance, or a http multiplexer that delegates to
an underlying http, for example, HTTPMultiplexer.
http_request: A Request to send.
redirections: (int, default 5) Number of redirects to follow.
check_response_func: Function to validate the HTTP response.
Arguments are (Response, response content, url).
Returns:
A Response object.
Raises:
RequestError if no response could be parsed.
"""
connection_type = None
# Handle overrides for connection types. This is used if the caller
# wants control over the underlying connection for managing callbacks
# or hash digestion.
if getattr(http, 'connections', None):
url_scheme = parse.urlsplit(http_request.url).scheme
if url_scheme and url_scheme in http.connections:
connection_type = http.connections[url_scheme]
# Custom printing only at debuglevel 4
new_debuglevel = 4 if httplib2.debuglevel == 4 else 0
with _Httplib2Debuglevel(http_request, new_debuglevel, http=http):
info, content = http.request(
str(http_request.url), method=str(http_request.http_method),
body=http_request.body, headers=http_request.headers,
redirections=redirections, connection_type=connection_type)
if info is None:
raise exceptions.RequestError()
response = Response(info, content, http_request.url)
check_response_func(response)
return response
_HTTP_FACTORIES = []
def _RegisterHttpFactory(factory):
_HTTP_FACTORIES.append(factory)
def GetHttp(**kwds):
for factory in _HTTP_FACTORIES:
http = factory(**kwds)
if http is not None:
return http
return httplib2.Http(**kwds)

View File

@@ -0,0 +1,135 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper function that executes a series of List queries for many APIs."""
from apitools.base.py import encoding
import six
__all__ = [
'YieldFromList',
]
def _GetattrNested(message, attribute):
"""Gets a possibly nested attribute.
Same as getattr() if attribute is a string;
if attribute is a tuple, returns the nested attribute referred to by
the fields in the tuple as if they were a dotted accessor path.
(ex _GetattrNested(msg, ('foo', 'bar', 'baz')) gets msg.foo.bar.baz
"""
if isinstance(attribute, six.string_types):
return getattr(message, attribute)
elif len(attribute) == 0:
return message
else:
return _GetattrNested(getattr(message, attribute[0]), attribute[1:])
def _SetattrNested(message, attribute, value):
"""Sets a possibly nested attribute.
Same as setattr() if attribute is a string;
if attribute is a tuple, sets the nested attribute referred to by
the fields in the tuple as if they were a dotted accessor path.
(ex _SetattrNested(msg, ('foo', 'bar', 'baz'), 'v') sets msg.foo.bar.baz
to 'v'
"""
if isinstance(attribute, six.string_types):
return setattr(message, attribute, value)
elif len(attribute) < 1:
raise ValueError("Need an attribute to set")
elif len(attribute) == 1:
return setattr(message, attribute[0], value)
else:
return setattr(_GetattrNested(message, attribute[:-1]),
attribute[-1], value)
def YieldFromList(
service, request, global_params=None, limit=None, batch_size=100,
method='List', field='items', predicate=None,
current_token_attribute='pageToken',
next_token_attribute='nextPageToken',
batch_size_attribute='maxResults',
get_field_func=_GetattrNested):
"""Make a series of List requests, keeping track of page tokens.
Args:
service: apitools_base.BaseApiService, A service with a .List() method.
request: protorpc.messages.Message, The request message
corresponding to the service's .List() method, with all the
attributes populated except the .maxResults and .pageToken
attributes.
global_params: protorpc.messages.Message, The global query parameters to
provide when calling the given method.
limit: int, The maximum number of records to yield. None if all available
records should be yielded.
batch_size: int, The number of items to retrieve per request.
method: str, The name of the method used to fetch resources.
field: str, The field in the response that will be a list of items.
predicate: lambda, A function that returns true for items to be yielded.
current_token_attribute: str or tuple, The name of the attribute in a
request message holding the page token for the page being
requested. If a tuple, path to attribute.
next_token_attribute: str or tuple, The name of the attribute in a
response message holding the page token for the next page. If a
tuple, path to the attribute.
batch_size_attribute: str or tuple, The name of the attribute in a
response message holding the maximum number of results to be
returned. None if caller-specified batch size is unsupported.
If a tuple, path to the attribute.
get_field_func: Function that returns the items to be yielded. Argument
is response message, and field.
Yields:
protorpc.message.Message, The resources listed by the service.
"""
request = encoding.CopyProtoMessage(request)
_SetattrNested(request, current_token_attribute, None)
while limit is None or limit:
if batch_size_attribute:
# On Py3, None is not comparable so min() below will fail.
# On Py2, None is always less than any number so if batch_size
# is None, the request_batch_size will always be None regardless
# of the value of limit. This doesn't generally strike me as the
# correct behavior, but this change preserves the existing Py2
# behavior on Py3.
if batch_size is None:
request_batch_size = None
else:
request_batch_size = min(batch_size, limit or batch_size)
_SetattrNested(request, batch_size_attribute, request_batch_size)
response = getattr(service, method)(request,
global_params=global_params)
items = get_field_func(response, field)
if predicate:
items = list(filter(predicate, items))
for item in items:
yield item
if limit is None:
continue
limit -= 1
if not limit:
return
token = _GetattrNested(response, next_token_attribute)
if not token:
return
_SetattrNested(request, current_token_attribute, token)

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Small helper class to provide a small slice of a stream."""
from apitools.base.py import exceptions
class StreamSlice(object):
"""Provides a slice-like object for streams."""
def __init__(self, stream, max_bytes):
self.__stream = stream
self.__remaining_bytes = max_bytes
self.__max_bytes = max_bytes
def __str__(self):
return 'Slice of stream %s with %s/%s bytes not yet read' % (
self.__stream, self.__remaining_bytes, self.__max_bytes)
def __len__(self):
return self.__max_bytes
def __nonzero__(self):
# For 32-bit python2.x, len() cannot exceed a 32-bit number; avoid
# accidental len() calls from httplib in the form of "if this_object:".
return bool(self.__max_bytes)
@property
def length(self):
# For 32-bit python2.x, len() cannot exceed a 32-bit number.
return self.__max_bytes
def read(self, size=None): # pylint: disable=missing-docstring
"""Read at most size bytes from this slice.
Compared to other streams, there is one case where we may
unexpectedly raise an exception on read: if the underlying stream
is exhausted (i.e. returns no bytes on read), and the size of this
slice indicates we should still be able to read more bytes, we
raise exceptions.StreamExhausted.
Args:
size: If provided, read no more than size bytes from the stream.
Returns:
The bytes read from this slice.
Raises:
exceptions.StreamExhausted
"""
if size is not None:
read_size = min(size, self.__remaining_bytes)
else:
read_size = self.__remaining_bytes
data = self.__stream.read(read_size)
if read_size > 0 and not data:
raise exceptions.StreamExhausted(
'Not enough bytes in stream; expected %d, exhausted '
'after %d' % (
self.__max_bytes,
self.__max_bytes - self.__remaining_bytes))
self.__remaining_bytes -= len(data)
return data

View File

@@ -0,0 +1,16 @@
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Package marker file."""

View File

@@ -0,0 +1,405 @@
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The mock module allows easy mocking of apitools clients.
This module allows you to mock out the constructor of a particular apitools
client, for a specific API and version. Then, when the client is created, it
will be run against an expected session that you define. This way code that is
not aware of the testing framework can construct new clients as normal, as long
as it's all done within the context of a mock.
"""
import difflib
import sys
import six
from apitools.base.protorpclite import messages
from apitools.base.py import base_api
from apitools.base.py import encoding
from apitools.base.py import exceptions
class Error(Exception):
"""Exceptions for this module."""
def _MessagesEqual(msg1, msg2):
"""Compare two protorpc messages for equality.
Using python's == operator does not work in all cases, specifically when
there is a list involved.
Args:
msg1: protorpc.messages.Message or [protorpc.messages.Message] or number
or string, One of the messages to compare.
msg2: protorpc.messages.Message or [protorpc.messages.Message] or number
or string, One of the messages to compare.
Returns:
If the messages are isomorphic.
"""
if isinstance(msg1, list) and isinstance(msg2, list):
if len(msg1) != len(msg2):
return False
return all(_MessagesEqual(x, y) for x, y in zip(msg1, msg2))
if (not isinstance(msg1, messages.Message) or
not isinstance(msg2, messages.Message)):
return msg1 == msg2
for field in msg1.all_fields():
field1 = getattr(msg1, field.name)
field2 = getattr(msg2, field.name)
if not _MessagesEqual(field1, field2):
return False
return True
class UnexpectedRequestException(Error):
def __init__(self, received_call, expected_call):
expected_key, expected_request = expected_call
received_key, received_request = received_call
expected_repr = encoding.MessageToRepr(
expected_request, multiline=True)
received_repr = encoding.MessageToRepr(
received_request, multiline=True)
expected_lines = expected_repr.splitlines()
received_lines = received_repr.splitlines()
diff_lines = difflib.unified_diff(expected_lines, received_lines)
diff = '\n'.join(diff_lines)
if expected_key != received_key:
msg = '\n'.join((
'expected: {expected_key}({expected_request})',
'received: {received_key}({received_request})',
'',
)).format(
expected_key=expected_key,
expected_request=expected_repr,
received_key=received_key,
received_request=received_repr)
super(UnexpectedRequestException, self).__init__(msg)
else:
msg = '\n'.join((
'for request to {key},',
'expected: {expected_request}',
'received: {received_request}',
'diff: {diff}',
'',
)).format(
key=expected_key,
expected_request=expected_repr,
received_request=received_repr,
diff=diff)
super(UnexpectedRequestException, self).__init__(msg)
class ExpectedRequestsException(Error):
def __init__(self, expected_calls):
msg = 'expected:\n'
for (key, request) in expected_calls:
msg += '{key}({request})\n'.format(
key=key,
request=encoding.MessageToRepr(request, multiline=True))
super(ExpectedRequestsException, self).__init__(msg)
class _ExpectedRequestResponse(object):
"""Encapsulation of an expected request and corresponding response."""
def __init__(self, key, request, response=None, exception=None):
self.__key = key
self.__request = request
if response and exception:
raise exceptions.ConfigurationValueError(
'Should specify at most one of response and exception')
if response and isinstance(response, exceptions.Error):
raise exceptions.ConfigurationValueError(
'Responses should not be an instance of Error')
if exception and not isinstance(exception, exceptions.Error):
raise exceptions.ConfigurationValueError(
'Exceptions must be instances of Error')
self.__response = response
self.__exception = exception
@property
def key(self):
return self.__key
@property
def request(self):
return self.__request
def ValidateAndRespond(self, key, request):
"""Validate that key and request match expectations, and respond if so.
Args:
key: str, Actual key to compare against expectations.
request: protorpc.messages.Message or [protorpc.messages.Message]
or number or string, Actual request to compare againt expectations
Raises:
UnexpectedRequestException: If key or request dont match
expectations.
apitools_base.Error: If a non-None exception is specified to
be thrown.
Returns:
The response that was specified to be returned.
"""
if key != self.__key or not (self.__request == request or
_MessagesEqual(request, self.__request)):
raise UnexpectedRequestException((key, request),
(self.__key, self.__request))
if self.__exception:
# Can only throw apitools_base.Error.
raise self.__exception # pylint: disable=raising-bad-type
return self.__response
class _MockedMethod(object):
"""A mocked API service method."""
def __init__(self, key, mocked_client, real_method):
self.__name__ = real_method.__name__
self.__key = key
self.__mocked_client = mocked_client
self.__real_method = real_method
self.method_config = real_method.method_config
config = self.method_config()
self.__request_type = getattr(self.__mocked_client.MESSAGES_MODULE,
config.request_type_name)
self.__response_type = getattr(self.__mocked_client.MESSAGES_MODULE,
config.response_type_name)
def _TypeCheck(self, msg, is_request):
"""Ensure the given message is of the expected type of this method.
Args:
msg: The message instance to check.
is_request: True to validate against the expected request type,
False to validate against the expected response type.
Raises:
exceptions.ConfigurationValueError: If the type of the message was
not correct.
"""
if is_request:
mode = 'request'
real_type = self.__request_type
else:
mode = 'response'
real_type = self.__response_type
if not isinstance(msg, real_type):
raise exceptions.ConfigurationValueError(
'Expected {} is not of the correct type for method [{}].\n'
' Required: [{}]\n'
' Given: [{}]'.format(
mode, self.__key, real_type, type(msg)))
def Expect(self, request, response=None, exception=None,
enable_type_checking=True, **unused_kwargs):
"""Add an expectation on the mocked method.
Exactly one of response and exception should be specified.
Args:
request: The request that should be expected
response: The response that should be returned or None if
exception is provided.
exception: An exception that should be thrown, or None.
enable_type_checking: When true, the message type of the request
and response (if provided) will be checked against the types
required by this method.
"""
# TODO(user): the unused_kwargs provides a placeholder for
# future things that can be passed to Expect(), like special
# params to the method call.
# Ensure that the registered request and response mocks actually
# match what this method accepts and returns.
if enable_type_checking:
self._TypeCheck(request, is_request=True)
if response:
self._TypeCheck(response, is_request=False)
# pylint: disable=protected-access
# Class in same module.
self.__mocked_client._request_responses.append(
_ExpectedRequestResponse(self.__key,
request,
response=response,
exception=exception))
# pylint: enable=protected-access
def __call__(self, request, **unused_kwargs):
# TODO(user): allow the testing code to expect certain
# values in these currently unused_kwargs, especially the
# upload parameter used by media-heavy services like bigquery
# or bigstore.
# pylint: disable=protected-access
# Class in same module.
if self.__mocked_client._request_responses:
request_response = self.__mocked_client._request_responses.pop(0)
else:
raise UnexpectedRequestException(
(self.__key, request), (None, None))
# pylint: enable=protected-access
response = request_response.ValidateAndRespond(self.__key, request)
if response is None and self.__real_method:
response = self.__real_method(request)
print(encoding.MessageToRepr(
response, multiline=True, shortstrings=True))
return response
return response
def _MakeMockedService(api_name, collection_name,
mock_client, service, real_service):
class MockedService(base_api.BaseApiService):
pass
for method in service.GetMethodsList():
real_method = None
if real_service:
real_method = getattr(real_service, method)
setattr(MockedService,
method,
_MockedMethod(api_name + '.' + collection_name + '.' + method,
mock_client,
real_method))
return MockedService
class Client(object):
"""Mock an apitools client."""
def __init__(self, client_class, real_client=None):
"""Mock an apitools API, given its class.
Args:
client_class: The class for the API. eg, if you
from apis.sqladmin import v1beta3
then you can pass v1beta3.SqladminV1beta3 to this class
and anything within its context will use your mocked
version.
real_client: apitools Client, The client to make requests
against when the expected response is None.
"""
if not real_client:
real_client = client_class(get_credentials=False)
self.__orig_class = self.__class__
self.__client_class = client_class
self.__real_service_classes = {}
self.__real_client = real_client
self._request_responses = []
self.__real_include_fields = None
def __enter__(self):
return self.Mock()
def Mock(self):
"""Stub out the client class with mocked services."""
client = self.__real_client or self.__client_class(
get_credentials=False)
class Patched(self.__class__, self.__client_class):
pass
self.__class__ = Patched
for name in dir(self.__client_class):
service_class = getattr(self.__client_class, name)
if not isinstance(service_class, type):
continue
if not issubclass(service_class, base_api.BaseApiService):
continue
self.__real_service_classes[name] = service_class
# pylint: disable=protected-access
collection_name = service_class._NAME
# pylint: enable=protected-access
api_name = '%s_%s' % (self.__client_class._PACKAGE,
self.__client_class._URL_VERSION)
mocked_service_class = _MakeMockedService(
api_name, collection_name, self,
service_class,
service_class(client) if self.__real_client else None)
setattr(self.__client_class, name, mocked_service_class)
setattr(self, collection_name, mocked_service_class(self))
self.__real_include_fields = self.__client_class.IncludeFields
self.__client_class.IncludeFields = self.IncludeFields
# pylint: disable=attribute-defined-outside-init
self._url = client._url
self._http = client._http
return self
def __exit__(self, exc_type, value, traceback):
is_active_exception = value is not None
self.Unmock(suppress=is_active_exception)
if is_active_exception:
six.reraise(exc_type, value, traceback)
return True
def Unmock(self, suppress=False):
self.__class__ = self.__orig_class
for name, service_class in self.__real_service_classes.items():
setattr(self.__client_class, name, service_class)
delattr(self, service_class._NAME)
self.__real_service_classes = {}
del self._url
del self._http
self.__client_class.IncludeFields = self.__real_include_fields
self.__real_include_fields = None
requests = [(rq_rs.key, rq_rs.request)
for rq_rs in self._request_responses]
self._request_responses = []
if requests and not suppress and sys.exc_info()[1] is None:
raise ExpectedRequestsException(requests)
def IncludeFields(self, include_fields):
if self.__real_client:
return self.__real_include_fields(self.__real_client,
include_fields)

View File

@@ -0,0 +1,237 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Assorted utilities shared between parts of apitools."""
import os
import random
import six
from six.moves import http_client
import six.moves.urllib.error as urllib_error
import six.moves.urllib.parse as urllib_parse
import six.moves.urllib.request as urllib_request
from apitools.base.protorpclite import messages
from apitools.base.py import encoding_helper as encoding
from apitools.base.py import exceptions
if six.PY3:
from collections.abc import Iterable
else:
from collections import Iterable
__all__ = [
'DetectGae',
'DetectGce',
]
_RESERVED_URI_CHARS = r":/?#[]@!$&'()*+,;="
def DetectGae():
"""Determine whether or not we're running on GAE.
This is based on:
https://developers.google.com/appengine/docs/python/#The_Environment
Returns:
True iff we're running on GAE.
"""
server_software = os.environ.get('SERVER_SOFTWARE', '')
return (server_software.startswith('Development/') or
server_software.startswith('Google App Engine/'))
def DetectGce():
"""Determine whether or not we're running on GCE.
This is based on:
https://cloud.google.com/compute/docs/metadata#runninggce
Returns:
True iff we're running on a GCE instance.
"""
metadata_url = 'http://{}'.format(
os.environ.get('GCE_METADATA_ROOT', 'metadata.google.internal'))
try:
o = urllib_request.build_opener(urllib_request.ProxyHandler({})).open(
urllib_request.Request(
metadata_url, headers={'Metadata-Flavor': 'Google'}))
except urllib_error.URLError:
return False
return (o.getcode() == http_client.OK and
o.headers.get('metadata-flavor') == 'Google')
def NormalizeScopes(scope_spec):
"""Normalize scope_spec to a set of strings."""
if isinstance(scope_spec, six.string_types):
scope_spec = six.ensure_str(scope_spec)
return set(scope_spec.split(' '))
elif isinstance(scope_spec, Iterable):
scope_spec = [six.ensure_str(x) for x in scope_spec]
return set(scope_spec)
raise exceptions.TypecheckError(
'NormalizeScopes expected string or iterable, found %s' % (
type(scope_spec),))
def Typecheck(arg, arg_type, msg=None):
if not isinstance(arg, arg_type):
if msg is None:
if isinstance(arg_type, tuple):
msg = 'Type of arg is "%s", not one of %r' % (
type(arg), arg_type)
else:
msg = 'Type of arg is "%s", not "%s"' % (type(arg), arg_type)
raise exceptions.TypecheckError(msg)
return arg
def ExpandRelativePath(method_config, params, relative_path=None):
"""Determine the relative path for request."""
path = relative_path or method_config.relative_path or ''
for param in method_config.path_params:
param_template = '{%s}' % param
# For more details about "reserved word expansion", see:
# http://tools.ietf.org/html/rfc6570#section-3.2.2
reserved_chars = ''
reserved_template = '{+%s}' % param
if reserved_template in path:
reserved_chars = _RESERVED_URI_CHARS
path = path.replace(reserved_template, param_template)
if param_template not in path:
raise exceptions.InvalidUserInputError(
'Missing path parameter %s' % param)
try:
# TODO(user): Do we want to support some sophisticated
# mapping here?
value = params[param]
except KeyError:
raise exceptions.InvalidUserInputError(
'Request missing required parameter %s' % param)
if value is None:
raise exceptions.InvalidUserInputError(
'Request missing required parameter %s' % param)
try:
if not isinstance(value, six.string_types):
value = str(value)
path = path.replace(param_template,
urllib_parse.quote(value.encode('utf_8'),
reserved_chars))
except TypeError as e:
raise exceptions.InvalidUserInputError(
'Error setting required parameter %s to value %s: %s' % (
param, value, e))
return path
def CalculateWaitForRetry(retry_attempt, max_wait=60):
"""Calculates amount of time to wait before a retry attempt.
Wait time grows exponentially with the number of attempts. A
random amount of jitter is added to spread out retry attempts from
different clients.
Args:
retry_attempt: Retry attempt counter.
max_wait: Upper bound for wait time [seconds].
Returns:
Number of seconds to wait before retrying request.
"""
wait_time = 2 ** retry_attempt
max_jitter = wait_time / 4.0
wait_time += random.uniform(-max_jitter, max_jitter)
return max(1, min(wait_time, max_wait))
def AcceptableMimeType(accept_patterns, mime_type):
"""Return True iff mime_type is acceptable for one of accept_patterns.
Note that this function assumes that all patterns in accept_patterns
will be simple types of the form "type/subtype", where one or both
of these can be "*". We do not support parameters (i.e. "; q=") in
patterns.
Args:
accept_patterns: list of acceptable MIME types.
mime_type: the mime type we would like to match.
Returns:
Whether or not mime_type matches (at least) one of these patterns.
"""
if '/' not in mime_type:
raise exceptions.InvalidUserInputError(
'Invalid MIME type: "%s"' % mime_type)
unsupported_patterns = [p for p in accept_patterns if ';' in p]
if unsupported_patterns:
raise exceptions.GeneratedClientError(
'MIME patterns with parameter unsupported: "%s"' % ', '.join(
unsupported_patterns))
def MimeTypeMatches(pattern, mime_type):
"""Return True iff mime_type is acceptable for pattern."""
# Some systems use a single '*' instead of '*/*'.
if pattern == '*':
pattern = '*/*'
return all(accept in ('*', provided) for accept, provided
in zip(pattern.split('/'), mime_type.split('/')))
return any(MimeTypeMatches(pattern, mime_type)
for pattern in accept_patterns)
def MapParamNames(params, request_type):
"""Reverse parameter remappings for URL construction."""
return [encoding.GetCustomJsonFieldMapping(request_type, json_name=p) or p
for p in params]
def MapRequestParams(params, request_type):
"""Perform any renames/remappings needed for URL construction.
Currently, we have several ways to customize JSON encoding, in
particular of field names and enums. This works fine for JSON
bodies, but also needs to be applied for path and query parameters
in the URL.
This function takes a dictionary from param names to values, and
performs any registered mappings. We also need the request type (to
look up the mappings).
Args:
params: (dict) Map from param names to values
request_type: (protorpc.messages.Message) request type for this API call
Returns:
A new dict of the same size, with all registered mappings applied.
"""
new_params = dict(params)
for param_name, value in params.items():
field_remapping = encoding.GetCustomJsonFieldMapping(
request_type, python_name=param_name)
if field_remapping is not None:
new_params[field_remapping] = new_params.pop(param_name)
param_name = field_remapping
if isinstance(value, messages.Enum):
new_params[param_name] = encoding.GetCustomJsonEnumMapping(
type(value), python_name=str(value)) or str(value)
return new_params