feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared __init__.py for apitools."""
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

View File

@@ -0,0 +1,551 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extended protorpc descriptors.
This takes existing protorpc Descriptor classes and adds extra
properties not directly supported in proto itself, notably field and
message descriptions. We need this in order to generate protorpc
message files with comments.
Note that for most of these classes, we can't simply wrap the existing
message, since we need to change the type of the subfields. We could
have a "plain" descriptor attached, but that seems like unnecessary
bookkeeping. Where possible, we purposely reuse existing tag numbers;
for new fields, we start numbering at 100.
"""
import abc
import operator
import textwrap
import six
from apitools.base.protorpclite import descriptor as protorpc_descriptor
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.py import extra_types
class ExtendedEnumValueDescriptor(messages.Message):
"""Enum value descriptor with additional fields.
Fields:
name: Name of enumeration value.
number: Number of enumeration value.
description: Description of this enum value.
"""
name = messages.StringField(1)
number = messages.IntegerField(2, variant=messages.Variant.INT32)
description = messages.StringField(100)
class ExtendedEnumDescriptor(messages.Message):
"""Enum class descriptor with additional fields.
Fields:
name: Name of Enum without any qualification.
values: Values defined by Enum class.
description: Description of this enum class.
full_name: Fully qualified name of this enum class.
enum_mappings: Mappings from python to JSON names for enum values.
"""
class JsonEnumMapping(messages.Message):
"""Mapping from a python name to the wire name for an enum."""
python_name = messages.StringField(1)
json_name = messages.StringField(2)
name = messages.StringField(1)
values = messages.MessageField(
ExtendedEnumValueDescriptor, 2, repeated=True)
description = messages.StringField(100)
full_name = messages.StringField(101)
enum_mappings = messages.MessageField(
'JsonEnumMapping', 102, repeated=True)
class ExtendedFieldDescriptor(messages.Message):
"""Field descriptor with additional fields.
Fields:
field_descriptor: The underlying field descriptor.
name: The name of this field.
description: Description of this field.
"""
field_descriptor = messages.MessageField(
protorpc_descriptor.FieldDescriptor, 100)
# We duplicate the names for easier bookkeeping.
name = messages.StringField(101)
description = messages.StringField(102)
class ExtendedMessageDescriptor(messages.Message):
"""Message descriptor with additional fields.
Fields:
name: Name of Message without any qualification.
fields: Fields defined for message.
message_types: Nested Message classes defined on message.
enum_types: Nested Enum classes defined on message.
description: Description of this message.
full_name: Full qualified name of this message.
decorators: Decorators to include in the definition when printing.
Printed in the given order from top to bottom (so the last entry
is the innermost decorator).
alias_for: This type is just an alias for the named type.
field_mappings: Mappings from python to json field names.
"""
class JsonFieldMapping(messages.Message):
"""Mapping from a python name to the wire name for a field."""
python_name = messages.StringField(1)
json_name = messages.StringField(2)
name = messages.StringField(1)
fields = messages.MessageField(ExtendedFieldDescriptor, 2, repeated=True)
message_types = messages.MessageField(
'extended_descriptor.ExtendedMessageDescriptor', 3, repeated=True)
enum_types = messages.MessageField(
ExtendedEnumDescriptor, 4, repeated=True)
description = messages.StringField(100)
full_name = messages.StringField(101)
decorators = messages.StringField(102, repeated=True)
alias_for = messages.StringField(103)
field_mappings = messages.MessageField(
'JsonFieldMapping', 104, repeated=True)
class ExtendedFileDescriptor(messages.Message):
"""File descriptor with additional fields.
Fields:
package: Fully qualified name of package that definitions belong to.
message_types: Message definitions contained in file.
enum_types: Enum definitions contained in file.
description: Description of this file.
additional_imports: Extra imports used in this package.
"""
package = messages.StringField(2)
message_types = messages.MessageField(
ExtendedMessageDescriptor, 4, repeated=True)
enum_types = messages.MessageField(
ExtendedEnumDescriptor, 5, repeated=True)
description = messages.StringField(100)
additional_imports = messages.StringField(101, repeated=True)
def _WriteFile(file_descriptor, package, version, proto_printer):
"""Write the given extended file descriptor to the printer."""
proto_printer.PrintPreamble(package, version, file_descriptor)
_PrintEnums(proto_printer, file_descriptor.enum_types)
_PrintMessages(proto_printer, file_descriptor.message_types)
custom_json_mappings = _FetchCustomMappings(file_descriptor.enum_types)
custom_json_mappings.extend(
_FetchCustomMappings(file_descriptor.message_types))
for mapping in custom_json_mappings:
proto_printer.PrintCustomJsonMapping(mapping)
def WriteMessagesFile(file_descriptor, package, version, printer):
"""Write the given extended file descriptor to out as a message file."""
_WriteFile(file_descriptor, package, version,
_Proto2Printer(printer))
def WritePythonFile(file_descriptor, package, version, printer):
"""Write the given extended file descriptor to out."""
_WriteFile(file_descriptor, package, version,
_ProtoRpcPrinter(printer))
def PrintIndentedDescriptions(printer, ls, name, prefix=''):
if ls:
with printer.Indent(indent=prefix):
with printer.CommentContext():
width = printer.CalculateWidth() - len(prefix)
printer()
printer(name + ':')
for x in ls:
description = '%s: %s' % (x.name, x.description)
for line in textwrap.wrap(description, width,
initial_indent=' ',
subsequent_indent=' '):
printer(line)
def _FetchCustomMappings(descriptor_ls):
"""Find and return all custom mappings for descriptors in descriptor_ls."""
custom_mappings = []
for descriptor in descriptor_ls:
if isinstance(descriptor, ExtendedEnumDescriptor):
custom_mappings.extend(
_FormatCustomJsonMapping('Enum', m, descriptor)
for m in descriptor.enum_mappings)
elif isinstance(descriptor, ExtendedMessageDescriptor):
custom_mappings.extend(
_FormatCustomJsonMapping('Field', m, descriptor)
for m in descriptor.field_mappings)
custom_mappings.extend(
_FetchCustomMappings(descriptor.enum_types))
custom_mappings.extend(
_FetchCustomMappings(descriptor.message_types))
return custom_mappings
def _FormatCustomJsonMapping(mapping_type, mapping, descriptor):
return '\n'.join((
'encoding.AddCustomJson%sMapping(' % mapping_type,
" %s, '%s', '%s')" % (descriptor.full_name, mapping.python_name,
mapping.json_name),
))
def _EmptyMessage(message_type):
return not any((message_type.enum_types,
message_type.message_types,
message_type.fields))
class ProtoPrinter(six.with_metaclass(abc.ABCMeta, object)):
"""Interface for proto printers."""
@abc.abstractmethod
def PrintPreamble(self, package, version, file_descriptor):
"""Print the file docstring and import lines."""
@abc.abstractmethod
def PrintEnum(self, enum_type):
"""Print the given enum declaration."""
@abc.abstractmethod
def PrintMessage(self, message_type):
"""Print the given message declaration."""
class _Proto2Printer(ProtoPrinter):
"""Printer for proto2 definitions."""
def __init__(self, printer):
self.__printer = printer
def __PrintEnumCommentLines(self, enum_type):
description = enum_type.description or '%s enum type.' % enum_type.name
for line in textwrap.wrap(description,
self.__printer.CalculateWidth() - 3):
self.__printer('// %s', line)
PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values',
prefix='// ')
def __PrintEnumValueCommentLines(self, enum_value):
if enum_value.description:
width = self.__printer.CalculateWidth() - 3
for line in textwrap.wrap(enum_value.description, width):
self.__printer('// %s', line)
def PrintEnum(self, enum_type):
self.__PrintEnumCommentLines(enum_type)
self.__printer('enum %s {', enum_type.name)
with self.__printer.Indent():
enum_values = sorted(
enum_type.values, key=operator.attrgetter('number'))
for enum_value in enum_values:
self.__printer()
self.__PrintEnumValueCommentLines(enum_value)
self.__printer('%s = %s;', enum_value.name, enum_value.number)
self.__printer('}')
self.__printer()
def PrintPreamble(self, package, version, file_descriptor):
self.__printer('// Generated message classes for %s version %s.',
package, version)
self.__printer('// NOTE: This file is autogenerated and should not be '
'edited by hand.')
description_lines = textwrap.wrap(file_descriptor.description, 75)
if description_lines:
self.__printer('//')
for line in description_lines:
self.__printer('// %s', line)
self.__printer()
self.__printer('syntax = "proto2";')
self.__printer('package %s;', file_descriptor.package)
def __PrintMessageCommentLines(self, message_type):
"""Print the description of this message."""
description = message_type.description or '%s message type.' % (
message_type.name)
width = self.__printer.CalculateWidth() - 3
for line in textwrap.wrap(description, width):
self.__printer('// %s', line)
PrintIndentedDescriptions(self.__printer, message_type.enum_types,
'Enums', prefix='// ')
PrintIndentedDescriptions(self.__printer, message_type.message_types,
'Messages', prefix='// ')
PrintIndentedDescriptions(self.__printer, message_type.fields,
'Fields', prefix='// ')
def __PrintFieldDescription(self, description):
for line in textwrap.wrap(description,
self.__printer.CalculateWidth() - 3):
self.__printer('// %s', line)
def __PrintFields(self, fields):
for extended_field in fields:
field = extended_field.field_descriptor
field_type = messages.Field.lookup_field_type_by_variant(
field.variant)
self.__printer()
self.__PrintFieldDescription(extended_field.description)
label = str(field.label).lower()
if field_type in (messages.EnumField, messages.MessageField):
proto_type = field.type_name
else:
proto_type = str(field.variant).lower()
default_statement = ''
if field.default_value:
if field_type in [messages.BytesField, messages.StringField]:
default_value = '"%s"' % field.default_value
elif field_type is messages.BooleanField:
default_value = str(field.default_value).lower()
else:
default_value = str(field.default_value)
default_statement = ' [default = %s]' % default_value
self.__printer(
'%s %s %s = %d%s;',
label, proto_type, field.name, field.number, default_statement)
def PrintMessage(self, message_type):
self.__printer()
self.__PrintMessageCommentLines(message_type)
if _EmptyMessage(message_type):
self.__printer('message %s {}', message_type.name)
return
self.__printer('message %s {', message_type.name)
with self.__printer.Indent():
_PrintEnums(self, message_type.enum_types)
_PrintMessages(self, message_type.message_types)
self.__PrintFields(message_type.fields)
self.__printer('}')
def PrintCustomJsonMapping(self, mapping_lines):
raise NotImplementedError(
'Custom JSON encoding not supported for proto2')
class _ProtoRpcPrinter(ProtoPrinter):
"""Printer for ProtoRPC definitions."""
def __init__(self, printer):
self.__printer = printer
def __PrintClassSeparator(self):
self.__printer()
if not self.__printer.indent:
self.__printer()
def __PrintEnumDocstringLines(self, enum_type):
description = enum_type.description or '%s enum type.' % enum_type.name
for line in textwrap.wrap('r"""%s' % description,
self.__printer.CalculateWidth()):
self.__printer(line)
PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values')
self.__printer('"""')
def PrintEnum(self, enum_type):
self.__printer('class %s(_messages.Enum):', enum_type.name)
with self.__printer.Indent():
self.__PrintEnumDocstringLines(enum_type)
enum_values = sorted(
enum_type.values, key=operator.attrgetter('number'))
for enum_value in enum_values:
self.__printer('%s = %s', enum_value.name, enum_value.number)
if not enum_type.values:
self.__printer('pass')
self.__PrintClassSeparator()
def __PrintAdditionalImports(self, imports):
"""Print additional imports needed for protorpc."""
google_imports = [x for x in imports if 'google' in x]
other_imports = [x for x in imports if 'google' not in x]
if other_imports:
for import_ in sorted(other_imports):
self.__printer(import_)
self.__printer()
# Note: If we ever were going to add imports from this package, we'd
# need to sort those out and put them at the end.
if google_imports:
for import_ in sorted(google_imports):
self.__printer(import_)
self.__printer()
def PrintPreamble(self, package, version, file_descriptor):
self.__printer('"""Generated message classes for %s version %s.',
package, version)
self.__printer()
for line in textwrap.wrap(file_descriptor.description, 78):
self.__printer(line)
self.__printer('"""')
self.__printer('# NOTE: This file is autogenerated and should not be '
'edited by hand.')
self.__printer()
self.__printer('from __future__ import absolute_import')
self.__printer()
self.__PrintAdditionalImports(file_descriptor.additional_imports)
self.__printer()
self.__printer("package = '%s'", file_descriptor.package)
self.__printer()
self.__printer()
def __PrintMessageDocstringLines(self, message_type):
"""Print the docstring for this message."""
description = message_type.description or '%s message type.' % (
message_type.name)
short_description = (
_EmptyMessage(message_type) and
len(description) < (self.__printer.CalculateWidth() - 6))
with self.__printer.CommentContext():
if short_description:
# Note that we use explicit string interpolation here since
# we're in comment context.
self.__printer('r"""%s"""' % description)
return
for line in textwrap.wrap('r"""%s' % description,
self.__printer.CalculateWidth()):
self.__printer(line)
PrintIndentedDescriptions(self.__printer, message_type.enum_types,
'Enums')
PrintIndentedDescriptions(
self.__printer, message_type.message_types, 'Messages')
PrintIndentedDescriptions(
self.__printer, message_type.fields, 'Fields')
self.__printer('"""')
self.__printer()
def PrintMessage(self, message_type):
if message_type.alias_for:
self.__printer(
'%s = %s', message_type.name, message_type.alias_for)
self.__PrintClassSeparator()
return
for decorator in message_type.decorators:
self.__printer('@%s', decorator)
self.__printer('class %s(_messages.Message):', message_type.name)
with self.__printer.Indent():
self.__PrintMessageDocstringLines(message_type)
_PrintEnums(self, message_type.enum_types)
_PrintMessages(self, message_type.message_types)
_PrintFields(message_type.fields, self.__printer)
self.__PrintClassSeparator()
def PrintCustomJsonMapping(self, mapping):
self.__printer(mapping)
def _PrintEnums(proto_printer, enum_types):
"""Print all enums to the given proto_printer."""
enum_types = sorted(enum_types, key=operator.attrgetter('name'))
for enum_type in enum_types:
proto_printer.PrintEnum(enum_type)
def _PrintMessages(proto_printer, message_list):
message_list = sorted(message_list, key=operator.attrgetter('name'))
for message_type in message_list:
proto_printer.PrintMessage(message_type)
_MESSAGE_FIELD_MAP = {
message_types.DateTimeMessage.definition_name(): (
message_types.DateTimeField),
}
def _PrintFields(fields, printer):
for extended_field in fields:
field = extended_field.field_descriptor
printed_field_info = {
'name': field.name,
'module': '_messages',
'type_name': '',
'type_format': '',
'number': field.number,
'label_format': '',
'variant_format': '',
'default_format': '',
}
message_field = _MESSAGE_FIELD_MAP.get(field.type_name)
if message_field:
printed_field_info['module'] = '_message_types'
field_type = message_field
elif field.type_name == 'extra_types.DateField':
printed_field_info['module'] = 'extra_types'
field_type = extra_types.DateField
else:
field_type = messages.Field.lookup_field_type_by_variant(
field.variant)
if field_type in (messages.EnumField, messages.MessageField):
printed_field_info['type_format'] = "'%s', " % field.type_name
if field.label == protorpc_descriptor.FieldDescriptor.Label.REQUIRED:
printed_field_info['label_format'] = ', required=True'
elif field.label == protorpc_descriptor.FieldDescriptor.Label.REPEATED:
printed_field_info['label_format'] = ', repeated=True'
if field_type.DEFAULT_VARIANT != field.variant:
printed_field_info['variant_format'] = (
', variant=_messages.Variant.%s' % field.variant)
if field.default_value:
if field_type in [messages.BytesField, messages.StringField]:
default_value = repr(field.default_value)
elif field_type is messages.EnumField:
try:
default_value = str(int(field.default_value))
except ValueError:
default_value = repr(field.default_value)
else:
default_value = field.default_value
printed_field_info[
'default_format'] = ', default=%s' % (default_value,)
printed_field_info['type_name'] = field_type.__name__
args = ''.join('%%(%s)s' % field for field in (
'type_format',
'number',
'label_format',
'variant_format',
'default_format'))
format_str = '%%(name)s = %%(module)s.%%(type_name)s(%s)' % args
printer(format_str % printed_field_info)

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command-line interface to gen_client."""
import argparse
import contextlib
import io
import json
import logging
import os
import pkgutil
import sys
from apitools.base.py import exceptions
from apitools.gen import gen_client_lib
from apitools.gen import util
def _CopyLocalFile(filename):
with contextlib.closing(io.open(filename, 'w')) as out:
src_data = pkgutil.get_data(
'apitools.base.py', filename)
if src_data is None:
raise exceptions.GeneratedClientError(
'Could not find file %s' % filename)
out.write(src_data)
def _GetDiscoveryDocFromFlags(args):
"""Get the discovery doc from flags."""
if args.discovery_url:
try:
return util.FetchDiscoveryDoc(args.discovery_url)
except exceptions.CommunicationError:
raise exceptions.GeneratedClientError(
'Could not fetch discovery doc')
infile = os.path.expanduser(args.infile) or '/dev/stdin'
with io.open(infile, encoding='utf8') as f:
return json.loads(util.ReplaceHomoglyphs(f.read()))
def _GetCodegenFromFlags(args):
"""Create a codegen object from flags."""
discovery_doc = _GetDiscoveryDocFromFlags(args)
names = util.Names(
args.strip_prefix,
args.experimental_name_convention,
args.experimental_capitalize_enums)
if args.client_json:
try:
with io.open(args.client_json, encoding='utf8') as client_json:
f = json.loads(util.ReplaceHomoglyphs(client_json.read()))
web = f.get('installed', f.get('web', {}))
client_id = web.get('client_id')
client_secret = web.get('client_secret')
except IOError:
raise exceptions.NotFoundError(
'Failed to open client json file: %s' % args.client_json)
else:
client_id = args.client_id
client_secret = args.client_secret
if not client_id:
logging.warning('No client ID supplied')
client_id = ''
if not client_secret:
logging.warning('No client secret supplied')
client_secret = ''
client_info = util.ClientInfo.Create(
discovery_doc, args.scope, client_id, client_secret,
args.user_agent, names, args.api_key, args.version_identifier)
outdir = os.path.expanduser(args.outdir) or client_info.default_directory
if os.path.exists(outdir) and not args.overwrite:
raise exceptions.ConfigurationValueError(
'Output directory exists, pass --overwrite to replace '
'the existing files.')
if not os.path.exists(outdir):
os.makedirs(outdir)
return gen_client_lib.DescriptorGenerator(
discovery_doc, client_info, names, args.root_package, outdir,
base_package=args.base_package,
protorpc_package=args.protorpc_package,
init_wildcards_file=(args.init_file == 'wildcards'),
use_proto2=args.experimental_proto2_output,
unelidable_request_methods=args.unelidable_request_methods,
apitools_version=args.apitools_version)
# TODO(user): Delete this if we don't need this functionality.
def _WriteBaseFiles(codegen):
with util.Chdir(codegen.outdir):
_CopyLocalFile('base_api.py')
_CopyLocalFile('credentials_lib.py')
_CopyLocalFile('exceptions.py')
def _WriteIntermediateInit(codegen):
with io.open('__init__.py', 'w') as out:
codegen.WriteIntermediateInit(out)
def _WriteProtoFiles(codegen):
with util.Chdir(codegen.outdir):
with io.open(codegen.client_info.messages_proto_file_name, 'w') as out:
codegen.WriteMessagesProtoFile(out)
with io.open(codegen.client_info.services_proto_file_name, 'w') as out:
codegen.WriteServicesProtoFile(out)
def _WriteGeneratedFiles(args, codegen):
if codegen.use_proto2:
_WriteProtoFiles(codegen)
with util.Chdir(codegen.outdir):
with io.open(codegen.client_info.messages_file_name, 'w') as out:
codegen.WriteMessagesFile(out)
with io.open(codegen.client_info.client_file_name, 'w') as out:
codegen.WriteClientLibrary(out)
def _WriteInit(codegen):
with util.Chdir(codegen.outdir):
with io.open('__init__.py', 'w') as out:
codegen.WriteInit(out)
def _WriteSetupPy(codegen):
with io.open('setup.py', 'w') as out:
codegen.WriteSetupPy(out)
def GenerateClient(args):
"""Driver for client code generation."""
codegen = _GetCodegenFromFlags(args)
if codegen is None:
logging.error('Failed to create codegen, exiting.')
return 128
_WriteGeneratedFiles(args, codegen)
if args.init_file != 'none':
_WriteInit(codegen)
def GeneratePipPackage(args):
"""Generate a client as a pip-installable tarball."""
discovery_doc = _GetDiscoveryDocFromFlags(args)
package = discovery_doc['name']
original_outdir = os.path.expanduser(args.outdir)
args.outdir = os.path.join(
args.outdir, 'apitools/clients/%s' % package)
args.root_package = 'apitools.clients.%s' % package
codegen = _GetCodegenFromFlags(args)
if codegen is None:
logging.error('Failed to create codegen, exiting.')
return 1
_WriteGeneratedFiles(args, codegen)
_WriteInit(codegen)
with util.Chdir(original_outdir):
_WriteSetupPy(codegen)
with util.Chdir('apitools'):
_WriteIntermediateInit(codegen)
with util.Chdir('clients'):
_WriteIntermediateInit(codegen)
def GenerateProto(args):
"""Generate just the two proto files for a given API."""
codegen = _GetCodegenFromFlags(args)
_WriteProtoFiles(codegen)
class _SplitCommaSeparatedList(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, values.split(','))
def main(argv=None):
if argv is None:
argv = sys.argv
parser = argparse.ArgumentParser(
description='Apitools Client Code Generator')
discovery_group = parser.add_mutually_exclusive_group()
discovery_group.add_argument(
'--infile',
help=('Filename for the discovery document. Mutually exclusive with '
'--discovery_url'))
discovery_group.add_argument(
'--discovery_url',
help=('URL (or "name.version") of the discovery document to use. '
'Mutually exclusive with --infile.'))
parser.add_argument(
'--base_package',
default='apitools.base.py',
help='Base package path of apitools (defaults to apitools.base.py')
parser.add_argument(
'--protorpc_package',
default='apitools.base.protorpclite',
help=('Base package path of protorpc '
'(defaults to apitools.base.protorpclite'))
parser.add_argument(
'--version-identifier',
help=('Version identifier to use for the generated client (defaults to '
'"version" value in discovery doc). This must be a valid '
'identifier when used in a Python module name.'))
parser.add_argument(
'--outdir',
default='',
help='Directory name for output files. (Defaults to the API name.)')
parser.add_argument(
'--overwrite',
default=False, action='store_true',
help='Only overwrite the output directory if this flag is specified.')
parser.add_argument(
'--root_package',
default='',
help=('Python import path for where these modules '
'should be imported from.'))
parser.add_argument(
'--strip_prefix', nargs='*',
default=[],
help=('Prefix to strip from type names in the discovery document. '
'(May be specified multiple times.)'))
parser.add_argument(
'--api_key',
help=('API key to use for API access.'))
parser.add_argument(
'--client_json',
help=('Use the given file downloaded from the dev. console for '
'client_id and client_secret.'))
parser.add_argument(
'--client_id',
default='CLIENT_ID',
help='Client ID to use for the generated client.')
parser.add_argument(
'--client_secret',
default='CLIENT_SECRET',
help='Client secret for the generated client.')
parser.add_argument(
'--scope', nargs='*',
default=[],
help=('Scopes to request in the generated client. '
'May be specified more than once.'))
parser.add_argument(
'--user_agent',
default='x_Tw5K8nnjoRAqULM9PFAC2b',
help=('User agent for the generated client. '
'Defaults to <api>-generated/0.1.'))
parser.add_argument(
'--generate_cli', dest='generate_cli', action='store_true',
help='Ignored.')
parser.add_argument(
'--nogenerate_cli', dest='generate_cli', action='store_false',
help='Ignored.')
parser.add_argument(
'--init-file',
choices=['none', 'empty', 'wildcards'],
type=lambda s: s.lower(),
default='wildcards',
help='Controls whether and how to generate package __init__.py file.')
parser.add_argument(
'--unelidable_request_methods',
action=_SplitCommaSeparatedList,
default=[],
help=('Full method IDs of methods for which we should NOT try to '
'elide the request type. (Should be a comma-separated list.'))
parser.add_argument(
'--apitools_version',
default='', dest='apitools_version',
help=('Apitools version used as a requirement in generated clients. '
'Defaults to version of apitools used to generate the clients.'))
parser.add_argument(
'--experimental_capitalize_enums',
default=False, action='store_true',
help='Dangerous: attempt to rewrite enum values to be uppercase.')
parser.add_argument(
'--experimental_name_convention',
choices=util.Names.NAME_CONVENTIONS,
default=util.Names.DEFAULT_NAME_CONVENTION,
help='Dangerous: use a particular style for generated names.')
parser.add_argument(
'--experimental_proto2_output',
default=False, action='store_true',
help='Dangerous: also output a proto2 message file.')
subparsers = parser.add_subparsers(help='Type of generated code')
client_parser = subparsers.add_parser(
'client', help='Generate apitools client in destination folder')
client_parser.set_defaults(func=GenerateClient)
pip_package_parser = subparsers.add_parser(
'pip_package', help='Generate apitools client pip package')
pip_package_parser.set_defaults(func=GeneratePipPackage)
proto_parser = subparsers.add_parser(
'proto', help='Generate apitools client protos')
proto_parser.set_defaults(func=GenerateProto)
args = parser.parse_args(argv[1:])
return args.func(args) or 0
if __name__ == '__main__':
sys.exit(main())

View File

@@ -0,0 +1,269 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple tool for generating a client library.
Relevant links:
https://developers.google.com/discovery/v1/reference/apis#resource
"""
import datetime
from apitools.gen import message_registry
from apitools.gen import service_registry
from apitools.gen import util
def _ApitoolsVersion():
"""Returns version of the currently installed google-apitools package."""
try:
import pkg_resources
except ImportError:
return 'X.X.X'
try:
return pkg_resources.get_distribution('google-apitools').version
except pkg_resources.DistributionNotFound:
return 'X.X.X'
def _StandardQueryParametersSchema(discovery_doc):
"""Sets up dict of standard query parameters."""
standard_query_schema = {
'id': 'StandardQueryParameters',
'type': 'object',
'description': 'Query parameters accepted by all methods.',
'properties': discovery_doc.get('parameters', {}),
}
# We add an entry for the trace, since Discovery doesn't.
standard_query_schema['properties']['trace'] = {
'type': 'string',
'description': ('A tracing token of the form "token:<tokenid>" '
'to include in api requests.'),
'location': 'query',
}
return standard_query_schema
class DescriptorGenerator(object):
"""Code generator for a given discovery document."""
def __init__(self, discovery_doc, client_info, names, root_package, outdir,
base_package, protorpc_package, init_wildcards_file=True,
use_proto2=False, unelidable_request_methods=None,
apitools_version=''):
self.__discovery_doc = discovery_doc
self.__client_info = client_info
self.__outdir = outdir
self.__use_proto2 = use_proto2
self.__description = util.CleanDescription(
self.__discovery_doc.get('description', ''))
self.__package = self.__client_info.package
self.__version = self.__client_info.version
self.__revision = discovery_doc.get('revision', '1')
self.__init_wildcards_file = init_wildcards_file
self.__root_package = root_package
self.__base_files_package = base_package
self.__protorpc_package = protorpc_package
self.__names = names
# Order is important here: we need the schemas before we can
# define the services.
self.__message_registry = message_registry.MessageRegistry(
self.__client_info, self.__names, self.__description,
self.__root_package, self.__base_files_package,
self.__protorpc_package)
schemas = self.__discovery_doc.get('schemas', {})
for schema_name, schema in sorted(schemas.items()):
self.__message_registry.AddDescriptorFromSchema(
schema_name, schema)
# We need to add one more message type for the global parameters.
standard_query_schema = _StandardQueryParametersSchema(
self.__discovery_doc)
self.__message_registry.AddDescriptorFromSchema(
standard_query_schema['id'], standard_query_schema)
# Now that we know all the messages, we need to correct some
# fields from MessageFields to EnumFields.
self.__message_registry.FixupMessageFields()
self.__services_registry = service_registry.ServiceRegistry(
self.__client_info,
self.__message_registry,
self.__names,
self.__root_package,
self.__base_files_package,
unelidable_request_methods or [])
services = self.__discovery_doc.get('resources', {})
for service_name, methods in sorted(services.items()):
self.__services_registry.AddServiceFromResource(
service_name, methods)
# We might also have top-level methods.
api_methods = self.__discovery_doc.get('methods', [])
if api_methods:
self.__services_registry.AddServiceFromResource(
'api', {'methods': api_methods})
# pylint: disable=protected-access
self.__client_info = self.__client_info._replace(
scopes=self.__services_registry.scopes)
# The apitools version that will be used in prerequisites for the
# generated packages.
self.__apitools_version = (
apitools_version if apitools_version else _ApitoolsVersion())
@property
def client_info(self):
return self.__client_info
@property
def discovery_doc(self):
return self.__discovery_doc
@property
def names(self):
return self.__names
@property
def outdir(self):
return self.__outdir
@property
def package(self):
return self.__package
@property
def use_proto2(self):
return self.__use_proto2
@property
def apitools_version(self):
return self.__apitools_version
def _GetPrinter(self, out):
printer = util.SimplePrettyPrinter(out)
return printer
def WriteInit(self, out):
"""Write a simple __init__.py for the generated client."""
printer = self._GetPrinter(out)
if self.__init_wildcards_file:
printer('"""Common imports for generated %s client library."""',
self.__client_info.package)
printer('# pylint:disable=wildcard-import')
else:
printer('"""Package marker file."""')
printer()
printer('from __future__ import absolute_import')
printer()
printer('import pkgutil')
printer()
if self.__init_wildcards_file:
printer('from %s import *', self.__base_files_package)
if self.__root_package == '.':
import_prefix = '.'
else:
import_prefix = '%s.' % self.__root_package
printer('from %s%s import *',
import_prefix, self.__client_info.client_rule_name)
printer('from %s%s import *',
import_prefix, self.__client_info.messages_rule_name)
printer()
printer('__path__ = pkgutil.extend_path(__path__, __name__)')
def WriteIntermediateInit(self, out):
"""Write a simple __init__.py for an intermediate directory."""
printer = self._GetPrinter(out)
printer('#!/usr/bin/env python')
printer('"""Shared __init__.py for apitools."""')
printer()
printer('from pkgutil import extend_path')
printer('__path__ = extend_path(__path__, __name__)')
def WriteSetupPy(self, out):
"""Write a setup.py for upload to PyPI."""
printer = self._GetPrinter(out)
year = datetime.datetime.now().year
printer('# Copyright %s Google Inc. All Rights Reserved.' % year)
printer('#')
printer('# Licensed under the Apache License, Version 2.0 (the'
'"License");')
printer('# you may not use this file except in compliance with '
'the License.')
printer('# You may obtain a copy of the License at')
printer('#')
printer('# http://www.apache.org/licenses/LICENSE-2.0')
printer('#')
printer('# Unless required by applicable law or agreed to in writing, '
'software')
printer('# distributed under the License is distributed on an "AS IS" '
'BASIS,')
printer('# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either '
'express or implied.')
printer('# See the License for the specific language governing '
'permissions and')
printer('# limitations under the License.')
printer()
printer('import setuptools')
printer('REQUIREMENTS = [')
with printer.Indent(indent=' '):
parts = self.apitools_version.split('.')
major = parts.pop(0)
minor = parts.pop(0)
printer('"google-apitools>=%s,~=%s.%s",',
self.apitools_version, major, minor)
printer('"httplib2>=0.9",')
printer('"oauth2client>=1.4.12",')
printer(']')
printer('_PACKAGE = "apitools.clients.%s"' % self.__package)
printer()
printer('setuptools.setup(')
# TODO(user): Allow customization of these options.
with printer.Indent(indent=' '):
printer('name="google-apitools-%s-%s",',
self.__package, self.__version)
printer('version="%s.%s",',
self.apitools_version, self.__revision)
printer('description="Autogenerated apitools library for %s",' % (
self.__package,))
printer('url="https://github.com/google/apitools",')
printer('author="Craig Citro",')
printer('author_email="craigcitro@google.com",')
printer('packages=setuptools.find_packages(),')
printer('install_requires=REQUIREMENTS,')
printer('classifiers=[')
with printer.Indent(indent=' '):
printer('"Programming Language :: Python :: 2.7",')
printer('"License :: OSI Approved :: Apache Software '
'License",')
printer('],')
printer('license="Apache 2.0",')
printer('keywords="apitools apitools-%s %s",' % (
self.__package, self.__package))
printer(')')
def WriteMessagesFile(self, out):
self.__message_registry.WriteFile(self._GetPrinter(out))
def WriteMessagesProtoFile(self, out):
self.__message_registry.WriteProtoFile(self._GetPrinter(out))
def WriteServicesProtoFile(self, out):
self.__services_registry.WriteProtoFile(self._GetPrinter(out))
def WriteClientLibrary(self, out):
self.__services_registry.WriteFile(self._GetPrinter(out))

