feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared __init__.py for apitools."""
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View File

@@ -0,0 +1,615 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Services descriptor definitions.
Contains message definitions and functions for converting
service classes into transmittable message format.
Describing an Enum instance, Enum class, Field class or Message class will
generate an appropriate descriptor object that describes that class.
This message can itself be used to transmit information to clients wishing
to know the description of an enum value, enum, field or message without
needing to download the source code. This format is also compatible with
other, non-Python languages.
The descriptors are modeled to be binary compatible with
https://github.com/google/protobuf
NOTE: The names of types and fields are not always the same between these
descriptors and the ones defined in descriptor.proto. This was done in order
to make source code files that use these descriptors easier to read. For
example, it is not necessary to prefix TYPE to all the values in
FieldDescriptor.Variant as is done in descriptor.proto
FieldDescriptorProto.Type.
Example:
class Pixel(messages.Message):
x = messages.IntegerField(1, required=True)
y = messages.IntegerField(2, required=True)
color = messages.BytesField(3)
# Describe Pixel class using message descriptor.
fields = []
field = FieldDescriptor()
field.name = 'x'
field.number = 1
field.label = FieldDescriptor.Label.REQUIRED
field.variant = FieldDescriptor.Variant.INT64
fields.append(field)
field = FieldDescriptor()
field.name = 'y'
field.number = 2
field.label = FieldDescriptor.Label.REQUIRED
field.variant = FieldDescriptor.Variant.INT64
fields.append(field)
field = FieldDescriptor()
field.name = 'color'
field.number = 3
field.label = FieldDescriptor.Label.OPTIONAL
field.variant = FieldDescriptor.Variant.BYTES
fields.append(field)
message = MessageDescriptor()
message.name = 'Pixel'
message.fields = fields
# Describing is the equivalent of building the above message.
message == describe_message(Pixel)
Public Classes:
EnumValueDescriptor: Describes Enum values.
EnumDescriptor: Describes Enum classes.
FieldDescriptor: Describes field instances.
FileDescriptor: Describes a single 'file' unit.
FileSet: Describes a collection of file descriptors.
MessageDescriptor: Describes Message classes.
Public Functions:
describe_enum_value: Describe an individual enum-value.
describe_enum: Describe an Enum class.
describe_field: Describe a Field definition.
describe_file: Describe a 'file' unit from a Python module or object.
describe_file_set: Describe a file set from a list of modules or objects.
describe_message: Describe a Message definition.
"""
import codecs
import types
import six
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
__all__ = [
'EnumDescriptor',
'EnumValueDescriptor',
'FieldDescriptor',
'MessageDescriptor',
'FileDescriptor',
'FileSet',
'DescriptorLibrary',
'describe_enum',
'describe_enum_value',
'describe_field',
'describe_message',
'describe_file',
'describe_file_set',
'describe',
'import_descriptor_loader',
]
# NOTE: MessageField is missing because message fields cannot have
# a default value at this time.
# TODO(user): Support default message values.
#
# Map to functions that convert default values of fields of a given type
# to a string. The function must return a value that is compatible with
# FieldDescriptor.default_value and therefore a unicode string.
_DEFAULT_TO_STRING_MAP = {
messages.IntegerField: six.text_type,
messages.FloatField: six.text_type,
messages.BooleanField: lambda value: value and u'true' or u'false',
messages.BytesField: lambda value: codecs.escape_encode(value)[0],
messages.StringField: lambda value: value,
messages.EnumField: lambda value: six.text_type(value.number),
}
_DEFAULT_FROM_STRING_MAP = {
messages.IntegerField: int,
messages.FloatField: float,
messages.BooleanField: lambda value: value == u'true',
messages.BytesField: lambda value: codecs.escape_decode(value)[0],
messages.StringField: lambda value: value,
messages.EnumField: int,
}
class EnumValueDescriptor(messages.Message):
"""Enum value descriptor.
Fields:
name: Name of enumeration value.
number: Number of enumeration value.
"""
# TODO(user): Why are these listed as optional in descriptor.proto.
# Harmonize?
name = messages.StringField(1, required=True)
number = messages.IntegerField(2,
required=True,
variant=messages.Variant.INT32)
class EnumDescriptor(messages.Message):
"""Enum class descriptor.
Fields:
name: Name of Enum without any qualification.
values: Values defined by Enum class.
"""
name = messages.StringField(1)
values = messages.MessageField(EnumValueDescriptor, 2, repeated=True)
class FieldDescriptor(messages.Message):
"""Field definition descriptor.
Enums:
Variant: Wire format hint sub-types for field.
Label: Values for optional, required and repeated fields.
Fields:
name: Name of field.
number: Number of field.
variant: Variant of field.
type_name: Type name for message and enum fields.
default_value: String representation of default value.
"""
Variant = messages.Variant # pylint:disable=invalid-name
class Label(messages.Enum):
"""Field label."""
OPTIONAL = 1
REQUIRED = 2
REPEATED = 3
name = messages.StringField(1, required=True)
number = messages.IntegerField(3,
required=True,
variant=messages.Variant.INT32)
label = messages.EnumField(Label, 4, default=Label.OPTIONAL)
variant = messages.EnumField(Variant, 5)
type_name = messages.StringField(6)
# For numeric types, contains the original text representation of
# the value.
# For booleans, "true" or "false".
# For strings, contains the default text contents (not escaped in any
# way).
# For bytes, contains the C escaped value. All bytes < 128 are that are
# traditionally considered unprintable are also escaped.
default_value = messages.StringField(7)
class MessageDescriptor(messages.Message):
"""Message definition descriptor.
Fields:
name: Name of Message without any qualification.
fields: Fields defined for message.
message_types: Nested Message classes defined on message.
enum_types: Nested Enum classes defined on message.
"""
name = messages.StringField(1)
fields = messages.MessageField(FieldDescriptor, 2, repeated=True)
message_types = messages.MessageField(
'apitools.base.protorpclite.descriptor.MessageDescriptor', 3,
repeated=True)
enum_types = messages.MessageField(EnumDescriptor, 4, repeated=True)
class FileDescriptor(messages.Message):
"""Description of file containing protobuf definitions.
Fields:
package: Fully qualified name of package that definitions belong to.
message_types: Message definitions contained in file.
enum_types: Enum definitions contained in file.
"""
package = messages.StringField(2)
# TODO(user): Add dependency field
message_types = messages.MessageField(MessageDescriptor, 4, repeated=True)
enum_types = messages.MessageField(EnumDescriptor, 5, repeated=True)
class FileSet(messages.Message):
"""A collection of FileDescriptors.
Fields:
files: Files in file-set.
"""
files = messages.MessageField(FileDescriptor, 1, repeated=True)
def describe_enum_value(enum_value):
"""Build descriptor for Enum instance.
Args:
enum_value: Enum value to provide descriptor for.
Returns:
Initialized EnumValueDescriptor instance describing the Enum instance.
"""
enum_value_descriptor = EnumValueDescriptor()
enum_value_descriptor.name = six.text_type(enum_value.name)
enum_value_descriptor.number = enum_value.number
return enum_value_descriptor
def describe_enum(enum_definition):
"""Build descriptor for Enum class.
Args:
enum_definition: Enum class to provide descriptor for.
Returns:
Initialized EnumDescriptor instance describing the Enum class.
"""
enum_descriptor = EnumDescriptor()
enum_descriptor.name = enum_definition.definition_name().split('.')[-1]
values = []
for number in sorted(enum_definition.numbers()):
value = enum_definition.lookup_by_number(number)
values.append(describe_enum_value(value))
if values:
enum_descriptor.values = values
return enum_descriptor
def describe_field(field_definition):
"""Build descriptor for Field instance.
Args:
field_definition: Field instance to provide descriptor for.
Returns:
Initialized FieldDescriptor instance describing the Field instance.
"""
field_descriptor = FieldDescriptor()
field_descriptor.name = field_definition.name
field_descriptor.number = field_definition.number
field_descriptor.variant = field_definition.variant
if isinstance(field_definition, messages.EnumField):
field_descriptor.type_name = field_definition.type.definition_name()
if isinstance(field_definition, messages.MessageField):
field_descriptor.type_name = (
field_definition.message_type.definition_name())
if field_definition.default is not None:
field_descriptor.default_value = _DEFAULT_TO_STRING_MAP[
type(field_definition)](field_definition.default)
# Set label.
if field_definition.repeated:
field_descriptor.label = FieldDescriptor.Label.REPEATED
elif field_definition.required:
field_descriptor.label = FieldDescriptor.Label.REQUIRED
else:
field_descriptor.label = FieldDescriptor.Label.OPTIONAL
return field_descriptor
def describe_message(message_definition):
"""Build descriptor for Message class.
Args:
message_definition: Message class to provide descriptor for.
Returns:
Initialized MessageDescriptor instance describing the Message class.
"""
message_descriptor = MessageDescriptor()
message_descriptor.name = message_definition.definition_name().split(
'.')[-1]
fields = sorted(message_definition.all_fields(),
key=lambda v: v.number)
if fields:
message_descriptor.fields = [describe_field(field) for field in fields]
try:
nested_messages = message_definition.__messages__
except AttributeError:
pass
else:
message_descriptors = []
for name in nested_messages:
value = getattr(message_definition, name)
message_descriptors.append(describe_message(value))
message_descriptor.message_types = message_descriptors
try:
nested_enums = message_definition.__enums__
except AttributeError:
pass
else:
enum_descriptors = []
for name in nested_enums:
value = getattr(message_definition, name)
enum_descriptors.append(describe_enum(value))
message_descriptor.enum_types = enum_descriptors
return message_descriptor
def describe_file(module):
"""Build a file from a specified Python module.
Args:
module: Python module to describe.
Returns:
Initialized FileDescriptor instance describing the module.
"""
descriptor = FileDescriptor()
descriptor.package = util.get_package_for_module(module)
if not descriptor.package:
descriptor.package = None
message_descriptors = []
enum_descriptors = []
# Need to iterate over all top level attributes of the module looking for
# message and enum definitions. Each definition must be itself described.
for name in sorted(dir(module)):
value = getattr(module, name)
if isinstance(value, type):
if issubclass(value, messages.Message):
message_descriptors.append(describe_message(value))
elif issubclass(value, messages.Enum):
enum_descriptors.append(describe_enum(value))
if message_descriptors:
descriptor.message_types = message_descriptors
if enum_descriptors:
descriptor.enum_types = enum_descriptors
return descriptor
def describe_file_set(modules):
"""Build a file set from a specified Python modules.
Args:
modules: Iterable of Python module to describe.
Returns:
Initialized FileSet instance describing the modules.
"""
descriptor = FileSet()
file_descriptors = []
for module in modules:
file_descriptors.append(describe_file(module))
if file_descriptors:
descriptor.files = file_descriptors
return descriptor
def describe(value):
"""Describe any value as a descriptor.
Helper function for describing any object with an appropriate descriptor
object.
Args:
value: Value to describe as a descriptor.
Returns:
Descriptor message class if object is describable as a descriptor, else
None.
"""
if isinstance(value, types.ModuleType):
return describe_file(value)
elif isinstance(value, messages.Field):
return describe_field(value)
elif isinstance(value, messages.Enum):
return describe_enum_value(value)
elif isinstance(value, type):
if issubclass(value, messages.Message):
return describe_message(value)
elif issubclass(value, messages.Enum):
return describe_enum(value)
return None
@util.positional(1)
def import_descriptor_loader(definition_name, importer=__import__):
"""Find objects by importing modules as needed.
A definition loader is a function that resolves a definition name to a
descriptor.
The import finder resolves definitions to their names by importing modules
when necessary.
Args:
definition_name: Name of definition to find.
importer: Import function used for importing new modules.
Returns:
Appropriate descriptor for any describable type located by name.
Raises:
DefinitionNotFoundError when a name does not refer to either a definition
or a module.
"""
# Attempt to import descriptor as a module.
if definition_name.startswith('.'):
definition_name = definition_name[1:]
if not definition_name.startswith('.'):
leaf = definition_name.split('.')[-1]
if definition_name:
try:
module = importer(definition_name, '', '', [leaf])
except ImportError:
pass
else:
return describe(module)
try:
# Attempt to use messages.find_definition to find item.
return describe(messages.find_definition(definition_name,
importer=__import__))
except messages.DefinitionNotFoundError as err:
# There are things that find_definition will not find, but if
# the parent is loaded, its children can be searched for a
# match.
split_name = definition_name.rsplit('.', 1)
if len(split_name) > 1:
parent, child = split_name
try:
parent_definition = import_descriptor_loader(
parent, importer=importer)
except messages.DefinitionNotFoundError:
# Fall through to original error.
pass
else:
# Check the parent definition for a matching descriptor.
if isinstance(parent_definition, EnumDescriptor):
search_list = parent_definition.values or []
elif isinstance(parent_definition, MessageDescriptor):
search_list = parent_definition.fields or []
else:
search_list = []
for definition in search_list:
if definition.name == child:
return definition
# Still didn't find. Reraise original exception.
raise err
class DescriptorLibrary(object):
"""A descriptor library is an object that contains known definitions.
A descriptor library contains a cache of descriptor objects mapped by
definition name. It contains all types of descriptors except for
file sets.
When a definition name is requested that the library does not know about
it can be provided with a descriptor loader which attempt to resolve the
missing descriptor.
"""
@util.positional(1)
def __init__(self,
descriptors=None,
descriptor_loader=import_descriptor_loader):
"""Constructor.
Args:
descriptors: A dictionary or dictionary-like object that can be used
to store and cache descriptors by definition name.
definition_loader: A function used for resolving missing descriptors.
The function takes a definition name as its parameter and returns
an appropriate descriptor. It may raise DefinitionNotFoundError.
"""
self.__descriptor_loader = descriptor_loader
self.__descriptors = descriptors or {}
def lookup_descriptor(self, definition_name):
"""Lookup descriptor by name.
Get descriptor from library by name. If descriptor is not found will
attempt to find via descriptor loader if provided.
Args:
definition_name: Definition name to find.
Returns:
Descriptor that describes definition name.
Raises:
DefinitionNotFoundError if not descriptor exists for definition name.
"""
try:
return self.__descriptors[definition_name]
except KeyError:
pass
if self.__descriptor_loader:
definition = self.__descriptor_loader(definition_name)
self.__descriptors[definition_name] = definition
return definition
else:
raise messages.DefinitionNotFoundError(
'Could not find definition for %s' % definition_name)
def lookup_package(self, definition_name):
"""Determines the package name for any definition.
Determine the package that any definition name belongs to. May
check parent for package name and will resolve missing
descriptors if provided descriptor loader.
Args:
definition_name: Definition name to find package for.
"""
while True:
descriptor = self.lookup_descriptor(definition_name)
if isinstance(descriptor, FileDescriptor):
return descriptor.package
else:
index = definition_name.rfind('.')
if index < 0:
return None
definition_name = definition_name[:index]

