feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying custom jobs in AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core.console import console_io
class CustomJobsClient(object):
"""Client used for interacting with CustomJob endpoint."""
def __init__(self, version=constants.GA_VERSION):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._version = version
self._service = client.projects_locations_customJobs
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def GetMessage(self, message_name):
"""Returns the API message class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def CustomJobMessage(self):
"""Retures the CustomJob resource message."""
return self.GetMessage('CustomJob')
def Create(self,
parent,
job_spec,
display_name=None,
kms_key_name=None,
labels=None):
"""Constructs a request and sends it to the endpoint to create a custom job instance.
Args:
parent: str, The project resource path of the custom job to create.
job_spec: The CustomJobSpec message instance for the job creation request.
display_name: str, The display name of the custom job to create.
kms_key_name: A customer-managed encryption key to use for the custom job.
labels: LabelValues, map-like user-defined metadata to organize the custom
jobs.
Returns:
A CustomJob message instance created.
"""
custom_job = self.CustomJobMessage()(
displayName=display_name, jobSpec=job_spec)
if kms_key_name is not None:
custom_job.encryptionSpec = self.GetMessage('EncryptionSpec')(
kmsKeyName=kms_key_name)
if labels:
custom_job.labels = labels
if self._version == constants.BETA_VERSION:
return self._service.Create(
self._messages.AiplatformProjectsLocationsCustomJobsCreateRequest(
parent=parent, googleCloudAiplatformV1beta1CustomJob=custom_job))
else:
return self._service.Create(
self._messages.AiplatformProjectsLocationsCustomJobsCreateRequest(
parent=parent, googleCloudAiplatformV1CustomJob=custom_job))
def List(self, limit=None, region=None):
return list_pager.YieldFromList(
self._service,
self._messages.AiplatformProjectsLocationsCustomJobsListRequest(
parent=region),
field='customJobs',
batch_size_attribute='pageSize',
limit=limit)
def Get(self, name):
request = self._messages.AiplatformProjectsLocationsCustomJobsGetRequest(
name=name)
return self._service.Get(request)
def Cancel(self, name):
request = self._messages.AiplatformProjectsLocationsCustomJobsCancelRequest(
name=name)
return self._service.Cancel(request)
def CheckJobComplete(self, name):
"""Returns a function to decide if log fetcher should continue polling.
Args:
name: String id of job.
Returns:
A one-argument function decides if log fetcher should continue.
"""
request = self._messages.AiplatformProjectsLocationsCustomJobsGetRequest(
name=name)
response = self._service.Get(request)
def ShouldContinue(periods_without_logs):
if periods_without_logs <= 1:
return True
return response.endTime is None
return ShouldContinue
def ImportResourceMessage(self, yaml_file, message_name):
"""Import a messages class instance typed by name from a YAML file."""
data = console_io.ReadFromFileOrStdin(yaml_file, binary=False)
message_type = self.GetMessage(message_name)
return export_util.Import(message_type=message_type, stream=data)

View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform deployment resource pools API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags
class DeploymentResourcePoolsClient(object):
"""High-level client for the AI Platform deployment resource pools surface."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version]
)
self.messages = messages or self.client.MESSAGES_MODULE
def CreateBeta(
self,
location_ref,
deployment_resource_pool_id,
autoscaling_metric_specs=None,
accelerator_dict=None,
min_replica_count=None,
max_replica_count=None,
machine_type=None,
tpu_topology=None,
multihost_gpu_node_count=None,
reservation_affinity=None,
spot=False,
required_replica_count=None,
):
"""Creates a new deployment resource pool using v1beta1 API.
Args:
location_ref: Resource, the parsed location to create a deployment
resource pool.
deployment_resource_pool_id: str, The ID to use for the
DeploymentResourcePool, which will become the final component of the
DeploymentResourcePool's resource name.
autoscaling_metric_specs: dict or None, the metric specification that
defines the target resource utilization for calculating the desired
replica count.
accelerator_dict: dict or None, the accelerator attached to the deployment
resource pool from args.
min_replica_count: int or None, The minimum number of machine replicas
this deployment resource pool will be always deployed on. This value
must be greater than or equal to 1.
max_replica_count: int or None, The maximum number of replicas this
deployment resource pool may be deployed on when the traffic against it
increases.
machine_type: str or None, Immutable. The type of the machine.
tpu_topology: str or None, the topology of the TPU to serve the model.
multihost_gpu_node_count: int or None, the number of nodes per replica for
multihost GPU deployments.
reservation_affinity: dict or None, the reservation affinity of the
deployed model which specifies which reservations the deployed model can
use.
spot: bool, whether or not deploy the model on spot resources.
required_replica_count: int or None, The required number of replicas this
deployment resource pool will be considered successfully deployed. This
value must be greater than or equal to 1 and less than or equal to
min_replica_count.
Returns:
A long-running operation for Create.
"""
machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec()
if machine_type is not None:
machine_spec.machineType = machine_type
if tpu_topology is not None:
machine_spec.tpuTopology = tpu_topology
if multihost_gpu_node_count is not None:
machine_spec.multihostGpuNodeCount = multihost_gpu_node_count
accelerator = flags.ParseAcceleratorFlag(
accelerator_dict, constants.BETA_VERSION
)
if accelerator is not None:
machine_spec.acceleratorType = accelerator.acceleratorType
machine_spec.acceleratorCount = accelerator.acceleratorCount
if reservation_affinity is not None:
machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag(
reservation_affinity, constants.BETA_VERSION
)
dedicated = self.messages.GoogleCloudAiplatformV1beta1DedicatedResources(
machineSpec=machine_spec, spot=spot
)
dedicated.minReplicaCount = min_replica_count or 1
if max_replica_count is not None:
dedicated.maxReplicaCount = max_replica_count
if required_replica_count is not None:
dedicated.requiredReplicaCount = required_replica_count
if autoscaling_metric_specs is not None:
autoscaling_metric_specs_list = []
for name, target in sorted(autoscaling_metric_specs.items()):
autoscaling_metric_specs_list.append(
self.messages.GoogleCloudAiplatformV1beta1AutoscalingMetricSpec(
metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[name],
target=target
)
)
dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list
pool = self.messages.GoogleCloudAiplatformV1beta1DeploymentResourcePool(
dedicatedResources=dedicated
)
pool_request = self.messages.GoogleCloudAiplatformV1beta1CreateDeploymentResourcePoolRequest(
deploymentResourcePool=pool,
deploymentResourcePoolId=deployment_resource_pool_id
)
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1CreateDeploymentResourcePoolRequest=pool_request
)
operation = self.client.projects_locations_deploymentResourcePools.Create(
req
)
return operation
def DeleteBeta(self, deployment_resource_pool_ref):
"""Deletes a deployment resource pool using v1beta1 API.
Args:
deployment_resource_pool_ref: str, The deployment resource pool to delete.
Returns:
A GoogleProtobufEmpty response message for delete.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsDeleteRequest(
name=deployment_resource_pool_ref.RelativeName()
)
operation = self.client.projects_locations_deploymentResourcePools.Delete(
req
)
return operation
def DescribeBeta(self, deployment_resource_pool_ref):
"""Describes a deployment resource pool using v1beta1 API.
Args:
deployment_resource_pool_ref: str, Deployment resource pool to describe.
Returns:
GoogleCloudAiplatformV1beta1DeploymentResourcePool response message.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsGetRequest(
name=deployment_resource_pool_ref.RelativeName()
)
response = self.client.projects_locations_deploymentResourcePools.Get(req)
return response
def ListBeta(self, location_ref):
"""Lists deployment resource pools using v1beta1 API.
Args:
location_ref: Resource, the parsed location to list deployment resource
pools.
Returns:
Nested attribute containing list of deployment resource pools.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsListRequest(
parent=location_ref.RelativeName()
)
return list_pager.YieldFromList(
self.client.projects_locations_deploymentResourcePools,
req,
field='deploymentResourcePools',
batch_size_attribute='pageSize'
)
def QueryDeployedModelsBeta(self, deployment_resource_pool_ref):
"""Queries deployed models sharing a specified deployment resource pool using v1beta1 API.
Args:
deployment_resource_pool_ref: str, Deployment resource pool to query.
Returns:
GoogleCloudAiplatformV1beta1QueryDeployedModelsResponse message.
"""
req = self.messages.AiplatformProjectsLocationsDeploymentResourcePoolsQueryDeployedModelsRequest(
deploymentResourcePool=deployment_resource_pool_ref.RelativeName()
)
response = self.client.projects_locations_deploymentResourcePools.QueryDeployedModels(
req
)
return response

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library for streaming prediction results from the Vertex AI PredictionService API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
from googlecloudsdk.api_lib.util import apis
class PredictionStreamer(object):
"""Streams prediction responses using gRPC."""
def __init__(self, version):
self.client = apis.GetGapicClientInstance('aiplatform', version)
def StreamDirectPredict(
self,
endpoint,
inputs,
parameters,
):
"""Streams prediction results from the Cloud Vertex AI PredictionService API.
Args:
endpoint: The name of the endpoint to stream predictions from.
inputs: The inputs to send to the endpoint.
parameters: The parameters to send to the endpoint.
Yields:
Streamed prediction results.
"""
# Construct the request.
request = self.client.types.StreamDirectPredictRequest(endpoint=endpoint)
for curr_input in inputs:
request.inputs.append(
self.client.types.Tensor.from_json(json.dumps(curr_input))
)
request.parameters = self.client.types.Tensor.from_json(
json.dumps(parameters)
)
for prediction in self.client.prediction.stream_direct_predict(
iter([request])
):
yield prediction
def StreamDirectRawPredict(
self,
endpoint,
method_name,
input,
):
"""Streams prediction results from the Cloud Vertex AI PredictionService API.
Args:
endpoint: The name of the endpoint to stream predictions from.
method_name: The name of the method to call.
input: The input bytes to send to the endpoint.
Yields:
Streamed prediction results.
"""
# Construct the request.
request = self.client.types.StreamDirectRawPredictRequest(
endpoint=endpoint, method_name=method_name, input=input
)
for prediction in self.client.prediction.stream_direct_raw_predict(
iter([request])
):
yield prediction

View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying hptuning-jobs in AI platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import messages as messages_util
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.core import yaml
def GetAlgorithmEnum(version=constants.BETA_VERSION):
messages = apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
if version == constants.GA_VERSION:
return messages.GoogleCloudAiplatformV1StudySpec.AlgorithmValueValuesEnum
else:
return messages.GoogleCloudAiplatformV1beta1StudySpec.AlgorithmValueValuesEnum
class HpTuningJobsClient(object):
"""Client used for interacting with HyperparameterTuningJob endpoint."""
def __init__(self, version):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._service = client.projects_locations_hyperparameterTuningJobs
self.version = version
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def _GetMessage(self, message_name):
"""Returns the API messsages class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def HyperparameterTuningJobMessage(self):
"""Returns the HyperparameterTuningJob resource message."""
return self._GetMessage('HyperparameterTuningJob')
def AlgorithmEnum(self):
"""Returns enum message representing Algorithm."""
return self._GetMessage('StudySpec').AlgorithmValueValuesEnum
def Create(
self,
config_path,
display_name,
parent=None,
max_trial_count=None,
parallel_trial_count=None,
algorithm=None,
kms_key_name=None,
network=None,
service_account=None,
enable_web_access=False,
enable_dashboard_access=False,
labels=None):
"""Creates a hyperparameter tuning job with given parameters.
Args:
config_path: str, the file path of the hyperparameter tuning job
configuration.
display_name: str, the display name of the created hyperparameter tuning
job.
parent: str, parent of the created hyperparameter tuning job. e.g.
/projects/xxx/locations/xxx/
max_trial_count: int, the desired total number of Trials. The default
value is 1.
parallel_trial_count: int, the desired number of Trials to run in
parallel. The default value is 1.
algorithm: AlgorithmValueValuesEnum, the search algorithm specified for
the Study.
kms_key_name: str, A customer-managed encryption key to use for the
hyperparameter tuning job.
network: str, user network to which the job should be peered with
(overrides yaml file)
service_account: str, A service account (email address string) to use for
the job.
enable_web_access: bool, Whether to enable the interactive shell for the
job.
enable_dashboard_access: bool, Whether to enable the dashboard defined for
the job.
labels: LabelsValues, map-like user-defined metadata to organize the
hp-tuning jobs.
Returns:
Created hyperparameter tuning job.
"""
job_spec = self.HyperparameterTuningJobMessage()
if config_path:
data = yaml.load_path(config_path)
if data:
job_spec = messages_util.DictToMessageWithErrorCheck(
data, self.HyperparameterTuningJobMessage())
if not job_spec.maxTrialCount and not max_trial_count:
job_spec.maxTrialCount = 1
elif max_trial_count:
job_spec.maxTrialCount = max_trial_count
if not job_spec.parallelTrialCount and not parallel_trial_count:
job_spec.parallelTrialCount = 1
elif parallel_trial_count:
job_spec.parallelTrialCount = parallel_trial_count
if network:
job_spec.trialJobSpec.network = network
if service_account:
job_spec.trialJobSpec.serviceAccount = service_account
if enable_web_access:
job_spec.trialJobSpec.enableWebAccess = enable_web_access
if enable_dashboard_access:
job_spec.trialJobSpec.enableDashboardAccess = enable_dashboard_access
if display_name:
job_spec.displayName = display_name
if algorithm and job_spec.studySpec:
job_spec.studySpec.algorithm = algorithm
if kms_key_name is not None:
job_spec.encryptionSpec = self._GetMessage('EncryptionSpec')(
kmsKeyName=kms_key_name)
if labels:
job_spec.labels = labels
if self.version == constants.GA_VERSION:
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsCreateRequest(
parent=parent,
googleCloudAiplatformV1HyperparameterTuningJob=job_spec)
else:
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsCreateRequest(
parent=parent,
googleCloudAiplatformV1beta1HyperparameterTuningJob=job_spec)
return self._service.Create(request)
def Get(self, name=None):
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsGetRequest(
name=name)
return self._service.Get(request)
def Cancel(self, name=None):
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsCancelRequest(
name=name)
return self._service.Cancel(request)
def List(self, limit=None, region=None):
return list_pager.YieldFromList(
self._service,
self._messages
.AiplatformProjectsLocationsHyperparameterTuningJobsListRequest(
parent=region),
field='hyperparameterTuningJobs',
batch_size_attribute='pageSize',
limit=limit)
def CheckJobComplete(self, name):
"""Returns a function to decide if log fetcher should continue polling.
Args:
name: String id of job.
Returns:
A one-argument function decides if log fetcher should continue.
"""
request = self._messages.AiplatformProjectsLocationsHyperparameterTuningJobsGetRequest(
name=name)
response = self._service.Get(request)
def ShouldContinue(periods_without_logs):
if periods_without_logs <= 1:
return True
return response.endTime is None
return ShouldContinue