View File

@@ -0,0 +1,473 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Message registry for apitools."""
import collections
import contextlib
import json
import six
from apitools.base.protorpclite import descriptor
from apitools.base.protorpclite import messages
from apitools.gen import extended_descriptor
from apitools.gen import util
TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant'))
class MessageRegistry(object):
"""Registry for message types.
This closely mirrors a messages.FileDescriptor, but adds additional
attributes (such as message and field descriptions) and some extra
code for validation and cycle detection.
"""
# Type information from these two maps comes from here:
# https://developers.google.com/discovery/v1/type-format
PRIMITIVE_TYPE_INFO_MAP = {
'string': TypeInfo(type_name='string',
variant=messages.StringField.DEFAULT_VARIANT),
'integer': TypeInfo(type_name='integer',
variant=messages.IntegerField.DEFAULT_VARIANT),
'boolean': TypeInfo(type_name='boolean',
variant=messages.BooleanField.DEFAULT_VARIANT),
'number': TypeInfo(type_name='number',
variant=messages.FloatField.DEFAULT_VARIANT),
'any': TypeInfo(type_name='extra_types.JsonValue',
variant=messages.Variant.MESSAGE),
}
PRIMITIVE_FORMAT_MAP = {
'int32': TypeInfo(type_name='integer',
variant=messages.Variant.INT32),
'uint32': TypeInfo(type_name='integer',
variant=messages.Variant.UINT32),
'int64': TypeInfo(type_name='string',
variant=messages.Variant.INT64),
'uint64': TypeInfo(type_name='string',
variant=messages.Variant.UINT64),
'double': TypeInfo(type_name='number',
variant=messages.Variant.DOUBLE),
'float': TypeInfo(type_name='number',
variant=messages.Variant.FLOAT),
'byte': TypeInfo(type_name='byte',
variant=messages.BytesField.DEFAULT_VARIANT),
'date': TypeInfo(type_name='extra_types.DateField',
variant=messages.Variant.STRING),
'date-time': TypeInfo(
type_name=('apitools.base.protorpclite.message_types.'
'DateTimeMessage'),
variant=messages.Variant.MESSAGE),
}
def __init__(self, client_info, names, description, root_package_dir,
base_files_package, protorpc_package):
self.__names = names
self.__client_info = client_info
self.__package = client_info.package
self.__description = util.CleanDescription(description)
self.__root_package_dir = root_package_dir
self.__base_files_package = base_files_package
self.__protorpc_package = protorpc_package
self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor(
package=self.__package, description=self.__description)
# Add required imports
self.__file_descriptor.additional_imports = [
'from %s import messages as _messages' % self.__protorpc_package,
]
# Map from scoped names (i.e. Foo.Bar) to MessageDescriptors.
self.__message_registry = collections.OrderedDict()
# A set of types that we're currently adding (for cycle detection).
self.__nascent_types = set()
# A set of types for which we've seen a reference but no
# definition; if this set is nonempty, validation fails.
self.__unknown_types = set()
# Used for tracking paths during message creation
self.__current_path = []
# Where to register created messages
self.__current_env = self.__file_descriptor
# TODO(user): Add a `Finalize` method.
@property
def file_descriptor(self):
self.Validate()
return self.__file_descriptor
def WriteProtoFile(self, printer):
"""Write the messages file to out as proto."""
self.Validate()
extended_descriptor.WriteMessagesFile(
self.__file_descriptor, self.__package, self.__client_info.version,
printer)
def WriteFile(self, printer):
"""Write the messages file to out."""
self.Validate()
extended_descriptor.WritePythonFile(
self.__file_descriptor, self.__package, self.__client_info.version,
printer)
def Validate(self):
mysteries = self.__nascent_types or self.__unknown_types
if mysteries:
raise ValueError('Malformed MessageRegistry: %s' % mysteries)
def __ComputeFullName(self, name):
return '.'.join(map(six.text_type, self.__current_path[:] + [name]))
def __AddImport(self, new_import):
if new_import not in self.__file_descriptor.additional_imports:
self.__file_descriptor.additional_imports.append(new_import)
def __DeclareDescriptor(self, name):
self.__nascent_types.add(self.__ComputeFullName(name))
def __RegisterDescriptor(self, new_descriptor):
"""Register the given descriptor in this registry."""
if not isinstance(new_descriptor, (
extended_descriptor.ExtendedMessageDescriptor,
extended_descriptor.ExtendedEnumDescriptor)):
raise ValueError('Cannot add descriptor of type %s' % (
type(new_descriptor),))
full_name = self.__ComputeFullName(new_descriptor.name)
if full_name in self.__message_registry:
raise ValueError(
'Attempt to re-register descriptor %s' % full_name)
if full_name not in self.__nascent_types:
raise ValueError('Directly adding types is not supported')
new_descriptor.full_name = full_name
self.__message_registry[full_name] = new_descriptor
if isinstance(new_descriptor,
extended_descriptor.ExtendedMessageDescriptor):
self.__current_env.message_types.append(new_descriptor)
elif isinstance(new_descriptor,
extended_descriptor.ExtendedEnumDescriptor):
self.__current_env.enum_types.append(new_descriptor)
self.__unknown_types.discard(full_name)
self.__nascent_types.remove(full_name)
def LookupDescriptor(self, name):
return self.__GetDescriptorByName(name)
def LookupDescriptorOrDie(self, name):
message_descriptor = self.LookupDescriptor(name)
if message_descriptor is None:
raise ValueError('No message descriptor named "%s"' % name)
return message_descriptor
def __GetDescriptor(self, name):
return self.__GetDescriptorByName(self.__ComputeFullName(name))
def __GetDescriptorByName(self, name):
if name in self.__message_registry:
return self.__message_registry[name]
if name in self.__nascent_types:
raise ValueError(
'Cannot retrieve type currently being created: %s' % name)
return None
@contextlib.contextmanager
def __DescriptorEnv(self, message_descriptor):
# TODO(user): Typecheck?
previous_env = self.__current_env
self.__current_path.append(message_descriptor.name)
self.__current_env = message_descriptor
yield
self.__current_path.pop()
self.__current_env = previous_env
def AddEnumDescriptor(self, name, description,
enum_values, enum_descriptions):
"""Add a new EnumDescriptor named name with the given enum values."""
message = extended_descriptor.ExtendedEnumDescriptor()
message.name = self.__names.ClassName(name)
message.description = util.CleanDescription(description)
self.__DeclareDescriptor(message.name)
for index, (enum_name, enum_description) in enumerate(
zip(enum_values, enum_descriptions)):
enum_value = extended_descriptor.ExtendedEnumValueDescriptor()
enum_value.name = self.__names.NormalizeEnumName(enum_name)
if enum_value.name != enum_name:
message.enum_mappings.append(
extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping(
python_name=enum_value.name, json_name=enum_name))
self.__AddImport('from %s import encoding' %
self.__base_files_package)
enum_value.number = index
enum_value.description = util.CleanDescription(
enum_description or '<no description>')
message.values.append(enum_value)
self.__RegisterDescriptor(message)
def __DeclareMessageAlias(self, schema, alias_for):
"""Declare schema as an alias for alias_for."""
# TODO(user): This is a hack. Remove it.
message = extended_descriptor.ExtendedMessageDescriptor()
message.name = self.__names.ClassName(schema['id'])
message.alias_for = alias_for
self.__DeclareDescriptor(message.name)
self.__AddImport('from %s import extra_types' %
self.__base_files_package)
self.__RegisterDescriptor(message)
def __AddAdditionalProperties(self, message, schema, properties):
"""Add an additionalProperties field to message."""
additional_properties_info = schema['additionalProperties']
entries_type_name = self.__AddAdditionalPropertyType(
message.name, additional_properties_info)
description = util.CleanDescription(
additional_properties_info.get('description'))
if description is None:
description = 'Additional properties of type %s' % message.name
attrs = {
'items': {
'$ref': entries_type_name,
},
'description': description,
'type': 'array',
}
field_name = 'additionalProperties'
message.fields.append(self.__FieldDescriptorFromProperties(
field_name, len(properties) + 1, attrs))
self.__AddImport('from %s import encoding' % self.__base_files_package)
message.decorators.append(
'encoding.MapUnrecognizedFields(%r)' % field_name)
def AddDescriptorFromSchema(self, schema_name, schema):
"""Add a new MessageDescriptor named schema_name based on schema."""
# TODO(user): Is schema_name redundant?
if self.__GetDescriptor(schema_name):
return
if schema.get('enum'):
self.__DeclareEnum(schema_name, schema)
return
if schema.get('type') == 'any':
self.__DeclareMessageAlias(schema, 'extra_types.JsonValue')
return
if schema.get('type') != 'object':
raise ValueError('Cannot create message descriptors for type %s' %
schema.get('type'))
message = extended_descriptor.ExtendedMessageDescriptor()
message.name = self.__names.ClassName(schema['id'])
message.description = util.CleanDescription(schema.get(
'description', 'A %s object.' % message.name))
self.__DeclareDescriptor(message.name)
with self.__DescriptorEnv(message):
properties = schema.get('properties', {})
for index, (name, attrs) in enumerate(sorted(properties.items())):
field = self.__FieldDescriptorFromProperties(
name, index + 1, attrs)
message.fields.append(field)
if field.name != name:
message.field_mappings.append(
type(message).JsonFieldMapping(
python_name=field.name, json_name=name))
self.__AddImport(
'from %s import encoding' % self.__base_files_package)
if 'additionalProperties' in schema:
self.__AddAdditionalProperties(message, schema, properties)
self.__RegisterDescriptor(message)
def __AddAdditionalPropertyType(self, name, property_schema):
"""Add a new nested AdditionalProperty message."""
new_type_name = 'AdditionalProperty'
property_schema = dict(property_schema)
# We drop the description here on purpose, so the resulting
# messages are less repetitive.
property_schema.pop('description', None)
description = 'An additional property for a %s object.' % name
schema = {
'id': new_type_name,
'type': 'object',
'description': description,
'properties': {
'key': {
'type': 'string',
'description': 'Name of the additional property.',
},
'value': property_schema,
},
}
self.AddDescriptorFromSchema(new_type_name, schema)
return new_type_name
def __AddEntryType(self, entry_type_name, entry_schema, parent_name):
"""Add a type for a list entry."""
entry_schema.pop('description', None)
description = 'Single entry in a %s.' % parent_name
schema = {
'id': entry_type_name,
'type': 'object',
'description': description,
'properties': {
'entry': {
'type': 'array',
'items': entry_schema,
},
},
}
self.AddDescriptorFromSchema(entry_type_name, schema)
return entry_type_name
def __FieldDescriptorFromProperties(self, name, index, attrs):
"""Create a field descriptor for these attrs."""
field = descriptor.FieldDescriptor()
field.name = self.__names.CleanName(name)
field.number = index
field.label = self.__ComputeLabel(attrs)
new_type_name_hint = self.__names.ClassName(
'%sValue' % self.__names.ClassName(name))
type_info = self.__GetTypeInfo(attrs, new_type_name_hint)
field.type_name = type_info.type_name
field.variant = type_info.variant
if 'default' in attrs:
# TODO(user): Correctly handle non-primitive default values.
default = attrs['default']
if not (field.type_name == 'string' or
field.variant == messages.Variant.ENUM):
default = str(json.loads(default))
if field.variant == messages.Variant.ENUM:
default = self.__names.NormalizeEnumName(default)
field.default_value = default
extended_field = extended_descriptor.ExtendedFieldDescriptor()
extended_field.name = field.name
extended_field.description = util.CleanDescription(
attrs.get('description', 'A %s attribute.' % field.type_name))
extended_field.field_descriptor = field
return extended_field
@staticmethod
def __ComputeLabel(attrs):
if attrs.get('required', False):
return descriptor.FieldDescriptor.Label.REQUIRED
elif attrs.get('type') == 'array':
return descriptor.FieldDescriptor.Label.REPEATED
elif attrs.get('repeated'):
return descriptor.FieldDescriptor.Label.REPEATED
return descriptor.FieldDescriptor.Label.OPTIONAL
def __DeclareEnum(self, enum_name, attrs):
description = util.CleanDescription(attrs.get('description', ''))
enum_values = attrs['enum']
enum_descriptions = attrs.get(
'enumDescriptions', [''] * len(enum_values))
self.AddEnumDescriptor(enum_name, description,
enum_values, enum_descriptions)
self.__AddIfUnknown(enum_name)
return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM)
def __AddIfUnknown(self, type_name):
type_name = self.__names.ClassName(type_name)
full_type_name = self.__ComputeFullName(type_name)
if (full_type_name not in self.__message_registry.keys() and
type_name not in self.__message_registry.keys()):
self.__unknown_types.add(type_name)
def __GetTypeInfo(self, attrs, name_hint):
"""Return a TypeInfo object for attrs, creating one if needed."""
type_ref = self.__names.ClassName(attrs.get('$ref'))
type_name = attrs.get('type')
if not (type_ref or type_name):
raise ValueError('No type found for %s' % attrs)
if type_ref:
self.__AddIfUnknown(type_ref)
# We don't actually know this is a message -- it might be an
# enum. However, we can't check that until we've created all the
# types, so we come back and fix this up later.
return TypeInfo(
type_name=type_ref, variant=messages.Variant.MESSAGE)
if 'enum' in attrs:
enum_name = '%sValuesEnum' % name_hint
return self.__DeclareEnum(enum_name, attrs)
if 'format' in attrs:
type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format'])
# NOTE: If we don't recognize the format, the spec says we fall back
# to just using the type name.
if type_info is not None:
if type_info.type_name.startswith((
'apitools.base.protorpclite.message_types.',
'message_types.')):
self.__AddImport(
'from %s import message_types as _message_types' %
self.__protorpc_package)
if type_info.type_name.startswith('extra_types.'):
self.__AddImport(
'from %s import extra_types' % self.__base_files_package)
return type_info
if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name]
if type_info.type_name.startswith('extra_types.'):
self.__AddImport(
'from %s import extra_types' % self.__base_files_package)
return type_info
if type_name == 'array':
items = attrs.get('items')
if not items:
raise ValueError('Array type with no item type: %s' % attrs)
entry_name_hint = self.__names.ClassName(
items.get('title') or '%sListEntry' % name_hint)
entry_label = self.__ComputeLabel(items)
if entry_label == descriptor.FieldDescriptor.Label.REPEATED:
parent_name = self.__names.ClassName(
items.get('title') or name_hint)
entry_type_name = self.__AddEntryType(
entry_name_hint, items.get('items'), parent_name)
return TypeInfo(type_name=entry_type_name,
variant=messages.Variant.MESSAGE)
return self.__GetTypeInfo(items, entry_name_hint)
elif type_name == 'any':
self.__AddImport('from %s import extra_types' %
self.__base_files_package)
return self.PRIMITIVE_TYPE_INFO_MAP['any']
elif type_name == 'object':
# TODO(user): Think of a better way to come up with names.
if not name_hint:
raise ValueError(
'Cannot create subtype without some name hint')
schema = dict(attrs)
schema['id'] = name_hint
self.AddDescriptorFromSchema(name_hint, schema)
self.__AddIfUnknown(name_hint)
return TypeInfo(
type_name=name_hint, variant=messages.Variant.MESSAGE)
raise ValueError('Unknown type: %s' % type_name)
def FixupMessageFields(self):
for message_type in self.file_descriptor.message_types:
self._FixupMessage(message_type)
def _FixupMessage(self, message_type):
with self.__DescriptorEnv(message_type):
for field in message_type.fields:
if field.field_descriptor.variant == messages.Variant.MESSAGE:
field_type_name = field.field_descriptor.type_name
field_type = self.LookupDescriptor(field_type_name)
if isinstance(field_type,
extended_descriptor.ExtendedEnumDescriptor):
field.field_descriptor.variant = messages.Variant.ENUM
for submessage_type in message_type.message_types:
self._FixupMessage(submessage_type)

View File

@@ -0,0 +1,488 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Service registry for apitools."""
import collections
import logging
import re
import textwrap
from apitools.base.py import base_api
from apitools.gen import util
# We're a code generator. I don't care.
# pylint:disable=too-many-statements
_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+')
class ServiceRegistry(object):
"""Registry for service types."""
def __init__(self, client_info, message_registry,
names, root_package, base_files_package,
unelidable_request_methods):
self.__client_info = client_info
self.__package = client_info.package
self.__names = names
self.__service_method_info_map = collections.OrderedDict()
self.__message_registry = message_registry
self.__root_package = root_package
self.__base_files_package = base_files_package
self.__unelidable_request_methods = unelidable_request_methods
self.__all_scopes = set(self.__client_info.scopes)
def Validate(self):
self.__message_registry.Validate()
@property
def scopes(self):
return sorted(list(self.__all_scopes))
def __GetServiceClassName(self, service_name):
return self.__names.ClassName(
'%sService' % self.__names.ClassName(service_name))
def __PrintDocstring(self, printer, method_info, method_name, name):
"""Print a docstring for a service method."""
if method_info.description:
description = util.CleanDescription(method_info.description)
first_line, newline, remaining = method_info.description.partition(
'\n')
if not first_line.endswith('.'):
first_line = '%s.' % first_line
description = '%s%s%s' % (first_line, newline, remaining)
else:
description = '%s method for the %s service.' % (method_name, name)
with printer.CommentContext():
printer('r"""%s' % description)
printer()
printer('Args:')
printer(' request: (%s) input message', method_info.request_type_name)
printer(' global_params: (StandardQueryParameters, default: None) '
'global arguments')
if method_info.upload_config:
printer(' upload: (Upload, default: None) If present, upload')
printer(' this stream with the request.')
if method_info.supports_download:
printer(
' download: (Download, default: None) If present, download')
printer(' data from the request via this stream.')
printer('Returns:')
printer(' (%s) The response message.', method_info.response_type_name)
printer('"""')
def __WriteSingleService(
self, printer, name, method_info_map, client_class_name):
printer()
class_name = self.__GetServiceClassName(name)
printer('class %s(base_api.BaseApiService):', class_name)
with printer.Indent():
printer('"""Service class for the %s resource."""', name)
printer()
printer('_NAME = %s', repr(name))
# Print the configs for the methods first.
printer()
printer('def __init__(self, client):')
with printer.Indent():
printer('super(%s.%s, self).__init__(client)',
client_class_name, class_name)
printer('self._upload_configs = {')
with printer.Indent(indent=' '):
for method_name, method_info in method_info_map.items():
upload_config = method_info.upload_config
if upload_config is not None:
printer(
"'%s': base_api.ApiUploadInfo(", method_name)
with printer.Indent(indent=' '):
attrs = sorted(
x.name for x in upload_config.all_fields())
for attr in attrs:
printer('%s=%r,',
attr, getattr(upload_config, attr))
printer('),')
printer('}')
# Now write each method in turn.
for method_name, method_info in method_info_map.items():
printer()
params = ['self', 'request', 'global_params=None']
if method_info.upload_config:
params.append('upload=None')
if method_info.supports_download:
params.append('download=None')
printer('def %s(%s):', method_name, ', '.join(params))
with printer.Indent():
self.__PrintDocstring(
printer, method_info, method_name, name)
printer("config = self.GetMethodConfig('%s')", method_name)
upload_config = method_info.upload_config
if upload_config is not None:
printer("upload_config = self.GetUploadConfig('%s')",
method_name)
arg_lines = [
'config, request, global_params=global_params']
if method_info.upload_config:
arg_lines.append(
'upload=upload, upload_config=upload_config')
if method_info.supports_download:
arg_lines.append('download=download')
printer('return self._RunMethod(')
with printer.Indent(indent=' '):
for line in arg_lines[:-1]:
printer('%s,', line)
printer('%s)', arg_lines[-1])
printer()
printer('{0}.method_config = lambda: base_api.ApiMethodInfo('
.format(method_name))
with printer.Indent(indent=' '):
method_info = method_info_map[method_name]
attrs = sorted(
x.name for x in method_info.all_fields())
for attr in attrs:
if attr in ('upload_config', 'description'):
continue
value = getattr(method_info, attr)
if value is not None:
printer('%s=%r,', attr, value)
printer(')')
def __WriteProtoServiceDeclaration(self, printer, name, method_info_map):
"""Write a single service declaration to a proto file."""
printer()
printer('service %s {', self.__GetServiceClassName(name))
with printer.Indent():
for method_name, method_info in method_info_map.items():
for line in textwrap.wrap(method_info.description,
printer.CalculateWidth() - 3):
printer('// %s', line)
printer('rpc %s (%s) returns (%s);',
method_name,
method_info.request_type_name,
method_info.response_type_name)
printer('}')
def WriteProtoFile(self, printer):
"""Write the services in this registry to out as proto."""
self.Validate()
client_info = self.__client_info
printer('// Generated services for %s version %s.',
client_info.package, client_info.version)
printer()
printer('syntax = "proto2";')
printer('package %s;', self.__package)
printer('import "%s";', client_info.messages_proto_file_name)
printer()
for name, method_info_map in self.__service_method_info_map.items():
self.__WriteProtoServiceDeclaration(printer, name, method_info_map)
def WriteFile(self, printer):
"""Write the services in this registry to out."""
self.Validate()
client_info = self.__client_info
printer('"""Generated client library for %s version %s."""',
client_info.package, client_info.version)
printer('# NOTE: This file is autogenerated and should not be edited '
'by hand.')
printer()
printer('from __future__ import absolute_import')
printer()
printer('from %s import base_api', self.__base_files_package)
if self.__root_package:
import_prefix = 'from {0} '.format(self.__root_package)
else:
import_prefix = ''
printer('%simport %s as messages', import_prefix,
client_info.messages_rule_name)
printer()
printer()
printer('class %s(base_api.BaseApiClient):',
client_info.client_class_name)
with printer.Indent():
printer(
'"""Generated client library for service %s version %s."""',
client_info.package, client_info.version)
printer()
printer('MESSAGES_MODULE = messages')
printer('BASE_URL = {0!r}'.format(client_info.base_url))
printer('MTLS_BASE_URL = {0!r}'.format(client_info.mtls_base_url))
printer()
printer('_PACKAGE = {0!r}'.format(client_info.package))
printer('_SCOPES = {0!r}'.format(
client_info.scopes or
['https://www.googleapis.com/auth/userinfo.email']))
printer('_VERSION = {0!r}'.format(client_info.version))
printer('_CLIENT_ID = {0!r}'.format(client_info.client_id))
printer('_CLIENT_SECRET = {0!r}'.format(client_info.client_secret))
printer('_USER_AGENT = {0!r}'.format(client_info.user_agent))
printer('_CLIENT_CLASS_NAME = {0!r}'.format(
client_info.client_class_name))
printer('_URL_VERSION = {0!r}'.format(client_info.url_version))
printer('_API_KEY = {0!r}'.format(client_info.api_key))
printer()
printer("def __init__(self, url='', credentials=None,")
with printer.Indent(indent=' '):
printer('get_credentials=True, http=None, model=None,')
printer('log_request=False, log_response=False,')
printer('credentials_args=None, default_global_params=None,')
printer('additional_http_headers=None, '
'response_encoding=None):')
with printer.Indent():
printer('"""Create a new %s handle."""', client_info.package)
printer('url = url or self.BASE_URL')
printer(
'super(%s, self).__init__(', client_info.client_class_name)
printer(' url, credentials=credentials,')
printer(' get_credentials=get_credentials, http=http, '
'model=model,')
printer(' log_request=log_request, '
'log_response=log_response,')
printer(' credentials_args=credentials_args,')
printer(' default_global_params=default_global_params,')
printer(' additional_http_headers=additional_http_headers,')
printer(' response_encoding=response_encoding)')
for name in self.__service_method_info_map.keys():
printer('self.%s = self.%s(self)',
name, self.__GetServiceClassName(name))
for name, method_info in self.__service_method_info_map.items():
self.__WriteSingleService(
printer, name, method_info, client_info.client_class_name)
def __RegisterService(self, service_name, method_info_map):
if service_name in self.__service_method_info_map:
raise ValueError(
'Attempt to re-register descriptor %s' % service_name)
self.__service_method_info_map[service_name] = method_info_map
def __CreateRequestType(self, method_description, body_type=None):
"""Create a request type for this method."""
schema = {}
schema['id'] = self.__names.ClassName('%sRequest' % (
self.__names.ClassName(method_description['id'], separator='.'),))
schema['type'] = 'object'
schema['properties'] = collections.OrderedDict()
if 'parameterOrder' not in method_description:
ordered_parameters = list(method_description.get('parameters', []))
else:
ordered_parameters = method_description['parameterOrder'][:]
for k in method_description['parameters']:
if k not in ordered_parameters:
ordered_parameters.append(k)
for parameter_name in ordered_parameters:
field = dict(method_description['parameters'][parameter_name])
if 'type' not in field:
raise ValueError('No type found in parameter %s' % field)
schema['properties'][parameter_name] = field
if body_type is not None:
body_field_name = self.__GetRequestField(
method_description, body_type)
if body_field_name in schema['properties']:
raise ValueError('Failed to normalize request resource name')
if 'description' not in body_type:
body_type['description'] = (
'A %s resource to be passed as the request body.' % (
self.__GetRequestType(body_type),))
schema['properties'][body_field_name] = body_type
self.__message_registry.AddDescriptorFromSchema(schema['id'], schema)
return schema['id']
def __CreateVoidResponseType(self, method_description):
"""Create an empty response type."""
schema = {}
method_name = self.__names.ClassName(
method_description['id'], separator='.')
schema['id'] = self.__names.ClassName('%sResponse' % method_name)
schema['type'] = 'object'
schema['description'] = 'An empty %s response.' % method_name
self.__message_registry.AddDescriptorFromSchema(schema['id'], schema)
return schema['id']
def __NeedRequestType(self, method_description, request_type):
"""Determine if this method needs a new request type created."""
if not request_type:
return True
method_id = method_description.get('id', '')
if method_id in self.__unelidable_request_methods:
return True
message = self.__message_registry.LookupDescriptorOrDie(request_type)
if message is None:
return True
field_names = [x.name for x in message.fields]
parameters = method_description.get('parameters', {})
for param_name, param_info in parameters.items():
if (param_info.get('location') != 'path' or
self.__names.CleanName(param_name) not in field_names):
break
else:
return False
return True
def __MaxSizeToInt(self, max_size):
"""Convert max_size to an int."""
size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size)
if size_groups is None:
raise ValueError('Could not parse maxSize')
size, unit = size_groups.group('size', 'unit')
shift = 0
if unit is not None:
unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40}
shift = unit_dict.get(unit.upper())
if shift is None:
raise ValueError('Unknown unit %s' % unit)
return int(size) * (1 << shift)
def __ComputeUploadConfig(self, media_upload_config, method_id):
"""Fill out the upload config for this method."""
config = base_api.ApiUploadInfo()
if 'maxSize' in media_upload_config:
config.max_size = self.__MaxSizeToInt(
media_upload_config['maxSize'])
if 'accept' not in media_upload_config:
logging.warning(
'No accept types found for upload configuration in '
'method %s, using */*', method_id)
config.accept.extend([
str(a) for a in media_upload_config.get('accept', '*/*')])
for accept_pattern in config.accept:
if not _MIME_PATTERN_RE.match(accept_pattern):
logging.warning('Unexpected MIME type: %s', accept_pattern)
protocols = media_upload_config.get('protocols', {})
for protocol in ('simple', 'resumable'):
media = protocols.get(protocol, {})
for attr in ('multipart', 'path'):
if attr in media:
setattr(config, '%s_%s' % (protocol, attr), media[attr])
return config
def __ComputeMethodInfo(self, method_description, request, response,
request_field):
"""Compute the base_api.ApiMethodInfo for this method."""
relative_path = self.__names.NormalizeRelativePath(
''.join((self.__client_info.base_path,
method_description['path'])))
method_id = method_description['id']
ordered_params = []
for param_name in method_description.get('parameterOrder', []):
param_info = method_description['parameters'][param_name]
if param_info.get('required', False):
ordered_params.append(param_name)
method_info = base_api.ApiMethodInfo(
relative_path=relative_path,
method_id=method_id,
http_method=method_description['httpMethod'],
description=util.CleanDescription(
method_description.get('description', '')),
query_params=[],
path_params=[],
ordered_params=ordered_params,
request_type_name=self.__names.ClassName(request),
response_type_name=self.__names.ClassName(response),
request_field=request_field,
)
flat_path = method_description.get('flatPath', None)
if flat_path is not None:
flat_path = self.__names.NormalizeRelativePath(
self.__client_info.base_path + flat_path)
if flat_path != relative_path:
method_info.flat_path = flat_path
if method_description.get('supportsMediaUpload', False):
method_info.upload_config = self.__ComputeUploadConfig(
method_description.get('mediaUpload'), method_id)
method_info.supports_download = method_description.get(
'supportsMediaDownload', False)
if method_description.get('apiVersion'):
method_info.api_version_param = method_description.get('apiVersion')
self.__all_scopes.update(method_description.get('scopes', ()))
for param, desc in method_description.get('parameters', {}).items():
param = self.__names.CleanName(param)
location = desc['location']
if location == 'query':
method_info.query_params.append(param)
elif location == 'path':
method_info.path_params.append(param)
else:
raise ValueError(
'Unknown parameter location %s for parameter %s' % (
location, param))
method_info.path_params.sort()
method_info.query_params.sort()
return method_info
def __BodyFieldName(self, body_type):
if body_type is None:
return ''
return self.__names.FieldName(body_type['$ref'])
def __GetRequestType(self, body_type):
return self.__names.ClassName(body_type.get('$ref'))
def __GetRequestField(self, method_description, body_type):
"""Determine the request field for this method."""
body_field_name = self.__BodyFieldName(body_type)
if body_field_name in method_description.get('parameters', {}):
body_field_name = self.__names.FieldName(
'%s_resource' % body_field_name)
# It's exceedingly unlikely that we'd get two name collisions, which
# means it's bound to happen at some point.
while body_field_name in method_description.get('parameters', {}):
body_field_name = self.__names.FieldName(
'%s_body' % body_field_name)
return body_field_name
def AddServiceFromResource(self, service_name, methods):
"""Add a new service named service_name with the given methods."""
service_name = self.__names.CleanName(service_name)
method_descriptions = methods.get('methods', {})
method_info_map = collections.OrderedDict()
items = sorted(method_descriptions.items())
for method_name, method_description in items:
method_name = self.__names.MethodName(method_name)
# NOTE: According to the discovery document, if the request or
# response is present, it will simply contain a `$ref`.
body_type = method_description.get('request')
if body_type is None:
request_type = None
else:
request_type = self.__GetRequestType(body_type)
if self.__NeedRequestType(method_description, request_type):
request = self.__CreateRequestType(
method_description, body_type=body_type)
request_field = self.__GetRequestField(
method_description, body_type)
else:
request = request_type
request_field = base_api.REQUEST_IS_BODY
if 'response' in method_description:
response = method_description['response']['$ref']
else:
response = self.__CreateVoidResponseType(method_description)
method_info_map[method_name] = self.__ComputeMethodInfo(
method_description, request, response, request_field)
nested_services = methods.get('resources', {})
services = sorted(nested_services.items())
for subservice_name, submethods in services:
new_service_name = '%s_%s' % (service_name, subservice_name)
self.AddServiceFromResource(new_service_name, submethods)
self.__RegisterService(service_name, method_info_map)