View File

@@ -0,0 +1,119 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Simple protocol message types.
Includes new message and field types that are outside what is defined by the
protocol buffers standard.
"""
import datetime
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
__all__ = [
'DateTimeField',
'DateTimeMessage',
'VoidMessage',
]
class VoidMessage(messages.Message):
"""Empty message."""
class DateTimeMessage(messages.Message):
"""Message to store/transmit a DateTime.
Fields:
milliseconds: Milliseconds since Jan 1st 1970 local time.
time_zone_offset: Optional time zone offset, in minutes from UTC.
"""
milliseconds = messages.IntegerField(1, required=True)
time_zone_offset = messages.IntegerField(2)
class DateTimeField(messages.MessageField):
"""Field definition for datetime values.
Stores a python datetime object as a field. If time zone information is
included in the datetime object, it will be included in
the encoded data when this is encoded/decoded.
"""
type = datetime.datetime
message_type = DateTimeMessage
@util.positional(3)
def __init__(self,
number,
**kwargs):
super(DateTimeField, self).__init__(self.message_type,
number,
**kwargs)
def value_from_message(self, message):
"""Convert DateTimeMessage to a datetime.
Args:
A DateTimeMessage instance.
Returns:
A datetime instance.
"""
message = super(DateTimeField, self).value_from_message(message)
if message.time_zone_offset is None:
return datetime.datetime.fromtimestamp(
message.milliseconds / 1000.0, tz=datetime.timezone.utc).replace(tzinfo=None)
# Need to subtract the time zone offset, because when we call
# datetime.fromtimestamp, it will add the time zone offset to the
# value we pass.
milliseconds = (message.milliseconds -
60000 * message.time_zone_offset)
timezone = util.TimeZoneOffset(message.time_zone_offset)
return datetime.datetime.fromtimestamp(milliseconds / 1000.0,
tz=timezone)
def value_to_message(self, value):
value = super(DateTimeField, self).value_to_message(value)
# First, determine the delta from the epoch, so we can fill in
# DateTimeMessage's milliseconds field.
if value.tzinfo is None:
time_zone_offset = 0
local_epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None)
else:
time_zone_offset = util.total_seconds(
value.tzinfo.utcoffset(value))
# Determine Jan 1, 1970 local time.
local_epoch = datetime.datetime.fromtimestamp(-time_zone_offset,
tz=value.tzinfo)
delta = value - local_epoch
# Create and fill in the DateTimeMessage, including time zone if
# one was specified.
message = DateTimeMessage()
message.milliseconds = int(util.total_seconds(delta) * 1000)
if value.tzinfo is not None:
utc_offset = value.tzinfo.utcoffset(value)
if utc_offset is not None:
message.time_zone_offset = int(
util.total_seconds(value.tzinfo.utcoffset(value)) / 60)
return message

View File

@@ -0,0 +1,400 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""JSON support for message types.
Public classes:
MessageJSONEncoder: JSON encoder for message objects.
Public functions:
encode_message: Encodes a message in to a JSON string.
decode_message: Merge from a JSON string in to a message.
"""
import base64
import binascii
import logging
import six
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
__all__ = [
'ALTERNATIVE_CONTENT_TYPES',
'CONTENT_TYPE',
'MessageJSONEncoder',
'encode_message',
'decode_message',
'ProtoJson',
]
def _load_json_module():
"""Try to load a valid json module.
There are more than one json modules that might be installed. They are
mostly compatible with one another but some versions may be different.
This function attempts to load various json modules in a preferred order.
It does a basic check to guess if a loaded version of json is compatible.
Returns:
Compatible json module.
Raises:
ImportError if there are no json modules or the loaded json module is
not compatible with ProtoRPC.
"""
first_import_error = None
for module_name in ['json',
'simplejson']:
try:
module = __import__(module_name, {}, {}, 'json')
if not hasattr(module, 'JSONEncoder'):
message = (
'json library "%s" is not compatible with ProtoRPC' %
module_name)
logging.warning(message)
raise ImportError(message)
else:
return module
except ImportError as err:
if not first_import_error:
first_import_error = err
logging.error('Must use valid json library (json or simplejson)')
raise first_import_error # pylint:disable=raising-bad-type
json = _load_json_module()
# TODO: Rename this to MessageJsonEncoder.
class MessageJSONEncoder(json.JSONEncoder):
"""Message JSON encoder class.
Extension of JSONEncoder that can build JSON from a message object.
"""
def __init__(self, protojson_protocol=None, **kwargs):
"""Constructor.
Args:
protojson_protocol: ProtoJson instance.
"""
super(MessageJSONEncoder, self).__init__(**kwargs)
self.__protojson_protocol = (
protojson_protocol or ProtoJson.get_default())
def default(self, value):
"""Return dictionary instance from a message object.
Args:
value: Value to get dictionary for. If not encodable, will
call superclasses default method.
"""
if isinstance(value, messages.Enum):
return str(value)
if six.PY3 and isinstance(value, bytes):
return value.decode('utf8')
if isinstance(value, messages.Message):
result = {}
for field in value.all_fields():
item = value.get_assigned_value(field.name)
if item not in (None, [], ()):
result[field.name] = (
self.__protojson_protocol.encode_field(field, item))
# Handle unrecognized fields, so they're included when a message is
# decoded then encoded.
for unknown_key in value.all_unrecognized_fields():
unrecognized_field, _ = value.get_unrecognized_field_info(
unknown_key)
# Unknown fields are not encoded as they should have been
# processed before we get to here.
result[unknown_key] = unrecognized_field
return result
return super(MessageJSONEncoder, self).default(value)
class ProtoJson(object):
"""ProtoRPC JSON implementation class.
Implementation of JSON based protocol used for serializing and
deserializing message objects. Instances of remote.ProtocolConfig
constructor or used with remote.Protocols.add_protocol. See the
remote.py module for more details.
"""
CONTENT_TYPE = 'application/json'
ALTERNATIVE_CONTENT_TYPES = [
'application/x-javascript',
'text/javascript',
'text/x-javascript',
'text/x-json',
'text/json',
]
def encode_field(self, field, value):
"""Encode a python field value to a JSON value.
Args:
field: A ProtoRPC field instance.
value: A python value supported by field.
Returns:
A JSON serializable value appropriate for field.
"""
if isinstance(field, messages.BytesField):
if field.repeated:
value = [base64.b64encode(byte) for byte in value]
else:
value = base64.b64encode(value)
elif isinstance(field, message_types.DateTimeField):
# DateTimeField stores its data as a RFC 3339 compliant string.
if field.repeated:
value = [i.isoformat() for i in value]
else:
value = value.isoformat()
return value
def encode_message(self, message):
"""Encode Message instance to JSON string.
Args:
Message instance to encode in to JSON string.
Returns:
String encoding of Message instance in protocol JSON format.
Raises:
messages.ValidationError if message is not initialized.
"""
message.check_initialized()
return json.dumps(message, cls=MessageJSONEncoder,
protojson_protocol=self)
def decode_message(self, message_type, encoded_message):
"""Merge JSON structure to Message instance.
Args:
message_type: Message to decode data to.
encoded_message: JSON encoded version of message.
Returns:
Decoded instance of message_type.
Raises:
ValueError: If encoded_message is not valid JSON.
messages.ValidationError if merged message is not initialized.
"""
encoded_message = six.ensure_str(encoded_message)
if not encoded_message.strip():
return message_type()
dictionary = json.loads(encoded_message)
message = self.__decode_dictionary(message_type, dictionary)
message.check_initialized()
return message
def __find_variant(self, value):
"""Find the messages.Variant type that describes this value.
Args:
value: The value whose variant type is being determined.
Returns:
The messages.Variant value that best describes value's type,
or None if it's a type we don't know how to handle.
"""
if isinstance(value, bool):
return messages.Variant.BOOL
elif isinstance(value, six.integer_types):
return messages.Variant.INT64
elif isinstance(value, float):
return messages.Variant.DOUBLE
elif isinstance(value, six.string_types):
return messages.Variant.STRING
elif isinstance(value, (list, tuple)):
# Find the most specific variant that covers all elements.
variant_priority = [None,
messages.Variant.INT64,
messages.Variant.DOUBLE,
messages.Variant.STRING]
chosen_priority = 0
for v in value:
variant = self.__find_variant(v)
try:
priority = variant_priority.index(variant)
except IndexError:
priority = -1
if priority > chosen_priority:
chosen_priority = priority
return variant_priority[chosen_priority]
# Unrecognized type.
return None
def __decode_dictionary(self, message_type, dictionary):
"""Merge dictionary in to message.
Args:
message: Message to merge dictionary in to.
dictionary: Dictionary to extract information from. Dictionary
is as parsed from JSON. Nested objects will also be dictionaries.
"""
message = message_type()
for key, value in six.iteritems(dictionary):
if value is None:
try:
message.reset(key)
except AttributeError:
pass # This is an unrecognized field, skip it.
continue
try:
field = message.field_by_name(key)
except KeyError:
# Save unknown values.
variant = self.__find_variant(value)
if variant:
message.set_unrecognized_field(key, value, variant)
continue
is_enum_field = isinstance(field, messages.EnumField)
is_unrecognized_field = False
if field.repeated:
# This should be unnecessary? Or in fact become an error.
if not isinstance(value, list):
value = [value]
valid_value = []
for item in value:
try:
v = self.decode_field(field, item)
if is_enum_field and v is None:
continue
except messages.DecodeError:
if not is_enum_field:
raise
is_unrecognized_field = True
continue
valid_value.append(v)
setattr(message, field.name, valid_value)
if is_unrecognized_field:
variant = self.__find_variant(value)
if variant:
message.set_unrecognized_field(key, value, variant)
continue
# This is just for consistency with the old behavior.
if value == []:
continue
try:
setattr(message, field.name, self.decode_field(field, value))
except messages.DecodeError:
# Save unknown enum values.
if not is_enum_field:
raise
variant = self.__find_variant(value)
if variant:
message.set_unrecognized_field(key, value, variant)
return message
def decode_field(self, field, value):
"""Decode a JSON value to a python value.
Args:
field: A ProtoRPC field instance.
value: A serialized JSON value.
Return:
A Python value compatible with field.
"""
if isinstance(field, messages.EnumField):
try:
return field.type(value)
except TypeError:
raise messages.DecodeError(
'Invalid enum value "%s"' % (value or ''))
elif isinstance(field, messages.BytesField):
try:
return base64.b64decode(value)
except (binascii.Error, TypeError) as err:
raise messages.DecodeError('Base64 decoding error: %s' % err)
elif isinstance(field, message_types.DateTimeField):
try:
return util.decode_datetime(value, truncate_time=True)
except ValueError as err:
raise messages.DecodeError(err)
elif (isinstance(field, messages.MessageField) and
issubclass(field.type, messages.Message)):
return self.__decode_dictionary(field.type, value)
elif (isinstance(field, messages.FloatField) and
isinstance(value, (six.integer_types, six.string_types))):
try:
return float(value)
except: # pylint:disable=bare-except
pass
elif (isinstance(field, messages.IntegerField) and
isinstance(value, six.string_types)):
try:
return int(value)
except: # pylint:disable=bare-except
pass
return value
@staticmethod
def get_default():
"""Get default instanceof ProtoJson."""
try:
return ProtoJson.__default
except AttributeError:
ProtoJson.__default = ProtoJson()
return ProtoJson.__default
@staticmethod
def set_default(protocol):
"""Set the default instance of ProtoJson.
Args:
protocol: A ProtoJson instance.
"""
if not isinstance(protocol, ProtoJson):
raise TypeError('Expected protocol of type ProtoJson')
ProtoJson.__default = protocol
CONTENT_TYPE = ProtoJson.CONTENT_TYPE
ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES
encode_message = ProtoJson.get_default().encode_message
decode_message = ProtoJson.get_default().decode_message

