622 lines
21 KiB
Python
622 lines
21 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2016 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Utilities for the gcloud meta apis surface."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
from apitools.base.protorpclite import messages
|
|
from apitools.base.py import exceptions as apitools_exc
|
|
from apitools.base.py import list_pager
|
|
|
|
from googlecloudsdk.api_lib.util import apis
|
|
from googlecloudsdk.api_lib.util import apis_internal
|
|
from googlecloudsdk.api_lib.util import resource
|
|
from googlecloudsdk.command_lib.util.apis import arg_utils
|
|
from googlecloudsdk.core import exceptions
|
|
from googlecloudsdk.core import log
|
|
from googlecloudsdk.generated_clients.apis import apis_map
|
|
import six
|
|
|
|
NAME_SEPARATOR = '.'
|
|
|
|
|
|
class Error(exceptions.Error):
|
|
pass
|
|
|
|
|
|
class UnknownAPIError(Error):
|
|
|
|
def __init__(self, api_name):
|
|
super(UnknownAPIError, self).__init__(
|
|
'API [{api}] does not exist or is not registered.'
|
|
.format(api=api_name)
|
|
)
|
|
|
|
|
|
class UnknownAPIVersionError(Error):
|
|
|
|
def __init__(self, api_name, version):
|
|
super(UnknownAPIVersionError, self).__init__(
|
|
'Version [{version}] does not exist for API [{api}].'
|
|
.format(version=version, api=api_name)
|
|
)
|
|
|
|
|
|
class NoDefaultVersionError(Error):
|
|
|
|
def __init__(self, api_name):
|
|
super(NoDefaultVersionError, self).__init__(
|
|
'API [{api}] does not have a default version. You must specify which '
|
|
'version to use.'.format(api=api_name)
|
|
)
|
|
|
|
|
|
class UnknownCollectionError(Error):
|
|
|
|
def __init__(self, api_name, api_version, collection):
|
|
super(UnknownCollectionError, self).__init__(
|
|
'Collection [{collection}] does not exist for [{api}] [{version}].'
|
|
.format(collection=collection, api=api_name, version=api_version)
|
|
)
|
|
|
|
|
|
class UnknownMethodError(Error):
|
|
|
|
def __init__(self, method, collection):
|
|
super(UnknownMethodError, self).__init__(
|
|
'Method [{method}] does not exist for collection [{collection}].'
|
|
.format(method=method, collection=collection)
|
|
)
|
|
|
|
|
|
class APICallError(Error):
|
|
pass
|
|
|
|
|
|
class API(object):
|
|
"""A data holder for returning API data for display."""
|
|
|
|
def __init__(self, name, version, is_default, client, base_url):
|
|
self.name = name
|
|
self.version = version
|
|
self.is_default = is_default
|
|
self._client = client
|
|
self.base_url = base_url
|
|
|
|
def GetMessagesModule(self):
|
|
return self._client.MESSAGES_MODULE
|
|
|
|
|
|
class APICollection(object):
|
|
"""A data holder for collection information for an API."""
|
|
|
|
def __init__(self, collection_info):
|
|
self.api_name = collection_info.api_name
|
|
self.api_version = collection_info.api_version
|
|
self.base_url = collection_info.base_url
|
|
self.docs_url = collection_info.docs_url
|
|
self.name = collection_info.name
|
|
self.full_name = collection_info.full_name
|
|
self.detailed_path = collection_info.GetPath('')
|
|
self.detailed_params = collection_info.GetParams('')
|
|
self.path = collection_info.path
|
|
self.params = collection_info.params
|
|
self.enable_uri_parsing = collection_info.enable_uri_parsing
|
|
|
|
|
|
class APIMethod(object):
|
|
"""A data holder for method information for an API collection."""
|
|
|
|
def __init__(self, service, name, api_collection, method_config,
|
|
disable_pagination=False):
|
|
self._service = service
|
|
self._method_name = name
|
|
self._disable_pagination = disable_pagination
|
|
|
|
self.collection = api_collection
|
|
|
|
self.name = method_config.method_id
|
|
dotted_path = self.collection.full_name + NAME_SEPARATOR
|
|
if self.name.startswith(dotted_path):
|
|
self.name = self.name[len(dotted_path):]
|
|
|
|
self.path = _RemoveVersionPrefix(
|
|
self.collection.api_version, method_config.relative_path)
|
|
self.params = method_config.ordered_params
|
|
if method_config.flat_path:
|
|
self.detailed_path = _RemoveVersionPrefix(
|
|
self.collection.api_version, method_config.flat_path)
|
|
self.detailed_params = resource.GetParamsFromPath(method_config.flat_path)
|
|
else:
|
|
self.detailed_path = self.path
|
|
self.detailed_params = self.params
|
|
|
|
self.http_method = method_config.http_method
|
|
self.request_field = method_config.request_field
|
|
self.request_type = method_config.request_type_name
|
|
self.response_type = method_config.response_type_name
|
|
|
|
self._request_collection = self._RequestCollection()
|
|
# Keep track of method query parameters
|
|
self.query_params = method_config.query_params
|
|
|
|
@property
|
|
def resource_argument_collection(self):
|
|
"""Gets the collection that should be used to represent the resource.
|
|
|
|
Most of the time this is the same as request_collection because all methods
|
|
in a collection operate on the same resource and so the API method takes
|
|
the same parameters that make up the resource.
|
|
|
|
One exception is List methods where the API parameters are for the parent
|
|
collection. Because people don't specify the resource directly for list
|
|
commands this also returns the parent collection for parsing purposes.
|
|
|
|
The other exception is Create methods. They reference the parent collection
|
|
list Like, but the difference is that we *do* want to specify the actual
|
|
resource on the command line, so the original resource collection is
|
|
returned here instead of the one that matches the API methods. When
|
|
generating the request, you must figure out how to generate the message
|
|
correctly from the parsed resource (as you cannot simply pass the reference
|
|
to the API).
|
|
|
|
Returns:
|
|
APICollection: The collection.
|
|
"""
|
|
if self.IsList():
|
|
return self._request_collection
|
|
return self.collection
|
|
|
|
@property
|
|
def request_collection(self):
|
|
"""Gets the API collection that matches the parameters of the API method."""
|
|
return self._request_collection
|
|
|
|
def GetRequestType(self):
|
|
"""Gets the apitools request class for this method."""
|
|
return self._service.GetRequestType(self._method_name)
|
|
|
|
def GetResponseType(self):
|
|
"""Gets the apitools response class for this method."""
|
|
return self._service.GetResponseType(self._method_name)
|
|
|
|
def GetEffectiveResponseType(self):
|
|
"""Gets the effective apitools response class for this method.
|
|
|
|
This will be different from GetResponseType for List methods if we are
|
|
extracting the list of response items from the overall response. This will
|
|
always match the type of response that Call() returns.
|
|
|
|
Returns:
|
|
The apitools Message object.
|
|
"""
|
|
if (item_field := self.ListItemField()) and self.HasTokenizedRequest():
|
|
return arg_utils.GetFieldFromMessage(
|
|
self.GetResponseType(), item_field).type
|
|
else:
|
|
return self.GetResponseType()
|
|
|
|
def GetMessageByName(self, name):
|
|
"""Gets a arbitrary apitools message class by name.
|
|
|
|
This method can be used to get arbitrary apitools messages from the
|
|
underlying service. Examples:
|
|
|
|
policy_type = method.GetMessageByName('Policy')
|
|
status_type = method.GetMessageByName('Status')
|
|
|
|
Args:
|
|
name: str, the name of the message to return.
|
|
Returns:
|
|
The apitools Message object.
|
|
"""
|
|
msgs = self._service.client.MESSAGES_MODULE
|
|
return getattr(msgs, name, None)
|
|
|
|
def IsList(self):
|
|
"""Determines whether this is a List method."""
|
|
return self._method_name == 'List'
|
|
|
|
def HasTokenizedRequest(self):
|
|
"""Determines whether this is a method that supports paging."""
|
|
return (not self._disable_pagination
|
|
and 'pageToken' in self._RequestFieldNames()
|
|
and 'nextPageToken' in self._ResponseFieldNames())
|
|
|
|
def BatchPageSizeField(self):
|
|
"""Gets the name of the page size field in the request if it exists."""
|
|
request_fields = self._RequestFieldNames()
|
|
if 'maxResults' in request_fields:
|
|
return 'maxResults'
|
|
if 'pageSize' in request_fields:
|
|
return 'pageSize'
|
|
return None
|
|
|
|
def ListItemField(self):
|
|
"""Gets the name of the field that contains the items in paginated response.
|
|
|
|
This will return None if the method is not a paginated or if a single
|
|
repeated field of items could not be found in the response type.
|
|
|
|
Returns:
|
|
str, The name of the field or None.
|
|
"""
|
|
if self._disable_pagination:
|
|
return None
|
|
|
|
response = self.GetResponseType()
|
|
found = [f for f in response.all_fields()
|
|
if f.variant == messages.Variant.MESSAGE and f.repeated]
|
|
if len(found) == 1:
|
|
return found[0].name
|
|
else:
|
|
return None
|
|
|
|
def _RequestCollection(self):
|
|
"""Gets the collection that matches the API parameters of this method.
|
|
|
|
Methods apply to elements of a collection. The resource argument is always
|
|
of the type of that collection. List is an exception where you are listing
|
|
items of that collection so the argument to be provided is that of the
|
|
parent collection. This method returns the collection that should be used
|
|
to parse the resource for this specific method.
|
|
|
|
Returns:
|
|
APICollection, The collection to use or None if no parent collection could
|
|
be found.
|
|
"""
|
|
if self.detailed_params == self.collection.detailed_params:
|
|
return self.collection
|
|
collections = GetAPICollections(
|
|
self.collection.api_name, self.collection.api_version)
|
|
for c in collections:
|
|
if (self.detailed_params == c.detailed_params
|
|
and c.detailed_path in self.detailed_path):
|
|
return c
|
|
# Fallback to collection that matches params only
|
|
for c in collections:
|
|
if self.detailed_params == c.detailed_params:
|
|
return c
|
|
return None
|
|
|
|
def _RequestFieldNames(self):
|
|
"""Gets the fields that are actually a part of the request message.
|
|
|
|
For APIs that use atomic names, this will only be the single name parameter
|
|
(and any other message fields) but not the detailed parameters.
|
|
|
|
Returns:
|
|
[str], The field names.
|
|
"""
|
|
return [f.name for f in self.GetRequestType().all_fields()]
|
|
|
|
def _ResponseFieldNames(self):
|
|
"""Gets the fields that are actually a part of the response message.
|
|
|
|
Returns:
|
|
[str], The field names.
|
|
"""
|
|
return [f.name for f in self.GetResponseType().all_fields()]
|
|
|
|
def Call(self, request, client=None, global_params=None, raw=False,
|
|
limit=None, page_size=None):
|
|
"""Executes this method with the given arguments.
|
|
|
|
Args:
|
|
request: The apitools request object to send.
|
|
client: base_api.BaseApiClient, An API client to use for making requests.
|
|
global_params: {str: str}, A dictionary of global parameters to send with
|
|
the request.
|
|
raw: bool, True to not do any processing of the response, False to maybe
|
|
do processing for List results.
|
|
limit: int, The max number of items to return if this is a List method.
|
|
page_size: int, The max number of items to return in a page if this API
|
|
supports paging.
|
|
|
|
Returns:
|
|
The response from the API.
|
|
"""
|
|
if client is None:
|
|
client = apis.GetClientInstance(
|
|
self.collection.api_name, self.collection.api_version)
|
|
service = _GetService(client, self.collection.name)
|
|
request_func = self._GetRequestFunc(
|
|
service, request, raw=raw, limit=limit, page_size=page_size)
|
|
try:
|
|
return request_func(global_params=global_params)
|
|
except apitools_exc.InvalidUserInputError as e:
|
|
log.debug('', exc_info=True)
|
|
raise APICallError(str(e))
|
|
|
|
def _GetRequestFunc(self, service, request, raw=False,
|
|
limit=None, page_size=None):
|
|
"""Gets a request function to call and process the results.
|
|
|
|
If this is a method with paginated response, it may flatten the response
|
|
depending on if the List Pager can be used.
|
|
|
|
Args:
|
|
service: The apitools service that will be making the request.
|
|
request: The apitools request object to send.
|
|
raw: bool, True to not do any processing of the response, False to maybe
|
|
do processing for List results.
|
|
limit: int, The max number of items to return if this is a List method.
|
|
page_size: int, The max number of items to return in a page if this API
|
|
supports paging.
|
|
|
|
Returns:
|
|
A function to make the request.
|
|
"""
|
|
if raw or self._disable_pagination:
|
|
return self._NormalRequest(service, request)
|
|
|
|
item_field = self.ListItemField()
|
|
if not item_field:
|
|
if self.IsList():
|
|
log.debug(
|
|
'Unable to flatten list response, raw results being returned.')
|
|
return self._NormalRequest(service, request)
|
|
|
|
if not self.HasTokenizedRequest():
|
|
# API doesn't do paging.
|
|
if self.IsList():
|
|
return self._FlatNonPagedRequest(service, request, item_field)
|
|
else:
|
|
return self._NormalRequest(service, request)
|
|
|
|
def RequestFunc(global_params=None):
|
|
return list_pager.YieldFromList(
|
|
service, request, method=self._method_name, field=item_field,
|
|
global_params=global_params, limit=limit,
|
|
current_token_attribute='pageToken',
|
|
next_token_attribute='nextPageToken',
|
|
batch_size_attribute=self.BatchPageSizeField(),
|
|
batch_size=page_size)
|
|
return RequestFunc
|
|
|
|
def _NormalRequest(self, service, request):
|
|
"""Generates a basic request function for the method.
|
|
|
|
Args:
|
|
service: The apitools service that will be making the request.
|
|
request: The apitools request object to send.
|
|
|
|
Returns:
|
|
A function to make the request.
|
|
"""
|
|
def RequestFunc(global_params=None):
|
|
method = getattr(service, self._method_name)
|
|
return method(request, global_params=global_params)
|
|
return RequestFunc
|
|
|
|
def _FlatNonPagedRequest(self, service, request, item_field):
|
|
"""Generates a request function for the method that extracts an item list.
|
|
|
|
List responses usually have a single repeated field that represents the
|
|
actual items being listed. This request function returns only those items
|
|
not the entire response.
|
|
|
|
Args:
|
|
service: The apitools service that will be making the request.
|
|
request: The apitools request object to send.
|
|
item_field: str, The name of the field that the list of items can be found
|
|
in.
|
|
|
|
Returns:
|
|
A function to make the request.
|
|
"""
|
|
def RequestFunc(global_params=None):
|
|
response = self._NormalRequest(service, request)(
|
|
global_params=global_params)
|
|
return getattr(response, item_field)
|
|
return RequestFunc
|
|
|
|
|
|
def _RemoveVersionPrefix(api_version, path):
|
|
"""Trims the version number off the front of a URL path if present."""
|
|
if not path:
|
|
return None
|
|
if path.startswith(api_version):
|
|
return path[len(api_version) + 1:]
|
|
return path
|
|
|
|
|
|
def _ValidateAndGetDefaultVersion(api_name, api_version):
|
|
"""Validates the API exists and gets the default version if not given."""
|
|
# pylint:disable=protected-access
|
|
api_name, _ = apis_internal._GetApiNameAndAlias(api_name)
|
|
api_vers = apis_map.MAP.get(api_name, {})
|
|
if not api_vers:
|
|
# No versions, this API is not registered.
|
|
raise UnknownAPIError(api_name)
|
|
if api_version:
|
|
if api_version not in api_vers:
|
|
raise UnknownAPIVersionError(api_name, api_version)
|
|
return api_version
|
|
|
|
for version, api_def in six.iteritems(api_vers):
|
|
if api_def.default_version:
|
|
return version
|
|
raise NoDefaultVersionError(api_name)
|
|
|
|
|
|
def GetAPI(api_name, api_version=None):
|
|
"""Get a specific API definition.
|
|
|
|
Args:
|
|
api_name: str, The name of the API.
|
|
api_version: str, The version string of the API.
|
|
|
|
Returns:
|
|
API, The API definition.
|
|
"""
|
|
api_version = _ValidateAndGetDefaultVersion(api_name, api_version)
|
|
# pylint: disable=protected-access
|
|
api_def = apis_internal.GetApiDef(api_name, api_version)
|
|
if api_def.apitools:
|
|
api_client = apis_internal._GetClientClassFromDef(api_def)
|
|
else:
|
|
api_client = apis_internal._GetGapicClientClass(api_name, api_version)
|
|
|
|
if hasattr(api_client, 'BASE_URL'):
|
|
base_url = api_client.BASE_URL
|
|
else:
|
|
try:
|
|
base_url = apis_internal._GetResourceModule(
|
|
api_name, api_version
|
|
).BASE_URL
|
|
except ImportError:
|
|
base_url = 'https://{}.googleapis.com/{}'.format(api_name, api_version)
|
|
return API(
|
|
api_name, api_version, api_def.default_version, api_client, base_url
|
|
)
|
|
|
|
|
|
def GetAllAPIs():
|
|
"""Gets all registered APIs.
|
|
|
|
Returns:
|
|
[API], A list of API definitions.
|
|
"""
|
|
all_apis = []
|
|
for api_name, versions in six.iteritems(apis_map.MAP):
|
|
for api_version, _ in six.iteritems(versions):
|
|
all_apis.append(GetAPI(api_name, api_version))
|
|
return all_apis
|
|
|
|
|
|
def _SplitFullCollectionName(collection):
|
|
return tuple(collection.split(NAME_SEPARATOR, 1))
|
|
|
|
|
|
def GetAPICollections(api_name=None, api_version=None):
|
|
"""Gets the registered collections for the given API version.
|
|
|
|
Args:
|
|
api_name: str, The name of the API or None for all apis.
|
|
api_version: str, The version string of the API or None to use the default
|
|
version.
|
|
|
|
Returns:
|
|
[APICollection], A list of the registered collections.
|
|
"""
|
|
if api_name:
|
|
all_apis = {api_name: _ValidateAndGetDefaultVersion(api_name, api_version)}
|
|
else:
|
|
all_apis = {x.name: x.version for x in GetAllAPIs() if x.is_default}
|
|
|
|
collections = []
|
|
for n, v in six.iteritems(all_apis):
|
|
# pylint:disable=protected-access
|
|
collections.extend(
|
|
[APICollection(c) for c in apis_internal._GetApiCollections(n, v)])
|
|
return collections
|
|
|
|
|
|
def GetAPICollection(full_collection_name, api_version=None):
|
|
"""Gets the given collection for the given API version.
|
|
|
|
Args:
|
|
full_collection_name: str, The collection to get including the api name.
|
|
api_version: str, The version string of the API or None to use the default
|
|
for this API.
|
|
|
|
Returns:
|
|
APICollection, The requested API collection.
|
|
|
|
Raises:
|
|
UnknownCollectionError: If the collection does not exist for the given API
|
|
and version.
|
|
"""
|
|
api_name, collection = _SplitFullCollectionName(full_collection_name)
|
|
api_version = _ValidateAndGetDefaultVersion(api_name, api_version)
|
|
collections = GetAPICollections(api_name, api_version)
|
|
for c in collections:
|
|
if c.name == collection:
|
|
return c
|
|
raise UnknownCollectionError(api_name, api_version, collection)
|
|
|
|
|
|
def GetMethod(full_collection_name, method, api_version=None,
|
|
disable_pagination=False):
|
|
"""Gets the specification for the given API method.
|
|
|
|
Args:
|
|
full_collection_name: str, The collection including the api name.
|
|
method: str, The name of the method.
|
|
api_version: str, The version string of the API or None to use the default
|
|
for this API.
|
|
disable_pagination: bool, Boolean for whether pagination should be disabled
|
|
|
|
Returns:
|
|
APIMethod, The method specification.
|
|
|
|
Raises:
|
|
UnknownMethodError: If the method does not exist on the collection.
|
|
"""
|
|
methods = GetMethods(
|
|
full_collection_name, api_version=api_version,
|
|
disable_pagination=disable_pagination)
|
|
for m in methods:
|
|
if m.name == method:
|
|
return m
|
|
raise UnknownMethodError(method, full_collection_name)
|
|
|
|
|
|
def _GetService(client, collection_name):
|
|
return getattr(client, collection_name.replace(NAME_SEPARATOR, '_'), None)
|
|
|
|
|
|
def _GetApiClient(api_name, api_version):
|
|
"""Gets the repesctive api client for the api."""
|
|
api_def = apis_internal.GetApiDef(api_name, api_version)
|
|
if api_def.apitools:
|
|
client = apis.GetClientInstance(api_name, api_version, no_http=True)
|
|
else:
|
|
client = apis.GetGapicClientInstance(api_name, api_version)
|
|
return client
|
|
|
|
|
|
def GetMethods(
|
|
full_collection_name, api_version=None, disable_pagination=False):
|
|
"""Gets all the methods available on the given collection.
|
|
|
|
Args:
|
|
full_collection_name: str, The collection including the api name.
|
|
api_version: str, The version string of the API or None to use the default
|
|
for this API.
|
|
disable_pagination: bool, Boolean for whether pagination should be disabled
|
|
|
|
Returns:
|
|
[APIMethod], The method specifications.
|
|
"""
|
|
api_collection = GetAPICollection(full_collection_name,
|
|
api_version=api_version)
|
|
|
|
client = _GetApiClient(api_collection.api_name, api_collection.api_version)
|
|
service = _GetService(client, api_collection.name)
|
|
if not service:
|
|
# This is a synthetic collection that does not actually have a backing API.
|
|
return []
|
|
|
|
method_names = service.GetMethodsList()
|
|
method_configs = [(name, service.GetMethodConfig(name))
|
|
for name in method_names]
|
|
return [APIMethod(service, name, api_collection, config, disable_pagination)
|
|
for name, config in method_configs]
|