View File

@@ -0,0 +1,518 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform index endpoints API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def _ParseIndex(index_id, location_id):
"""Parses a index ID into a index resource object."""
return resources.REGISTRY.Parse(
index_id,
params={
'locationsId': location_id,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='aiplatform.projects.locations.indexes')
class IndexEndpointsClient(object):
"""High-level client for the AI Platform index endpoints surface."""
def __init__(self, client=None, messages=None, version=constants.GA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_indexEndpoints
def CreateBeta(self, location_ref, args):
"""Create a new index endpoint."""
labels = labels_util.ParseCreateArgs(
args,
self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint.LabelsValue)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
private_service_connect_config = None
if args.enable_private_service_connect:
private_service_connect_config = (
self.messages.GoogleCloudAiplatformV1beta1PrivateServiceConnectConfig(
enablePrivateServiceConnect=args.enable_private_service_connect,
projectAllowlist=(args.project_allowlist
if args.project_allowlist else [])
)
)
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=args.public_endpoint_enabled,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
elif args.network is not None:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
displayName=args.display_name,
description=args.description,
network=args.network,
labels=labels,
encryptionSpec=encryption_spec,
),
)
else:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=True,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
return self._service.Create(req)
def Create(self, location_ref, args):
"""Create a new v1 index endpoint."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1IndexEndpoint.LabelsValue)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
private_service_connect_config = None
if args.enable_private_service_connect:
private_service_connect_config = (
self.messages.GoogleCloudAiplatformV1PrivateServiceConnectConfig(
enablePrivateServiceConnect=args.enable_private_service_connect,
projectAllowlist=(args.project_allowlist
if args.project_allowlist else []),
)
)
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=self.messages.GoogleCloudAiplatformV1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=args.public_endpoint_enabled,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
elif args.network is not None:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=self.messages.GoogleCloudAiplatformV1IndexEndpoint(
displayName=args.display_name,
description=args.description,
network=args.network,
labels=labels,
encryptionSpec=encryption_spec,
),
)
else:
req = self.messages.AiplatformProjectsLocationsIndexEndpointsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=self.messages.GoogleCloudAiplatformV1IndexEndpoint(
displayName=args.display_name,
description=args.description,
publicEndpointEnabled=True,
labels=labels,
encryptionSpec=encryption_spec,
privateServiceConnectConfig=private_service_connect_config,
),
)
return self._service.Create(req)
def PatchBeta(self, index_endpoint_ref, args):
"""Update an index endpoint."""
index_endpoint = self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint()
update_mask = []
if args.display_name is not None:
index_endpoint.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index_endpoint.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_endpoint_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args,
self.messages.GoogleCloudAiplatformV1beta1IndexEndpoint.LabelsValue,
GetLabels)
if labels_update.needs_update:
index_endpoint.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexEndpointsPatchRequest(
name=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1IndexEndpoint=index_endpoint,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Patch(self, index_endpoint_ref, args):
"""Update an v1 index endpoint."""
index_endpoint = self.messages.GoogleCloudAiplatformV1IndexEndpoint()
update_mask = []
if args.display_name is not None:
index_endpoint.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index_endpoint.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_endpoint_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1IndexEndpoint.LabelsValue,
GetLabels)
if labels_update.needs_update:
index_endpoint.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexEndpointsPatchRequest(
name=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1IndexEndpoint=index_endpoint,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def DeployIndexBeta(self, index_endpoint_ref, args):
"""Deploy an index to an index endpoint."""
index_ref = _ParseIndex(args.index, args.region)
deployed_index = self.messages.GoogleCloudAiplatformV1beta1DeployedIndex(
displayName=args.display_name,
id=args.deployed_index_id,
index=index_ref.RelativeName(),
)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.deployment_tier:
deployed_index.deploymentTier = self.messages.GoogleCloudAiplatformV1beta1DeployedIndex.DeploymentTierValueValuesEnum(
args.deployment_tier.upper())
if args.enable_access_logging is not None:
deployed_index.enableAccessLogging = args.enable_access_logging
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfig(
authProvider=auth_provider))
if args.machine_type is not None:
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1beta1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
deployed_index.dedicatedResources = dedicated_resources
else:
automatic_resources = (
self.messages.GoogleCloudAiplatformV1beta1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
deployed_index.automaticResources = automatic_resources
deploy_index_req = self.messages.GoogleCloudAiplatformV1beta1DeployIndexRequest(
deployedIndex=deployed_index)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsDeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1DeployIndexRequest=deploy_index_req)
return self._service.DeployIndex(request)
def DeployIndex(self, index_endpoint_ref, args):
"""Deploy an v1 index to an index endpoint."""
index_ref = _ParseIndex(args.index, args.region)
deployed_index = self.messages.GoogleCloudAiplatformV1DeployedIndex(
displayName=args.display_name,
id=args.deployed_index_id,
index=index_ref.RelativeName(),
enableAccessLogging=args.enable_access_logging
)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.deployment_tier:
deployed_index.deploymentTier = self.messages.GoogleCloudAiplatformV1DeployedIndex.DeploymentTierValueValuesEnum(
args.deployment_tier.upper())
# JWT Authentication config
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfig(
authProvider=auth_provider))
# PSC automation configs
if args.psc_automation_configs is not None:
deployed_index.pscAutomationConfigs = []
for psc_automation_config in args.psc_automation_configs:
deployed_index.pscAutomationConfigs.append(
self.messages.GoogleCloudAiplatformV1PSCAutomationConfig(
projectId=psc_automation_config['project-id'],
network=psc_automation_config['network'],
)
)
if args.machine_type is not None:
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
deployed_index.dedicatedResources = dedicated_resources
else:
automatic_resources = (
self.messages.GoogleCloudAiplatformV1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
deployed_index.automaticResources = automatic_resources
deploy_index_req = self.messages.GoogleCloudAiplatformV1DeployIndexRequest(
deployedIndex=deployed_index)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsDeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1DeployIndexRequest=deploy_index_req)
return self._service.DeployIndex(request)
def UndeployIndexBeta(self, index_endpoint_ref, args):
"""Undeploy an index to an index endpoint."""
undeploy_index_req = self.messages.GoogleCloudAiplatformV1beta1UndeployIndexRequest(
deployedIndexId=args.deployed_index_id)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsUndeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1UndeployIndexRequest=undeploy_index_req)
return self._service.UndeployIndex(request)
def UndeployIndex(self, index_endpoint_ref, args):
"""Undeploy an v1 index to an index endpoint."""
undeploy_index_req = self.messages.GoogleCloudAiplatformV1UndeployIndexRequest(
deployedIndexId=args.deployed_index_id)
request = self.messages.AiplatformProjectsLocationsIndexEndpointsUndeployIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1UndeployIndexRequest=undeploy_index_req)
return self._service.UndeployIndex(request)
def MutateDeployedIndexBeta(self, index_endpoint_ref, args):
"""Mutate a deployed index from an index endpoint."""
deployed_index = self.messages.GoogleCloudAiplatformV1beta1DeployedIndex(
id=args.deployed_index_id,
enableAccessLogging=args.enable_access_logging,
)
if args.machine_type is not None:
deployed_index.dedicatedResources = self._GetDedicatedResourcesBeta(args)
else:
deployed_index.automaticResources = self._GetAutomaticResourcesBeta(args)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1beta1DeployedIndexAuthConfig(
authProvider=auth_provider))
request = self.messages.AiplatformProjectsLocationsIndexEndpointsMutateDeployedIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1beta1DeployedIndex=deployed_index)
return self._service.MutateDeployedIndex(request)
def MutateDeployedIndex(self, index_endpoint_ref, args):
"""Mutate a deployed index from an index endpoint."""
deployed_index = self.messages.GoogleCloudAiplatformV1DeployedIndex(
id=args.deployed_index_id,
enableAccessLogging=args.enable_access_logging,
)
if args.machine_type is not None:
deployed_index.dedicatedResources = self._GetDedicatedResources(args)
else:
deployed_index.automaticResources = self._GetAutomaticResources(args)
if args.reserved_ip_ranges is not None:
deployed_index.reservedIpRanges.extend(args.reserved_ip_ranges)
if args.deployment_group is not None:
deployed_index.deploymentGroup = args.deployment_group
if args.audiences is not None and args.allowed_issuers is not None:
auth_provider = self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfigAuthProvider()
auth_provider.audiences.extend(args.audiences)
auth_provider.allowedIssuers.extend(args.allowed_issuers)
deployed_index.deployedIndexAuthConfig = (
self.messages.GoogleCloudAiplatformV1DeployedIndexAuthConfig(
authProvider=auth_provider))
request = self.messages.AiplatformProjectsLocationsIndexEndpointsMutateDeployedIndexRequest(
indexEndpoint=index_endpoint_ref.RelativeName(),
googleCloudAiplatformV1DeployedIndex=deployed_index)
return self._service.MutateDeployedIndex(request)
def Get(self, index_endpoint_ref):
request = self.messages.AiplatformProjectsLocationsIndexEndpointsGetRequest(
name=index_endpoint_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=None, region_ref=None):
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsIndexEndpointsListRequest(
parent=region_ref.RelativeName()),
field='indexEndpoints',
batch_size_attribute='pageSize',
limit=limit)
def Delete(self, index_endpoint_ref):
request = self.messages.AiplatformProjectsLocationsIndexEndpointsDeleteRequest(
name=index_endpoint_ref.RelativeName())
return self._service.Delete(request)
def _GetDedicatedResourcesBeta(self, args):
"""Construct dedicated resources for beta API."""
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1beta1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
return dedicated_resources
def _GetAutomaticResourcesBeta(self, args):
"""Construct automatic resources for beta API."""
automatic_resources = (
self.messages.GoogleCloudAiplatformV1beta1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
return automatic_resources
def _GetDedicatedResources(self, args):
"""Construct dedicated resources for GA API."""
dedicated_resources = (
self.messages.GoogleCloudAiplatformV1DedicatedResources()
)
dedicated_resources.machineSpec = (
self.messages.GoogleCloudAiplatformV1MachineSpec(
machineType=args.machine_type
)
)
if args.min_replica_count is not None:
dedicated_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
dedicated_resources.maxReplicaCount = args.max_replica_count
return dedicated_resources
def _GetAutomaticResources(self, args):
"""Construct automatic resources for GA API."""
automatic_resources = (
self.messages.GoogleCloudAiplatformV1AutomaticResources()
)
if args.min_replica_count is not None:
automatic_resources.minReplicaCount = args.min_replica_count
if args.max_replica_count is not None:
automatic_resources.maxReplicaCount = args.max_replica_count
return automatic_resources

View File

@@ -0,0 +1,313 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform indexes API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import extra_types
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import messages as messages_util
from googlecloudsdk.calliope import exceptions as gcloud_exceptions
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import yaml
class IndexesClient(object):
"""High-level client for the AI Platform indexes surface."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_indexes
def _ReadIndexMetadata(self, metadata_file):
"""Parse json metadata file."""
if not metadata_file:
raise gcloud_exceptions.BadArgumentException(
'--metadata-file', 'Index metadata file must be specified.')
index_metadata = None
# Yaml is a superset of json, so parse json file as yaml.
data = yaml.load_path(metadata_file)
if data:
index_metadata = messages_util.DictToMessageWithErrorCheck(
data, extra_types.JsonValue)
return index_metadata
def Get(self, index_ref):
request = self.messages.AiplatformProjectsLocationsIndexesGetRequest(
name=index_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=None, region_ref=None):
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsIndexesListRequest(
parent=region_ref.RelativeName()),
field='indexes',
batch_size_attribute='pageSize',
limit=limit)
def CreateBeta(self, location_ref, args):
"""Create a new index."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1beta1Index.LabelsValue)
index_update_method = None
if args.index_update_method:
if args.index_update_method == 'stream-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1beta1Index.
IndexUpdateMethodValueValuesEnum.STREAM_UPDATE)
elif args.index_update_method == 'batch-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1beta1Index.
IndexUpdateMethodValueValuesEnum.BATCH_UPDATE)
else:
raise gcloud_exceptions.BadArgumentException(
'--index-update-method',
'Invalid index update method: {}'.format(args.index_update_method),
)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
req = self.messages.AiplatformProjectsLocationsIndexesCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1Index=self.messages
.GoogleCloudAiplatformV1beta1Index(
displayName=args.display_name,
description=args.description,
metadata=self._ReadIndexMetadata(args.metadata_file),
labels=labels,
indexUpdateMethod=index_update_method,
encryptionSpec=encryption_spec
))
return self._service.Create(req)
def Create(self, location_ref, args):
"""Create a new v1 index."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1Index.LabelsValue)
index_update_method = None
if args.index_update_method:
if args.index_update_method == 'stream-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1Index
.IndexUpdateMethodValueValuesEnum.STREAM_UPDATE)
elif args.index_update_method == 'batch-update':
index_update_method = (
self.messages.GoogleCloudAiplatformV1Index.IndexUpdateMethodValueValuesEnum.BATCH_UPDATE
)
else:
raise gcloud_exceptions.BadArgumentException(
'--index-update-method',
'Invalid index update method: {}'.format(args.index_update_method),
)
encryption_spec = None
if args.encryption_kms_key_name is not None:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
kmsKeyName=args.encryption_kms_key_name))
req = self.messages.AiplatformProjectsLocationsIndexesCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1Index=self.messages.GoogleCloudAiplatformV1Index(
displayName=args.display_name,
description=args.description,
metadata=self._ReadIndexMetadata(args.metadata_file),
labels=labels,
indexUpdateMethod=index_update_method,
encryptionSpec=encryption_spec
))
return self._service.Create(req)
def PatchBeta(self, index_ref, args):
"""Update an index."""
index = self.messages.GoogleCloudAiplatformV1beta1Index()
update_mask = []
if args.metadata_file is not None:
index.metadata = self._ReadIndexMetadata(args.metadata_file)
update_mask.append('metadata')
else:
if args.display_name is not None:
index.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1beta1Index.LabelsValue,
GetLabels)
if labels_update.needs_update:
index.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexesPatchRequest(
name=index_ref.RelativeName(),
googleCloudAiplatformV1beta1Index=index,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Patch(self, index_ref, args):
"""Update an v1 index."""
index = self.messages.GoogleCloudAiplatformV1Index()
update_mask = []
if args.metadata_file is not None:
index.metadata = self._ReadIndexMetadata(args.metadata_file)
update_mask.append('metadata')
else:
if args.display_name is not None:
index.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
index.description = args.description
update_mask.append('description')
def GetLabels():
return self.Get(index_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1Index.LabelsValue,
GetLabels)
if labels_update.needs_update:
index.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsIndexesPatchRequest(
name=index_ref.RelativeName(),
googleCloudAiplatformV1Index=index,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Delete(self, index_ref):
request = self.messages.AiplatformProjectsLocationsIndexesDeleteRequest(
name=index_ref.RelativeName())
return self._service.Delete(request)
def RemoveDatapointsBeta(self, index_ref, args):
"""Remove data points from a v1beta1 index."""
if args.datapoint_ids and args.datapoints_from_file:
raise errors.ArgumentError(
'datapoint_ids and datapoints_from_file can not be set'
' at the same time.'
)
if args.datapoint_ids:
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1beta1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1beta1RemoveDatapointsRequest(
datapointIds=args.datapoint_ids))
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1beta1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1beta1RemoveDatapointsRequest(
datapointIds=data))
return self._service.RemoveDatapoints(req)
def RemoveDatapoints(self, index_ref, args):
"""Remove data points from a v1 index."""
if args.datapoint_ids and args.datapoints_from_file:
raise errors.ArgumentError(
'`--datapoint_ids` and `--datapoints_from_file` can not be set at the'
' same time.'
)
if args.datapoint_ids:
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1RemoveDatapointsRequest(
datapointIds=args.datapoint_ids))
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1RemoveDatapointsRequest=self.messages
.GoogleCloudAiplatformV1RemoveDatapointsRequest(
datapointIds=data))
return self._service.RemoveDatapoints(req)
def UpsertDatapointsBeta(self, index_ref, args):
"""Upsert data points from a v1beta1 index."""
datapoints = []
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
for datapoint_json in data:
datapoint = messages_util.DictToMessageWithErrorCheck(
datapoint_json,
self.messages.GoogleCloudAiplatformV1beta1IndexDatapoint)
datapoints.append(datapoint)
update_mask = None
if args.update_mask:
update_mask = ','.join(args.update_mask)
req = self.messages.AiplatformProjectsLocationsIndexesUpsertDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1beta1UpsertDatapointsRequest=self.messages
.GoogleCloudAiplatformV1beta1UpsertDatapointsRequest(
datapoints=datapoints,
updateMask=update_mask))
return self._service.UpsertDatapoints(req)
def UpsertDatapoints(self, index_ref, args):
"""Upsert data points from a v1 index."""
datapoints = []
if args.datapoints_from_file:
data = yaml.load_path(args.datapoints_from_file)
for datapoint_json in data:
datapoint = messages_util.DictToMessageWithErrorCheck(
datapoint_json,
self.messages.GoogleCloudAiplatformV1IndexDatapoint)
datapoints.append(datapoint)
update_mask = None
if args.update_mask:
update_mask = ','.join(args.update_mask)
req = self.messages.AiplatformProjectsLocationsIndexesUpsertDatapointsRequest(
index=index_ref.RelativeName(),
googleCloudAiplatformV1UpsertDatapointsRequest=self.messages
.GoogleCloudAiplatformV1UpsertDatapointsRequest(
datapoints=datapoints,
updateMask=update_mask))
return self._service.UpsertDatapoints(req)