View File

@@ -0,0 +1,55 @@
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various utilities used in tests."""
import contextlib
import os
import shutil
import sys
import tempfile
import unittest
import six
SkipOnWindows = unittest.skipIf(
os.name == 'nt', 'Does not run on windows')
@contextlib.contextmanager
def TempDir(change_to=False):
if change_to:
original_dir = os.getcwd()
path = tempfile.mkdtemp()
try:
if change_to:
os.chdir(path)
yield path
finally:
if change_to:
os.chdir(original_dir)
shutil.rmtree(path)
@contextlib.contextmanager
def CaptureOutput():
new_stdout, new_stderr = six.StringIO(), six.StringIO()
old_stdout, old_stderr = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_stdout, new_stderr
yield new_stdout, new_stderr
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr

View File

@@ -0,0 +1,439 @@
#!/usr/bin/env python
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Assorted utilities shared between parts of apitools."""
from __future__ import print_function
from __future__ import unicode_literals
import collections
import contextlib
import gzip
import json
import keyword
import logging
import os
import re
import tempfile
import six
from six.moves import urllib_parse
import six.moves.urllib.error as urllib_error
import six.moves.urllib.request as urllib_request
class Error(Exception):
"""Base error for apitools generation."""
class CommunicationError(Error):
"""Error in network communication."""
def _SortLengthFirstKey(a):
return -len(a), a
class Names(object):
"""Utility class for cleaning and normalizing names in a fixed style."""
DEFAULT_NAME_CONVENTION = 'LOWER_CAMEL'
NAME_CONVENTIONS = ['LOWER_CAMEL', 'LOWER_WITH_UNDER', 'NONE']
def __init__(self, strip_prefixes,
name_convention=None,
capitalize_enums=False):
self.__strip_prefixes = sorted(strip_prefixes, key=_SortLengthFirstKey)
self.__name_convention = (
name_convention or self.DEFAULT_NAME_CONVENTION)
self.__capitalize_enums = capitalize_enums
@staticmethod
def __FromCamel(name, separator='_'):
name = re.sub(r'([a-z0-9])([A-Z])', r'\1%s\2' % separator, name)
return name.lower()
@staticmethod
def __ToCamel(name, separator='_'):
# TODO(user): Consider what to do about leading or trailing
# underscores (such as `_refValue` in discovery).
return ''.join(s[0:1].upper() + s[1:] for s in name.split(separator))
@staticmethod
def __ToLowerCamel(name, separator='_'):
name = Names.__ToCamel(name, separator=separator)
return name[0].lower() + name[1:]
def __StripName(self, name):
"""Strip strip_prefix entries from name."""
if not name:
return name
for prefix in self.__strip_prefixes:
if name.startswith(prefix):
return name[len(prefix):]
return name
@staticmethod
def CleanName(name):
"""Perform generic name cleaning."""
name = re.sub('[^_A-Za-z0-9]', '_', name)
if name[0].isdigit():
name = '_%s' % name
while keyword.iskeyword(name) or name == 'exec':
name = '%s_' % name
# If we end up with __ as a prefix, we'll run afoul of python
# field renaming, so we manually correct for it.
if name.startswith('__'):
name = 'f%s' % name
return name
@staticmethod
def NormalizeRelativePath(path):
"""Normalize camelCase entries in path."""
path_components = path.split('/')
normalized_components = []
for component in path_components:
if re.match(r'{[A-Za-z0-9_]+}$', component):
normalized_components.append(
'{%s}' % Names.CleanName(component[1:-1]))
else:
normalized_components.append(component)
return '/'.join(normalized_components)
def NormalizeEnumName(self, enum_name):
if self.__capitalize_enums:
enum_name = enum_name.upper()
return self.CleanName(enum_name)
def ClassName(self, name, separator='_'):
"""Generate a valid class name from name."""
# TODO(user): Get rid of this case here and in MethodName.
if name is None:
return name
# TODO(user): This is a hack to handle the case of specific
# protorpc class names; clean this up.
if name.startswith(('protorpc.', 'message_types.',
'apitools.base.protorpclite.',
'apitools.base.protorpclite.message_types.')):
return name
name = self.__StripName(name)
name = self.__ToCamel(name, separator=separator)
return self.CleanName(name)
def MethodName(self, name, separator='_'):
"""Generate a valid method name from name."""
if name is None:
return None
name = Names.__ToCamel(name, separator=separator)
return Names.CleanName(name)
def FieldName(self, name):
"""Generate a valid field name from name."""
# TODO(user): We shouldn't need to strip this name, but some
# of the service names here are excessive. Fix the API and then
# remove this.
name = self.__StripName(name)
if self.__name_convention == 'LOWER_CAMEL':
name = Names.__ToLowerCamel(name)
elif self.__name_convention == 'LOWER_WITH_UNDER':
name = Names.__FromCamel(name)
return Names.CleanName(name)
@contextlib.contextmanager
def Chdir(dirname, create=True):
if not os.path.exists(dirname):
if not create:
raise OSError('Cannot find directory %s' % dirname)
else:
os.mkdir(dirname)
previous_directory = os.getcwd()
try:
os.chdir(dirname)
yield
finally:
os.chdir(previous_directory)
def NormalizeVersion(version):
# Currently, '.' is the only character that might cause us trouble.
return version.replace('.', '_')
def _ComputePaths(package, version, root_url, service_path):
"""Compute the base url and base path.
Attributes:
package: name field of the discovery, i.e. 'storage' for storage service.
version: version of the service, i.e. 'v1'.
root_url: root url of the service, i.e. 'https://www.googleapis.com/'.
service_path: path of the service under the rool url, i.e. 'storage/v1/'.
Returns:
base url: string, base url of the service,
'https://www.googleapis.com/storage/v1/' for the storage service.
base path: string, common prefix of service endpoints after the base url.
"""
full_path = urllib_parse.urljoin(root_url, service_path)
api_path_component = '/'.join((package, version, ''))
if api_path_component not in full_path:
return full_path, ''
prefix, _, suffix = full_path.rpartition(api_path_component)
return prefix + api_path_component, suffix
class ClientInfo(collections.namedtuple('ClientInfo', (
'package', 'scopes', 'version', 'client_id', 'client_secret',
'user_agent', 'client_class_name', 'url_version', 'api_key',
'base_url', 'base_path', 'mtls_base_url'))):
"""Container for client-related info and names."""
@classmethod
def Create(cls, discovery_doc,
scope_ls, client_id, client_secret, user_agent, names, api_key,
version_identifier):
"""Create a new ClientInfo object from a discovery document."""
scopes = set(
discovery_doc.get('auth', {}).get('oauth2', {}).get('scopes', {}))
scopes.update(scope_ls)
package = discovery_doc['name']
version = (
version_identifier or NormalizeVersion(discovery_doc['version']))
url_version = discovery_doc['version']
base_url, base_path = _ComputePaths(package, url_version,
discovery_doc['rootUrl'],
discovery_doc['servicePath'])
mtls_root_url = discovery_doc.get('mtlsRootUrl', '')
mtls_base_url = ''
if mtls_root_url:
mtls_base_url, _ = _ComputePaths(package, url_version,
mtls_root_url,
discovery_doc['servicePath'])
client_info = {
'package': package,
'version': version,
'url_version': url_version,
'scopes': sorted(list(scopes)),
'client_id': client_id,
'client_secret': client_secret,
'user_agent': user_agent,
'api_key': api_key,
'base_url': base_url,
'base_path': base_path,
'mtls_base_url': mtls_base_url,
}
client_class_name = '%s%s' % (
names.ClassName(client_info['package']),
names.ClassName(client_info['version']))
client_info['client_class_name'] = client_class_name
return cls(**client_info)
@property
def default_directory(self):
return self.package
@property
def client_rule_name(self):
return '%s_%s_client' % (self.package, self.version)
@property
def client_file_name(self):
return '%s.py' % self.client_rule_name
@property
def messages_rule_name(self):
return '%s_%s_messages' % (self.package, self.version)
@property
def services_rule_name(self):
return '%s_%s_services' % (self.package, self.version)
@property
def messages_file_name(self):
return '%s.py' % self.messages_rule_name
@property
def messages_proto_file_name(self):
return '%s.proto' % self.messages_rule_name
@property
def services_proto_file_name(self):
return '%s.proto' % self.services_rule_name
def ReplaceHomoglyphs(s):
"""Returns s with unicode homoglyphs replaced by ascii equivalents."""
homoglyphs = {
'\xa0': ' ', # &nbsp; ?
'\u00e3': '', # TODO(user) drop after .proto spurious char elided
'\u00a0': ' ', # &nbsp; ?
'\u00a9': '(C)', # COPYRIGHT SIGN (would you believe "asciiglyph"?)
'\u00ae': '(R)', # REGISTERED SIGN (would you believe "asciiglyph"?)
'\u2014': '-', # EM DASH
'\u2018': "'", # LEFT SINGLE QUOTATION MARK
'\u2019': "'", # RIGHT SINGLE QUOTATION MARK
'\u201c': '"', # LEFT DOUBLE QUOTATION MARK
'\u201d': '"', # RIGHT DOUBLE QUOTATION MARK
'\u2026': '...', # HORIZONTAL ELLIPSIS
'\u2e3a': '-', # TWO-EM DASH
}
def _ReplaceOne(c):
"""Returns the homoglyph or escaped replacement for c."""
equiv = homoglyphs.get(c)
if equiv is not None:
return equiv
try:
c.encode('ascii')
return c
except UnicodeError:
pass
try:
return c.encode('unicode-escape').decode('ascii')
except UnicodeError:
return '?'
return ''.join([_ReplaceOne(c) for c in s])
def CleanDescription(description):
"""Return a version of description safe for printing in a docstring."""
if not isinstance(description, six.string_types):
return description
if six.PY3:
# https://docs.python.org/3/reference/lexical_analysis.html#index-18
description = description.replace('\\N', '\\\\N')
description = description.replace('\\u', '\\\\u')
description = description.replace('\\U', '\\\\U')
description = ReplaceHomoglyphs(description)
return description.replace('"""', '" " "')
class SimplePrettyPrinter(object):
"""Simple pretty-printer that supports an indent contextmanager."""
def __init__(self, out):
self.__out = out
self.__indent = ''
self.__skip = False
self.__comment_context = False
@property
def indent(self):
return self.__indent
def CalculateWidth(self, max_width=78):
return max_width - len(self.indent)
@contextlib.contextmanager
def Indent(self, indent=' '):
previous_indent = self.__indent
self.__indent = '%s%s' % (previous_indent, indent)
yield
self.__indent = previous_indent
@contextlib.contextmanager
def CommentContext(self):
"""Print without any argument formatting."""
old_context = self.__comment_context
self.__comment_context = True
yield
self.__comment_context = old_context
def __call__(self, *args):
if self.__comment_context and args[1:]:
raise Error('Cannot do string interpolation in comment context')
if args and args[0]:
if not self.__comment_context:
line = (args[0] % args[1:]).rstrip()
else:
line = args[0].rstrip()
line = ReplaceHomoglyphs(line)
try:
print('%s%s' % (self.__indent, line), file=self.__out)
except UnicodeEncodeError:
line = line.encode('ascii', 'backslashreplace').decode('ascii')
print('%s%s' % (self.__indent, line), file=self.__out)
else:
print('', file=self.__out)
def _NormalizeDiscoveryUrls(discovery_url):
"""Expands a few abbreviations into full discovery urls."""
if discovery_url.startswith('http'):
return [discovery_url]
elif '.' not in discovery_url:
raise ValueError('Unrecognized value "%s" for discovery url')
api_name, _, api_version = discovery_url.partition('.')
return [
'https://www.googleapis.com/discovery/v1/apis/%s/%s/rest' % (
api_name, api_version),
'https://%s.googleapis.com/$discovery/rest?version=%s' % (
api_name, api_version),
]
def _Gunzip(gzipped_content):
"""Returns gunzipped content from gzipped contents."""
f = tempfile.NamedTemporaryFile(suffix='gz', mode='w+b', delete=False)
try:
f.write(gzipped_content)
f.close() # force file synchronization
with gzip.open(f.name, 'rb') as h:
decompressed_content = h.read()
return decompressed_content
finally:
os.unlink(f.name)
def _GetURLContent(url):
"""Download and return the content of URL."""
response = urllib_request.urlopen(url)
encoding = response.info().get('Content-Encoding')
if encoding == 'gzip':
content = _Gunzip(response.read())
else:
content = response.read()
return content
def FetchDiscoveryDoc(discovery_url, retries=5):
"""Fetch the discovery document at the given url."""
discovery_urls = _NormalizeDiscoveryUrls(discovery_url)
discovery_doc = None
last_exception = None
for url in discovery_urls:
for _ in range(retries):
try:
content = _GetURLContent(url)
if isinstance(content, bytes):
content = content.decode('utf8')
discovery_doc = json.loads(content)
if discovery_doc:
return discovery_doc
except (urllib_error.HTTPError, urllib_error.URLError) as e:
logging.info(
'Attempting to fetch discovery doc again after "%s"', e)
last_exception = e
if discovery_doc is None:
raise CommunicationError(
'Could not find discovery doc at any of %s: %s' % (
discovery_urls, last_exception))