1210 lines
46 KiB
Python
1210 lines
46 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2020 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for dealing with AI Platform endpoints API."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
from apitools.base.py import encoding
|
|
from apitools.base.py import exceptions as apitools_exceptions
|
|
from apitools.base.py import extra_types
|
|
from apitools.base.py import list_pager
|
|
from googlecloudsdk.api_lib.ai import util as api_util
|
|
from googlecloudsdk.api_lib.ai.models import client as model_client
|
|
from googlecloudsdk.api_lib.util import apis
|
|
from googlecloudsdk.command_lib.ai import constants
|
|
from googlecloudsdk.command_lib.ai import errors
|
|
from googlecloudsdk.command_lib.ai import flags
|
|
from googlecloudsdk.core import exceptions as core_exceptions
|
|
from googlecloudsdk.core import properties
|
|
from googlecloudsdk.core import resources
|
|
from googlecloudsdk.core.credentials import requests
|
|
from six.moves import http_client
|
|
|
|
|
|
def _ParseModel(model_id, location_id):
|
|
"""Parses a model ID into a model resource object."""
|
|
return resources.REGISTRY.Parse(
|
|
model_id,
|
|
params={
|
|
'locationsId': location_id,
|
|
'projectsId': properties.VALUES.core.project.GetOrFail,
|
|
},
|
|
collection='aiplatform.projects.locations.models',
|
|
)
|
|
|
|
|
|
def _ConvertPyListToMessageList(message_type, values):
|
|
return [encoding.PyValueToMessage(message_type, v) for v in values]
|
|
|
|
|
|
def _GetModelDeploymentResourceType(
|
|
model_ref, client, shared_resources_ref=None
|
|
):
|
|
"""Gets the deployment resource type of a model.
|
|
|
|
Args:
|
|
model_ref: a model resource object.
|
|
client: an apis.GetClientInstance object.
|
|
shared_resources_ref: str, the shared deployment resource pool the model
|
|
should use, formatted as the full URI
|
|
|
|
Returns:
|
|
A string which value must be 'DEDICATED_RESOURCES', 'AUTOMATIC_RESOURCES'
|
|
or 'SHARED_RESOURCES'
|
|
|
|
Raises:
|
|
ArgumentError: if the model resource object is not found.
|
|
"""
|
|
try:
|
|
model_msg = model_client.ModelsClient(client=client).Get(model_ref)
|
|
except apitools_exceptions.HttpError:
|
|
raise errors.ArgumentError((
|
|
'There is an error while getting the model information. '
|
|
'Please make sure the model %r exists.'
|
|
% model_ref.RelativeName()
|
|
))
|
|
model_resource = encoding.MessageToPyValue(model_msg)
|
|
|
|
# The resource values returned in the list could be multiple.
|
|
supported_deployment_resources_types = model_resource[
|
|
'supportedDeploymentResourcesTypes'
|
|
]
|
|
if shared_resources_ref is not None:
|
|
if 'SHARED_RESOURCES' not in supported_deployment_resources_types:
|
|
raise errors.ArgumentError(
|
|
'Shared resources not supported for model {}.'.format(
|
|
model_ref.RelativeName()
|
|
)
|
|
)
|
|
else:
|
|
return 'SHARED_RESOURCES'
|
|
try:
|
|
supported_deployment_resources_types.remove('SHARED_RESOURCES')
|
|
return supported_deployment_resources_types[0]
|
|
# Throws value error if dedicated/automatic resources was the only supported
|
|
# resource found in list
|
|
except ValueError:
|
|
return model_resource['supportedDeploymentResourcesTypes'][0]
|
|
|
|
|
|
def _DoHttpPost(url, headers, body):
|
|
"""Makes an http POST request."""
|
|
response = requests.GetSession().request(
|
|
'POST', url, data=body, headers=headers
|
|
)
|
|
return response.status_code, response.headers, response.content
|
|
|
|
|
|
def _DoStreamHttpPost(url, headers, body):
|
|
"""Makes an http POST request."""
|
|
with requests.GetSession().request(
|
|
'POST', url, data=body, headers=headers, stream=True
|
|
) as resp:
|
|
for line in resp.iter_lines():
|
|
yield line
|
|
|
|
|
|
def _CheckIsGdcGgsModel(self, endpoint_ref):
|
|
"""GDC GGS model is only supported for GDC endpoints."""
|
|
endpoint = self.Get(endpoint_ref)
|
|
endpoint_resource = encoding.MessageToPyValue(endpoint)
|
|
return (
|
|
endpoint_resource is not None
|
|
and 'gdcConfig' in endpoint_resource
|
|
and 'zone' in endpoint_resource['gdcConfig']
|
|
and endpoint_resource['gdcConfig']['zone']
|
|
)
|
|
|
|
|
|
class EndpointsClient(object):
|
|
"""High-level client for the AI Platform endpoints surface."""
|
|
|
|
def __init__(self, client=None, messages=None, version=None):
|
|
self.client = client or apis.GetClientInstance(
|
|
constants.AI_PLATFORM_API_NAME,
|
|
constants.AI_PLATFORM_API_VERSION[version],
|
|
)
|
|
self.messages = messages or self.client.MESSAGES_MODULE
|
|
|
|
def Create(
|
|
self,
|
|
location_ref,
|
|
display_name,
|
|
labels,
|
|
description=None,
|
|
network=None,
|
|
endpoint_id=None,
|
|
encryption_kms_key_name=None,
|
|
request_response_logging_table=None,
|
|
request_response_logging_rate=None,
|
|
):
|
|
"""Creates a new endpoint using v1 API.
|
|
|
|
Args:
|
|
location_ref: Resource, the parsed location to create an endpoint.
|
|
display_name: str, the display name of the new endpoint.
|
|
labels: list, the labels to organize the new endpoint.
|
|
description: str or None, the description of the new endpoint.
|
|
network: str, the full name of the Google Compute Engine network.
|
|
endpoint_id: str or None, the id of the new endpoint.
|
|
encryption_kms_key_name: str or None, the Cloud KMS resource identifier of
|
|
the customer managed encryption key used to protect a resource.
|
|
request_response_logging_table: str or None, the BigQuery table uri for
|
|
request-response logging.
|
|
request_response_logging_rate: float or None, the sampling rate for
|
|
request-response logging.
|
|
|
|
Returns:
|
|
A long-running operation for Create.
|
|
"""
|
|
encryption_spec = None
|
|
if encryption_kms_key_name:
|
|
encryption_spec = self.messages.GoogleCloudAiplatformV1EncryptionSpec(
|
|
kmsKeyName=encryption_kms_key_name
|
|
)
|
|
|
|
endpoint = api_util.GetMessage('Endpoint', constants.GA_VERSION)(
|
|
displayName=display_name,
|
|
description=description,
|
|
labels=labels,
|
|
network=network,
|
|
encryptionSpec=encryption_spec,
|
|
)
|
|
if request_response_logging_table is not None:
|
|
endpoint.predictRequestResponseLoggingConfig = api_util.GetMessage(
|
|
'PredictRequestResponseLoggingConfig', constants.GA_VERSION
|
|
)(
|
|
enabled=True,
|
|
samplingRate=request_response_logging_rate
|
|
if request_response_logging_rate
|
|
else 0.0,
|
|
bigqueryDestination=api_util.GetMessage(
|
|
'BigQueryDestination', constants.GA_VERSION
|
|
)(outputUri=request_response_logging_table),
|
|
)
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsCreateRequest(
|
|
parent=location_ref.RelativeName(),
|
|
endpointId=endpoint_id,
|
|
googleCloudAiplatformV1Endpoint=endpoint,
|
|
)
|
|
return self.client.projects_locations_endpoints.Create(req)
|
|
|
|
def CreateBeta(
|
|
self,
|
|
location_ref,
|
|
display_name,
|
|
labels,
|
|
description=None,
|
|
network=None,
|
|
endpoint_id=None,
|
|
encryption_kms_key_name=None,
|
|
gdce_zone=None,
|
|
gdc_zone=None,
|
|
request_response_logging_table=None,
|
|
request_response_logging_rate=None,
|
|
):
|
|
"""Creates a new endpoint using v1beta1 API.
|
|
|
|
Args:
|
|
location_ref: Resource, the parsed location to create an endpoint.
|
|
display_name: str, the display name of the new endpoint.
|
|
labels: list, the labels to organize the new endpoint.
|
|
description: str or None, the description of the new endpoint.
|
|
network: str, the full name of the Google Compute Engine network.
|
|
endpoint_id: str or None, the id of the new endpoint.
|
|
encryption_kms_key_name: str or None, the Cloud KMS resource identifier of
|
|
the customer managed encryption key used to protect a resource.
|
|
gdce_zone: str or None, the name of the GDCE zone.
|
|
gdc_zone: str or None, the name of the GDC zone.
|
|
request_response_logging_table: str or None, the BigQuery table uri for
|
|
request-response logging.
|
|
request_response_logging_rate: float or None, the sampling rate for
|
|
request-response logging.
|
|
|
|
Returns:
|
|
A long-running operation for Create.
|
|
"""
|
|
encryption_spec = None
|
|
if encryption_kms_key_name:
|
|
encryption_spec = (
|
|
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
|
|
kmsKeyName=encryption_kms_key_name
|
|
)
|
|
)
|
|
|
|
gdce_config = None
|
|
if gdce_zone:
|
|
gdce_config = self.messages.GoogleCloudAiplatformV1beta1GdceConfig(
|
|
zone=gdce_zone
|
|
)
|
|
|
|
gdc_config = None
|
|
if gdc_zone:
|
|
gdc_config = self.messages.GoogleCloudAiplatformV1beta1GdcConfig(
|
|
zone=gdc_zone
|
|
)
|
|
|
|
endpoint = api_util.GetMessage('Endpoint', constants.BETA_VERSION)(
|
|
displayName=display_name,
|
|
description=description,
|
|
labels=labels,
|
|
network=network,
|
|
encryptionSpec=encryption_spec,
|
|
gdceConfig=gdce_config,
|
|
gdcConfig=gdc_config,
|
|
)
|
|
if request_response_logging_table is not None:
|
|
endpoint.predictRequestResponseLoggingConfig = api_util.GetMessage(
|
|
'PredictRequestResponseLoggingConfig', constants.BETA_VERSION
|
|
)(
|
|
enabled=True,
|
|
samplingRate=request_response_logging_rate
|
|
if request_response_logging_rate
|
|
else 0.0,
|
|
bigqueryDestination=api_util.GetMessage(
|
|
'BigQueryDestination', constants.BETA_VERSION
|
|
)(outputUri=request_response_logging_table),
|
|
)
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsCreateRequest(
|
|
parent=location_ref.RelativeName(),
|
|
endpointId=endpoint_id,
|
|
googleCloudAiplatformV1beta1Endpoint=endpoint,
|
|
)
|
|
return self.client.projects_locations_endpoints.Create(req)
|
|
|
|
def Delete(self, endpoint_ref):
|
|
"""Deletes an existing endpoint."""
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsDeleteRequest(
|
|
name=endpoint_ref.RelativeName()
|
|
)
|
|
return self.client.projects_locations_endpoints.Delete(req)
|
|
|
|
def Get(self, endpoint_ref):
|
|
"""Gets details about an endpoint."""
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsGetRequest(
|
|
name=endpoint_ref.RelativeName()
|
|
)
|
|
return self.client.projects_locations_endpoints.Get(req)
|
|
|
|
def List(self, location_ref, filter_str=None, gdc_zone=None):
|
|
"""Lists endpoints in the project."""
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsListRequest(
|
|
parent=location_ref.RelativeName(),
|
|
filter=filter_str,
|
|
gdcZone=gdc_zone,
|
|
)
|
|
return list_pager.YieldFromList(
|
|
self.client.projects_locations_endpoints,
|
|
req,
|
|
field='endpoints',
|
|
batch_size_attribute='pageSize',
|
|
)
|
|
|
|
def Patch(
|
|
self,
|
|
endpoint_ref,
|
|
labels_update,
|
|
display_name=None,
|
|
description=None,
|
|
traffic_split=None,
|
|
clear_traffic_split=False,
|
|
request_response_logging_table=None,
|
|
request_response_logging_rate=None,
|
|
disable_request_response_logging=False,
|
|
):
|
|
"""Updates an endpoint using v1 API.
|
|
|
|
Args:
|
|
endpoint_ref: Resource, the parsed endpoint to be updated.
|
|
labels_update: UpdateResult, the result of applying the label diff
|
|
constructed from args.
|
|
display_name: str or None, the new display name of the endpoint.
|
|
description: str or None, the new description of the endpoint.
|
|
traffic_split: dict or None, the new traffic split of the endpoint.
|
|
clear_traffic_split: bool, whether or not clear traffic split of the
|
|
endpoint.
|
|
request_response_logging_table: str or None, the BigQuery table uri for
|
|
request-response logging.
|
|
request_response_logging_rate: float or None, the sampling rate for
|
|
request-response logging.
|
|
disable_request_response_logging: bool, whether or not disable
|
|
request-response logging of the endpoint.
|
|
|
|
Returns:
|
|
The response message of Patch.
|
|
|
|
Raises:
|
|
NoFieldsSpecifiedError: An error if no updates requested.
|
|
"""
|
|
endpoint = api_util.GetMessage('Endpoint', constants.GA_VERSION)()
|
|
update_mask = []
|
|
|
|
if labels_update.needs_update:
|
|
endpoint.labels = labels_update.labels
|
|
update_mask.append('labels')
|
|
|
|
if display_name is not None:
|
|
endpoint.displayName = display_name
|
|
update_mask.append('display_name')
|
|
|
|
if traffic_split is not None:
|
|
additional_properties = []
|
|
for key, value in sorted(traffic_split.items()):
|
|
additional_properties.append(
|
|
endpoint.TrafficSplitValue().AdditionalProperty(
|
|
key=key, value=value
|
|
)
|
|
)
|
|
endpoint.trafficSplit = endpoint.TrafficSplitValue(
|
|
additionalProperties=additional_properties
|
|
)
|
|
update_mask.append('traffic_split')
|
|
|
|
if clear_traffic_split:
|
|
endpoint.trafficSplit = None
|
|
update_mask.append('traffic_split')
|
|
|
|
if description is not None:
|
|
endpoint.description = description
|
|
update_mask.append('description')
|
|
|
|
if (
|
|
request_response_logging_table is not None
|
|
or request_response_logging_rate is not None
|
|
):
|
|
request_response_logging_config = self.Get(
|
|
endpoint_ref
|
|
).predictRequestResponseLoggingConfig
|
|
if not request_response_logging_config:
|
|
request_response_logging_config = api_util.GetMessage(
|
|
'PredictRequestResponseLoggingConfig', constants.GA_VERSION
|
|
)()
|
|
request_response_logging_config.enabled = True
|
|
if request_response_logging_table is not None:
|
|
request_response_logging_config.bigqueryDestination = (
|
|
api_util.GetMessage('BigQueryDestination', constants.GA_VERSION)(
|
|
outputUri=request_response_logging_table
|
|
)
|
|
)
|
|
if request_response_logging_rate is not None:
|
|
request_response_logging_config.samplingRate = (
|
|
request_response_logging_rate
|
|
)
|
|
endpoint.predictRequestResponseLoggingConfig = (
|
|
request_response_logging_config
|
|
)
|
|
update_mask.append('predict_request_response_logging_config')
|
|
|
|
if disable_request_response_logging:
|
|
request_response_logging_config = self.Get(
|
|
endpoint_ref
|
|
).predictRequestResponseLoggingConfig
|
|
if request_response_logging_config:
|
|
request_response_logging_config.enabled = False
|
|
endpoint.predictRequestResponseLoggingConfig = (
|
|
request_response_logging_config
|
|
)
|
|
update_mask.append('predict_request_response_logging_config')
|
|
|
|
if not update_mask:
|
|
raise errors.NoFieldsSpecifiedError('No updates requested.')
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest(
|
|
name=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1Endpoint=endpoint,
|
|
updateMask=','.join(update_mask),
|
|
)
|
|
return self.client.projects_locations_endpoints.Patch(req)
|
|
|
|
def PatchBeta(
|
|
self,
|
|
endpoint_ref,
|
|
labels_update,
|
|
display_name=None,
|
|
description=None,
|
|
traffic_split=None,
|
|
clear_traffic_split=False,
|
|
request_response_logging_table=None,
|
|
request_response_logging_rate=None,
|
|
disable_request_response_logging=False,
|
|
):
|
|
"""Updates an endpoint using v1beta1 API.
|
|
|
|
Args:
|
|
endpoint_ref: Resource, the parsed endpoint to be updated.
|
|
labels_update: UpdateResult, the result of applying the label diff
|
|
constructed from args.
|
|
display_name: str or None, the new display name of the endpoint.
|
|
description: str or None, the new description of the endpoint.
|
|
traffic_split: dict or None, the new traffic split of the endpoint.
|
|
clear_traffic_split: bool, whether or not clear traffic split of the
|
|
endpoint.
|
|
request_response_logging_table: str or None, the BigQuery table uri for
|
|
request-response logging.
|
|
request_response_logging_rate: float or None, the sampling rate for
|
|
request-response logging.
|
|
disable_request_response_logging: bool, whether or not disable
|
|
request-response logging of the endpoint.
|
|
|
|
Returns:
|
|
The response message of Patch.
|
|
|
|
Raises:
|
|
NoFieldsSpecifiedError: An error if no updates requested.
|
|
"""
|
|
endpoint = self.messages.GoogleCloudAiplatformV1beta1Endpoint()
|
|
update_mask = []
|
|
|
|
if labels_update.needs_update:
|
|
endpoint.labels = labels_update.labels
|
|
update_mask.append('labels')
|
|
|
|
if display_name is not None:
|
|
endpoint.displayName = display_name
|
|
update_mask.append('display_name')
|
|
|
|
if traffic_split is not None:
|
|
additional_properties = []
|
|
for key, value in sorted(traffic_split.items()):
|
|
additional_properties.append(
|
|
endpoint.TrafficSplitValue().AdditionalProperty(
|
|
key=key, value=value
|
|
)
|
|
)
|
|
endpoint.trafficSplit = endpoint.TrafficSplitValue(
|
|
additionalProperties=additional_properties
|
|
)
|
|
update_mask.append('traffic_split')
|
|
|
|
if clear_traffic_split:
|
|
endpoint.trafficSplit = None
|
|
update_mask.append('traffic_split')
|
|
|
|
if description is not None:
|
|
endpoint.description = description
|
|
update_mask.append('description')
|
|
|
|
if (
|
|
request_response_logging_table is not None
|
|
or request_response_logging_rate is not None
|
|
):
|
|
request_response_logging_config = self.Get(
|
|
endpoint_ref
|
|
).predictRequestResponseLoggingConfig
|
|
if not request_response_logging_config:
|
|
request_response_logging_config = api_util.GetMessage(
|
|
'PredictRequestResponseLoggingConfig', constants.BETA_VERSION
|
|
)()
|
|
request_response_logging_config.enabled = True
|
|
if request_response_logging_table is not None:
|
|
request_response_logging_config.bigqueryDestination = (
|
|
api_util.GetMessage('BigQueryDestination', constants.BETA_VERSION)(
|
|
outputUri=request_response_logging_table
|
|
)
|
|
)
|
|
if request_response_logging_rate is not None:
|
|
request_response_logging_config.samplingRate = (
|
|
request_response_logging_rate
|
|
)
|
|
endpoint.predictRequestResponseLoggingConfig = (
|
|
request_response_logging_config
|
|
)
|
|
update_mask.append('predict_request_response_logging_config')
|
|
|
|
if disable_request_response_logging:
|
|
request_response_logging_config = self.Get(
|
|
endpoint_ref
|
|
).predictRequestResponseLoggingConfig
|
|
if request_response_logging_config:
|
|
request_response_logging_config.enabled = False
|
|
endpoint.predictRequestResponseLoggingConfig = (
|
|
request_response_logging_config
|
|
)
|
|
update_mask.append('predict_request_response_logging_config')
|
|
|
|
if not update_mask:
|
|
raise errors.NoFieldsSpecifiedError('No updates requested.')
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest(
|
|
name=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1Endpoint=endpoint,
|
|
updateMask=','.join(update_mask),
|
|
)
|
|
return self.client.projects_locations_endpoints.Patch(req)
|
|
|
|
def Predict(self, endpoint_ref, instances_json):
|
|
"""Sends online prediction request to an endpoint using v1 API."""
|
|
predict_request = self.messages.GoogleCloudAiplatformV1PredictRequest(
|
|
instances=_ConvertPyListToMessageList(
|
|
extra_types.JsonValue, instances_json['instances']
|
|
)
|
|
)
|
|
if 'parameters' in instances_json:
|
|
predict_request.parameters = encoding.PyValueToMessage(
|
|
extra_types.JsonValue, instances_json['parameters']
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsPredictRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1PredictRequest=predict_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.Predict(req)
|
|
|
|
def PredictBeta(self, endpoint_ref, instances_json):
|
|
"""Sends online prediction request to an endpoint using v1beta1 API."""
|
|
predict_request = self.messages.GoogleCloudAiplatformV1beta1PredictRequest(
|
|
instances=_ConvertPyListToMessageList(
|
|
extra_types.JsonValue, instances_json['instances']
|
|
)
|
|
)
|
|
if 'parameters' in instances_json:
|
|
predict_request.parameters = encoding.PyValueToMessage(
|
|
extra_types.JsonValue, instances_json['parameters']
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsPredictRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1PredictRequest=predict_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.Predict(req)
|
|
|
|
def RawPredict(self, endpoint_ref, headers, request):
|
|
"""Sends online raw prediction request to an endpoint."""
|
|
url = '{}{}/{}:rawPredict'.format(
|
|
self.client.url,
|
|
getattr(self.client, '_VERSION'),
|
|
endpoint_ref.RelativeName(),
|
|
)
|
|
|
|
status, response_headers, response = _DoHttpPost(url, headers, request)
|
|
if status != http_client.OK:
|
|
raise core_exceptions.Error(
|
|
'HTTP request failed. Response:\n' + response.decode()
|
|
)
|
|
|
|
return response_headers, response
|
|
|
|
def StreamRawPredict(self, endpoint_ref, headers, request):
|
|
"""Sends online raw prediction request to an endpoint."""
|
|
url = '{}{}/{}:streamRawPredict'.format(
|
|
self.client.url,
|
|
getattr(self.client, '_VERSION'),
|
|
endpoint_ref.RelativeName(),
|
|
)
|
|
|
|
for resp in _DoStreamHttpPost(url, headers, request):
|
|
yield resp
|
|
|
|
def DirectPredict(self, endpoint_ref, inputs_json):
|
|
"""Sends online direct prediction request to an endpoint using v1 API."""
|
|
direct_predict_request = (
|
|
self.messages.GoogleCloudAiplatformV1DirectPredictRequest(
|
|
inputs=_ConvertPyListToMessageList(
|
|
self.messages.GoogleCloudAiplatformV1Tensor,
|
|
inputs_json['inputs'],
|
|
)
|
|
)
|
|
)
|
|
if 'parameters' in inputs_json:
|
|
direct_predict_request.parameters = encoding.PyValueToMessage(
|
|
self.messages.GoogleCloudAiplatformV1Tensor, inputs_json['parameters']
|
|
)
|
|
|
|
req = (
|
|
self.messages.AiplatformProjectsLocationsEndpointsDirectPredictRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1DirectPredictRequest=direct_predict_request,
|
|
)
|
|
)
|
|
return self.client.projects_locations_endpoints.DirectPredict(req)
|
|
|
|
def DirectPredictBeta(self, endpoint_ref, inputs_json):
|
|
"""Sends online direct prediction request to an endpoint using v1beta1 API."""
|
|
direct_predict_request = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DirectPredictRequest(
|
|
inputs=_ConvertPyListToMessageList(
|
|
self.messages.GoogleCloudAiplatformV1beta1Tensor,
|
|
inputs_json['inputs'],
|
|
)
|
|
)
|
|
)
|
|
if 'parameters' in inputs_json:
|
|
direct_predict_request.parameters = encoding.PyValueToMessage(
|
|
self.messages.GoogleCloudAiplatformV1beta1Tensor,
|
|
inputs_json['parameters'],
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsDirectPredictRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1DirectPredictRequest=direct_predict_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.DirectPredict(req)
|
|
|
|
def DirectRawPredict(self, endpoint_ref, input_json):
|
|
"""Sends online direct raw prediction request to an endpoint using v1 API."""
|
|
direct_raw_predict_request = self.messages.GoogleCloudAiplatformV1DirectRawPredictRequest(
|
|
input=bytes(input_json['input'], 'utf-8'),
|
|
# Method name can be "methodName" or "method_name"
|
|
methodName=input_json.get('methodName', input_json.get('method_name')),
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsDirectRawPredictRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1DirectRawPredictRequest=direct_raw_predict_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.DirectRawPredict(req)
|
|
|
|
def DirectRawPredictBeta(self, endpoint_ref, input_json):
|
|
"""Sends online direct raw prediction request to an endpoint using v1beta1 API."""
|
|
direct_raw_predict_request = self.messages.GoogleCloudAiplatformV1beta1DirectRawPredictRequest(
|
|
input=bytes(input_json['input'], 'utf-8'),
|
|
# Method name can be "methodName" or "method_name"
|
|
methodName=input_json.get('methodName', input_json.get('method_name')),
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsDirectRawPredictRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1DirectRawPredictRequest=direct_raw_predict_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.DirectRawPredict(req)
|
|
|
|
def Explain(self, endpoint_ref, instances_json, args):
|
|
"""Sends online explanation request to an endpoint using v1beta1 API."""
|
|
explain_request = self.messages.GoogleCloudAiplatformV1ExplainRequest(
|
|
instances=_ConvertPyListToMessageList(
|
|
extra_types.JsonValue, instances_json['instances']
|
|
)
|
|
)
|
|
if 'parameters' in instances_json:
|
|
explain_request.parameters = encoding.PyValueToMessage(
|
|
extra_types.JsonValue, instances_json['parameters']
|
|
)
|
|
if args.deployed_model_id is not None:
|
|
explain_request.deployedModelId = args.deployed_model_id
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsExplainRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1ExplainRequest=explain_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.Explain(req)
|
|
|
|
def ExplainBeta(self, endpoint_ref, instances_json, args):
|
|
"""Sends online explanation request to an endpoint using v1beta1 API."""
|
|
explain_request = self.messages.GoogleCloudAiplatformV1beta1ExplainRequest(
|
|
instances=_ConvertPyListToMessageList(
|
|
extra_types.JsonValue, instances_json['instances']
|
|
)
|
|
)
|
|
if 'parameters' in instances_json:
|
|
explain_request.parameters = encoding.PyValueToMessage(
|
|
extra_types.JsonValue, instances_json['parameters']
|
|
)
|
|
if 'explanation_spec_override' in instances_json:
|
|
explain_request.explanationSpecOverride = encoding.PyValueToMessage(
|
|
self.messages.GoogleCloudAiplatformV1beta1ExplanationSpecOverride,
|
|
instances_json['explanation_spec_override'],
|
|
)
|
|
if args.deployed_model_id is not None:
|
|
explain_request.deployedModelId = args.deployed_model_id
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsExplainRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1ExplainRequest=explain_request,
|
|
)
|
|
return self.client.projects_locations_endpoints.Explain(req)
|
|
|
|
def DeployModel(
|
|
self,
|
|
endpoint_ref,
|
|
model,
|
|
region,
|
|
display_name,
|
|
machine_type=None,
|
|
tpu_topology=None,
|
|
multihost_gpu_node_count=None,
|
|
accelerator_dict=None,
|
|
min_replica_count=None,
|
|
max_replica_count=None,
|
|
required_replica_count=None,
|
|
reservation_affinity=None,
|
|
autoscaling_metric_specs=None,
|
|
spot=False,
|
|
enable_access_logging=False,
|
|
disable_container_logging=False,
|
|
service_account=None,
|
|
traffic_split=None,
|
|
deployed_model_id=None,
|
|
gpu_partition_size=None,
|
|
):
|
|
"""Deploys a model to an existing endpoint using v1 API.
|
|
|
|
Args:
|
|
endpoint_ref: Resource, the parsed endpoint that the model is deployed to.
|
|
model: str, Id of the uploaded model to be deployed.
|
|
region: str, the location of the endpoint and the model.
|
|
display_name: str, the display name of the new deployed model.
|
|
machine_type: str or None, the type of the machine to serve the model.
|
|
tpu_topology: str or None, the topology of the TPU to serve the model.
|
|
multihost_gpu_node_count: int or None, the number of nodes per replica for
|
|
multihost GPU deployments.
|
|
accelerator_dict: dict or None, the accelerator attached to the deployed
|
|
model from args.
|
|
min_replica_count: int or None, the minimum number of replicas the
|
|
deployed model will be always deployed on.
|
|
max_replica_count: int or None, the maximum number of replicas the
|
|
deployed model may be deployed on.
|
|
required_replica_count: int or None, the required number of replicas the
|
|
deployed model will be considered successfully deployed.
|
|
reservation_affinity: dict or None, the reservation affinity of the
|
|
deployed model which specifies which reservations the deployed model can
|
|
use.
|
|
autoscaling_metric_specs: dict or None, the metric specification that
|
|
defines the target resource utilization for calculating the desired
|
|
replica count.
|
|
spot: bool, whether or not deploy the model on spot resources.
|
|
enable_access_logging: bool, whether or not enable access logs.
|
|
disable_container_logging: bool, whether or not disable container logging.
|
|
service_account: str or None, the service account that the deployed model
|
|
runs as.
|
|
traffic_split: dict or None, the new traffic split of the endpoint.
|
|
deployed_model_id: str or None, id of the deployed model.
|
|
gpu_partition_size: str or None, the partition size of the GPU
|
|
accelerator.
|
|
|
|
Returns:
|
|
A long-running operation for DeployModel.
|
|
"""
|
|
model_ref = _ParseModel(model, region)
|
|
|
|
resource_type = _GetModelDeploymentResourceType(model_ref, self.client)
|
|
if resource_type == 'DEDICATED_RESOURCES':
|
|
# dedicated resources
|
|
machine_spec = self.messages.GoogleCloudAiplatformV1MachineSpec()
|
|
if machine_type is not None:
|
|
machine_spec.machineType = machine_type
|
|
if tpu_topology is not None:
|
|
machine_spec.tpuTopology = tpu_topology
|
|
if multihost_gpu_node_count is not None:
|
|
machine_spec.multihostGpuNodeCount = multihost_gpu_node_count
|
|
accelerator = flags.ParseAcceleratorFlag(
|
|
accelerator_dict, constants.GA_VERSION
|
|
)
|
|
if accelerator is not None:
|
|
machine_spec.acceleratorType = accelerator.acceleratorType
|
|
machine_spec.acceleratorCount = accelerator.acceleratorCount
|
|
if reservation_affinity is not None:
|
|
machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag(
|
|
reservation_affinity, constants.GA_VERSION
|
|
)
|
|
if gpu_partition_size is not None:
|
|
machine_spec.gpuPartitionSize = gpu_partition_size
|
|
|
|
dedicated = self.messages.GoogleCloudAiplatformV1DedicatedResources(
|
|
machineSpec=machine_spec, spot=spot
|
|
)
|
|
# min-replica-count is required and must be >= 1 if models use dedicated
|
|
# resources. Default to 1 if not specified.
|
|
dedicated.minReplicaCount = min_replica_count or 1
|
|
if max_replica_count is not None:
|
|
dedicated.maxReplicaCount = max_replica_count
|
|
if required_replica_count is not None:
|
|
dedicated.requiredReplicaCount = required_replica_count
|
|
|
|
if autoscaling_metric_specs is not None:
|
|
autoscaling_metric_specs_list = []
|
|
for name, target in sorted(autoscaling_metric_specs.items()):
|
|
autoscaling_metric_specs_list.append(
|
|
self.messages.GoogleCloudAiplatformV1AutoscalingMetricSpec(
|
|
metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[name],
|
|
target=target,
|
|
)
|
|
)
|
|
dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list
|
|
|
|
deployed_model = self.messages.GoogleCloudAiplatformV1DeployedModel(
|
|
dedicatedResources=dedicated,
|
|
displayName=display_name,
|
|
model=model_ref.RelativeName(),
|
|
)
|
|
else:
|
|
# automatic resources
|
|
automatic = self.messages.GoogleCloudAiplatformV1AutomaticResources()
|
|
if min_replica_count is not None:
|
|
automatic.minReplicaCount = min_replica_count
|
|
if max_replica_count is not None:
|
|
automatic.maxReplicaCount = max_replica_count
|
|
|
|
deployed_model = self.messages.GoogleCloudAiplatformV1DeployedModel(
|
|
automaticResources=automatic,
|
|
displayName=display_name,
|
|
model=model_ref.RelativeName(),
|
|
)
|
|
|
|
deployed_model.enableAccessLogging = enable_access_logging
|
|
deployed_model.disableContainerLogging = disable_container_logging
|
|
|
|
if service_account is not None:
|
|
deployed_model.serviceAccount = service_account
|
|
|
|
if deployed_model_id is not None:
|
|
deployed_model.id = deployed_model_id
|
|
|
|
deployed_model_req = (
|
|
self.messages.GoogleCloudAiplatformV1DeployModelRequest(
|
|
deployedModel=deployed_model
|
|
)
|
|
)
|
|
|
|
if traffic_split is not None:
|
|
additional_properties = []
|
|
for key, value in sorted(traffic_split.items()):
|
|
additional_properties.append(
|
|
deployed_model_req.TrafficSplitValue().AdditionalProperty(
|
|
key=key, value=value
|
|
)
|
|
)
|
|
deployed_model_req.trafficSplit = deployed_model_req.TrafficSplitValue(
|
|
additionalProperties=additional_properties
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsDeployModelRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1DeployModelRequest=deployed_model_req,
|
|
)
|
|
return self.client.projects_locations_endpoints.DeployModel(req)
|
|
|
|
def DeployModelBeta(
|
|
self,
|
|
endpoint_ref,
|
|
model,
|
|
region,
|
|
display_name,
|
|
machine_type=None,
|
|
tpu_topology=None,
|
|
multihost_gpu_node_count=None,
|
|
accelerator_dict=None,
|
|
min_replica_count=None,
|
|
max_replica_count=None,
|
|
required_replica_count=None,
|
|
reservation_affinity=None,
|
|
autoscaling_metric_specs=None,
|
|
spot=False,
|
|
enable_access_logging=False,
|
|
enable_container_logging=False,
|
|
service_account=None,
|
|
traffic_split=None,
|
|
deployed_model_id=None,
|
|
shared_resources_ref=None,
|
|
min_scaleup_period=None,
|
|
idle_scaledown_period=None,
|
|
initial_replica_count=None,
|
|
gpu_partition_size=None,
|
|
):
|
|
"""Deploys a model to an existing endpoint using v1beta1 API.
|
|
|
|
Args:
|
|
endpoint_ref: Resource, the parsed endpoint that the model is deployed to.
|
|
model: str, Id of the uploaded model to be deployed.
|
|
region: str, the location of the endpoint and the model.
|
|
display_name: str, the display name of the new deployed model.
|
|
machine_type: str or None, the type of the machine to serve the model.
|
|
tpu_topology: str or None, the topology of the TPU to serve the model.
|
|
multihost_gpu_node_count: int or None, the number of nodes per replica for
|
|
multihost GPU deployments.
|
|
accelerator_dict: dict or None, the accelerator attached to the deployed
|
|
model from args.
|
|
min_replica_count: int or None, the minimum number of replicas the
|
|
deployed model will be always deployed on.
|
|
max_replica_count: int or None, the maximum number of replicas the
|
|
deployed model may be deployed on.
|
|
required_replica_count: int or None, the required number of replicas the
|
|
deployed model will be considered successfully deployed.
|
|
reservation_affinity: dict or None, the reservation affinity of the
|
|
deployed model which specifies which reservations the deployed model can
|
|
use.
|
|
autoscaling_metric_specs: dict or None, the metric specification that
|
|
defines the target resource utilization for calculating the desired
|
|
replica count.
|
|
spot: bool, whether or not deploy the model on spot resources.
|
|
enable_access_logging: bool, whether or not enable access logs.
|
|
enable_container_logging: bool, whether or not enable container logging.
|
|
service_account: str or None, the service account that the deployed model
|
|
runs as.
|
|
traffic_split: dict or None, the new traffic split of the endpoint.
|
|
deployed_model_id: str or None, id of the deployed model.
|
|
shared_resources_ref: str or None, the shared deployment resource pool the
|
|
model should use
|
|
min_scaleup_period: str or None, the minimum duration (in seconds) that a
|
|
deployment will be scaled up before traffic is evaluated for potential
|
|
scale-down. Defaults to 1 hour if min replica count is 0.
|
|
idle_scaledown_period: str or None, the duration after which the
|
|
deployment is scaled down if no traffic is received. This only applies
|
|
to deployments enrolled in scale-to-zero.
|
|
initial_replica_count: int or None, the initial number of replicas the
|
|
deployment will be scaled up to. This only applies to deployments
|
|
enrolled in scale-to-zero.
|
|
gpu_partition_size: str or None, the partition size of the GPU
|
|
accelerator.
|
|
|
|
Returns:
|
|
A long-running operation for DeployModel.
|
|
"""
|
|
is_gdc_ggs_model = _CheckIsGdcGgsModel(self, endpoint_ref)
|
|
if is_gdc_ggs_model:
|
|
# send psudo dedicated resources for gdc ggs model.
|
|
machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec(
|
|
machineType='n1-standard-2',
|
|
acceleratorType=self.messages.GoogleCloudAiplatformV1beta1MachineSpec.AcceleratorTypeValueValuesEnum.NVIDIA_TESLA_T4,
|
|
acceleratorCount=1,
|
|
)
|
|
dedicated = self.messages.GoogleCloudAiplatformV1beta1DedicatedResources(
|
|
machineSpec=machine_spec, minReplicaCount=1, maxReplicaCount=1
|
|
)
|
|
deployed_model = self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
|
|
dedicatedResources=dedicated,
|
|
displayName=display_name,
|
|
gdcConnectedModel=model,
|
|
)
|
|
else:
|
|
model_ref = _ParseModel(model, region)
|
|
resource_type = _GetModelDeploymentResourceType(
|
|
model_ref, self.client, shared_resources_ref
|
|
)
|
|
if resource_type == 'DEDICATED_RESOURCES':
|
|
# dedicated resources
|
|
machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec()
|
|
if machine_type is not None:
|
|
machine_spec.machineType = machine_type
|
|
if tpu_topology is not None:
|
|
machine_spec.tpuTopology = tpu_topology
|
|
if multihost_gpu_node_count is not None:
|
|
machine_spec.multihostGpuNodeCount = multihost_gpu_node_count
|
|
accelerator = flags.ParseAcceleratorFlag(
|
|
accelerator_dict, constants.BETA_VERSION
|
|
)
|
|
if accelerator is not None:
|
|
machine_spec.acceleratorType = accelerator.acceleratorType
|
|
machine_spec.acceleratorCount = accelerator.acceleratorCount
|
|
if reservation_affinity is not None:
|
|
machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag(
|
|
reservation_affinity, constants.BETA_VERSION
|
|
)
|
|
if gpu_partition_size is not None:
|
|
machine_spec.gpuPartitionSize = gpu_partition_size
|
|
|
|
dedicated = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DedicatedResources(
|
|
machineSpec=machine_spec, spot=spot
|
|
)
|
|
)
|
|
# min-replica-count is required and must be >= 0 if models use dedicated
|
|
# resources. If value is 0, the deployment will be enrolled in the
|
|
# scale-to-zero feature. Default to 1 if not specified.
|
|
dedicated.minReplicaCount = (
|
|
1 if min_replica_count is None else min_replica_count
|
|
)
|
|
|
|
# if not specified and min-replica-count is 0, default to 1.
|
|
if max_replica_count is None and dedicated.minReplicaCount == 0:
|
|
dedicated.maxReplicaCount = 1
|
|
else:
|
|
if max_replica_count is not None:
|
|
dedicated.maxReplicaCount = max_replica_count
|
|
if required_replica_count is not None:
|
|
dedicated.requiredReplicaCount = required_replica_count
|
|
|
|
# if not specified and min-replica-count is 0, default to 1.
|
|
if initial_replica_count is None and dedicated.minReplicaCount == 0:
|
|
dedicated.initialReplicaCount = 1
|
|
else:
|
|
if initial_replica_count is not None:
|
|
dedicated.initialReplicaCount = initial_replica_count
|
|
|
|
if autoscaling_metric_specs is not None:
|
|
autoscaling_metric_specs_list = []
|
|
for name, target in sorted(autoscaling_metric_specs.items()):
|
|
autoscaling_metric_specs_list.append(
|
|
self.messages.GoogleCloudAiplatformV1beta1AutoscalingMetricSpec(
|
|
metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[
|
|
name
|
|
],
|
|
target=target,
|
|
)
|
|
)
|
|
dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list
|
|
|
|
stz_spec = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DedicatedResourcesScaleToZeroSpec()
|
|
)
|
|
stz_spec_modified = False
|
|
if min_scaleup_period is not None:
|
|
stz_spec.minScaleupPeriod = '{}s'.format(min_scaleup_period)
|
|
stz_spec_modified = True
|
|
if idle_scaledown_period is not None:
|
|
stz_spec.idleScaledownPeriod = '{}s'.format(idle_scaledown_period)
|
|
stz_spec_modified = True
|
|
|
|
if stz_spec_modified:
|
|
dedicated.scaleToZeroSpec = stz_spec
|
|
|
|
deployed_model = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
|
|
dedicatedResources=dedicated,
|
|
displayName=display_name,
|
|
model=model_ref.RelativeName(),
|
|
)
|
|
)
|
|
elif resource_type == 'AUTOMATIC_RESOURCES':
|
|
# automatic resources
|
|
automatic = (
|
|
self.messages.GoogleCloudAiplatformV1beta1AutomaticResources()
|
|
)
|
|
if min_replica_count is not None:
|
|
automatic.minReplicaCount = min_replica_count
|
|
if max_replica_count is not None:
|
|
automatic.maxReplicaCount = max_replica_count
|
|
|
|
deployed_model = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
|
|
automaticResources=automatic,
|
|
displayName=display_name,
|
|
model=model_ref.RelativeName(),
|
|
)
|
|
)
|
|
# if resource type is SHARED_RESOURCES
|
|
else:
|
|
deployed_model = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
|
|
displayName=display_name,
|
|
model=model_ref.RelativeName(),
|
|
sharedResources=shared_resources_ref.RelativeName(),
|
|
)
|
|
)
|
|
|
|
deployed_model.enableAccessLogging = enable_access_logging
|
|
deployed_model.enableContainerLogging = enable_container_logging
|
|
|
|
if service_account is not None:
|
|
deployed_model.serviceAccount = service_account
|
|
|
|
if deployed_model_id is not None:
|
|
deployed_model.id = deployed_model_id
|
|
|
|
deployed_model_req = (
|
|
self.messages.GoogleCloudAiplatformV1beta1DeployModelRequest(
|
|
deployedModel=deployed_model
|
|
)
|
|
)
|
|
|
|
if traffic_split is not None:
|
|
additional_properties = []
|
|
for key, value in sorted(traffic_split.items()):
|
|
additional_properties.append(
|
|
deployed_model_req.TrafficSplitValue().AdditionalProperty(
|
|
key=key, value=value
|
|
)
|
|
)
|
|
deployed_model_req.trafficSplit = deployed_model_req.TrafficSplitValue(
|
|
additionalProperties=additional_properties
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsDeployModelRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1DeployModelRequest=deployed_model_req,
|
|
)
|
|
return self.client.projects_locations_endpoints.DeployModel(req)
|
|
|
|
def UndeployModel(self, endpoint_ref, deployed_model_id, traffic_split=None):
|
|
"""Undeploys a model from an endpoint using v1 API.
|
|
|
|
Args:
|
|
endpoint_ref: Resource, the parsed endpoint that the model is undeployed
|
|
from.
|
|
deployed_model_id: str, Id of the deployed model to be undeployed.
|
|
traffic_split: dict or None, the new traffic split of the endpoint.
|
|
|
|
Returns:
|
|
A long-running operation for UndeployModel.
|
|
"""
|
|
undeployed_model_req = (
|
|
self.messages.GoogleCloudAiplatformV1UndeployModelRequest(
|
|
deployedModelId=deployed_model_id
|
|
)
|
|
)
|
|
|
|
if traffic_split is not None:
|
|
additional_properties = []
|
|
for key, value in sorted(traffic_split.items()):
|
|
additional_properties.append(
|
|
undeployed_model_req.TrafficSplitValue().AdditionalProperty(
|
|
key=key, value=value
|
|
)
|
|
)
|
|
undeployed_model_req.trafficSplit = (
|
|
undeployed_model_req.TrafficSplitValue(
|
|
additionalProperties=additional_properties
|
|
)
|
|
)
|
|
|
|
req = (
|
|
self.messages.AiplatformProjectsLocationsEndpointsUndeployModelRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1UndeployModelRequest=undeployed_model_req,
|
|
)
|
|
)
|
|
return self.client.projects_locations_endpoints.UndeployModel(req)
|
|
|
|
def UndeployModelBeta(
|
|
self, endpoint_ref, deployed_model_id, traffic_split=None
|
|
):
|
|
"""Undeploys a model from an endpoint using v1beta1 API.
|
|
|
|
Args:
|
|
endpoint_ref: Resource, the parsed endpoint that the model is undeployed
|
|
from.
|
|
deployed_model_id: str, Id of the deployed model to be undeployed.
|
|
traffic_split: dict or None, the new traffic split of the endpoint.
|
|
|
|
Returns:
|
|
A long-running operation for UndeployModel.
|
|
"""
|
|
undeployed_model_req = (
|
|
self.messages.GoogleCloudAiplatformV1beta1UndeployModelRequest(
|
|
deployedModelId=deployed_model_id
|
|
)
|
|
)
|
|
|
|
if traffic_split is not None:
|
|
additional_properties = []
|
|
for key, value in sorted(traffic_split.items()):
|
|
additional_properties.append(
|
|
undeployed_model_req.TrafficSplitValue().AdditionalProperty(
|
|
key=key, value=value
|
|
)
|
|
)
|
|
undeployed_model_req.trafficSplit = (
|
|
undeployed_model_req.TrafficSplitValue(
|
|
additionalProperties=additional_properties
|
|
)
|
|
)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsEndpointsUndeployModelRequest(
|
|
endpoint=endpoint_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1UndeployModelRequest=undeployed_model_req,
|
|
)
|
|
return self.client.projects_locations_endpoints.UndeployModel(req)
|