View File

@@ -0,0 +1,515 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for Vertex AI Model Garden APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import flags
_HF_WILDCARD_FILTER = 'is_hf_wildcard(true)'
_NATIVE_MODEL_FILTER = 'is_hf_wildcard(false)'
_VERIFIED_DEPLOYMENT_FILTER = (
'labels.VERIFIED_DEPLOYMENT_CONFIG=VERIFIED_DEPLOYMENT_SUCCEED'
)
def IsHuggingFaceModel(model_name: str) -> bool:
"""Returns whether the model is a Hugging Face model."""
return bool(re.match(r'^[^/]+/[^/@]+$', model_name))
def IsCustomWeightsModel(model: str) -> bool:
"""Returns whether the model is a custom weights model."""
return bool(re.match(r'^gs://', model))
def DeployCustomWeightsModel(
messages,
projects_locations_service,
model,
machine_type,
accelerator_type,
accelerator_count,
project,
location,
):
"""Deploys a custom weights model."""
deploy_request = messages.GoogleCloudAiplatformV1beta1DeployRequest()
deploy_request.customModel = (
messages.GoogleCloudAiplatformV1beta1DeployRequestCustomModel(
gcsUri=model
)
)
if machine_type:
deploy_request.deployConfig = messages.GoogleCloudAiplatformV1beta1DeployRequestDeployConfig(
dedicatedResources=messages.GoogleCloudAiplatformV1beta1DedicatedResources(
machineSpec=messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=machine_type,
acceleratorType=accelerator_type,
acceleratorCount=accelerator_count,
),
minReplicaCount=1,
),
)
request = messages.AiplatformProjectsLocationsDeployRequest(
destination=f'projects/{project}/locations/{location}',
googleCloudAiplatformV1beta1DeployRequest=deploy_request,
)
return projects_locations_service.Deploy(request)
class ModelGardenClient(object):
"""Client used for interacting with Model Garden APIs."""
def __init__(self, version=constants.BETA_VERSION):
client = apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version],
)
self._messages = client.MESSAGES_MODULE
self._publishers_models_service = client.publishers_models
self._projects_locations_service = client.projects_locations
def GetPublisherModel(
self,
model_name,
is_hugging_face_model=False,
include_equivalent_model_garden_model_deployment_configs=True,
hugging_face_token=None,
):
"""Get a publisher model.
Args:
model_name: The name of the model to get. The format should be
publishers/{publisher}/models/{model}
is_hugging_face_model: Whether the model is a hugging face model.
include_equivalent_model_garden_model_deployment_configs: Whether to
include equivalent Model Garden model deployment configs for Hugging
Face models.
hugging_face_token: The Hugging Face access token to access the model
artifacts for gated models unverified by Model Garden.
Returns:
A publisher model.
"""
request = self._messages.AiplatformPublishersModelsGetRequest(
name=model_name,
isHuggingFaceModel=is_hugging_face_model,
includeEquivalentModelGardenModelDeploymentConfigs=include_equivalent_model_garden_model_deployment_configs,
huggingFaceToken=hugging_face_token,
)
return self._publishers_models_service.Get(request)
def Deploy(
self,
project,
location,
model,
accept_eula,
accelerator_type,
accelerator_count,
machine_type,
endpoint_display_name,
hugging_face_access_token,
spot,
reservation_affinity,
use_dedicated_endpoint,
enable_fast_tryout,
container_image_uri=None,
container_command=None,
container_args=None,
container_env_vars=None,
container_ports=None,
container_grpc_ports=None,
container_predict_route=None,
container_health_route=None,
container_deployment_timeout_seconds=None,
container_shared_memory_size_mb=None,
container_startup_probe_exec=None,
container_startup_probe_period_seconds=None,
container_startup_probe_timeout_seconds=None,
container_health_probe_exec=None,
container_health_probe_period_seconds=None,
container_health_probe_timeout_seconds=None,
):
"""Deploy an open weight model.
Args:
project: The project to deploy the model to.
location: The location to deploy the model to.
model: The name of the model to deploy or its gcs uri for custom weights.
accept_eula: Whether to accept the end-user license agreement.
accelerator_type: The type of accelerator to use.
accelerator_count: The number of accelerators to use.
machine_type: The type of machine to use.
endpoint_display_name: The display name of the endpoint.
hugging_face_access_token: The Hugging Face access token.
spot: Whether to deploy the model on Spot VMs.
reservation_affinity: The reservation affinity to use.
use_dedicated_endpoint: Whether to use a dedicated endpoint.
enable_fast_tryout: Whether to enable fast tryout.
container_image_uri: Immutable. URI of the Docker image to be used as the
custom container for serving predictions. This URI must identify an
image in Artifact Registry or Container Registry. Learn more about the
[container publishing requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#publishing), including
permissions requirements for the Vertex AI Service Agent. The container
image is ingested upon ModelService.UploadModel, stored internally, and
this original path is afterwards not used. To learn about the
requirements for the Docker image itself, see [Custom container
requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#). You can use the URI
to one of Vertex AI's [pre-built container images for
prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-
built-containers) in this field.
container_command: Specifies the command that runs when the container
starts. This overrides the container's [ENTRYPOINT](https://docs.docker.
com/engine/reference/builder/#entrypoint). Specify this field as an
array of executable and arguments, similar to a Docker `ENTRYPOINT`'s
"exec" form, not its "shell" form. If you do not specify this field,
then the container's `ENTRYPOINT` runs, in conjunction with the args
field or the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd), if
either exists. If this field is not specified and the container does not
have an `ENTRYPOINT`, then refer to the Docker documentation about [how
`CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). If you specify this field, then you
can also specify the `args` field to provide additional arguments for
this command. However, if you specify this field, then the container's
`CMD` is ignored. See the [Kubernetes documentation about how the
`command` and `args` fields interact with a container's `ENTRYPOINT` and
`CMD`](https://kubernetes.io/docs/tasks/inject-data-application/define-
command-argument-container/#notes). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `command` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_args: Specifies arguments for the command that runs when the
container starts. This overrides the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd). Specify
this field as an array of executable and arguments, similar to a Docker
`CMD`'s "default parameters" form. If you don't specify this field but
do specify the command field, then the command from the `command` field
runs without any additional arguments. See the [Kubernetes documentation
about how the `command` and `args` fields interact with a container's
`ENTRYPOINT` and `CMD`](https://kubernetes.io/docs/tasks/inject-data-
application/define-command-argument-container/#notes). If you don't
specify this field and don't specify the `command` field, then the
container's
[`ENTRYPOINT`](https://docs.docker.com/engine/reference/builder/#cmd)
and `CMD` determine what runs based on their default behavior. See the
Docker documentation about [how `CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `args` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core)..
container_env_vars: List of environment variables to set in the container.
After the container starts running, code running in the container can
read these environment variables. Additionally, the command and args
fields can reference these variables. Later entries in this list can
also reference earlier entries. For example, the following example sets
the variable `VAR_2` to have the value `foo bar`: ```json [ { "name":
"VAR_1", "value": "foo" }, { "name": "VAR_2", "value": "$(VAR_1) bar" }
] ``` If you switch the order of the variables in the example, then the
expansion does not occur. This field corresponds to the `env` field of
the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_ports: List of ports to expose from the container. Vertex AI
sends any http prediction requests that it receives to the first port on
this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, it defaults to following value: ```json [ { "containerPort":
8080 } ] ``` Vertex AI does not use ports other than the first one
listed. This field corresponds to the `ports` field of the Kubernetes
Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_grpc_ports: List of ports to expose from the container. Vertex
AI sends any grpc prediction requests that it receives to the first port
on this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, gRPC requests to the container will be disabled. Vertex AI
does not use ports other than the first one listed. This field
corresponds to the `ports` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_predict_route: HTTP path on the container to send prediction
requests to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the container's IP
address and port. Vertex AI then returns the container's response in the
API response. For example, if you set this field to `/foo`, then when
Vertex AI receives a prediction request, it forwards the request body in
a POST request to the `/foo` path on the port of your container
specified by the first value of this `ModelContainerSpec`'s ports field.
If you don't specify this field, it defaults to the following value when
you deploy this Model to an Endpoint:
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict The
placeholders in this value are replaced as follows: * ENDPOINT: The last
segment (following `endpoints/`)of the Endpoint.name][] field of the
Endpoint where this Model has been deployed. (Vertex AI makes this value
available to your container code as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_health_route: HTTP path on the container to send health checks
to. Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is healthy.
Read more about [health checks](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#health). For example,
if you set this field to `/bar`, then Vertex AI intermittently sends a
GET request to the `/bar` path on the port of your container specified
by the first value of this `ModelContainerSpec`'s ports field. If you
don't specify this field, it defaults to the following value when you
deploy this Model to an Endpoint: /v1/endpoints/ENDPOINT/deployedModels/
DEPLOYED_MODEL:predict The placeholders in this value are replaced as
follows * ENDPOINT: The last segment (following `endpoints/`)of the
Endpoint.name][] field of the Endpoint where this Model has been
deployed. (Vertex AI makes this value available to your container code
as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_deployment_timeout_seconds (int): Deployment timeout in seconds.
container_shared_memory_size_mb (int): The amount of the VM memory to
reserve as the shared memory for the model in megabytes.
container_startup_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by startup probe. An example of this argument would be
["cat", "/tmp/healthy"]
container_startup_probe_period_seconds (int): How often (in seconds) to
perform the startup probe. Default to 10 seconds. Minimum value is 1.
container_startup_probe_timeout_seconds (int): Number of seconds after
which the startup probe times out. Defaults to 1 second. Minimum value
is 1.
container_health_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by health probe. An example of this argument would be ["cat",
"/tmp/healthy"]
container_health_probe_period_seconds (int): How often (in seconds) to
perform the health probe. Default to 10 seconds. Minimum value is 1.
container_health_probe_timeout_seconds (int): Number of seconds after
which the health probe times out. Defaults to 1 second. Minimum value is
1.
Returns:
The deploy long-running operation.
"""
container_spec = None
if container_image_uri:
container_spec = (
self._messages.GoogleCloudAiplatformV1beta1ModelContainerSpec(
healthRoute=container_health_route,
imageUri=container_image_uri,
predictRoute=container_predict_route,
)
)
if container_command:
container_spec.command = container_command
if container_args:
container_spec.args = container_args
if container_env_vars:
container_spec.env = [
self._messages.GoogleCloudAiplatformV1beta1EnvVar(
name=k, value=container_env_vars[k]
)
for k in container_env_vars
]
if container_ports:
container_spec.ports = [
self._messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_ports
]
if container_grpc_ports:
container_spec.grpcPorts = [
self._messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_grpc_ports
]
if container_deployment_timeout_seconds:
container_spec.deploymentTimeout = (
str(container_deployment_timeout_seconds) + 's'
)
if container_shared_memory_size_mb:
container_spec.sharedMemorySizeMb = container_shared_memory_size_mb
if (
container_startup_probe_exec
or container_startup_probe_period_seconds
or container_startup_probe_timeout_seconds
):
startup_probe_exec = None
if container_startup_probe_exec:
startup_probe_exec = (
self._messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_startup_probe_exec
)
)
container_spec.startupProbe = (
self._messages.GoogleCloudAiplatformV1beta1Probe(
exec_=startup_probe_exec,
periodSeconds=container_startup_probe_period_seconds,
timeoutSeconds=container_startup_probe_timeout_seconds,
)
)
if (
container_health_probe_exec
or container_health_probe_period_seconds
or container_health_probe_timeout_seconds
):
health_probe_exec = None
if container_health_probe_exec:
health_probe_exec = (
self._messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_health_probe_exec
)
)
container_spec.healthProbe = (
self._messages.GoogleCloudAiplatformV1beta1Probe(
exec_=health_probe_exec,
periodSeconds=container_health_probe_period_seconds,
timeoutSeconds=container_health_probe_timeout_seconds,
)
)
if IsCustomWeightsModel(model):
return DeployCustomWeightsModel(
self._messages,
self._projects_locations_service,
model,
machine_type,
accelerator_type,
accelerator_count,
project,
location,
)
elif IsHuggingFaceModel(model):
deploy_request = self._messages.GoogleCloudAiplatformV1beta1DeployRequest(
huggingFaceModelId=model
)
else:
deploy_request = self._messages.GoogleCloudAiplatformV1beta1DeployRequest(
publisherModelName=model
)
deploy_request.modelConfig = (
self._messages.GoogleCloudAiplatformV1beta1DeployRequestModelConfig(
huggingFaceAccessToken=hugging_face_access_token,
acceptEula=accept_eula,
containerSpec=container_spec,
)
)
deploy_request.endpointConfig = (
self._messages.GoogleCloudAiplatformV1beta1DeployRequestEndpointConfig(
endpointDisplayName=endpoint_display_name,
dedicatedEndpointEnabled=use_dedicated_endpoint,
)
)
deploy_request.deployConfig = self._messages.GoogleCloudAiplatformV1beta1DeployRequestDeployConfig(
dedicatedResources=self._messages.GoogleCloudAiplatformV1beta1DedicatedResources(
machineSpec=self._messages.GoogleCloudAiplatformV1beta1MachineSpec(
machineType=machine_type,
acceleratorType=accelerator_type,
acceleratorCount=accelerator_count,
reservationAffinity=flags.ParseReservationAffinityFlag(
reservation_affinity, constants.BETA_VERSION
),
),
minReplicaCount=1,
spot=spot,
),
fastTryoutEnabled=enable_fast_tryout,
)
request = self._messages.AiplatformProjectsLocationsDeployRequest(
destination=f'projects/{project}/locations/{location}',
googleCloudAiplatformV1beta1DeployRequest=deploy_request,
)
return self._projects_locations_service.Deploy(request)
def ListPublisherModels(
self,
limit=None,
batch_size=100,
list_hf_models=False,
model_filter=None,
):
"""List publisher models in Model Garden.
Args:
limit: The maximum number of items to list. None if all available records
should be yielded.
batch_size: The number of items to list per page.
list_hf_models: Whether to only list Hugging Face models.
model_filter: The filter on model name to apply on server-side.
Returns:
The list of publisher models in Model Garden..
"""
filter_str = _NATIVE_MODEL_FILTER
if list_hf_models:
filter_str = ' AND '.join(
[_HF_WILDCARD_FILTER, _VERIFIED_DEPLOYMENT_FILTER]
)
if model_filter:
filter_str = (
f'{filter_str} AND (model_user_id=~"(?i).*{model_filter}.*" OR'
f' display_name=~"(?i).*{model_filter}.*")'
)
return list_pager.YieldFromList(
self._publishers_models_service,
self._messages.AiplatformPublishersModelsListRequest(
parent='publishers/*',
listAllVersions=True,
filter=filter_str,
),
field='publisherModels',
batch_size_attribute='pageSize',
batch_size=batch_size,
limit=limit,
)

