# -*- coding: utf-8 -*- # # Copyright 2020 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for dealing with AI Platform endpoints API.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from apitools.base.py import encoding from apitools.base.py import exceptions as apitools_exceptions from apitools.base.py import extra_types from apitools.base.py import list_pager from googlecloudsdk.api_lib.ai import util as api_util from googlecloudsdk.api_lib.ai.models import client as model_client from googlecloudsdk.api_lib.util import apis from googlecloudsdk.command_lib.ai import constants from googlecloudsdk.command_lib.ai import errors from googlecloudsdk.command_lib.ai import flags from googlecloudsdk.core import exceptions as core_exceptions from googlecloudsdk.core import properties from googlecloudsdk.core import resources from googlecloudsdk.core.credentials import requests from six.moves import http_client def _ParseModel(model_id, location_id): """Parses a model ID into a model resource object.""" return resources.REGISTRY.Parse( model_id, params={ 'locationsId': location_id, 'projectsId': properties.VALUES.core.project.GetOrFail, }, collection='aiplatform.projects.locations.models', ) def _ConvertPyListToMessageList(message_type, values): return [encoding.PyValueToMessage(message_type, v) for v in values] def _GetModelDeploymentResourceType( model_ref, client, shared_resources_ref=None ): """Gets the deployment resource type of a model. Args: model_ref: a model resource object. client: an apis.GetClientInstance object. shared_resources_ref: str, the shared deployment resource pool the model should use, formatted as the full URI Returns: A string which value must be 'DEDICATED_RESOURCES', 'AUTOMATIC_RESOURCES' or 'SHARED_RESOURCES' Raises: ArgumentError: if the model resource object is not found. """ try: model_msg = model_client.ModelsClient(client=client).Get(model_ref) except apitools_exceptions.HttpError: raise errors.ArgumentError(( 'There is an error while getting the model information. ' 'Please make sure the model %r exists.' % model_ref.RelativeName() )) model_resource = encoding.MessageToPyValue(model_msg) # The resource values returned in the list could be multiple. supported_deployment_resources_types = model_resource[ 'supportedDeploymentResourcesTypes' ] if shared_resources_ref is not None: if 'SHARED_RESOURCES' not in supported_deployment_resources_types: raise errors.ArgumentError( 'Shared resources not supported for model {}.'.format( model_ref.RelativeName() ) ) else: return 'SHARED_RESOURCES' try: supported_deployment_resources_types.remove('SHARED_RESOURCES') return supported_deployment_resources_types[0] # Throws value error if dedicated/automatic resources was the only supported # resource found in list except ValueError: return model_resource['supportedDeploymentResourcesTypes'][0] def _DoHttpPost(url, headers, body): """Makes an http POST request.""" response = requests.GetSession().request( 'POST', url, data=body, headers=headers ) return response.status_code, response.headers, response.content def _DoStreamHttpPost(url, headers, body): """Makes an http POST request.""" with requests.GetSession().request( 'POST', url, data=body, headers=headers, stream=True ) as resp: for line in resp.iter_lines(): yield line def _CheckIsGdcGgsModel(self, endpoint_ref): """GDC GGS model is only supported for GDC endpoints.""" endpoint = self.Get(endpoint_ref) endpoint_resource = encoding.MessageToPyValue(endpoint) return ( endpoint_resource is not None and 'gdcConfig' in endpoint_resource and 'zone' in endpoint_resource['gdcConfig'] and endpoint_resource['gdcConfig']['zone'] ) class EndpointsClient(object): """High-level client for the AI Platform endpoints surface.""" def __init__(self, client=None, messages=None, version=None): self.client = client or apis.GetClientInstance( constants.AI_PLATFORM_API_NAME, constants.AI_PLATFORM_API_VERSION[version], ) self.messages = messages or self.client.MESSAGES_MODULE def Create( self, location_ref, display_name, labels, description=None, network=None, endpoint_id=None, encryption_kms_key_name=None, request_response_logging_table=None, request_response_logging_rate=None, ): """Creates a new endpoint using v1 API. Args: location_ref: Resource, the parsed location to create an endpoint. display_name: str, the display name of the new endpoint. labels: list, the labels to organize the new endpoint. description: str or None, the description of the new endpoint. network: str, the full name of the Google Compute Engine network. endpoint_id: str or None, the id of the new endpoint. encryption_kms_key_name: str or None, the Cloud KMS resource identifier of the customer managed encryption key used to protect a resource. request_response_logging_table: str or None, the BigQuery table uri for request-response logging. request_response_logging_rate: float or None, the sampling rate for request-response logging. Returns: A long-running operation for Create. """ encryption_spec = None if encryption_kms_key_name: encryption_spec = self.messages.GoogleCloudAiplatformV1EncryptionSpec( kmsKeyName=encryption_kms_key_name ) endpoint = api_util.GetMessage('Endpoint', constants.GA_VERSION)( displayName=display_name, description=description, labels=labels, network=network, encryptionSpec=encryption_spec, ) if request_response_logging_table is not None: endpoint.predictRequestResponseLoggingConfig = api_util.GetMessage( 'PredictRequestResponseLoggingConfig', constants.GA_VERSION )( enabled=True, samplingRate=request_response_logging_rate if request_response_logging_rate else 0.0, bigqueryDestination=api_util.GetMessage( 'BigQueryDestination', constants.GA_VERSION )(outputUri=request_response_logging_table), ) req = self.messages.AiplatformProjectsLocationsEndpointsCreateRequest( parent=location_ref.RelativeName(), endpointId=endpoint_id, googleCloudAiplatformV1Endpoint=endpoint, ) return self.client.projects_locations_endpoints.Create(req) def CreateBeta( self, location_ref, display_name, labels, description=None, network=None, endpoint_id=None, encryption_kms_key_name=None, gdce_zone=None, gdc_zone=None, request_response_logging_table=None, request_response_logging_rate=None, ): """Creates a new endpoint using v1beta1 API. Args: location_ref: Resource, the parsed location to create an endpoint. display_name: str, the display name of the new endpoint. labels: list, the labels to organize the new endpoint. description: str or None, the description of the new endpoint. network: str, the full name of the Google Compute Engine network. endpoint_id: str or None, the id of the new endpoint. encryption_kms_key_name: str or None, the Cloud KMS resource identifier of the customer managed encryption key used to protect a resource. gdce_zone: str or None, the name of the GDCE zone. gdc_zone: str or None, the name of the GDC zone. request_response_logging_table: str or None, the BigQuery table uri for request-response logging. request_response_logging_rate: float or None, the sampling rate for request-response logging. Returns: A long-running operation for Create. """ encryption_spec = None if encryption_kms_key_name: encryption_spec = ( self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec( kmsKeyName=encryption_kms_key_name ) ) gdce_config = None if gdce_zone: gdce_config = self.messages.GoogleCloudAiplatformV1beta1GdceConfig( zone=gdce_zone ) gdc_config = None if gdc_zone: gdc_config = self.messages.GoogleCloudAiplatformV1beta1GdcConfig( zone=gdc_zone ) endpoint = api_util.GetMessage('Endpoint', constants.BETA_VERSION)( displayName=display_name, description=description, labels=labels, network=network, encryptionSpec=encryption_spec, gdceConfig=gdce_config, gdcConfig=gdc_config, ) if request_response_logging_table is not None: endpoint.predictRequestResponseLoggingConfig = api_util.GetMessage( 'PredictRequestResponseLoggingConfig', constants.BETA_VERSION )( enabled=True, samplingRate=request_response_logging_rate if request_response_logging_rate else 0.0, bigqueryDestination=api_util.GetMessage( 'BigQueryDestination', constants.BETA_VERSION )(outputUri=request_response_logging_table), ) req = self.messages.AiplatformProjectsLocationsEndpointsCreateRequest( parent=location_ref.RelativeName(), endpointId=endpoint_id, googleCloudAiplatformV1beta1Endpoint=endpoint, ) return self.client.projects_locations_endpoints.Create(req) def Delete(self, endpoint_ref): """Deletes an existing endpoint.""" req = self.messages.AiplatformProjectsLocationsEndpointsDeleteRequest( name=endpoint_ref.RelativeName() ) return self.client.projects_locations_endpoints.Delete(req) def Get(self, endpoint_ref): """Gets details about an endpoint.""" req = self.messages.AiplatformProjectsLocationsEndpointsGetRequest( name=endpoint_ref.RelativeName() ) return self.client.projects_locations_endpoints.Get(req) def List(self, location_ref, filter_str=None, gdc_zone=None): """Lists endpoints in the project.""" req = self.messages.AiplatformProjectsLocationsEndpointsListRequest( parent=location_ref.RelativeName(), filter=filter_str, gdcZone=gdc_zone, ) return list_pager.YieldFromList( self.client.projects_locations_endpoints, req, field='endpoints', batch_size_attribute='pageSize', ) def Patch( self, endpoint_ref, labels_update, display_name=None, description=None, traffic_split=None, clear_traffic_split=False, request_response_logging_table=None, request_response_logging_rate=None, disable_request_response_logging=False, ): """Updates an endpoint using v1 API. Args: endpoint_ref: Resource, the parsed endpoint to be updated. labels_update: UpdateResult, the result of applying the label diff constructed from args. display_name: str or None, the new display name of the endpoint. description: str or None, the new description of the endpoint. traffic_split: dict or None, the new traffic split of the endpoint. clear_traffic_split: bool, whether or not clear traffic split of the endpoint. request_response_logging_table: str or None, the BigQuery table uri for request-response logging. request_response_logging_rate: float or None, the sampling rate for request-response logging. disable_request_response_logging: bool, whether or not disable request-response logging of the endpoint. Returns: The response message of Patch. Raises: NoFieldsSpecifiedError: An error if no updates requested. """ endpoint = api_util.GetMessage('Endpoint', constants.GA_VERSION)() update_mask = [] if labels_update.needs_update: endpoint.labels = labels_update.labels update_mask.append('labels') if display_name is not None: endpoint.displayName = display_name update_mask.append('display_name') if traffic_split is not None: additional_properties = [] for key, value in sorted(traffic_split.items()): additional_properties.append( endpoint.TrafficSplitValue().AdditionalProperty( key=key, value=value ) ) endpoint.trafficSplit = endpoint.TrafficSplitValue( additionalProperties=additional_properties ) update_mask.append('traffic_split') if clear_traffic_split: endpoint.trafficSplit = None update_mask.append('traffic_split') if description is not None: endpoint.description = description update_mask.append('description') if ( request_response_logging_table is not None or request_response_logging_rate is not None ): request_response_logging_config = self.Get( endpoint_ref ).predictRequestResponseLoggingConfig if not request_response_logging_config: request_response_logging_config = api_util.GetMessage( 'PredictRequestResponseLoggingConfig', constants.GA_VERSION )() request_response_logging_config.enabled = True if request_response_logging_table is not None: request_response_logging_config.bigqueryDestination = ( api_util.GetMessage('BigQueryDestination', constants.GA_VERSION)( outputUri=request_response_logging_table ) ) if request_response_logging_rate is not None: request_response_logging_config.samplingRate = ( request_response_logging_rate ) endpoint.predictRequestResponseLoggingConfig = ( request_response_logging_config ) update_mask.append('predict_request_response_logging_config') if disable_request_response_logging: request_response_logging_config = self.Get( endpoint_ref ).predictRequestResponseLoggingConfig if request_response_logging_config: request_response_logging_config.enabled = False endpoint.predictRequestResponseLoggingConfig = ( request_response_logging_config ) update_mask.append('predict_request_response_logging_config') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest( name=endpoint_ref.RelativeName(), googleCloudAiplatformV1Endpoint=endpoint, updateMask=','.join(update_mask), ) return self.client.projects_locations_endpoints.Patch(req) def PatchBeta( self, endpoint_ref, labels_update, display_name=None, description=None, traffic_split=None, clear_traffic_split=False, request_response_logging_table=None, request_response_logging_rate=None, disable_request_response_logging=False, ): """Updates an endpoint using v1beta1 API. Args: endpoint_ref: Resource, the parsed endpoint to be updated. labels_update: UpdateResult, the result of applying the label diff constructed from args. display_name: str or None, the new display name of the endpoint. description: str or None, the new description of the endpoint. traffic_split: dict or None, the new traffic split of the endpoint. clear_traffic_split: bool, whether or not clear traffic split of the endpoint. request_response_logging_table: str or None, the BigQuery table uri for request-response logging. request_response_logging_rate: float or None, the sampling rate for request-response logging. disable_request_response_logging: bool, whether or not disable request-response logging of the endpoint. Returns: The response message of Patch. Raises: NoFieldsSpecifiedError: An error if no updates requested. """ endpoint = self.messages.GoogleCloudAiplatformV1beta1Endpoint() update_mask = [] if labels_update.needs_update: endpoint.labels = labels_update.labels update_mask.append('labels') if display_name is not None: endpoint.displayName = display_name update_mask.append('display_name') if traffic_split is not None: additional_properties = [] for key, value in sorted(traffic_split.items()): additional_properties.append( endpoint.TrafficSplitValue().AdditionalProperty( key=key, value=value ) ) endpoint.trafficSplit = endpoint.TrafficSplitValue( additionalProperties=additional_properties ) update_mask.append('traffic_split') if clear_traffic_split: endpoint.trafficSplit = None update_mask.append('traffic_split') if description is not None: endpoint.description = description update_mask.append('description') if ( request_response_logging_table is not None or request_response_logging_rate is not None ): request_response_logging_config = self.Get( endpoint_ref ).predictRequestResponseLoggingConfig if not request_response_logging_config: request_response_logging_config = api_util.GetMessage( 'PredictRequestResponseLoggingConfig', constants.BETA_VERSION )() request_response_logging_config.enabled = True if request_response_logging_table is not None: request_response_logging_config.bigqueryDestination = ( api_util.GetMessage('BigQueryDestination', constants.BETA_VERSION)( outputUri=request_response_logging_table ) ) if request_response_logging_rate is not None: request_response_logging_config.samplingRate = ( request_response_logging_rate ) endpoint.predictRequestResponseLoggingConfig = ( request_response_logging_config ) update_mask.append('predict_request_response_logging_config') if disable_request_response_logging: request_response_logging_config = self.Get( endpoint_ref ).predictRequestResponseLoggingConfig if request_response_logging_config: request_response_logging_config.enabled = False endpoint.predictRequestResponseLoggingConfig = ( request_response_logging_config ) update_mask.append('predict_request_response_logging_config') if not update_mask: raise errors.NoFieldsSpecifiedError('No updates requested.') req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest( name=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1Endpoint=endpoint, updateMask=','.join(update_mask), ) return self.client.projects_locations_endpoints.Patch(req) def Predict(self, endpoint_ref, instances_json): """Sends online prediction request to an endpoint using v1 API.""" predict_request = self.messages.GoogleCloudAiplatformV1PredictRequest( instances=_ConvertPyListToMessageList( extra_types.JsonValue, instances_json['instances'] ) ) if 'parameters' in instances_json: predict_request.parameters = encoding.PyValueToMessage( extra_types.JsonValue, instances_json['parameters'] ) req = self.messages.AiplatformProjectsLocationsEndpointsPredictRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1PredictRequest=predict_request, ) return self.client.projects_locations_endpoints.Predict(req) def PredictBeta(self, endpoint_ref, instances_json): """Sends online prediction request to an endpoint using v1beta1 API.""" predict_request = self.messages.GoogleCloudAiplatformV1beta1PredictRequest( instances=_ConvertPyListToMessageList( extra_types.JsonValue, instances_json['instances'] ) ) if 'parameters' in instances_json: predict_request.parameters = encoding.PyValueToMessage( extra_types.JsonValue, instances_json['parameters'] ) req = self.messages.AiplatformProjectsLocationsEndpointsPredictRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1PredictRequest=predict_request, ) return self.client.projects_locations_endpoints.Predict(req) def RawPredict(self, endpoint_ref, headers, request): """Sends online raw prediction request to an endpoint.""" url = '{}{}/{}:rawPredict'.format( self.client.url, getattr(self.client, '_VERSION'), endpoint_ref.RelativeName(), ) status, response_headers, response = _DoHttpPost(url, headers, request) if status != http_client.OK: raise core_exceptions.Error( 'HTTP request failed. Response:\n' + response.decode() ) return response_headers, response def StreamRawPredict(self, endpoint_ref, headers, request): """Sends online raw prediction request to an endpoint.""" url = '{}{}/{}:streamRawPredict'.format( self.client.url, getattr(self.client, '_VERSION'), endpoint_ref.RelativeName(), ) for resp in _DoStreamHttpPost(url, headers, request): yield resp def DirectPredict(self, endpoint_ref, inputs_json): """Sends online direct prediction request to an endpoint using v1 API.""" direct_predict_request = ( self.messages.GoogleCloudAiplatformV1DirectPredictRequest( inputs=_ConvertPyListToMessageList( self.messages.GoogleCloudAiplatformV1Tensor, inputs_json['inputs'], ) ) ) if 'parameters' in inputs_json: direct_predict_request.parameters = encoding.PyValueToMessage( self.messages.GoogleCloudAiplatformV1Tensor, inputs_json['parameters'] ) req = ( self.messages.AiplatformProjectsLocationsEndpointsDirectPredictRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1DirectPredictRequest=direct_predict_request, ) ) return self.client.projects_locations_endpoints.DirectPredict(req) def DirectPredictBeta(self, endpoint_ref, inputs_json): """Sends online direct prediction request to an endpoint using v1beta1 API.""" direct_predict_request = ( self.messages.GoogleCloudAiplatformV1beta1DirectPredictRequest( inputs=_ConvertPyListToMessageList( self.messages.GoogleCloudAiplatformV1beta1Tensor, inputs_json['inputs'], ) ) ) if 'parameters' in inputs_json: direct_predict_request.parameters = encoding.PyValueToMessage( self.messages.GoogleCloudAiplatformV1beta1Tensor, inputs_json['parameters'], ) req = self.messages.AiplatformProjectsLocationsEndpointsDirectPredictRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1DirectPredictRequest=direct_predict_request, ) return self.client.projects_locations_endpoints.DirectPredict(req) def DirectRawPredict(self, endpoint_ref, input_json): """Sends online direct raw prediction request to an endpoint using v1 API.""" direct_raw_predict_request = self.messages.GoogleCloudAiplatformV1DirectRawPredictRequest( input=bytes(input_json['input'], 'utf-8'), # Method name can be "methodName" or "method_name" methodName=input_json.get('methodName', input_json.get('method_name')), ) req = self.messages.AiplatformProjectsLocationsEndpointsDirectRawPredictRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1DirectRawPredictRequest=direct_raw_predict_request, ) return self.client.projects_locations_endpoints.DirectRawPredict(req) def DirectRawPredictBeta(self, endpoint_ref, input_json): """Sends online direct raw prediction request to an endpoint using v1beta1 API.""" direct_raw_predict_request = self.messages.GoogleCloudAiplatformV1beta1DirectRawPredictRequest( input=bytes(input_json['input'], 'utf-8'), # Method name can be "methodName" or "method_name" methodName=input_json.get('methodName', input_json.get('method_name')), ) req = self.messages.AiplatformProjectsLocationsEndpointsDirectRawPredictRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1DirectRawPredictRequest=direct_raw_predict_request, ) return self.client.projects_locations_endpoints.DirectRawPredict(req) def Explain(self, endpoint_ref, instances_json, args): """Sends online explanation request to an endpoint using v1beta1 API.""" explain_request = self.messages.GoogleCloudAiplatformV1ExplainRequest( instances=_ConvertPyListToMessageList( extra_types.JsonValue, instances_json['instances'] ) ) if 'parameters' in instances_json: explain_request.parameters = encoding.PyValueToMessage( extra_types.JsonValue, instances_json['parameters'] ) if args.deployed_model_id is not None: explain_request.deployedModelId = args.deployed_model_id req = self.messages.AiplatformProjectsLocationsEndpointsExplainRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1ExplainRequest=explain_request, ) return self.client.projects_locations_endpoints.Explain(req) def ExplainBeta(self, endpoint_ref, instances_json, args): """Sends online explanation request to an endpoint using v1beta1 API.""" explain_request = self.messages.GoogleCloudAiplatformV1beta1ExplainRequest( instances=_ConvertPyListToMessageList( extra_types.JsonValue, instances_json['instances'] ) ) if 'parameters' in instances_json: explain_request.parameters = encoding.PyValueToMessage( extra_types.JsonValue, instances_json['parameters'] ) if 'explanation_spec_override' in instances_json: explain_request.explanationSpecOverride = encoding.PyValueToMessage( self.messages.GoogleCloudAiplatformV1beta1ExplanationSpecOverride, instances_json['explanation_spec_override'], ) if args.deployed_model_id is not None: explain_request.deployedModelId = args.deployed_model_id req = self.messages.AiplatformProjectsLocationsEndpointsExplainRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1ExplainRequest=explain_request, ) return self.client.projects_locations_endpoints.Explain(req) def DeployModel( self, endpoint_ref, model, region, display_name, machine_type=None, tpu_topology=None, multihost_gpu_node_count=None, accelerator_dict=None, min_replica_count=None, max_replica_count=None, required_replica_count=None, reservation_affinity=None, autoscaling_metric_specs=None, spot=False, enable_access_logging=False, disable_container_logging=False, service_account=None, traffic_split=None, deployed_model_id=None, gpu_partition_size=None, ): """Deploys a model to an existing endpoint using v1 API. Args: endpoint_ref: Resource, the parsed endpoint that the model is deployed to. model: str, Id of the uploaded model to be deployed. region: str, the location of the endpoint and the model. display_name: str, the display name of the new deployed model. machine_type: str or None, the type of the machine to serve the model. tpu_topology: str or None, the topology of the TPU to serve the model. multihost_gpu_node_count: int or None, the number of nodes per replica for multihost GPU deployments. accelerator_dict: dict or None, the accelerator attached to the deployed model from args. min_replica_count: int or None, the minimum number of replicas the deployed model will be always deployed on. max_replica_count: int or None, the maximum number of replicas the deployed model may be deployed on. required_replica_count: int or None, the required number of replicas the deployed model will be considered successfully deployed. reservation_affinity: dict or None, the reservation affinity of the deployed model which specifies which reservations the deployed model can use. autoscaling_metric_specs: dict or None, the metric specification that defines the target resource utilization for calculating the desired replica count. spot: bool, whether or not deploy the model on spot resources. enable_access_logging: bool, whether or not enable access logs. disable_container_logging: bool, whether or not disable container logging. service_account: str or None, the service account that the deployed model runs as. traffic_split: dict or None, the new traffic split of the endpoint. deployed_model_id: str or None, id of the deployed model. gpu_partition_size: str or None, the partition size of the GPU accelerator. Returns: A long-running operation for DeployModel. """ model_ref = _ParseModel(model, region) resource_type = _GetModelDeploymentResourceType(model_ref, self.client) if resource_type == 'DEDICATED_RESOURCES': # dedicated resources machine_spec = self.messages.GoogleCloudAiplatformV1MachineSpec() if machine_type is not None: machine_spec.machineType = machine_type if tpu_topology is not None: machine_spec.tpuTopology = tpu_topology if multihost_gpu_node_count is not None: machine_spec.multihostGpuNodeCount = multihost_gpu_node_count accelerator = flags.ParseAcceleratorFlag( accelerator_dict, constants.GA_VERSION ) if accelerator is not None: machine_spec.acceleratorType = accelerator.acceleratorType machine_spec.acceleratorCount = accelerator.acceleratorCount if reservation_affinity is not None: machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag( reservation_affinity, constants.GA_VERSION ) if gpu_partition_size is not None: machine_spec.gpuPartitionSize = gpu_partition_size dedicated = self.messages.GoogleCloudAiplatformV1DedicatedResources( machineSpec=machine_spec, spot=spot ) # min-replica-count is required and must be >= 1 if models use dedicated # resources. Default to 1 if not specified. dedicated.minReplicaCount = min_replica_count or 1 if max_replica_count is not None: dedicated.maxReplicaCount = max_replica_count if required_replica_count is not None: dedicated.requiredReplicaCount = required_replica_count if autoscaling_metric_specs is not None: autoscaling_metric_specs_list = [] for name, target in sorted(autoscaling_metric_specs.items()): autoscaling_metric_specs_list.append( self.messages.GoogleCloudAiplatformV1AutoscalingMetricSpec( metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[name], target=target, ) ) dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list deployed_model = self.messages.GoogleCloudAiplatformV1DeployedModel( dedicatedResources=dedicated, displayName=display_name, model=model_ref.RelativeName(), ) else: # automatic resources automatic = self.messages.GoogleCloudAiplatformV1AutomaticResources() if min_replica_count is not None: automatic.minReplicaCount = min_replica_count if max_replica_count is not None: automatic.maxReplicaCount = max_replica_count deployed_model = self.messages.GoogleCloudAiplatformV1DeployedModel( automaticResources=automatic, displayName=display_name, model=model_ref.RelativeName(), ) deployed_model.enableAccessLogging = enable_access_logging deployed_model.disableContainerLogging = disable_container_logging if service_account is not None: deployed_model.serviceAccount = service_account if deployed_model_id is not None: deployed_model.id = deployed_model_id deployed_model_req = ( self.messages.GoogleCloudAiplatformV1DeployModelRequest( deployedModel=deployed_model ) ) if traffic_split is not None: additional_properties = [] for key, value in sorted(traffic_split.items()): additional_properties.append( deployed_model_req.TrafficSplitValue().AdditionalProperty( key=key, value=value ) ) deployed_model_req.trafficSplit = deployed_model_req.TrafficSplitValue( additionalProperties=additional_properties ) req = self.messages.AiplatformProjectsLocationsEndpointsDeployModelRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1DeployModelRequest=deployed_model_req, ) return self.client.projects_locations_endpoints.DeployModel(req) def DeployModelBeta( self, endpoint_ref, model, region, display_name, machine_type=None, tpu_topology=None, multihost_gpu_node_count=None, accelerator_dict=None, min_replica_count=None, max_replica_count=None, required_replica_count=None, reservation_affinity=None, autoscaling_metric_specs=None, spot=False, enable_access_logging=False, enable_container_logging=False, service_account=None, traffic_split=None, deployed_model_id=None, shared_resources_ref=None, min_scaleup_period=None, idle_scaledown_period=None, initial_replica_count=None, gpu_partition_size=None, ): """Deploys a model to an existing endpoint using v1beta1 API. Args: endpoint_ref: Resource, the parsed endpoint that the model is deployed to. model: str, Id of the uploaded model to be deployed. region: str, the location of the endpoint and the model. display_name: str, the display name of the new deployed model. machine_type: str or None, the type of the machine to serve the model. tpu_topology: str or None, the topology of the TPU to serve the model. multihost_gpu_node_count: int or None, the number of nodes per replica for multihost GPU deployments. accelerator_dict: dict or None, the accelerator attached to the deployed model from args. min_replica_count: int or None, the minimum number of replicas the deployed model will be always deployed on. max_replica_count: int or None, the maximum number of replicas the deployed model may be deployed on. required_replica_count: int or None, the required number of replicas the deployed model will be considered successfully deployed. reservation_affinity: dict or None, the reservation affinity of the deployed model which specifies which reservations the deployed model can use. autoscaling_metric_specs: dict or None, the metric specification that defines the target resource utilization for calculating the desired replica count. spot: bool, whether or not deploy the model on spot resources. enable_access_logging: bool, whether or not enable access logs. enable_container_logging: bool, whether or not enable container logging. service_account: str or None, the service account that the deployed model runs as. traffic_split: dict or None, the new traffic split of the endpoint. deployed_model_id: str or None, id of the deployed model. shared_resources_ref: str or None, the shared deployment resource pool the model should use min_scaleup_period: str or None, the minimum duration (in seconds) that a deployment will be scaled up before traffic is evaluated for potential scale-down. Defaults to 1 hour if min replica count is 0. idle_scaledown_period: str or None, the duration after which the deployment is scaled down if no traffic is received. This only applies to deployments enrolled in scale-to-zero. initial_replica_count: int or None, the initial number of replicas the deployment will be scaled up to. This only applies to deployments enrolled in scale-to-zero. gpu_partition_size: str or None, the partition size of the GPU accelerator. Returns: A long-running operation for DeployModel. """ is_gdc_ggs_model = _CheckIsGdcGgsModel(self, endpoint_ref) if is_gdc_ggs_model: # send psudo dedicated resources for gdc ggs model. machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec( machineType='n1-standard-2', acceleratorType=self.messages.GoogleCloudAiplatformV1beta1MachineSpec.AcceleratorTypeValueValuesEnum.NVIDIA_TESLA_T4, acceleratorCount=1, ) dedicated = self.messages.GoogleCloudAiplatformV1beta1DedicatedResources( machineSpec=machine_spec, minReplicaCount=1, maxReplicaCount=1 ) deployed_model = self.messages.GoogleCloudAiplatformV1beta1DeployedModel( dedicatedResources=dedicated, displayName=display_name, gdcConnectedModel=model, ) else: model_ref = _ParseModel(model, region) resource_type = _GetModelDeploymentResourceType( model_ref, self.client, shared_resources_ref ) if resource_type == 'DEDICATED_RESOURCES': # dedicated resources machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec() if machine_type is not None: machine_spec.machineType = machine_type if tpu_topology is not None: machine_spec.tpuTopology = tpu_topology if multihost_gpu_node_count is not None: machine_spec.multihostGpuNodeCount = multihost_gpu_node_count accelerator = flags.ParseAcceleratorFlag( accelerator_dict, constants.BETA_VERSION ) if accelerator is not None: machine_spec.acceleratorType = accelerator.acceleratorType machine_spec.acceleratorCount = accelerator.acceleratorCount if reservation_affinity is not None: machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag( reservation_affinity, constants.BETA_VERSION ) if gpu_partition_size is not None: machine_spec.gpuPartitionSize = gpu_partition_size dedicated = ( self.messages.GoogleCloudAiplatformV1beta1DedicatedResources( machineSpec=machine_spec, spot=spot ) ) # min-replica-count is required and must be >= 0 if models use dedicated # resources. If value is 0, the deployment will be enrolled in the # scale-to-zero feature. Default to 1 if not specified. dedicated.minReplicaCount = ( 1 if min_replica_count is None else min_replica_count ) # if not specified and min-replica-count is 0, default to 1. if max_replica_count is None and dedicated.minReplicaCount == 0: dedicated.maxReplicaCount = 1 else: if max_replica_count is not None: dedicated.maxReplicaCount = max_replica_count if required_replica_count is not None: dedicated.requiredReplicaCount = required_replica_count # if not specified and min-replica-count is 0, default to 1. if initial_replica_count is None and dedicated.minReplicaCount == 0: dedicated.initialReplicaCount = 1 else: if initial_replica_count is not None: dedicated.initialReplicaCount = initial_replica_count if autoscaling_metric_specs is not None: autoscaling_metric_specs_list = [] for name, target in sorted(autoscaling_metric_specs.items()): autoscaling_metric_specs_list.append( self.messages.GoogleCloudAiplatformV1beta1AutoscalingMetricSpec( metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[ name ], target=target, ) ) dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list stz_spec = ( self.messages.GoogleCloudAiplatformV1beta1DedicatedResourcesScaleToZeroSpec() ) stz_spec_modified = False if min_scaleup_period is not None: stz_spec.minScaleupPeriod = '{}s'.format(min_scaleup_period) stz_spec_modified = True if idle_scaledown_period is not None: stz_spec.idleScaledownPeriod = '{}s'.format(idle_scaledown_period) stz_spec_modified = True if stz_spec_modified: dedicated.scaleToZeroSpec = stz_spec deployed_model = ( self.messages.GoogleCloudAiplatformV1beta1DeployedModel( dedicatedResources=dedicated, displayName=display_name, model=model_ref.RelativeName(), ) ) elif resource_type == 'AUTOMATIC_RESOURCES': # automatic resources automatic = ( self.messages.GoogleCloudAiplatformV1beta1AutomaticResources() ) if min_replica_count is not None: automatic.minReplicaCount = min_replica_count if max_replica_count is not None: automatic.maxReplicaCount = max_replica_count deployed_model = ( self.messages.GoogleCloudAiplatformV1beta1DeployedModel( automaticResources=automatic, displayName=display_name, model=model_ref.RelativeName(), ) ) # if resource type is SHARED_RESOURCES else: deployed_model = ( self.messages.GoogleCloudAiplatformV1beta1DeployedModel( displayName=display_name, model=model_ref.RelativeName(), sharedResources=shared_resources_ref.RelativeName(), ) ) deployed_model.enableAccessLogging = enable_access_logging deployed_model.enableContainerLogging = enable_container_logging if service_account is not None: deployed_model.serviceAccount = service_account if deployed_model_id is not None: deployed_model.id = deployed_model_id deployed_model_req = ( self.messages.GoogleCloudAiplatformV1beta1DeployModelRequest( deployedModel=deployed_model ) ) if traffic_split is not None: additional_properties = [] for key, value in sorted(traffic_split.items()): additional_properties.append( deployed_model_req.TrafficSplitValue().AdditionalProperty( key=key, value=value ) ) deployed_model_req.trafficSplit = deployed_model_req.TrafficSplitValue( additionalProperties=additional_properties ) req = self.messages.AiplatformProjectsLocationsEndpointsDeployModelRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1DeployModelRequest=deployed_model_req, ) return self.client.projects_locations_endpoints.DeployModel(req) def UndeployModel(self, endpoint_ref, deployed_model_id, traffic_split=None): """Undeploys a model from an endpoint using v1 API. Args: endpoint_ref: Resource, the parsed endpoint that the model is undeployed from. deployed_model_id: str, Id of the deployed model to be undeployed. traffic_split: dict or None, the new traffic split of the endpoint. Returns: A long-running operation for UndeployModel. """ undeployed_model_req = ( self.messages.GoogleCloudAiplatformV1UndeployModelRequest( deployedModelId=deployed_model_id ) ) if traffic_split is not None: additional_properties = [] for key, value in sorted(traffic_split.items()): additional_properties.append( undeployed_model_req.TrafficSplitValue().AdditionalProperty( key=key, value=value ) ) undeployed_model_req.trafficSplit = ( undeployed_model_req.TrafficSplitValue( additionalProperties=additional_properties ) ) req = ( self.messages.AiplatformProjectsLocationsEndpointsUndeployModelRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1UndeployModelRequest=undeployed_model_req, ) ) return self.client.projects_locations_endpoints.UndeployModel(req) def UndeployModelBeta( self, endpoint_ref, deployed_model_id, traffic_split=None ): """Undeploys a model from an endpoint using v1beta1 API. Args: endpoint_ref: Resource, the parsed endpoint that the model is undeployed from. deployed_model_id: str, Id of the deployed model to be undeployed. traffic_split: dict or None, the new traffic split of the endpoint. Returns: A long-running operation for UndeployModel. """ undeployed_model_req = ( self.messages.GoogleCloudAiplatformV1beta1UndeployModelRequest( deployedModelId=deployed_model_id ) ) if traffic_split is not None: additional_properties = [] for key, value in sorted(traffic_split.items()): additional_properties.append( undeployed_model_req.TrafficSplitValue().AdditionalProperty( key=key, value=value ) ) undeployed_model_req.trafficSplit = ( undeployed_model_req.TrafficSplitValue( additionalProperties=additional_properties ) ) req = self.messages.AiplatformProjectsLocationsEndpointsUndeployModelRequest( endpoint=endpoint_ref.RelativeName(), googleCloudAiplatformV1beta1UndeployModelRequest=undeployed_model_req, ) return self.client.projects_locations_endpoints.UndeployModel(req)