314 lines
12 KiB
Python
314 lines
12 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2020 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities for dealing with AI Platform indexes API."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
|
|
from apitools.base.py import extra_types
|
|
from apitools.base.py import list_pager
|
|
from googlecloudsdk.api_lib.util import apis
|
|
from googlecloudsdk.api_lib.util import messages as messages_util
|
|
from googlecloudsdk.calliope import exceptions as gcloud_exceptions
|
|
from googlecloudsdk.command_lib.ai import constants
|
|
from googlecloudsdk.command_lib.ai import errors
|
|
from googlecloudsdk.command_lib.util.args import labels_util
|
|
from googlecloudsdk.core import yaml
|
|
|
|
|
|
class IndexesClient(object):
|
|
"""High-level client for the AI Platform indexes surface."""
|
|
|
|
def __init__(self, client=None, messages=None, version=None):
|
|
self.client = client or apis.GetClientInstance(
|
|
constants.AI_PLATFORM_API_NAME,
|
|
constants.AI_PLATFORM_API_VERSION[version])
|
|
self.messages = messages or self.client.MESSAGES_MODULE
|
|
self._service = self.client.projects_locations_indexes
|
|
|
|
def _ReadIndexMetadata(self, metadata_file):
|
|
"""Parse json metadata file."""
|
|
if not metadata_file:
|
|
raise gcloud_exceptions.BadArgumentException(
|
|
'--metadata-file', 'Index metadata file must be specified.')
|
|
index_metadata = None
|
|
# Yaml is a superset of json, so parse json file as yaml.
|
|
data = yaml.load_path(metadata_file)
|
|
if data:
|
|
index_metadata = messages_util.DictToMessageWithErrorCheck(
|
|
data, extra_types.JsonValue)
|
|
return index_metadata
|
|
|
|
def Get(self, index_ref):
|
|
request = self.messages.AiplatformProjectsLocationsIndexesGetRequest(
|
|
name=index_ref.RelativeName())
|
|
return self._service.Get(request)
|
|
|
|
def List(self, limit=None, region_ref=None):
|
|
return list_pager.YieldFromList(
|
|
self._service,
|
|
self.messages.AiplatformProjectsLocationsIndexesListRequest(
|
|
parent=region_ref.RelativeName()),
|
|
field='indexes',
|
|
batch_size_attribute='pageSize',
|
|
limit=limit)
|
|
|
|
def CreateBeta(self, location_ref, args):
|
|
"""Create a new index."""
|
|
labels = labels_util.ParseCreateArgs(
|
|
args, self.messages.GoogleCloudAiplatformV1beta1Index.LabelsValue)
|
|
|
|
index_update_method = None
|
|
if args.index_update_method:
|
|
if args.index_update_method == 'stream-update':
|
|
index_update_method = (
|
|
self.messages.GoogleCloudAiplatformV1beta1Index.
|
|
IndexUpdateMethodValueValuesEnum.STREAM_UPDATE)
|
|
elif args.index_update_method == 'batch-update':
|
|
index_update_method = (
|
|
self.messages.GoogleCloudAiplatformV1beta1Index.
|
|
IndexUpdateMethodValueValuesEnum.BATCH_UPDATE)
|
|
else:
|
|
raise gcloud_exceptions.BadArgumentException(
|
|
'--index-update-method',
|
|
'Invalid index update method: {}'.format(args.index_update_method),
|
|
)
|
|
|
|
encryption_spec = None
|
|
if args.encryption_kms_key_name is not None:
|
|
encryption_spec = (
|
|
self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
|
|
kmsKeyName=args.encryption_kms_key_name))
|
|
|
|
req = self.messages.AiplatformProjectsLocationsIndexesCreateRequest(
|
|
parent=location_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1Index=self.messages
|
|
.GoogleCloudAiplatformV1beta1Index(
|
|
displayName=args.display_name,
|
|
description=args.description,
|
|
metadata=self._ReadIndexMetadata(args.metadata_file),
|
|
labels=labels,
|
|
indexUpdateMethod=index_update_method,
|
|
encryptionSpec=encryption_spec
|
|
))
|
|
return self._service.Create(req)
|
|
|
|
def Create(self, location_ref, args):
|
|
"""Create a new v1 index."""
|
|
labels = labels_util.ParseCreateArgs(
|
|
args, self.messages.GoogleCloudAiplatformV1Index.LabelsValue)
|
|
|
|
index_update_method = None
|
|
if args.index_update_method:
|
|
if args.index_update_method == 'stream-update':
|
|
index_update_method = (
|
|
self.messages.GoogleCloudAiplatformV1Index
|
|
.IndexUpdateMethodValueValuesEnum.STREAM_UPDATE)
|
|
elif args.index_update_method == 'batch-update':
|
|
index_update_method = (
|
|
self.messages.GoogleCloudAiplatformV1Index.IndexUpdateMethodValueValuesEnum.BATCH_UPDATE
|
|
)
|
|
else:
|
|
raise gcloud_exceptions.BadArgumentException(
|
|
'--index-update-method',
|
|
'Invalid index update method: {}'.format(args.index_update_method),
|
|
)
|
|
|
|
encryption_spec = None
|
|
if args.encryption_kms_key_name is not None:
|
|
encryption_spec = (
|
|
self.messages.GoogleCloudAiplatformV1EncryptionSpec(
|
|
kmsKeyName=args.encryption_kms_key_name))
|
|
|
|
req = self.messages.AiplatformProjectsLocationsIndexesCreateRequest(
|
|
parent=location_ref.RelativeName(),
|
|
googleCloudAiplatformV1Index=self.messages.GoogleCloudAiplatformV1Index(
|
|
displayName=args.display_name,
|
|
description=args.description,
|
|
metadata=self._ReadIndexMetadata(args.metadata_file),
|
|
labels=labels,
|
|
indexUpdateMethod=index_update_method,
|
|
encryptionSpec=encryption_spec
|
|
))
|
|
return self._service.Create(req)
|
|
|
|
def PatchBeta(self, index_ref, args):
|
|
"""Update an index."""
|
|
index = self.messages.GoogleCloudAiplatformV1beta1Index()
|
|
update_mask = []
|
|
|
|
if args.metadata_file is not None:
|
|
index.metadata = self._ReadIndexMetadata(args.metadata_file)
|
|
update_mask.append('metadata')
|
|
else:
|
|
if args.display_name is not None:
|
|
index.displayName = args.display_name
|
|
update_mask.append('display_name')
|
|
|
|
if args.description is not None:
|
|
index.description = args.description
|
|
update_mask.append('description')
|
|
|
|
def GetLabels():
|
|
return self.Get(index_ref).labels
|
|
|
|
labels_update = labels_util.ProcessUpdateArgsLazy(
|
|
args, self.messages.GoogleCloudAiplatformV1beta1Index.LabelsValue,
|
|
GetLabels)
|
|
if labels_update.needs_update:
|
|
index.labels = labels_update.labels
|
|
update_mask.append('labels')
|
|
|
|
if not update_mask:
|
|
raise errors.NoFieldsSpecifiedError('No updates requested.')
|
|
|
|
request = self.messages.AiplatformProjectsLocationsIndexesPatchRequest(
|
|
name=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1Index=index,
|
|
updateMask=','.join(update_mask))
|
|
return self._service.Patch(request)
|
|
|
|
def Patch(self, index_ref, args):
|
|
"""Update an v1 index."""
|
|
index = self.messages.GoogleCloudAiplatformV1Index()
|
|
update_mask = []
|
|
|
|
if args.metadata_file is not None:
|
|
index.metadata = self._ReadIndexMetadata(args.metadata_file)
|
|
update_mask.append('metadata')
|
|
else:
|
|
if args.display_name is not None:
|
|
index.displayName = args.display_name
|
|
update_mask.append('display_name')
|
|
|
|
if args.description is not None:
|
|
index.description = args.description
|
|
update_mask.append('description')
|
|
|
|
def GetLabels():
|
|
return self.Get(index_ref).labels
|
|
|
|
labels_update = labels_util.ProcessUpdateArgsLazy(
|
|
args, self.messages.GoogleCloudAiplatformV1Index.LabelsValue,
|
|
GetLabels)
|
|
if labels_update.needs_update:
|
|
index.labels = labels_update.labels
|
|
update_mask.append('labels')
|
|
|
|
if not update_mask:
|
|
raise errors.NoFieldsSpecifiedError('No updates requested.')
|
|
|
|
request = self.messages.AiplatformProjectsLocationsIndexesPatchRequest(
|
|
name=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1Index=index,
|
|
updateMask=','.join(update_mask))
|
|
return self._service.Patch(request)
|
|
|
|
def Delete(self, index_ref):
|
|
request = self.messages.AiplatformProjectsLocationsIndexesDeleteRequest(
|
|
name=index_ref.RelativeName())
|
|
return self._service.Delete(request)
|
|
|
|
def RemoveDatapointsBeta(self, index_ref, args):
|
|
"""Remove data points from a v1beta1 index."""
|
|
if args.datapoint_ids and args.datapoints_from_file:
|
|
raise errors.ArgumentError(
|
|
'datapoint_ids and datapoints_from_file can not be set'
|
|
' at the same time.'
|
|
)
|
|
|
|
if args.datapoint_ids:
|
|
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
|
|
index=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1RemoveDatapointsRequest=self.messages
|
|
.GoogleCloudAiplatformV1beta1RemoveDatapointsRequest(
|
|
datapointIds=args.datapoint_ids))
|
|
if args.datapoints_from_file:
|
|
data = yaml.load_path(args.datapoints_from_file)
|
|
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
|
|
index=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1RemoveDatapointsRequest=self.messages
|
|
.GoogleCloudAiplatformV1beta1RemoveDatapointsRequest(
|
|
datapointIds=data))
|
|
return self._service.RemoveDatapoints(req)
|
|
|
|
def RemoveDatapoints(self, index_ref, args):
|
|
"""Remove data points from a v1 index."""
|
|
if args.datapoint_ids and args.datapoints_from_file:
|
|
raise errors.ArgumentError(
|
|
'`--datapoint_ids` and `--datapoints_from_file` can not be set at the'
|
|
' same time.'
|
|
)
|
|
|
|
if args.datapoint_ids:
|
|
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
|
|
index=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1RemoveDatapointsRequest=self.messages
|
|
.GoogleCloudAiplatformV1RemoveDatapointsRequest(
|
|
datapointIds=args.datapoint_ids))
|
|
if args.datapoints_from_file:
|
|
data = yaml.load_path(args.datapoints_from_file)
|
|
req = self.messages.AiplatformProjectsLocationsIndexesRemoveDatapointsRequest(
|
|
index=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1RemoveDatapointsRequest=self.messages
|
|
.GoogleCloudAiplatformV1RemoveDatapointsRequest(
|
|
datapointIds=data))
|
|
return self._service.RemoveDatapoints(req)
|
|
|
|
def UpsertDatapointsBeta(self, index_ref, args):
|
|
"""Upsert data points from a v1beta1 index."""
|
|
datapoints = []
|
|
if args.datapoints_from_file:
|
|
data = yaml.load_path(args.datapoints_from_file)
|
|
for datapoint_json in data:
|
|
datapoint = messages_util.DictToMessageWithErrorCheck(
|
|
datapoint_json,
|
|
self.messages.GoogleCloudAiplatformV1beta1IndexDatapoint)
|
|
datapoints.append(datapoint)
|
|
update_mask = None
|
|
if args.update_mask:
|
|
update_mask = ','.join(args.update_mask)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsIndexesUpsertDatapointsRequest(
|
|
index=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1beta1UpsertDatapointsRequest=self.messages
|
|
.GoogleCloudAiplatformV1beta1UpsertDatapointsRequest(
|
|
datapoints=datapoints,
|
|
updateMask=update_mask))
|
|
return self._service.UpsertDatapoints(req)
|
|
|
|
def UpsertDatapoints(self, index_ref, args):
|
|
"""Upsert data points from a v1 index."""
|
|
datapoints = []
|
|
if args.datapoints_from_file:
|
|
data = yaml.load_path(args.datapoints_from_file)
|
|
for datapoint_json in data:
|
|
datapoint = messages_util.DictToMessageWithErrorCheck(
|
|
datapoint_json,
|
|
self.messages.GoogleCloudAiplatformV1IndexDatapoint)
|
|
datapoints.append(datapoint)
|
|
update_mask = None
|
|
if args.update_mask:
|
|
update_mask = ','.join(args.update_mask)
|
|
|
|
req = self.messages.AiplatformProjectsLocationsIndexesUpsertDatapointsRequest(
|
|
index=index_ref.RelativeName(),
|
|
googleCloudAiplatformV1UpsertDatapointsRequest=self.messages
|
|
.GoogleCloudAiplatformV1UpsertDatapointsRequest(
|
|
datapoints=datapoints,
|
|
updateMask=update_mask))
|
|
return self._service.UpsertDatapoints(req)
|