View File

@@ -0,0 +1,528 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform model monitoring jobs API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
from apitools.base.py import encoding
from apitools.base.py import extra_types
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import messages as messages_util
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai import model_monitoring_jobs_util
from googlecloudsdk.command_lib.ai import validation as common_validation
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
import six
def _ParseEndpoint(endpoint_id, region_ref):
"""Parses a endpoint ID into a endpoint resource object."""
region = region_ref.AsDict()['locationsId']
return resources.REGISTRY.Parse(
endpoint_id,
params={
'locationsId': region,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='aiplatform.projects.locations.endpoints')
def _ParseDataset(dataset_id, region_ref):
"""Parses a dataset ID into a dataset resource object."""
region = region_ref.AsDict()['locationsId']
return resources.REGISTRY.Parse(
dataset_id,
params={
'locationsId': region,
'projectsId': properties.VALUES.core.project.GetOrFail
},
collection='aiplatform.projects.locations.datasets')
class ModelMonitoringJobsClient(object):
"""High-level client for the AI Platform model deployment monitoring jobs surface."""
def __init__(self, client=None, messages=None, version=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_modelDeploymentMonitoringJobs
self._version = version
def _ConstructDriftThresholds(self, feature_thresholds,
feature_attribution_thresholds):
"""Construct drift thresholds from user input.
Args:
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key:feature_name, value:
attribution score thresholds.
Returns:
PredictionDriftDetectionConfig
"""
prediction_drift_detection = api_util.GetMessage(
'ModelMonitoringObjectiveConfigPredictionDriftDetectionConfig',
self._version)()
additional_properties = []
attribution_additional_properties = []
if feature_thresholds:
for key, value in feature_thresholds.items():
threshold = 0.3 if not value else float(value)
additional_properties.append(prediction_drift_detection
.DriftThresholdsValue().AdditionalProperty(
key=key,
value=api_util.GetMessage(
'ThresholdConfig',
self._version)(value=threshold)))
prediction_drift_detection.driftThresholds = prediction_drift_detection.DriftThresholdsValue(
additionalProperties=additional_properties)
if feature_attribution_thresholds:
for key, value in feature_attribution_thresholds.items():
threshold = 0.3 if not value else float(value)
attribution_additional_properties.append(
prediction_drift_detection.AttributionScoreDriftThresholdsValue(
).AdditionalProperty(
key=key,
value=api_util.GetMessage('ThresholdConfig',
self._version)(value=threshold)))
prediction_drift_detection.attributionScoreDriftThresholds = prediction_drift_detection.AttributionScoreDriftThresholdsValue(
additionalProperties=attribution_additional_properties)
return prediction_drift_detection
def _ConstructSkewThresholds(self, feature_thresholds,
feature_attribution_thresholds):
"""Construct skew thresholds from user input.
Args:
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key:feature_name, value:
attribution score thresholds.
Returns:
TrainingPredictionSkewDetectionConfig
"""
training_prediction_skew_detection = api_util.GetMessage(
'ModelMonitoringObjectiveConfigTrainingPredictionSkewDetectionConfig',
self._version)()
additional_properties = []
attribution_additional_properties = []
if feature_thresholds:
for key, value in feature_thresholds.items():
threshold = 0.3 if not value else float(value)
additional_properties.append(training_prediction_skew_detection
.SkewThresholdsValue().AdditionalProperty(
key=key,
value=api_util.GetMessage(
'ThresholdConfig',
self._version)(value=threshold)))
training_prediction_skew_detection.skewThresholds = training_prediction_skew_detection.SkewThresholdsValue(
additionalProperties=additional_properties)
if feature_attribution_thresholds:
for key, value in feature_attribution_thresholds.items():
threshold = 0.3 if not value else float(value)
attribution_additional_properties.append(
training_prediction_skew_detection
.AttributionScoreSkewThresholdsValue().AdditionalProperty(
key=key,
value=api_util.GetMessage('ThresholdConfig',
self._version)(value=threshold)))
training_prediction_skew_detection.attributionScoreSkewThresholds = training_prediction_skew_detection.AttributionScoreSkewThresholdsValue(
additionalProperties=attribution_additional_properties)
return training_prediction_skew_detection
def _ConstructObjectiveConfigForUpdate(self, existing_monitoring_job,
feature_thresholds,
feature_attribution_thresholds):
"""Construct monitoring objective config.
Update the feature thresholds for skew/drift detection to all the existing
deployed models under the job.
Args:
existing_monitoring_job: Existing monitoring job.
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key: feature_name, value:
attribution score thresholds.
Returns:
A list of model monitoring objective config.
"""
prediction_drift_detection = self._ConstructDriftThresholds(
feature_thresholds, feature_attribution_thresholds)
training_prediction_skew_detection = self._ConstructSkewThresholds(
feature_thresholds, feature_attribution_thresholds)
objective_configs = []
for objective_config in existing_monitoring_job.modelDeploymentMonitoringObjectiveConfigs:
if objective_config.objectiveConfig.trainingPredictionSkewDetectionConfig:
if training_prediction_skew_detection.skewThresholds:
objective_config.objectiveConfig.trainingPredictionSkewDetectionConfig.skewThresholds = training_prediction_skew_detection.skewThresholds
if training_prediction_skew_detection.attributionScoreSkewThresholds:
objective_config.objectiveConfig.trainingPredictionSkewDetectionConfig.attributionScoreSkewThresholds = training_prediction_skew_detection.attributionScoreSkewThresholds
if objective_config.objectiveConfig.predictionDriftDetectionConfig:
if prediction_drift_detection.driftThresholds:
objective_config.objectiveConfig.predictionDriftDetectionConfig.driftThresholds = prediction_drift_detection.driftThresholds
if prediction_drift_detection.attributionScoreDriftThresholds:
objective_config.objectiveConfig.predictionDriftDetectionConfig.attributionScoreDriftThresholds = prediction_drift_detection.attributionScoreDriftThresholds
if training_prediction_skew_detection.attributionScoreSkewThresholds or prediction_drift_detection.attributionScoreDriftThresholds:
objective_config.objectiveConfig.explanationConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfigExplanationConfig', self._version)(
enableFeatureAttributes=True)
objective_configs.append(objective_config)
return objective_configs
def _ConstructObjectiveConfigForCreate(self, location_ref, endpoint_name,
feature_thresholds,
feature_attribution_thresholds,
dataset, bigquery_uri, data_format,
gcs_uris, target_field,
training_sampling_rate):
"""Construct monitoring objective config.
Apply the feature thresholds for skew or drift detection to all the deployed
models under the endpoint.
Args:
location_ref: Location reference.
endpoint_name: Endpoint resource name.
feature_thresholds: Dict or None, key: feature_name, value: thresholds.
feature_attribution_thresholds: Dict or None, key: feature_name, value:
attribution score thresholds.
dataset: Vertex AI Dataset Id.
bigquery_uri: The BigQuery table of the unmanaged Dataset used to train
this Model.
data_format: Google Cloud Storage format, supported format: csv,
tf-record.
gcs_uris: The Google Cloud Storage uri of the unmanaged Dataset used to
train this Model.
target_field: The target field name the model is to predict.
training_sampling_rate: Training Dataset sampling rate.
Returns:
A list of model monitoring objective config.
"""
objective_config_template = api_util.GetMessage(
'ModelDeploymentMonitoringObjectiveConfig', self._version)()
if feature_thresholds or feature_attribution_thresholds:
if dataset or bigquery_uri or gcs_uris or data_format:
training_dataset = api_util.GetMessage(
'ModelMonitoringObjectiveConfigTrainingDataset', self._version)()
if target_field is None:
raise errors.ArgumentError(
"Target field must be provided if you'd like to do training-prediction skew detection."
)
training_dataset.targetField = target_field
training_dataset.loggingSamplingStrategy = api_util.GetMessage(
'SamplingStrategy', self._version)(
randomSampleConfig=api_util.GetMessage(
'SamplingStrategyRandomSampleConfig', self._version)(
sampleRate=training_sampling_rate))
if dataset:
training_dataset.dataset = _ParseDataset(dataset,
location_ref).RelativeName()
elif bigquery_uri:
training_dataset.bigquerySource = api_util.GetMessage(
'BigQuerySource', self._version)(
inputUri=bigquery_uri)
elif gcs_uris or data_format:
if gcs_uris is None:
raise errors.ArgumentError(
'Data format is defined but no Google Cloud Storage uris are provided. Please use --gcs-uris to provide training datasets.'
)
if data_format is None:
raise errors.ArgumentError(
'No Data format is defined for Google Cloud Storage training dataset. Please use --data-format to define the Data format.'
)
training_dataset.dataFormat = data_format
training_dataset.gcsSource = api_util.GetMessage(
'GcsSource', self._version)(
uris=gcs_uris)
training_prediction_skew_detection = self._ConstructSkewThresholds(
feature_thresholds, feature_attribution_thresholds)
objective_config_template.objectiveConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfig', self._version
)(trainingDataset=training_dataset,
trainingPredictionSkewDetectionConfig=training_prediction_skew_detection
)
else:
prediction_drift_detection = self._ConstructDriftThresholds(
feature_thresholds, feature_attribution_thresholds)
objective_config_template.objectiveConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfig', self._version)(
predictionDriftDetectionConfig=prediction_drift_detection)
if feature_attribution_thresholds:
objective_config_template.objectiveConfig.explanationConfig = api_util.GetMessage(
'ModelMonitoringObjectiveConfigExplanationConfig', self._version)(
enableFeatureAttributes=True)
get_endpoint_req = self.messages.AiplatformProjectsLocationsEndpointsGetRequest(
name=endpoint_name)
endpoint = self.client.projects_locations_endpoints.Get(get_endpoint_req)
objective_configs = []
for deployed_model in endpoint.deployedModels:
objective_config = copy.deepcopy(objective_config_template)
objective_config.deployedModelId = deployed_model.id
objective_configs.append(objective_config)
return objective_configs
def _ParseCreateLabels(self, args):
"""Parses create labels."""
return labels_util.ParseCreateArgs(
args,
api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)().LabelsValue)
def _ParseUpdateLabels(self, model_monitoring_job_ref, args):
"""Parses update labels."""
def GetLabels():
return self.Get(model_monitoring_job_ref).labels
return labels_util.ProcessUpdateArgsLazy(
args,
api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)().LabelsValue, GetLabels)
def Create(self, location_ref, args):
"""Creates a model deployment monitoring job."""
endpoint_ref = _ParseEndpoint(args.endpoint, location_ref)
job_spec = api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)()
kms_key_name = common_validation.GetAndValidateKmsKey(args)
if kms_key_name is not None:
job_spec.encryptionSpec = api_util.GetMessage('EncryptionSpec',
self._version)(
kmsKeyName=kms_key_name)
if args.monitoring_config_from_file:
data = yaml.load_path(args.monitoring_config_from_file)
if data:
job_spec = messages_util.DictToMessageWithErrorCheck(
data,
api_util.GetMessage('ModelDeploymentMonitoringJob', self._version))
else:
job_spec.modelDeploymentMonitoringObjectiveConfigs = self._ConstructObjectiveConfigForCreate(
location_ref, endpoint_ref.RelativeName(), args.feature_thresholds,
args.feature_attribution_thresholds, args.dataset, args.bigquery_uri,
args.data_format, args.gcs_uris, args.target_field,
args.training_sampling_rate)
job_spec.endpoint = endpoint_ref.RelativeName()
job_spec.displayName = args.display_name
job_spec.labels = self._ParseCreateLabels(args)
enable_anomaly_cloud_logging = False if args.anomaly_cloud_logging is None else args.anomaly_cloud_logging
job_spec.modelMonitoringAlertConfig = api_util.GetMessage(
'ModelMonitoringAlertConfig', self._version)(
enableLogging=enable_anomaly_cloud_logging,
emailAlertConfig=api_util.GetMessage(
'ModelMonitoringAlertConfigEmailAlertConfig',
self._version)(userEmails=args.emails),
notificationChannels=args.notification_channels)
job_spec.loggingSamplingStrategy = api_util.GetMessage(
'SamplingStrategy', self._version)(
randomSampleConfig=api_util.GetMessage(
'SamplingStrategyRandomSampleConfig', self._version)(
sampleRate=args.prediction_sampling_rate))
job_spec.modelDeploymentMonitoringScheduleConfig = api_util.GetMessage(
'ModelDeploymentMonitoringScheduleConfig', self._version)(
monitorInterval='{}s'.format(
six.text_type(3600 * int(args.monitoring_frequency))))
if args.predict_instance_schema:
job_spec.predictInstanceSchemaUri = args.predict_instance_schema
if args.analysis_instance_schema:
job_spec.analysisInstanceSchemaUri = args.analysis_instance_schema
if args.log_ttl:
job_spec.logTtl = '{}s'.format(six.text_type(86400 * int(args.log_ttl)))
if args.sample_predict_request:
instance_json = model_monitoring_jobs_util.ReadInstanceFromArgs(
args.sample_predict_request)
job_spec.samplePredictInstance = encoding.PyValueToMessage(
extra_types.JsonValue, instance_json)
if self._version == constants.BETA_VERSION:
return self._service.Create(
self.messages.
AiplatformProjectsLocationsModelDeploymentMonitoringJobsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1ModelDeploymentMonitoringJob=job_spec
))
else:
return self._service.Create(
self.messages.
AiplatformProjectsLocationsModelDeploymentMonitoringJobsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1ModelDeploymentMonitoringJob=job_spec))
def Patch(self, model_monitoring_job_ref, args):
"""Update a model deployment monitoring job."""
model_monitoring_job_to_update = api_util.GetMessage(
'ModelDeploymentMonitoringJob', self._version)()
update_mask = []
job_spec = api_util.GetMessage('ModelDeploymentMonitoringJob',
self._version)()
if args.monitoring_config_from_file:
data = yaml.load_path(args.monitoring_config_from_file)
if data:
job_spec = messages_util.DictToMessageWithErrorCheck(
data,
api_util.GetMessage('ModelDeploymentMonitoringJob', self._version))
model_monitoring_job_to_update.modelDeploymentMonitoringObjectiveConfigs = job_spec.modelDeploymentMonitoringObjectiveConfigs
update_mask.append('model_deployment_monitoring_objective_configs')
if args.feature_thresholds or args.feature_attribution_thresholds:
get_monitoring_job_req = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsGetRequest(
name=model_monitoring_job_ref.RelativeName())
model_monitoring_job = self._service.Get(get_monitoring_job_req)
model_monitoring_job_to_update.modelDeploymentMonitoringObjectiveConfigs = self._ConstructObjectiveConfigForUpdate(
model_monitoring_job, args.feature_thresholds,
args.feature_attribution_thresholds)
update_mask.append('model_deployment_monitoring_objective_configs')
if args.display_name:
model_monitoring_job_to_update.displayName = args.display_name
update_mask.append('display_name')
if args.emails:
model_monitoring_job_to_update.modelMonitoringAlertConfig = (
api_util.GetMessage('ModelMonitoringAlertConfig', self._version)(
emailAlertConfig=api_util.GetMessage(
'ModelMonitoringAlertConfigEmailAlertConfig', self._version
)(userEmails=args.emails)
)
)
update_mask.append('model_monitoring_alert_config.email_alert_config')
if args.anomaly_cloud_logging is not None:
if args.emails:
model_monitoring_job_to_update.modelMonitoringAlertConfig.enableLogging = (
args.anomaly_cloud_logging
)
else:
model_monitoring_job_to_update.modelMonitoringAlertConfig = (
api_util.GetMessage('ModelMonitoringAlertConfig', self._version)(
enableLogging=args.anomaly_cloud_logging
)
)
update_mask.append('model_monitoring_alert_config.enable_logging')
if args.notification_channels:
if args.emails or args.anomaly_cloud_logging is not None:
model_monitoring_job_to_update.modelMonitoringAlertConfig.notificationChannels = (
args.notification_channels
)
else:
model_monitoring_job_to_update.modelMonitoringAlertConfig = (
api_util.GetMessage('ModelMonitoringAlertConfig', self._version)(
notificationChannels=args.notification_channels
)
)
update_mask.append('model_monitoring_alert_config.notification_channels')
# sampling rate
if args.prediction_sampling_rate:
model_monitoring_job_to_update.loggingSamplingStrategy = api_util.GetMessage(
'SamplingStrategy', self._version)(
randomSampleConfig=api_util.GetMessage(
'SamplingStrategyRandomSampleConfig', self._version)(
sampleRate=args.prediction_sampling_rate))
update_mask.append('logging_sampling_strategy')
# schedule
if args.monitoring_frequency:
model_monitoring_job_to_update.modelDeploymentMonitoringScheduleConfig = api_util.GetMessage(
'ModelDeploymentMonitoringScheduleConfig', self._version)(
monitorInterval='{}s'.format(
six.text_type(3600 * int(args.monitoring_frequency))))
update_mask.append('model_deployment_monitoring_schedule_config')
if args.analysis_instance_schema:
model_monitoring_job_to_update.analysisInstanceSchemaUri = args.analysis_instance_schema
update_mask.append('analysis_instance_schema_uri')
if args.log_ttl:
model_monitoring_job_to_update.logTtl = '{}s'.format(
six.text_type(86400 * int(args.log_ttl)))
update_mask.append('log_ttl')
labels_update = self._ParseUpdateLabels(model_monitoring_job_ref, args)
if labels_update.needs_update:
model_monitoring_job_to_update.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
if self._version == constants.BETA_VERSION:
req = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsPatchRequest(
name=model_monitoring_job_ref.RelativeName(),
googleCloudAiplatformV1beta1ModelDeploymentMonitoringJob=model_monitoring_job_to_update,
updateMask=','.join(update_mask))
else:
req = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsPatchRequest(
name=model_monitoring_job_ref.RelativeName(),
googleCloudAiplatformV1ModelDeploymentMonitoringJob=model_monitoring_job_to_update,
updateMask=','.join(update_mask))
return self._service.Patch(req)
def Get(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsGetRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=None, region_ref=None):
return list_pager.YieldFromList(
self._service,
self.messages
.AiplatformProjectsLocationsModelDeploymentMonitoringJobsListRequest(
parent=region_ref.RelativeName()),
field='modelDeploymentMonitoringJobs',
batch_size_attribute='pageSize',
limit=limit)
def Delete(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsDeleteRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Delete(request)
def Pause(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsPauseRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Pause(request)
def Resume(self, model_monitoring_job_ref):
request = self.messages.AiplatformProjectsLocationsModelDeploymentMonitoringJobsResumeRequest(
name=model_monitoring_job_ref.RelativeName())
return self._service.Resume(request)

View File

@@ -0,0 +1,895 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform models API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
class ModelsClient(object):
"""High-level client for the AI Platform models surface.
Attributes:
client: An instance of the given client, or the API client aiplatform of
Beta version.
messages: The messages module for the given client, or the API client
aiplatform of Beta version.
"""
def __init__(self, client=None, messages=None):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[constants.BETA_VERSION])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_models
def UploadV1Beta1(
self,
region_ref=None,
display_name=None,
description=None,
version_description=None,
artifact_uri=None,
container_image_uri=None,
container_command=None,
container_args=None,
container_env_vars=None,
container_ports=None,
container_grpc_ports=None,
container_predict_route=None,
container_health_route=None,
container_deployment_timeout_seconds=None,
container_shared_memory_size_mb=None,
container_startup_probe_exec=None,
container_startup_probe_period_seconds=None,
container_startup_probe_timeout_seconds=None,
container_health_probe_exec=None,
container_health_probe_period_seconds=None,
container_health_probe_timeout_seconds=None,
explanation_spec=None,
parent_model=None,
model_id=None,
version_aliases=None,
labels=None,
base_model_source=None,
):
"""Constructs, sends an UploadModel request and returns the LRO to be done.
Args:
region_ref: The resource reference for a given region. None if the region
reference is not provided.
display_name: The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description: The description of the Model.
version_description: The description of the Model version.
artifact_uri: The path to the directory containing the Model artifact and
any of its supporting files. Not present for AutoML Models.
container_image_uri: Immutable. URI of the Docker image to be used as the
custom container for serving predictions. This URI must identify an
image in Artifact Registry or Container Registry. Learn more about the
[container publishing requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#publishing), including
permissions requirements for the Vertex AI Service Agent. The container
image is ingested upon ModelService.UploadModel, stored internally, and
this original path is afterwards not used. To learn about the
requirements for the Docker image itself, see [Custom container
requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#). You can use the URI
to one of Vertex AI's [pre-built container images for
prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-
built-containers) in this field.
container_command: Specifies the command that runs when the container
starts. This overrides the container's [ENTRYPOINT](https://docs.docker.
com/engine/reference/builder/#entrypoint). Specify this field as an
array of executable and arguments, similar to a Docker `ENTRYPOINT`'s
"exec" form, not its "shell" form. If you do not specify this field,
then the container's `ENTRYPOINT` runs, in conjunction with the args
field or the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd), if
either exists. If this field is not specified and the container does not
have an `ENTRYPOINT`, then refer to the Docker documentation about [how
`CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). If you specify this field, then you
can also specify the `args` field to provide additional arguments for
this command. However, if you specify this field, then the container's
`CMD` is ignored. See the [Kubernetes documentation about how the
`command` and `args` fields interact with a container's `ENTRYPOINT` and
`CMD`](https://kubernetes.io/docs/tasks/inject-data-application/define-
command-argument-container/#notes). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `command` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_args: Specifies arguments for the command that runs when the
container starts. This overrides the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd). Specify
this field as an array of executable and arguments, similar to a Docker
`CMD`'s "default parameters" form. If you don't specify this field but
do specify the command field, then the command from the `command` field
runs without any additional arguments. See the [Kubernetes documentation
about how the `command` and `args` fields interact with a container's
`ENTRYPOINT` and `CMD`](https://kubernetes.io/docs/tasks/inject-data-
application/define-command-argument-container/#notes). If you don't
specify this field and don't specify the `command` field, then the
container's
[`ENTRYPOINT`](https://docs.docker.com/engine/reference/builder/#cmd)
and `CMD` determine what runs based on their default behavior. See the
Docker documentation about [how `CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `args` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core)..
container_env_vars: List of environment variables to set in the container.
After the container starts running, code running in the container can
read these environment variables. Additionally, the command and args
fields can reference these variables. Later entries in this list can
also reference earlier entries. For example, the following example sets
the variable `VAR_2` to have the value `foo bar`: ```json [ { "name":
"VAR_1", "value": "foo" }, { "name": "VAR_2", "value": "$(VAR_1) bar" }
] ``` If you switch the order of the variables in the example, then the
expansion does not occur. This field corresponds to the `env` field of
the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_ports: List of ports to expose from the container. Vertex AI
sends any http prediction requests that it receives to the first port on
this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, it defaults to following value: ```json [ { "containerPort":
8080 } ] ``` Vertex AI does not use ports other than the first one
listed. This field corresponds to the `ports` field of the Kubernetes
Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_grpc_ports: List of ports to expose from the container. Vertex
AI sends any grpc prediction requests that it receives to the first port
on this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, gRPC requests to the container will be disabled. Vertex AI
does not use ports other than the first one listed. This field
corresponds to the `ports` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_predict_route: HTTP path on the container to send prediction
requests to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the container's IP
address and port. Vertex AI then returns the container's response in the
API response. For example, if you set this field to `/foo`, then when
Vertex AI receives a prediction request, it forwards the request body in
a POST request to the `/foo` path on the port of your container
specified by the first value of this `ModelContainerSpec`'s ports field.
If you don't specify this field, it defaults to the following value when
you deploy this Model to an Endpoint:
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict The
placeholders in this value are replaced as follows: * ENDPOINT: The last
segment (following `endpoints/`)of the Endpoint.name][] field of the
Endpoint where this Model has been deployed. (Vertex AI makes this value
available to your container code as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_health_route: HTTP path on the container to send health checks
to. Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is healthy.
Read more about [health checks](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#health). For example,
if you set this field to `/bar`, then Vertex AI intermittently sends a
GET request to the `/bar` path on the port of your container specified
by the first value of this `ModelContainerSpec`'s ports field. If you
don't specify this field, it defaults to the following value when you
deploy this Model to an Endpoint: /v1/endpoints/ENDPOINT/deployedModels/
DEPLOYED_MODEL:predict The placeholders in this value are replaced as
follows * ENDPOINT: The last segment (following `endpoints/`)of the
Endpoint.name][] field of the Endpoint where this Model has been
deployed. (Vertex AI makes this value available to your container code
as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_deployment_timeout_seconds (int): Deployment timeout in seconds.
container_shared_memory_size_mb (int): The amount of the VM memory to
reserve as the shared memory for the model in megabytes.
container_startup_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by startup probe. An example of this argument would be
["cat", "/tmp/healthy"]
container_startup_probe_period_seconds (int): How often (in seconds) to
perform the startup probe. Default to 10 seconds. Minimum value is 1.
container_startup_probe_timeout_seconds (int): Number of seconds after
which the startup probe times out. Defaults to 1 second. Minimum value
is 1.
container_health_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by health probe. An example of this argument would be ["cat",
"/tmp/healthy"]
container_health_probe_period_seconds (int): How often (in seconds) to
perform the health probe. Default to 10 seconds. Minimum value is 1.
container_health_probe_timeout_seconds (int): Number of seconds after
which the health probe times out. Defaults to 1 second. Minimum value is
1.
explanation_spec: The default explanation specification for this Model.
The Model can be used for requesting explanation after being deployed if
it is populated. The Model can be used for batch explanation if it is
populated. All fields of the explanation_spec can be overridden by
explanation_spec of DeployModelRequest.deployed_model, or
explanation_spec of BatchPredictionJob. If the default explanation
specification is not set for this Model, this Model can still be used
for requesting explanation by setting explanation_spec of
DeployModelRequest.deployed_model and for batch explanation by setting
explanation_spec of BatchPredictionJob.
parent_model: The resource name of the model into which to upload the
version. Only specify this field when uploading a new version.
model_id: The ID to use for the uploaded Model, which will become the
final component of the model resource name. This value may be up to 63
characters, and valid characters are `[a-z0-9_-]`. The first character
cannot be a number or hyphen..
version_aliases: User provided version aliases so that a model version can
be referenced via alias (i.e. projects/{project}/locations/{location}/mo
dels/{model_id}@{version_alias} instead of auto-generated version id
(i.e.
projects/{project}/locations/{location}/models/{model_id}@{version_id}).
The format is a-z{0,126}[a-z0-9] to distinguish from version_id. A
default version alias will be created for the first version of the
model, and there must be exactly one default version alias for a model.
labels: The labels with user-defined metadata to organize your Models.
Label keys and values can be no longer than 64 characters (Unicode
codepoints), can only contain lowercase letters, numeric characters,
underscores and dashes. International characters are allowed. See
https://goo.gl/xmQnxf for more information and examples of labels.
base_model_source: A GoogleCloudAiplatformV1beta1ModelBaseModelSource
object that indicates the source of the model. Currently it only
supports specifying the Model Garden models and Generative AI Studio
models.
Returns:
Response from calling upload model with given request arguments.
"""
container_spec = (
self.messages.GoogleCloudAiplatformV1beta1ModelContainerSpec(
healthRoute=container_health_route,
imageUri=container_image_uri,
predictRoute=container_predict_route,
)
)
if container_command:
container_spec.command = container_command
if container_args:
container_spec.args = container_args
if container_env_vars:
container_spec.env = [
self.messages.GoogleCloudAiplatformV1beta1EnvVar(
name=k, value=container_env_vars[k]) for k in container_env_vars
]
if container_ports:
container_spec.ports = [
self.messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_ports
]
if container_grpc_ports:
container_spec.grpcPorts = [
self.messages.GoogleCloudAiplatformV1beta1Port(containerPort=port)
for port in container_grpc_ports
]
if container_deployment_timeout_seconds:
container_spec.deploymentTimeout = (
str(container_deployment_timeout_seconds) + 's'
)
if container_shared_memory_size_mb:
container_spec.sharedMemorySizeMb = container_shared_memory_size_mb
if (
container_startup_probe_exec
or container_startup_probe_period_seconds
or container_startup_probe_timeout_seconds
):
startup_probe_exec = None
if container_startup_probe_exec:
startup_probe_exec = (
self.messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_startup_probe_exec
)
)
container_spec.startupProbe = (
self.messages.GoogleCloudAiplatformV1beta1Probe(
exec_=startup_probe_exec,
periodSeconds=container_startup_probe_period_seconds,
timeoutSeconds=container_startup_probe_timeout_seconds,
)
)
if (
container_health_probe_exec
or container_health_probe_period_seconds
or container_health_probe_timeout_seconds
):
health_probe_exec = None
if container_health_probe_exec:
health_probe_exec = (
self.messages.GoogleCloudAiplatformV1beta1ProbeExecAction(
command=container_health_probe_exec
)
)
container_spec.healthProbe = (
self.messages.GoogleCloudAiplatformV1beta1Probe(
exec_=health_probe_exec,
periodSeconds=container_health_probe_period_seconds,
timeoutSeconds=container_health_probe_timeout_seconds,
)
)
model = self.messages.GoogleCloudAiplatformV1beta1Model(
artifactUri=artifact_uri,
containerSpec=container_spec,
description=description,
versionDescription=version_description,
displayName=display_name,
explanationSpec=explanation_spec,
baseModelSource=base_model_source,
)
if version_aliases:
model.versionAliases = version_aliases
if labels:
additional_properties = []
for key, value in sorted(labels.items()):
additional_properties.append(model.LabelsValue().AdditionalProperty(
key=key, value=value))
model.labels = model.LabelsValue(
additionalProperties=additional_properties)
return self._service.Upload(
self.messages.AiplatformProjectsLocationsModelsUploadRequest(
parent=region_ref.RelativeName(),
googleCloudAiplatformV1beta1UploadModelRequest=self.messages
.GoogleCloudAiplatformV1beta1UploadModelRequest(
model=model,
parentModel=parent_model,
modelId=model_id)))
def UploadV1(self,
region_ref=None,
display_name=None,
description=None,
version_description=None,
artifact_uri=None,
container_image_uri=None,
container_command=None,
container_args=None,
container_env_vars=None,
container_ports=None,
container_grpc_ports=None,
container_predict_route=None,
container_health_route=None,
container_deployment_timeout_seconds=None,
container_shared_memory_size_mb=None,
container_startup_probe_exec=None,
container_startup_probe_period_seconds=None,
container_startup_probe_timeout_seconds=None,
container_health_probe_exec=None,
container_health_probe_period_seconds=None,
container_health_probe_timeout_seconds=None,
explanation_spec=None,
parent_model=None,
model_id=None,
version_aliases=None,
labels=None):
"""Constructs, sends an UploadModel request and returns the LRO to be done.
Args:
region_ref: The resource reference for a given region. None if the region
reference is not provided.
display_name: The display name of the Model. The name can be up to 128
characters long and can be consist of any UTF-8 characters.
description: The description of the Model.
version_description: The description of the Model version.
artifact_uri: The path to the directory containing the Model artifact and
any of its supporting files. Not present for AutoML Models.
container_image_uri: Immutable. URI of the Docker image to be used as the
custom container for serving predictions. This URI must identify an
image in Artifact Registry or Container Registry. Learn more about the
[container publishing requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#publishing), including
permissions requirements for the Vertex AI Service Agent. The container
image is ingested upon ModelService.UploadModel, stored internally, and
this original path is afterwards not used. To learn about the
requirements for the Docker image itself, see [Custom container
requirements](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#). You can use the URI
to one of Vertex AI's [pre-built container images for
prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-
built-containers) in this field.
container_command: Specifies the command that runs when the container
starts. This overrides the container's [ENTRYPOINT](https://docs.docker.
com/engine/reference/builder/#entrypoint). Specify this field as an
array of executable and arguments, similar to a Docker `ENTRYPOINT`'s
"exec" form, not its "shell" form. If you do not specify this field,
then the container's `ENTRYPOINT` runs, in conjunction with the args
field or the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd), if
either exists. If this field is not specified and the container does not
have an `ENTRYPOINT`, then refer to the Docker documentation about [how
`CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). If you specify this field, then you
can also specify the `args` field to provide additional arguments for
this command. However, if you specify this field, then the container's
`CMD` is ignored. See the [Kubernetes documentation about how the
`command` and `args` fields interact with a container's `ENTRYPOINT` and
`CMD`](https://kubernetes.io/docs/tasks/inject-data-application/define-
command-argument-container/#notes). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `command` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_args: Specifies arguments for the command that runs when the
container starts. This overrides the container's
[`CMD`](https://docs.docker.com/engine/reference/builder/#cmd). Specify
this field as an array of executable and arguments, similar to a Docker
`CMD`'s "default parameters" form. If you don't specify this field but
do specify the command field, then the command from the `command` field
runs without any additional arguments. See the [Kubernetes documentation
about how the `command` and `args` fields interact with a container's
`ENTRYPOINT` and `CMD`](https://kubernetes.io/docs/tasks/inject-data-
application/define-command-argument-container/#notes). If you don't
specify this field and don't specify the `command` field, then the
container's
[`ENTRYPOINT`](https://docs.docker.com/engine/reference/builder/#cmd)
and `CMD` determine what runs based on their default behavior. See the
Docker documentation about [how `CMD` and `ENTRYPOINT`
interact](https://docs.docker.com/engine/reference/builder/#understand-
how-cmd-and-entrypoint-interact). In this field, you can reference
[environment variables set by Vertex
AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables) and environment variables set in
the env field. You cannot reference environment variables set in the
Docker image. In order for environment variables to be expanded,
reference them by using the following syntax: $( VARIABLE_NAME) Note
that this differs from Bash variable expansion, which does not use
parentheses. If a variable cannot be resolved, the reference in the
input string is used unchanged. To avoid variable expansion, you can
escape this syntax with `$$`; for example: $$(VARIABLE_NAME) This field
corresponds to the `args` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core)..
container_env_vars: List of environment variables to set in the container.
After the container starts running, code running in the container can
read these environment variables. Additionally, the command and args
fields can reference these variables. Later entries in this list can
also reference earlier entries. For example, the following example sets
the variable `VAR_2` to have the value `foo bar`: ```json [ { "name":
"VAR_1", "value": "foo" }, { "name": "VAR_2", "value": "$(VAR_1) bar" }
] ``` If you switch the order of the variables in the example, then the
expansion does not occur. This field corresponds to the `env` field of
the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_ports: List of ports to expose from the container. Vertex AI
sends any http prediction requests that it receives to the first port on
this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, it defaults to following value: ```json [ { "containerPort":
8080 } ] ``` Vertex AI does not use ports other than the first one
listed. This field corresponds to the `ports` field of the Kubernetes
Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_grpc_ports: List of ports to expose from the container. Vertex
AI sends any grpc prediction requests that it receives to the first port
on this list. Vertex AI also sends [liveness and health
checks](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#liveness) to this port. If you do not specify
this field, gRPC requests to the container will be disabled. Vertex AI
does not use ports other than the first one listed. This field
corresponds to the `ports` field of the Kubernetes Containers [v1 core
API](https://kubernetes.io/docs/reference/generated/kubernetes-
api/v1.23/#container-v1-core).
container_predict_route: HTTP path on the container to send prediction
requests to. Vertex AI forwards requests sent using
projects.locations.endpoints.predict to this path on the container's IP
address and port. Vertex AI then returns the container's response in the
API response. For example, if you set this field to `/foo`, then when
Vertex AI receives a prediction request, it forwards the request body in
a POST request to the `/foo` path on the port of your container
specified by the first value of this `ModelContainerSpec`'s ports field.
If you don't specify this field, it defaults to the following value when
you deploy this Model to an Endpoint:
/v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict The
placeholders in this value are replaced as follows: * ENDPOINT: The last
segment (following `endpoints/`)of the Endpoint.name][] field of the
Endpoint where this Model has been deployed. (Vertex AI makes this value
available to your container code as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_health_route: HTTP path on the container to send health checks
to. Vertex AI intermittently sends GET requests to this path on the
container's IP address and port to check that the container is healthy.
Read more about [health checks](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#health). For example,
if you set this field to `/bar`, then Vertex AI intermittently sends a
GET request to the `/bar` path on the port of your container specified
by the first value of this `ModelContainerSpec`'s ports field. If you
don't specify this field, it defaults to the following value when you
deploy this Model to an Endpoint: /v1/endpoints/ENDPOINT/deployedModels/
DEPLOYED_MODEL:predict The placeholders in this value are replaced as
follows * ENDPOINT: The last segment (following `endpoints/`)of the
Endpoint.name][] field of the Endpoint where this Model has been
deployed. (Vertex AI makes this value available to your container code
as the [`AIP_ENDPOINT_ID` environment
variable](https://cloud.google.com/vertex-ai/docs/predictions/custom-
container-requirements#aip-variables).) * DEPLOYED_MODEL:
DeployedModel.id of the `DeployedModel`. (Vertex AI makes this value
available to your container code as the [`AIP_DEPLOYED_MODEL_ID`
environment variable](https://cloud.google.com/vertex-
ai/docs/predictions/custom-container-requirements#aip-variables).)
container_deployment_timeout_seconds (int): Deployment timeout in seconds.
container_shared_memory_size_mb (int): The amount of the VM memory to
reserve as the shared memory for the model in megabytes.
container_startup_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by startup probe. An example of this argument would be
["cat", "/tmp/healthy"]
container_startup_probe_period_seconds (int): How often (in seconds) to
perform the startup probe. Default to 10 seconds. Minimum value is 1.
container_startup_probe_timeout_seconds (int): Number of seconds after
which the startup probe times out. Defaults to 1 second. Minimum value
is 1.
container_health_probe_exec (Sequence[str]): Exec specifies the action to
take. Used by health probe. An example of this argument would be ["cat",
"/tmp/healthy"]
container_health_probe_period_seconds (int): How often (in seconds) to
perform the health probe. Default to 10 seconds. Minimum value is 1.
container_health_probe_timeout_seconds (int): Number of seconds after
which the health probe times out. Defaults to 1 second. Minimum value is
1.
explanation_spec: The default explanation specification for this Model.
The Model can be used for requesting explanation after being deployed if
it is populated. The Model can be used for batch explanation if it is
populated. All fields of the explanation_spec can be overridden by
explanation_spec of DeployModelRequest.deployed_model, or
explanation_spec of BatchPredictionJob. If the default explanation
specification is not set for this Model, this Model can still be used
for requesting explanation by setting explanation_spec of
DeployModelRequest.deployed_model and for batch explanation by setting
explanation_spec of BatchPredictionJob.
parent_model: The resource name of the model into which to upload the
version. Only specify this field when uploading a new version.
model_id: The ID to use for the uploaded Model, which will become the
final component of the model resource name. This value may be up to 63
characters, and valid characters are `[a-z0-9_-]`. The first character
cannot be a number or hyphen..
version_aliases: User provided version aliases so that a model version can
be referenced via alias (i.e. projects/{project}/locations/{location}/mo
dels/{model_id}@{version_alias} instead of auto-generated version id
(i.e.
projects/{project}/locations/{location}/models/{model_id}@{version_id}).
The format is a-z{0,126}[a-z0-9] to distinguish from version_id. A
default version alias will be created for the first version of the
model, and there must be exactly one default version alias for a model.
labels: The labels with user-defined metadata to organize your Models.
Label keys and values can be no longer than 64 characters (Unicode
codepoints), can only contain lowercase letters, numeric characters,
underscores and dashes. International characters are allowed. See
https://goo.gl/xmQnxf for more information and examples of labels.
Returns:
Response from calling upload model with given request arguments.
"""
container_spec = self.messages.GoogleCloudAiplatformV1ModelContainerSpec(
healthRoute=container_health_route,
imageUri=container_image_uri,
predictRoute=container_predict_route)
if container_command:
container_spec.command = container_command
if container_args:
container_spec.args = container_args
if container_env_vars:
container_spec.env = [
self.messages.GoogleCloudAiplatformV1EnvVar(
name=k, value=container_env_vars[k]) for k in container_env_vars
]
if container_ports:
container_spec.ports = [
self.messages.GoogleCloudAiplatformV1Port(containerPort=port)
for port in container_ports
]
if container_grpc_ports:
container_spec.grpcPorts = [
self.messages.GoogleCloudAiplatformV1Port(containerPort=port)
for port in container_grpc_ports
]
if container_deployment_timeout_seconds:
container_spec.deploymentTimeout = (
str(container_deployment_timeout_seconds) + 's'
)
if container_shared_memory_size_mb:
container_spec.sharedMemorySizeMb = container_shared_memory_size_mb
if (
container_startup_probe_exec
or container_startup_probe_period_seconds
or container_startup_probe_timeout_seconds
):
startup_probe_exec = None
if container_startup_probe_exec:
startup_probe_exec = (
self.messages.GoogleCloudAiplatformV1ProbeExecAction(
command=container_startup_probe_exec
)
)
container_spec.startupProbe = (
self.messages.GoogleCloudAiplatformV1Probe(
exec_=startup_probe_exec,
periodSeconds=container_startup_probe_period_seconds,
timeoutSeconds=container_startup_probe_timeout_seconds,
)
)
if (
container_health_probe_exec
or container_health_probe_period_seconds
or container_health_probe_timeout_seconds
):
health_probe_exec = None
if container_health_probe_exec:
health_probe_exec = (
self.messages.GoogleCloudAiplatformV1ProbeExecAction(
command=container_health_probe_exec
)
)
container_spec.healthProbe = (
self.messages.GoogleCloudAiplatformV1Probe(
exec_=health_probe_exec,
periodSeconds=container_health_probe_period_seconds,
timeoutSeconds=container_health_probe_timeout_seconds,
)
)
model = self.messages.GoogleCloudAiplatformV1Model(
artifactUri=artifact_uri,
containerSpec=container_spec,
description=description,
versionDescription=version_description,
displayName=display_name,
explanationSpec=explanation_spec)
if version_aliases:
model.versionAliases = version_aliases
if labels:
additional_properties = []
for key, value in sorted(labels.items()):
additional_properties.append(model.LabelsValue().AdditionalProperty(
key=key, value=value))
model.labels = model.LabelsValue(
additionalProperties=additional_properties)
return self._service.Upload(
self.messages.AiplatformProjectsLocationsModelsUploadRequest(
parent=region_ref.RelativeName(),
googleCloudAiplatformV1UploadModelRequest=self.messages
.GoogleCloudAiplatformV1UploadModelRequest(
model=model,
parentModel=parent_model,
modelId=model_id)))
def Get(self, model_ref):
"""Gets (describe) the given model.
Args:
model_ref: The resource reference for a given model. None if model
resource reference is not provided.
Returns:
Response from calling get model with request containing given model.
"""
request = self.messages.AiplatformProjectsLocationsModelsGetRequest(
name=model_ref.RelativeName())
return self._service.Get(request)
def Delete(self, model_ref):
"""Deletes the given model.
Args:
model_ref: The resource reference for a given model. None if model
resource reference is not provided.
Returns:
Response from calling delete model with request containing given model.
"""
request = self.messages.AiplatformProjectsLocationsModelsDeleteRequest(
name=model_ref.RelativeName())
return self._service.Delete(request)
def DeleteVersion(self, model_version_ref):
"""Deletes the given model version.
Args:
model_version_ref: The resource reference for a given model version.
Returns:
Response from calling delete version with request containing given model
version.
"""
request = (
self.messages.AiplatformProjectsLocationsModelsDeleteVersionRequest(
name=model_version_ref.RelativeName()
)
)
return self._service.DeleteVersion(request)
def List(self, limit=None, region_ref=None):
"""List all models in the given region.
Args:
limit: int, The maximum number of records to yield. None if all available
records should be yielded.
region_ref: The resource reference for a given region. None if the region
reference is not provided.
Returns:
Response from calling list models with request containing given models
and limit.
"""
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsModelsListRequest(
parent=region_ref.RelativeName()),
field='models',
batch_size_attribute='pageSize',
limit=limit)
def ListVersion(self, model_ref=None, limit=None):
"""List all model versions of the given model.
Args:
model_ref: The resource reference for a given model. None if model
resource reference is not provided.
limit: int, The maximum number of records to yield. None if all available
records should be yielded.
Returns:
Response from calling list model versions with request containing given
model and limit.
"""
return list_pager.YieldFromList(
self._service,
self.messages.AiplatformProjectsLocationsModelsListVersionsRequest(
name=model_ref.RelativeName()),
method='ListVersions',
field='models',
batch_size_attribute='pageSize',
limit=limit)
def CopyV1Beta1(self,
destination_region_ref=None,
source_model=None,
kms_key_name=None,
destination_model_id=None,
destination_parent_model=None):
"""Copies the given source model into specified location.
The source model is copied into specified location (including cross-region)
either as a new model or a new model version under given parent model.
Args:
destination_region_ref: the resource reference to the location into which
to copy the Model.
source_model: The resource name of the Model to copy.
kms_key_name: The KMS key name for specifying encryption spec.
destination_model_id: The destination model resource name to copy the
model into.
destination_parent_model: The destination parent model to copy the model
as a model version into.
Returns:
Response from calling copy model.
"""
encryption_spec = None
if kms_key_name:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
kmsKeyName=kms_key_name
)
)
request = self.messages.AiplatformProjectsLocationsModelsCopyRequest(
parent=destination_region_ref.RelativeName(),
googleCloudAiplatformV1beta1CopyModelRequest=self.messages
.GoogleCloudAiplatformV1beta1CopyModelRequest(
sourceModel=source_model,
encryptionSpec=encryption_spec,
parentModel=destination_parent_model,
modelId=destination_model_id))
return self._service.Copy(request)
def CopyV1(self,
destination_region_ref=None,
source_model=None,
kms_key_name=None,
destination_model_id=None,
destination_parent_model=None):
"""Copies the given source model into specified location.
The source model is copied into specified location (including cross-region)
either as a new model or a new model version under given parent model.
Args:
destination_region_ref: the resource reference to the location into which
to copy the Model.
source_model: The resource name of the Model to copy.
kms_key_name: The name of the KMS key to use for model encryption.
destination_model_id: Optional. Thew custom ID to be used as the resource
name of the new model. This value may be up to 63 characters, and valid
characters are `[a-z0-9_-]`. The first character cannot be a number or
hyphen.
destination_parent_model: The destination parent model to copy the model
as a model version into.
Returns:
Response from calling copy model.
"""
encryption_spec = None
if kms_key_name:
encryption_spec = (
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
kmsKeyName=kms_key_name
)
)
request = self.messages.AiplatformProjectsLocationsModelsCopyRequest(
parent=destination_region_ref.RelativeName(),
googleCloudAiplatformV1CopyModelRequest=self.messages
.GoogleCloudAiplatformV1CopyModelRequest(
sourceModel=source_model,
encryptionSpec=encryption_spec,
parentModel=destination_parent_model,
modelId=destination_model_id))
return self._service.Copy(request)

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with long-running operations (simple uri)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.command_lib.ai import constants
def GetClientInstance(api_version=None, no_http=False):
return apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME, api_version, no_http=no_http)
class AiPlatformOperationPoller(waiter.CloudOperationPoller):
"""Poller for AI Platform operations API.
This is necessary because the core operations library doesn't directly support
simple_uri.
"""
def __init__(self, client):
self.client = client
super(AiPlatformOperationPoller, self).__init__(
self.client.client.projects_locations_operations,
self.client.client.projects_locations_operations)
def Poll(self, operation_ref):
return self.client.Get(operation_ref)
def GetResult(self, operation):
return operation
class OperationsClient(object):
"""High-level client for the AI Platform operations surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or GetClientInstance(
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
def Get(self, operation_ref):
return self.client.projects_locations_operations.Get(
self.messages.AiplatformProjectsLocationsOperationsGetRequest(
name=operation_ref.RelativeName()))
def WaitForOperation(
self, operation, operation_ref, message=None, max_wait_ms=1800000
):
"""Wait until the operation is complete or times out.
Args:
operation: The operation resource to wait on
operation_ref: The operation reference to the operation resource. It's the
result by calling resources.REGISTRY.Parse
message: str, the message to print while waiting.
max_wait_ms: int, number of ms to wait before raising WaitException.
Returns:
The operation resource when it has completed
Raises:
OperationTimeoutError: when the operation polling times out
OperationError: when the operation completed with an error
"""
poller = AiPlatformOperationPoller(self)
if poller.IsDone(operation):
return operation
if message is None:
message = 'Waiting for operation [{}]'.format(operation_ref.Name())
return waiter.WaitFor(
poller, operation_ref, message, max_wait_ms=max_wait_ms
)

View File

@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying Vertex AI Persistent Resources."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core.console import console_io
class PersistentResourcesClient(object):
"""Client used for interacting with the PersistentResource endpoint."""
def __init__(self, version=constants.GA_VERSION):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._version = version
self._service = client.projects_locations_persistentResources
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def GetMessage(self, message_name):
"""Returns the API message class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def PersistentResourceMessage(self):
"""Returns the PersistentResource message."""
return self.GetMessage('PersistentResource')
def Create(self,
parent,
resource_pools,
persistent_resource_id,
display_name=None,
kms_key_name=None,
labels=None,
network=None,
enable_custom_service_account=False,
service_account=None):
"""Constructs a request and sends it to the endpoint to create a persistent resource.
Args:
parent: str, The project resource path of the persistent resource to
create.
resource_pools: The PersistentResource message instance for the
creation request.
persistent_resource_id: The PersistentResource id for the creation
request.
display_name: str, The display name of the persistent resource to create.
kms_key_name: A customer-managed encryption key to use for the persistent
resource.
labels: LabelValues, map-like user-defined metadata to organize the
resource.
network: Network to peer with the PersistentResource
enable_custom_service_account: Whether or not to enable this Persistent
Resource to use a custom service account.
service_account: A service account (email address string) to use for
creating the Persistent Resource.
Returns:
A PersistentResource message instance created.
"""
persistent_resource = self.PersistentResourceMessage()(
displayName=display_name, resourcePools=resource_pools)
if kms_key_name is not None:
persistent_resource.encryptionSpec = self.GetMessage('EncryptionSpec')(
kmsKeyName=kms_key_name)
if labels:
persistent_resource.labels = labels
if network:
persistent_resource.network = network
if enable_custom_service_account:
persistent_resource.resourceRuntimeSpec = (
self.GetMessage('ResourceRuntimeSpec')(
serviceAccountSpec=self.GetMessage('ServiceAccountSpec')(
enableCustomServiceAccount=True,
serviceAccount=service_account)))
if self._version == constants.GA_VERSION:
return self._service.Create(
self._messages.AiplatformProjectsLocationsPersistentResourcesCreateRequest(
parent=parent,
googleCloudAiplatformV1PersistentResource=persistent_resource,
persistentResourceId=persistent_resource_id,
)
)
return self._service.Create(
self._messages.AiplatformProjectsLocationsPersistentResourcesCreateRequest(
parent=parent,
googleCloudAiplatformV1beta1PersistentResource=persistent_resource,
persistentResourceId=persistent_resource_id,
)
)
def List(self, limit=None, region=None):
"""Constructs a list request and sends it to the Persistent Resources endpoint.
Args:
limit: How many items to return in the list
region: Which region to list resources from
Returns:
A Persistent Resource list response message.
"""
return list_pager.YieldFromList(
self._service,
self._messages.AiplatformProjectsLocationsPersistentResourcesListRequest(
parent=region
),
field='persistentResources',
batch_size_attribute='pageSize',
limit=limit,
)
def Get(self, name):
request = (self._messages
.AiplatformProjectsLocationsPersistentResourcesGetRequest(
name=name))
return self._service.Get(request)
def Delete(self, name):
request = self._messages.AiplatformProjectsLocationsPersistentResourcesDeleteRequest(
name=name
)
return self._service.Delete(request)
def Reboot(self, name):
request = self._messages.AiplatformProjectsLocationsPersistentResourcesRebootRequest(
name=name
)
return self._service.Reboot(request)
def ImportResourceMessage(self, yaml_file, message_name):
"""Import a messages class instance typed by name from a YAML file."""
data = console_io.ReadFromFileOrStdin(yaml_file, binary=False)
message_type = self.GetMessage(message_name)
return export_util.Import(message_type=message_type, stream=data)