View File

@@ -0,0 +1,667 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Test utilities for message testing.
Includes module interface test to ensure that public parts of module are
correctly declared in __all__.
Includes message types that correspond to those defined in
services_test.proto.
Includes additional test utilities to make sure encoding/decoding libraries
conform.
"""
import cgi
import datetime
import inspect
import os
import re
import socket
import types
import unittest
import six
from six.moves import range # pylint: disable=redefined-builtin
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.protorpclite import util
# Unicode of the word "Russian" in cyrillic.
RUSSIAN = u'\u0440\u0443\u0441\u0441\u043a\u0438\u0439'
# All characters binary value interspersed with nulls.
BINARY = b''.join(six.int2byte(value) + b'\0' for value in range(256))
class TestCase(unittest.TestCase):
def assertRaisesWithRegexpMatch(self,
exception,
regexp,
function,
*params,
**kwargs):
"""Check that exception is raised and text matches regular expression.
Args:
exception: Exception type that is expected.
regexp: String regular expression that is expected in error message.
function: Callable to test.
params: Parameters to forward to function.
kwargs: Keyword arguments to forward to function.
"""
try:
function(*params, **kwargs)
self.fail('Expected exception %s was not raised' %
exception.__name__)
except exception as err:
match = bool(re.match(regexp, str(err)))
self.assertTrue(match, 'Expected match "%s", found "%s"' % (regexp,
err))
def assertHeaderSame(self, header1, header2):
"""Check that two HTTP headers are the same.
Args:
header1: Header value string 1.
header2: header value string 2.
"""
value1, params1 = cgi.parse_header(header1)
value2, params2 = cgi.parse_header(header2)
self.assertEqual(value1, value2)
self.assertEqual(params1, params2)
def assertIterEqual(self, iter1, iter2):
"""Check two iterators or iterables are equal independent of order.
Similar to Python 2.7 assertItemsEqual. Named differently in order to
avoid potential conflict.
Args:
iter1: An iterator or iterable.
iter2: An iterator or iterable.
"""
list1 = list(iter1)
list2 = list(iter2)
unmatched1 = list()
while list1:
item1 = list1[0]
del list1[0]
for index in range(len(list2)):
if item1 == list2[index]:
del list2[index]
break
else:
unmatched1.append(item1)
error_message = []
for item in unmatched1:
error_message.append(
' Item from iter1 not found in iter2: %r' % item)
for item in list2:
error_message.append(
' Item from iter2 not found in iter1: %r' % item)
if error_message:
self.fail('Collections not equivalent:\n' +
'\n'.join(error_message))
class ModuleInterfaceTest(object):
"""Test to ensure module interface is carefully constructed.
A module interface is the set of public objects listed in the
module __all__ attribute. Modules that that are considered public
should have this interface carefully declared. At all times, the
__all__ attribute should have objects intended to be publically
used and all other objects in the module should be considered
unused.
Protected attributes (those beginning with '_') and other imported
modules should not be part of this set of variables. An exception
is for variables that begin and end with '__' which are implicitly
part of the interface (eg. __name__, __file__, __all__ itself,
etc.).
Modules that are imported in to the tested modules are an
exception and may be left out of the __all__ definition. The test
is done by checking the value of what would otherwise be a public
name and not allowing it to be exported if it is an instance of a
module. Modules that are explicitly exported are for the time
being not permitted.
To use this test class a module should define a new class that
inherits first from ModuleInterfaceTest and then from
test_util.TestCase. No other tests should be added to this test
case, making the order of inheritance less important, but if setUp
for some reason is overidden, it is important that
ModuleInterfaceTest is first in the list so that its setUp method
is invoked.
Multiple inheritance is required so that ModuleInterfaceTest is
not itself a test, and is not itself executed as one.
The test class is expected to have the following class attributes
defined:
MODULE: A reference to the module that is being validated for interface
correctness.
Example:
Module definition (hello.py):
import sys
__all__ = ['hello']
def _get_outputter():
return sys.stdout
def hello():
_get_outputter().write('Hello\n')
Test definition:
import unittest
from protorpc import test_util
import hello
class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
test_util.TestCase):
MODULE = hello
class HelloTest(test_util.TestCase):
... Test 'hello' module ...
if __name__ == '__main__':
unittest.main()
"""
def setUp(self):
"""Set up makes sure that MODULE and IMPORTED_MODULES is defined.
This is a basic configuration test for the test itself so does not
get it's own test case.
"""
if not hasattr(self, 'MODULE'):
self.fail(
"You must define 'MODULE' on ModuleInterfaceTest sub-class "
"%s." % type(self).__name__)
def testAllExist(self):
"""Test that all attributes defined in __all__ exist."""
missing_attributes = []
for attribute in self.MODULE.__all__:
if not hasattr(self.MODULE, attribute):
missing_attributes.append(attribute)
if missing_attributes:
self.fail('%s of __all__ are not defined in module.' %
missing_attributes)
def testAllExported(self):
"""Test that all public attributes not imported are in __all__."""
missing_attributes = []
for attribute in dir(self.MODULE):
if not attribute.startswith('_'):
if (attribute not in self.MODULE.__all__ and
not isinstance(getattr(self.MODULE, attribute),
types.ModuleType) and
attribute != 'with_statement'):
missing_attributes.append(attribute)
if missing_attributes:
self.fail('%s are not modules and not defined in __all__.' %
missing_attributes)
def testNoExportedProtectedVariables(self):
"""Test that there are no protected variables listed in __all__."""
protected_variables = []
for attribute in self.MODULE.__all__:
if attribute.startswith('_'):
protected_variables.append(attribute)
if protected_variables:
self.fail('%s are protected variables and may not be exported.' %
protected_variables)
def testNoExportedModules(self):
"""Test that no modules exist in __all__."""
exported_modules = []
for attribute in self.MODULE.__all__:
try:
value = getattr(self.MODULE, attribute)
except AttributeError:
# This is a different error case tested for in testAllExist.
pass
else:
if isinstance(value, types.ModuleType):
exported_modules.append(attribute)
if exported_modules:
self.fail('%s are modules and may not be exported.' %
exported_modules)
class NestedMessage(messages.Message):
"""Simple message that gets nested in another message."""
a_value = messages.StringField(1, required=True)
class HasNestedMessage(messages.Message):
"""Message that has another message nested in it."""
nested = messages.MessageField(NestedMessage, 1)
repeated_nested = messages.MessageField(NestedMessage, 2, repeated=True)
class HasDefault(messages.Message):
"""Has a default value."""
a_value = messages.StringField(1, default=u'a default')
class OptionalMessage(messages.Message):
"""Contains all message types."""
class SimpleEnum(messages.Enum):
"""Simple enumeration type."""
VAL1 = 1
VAL2 = 2
double_value = messages.FloatField(1, variant=messages.Variant.DOUBLE)
float_value = messages.FloatField(2, variant=messages.Variant.FLOAT)
int64_value = messages.IntegerField(3, variant=messages.Variant.INT64)
uint64_value = messages.IntegerField(4, variant=messages.Variant.UINT64)
int32_value = messages.IntegerField(5, variant=messages.Variant.INT32)
bool_value = messages.BooleanField(6, variant=messages.Variant.BOOL)
string_value = messages.StringField(7, variant=messages.Variant.STRING)
bytes_value = messages.BytesField(8, variant=messages.Variant.BYTES)
enum_value = messages.EnumField(SimpleEnum, 10)
class RepeatedMessage(messages.Message):
"""Contains all message types as repeated fields."""
class SimpleEnum(messages.Enum):
"""Simple enumeration type."""
VAL1 = 1
VAL2 = 2
double_value = messages.FloatField(1,
variant=messages.Variant.DOUBLE,
repeated=True)
float_value = messages.FloatField(2,
variant=messages.Variant.FLOAT,
repeated=True)
int64_value = messages.IntegerField(3,
variant=messages.Variant.INT64,
repeated=True)
uint64_value = messages.IntegerField(4,
variant=messages.Variant.UINT64,
repeated=True)
int32_value = messages.IntegerField(5,
variant=messages.Variant.INT32,
repeated=True)
bool_value = messages.BooleanField(6,
variant=messages.Variant.BOOL,
repeated=True)
string_value = messages.StringField(7,
variant=messages.Variant.STRING,
repeated=True)
bytes_value = messages.BytesField(8,
variant=messages.Variant.BYTES,
repeated=True)
enum_value = messages.EnumField(SimpleEnum,
10,
repeated=True)
class HasOptionalNestedMessage(messages.Message):
nested = messages.MessageField(OptionalMessage, 1)
repeated_nested = messages.MessageField(OptionalMessage, 2, repeated=True)
# pylint:disable=anomalous-unicode-escape-in-string
class ProtoConformanceTestBase(object):
"""Protocol conformance test base class.
Each supported protocol should implement two methods that support encoding
and decoding of Message objects in that format:
encode_message(message) - Serialize to encoding.
encode_message(message, encoded_message) - Deserialize from encoding.
Tests for the modules where these functions are implemented should extend
this class in order to support basic behavioral expectations. This ensures
that protocols correctly encode and decode message transparently to the
caller.
In order to support these test, the base class should also extend
the TestCase class and implement the following class attributes
which define the encoded version of certain protocol buffers:
encoded_partial:
<OptionalMessage
double_value: 1.23
int64_value: -100000000000
string_value: u"a string"
enum_value: OptionalMessage.SimpleEnum.VAL2
>
encoded_full:
<OptionalMessage
double_value: 1.23
float_value: -2.5
int64_value: -100000000000
uint64_value: 102020202020
int32_value: 1020
bool_value: true
string_value: u"a string\u044f"
bytes_value: b"a bytes\xff\xfe"
enum_value: OptionalMessage.SimpleEnum.VAL2
>
encoded_repeated:
<RepeatedMessage
double_value: [1.23, 2.3]
float_value: [-2.5, 0.5]
int64_value: [-100000000000, 20]
uint64_value: [102020202020, 10]
int32_value: [1020, 718]
bool_value: [true, false]
string_value: [u"a string\u044f", u"another string"]
bytes_value: [b"a bytes\xff\xfe", b"another bytes"]
enum_value: [OptionalMessage.SimpleEnum.VAL2,
OptionalMessage.SimpleEnum.VAL 1]
>
encoded_nested:
<HasNestedMessage
nested: <NestedMessage
a_value: "a string"
>
>
encoded_repeated_nested:
<HasNestedMessage
repeated_nested: [
<NestedMessage a_value: "a string">,
<NestedMessage a_value: "another string">
]
>
unexpected_tag_message:
An encoded message that has an undefined tag or number in the stream.
encoded_default_assigned:
<HasDefault
a_value: "a default"
>
encoded_nested_empty:
<HasOptionalNestedMessage
nested: <OptionalMessage>
>
encoded_invalid_enum:
<OptionalMessage
enum_value: (invalid value for serialization type)
>
encoded_invalid_repeated_enum:
<RepeatedMessage
enum_value: (invalid value for serialization type)
>
"""
encoded_empty_message = ''
def testEncodeInvalidMessage(self):
message = NestedMessage()
self.assertRaises(messages.ValidationError,
self.PROTOLIB.encode_message, message)
def CompareEncoded(self, expected_encoded, actual_encoded):
"""Compare two encoded protocol values.
Can be overridden by sub-classes to special case comparison.
For example, to eliminate white space from output that is not
relevant to encoding.
Args:
expected_encoded: Expected string encoded value.
actual_encoded: Actual string encoded value.
"""
self.assertEqual(expected_encoded, actual_encoded)
def EncodeDecode(self, encoded, expected_message):
message = self.PROTOLIB.decode_message(type(expected_message), encoded)
self.assertEqual(expected_message, message)
self.CompareEncoded(encoded, self.PROTOLIB.encode_message(message))
def testEmptyMessage(self):
self.EncodeDecode(self.encoded_empty_message, OptionalMessage())
def testPartial(self):
"""Test message with a few values set."""
message = OptionalMessage()
message.double_value = 1.23
message.int64_value = -100000000000
message.int32_value = 1020
message.string_value = u'a string'
message.enum_value = OptionalMessage.SimpleEnum.VAL2
self.EncodeDecode(self.encoded_partial, message)
def testFull(self):
"""Test all types."""
message = OptionalMessage()
message.double_value = 1.23
message.float_value = -2.5
message.int64_value = -100000000000
message.uint64_value = 102020202020
message.int32_value = 1020
message.bool_value = True
message.string_value = u'a string\u044f'
message.bytes_value = b'a bytes\xff\xfe'
message.enum_value = OptionalMessage.SimpleEnum.VAL2
self.EncodeDecode(self.encoded_full, message)
def testRepeated(self):
"""Test repeated fields."""
message = RepeatedMessage()
message.double_value = [1.23, 2.3]
message.float_value = [-2.5, 0.5]
message.int64_value = [-100000000000, 20]
message.uint64_value = [102020202020, 10]
message.int32_value = [1020, 718]
message.bool_value = [True, False]
message.string_value = [u'a string\u044f', u'another string']
message.bytes_value = [b'a bytes\xff\xfe', b'another bytes']
message.enum_value = [RepeatedMessage.SimpleEnum.VAL2,
RepeatedMessage.SimpleEnum.VAL1]
self.EncodeDecode(self.encoded_repeated, message)
def testNested(self):
"""Test nested messages."""
nested_message = NestedMessage()
nested_message.a_value = u'a string'
message = HasNestedMessage()
message.nested = nested_message
self.EncodeDecode(self.encoded_nested, message)
def testRepeatedNested(self):
"""Test repeated nested messages."""
nested_message1 = NestedMessage()
nested_message1.a_value = u'a string'
nested_message2 = NestedMessage()
nested_message2.a_value = u'another string'
message = HasNestedMessage()
message.repeated_nested = [nested_message1, nested_message2]
self.EncodeDecode(self.encoded_repeated_nested, message)
def testStringTypes(self):
"""Test that encoding str on StringField works."""
message = OptionalMessage()
message.string_value = 'Latin'
self.EncodeDecode(self.encoded_string_types, message)
def testEncodeUninitialized(self):
"""Test that cannot encode uninitialized message."""
required = NestedMessage()
self.assertRaisesWithRegexpMatch(messages.ValidationError,
"Message NestedMessage is missing "
"required field a_value",
self.PROTOLIB.encode_message,
required)
def testUnexpectedField(self):
"""Test decoding and encoding unexpected fields."""
loaded_message = self.PROTOLIB.decode_message(
OptionalMessage, self.unexpected_tag_message)
# Message should be equal to an empty message, since unknown
# values aren't included in equality.
self.assertEqual(OptionalMessage(), loaded_message)
# Verify that the encoded message matches the source, including the
# unknown value.
self.assertEqual(self.unexpected_tag_message,
self.PROTOLIB.encode_message(loaded_message))
def testDoNotSendDefault(self):
"""Test that default is not sent when nothing is assigned."""
self.EncodeDecode(self.encoded_empty_message, HasDefault())
def testSendDefaultExplicitlyAssigned(self):
"""Test that default is sent when explcitly assigned."""
message = HasDefault()
message.a_value = HasDefault.a_value.default
self.EncodeDecode(self.encoded_default_assigned, message)
def testEncodingNestedEmptyMessage(self):
"""Test encoding a nested empty message."""
message = HasOptionalNestedMessage()
message.nested = OptionalMessage()
self.EncodeDecode(self.encoded_nested_empty, message)
def testEncodingRepeatedNestedEmptyMessage(self):
"""Test encoding a nested empty message."""
message = HasOptionalNestedMessage()
message.repeated_nested = [OptionalMessage(), OptionalMessage()]
self.EncodeDecode(self.encoded_repeated_nested_empty, message)
def testContentType(self):
self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str))
def testDecodeInvalidEnumType(self):
# Since protos need to be able to add new enums, a message should be
# successfully decoded even if the enum value is invalid. Encoding the
# decoded message should result in equivalence with the original
# encoded message containing an invalid enum.
decoded = self.PROTOLIB.decode_message(OptionalMessage,
self.encoded_invalid_enum)
message = OptionalMessage()
self.assertEqual(message, decoded)
encoded = self.PROTOLIB.encode_message(decoded)
self.assertEqual(self.encoded_invalid_enum, encoded)
def testDecodeInvalidRepeatedEnumType(self):
# Since protos need to be able to add new enums, a message should be
# successfully decoded even if the enum value is invalid. Encoding the
# decoded message should result in equivalence with the original
# encoded message containing an invalid enum.
decoded = self.PROTOLIB.decode_message(RepeatedMessage,
self.encoded_invalid_repeated_enum)
message = RepeatedMessage()
message.enum_value = [RepeatedMessage.SimpleEnum.VAL1]
self.assertEqual(message, decoded)
encoded = self.PROTOLIB.encode_message(decoded)
self.assertEqual(self.encoded_invalid_repeated_enum, encoded)
def testDateTimeNoTimeZone(self):
"""Test that DateTimeFields are encoded/decoded correctly."""
class MyMessage(messages.Message):
value = message_types.DateTimeField(1)
value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000)
message = MyMessage(value=value)
decoded = self.PROTOLIB.decode_message(
MyMessage, self.PROTOLIB.encode_message(message))
self.assertEqual(decoded.value, value)
def testDateTimeWithTimeZone(self):
"""Test DateTimeFields with time zones."""
class MyMessage(messages.Message):
value = message_types.DateTimeField(1)
value = datetime.datetime(2013, 1, 3, 11, 36, 30, 123000,
util.TimeZoneOffset(8 * 60))
message = MyMessage(value=value)
decoded = self.PROTOLIB.decode_message(
MyMessage, self.PROTOLIB.encode_message(message))
self.assertEqual(decoded.value, value)
def pick_unused_port():
"""Find an unused port to use in tests.
Derived from Damon Kohlers example:
http://code.activestate.com/recipes/531822-pick-unused-port
"""
temp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
temp.bind(('localhost', 0))
port = temp.getsockname()[1]
finally:
temp.close()
return port
def get_module_name(module_attribute):
"""Get the module name.
Args:
module_attribute: An attribute of the module.
Returns:
The fully qualified module name or simple module name where
'module_attribute' is defined if the module name is "__main__".
"""
if module_attribute.__module__ == '__main__':
module_file = inspect.getfile(module_attribute)
default = os.path.basename(module_file).split('.')[0]
return default
return module_attribute.__module__

View File

@@ -0,0 +1,314 @@
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Common utility library."""
from __future__ import with_statement
import datetime
import functools
import inspect
import logging
import os
import re
import sys
import six
__all__ = [
'Error',
'decode_datetime',
'get_package_for_module',
'positional',
'TimeZoneOffset',
'total_seconds',
]
class Error(Exception):
"""Base class for protorpc exceptions."""
_TIME_ZONE_RE_STRING = r"""
# Examples:
# +01:00
# -05:30
# Z12:00
((?P<z>Z) | (?P<sign>[-+])
(?P<hours>\d\d) :
(?P<minutes>\d\d))$
"""
_TIME_ZONE_RE = re.compile(_TIME_ZONE_RE_STRING, re.IGNORECASE | re.VERBOSE)
def positional(max_positional_args):
"""A decorator that declares only the first N arguments may be positional.
This decorator makes it easy to support Python 3 style keyword-only
parameters. For example, in Python 3 it is possible to write:
def fn(pos1, *, kwonly1=None, kwonly1=None):
...
All named parameters after * must be a keyword:
fn(10, 'kw1', 'kw2') # Raises exception.
fn(10, kwonly1='kw1') # Ok.
Example:
To define a function like above, do:
@positional(1)
def fn(pos1, kwonly1=None, kwonly2=None):
...
If no default value is provided to a keyword argument, it
becomes a required keyword argument:
@positional(0)
def fn(required_kw):
...
This must be called with the keyword parameter:
fn() # Raises exception.
fn(10) # Raises exception.
fn(required_kw=10) # Ok.
When defining instance or class methods always remember to account for
'self' and 'cls':
class MyClass(object):
@positional(2)
def my_method(self, pos1, kwonly1=None):
...
@classmethod
@positional(2)
def my_method(cls, pos1, kwonly1=None):
...
One can omit the argument to 'positional' altogether, and then no
arguments with default values may be passed positionally. This
would be equivalent to placing a '*' before the first argument
with a default value in Python 3. If there are no arguments with
default values, and no argument is given to 'positional', an error
is raised.
@positional
def fn(arg1, arg2, required_kw1=None, required_kw2=0):
...
fn(1, 3, 5) # Raises exception.
fn(1, 3) # Ok.
fn(1, 3, required_kw1=5) # Ok.
Args:
max_positional_arguments: Maximum number of positional arguments. All
parameters after the this index must be keyword only.
Returns:
A decorator that prevents using arguments after max_positional_args from
being used as positional parameters.
Raises:
TypeError if a keyword-only argument is provided as a positional
parameter.
ValueError if no maximum number of arguments is provided and the function
has no arguments with default values.
"""
def positional_decorator(wrapped):
"""Creates a function wraper to enforce number of arguments."""
@functools.wraps(wrapped)
def positional_wrapper(*args, **kwargs):
if len(args) > max_positional_args:
plural_s = ''
if max_positional_args != 1:
plural_s = 's'
raise TypeError('%s() takes at most %d positional argument%s '
'(%d given)' % (wrapped.__name__,
max_positional_args,
plural_s, len(args)))
return wrapped(*args, **kwargs)
return positional_wrapper
if isinstance(max_positional_args, six.integer_types):
return positional_decorator
else:
args, _, _, defaults, *_ = inspect.getfullargspec(max_positional_args)
if defaults is None:
raise ValueError(
'Functions with no keyword arguments must specify '
'max_positional_args')
return positional(len(args) - len(defaults))(max_positional_args)
@positional(1)
def get_package_for_module(module):
"""Get package name for a module.
Helper calculates the package name of a module.
Args:
module: Module to get name for. If module is a string, try to find
module in sys.modules.
Returns:
If module contains 'package' attribute, uses that as package name.
Else, if module is not the '__main__' module, the module __name__.
Else, the base name of the module file name. Else None.
"""
if isinstance(module, six.string_types):
try:
module = sys.modules[module]
except KeyError:
return None
try:
return six.text_type(module.package)
except AttributeError:
if module.__name__ == '__main__':
try:
file_name = module.__file__
except AttributeError:
pass
else:
base_name = os.path.basename(file_name)
split_name = os.path.splitext(base_name)
if len(split_name) == 1:
return six.text_type(base_name)
return u'.'.join(split_name[:-1])
return six.text_type(module.__name__)
def total_seconds(offset):
"""Backport of offset.total_seconds() from python 2.7+."""
seconds = offset.days * 24 * 60 * 60 + offset.seconds
microseconds = seconds * 10**6 + offset.microseconds
return microseconds / (10**6 * 1.0)
class TimeZoneOffset(datetime.tzinfo):
"""Time zone information as encoded/decoded for DateTimeFields."""
def __init__(self, offset):
"""Initialize a time zone offset.
Args:
offset: Integer or timedelta time zone offset, in minutes from UTC.
This can be negative.
"""
super(TimeZoneOffset, self).__init__()
if isinstance(offset, datetime.timedelta):
offset = total_seconds(offset) / 60
self.__offset = offset
def utcoffset(self, _):
"""Get the a timedelta with the time zone's offset from UTC.
Returns:
The time zone offset from UTC, as a timedelta.
"""
return datetime.timedelta(minutes=self.__offset)
def dst(self, _):
"""Get the daylight savings time offset.
The formats that ProtoRPC uses to encode/decode time zone
information don't contain any information about daylight
savings time. So this always returns a timedelta of 0.
Returns:
A timedelta of 0.
"""
return datetime.timedelta(0)
def decode_datetime(encoded_datetime, truncate_time=False):
"""Decode a DateTimeField parameter from a string to a python datetime.
Args:
encoded_datetime: A string in RFC 3339 format.
truncate_time: If true, truncate time string with precision higher than
microsecs.
Returns:
A datetime object with the date and time specified in encoded_datetime.
Raises:
ValueError: If the string is not in a recognized format.
"""
# Check if the string includes a time zone offset. Break out the
# part that doesn't include time zone info. Convert to uppercase
# because all our comparisons should be case-insensitive.
time_zone_match = _TIME_ZONE_RE.search(encoded_datetime)
if time_zone_match:
time_string = encoded_datetime[:time_zone_match.start(1)].upper()
else:
time_string = encoded_datetime.upper()
if '.' in time_string:
format_string = '%Y-%m-%dT%H:%M:%S.%f'
else:
format_string = '%Y-%m-%dT%H:%M:%S'
try:
decoded_datetime = datetime.datetime.strptime(time_string,
format_string)
except ValueError:
if truncate_time and '.' in time_string:
datetime_string, decimal_secs = time_string.split('.')
if len(decimal_secs) > 6:
# datetime can handle only microsecs precision.
truncated_time_string = '{}.{}'.format(
datetime_string, decimal_secs[:6])
decoded_datetime = datetime.datetime.strptime(
truncated_time_string,
format_string)
logging.warning(
'Truncating the datetime string from %s to %s',
time_string, truncated_time_string)
else:
raise
else:
raise
if not time_zone_match:
return decoded_datetime
# Time zone info was included in the parameter. Add a tzinfo
# object to the datetime. Datetimes can't be changed after they're
# created, so we'll need to create a new one.
if time_zone_match.group('z'):
offset_minutes = 0
else:
sign = time_zone_match.group('sign')
hours, minutes = [int(value) for value in
time_zone_match.group('hours', 'minutes')]
offset_minutes = hours * 60 + minutes
if sign == '-':
offset_minutes *= -1
return datetime.datetime(decoded_datetime.year,
decoded_datetime.month,
decoded_datetime.day,
decoded_datetime.hour,
decoded_datetime.minute,
decoded_datetime.second,
decoded_datetime.microsecond,
TimeZoneOffset(offset_minutes))