# -*- coding: utf-8 -*- # # Copyright 2016 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for the gcloud meta apis surface.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from apitools.base.protorpclite import messages from apitools.base.py import exceptions as apitools_exc from apitools.base.py import list_pager from googlecloudsdk.api_lib.util import apis from googlecloudsdk.api_lib.util import apis_internal from googlecloudsdk.api_lib.util import resource from googlecloudsdk.command_lib.util.apis import arg_utils from googlecloudsdk.core import exceptions from googlecloudsdk.core import log from googlecloudsdk.generated_clients.apis import apis_map import six NAME_SEPARATOR = '.' class Error(exceptions.Error): pass class UnknownAPIError(Error): def __init__(self, api_name): super(UnknownAPIError, self).__init__( 'API [{api}] does not exist or is not registered.' .format(api=api_name) ) class UnknownAPIVersionError(Error): def __init__(self, api_name, version): super(UnknownAPIVersionError, self).__init__( 'Version [{version}] does not exist for API [{api}].' .format(version=version, api=api_name) ) class NoDefaultVersionError(Error): def __init__(self, api_name): super(NoDefaultVersionError, self).__init__( 'API [{api}] does not have a default version. You must specify which ' 'version to use.'.format(api=api_name) ) class UnknownCollectionError(Error): def __init__(self, api_name, api_version, collection): super(UnknownCollectionError, self).__init__( 'Collection [{collection}] does not exist for [{api}] [{version}].' .format(collection=collection, api=api_name, version=api_version) ) class UnknownMethodError(Error): def __init__(self, method, collection): super(UnknownMethodError, self).__init__( 'Method [{method}] does not exist for collection [{collection}].' .format(method=method, collection=collection) ) class APICallError(Error): pass class API(object): """A data holder for returning API data for display.""" def __init__(self, name, version, is_default, client, base_url): self.name = name self.version = version self.is_default = is_default self._client = client self.base_url = base_url def GetMessagesModule(self): return self._client.MESSAGES_MODULE class APICollection(object): """A data holder for collection information for an API.""" def __init__(self, collection_info): self.api_name = collection_info.api_name self.api_version = collection_info.api_version self.base_url = collection_info.base_url self.docs_url = collection_info.docs_url self.name = collection_info.name self.full_name = collection_info.full_name self.detailed_path = collection_info.GetPath('') self.detailed_params = collection_info.GetParams('') self.path = collection_info.path self.params = collection_info.params self.enable_uri_parsing = collection_info.enable_uri_parsing class APIMethod(object): """A data holder for method information for an API collection.""" def __init__(self, service, name, api_collection, method_config, disable_pagination=False): self._service = service self._method_name = name self._disable_pagination = disable_pagination self.collection = api_collection self.name = method_config.method_id dotted_path = self.collection.full_name + NAME_SEPARATOR if self.name.startswith(dotted_path): self.name = self.name[len(dotted_path):] self.path = _RemoveVersionPrefix( self.collection.api_version, method_config.relative_path) self.params = method_config.ordered_params if method_config.flat_path: self.detailed_path = _RemoveVersionPrefix( self.collection.api_version, method_config.flat_path) self.detailed_params = resource.GetParamsFromPath(method_config.flat_path) else: self.detailed_path = self.path self.detailed_params = self.params self.http_method = method_config.http_method self.request_field = method_config.request_field self.request_type = method_config.request_type_name self.response_type = method_config.response_type_name self._request_collection = self._RequestCollection() # Keep track of method query parameters self.query_params = method_config.query_params @property def resource_argument_collection(self): """Gets the collection that should be used to represent the resource. Most of the time this is the same as request_collection because all methods in a collection operate on the same resource and so the API method takes the same parameters that make up the resource. One exception is List methods where the API parameters are for the parent collection. Because people don't specify the resource directly for list commands this also returns the parent collection for parsing purposes. The other exception is Create methods. They reference the parent collection list Like, but the difference is that we *do* want to specify the actual resource on the command line, so the original resource collection is returned here instead of the one that matches the API methods. When generating the request, you must figure out how to generate the message correctly from the parsed resource (as you cannot simply pass the reference to the API). Returns: APICollection: The collection. """ if self.IsList(): return self._request_collection return self.collection @property def request_collection(self): """Gets the API collection that matches the parameters of the API method.""" return self._request_collection def GetRequestType(self): """Gets the apitools request class for this method.""" return self._service.GetRequestType(self._method_name) def GetResponseType(self): """Gets the apitools response class for this method.""" return self._service.GetResponseType(self._method_name) def GetEffectiveResponseType(self): """Gets the effective apitools response class for this method. This will be different from GetResponseType for List methods if we are extracting the list of response items from the overall response. This will always match the type of response that Call() returns. Returns: The apitools Message object. """ if (item_field := self.ListItemField()) and self.HasTokenizedRequest(): return arg_utils.GetFieldFromMessage( self.GetResponseType(), item_field).type else: return self.GetResponseType() def GetMessageByName(self, name): """Gets a arbitrary apitools message class by name. This method can be used to get arbitrary apitools messages from the underlying service. Examples: policy_type = method.GetMessageByName('Policy') status_type = method.GetMessageByName('Status') Args: name: str, the name of the message to return. Returns: The apitools Message object. """ msgs = self._service.client.MESSAGES_MODULE return getattr(msgs, name, None) def IsList(self): """Determines whether this is a List method.""" return self._method_name == 'List' def HasTokenizedRequest(self): """Determines whether this is a method that supports paging.""" return (not self._disable_pagination and 'pageToken' in self._RequestFieldNames() and 'nextPageToken' in self._ResponseFieldNames()) def BatchPageSizeField(self): """Gets the name of the page size field in the request if it exists.""" request_fields = self._RequestFieldNames() if 'maxResults' in request_fields: return 'maxResults' if 'pageSize' in request_fields: return 'pageSize' return None def ListItemField(self): """Gets the name of the field that contains the items in paginated response. This will return None if the method is not a paginated or if a single repeated field of items could not be found in the response type. Returns: str, The name of the field or None. """ if self._disable_pagination: return None response = self.GetResponseType() found = [f for f in response.all_fields() if f.variant == messages.Variant.MESSAGE and f.repeated] if len(found) == 1: return found[0].name else: return None def _RequestCollection(self): """Gets the collection that matches the API parameters of this method. Methods apply to elements of a collection. The resource argument is always of the type of that collection. List is an exception where you are listing items of that collection so the argument to be provided is that of the parent collection. This method returns the collection that should be used to parse the resource for this specific method. Returns: APICollection, The collection to use or None if no parent collection could be found. """ if self.detailed_params == self.collection.detailed_params: return self.collection collections = GetAPICollections( self.collection.api_name, self.collection.api_version) for c in collections: if (self.detailed_params == c.detailed_params and c.detailed_path in self.detailed_path): return c # Fallback to collection that matches params only for c in collections: if self.detailed_params == c.detailed_params: return c return None def _RequestFieldNames(self): """Gets the fields that are actually a part of the request message. For APIs that use atomic names, this will only be the single name parameter (and any other message fields) but not the detailed parameters. Returns: [str], The field names. """ return [f.name for f in self.GetRequestType().all_fields()] def _ResponseFieldNames(self): """Gets the fields that are actually a part of the response message. Returns: [str], The field names. """ return [f.name for f in self.GetResponseType().all_fields()] def Call(self, request, client=None, global_params=None, raw=False, limit=None, page_size=None): """Executes this method with the given arguments. Args: request: The apitools request object to send. client: base_api.BaseApiClient, An API client to use for making requests. global_params: {str: str}, A dictionary of global parameters to send with the request. raw: bool, True to not do any processing of the response, False to maybe do processing for List results. limit: int, The max number of items to return if this is a List method. page_size: int, The max number of items to return in a page if this API supports paging. Returns: The response from the API. """ if client is None: client = apis.GetClientInstance( self.collection.api_name, self.collection.api_version) service = _GetService(client, self.collection.name) request_func = self._GetRequestFunc( service, request, raw=raw, limit=limit, page_size=page_size) try: return request_func(global_params=global_params) except apitools_exc.InvalidUserInputError as e: log.debug('', exc_info=True) raise APICallError(str(e)) def _GetRequestFunc(self, service, request, raw=False, limit=None, page_size=None): """Gets a request function to call and process the results. If this is a method with paginated response, it may flatten the response depending on if the List Pager can be used. Args: service: The apitools service that will be making the request. request: The apitools request object to send. raw: bool, True to not do any processing of the response, False to maybe do processing for List results. limit: int, The max number of items to return if this is a List method. page_size: int, The max number of items to return in a page if this API supports paging. Returns: A function to make the request. """ if raw or self._disable_pagination: return self._NormalRequest(service, request) item_field = self.ListItemField() if not item_field: if self.IsList(): log.debug( 'Unable to flatten list response, raw results being returned.') return self._NormalRequest(service, request) if not self.HasTokenizedRequest(): # API doesn't do paging. if self.IsList(): return self._FlatNonPagedRequest(service, request, item_field) else: return self._NormalRequest(service, request) def RequestFunc(global_params=None): return list_pager.YieldFromList( service, request, method=self._method_name, field=item_field, global_params=global_params, limit=limit, current_token_attribute='pageToken', next_token_attribute='nextPageToken', batch_size_attribute=self.BatchPageSizeField(), batch_size=page_size) return RequestFunc def _NormalRequest(self, service, request): """Generates a basic request function for the method. Args: service: The apitools service that will be making the request. request: The apitools request object to send. Returns: A function to make the request. """ def RequestFunc(global_params=None): method = getattr(service, self._method_name) return method(request, global_params=global_params) return RequestFunc def _FlatNonPagedRequest(self, service, request, item_field): """Generates a request function for the method that extracts an item list. List responses usually have a single repeated field that represents the actual items being listed. This request function returns only those items not the entire response. Args: service: The apitools service that will be making the request. request: The apitools request object to send. item_field: str, The name of the field that the list of items can be found in. Returns: A function to make the request. """ def RequestFunc(global_params=None): response = self._NormalRequest(service, request)( global_params=global_params) return getattr(response, item_field) return RequestFunc def _RemoveVersionPrefix(api_version, path): """Trims the version number off the front of a URL path if present.""" if not path: return None if path.startswith(api_version): return path[len(api_version) + 1:] return path def _ValidateAndGetDefaultVersion(api_name, api_version): """Validates the API exists and gets the default version if not given.""" # pylint:disable=protected-access api_name, _ = apis_internal._GetApiNameAndAlias(api_name) api_vers = apis_map.MAP.get(api_name, {}) if not api_vers: # No versions, this API is not registered. raise UnknownAPIError(api_name) if api_version: if api_version not in api_vers: raise UnknownAPIVersionError(api_name, api_version) return api_version for version, api_def in six.iteritems(api_vers): if api_def.default_version: return version raise NoDefaultVersionError(api_name) def GetAPI(api_name, api_version=None): """Get a specific API definition. Args: api_name: str, The name of the API. api_version: str, The version string of the API. Returns: API, The API definition. """ api_version = _ValidateAndGetDefaultVersion(api_name, api_version) # pylint: disable=protected-access api_def = apis_internal.GetApiDef(api_name, api_version) if api_def.apitools: api_client = apis_internal._GetClientClassFromDef(api_def) else: api_client = apis_internal._GetGapicClientClass(api_name, api_version) if hasattr(api_client, 'BASE_URL'): base_url = api_client.BASE_URL else: try: base_url = apis_internal._GetResourceModule( api_name, api_version ).BASE_URL except ImportError: base_url = 'https://{}.googleapis.com/{}'.format(api_name, api_version) return API( api_name, api_version, api_def.default_version, api_client, base_url ) def GetAllAPIs(): """Gets all registered APIs. Returns: [API], A list of API definitions. """ all_apis = [] for api_name, versions in six.iteritems(apis_map.MAP): for api_version, _ in six.iteritems(versions): all_apis.append(GetAPI(api_name, api_version)) return all_apis def _SplitFullCollectionName(collection): return tuple(collection.split(NAME_SEPARATOR, 1)) def GetAPICollections(api_name=None, api_version=None): """Gets the registered collections for the given API version. Args: api_name: str, The name of the API or None for all apis. api_version: str, The version string of the API or None to use the default version. Returns: [APICollection], A list of the registered collections. """ if api_name: all_apis = {api_name: _ValidateAndGetDefaultVersion(api_name, api_version)} else: all_apis = {x.name: x.version for x in GetAllAPIs() if x.is_default} collections = [] for n, v in six.iteritems(all_apis): # pylint:disable=protected-access collections.extend( [APICollection(c) for c in apis_internal._GetApiCollections(n, v)]) return collections def GetAPICollection(full_collection_name, api_version=None): """Gets the given collection for the given API version. Args: full_collection_name: str, The collection to get including the api name. api_version: str, The version string of the API or None to use the default for this API. Returns: APICollection, The requested API collection. Raises: UnknownCollectionError: If the collection does not exist for the given API and version. """ api_name, collection = _SplitFullCollectionName(full_collection_name) api_version = _ValidateAndGetDefaultVersion(api_name, api_version) collections = GetAPICollections(api_name, api_version) for c in collections: if c.name == collection: return c raise UnknownCollectionError(api_name, api_version, collection) def GetMethod(full_collection_name, method, api_version=None, disable_pagination=False): """Gets the specification for the given API method. Args: full_collection_name: str, The collection including the api name. method: str, The name of the method. api_version: str, The version string of the API or None to use the default for this API. disable_pagination: bool, Boolean for whether pagination should be disabled Returns: APIMethod, The method specification. Raises: UnknownMethodError: If the method does not exist on the collection. """ methods = GetMethods( full_collection_name, api_version=api_version, disable_pagination=disable_pagination) for m in methods: if m.name == method: return m raise UnknownMethodError(method, full_collection_name) def _GetService(client, collection_name): return getattr(client, collection_name.replace(NAME_SEPARATOR, '_'), None) def _GetApiClient(api_name, api_version): """Gets the repesctive api client for the api.""" api_def = apis_internal.GetApiDef(api_name, api_version) if api_def.apitools: client = apis.GetClientInstance(api_name, api_version, no_http=True) else: client = apis.GetGapicClientInstance(api_name, api_version) return client def GetMethods( full_collection_name, api_version=None, disable_pagination=False): """Gets all the methods available on the given collection. Args: full_collection_name: str, The collection including the api name. api_version: str, The version string of the API or None to use the default for this API. disable_pagination: bool, Boolean for whether pagination should be disabled Returns: [APIMethod], The method specifications. """ api_collection = GetAPICollection(full_collection_name, api_version=api_version) client = _GetApiClient(api_collection.api_name, api_collection.api_version) service = _GetService(client, api_collection.name) if not service: # This is a synthetic collection that does not actually have a backing API. return [] method_names = service.GetMethodsList() method_configs = [(name, service.GetMethodConfig(name)) for name in method_names] return [APIMethod(service, name, api_collection, config, disable_pagination) for name, config in method_configs]