View File

@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is used to get the client instance and messages module for GKE recommender."""
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
VERSION_MAP = {
base.ReleaseTrack.ALPHA: 'v1alpha1',
base.ReleaseTrack.GA: 'v1',
}
HTTP_ERROR_FORMAT = (
'ResponseError: code={status_code}, message={status_message}'
)
# The messages module can also be accessed from client.MESSAGES_MODULE
def GetMessagesModule(release_track=base.ReleaseTrack.GA):
api_version = VERSION_MAP.get(release_track)
return apis.GetMessagesModule('gkerecommender', api_version)
def GetClientInstance(release_track=base.ReleaseTrack.GA):
api_version = VERSION_MAP.get(release_track)
return apis.GetClientInstance('gkerecommender', api_version)

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for querying serverless ray jobs in AI Platform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core.console import console_io
class ServerlessRayJobsClient(object):
"""Client used for interacting with Serverless Ray Jobs endpoint."""
def __init__(self, version=constants.GA_VERSION):
client = apis.GetClientInstance(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self._messages = client.MESSAGES_MODULE
self._version = version
self._service = client.projects_locations_serverlessRayJobs
self._message_prefix = constants.AI_PLATFORM_MESSAGE_PREFIX[version]
def GetMessage(self, message_name):
"""Returns the API message class by name."""
return getattr(
self._messages,
'{prefix}{name}'.format(prefix=self._message_prefix,
name=message_name), None)
def ServerlessRayJobMessage(self):
"""Retures the Serverless Ray Jobs resource message."""
return self.GetMessage('ServerlessRayJob')
def Create(self,
parent,
job_spec,
display_name=None,
labels=None):
"""Constructs a request and sends it to the endpoint to create a serverless ray job instance.
Args:
parent: str, The project resource path of the serverless ray job to
create.
job_spec: The ServerlessRayJobSpec message instance for the job creation
request.
display_name: str, The display name of the serverless ray job to create.
labels: LabelValues, map-like user-defined metadata to organize the
serverless ray job.
Returns:
A ServerlessRayJob message instance created.
"""
serverless_ray_job = self.ServerlessRayJobMessage()(
displayName=display_name, jobSpec=job_spec
)
if labels:
serverless_ray_job.labels = labels
# TODO(b/390679825): Add V1 version support when Serverless Ray Jobs API is
# GA ready.
return self._service.Create(
self._messages.AiplatformProjectsLocationsServerlessRayJobsCreateRequest(
parent=parent,
googleCloudAiplatformV1beta1ServerlessRayJob=serverless_ray_job,
)
)
def List(self, limit=None, region=None):
return list_pager.YieldFromList(
self._service,
self._messages.AiplatformProjectsLocationsServerlessRayJobsListRequest(
parent=region
),
field='serverlessRayJobs',
batch_size_attribute='pageSize',
limit=limit,
)
def Get(self, name):
request = (
self._messages.AiplatformProjectsLocationsServerlessRayJobsGetRequest(
name=name
)
)
return self._service.Get(request)
def Cancel(self, name):
request = self._messages.AiplatformProjectsLocationsServerlessRayJobsCancelRequest(
name=name
)
return self._service.Cancel(request)
def ImportResourceMessage(self, yaml_file, message_name):
"""Import a messages class instance typed by name from a YAML file."""
data = console_io.ReadFromFileOrStdin(yaml_file, binary=False)
message_type = self.GetMessage(message_name)
return export_util.Import(message_type=message_type, stream=data)

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard experiments API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
class TensorboardExperimentsClient(object):
"""High-level client for the AI Platform Tensorboard experiment surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards_experiments
self._version = version
def Create(self, tensorboard_ref, args):
return self.CreateBeta(tensorboard_ref, args)
def CreateBeta(self, tensorboard_ref, args):
"""Create a new Tensorboard experiment."""
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1beta1TensorboardExperiment
.LabelsValue)
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsCreateRequest(
parent=tensorboard_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardExperiment=self.messages
.GoogleCloudAiplatformV1beta1TensorboardExperiment(
displayName=args.display_name,
description=args.description,
labels=labels),
tensorboardExperimentId=args.tensorboard_experiment_id)
return self._service.Create(request)
def List(self, tensorboard_ref, limit=1000, page_size=50, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsListRequest(
parent=tensorboard_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboardExperiments',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Get(self, tensorboard_exp_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsGetRequest(
name=tensorboard_exp_ref.RelativeName())
return self._service.Get(request)
def Delete(self, tensorboard_exp_ref):
request = (
self.messages
.AiplatformProjectsLocationsTensorboardsExperimentsDeleteRequest(
name=tensorboard_exp_ref.RelativeName()))
return self._service.Delete(request)
def Patch(self, tensorboard_exp_ref, args):
return self.PatchBeta(tensorboard_exp_ref, args)
def PatchBeta(self, tensorboard_exp_ref, args):
"""Update a Tensorboard experiment."""
tensorboard_exp = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardExperiment())
update_mask = []
def GetLabels():
return self.Get(tensorboard_exp_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, self.messages.GoogleCloudAiplatformV1beta1TensorboardExperiment
.LabelsValue, GetLabels)
if labels_update.needs_update:
tensorboard_exp.labels = labels_update.labels
update_mask.append('labels')
if args.display_name is not None:
tensorboard_exp.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard_exp.description = args.description
update_mask.append('description')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsPatchRequest(
name=tensorboard_exp_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardExperiment=tensorboard_exp,
updateMask=','.join(update_mask))
return self._service.Patch(request)

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard runs API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
class TensorboardRunsClient(object):
"""High-level client for the AI Platform Tensorboard run surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards_experiments_runs
self._version = version
def Create(self, tensorboard_exp_ref, args):
return self.CreateBeta(tensorboard_exp_ref, args)
def CreateBeta(self, tensorboard_exp_ref, args):
"""Create a new Tensorboard run."""
labels = labels_util.ParseCreateArgs(
args,
self.messages.GoogleCloudAiplatformV1beta1TensorboardRun.LabelsValue)
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsCreateRequest(
parent=tensorboard_exp_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardRun=self.messages
.GoogleCloudAiplatformV1beta1TensorboardRun(
displayName=args.display_name,
description=args.description,
labels=labels),
tensorboardRunId=args.tensorboard_run_id)
return self._service.Create(request)
def List(self, tensorboard_exp_ref, limit=1000, page_size=50, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsListRequest(
parent=tensorboard_exp_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboardRuns',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Get(self, tensorboard_run_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsGetRequest(
name=tensorboard_run_ref.RelativeName())
return self._service.Get(request)
def Delete(self, tensorboard_run_ref):
request = (
self.messages
.AiplatformProjectsLocationsTensorboardsExperimentsRunsDeleteRequest(
name=tensorboard_run_ref.RelativeName()))
return self._service.Delete(request)
def Patch(self, tensorboard_run_ref, args):
return self.PatchBeta(tensorboard_run_ref, args)
def PatchBeta(self, tensorboard_run_ref, args):
"""Update a Tensorboard run."""
tensorboard_run = self.messages.GoogleCloudAiplatformV1beta1TensorboardRun()
update_mask = []
def GetLabels():
return self.Get(tensorboard_run_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args,
self.messages.GoogleCloudAiplatformV1beta1TensorboardRun.LabelsValue,
GetLabels)
if labels_update.needs_update:
tensorboard_run.labels = labels_update.labels
update_mask.append('labels')
if args.display_name is not None:
tensorboard_run.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard_run.description = args.description
update_mask.append('description')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsPatchRequest(
name=tensorboard_run_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardRun=tensorboard_run,
updateMask=','.join(update_mask))
return self._service.Patch(request)

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboard time series API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
def GetMessagesModule(version=constants.BETA_VERSION):
return apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
class TensorboardTimeSeriesClient(object):
"""High-level client for the AI Platform Tensorboard time series surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards_experiments_runs_timeSeries
self._version = version
def Create(self, tensorboard_run_ref, args):
return self.CreateBeta(tensorboard_run_ref, args)
def CreateBeta(self, tensorboard_run_ref, args):
"""Create a new Tensorboard time series."""
if args.type == 'scalar':
value_type = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum.SCALAR)
elif args.type == 'blob-sequence':
value_type = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum.BLOB_SEQUENCE)
else:
value_type = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries
.ValueTypeValueValuesEnum.TENSOR)
if args.plugin_data is None:
plugin_data = ''
else:
plugin_data = args.plugin_data
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesCreateRequest(
parent=tensorboard_run_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardTimeSeries=self.messages
.GoogleCloudAiplatformV1beta1TensorboardTimeSeries(
displayName=args.display_name,
description=args.description,
valueType=value_type,
pluginName=args.plugin_name,
pluginData=bytes(plugin_data, encoding='utf8')))
return self._service.Create(request)
def List(self, tensorboard_run_ref, limit=1000, page_size=50, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesListRequest(
parent=tensorboard_run_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboardTimeSeries',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Get(self, tensorboard_time_series_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesGetRequest(
name=tensorboard_time_series_ref.RelativeName())
return self._service.Get(request)
def Delete(self, tensorboard_time_series_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesDeleteRequest(
name=tensorboard_time_series_ref.RelativeName())
return self._service.Delete(request)
def Patch(self, tensorboard_time_series_ref, args):
return self.PatchBeta(tensorboard_time_series_ref, args)
def PatchBeta(self, tensorboard_time_series_ref, args):
"""Update a Tensorboard time series."""
tensorboard_time_series = (
self.messages.GoogleCloudAiplatformV1beta1TensorboardTimeSeries())
update_mask = []
if args.display_name is not None:
tensorboard_time_series.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard_time_series.description = args.description
update_mask.append('description')
if args.plugin_name is not None:
tensorboard_time_series.pluginName = args.plugin_name
update_mask.append('plugin_name')
if args.plugin_data is not None:
tensorboard_time_series.pluginData = bytes(
args.plugin_data, encoding='utf8')
update_mask.append('plugin_data')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesPatchRequest(
name=tensorboard_time_series_ref.RelativeName(),
googleCloudAiplatformV1beta1TensorboardTimeSeries=tensorboard_time_series,
updateMask=','.join(update_mask))
return self._service.Patch(request)
def Read(self, tensorboard_time_series_ref, max_data_points, data_filter):
request = self.messages.AiplatformProjectsLocationsTensorboardsExperimentsRunsTimeSeriesReadRequest(
tensorboardTimeSeries=tensorboard_time_series_ref.RelativeName(),
maxDataPoints=max_data_points,
filter=data_filter)
return self._service.Read(request)

View File

@@ -0,0 +1,155 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for AI Platform Tensorboards API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import common_args
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai import validation as common_validation
from googlecloudsdk.command_lib.util.args import labels_util
class TensorboardsClient(object):
"""High-level client for the AI Platform Tensorboard surface."""
def __init__(self,
client=None,
messages=None,
version=constants.BETA_VERSION):
self.client = client or apis.GetClientInstance(
constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
self.messages = messages or self.client.MESSAGES_MODULE
self._service = self.client.projects_locations_tensorboards
self._version = version
def Create(self, location_ref, args):
if self._version == constants.GA_VERSION:
return self.CreateGa(location_ref, args)
else:
return self.CreateBeta(location_ref, args)
def CreateGa(self, location_ref, args):
"""Create a new Tensorboard."""
kms_key_name = common_validation.GetAndValidateKmsKey(args)
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1Tensorboard.LabelsValue)
tensorboard = self.messages.GoogleCloudAiplatformV1Tensorboard(
displayName=args.display_name,
description=args.description,
labels=labels)
if kms_key_name is not None:
tensorboard.encryptionSpec = api_util.GetMessage(
'EncryptionSpec', self._version)(
kmsKeyName=kms_key_name)
request = self.messages.AiplatformProjectsLocationsTensorboardsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1Tensorboard=tensorboard)
return self._service.Create(request)
def CreateBeta(self, location_ref, args):
"""Create a new Tensorboard."""
kms_key_name = common_validation.GetAndValidateKmsKey(args)
labels = labels_util.ParseCreateArgs(
args, self.messages.GoogleCloudAiplatformV1beta1Tensorboard.LabelsValue)
tensorboard = self.messages.GoogleCloudAiplatformV1beta1Tensorboard(
displayName=args.display_name,
description=args.description,
labels=labels)
if kms_key_name is not None:
tensorboard.encryptionSpec = api_util.GetMessage(
'EncryptionSpec', self._version)(
kmsKeyName=kms_key_name)
request = self.messages.AiplatformProjectsLocationsTensorboardsCreateRequest(
parent=location_ref.RelativeName(),
googleCloudAiplatformV1beta1Tensorboard=tensorboard)
return self._service.Create(request)
def Get(self, tensorboard_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsGetRequest(
name=tensorboard_ref.RelativeName())
return self._service.Get(request)
def List(self, limit=1000, page_size=50, region_ref=None, sort_by=None):
request = self.messages.AiplatformProjectsLocationsTensorboardsListRequest(
parent=region_ref.RelativeName(),
orderBy=common_args.ParseSortByArg(sort_by))
return list_pager.YieldFromList(
self._service,
request,
field='tensorboards',
batch_size_attribute='pageSize',
batch_size=page_size,
limit=limit)
def Delete(self, tensorboard_ref):
request = self.messages.AiplatformProjectsLocationsTensorboardsDeleteRequest(
name=tensorboard_ref.RelativeName())
return self._service.Delete(request)
def Patch(self, tensorboard_ref, args):
"""Update a Tensorboard."""
if self._version == constants.GA_VERSION:
tensorboard = self.messages.GoogleCloudAiplatformV1Tensorboard()
labels_value = self.messages.GoogleCloudAiplatformV1Tensorboard.LabelsValue
else:
tensorboard = self.messages.GoogleCloudAiplatformV1beta1Tensorboard()
labels_value = self.messages.GoogleCloudAiplatformV1beta1Tensorboard.LabelsValue
update_mask = []
def GetLabels():
return self.Get(tensorboard_ref).labels
labels_update = labels_util.ProcessUpdateArgsLazy(args, labels_value,
GetLabels)
if labels_update.needs_update:
tensorboard.labels = labels_update.labels
update_mask.append('labels')
if args.display_name is not None:
tensorboard.displayName = args.display_name
update_mask.append('display_name')
if args.description is not None:
tensorboard.description = args.description
update_mask.append('description')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
if self._version == constants.GA_VERSION:
req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest(
name=tensorboard_ref.RelativeName(),
googleCloudAiplatformV1Tensorboard=tensorboard,
updateMask=','.join(update_mask))
else:
req = self.messages.AiplatformProjectsLocationsTensorboardsPatchRequest(
name=tensorboard_ref.RelativeName(),
googleCloudAiplatformV1beta1Tensorboard=tensorboard,
updateMask=','.join(update_mask))
return self._service.Patch(req)

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities for dealing with Vertex AI api messages."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
def GetMessagesModule(version=constants.GA_VERSION):
"""Returns message module of the corresponding API version."""
return apis.GetMessagesModule(constants.AI_PLATFORM_API_NAME,
constants.AI_PLATFORM_API_VERSION[version])
def GetMessage(message_name, version=constants.GA_VERSION):
"""Returns the Vertex AI api messages class by name."""
return getattr(
GetMessagesModule(version), '{prefix}{name}'.format(
prefix=constants.AI_PLATFORM_MESSAGE_PREFIX[version],
name=message_name), None)