feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def Await(operation, message):
"""Wait for the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
poller = EmbeddedResponsePoller(
client.projects_instances_backups_operations)
ref = resources.REGISTRY.ParseRelativeName(
operation.name,
collection='spanner.projects.instances.backups.operations')
return waiter.WaitFor(poller, ref, message)
def Cancel(instance, backup, operation):
"""Cancel the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance,
'backupsId': backup
},
collection='spanner.projects.instances.backups.operations')
req = msgs.SpannerProjectsInstancesBackupsOperationsCancelRequest(
name=ref.RelativeName())
return client.projects_instances_backups_operations.Cancel(req)
def Get(instance, backup, operation):
"""Get the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance,
'backupsId': backup
},
collection='spanner.projects.instances.backups.operations')
req = msgs.SpannerProjectsInstancesBackupsOperationsGetRequest(
name=ref.RelativeName())
return client.projects_instances_backups_operations.Get(req)
def BuildDatabaseFilter(instance, database):
database_ref = resources.REGISTRY.Parse(
database,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance
},
collection='spanner.projects.instances.databases')
return 'metadata.database:\"{}\"'.format(database_ref.RelativeName())
def List(instance, op_filter=None):
"""List operations on the backup."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
instance_ref = resources.REGISTRY.Parse(
instance,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
},
collection='spanner.projects.instances')
req = msgs.SpannerProjectsInstancesBackupOperationsListRequest(
parent=instance_ref.RelativeName(),
filter=op_filter)
return list_pager.YieldFromList(
client.projects_instances_backupOperations,
req,
field='operations',
batch_size_attribute='pageSize')
def ListGeneric(instance, backup):
"""List operations on the backup with generic LRO API."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
instance_ref = resources.REGISTRY.Parse(
instance,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
},
collection='spanner.projects.instances')
name = '{}/backups/{}/operations'.format(instance_ref.RelativeName(), backup)
req = msgs.SpannerProjectsInstancesBackupsOperationsListRequest(name=name)
return list_pager.YieldFromList(
client.projects_instances_backups_operations,
req,
field='operations',
batch_size_attribute='pageSize')
class EmbeddedResponsePoller(waiter.CloudOperationPoller):
"""As CloudOperationPoller for polling, but uses the Operation.response."""
def __init__(self, operation_service):
self.operation_service = operation_service
def GetResult(self, operation):
return operation.response

View File

@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Spanner backup schedules API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.spanner.resource_args import CloudKmsKeyName
from googlecloudsdk.core.util import times
def ParseAndFormatRetentionDuration(retention_duration):
return times.FormatDurationForJson(times.ParseDuration(retention_duration))
def CreateBackupScheduleMessage(
backup_schedule_ref,
args,
msgs,
encryption_type=None,
kms_key: CloudKmsKeyName = None,
):
"""Create a backup schedule message.
Args:
backup_schedule_ref: resource argument for a cloud spanner backup schedule.
args: an argparse namespace. All the arguments that were provided to command
invocation.
msgs: contains the definitions of messages for the spanner v1 API.
encryption_type: encryption type for the backup encryption.
kms_key: contains the encryption keys for the backup encryption.
Returns:
BackupSchedule message.
"""
backup_schedule = msgs.BackupSchedule(name=backup_schedule_ref.RelativeName())
if args.retention_duration:
backup_schedule.retentionDuration = ParseAndFormatRetentionDuration(
args.retention_duration
)
if encryption_type or kms_key:
encryption_config = msgs.CreateBackupEncryptionConfig()
if encryption_type:
encryption_config.encryptionType = encryption_type
if kms_key:
if kms_key.kms_key_name:
encryption_config.kmsKeyName = kms_key.kms_key_name
elif kms_key.kms_key_names:
encryption_config.kmsKeyNames = kms_key.kms_key_names
backup_schedule.encryptionConfig = encryption_config
if args.cron:
backup_schedule.spec = msgs.BackupScheduleSpec(
cronSpec=msgs.CrontabSpec(text=args.cron)
)
return backup_schedule
def Get(backup_schedule_ref):
"""Get a backup schedule."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesBackupSchedulesGetRequest(
name=backup_schedule_ref.RelativeName()
)
return client.projects_instances_databases_backupSchedules.Get(req)
def Delete(backup_schedule_ref):
"""Delete a backup schedule."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesBackupSchedulesDeleteRequest(
name=backup_schedule_ref.RelativeName()
)
return client.projects_instances_databases_backupSchedules.Delete(req)
def Create(
backup_schedule_ref,
args,
encryption_type,
kms_key,
):
"""Create a new backup schedule."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesBackupSchedulesCreateRequest(
parent=backup_schedule_ref.Parent().RelativeName()
)
req.backupSchedule = CreateBackupScheduleMessage(
backup_schedule_ref, args, msgs, encryption_type, kms_key
)
if args.backup_type[0] == 'full-backup':
req.backupSchedule.fullBackupSpec = msgs.FullBackupSpec()
elif args.backup_type[0] == 'incremental-backup':
req.backupSchedule.incrementalBackupSpec = msgs.IncrementalBackupSpec()
req.backupScheduleId = backup_schedule_ref.Name()
return client.projects_instances_databases_backupSchedules.Create(req)
def List(database_ref):
"""List backup schedules in the database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesBackupSchedulesListRequest(
parent=database_ref.RelativeName()
)
return list_pager.YieldFromList(
client.projects_instances_databases_backupSchedules,
req,
field='backupSchedules',
batch_size_attribute='pageSize',
)
def Update(
backup_schedule_ref,
args,
encryption_type,
kms_key,
):
"""Update a backup schedule."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesBackupSchedulesPatchRequest(
name=backup_schedule_ref.RelativeName()
)
req.backupSchedule = CreateBackupScheduleMessage(
backup_schedule_ref, args, msgs, encryption_type, kms_key
)
update_mask_paths = []
if args.cron:
update_mask_paths.append('spec.cron_spec.text')
if args.retention_duration:
update_mask_paths.append('retention_duration')
if encryption_type or kms_key:
update_mask_paths.append('encryption_config')
req.updateMask = ','.join(update_mask_paths)
return client.projects_instances_databases_backupSchedules.Patch(req)

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*- #
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Spanner backups API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import exceptions as c_exceptions
from googlecloudsdk.command_lib.spanner.resource_args import CloudKmsKeyName
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core.credentials import requests
from googlecloudsdk.core.util import times
from six.moves import http_client as httplib
from six.moves import urllib
class HttpRequestFailedError(core_exceptions.Error):
"""Indicates that the http request failed in some way."""
pass
# General Utils
def ParseExpireTime(expiration_value):
"""Parse flag value into Datetime format for expireTime."""
# expiration_value could be in Datetime format or Duration format.
datetime = (
times.ParseDuration(expiration_value).GetRelativeDateTime(
times.Now(times.UTC)))
parsed_datetime = times.FormatDateTime(
datetime, '%Y-%m-%dT%H:%M:%S.%6f%Ez', tzinfo=times.UTC)
return parsed_datetime
def CheckAndGetExpireTime(args):
"""Check if fields for expireTime are correctly specified and parse value."""
# User can only specify either expiration_date or retention_period, not both.
if (args.IsSpecified('expiration_date') and
args.IsSpecified('retention_period')) or not(
args.IsSpecified('expiration_date') or
args.IsSpecified('retention_period')):
raise c_exceptions.InvalidArgumentException(
'--expiration-date or --retention-period',
'Must specify either --expiration-date or --retention-period.')
if args.expiration_date:
return args.expiration_date
elif args.retention_period:
return ParseExpireTime(args.retention_period)
def GetBackup(backup_ref):
"""Get a backup."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesBackupsGetRequest(
name=backup_ref.RelativeName())
return client.projects_instances_backups.Get(req)
def CreateBackup(
backup_ref, args, encryption_type=None, kms_key: CloudKmsKeyName = None
):
"""Create a new backup."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
query_params = {'alt': 'json', 'backupId': args.backup}
if encryption_type:
query_params['encryptionConfig.encryptionType'] = encryption_type
if kms_key:
if kms_key.kms_key_name:
query_params['encryptionConfig.kmsKeyName'] = kms_key.kms_key_name
elif kms_key.kms_key_names:
query_params['encryptionConfig.kmsKeyNames'] = kms_key.kms_key_names
parent = backup_ref.Parent().RelativeName()
url = '{}v1/{}/backups?{}'.format(
client.url, parent, urllib.parse.urlencode(query_params, doseq=True)
)
backup = msgs.Backup(
database=parent + '/databases/' + args.database,
expireTime=CheckAndGetExpireTime(args))
if args.IsSpecified('version_time'):
backup.versionTime = args.version_time
# We are not using `SpannerProjectsInstancesBackupsCreateRequest` from
# `spanner_v1_messages.py` because `apitools` does not generate nested proto
# messages correctly, b/31244944. Here, an `EncryptionConfig` should be a
# nested proto, rather than `EncryptionConfig_kmsKeyName` being a
# field(http://shortn/_gHieB9ir83). Thus, this workaround is necessary and
# will be here to stay since `apitools` is not under active development and
# gcloud will continue to support `apitools` http://shortn/_BJJCZbnCFp.
# Make an http request directly instead of using the apitools client which
# does not support '.' characters in query parameters (b/31244944).
http_client = requests.GetSession()
# Workaround since gcloud cannot handle HttpBody properly (b/31403673).
http_client.encoding = 'utf-8'
response = http_client.request(
'POST', url, data=client.SerializeMessage(backup)
)
if int(response.status_code) != httplib.OK:
raise HttpRequestFailedError(
'HTTP request failed. Response: ' + response.text
)
message_type = getattr(msgs, 'Operation')
return client.DeserializeMessage(message_type, response.content)
def CopyBackup(source_backup_ref,
destination_backup_ref,
args,
encryption_type=None,
kms_key=None):
"""Copy a backup."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
copy_backup_request = msgs.CopyBackupRequest(
backupId=destination_backup_ref.Name(),
sourceBackup=source_backup_ref.RelativeName())
copy_backup_request.expireTime = CheckAndGetExpireTime(args)
if kms_key:
copy_backup_request.encryptionConfig = msgs.CopyBackupEncryptionConfig(
encryptionType=encryption_type,
kmsKeyName=kms_key.kms_key_name,
kmsKeyNames=kms_key.kms_key_names,
)
elif encryption_type:
copy_backup_request.encryptionConfig = msgs.CopyBackupEncryptionConfig(
encryptionType=encryption_type,
)
req = msgs.SpannerProjectsInstancesBackupsCopyRequest(
parent=destination_backup_ref.Parent().RelativeName(),
copyBackupRequest=copy_backup_request)
return client.projects_instances_backups.Copy(req)
def ModifyUpdateMetadataRequest(backup_ref, args, req):
"""Parse arguments and construct update backup request."""
req.backup.name = backup_ref.Parent().RelativeName(
) + '/backups/' + args.backup
req.backup.expireTime = CheckAndGetExpireTime(args)
req.updateMask = 'expire_time'
return req
def ModifyListRequest(instance_ref, args, req):
"""Parse arguments and construct list backups request."""
req.parent = instance_ref.RelativeName()
if args.database:
database = instance_ref.RelativeName() + '/databases/' + args.database
req.filter = 'database="{}"'.format(database)
return req
def CheckBackupExists(backup_ref, _, req):
"""Checks if backup exists, if so, returns request."""
# The delete API returns a 200 regardless of whether the backup being
# deleted exists. In order to show users feedback for incorrectly
# entered backup names, we have to make a request to check if the backup
# exists. If the backup exists, it's deleted, otherwise, we display the
# error from backups.Get.
GetBackup(backup_ref)
return req
def FormatListBackups(backup_refs, _):
"""Formats existing fields for displaying them in the list response.
Args:
backup_refs: A list of backups.
Returns:
The list of backups with the new formatting.
"""
return [_FormatBackup(backup_ref) for backup_ref in backup_refs]
def _FormatBackup(backup_ref):
"""Formats a single backup for displaying it in the list response.
This function makes in-place modifications.
Args:
backup_ref: The backup to format.
Returns:
The backup with the new formatting.
"""
formatted_backup_ref = backup_ref
formatted_backup_ref.backupSchedules = [
_ExtractScheduleNameFromScheduleUri(schedule_uri)
for schedule_uri in backup_ref.backupSchedules
]
formatted_backup_ref.instancePartitions = [
_ExtractInstancePartitionNameFromInstancePartitionUri(
instance_partition.instancePartition
)
for instance_partition in backup_ref.instancePartitions
]
return formatted_backup_ref
def _ExtractScheduleNameFromScheduleUri(schedule_uri):
"""Converts a schedule URI to an schedule name.
Args:
schedule_uri: The URI of the schedule, e.g.,
"projects/test-project/instances/test-instance/databases/test-database/backupSchedules/test-backup-schedule".
Returns:
The name of the schedule ("test-backup-schedule" in the example above).
"""
return schedule_uri.split('/')[-1]
def _ExtractInstancePartitionNameFromInstancePartitionUri(
instance_partition_uri,
):
"""Converts an instance partition URI to an instance partition name.
Args:
instance_partition_uri: The URI of an instance partition, e.g.,
"projects/test-project/instances/test-instance/instancePartitions/test-instance-partition".
Returns:
The name of the instance partition ("test-instance-partition" in the
example above).
"""
return {'instancePartition': instance_partition_uri.split('/')[-1]}

View File

@@ -0,0 +1,145 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def Await(operation, message):
"""Wait for the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
poller = EmbeddedResponsePoller(
client.projects_instances_databases_operations)
ref = resources.REGISTRY.ParseRelativeName(
operation.name,
collection='spanner.projects.instances.databases.operations')
return waiter.WaitFor(poller, ref, message)
def Cancel(instance, database, operation):
"""Cancel the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance,
'databasesId': database},
collection='spanner.projects.instances.databases.operations')
req = msgs.SpannerProjectsInstancesDatabasesOperationsCancelRequest(
name=ref.RelativeName())
return client.projects_instances_databases_operations.Cancel(req)
def Get(instance, database, operation):
"""Get the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance,
'databasesId': database,
},
collection='spanner.projects.instances.databases.operations')
req = msgs.SpannerProjectsInstancesDatabasesOperationsGetRequest(
name=ref.RelativeName())
return client.projects_instances_databases_operations.Get(req)
def List(instance, database, type_filter=None):
"""List operations on the database using the generic operation list API."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
database,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance
},
collection='spanner.projects.instances.databases')
req = msgs.SpannerProjectsInstancesDatabasesOperationsListRequest(
name=ref.RelativeName()+'/operations',
filter=type_filter)
return list_pager.YieldFromList(
client.projects_instances_databases_operations,
req,
field='operations',
batch_size_attribute='pageSize')
def BuildDatabaseOperationTypeFilter(op_type):
"""Builds the filter for the different database operation metadata types."""
if op_type == 'DATABASE':
return ''
base_string = 'metadata.@type:type.googleapis.com/google.spanner.admin.database.v1.'
if op_type == 'DATABASE_RESTORE':
return '({}OptimizeRestoredDatabaseMetadata) OR ({}RestoreDatabaseMetadata)'.format(
base_string, base_string)
if op_type == 'DATABASE_CREATE':
return base_string + 'CreateDatabaseMetadata'
if op_type == 'DATABASE_UPDATE_DDL':
return base_string + 'UpdateDatabaseDdlMetadata'
if op_type == 'DATABASE_CHANGE_QUORUM':
return base_string + 'ChangeQuorumMetadata'
def ListDatabaseOperations(instance, database=None, type_filter=None):
"""List database operations using the Cloud Spanner specific API."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
instance_ref = resources.REGISTRY.Parse(
instance,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
},
collection='spanner.projects.instances')
# When the database is passed in, use the generic list command, so no
# operations are shown from previous incarnations of the database.
if database:
return List(instance, database, type_filter)
req = msgs.SpannerProjectsInstancesDatabaseOperationsListRequest(
parent=instance_ref.RelativeName(), filter=type_filter)
return list_pager.YieldFromList(
client.projects_instances_databaseOperations,
req,
field='operations',
batch_size_attribute='pageSize')
class EmbeddedResponsePoller(waiter.CloudOperationPoller):
"""As CloudOperationPoller for polling, but uses the Operation.response."""
def __init__(self, operation_service):
self.operation_service = operation_service
def GetResult(self, operation):
return operation.response

View File

@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database roles API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
def List(database_ref):
"""List IAM resources for database roles."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesDatabaseRolesListRequest(
parent=database_ref.RelativeName())
return list_pager.YieldFromList(
client.projects_instances_databases_databaseRoles,
req,
field='databaseRoles',
batch_size_attribute='pageSize')

View File

@@ -0,0 +1,338 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database sessions API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import extra_types
from apitools.base.py import http_wrapper
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.spanner.sql import QueryHasDml
def CheckResponse(response):
"""Wrap http_wrapper.CheckResponse to skip retry on 501."""
if response.status_code == 501:
raise apitools_exceptions.HttpError.FromResponse(response)
return http_wrapper.CheckResponse(response)
def Create(database_ref, creator_role=None):
"""Create a database session.
Args:
database_ref: String, The database in which the new session is created.
creator_role: String, The database role which created this session.
Returns:
Newly created session.
"""
client = _GetClientInstance('spanner', 'v1', None)
msgs = apis.GetMessagesModule('spanner', 'v1')
if creator_role is None:
req = msgs.SpannerProjectsInstancesDatabasesSessionsCreateRequest(
database=database_ref.RelativeName())
else:
create_session_request = msgs.CreateSessionRequest(
session=msgs.Session(creatorRole=creator_role))
req = msgs.SpannerProjectsInstancesDatabasesSessionsCreateRequest(
createSessionRequest=create_session_request,
database=database_ref.RelativeName())
return client.projects_instances_databases_sessions.Create(req)
def List(database_ref, server_filter=None):
"""Lists all active sessions on the given database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesSessionsListRequest(
database=database_ref.RelativeName(), filter=server_filter)
return list_pager.YieldFromList(
client.projects_instances_databases_sessions,
req,
# There is a batch_size_attribute ('pageSize') but we want to yield as
# many results as possible per request.
batch_size_attribute=None,
field='sessions')
def Delete(session_ref):
"""Delete a database session."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesSessionsDeleteRequest(
name=session_ref.RelativeName())
return client.projects_instances_databases_sessions.Delete(req)
def _GetClientInstance(api_name, api_version, http_timeout_sec=None):
client = apis.GetClientInstance(
api_name, api_version, http_timeout_sec=http_timeout_sec)
client.check_response_func = CheckResponse
return client
def ExecuteSql(sql, query_mode, session_ref, read_only_options=None,
request_options=None, enable_partitioned_dml=False,
http_timeout_sec=None):
"""Execute an SQL command.
Args:
sql: String, The SQL to execute.
query_mode: String, The mode in which to run the query. Must be one of
'NORMAL', 'PLAN', 'PROFILE', 'WITH_STATS', or 'WITH_PLAN_AND_STATS'.
session_ref: Session, Indicates that the repo should be created if it does
not exist.
read_only_options: The ReadOnly message for a read-only request. It is
ignored in a DML request.
request_options: The RequestOptions message that contains the priority.
enable_partitioned_dml: Boolean, whether partitioned dml is enabled.
http_timeout_sec: int, Maximum time in seconds to wait for the SQL query to
complete.
Returns:
(Repo) The capture repository.
"""
client = _GetClientInstance('spanner', 'v1', http_timeout_sec)
msgs = apis.GetMessagesModule('spanner', 'v1')
_RegisterCustomMessageCodec(msgs)
execute_sql_request = _GetQueryRequest(
sql,
query_mode,
session_ref,
read_only_options,
request_options,
enable_partitioned_dml,
)
req = msgs.SpannerProjectsInstancesDatabasesSessionsExecuteSqlRequest(
session=session_ref.RelativeName(), executeSqlRequest=execute_sql_request)
resp = client.projects_instances_databases_sessions.ExecuteSql(req)
if QueryHasDml(sql) and enable_partitioned_dml is False:
result_set = msgs.ResultSet(metadata=resp.metadata)
Commit(session_ref, [], result_set.metadata.transaction.id)
return resp
def _RegisterCustomMessageCodec(msgs):
"""Register custom message code.
Args:
msgs: Spanner v1 messages.
"""
# TODO(b/33482229): remove this workaround
def _ToJson(msg):
return extra_types.JsonProtoEncoder(
extra_types.JsonArray(entries=msg.entry))
def _FromJson(data):
return msgs.ResultSet.RowsValueListEntry(
entry=extra_types.JsonProtoDecoder(data).entries)
encoding.RegisterCustomMessageCodec(
encoder=_ToJson, decoder=_FromJson)(
msgs.ResultSet.RowsValueListEntry)
def _GetQueryRequest(sql,
query_mode,
session_ref=None,
read_only_options=None,
request_options=None,
enable_partitioned_dml=False):
"""Formats the request based on whether the statement contains DML.
Args:
sql: String, The SQL to execute.
query_mode: String, The mode in which to run the query. Must be one of
'NORMAL', 'PLAN', 'PROFILE', 'WITH_STATS', or 'WITH_PLAN_AND_STATS'.
session_ref: Reference to the session.
read_only_options: The ReadOnly message for a read-only request. It is
ignored in a DML request.
request_options: The RequestOptions message that contains the priority.
enable_partitioned_dml: Boolean, whether partitioned dml is enabled.
Returns:
ExecuteSqlRequest parameters
"""
msgs = apis.GetMessagesModule('spanner', 'v1')
if enable_partitioned_dml is True:
transaction = _GetPartitionedDmlTransaction(session_ref)
elif QueryHasDml(sql):
transaction_options = msgs.TransactionOptions(readWrite=msgs.ReadWrite())
transaction = msgs.TransactionSelector(begin=transaction_options)
else:
transaction_options = msgs.TransactionOptions(
readOnly=read_only_options)
transaction = msgs.TransactionSelector(singleUse=transaction_options)
return msgs.ExecuteSqlRequest(
sql=sql,
requestOptions=request_options,
queryMode=msgs.ExecuteSqlRequest.QueryModeValueValuesEnum(query_mode),
transaction=transaction)
def _GetPartitionedDmlTransaction(session_ref):
"""Creates a transaction for Partitioned DML.
Args:
session_ref: Reference to the session.
Returns:
TransactionSelector with the id property.
"""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
transaction_options = msgs.TransactionOptions(
partitionedDml=msgs.PartitionedDml())
begin_transaction_req = msgs.BeginTransactionRequest(
options=transaction_options)
req = msgs.SpannerProjectsInstancesDatabasesSessionsBeginTransactionRequest(
beginTransactionRequest=begin_transaction_req,
session=session_ref.RelativeName())
resp = client.projects_instances_databases_sessions.BeginTransaction(req)
return msgs.TransactionSelector(id=resp.id)
def Commit(session_ref, mutations, transaction_id=None):
"""Commit a transaction through a session.
In Cloud Spanner, each session can have at most one active transaction at a
time. In order to avoid retrying aborted transactions by accident, this
request uses a temporary single use transaction instead of a previously
started transaction to execute the mutations.
Note: this commit is non-idempotent.
Args:
session_ref: Session, through which the transaction would be committed.
mutations: A list of mutations, each represents a modification to one or
more Cloud Spanner rows.
transaction_id: An optional string for the transaction id.
Returns:
The Cloud Spanner timestamp at which the transaction committed.
"""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
if transaction_id is not None:
req = msgs.SpannerProjectsInstancesDatabasesSessionsCommitRequest(
session=session_ref.RelativeName(),
commitRequest=msgs.CommitRequest(
mutations=mutations, transactionId=transaction_id))
else:
req = msgs.SpannerProjectsInstancesDatabasesSessionsCommitRequest(
session=session_ref.RelativeName(),
commitRequest=msgs.CommitRequest(
mutations=mutations,
singleUseTransaction=msgs.TransactionOptions(
readWrite=msgs.ReadWrite())))
return client.projects_instances_databases_sessions.Commit(req)
class MutationFactory(object):
"""Factory that creates and returns a mutation object in Cloud Spanner.
A Mutation represents a sequence of inserts, updates and deletes that can be
applied to rows and tables in a Cloud Spanner database.
"""
msgs = apis.GetMessagesModule('spanner', 'v1')
@classmethod
def Insert(cls, table, data):
"""Constructs an INSERT mutation, which inserts a new row in a table.
Args:
table: String, the name of the table.
data: A collections.OrderedDict, the keys of which are the column names
and values are the column values to be inserted.
Returns:
An insert mutation operation.
"""
return cls.msgs.Mutation(insert=cls._GetWrite(table, data))
@classmethod
def Update(cls, table, data):
"""Constructs an UPDATE mutation, which updates a row in a table.
Args:
table: String, the name of the table.
data: An ordered dictionary where the keys are the column names and values
are the column values to be updated.
Returns:
An update mutation operation.
"""
return cls.msgs.Mutation(update=cls._GetWrite(table, data))
@classmethod
def Delete(cls, table, keys):
"""Constructs a DELETE mutation, which deletes a row in a table.
Args:
table: String, the name of the table.
keys: String list, the primary key values of the row to delete.
Returns:
A delete mutation operation.
"""
return cls.msgs.Mutation(delete=cls._GetDelete(table, keys))
@classmethod
def _GetWrite(cls, table, data):
"""Constructs Write object, which is needed for insert/update operations."""
# TODO(b/33482229): a workaround to handle JSON serialization
def _ToJson(msg):
return extra_types.JsonProtoEncoder(
extra_types.JsonArray(entries=msg.entry))
encoding.RegisterCustomMessageCodec(
encoder=_ToJson, decoder=None)(
cls.msgs.Write.ValuesValueListEntry)
json_columns = table.GetJsonData(data)
json_column_names = [col.col_name for col in json_columns]
json_column_values = [col.col_value for col in json_columns]
return cls.msgs.Write(
columns=json_column_names,
table=table.name,
values=[cls.msgs.Write.ValuesValueListEntry(entry=json_column_values)])
@classmethod
def _GetDelete(cls, table, keys):
"""Constructs Delete object, which is needed for delete operation."""
# TODO(b/33482229): a workaround to handle JSON serialization
def _ToJson(msg):
return extra_types.JsonProtoEncoder(
extra_types.JsonArray(entries=msg.entry))
encoding.RegisterCustomMessageCodec(
encoder=_ToJson, decoder=None)(
cls.msgs.KeySet.KeysValueListEntry)
key_set = cls.msgs.KeySet(keys=[
cls.msgs.KeySet.KeysValueListEntry(entry=table.GetJsonKeys(keys))
])
return cls.msgs.Delete(table=table.name, keySet=key_set)

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database splits helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import encoding
from googlecloudsdk.api_lib.spanner import database_sessions
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import resources
def AddSplitPoints(database_ref, split_points, initiator_string):
"""Add split points to a database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesAddSplitPointsRequest(
database=database_ref.RelativeName()
)
req.addSplitPointsRequest = msgs.AddSplitPointsRequest()
if initiator_string:
req.addSplitPointsRequest.initiator = initiator_string
req.addSplitPointsRequest.splitPoints = split_points
return client.projects_instances_databases.AddSplitPoints(req)
def ListSplitPoints(database_ref):
"""List the user added split points fo a database."""
# TODO(b/362149997): Check this for both dialects.
session_name = database_sessions.Create(database_ref, None)
session = resources.REGISTRY.ParseRelativeName(
relative_name=session_name.name,
collection='spanner.projects.instances.databases.sessions',
)
try:
return _TransformToSplitResult(
database_sessions.ExecuteSql(
'SELECT TABLE_NAME, INDEX_NAME, INITIATOR, SPLIT_KEY, EXPIRE_TIME'
' FROM SPANNER_SYS.USER_SPLIT_POINTS',
'NORMAL',
session,
)
)
finally:
database_sessions.Delete(session)
def _TransformToSplitResult(result):
"""Transform the result of the query to a list of split points."""
split_points = [
{
'TABLE_NAME': encoding.MessageToPyValue(row.entry[0]),
'INDEX_NAME': encoding.MessageToPyValue(row.entry[1]),
'INITIATOR': encoding.MessageToPyValue(row.entry[2]),
'SPLIT_KEY': encoding.MessageToPyValue(row.entry[3]),
'EXPIRE_TIME': encoding.MessageToPyValue(row.entry[4]),
}
for row in result.rows
]
return split_points

View File

@@ -0,0 +1,248 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from cloudsdk.google.protobuf import descriptor_pb2
from cloudsdk.google.protobuf import text_format
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.command_lib.spanner.resource_args import CloudKmsKeyName
# The list of pre-defined IAM roles in Spanner.
KNOWN_ROLES = [
'roles/spanner.admin', 'roles/spanner.databaseAdmin',
'roles/spanner.databaseReader', 'roles/spanner.databaseUser',
'roles/spanner.viewer'
]
# The available options of the SQL dialect for a Cloud Spanner database.
DATABASE_DIALECT_GOOGLESQL = 'GOOGLE_STANDARD_SQL'
DATABASE_DIALECT_POSTGRESQL = 'POSTGRESQL'
def Create(
instance_ref,
database,
ddl,
proto_descriptors=None,
kms_key: CloudKmsKeyName = None,
database_dialect=None,
):
"""Create a new database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req_args = {
'createStatement': 'CREATE DATABASE `{}`'.format(database),
'extraStatements': ddl,
}
if proto_descriptors:
req_args['protoDescriptors'] = proto_descriptors
if database_dialect:
database_dialect = database_dialect.upper()
if database_dialect == DATABASE_DIALECT_POSTGRESQL:
req_args['createStatement'] = 'CREATE DATABASE "{}"'.format(database)
req_args[
'databaseDialect'] = msgs.CreateDatabaseRequest.DatabaseDialectValueValuesEnum.POSTGRESQL
else:
req_args[
'databaseDialect'] = msgs.CreateDatabaseRequest.DatabaseDialectValueValuesEnum.GOOGLE_STANDARD_SQL
if kms_key:
req_args['encryptionConfig'] = msgs.EncryptionConfig(
kmsKeyName=kms_key.kms_key_name, kmsKeyNames=kms_key.kms_key_names
)
req = msgs.SpannerProjectsInstancesDatabasesCreateRequest(
parent=instance_ref.RelativeName(),
createDatabaseRequest=msgs.CreateDatabaseRequest(**req_args))
return client.projects_instances_databases.Create(req)
def SetPolicy(database_ref, policy):
"""Saves the given policy on the database, overwriting whatever exists."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
policy.version = iam_util.MAX_LIBRARY_IAM_SUPPORTED_VERSION
req = msgs.SpannerProjectsInstancesDatabasesSetIamPolicyRequest(
resource=database_ref.RelativeName(),
setIamPolicyRequest=msgs.SetIamPolicyRequest(policy=policy))
return client.projects_instances_databases.SetIamPolicy(req)
def Delete(database_ref):
"""Delete a database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesDropDatabaseRequest(
database=database_ref.RelativeName())
return client.projects_instances_databases.DropDatabase(req)
def GetIamPolicy(database_ref):
"""Gets the IAM policy on a database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesGetIamPolicyRequest(
getIamPolicyRequest=msgs.GetIamPolicyRequest(
options=msgs.GetPolicyOptions(
requestedPolicyVersion=
iam_util.MAX_LIBRARY_IAM_SUPPORTED_VERSION)),
resource=database_ref.RelativeName())
return client.projects_instances_databases.GetIamPolicy(req)
def Get(database_ref):
"""Get a database by name."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesGetRequest(
name=database_ref.RelativeName())
return client.projects_instances_databases.Get(req)
def GetDdl(database_ref):
"""Get a database's DDL description."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesGetDdlRequest(
database=database_ref.RelativeName())
return client.projects_instances_databases.GetDdl(req).statements
def GetDdlWithDescriptors(database_ref, args):
"""Get a database's DDL description."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesGetDdlRequest(
database=database_ref.RelativeName()
)
get_ddl_resp = client.projects_instances_databases.GetDdl(req)
if not args.include_proto_descriptors:
return get_ddl_resp.statements
ddls = ';\n\n'.join(get_ddl_resp.statements) + ';\n\n'
descriptors = descriptor_pb2.FileDescriptorSet.FromString(
get_ddl_resp.protoDescriptors
)
return (
ddls
+ 'Proto Bundle Descriptors:\n'
+ text_format.MessageToString(descriptors)
)
def List(instance_ref):
"""List databases in the instance."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstancesDatabasesListRequest(
parent=instance_ref.RelativeName())
return list_pager.YieldFromList(
client.projects_instances_databases,
req,
field='databases',
batch_size_attribute='pageSize')
def UpdateDdl(database_ref, ddl, proto_descriptors=None):
"""Update a database via DDL commands."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
update_ddl_req = msgs.UpdateDatabaseDdlRequest(statements=ddl)
if proto_descriptors:
update_ddl_req.protoDescriptors = proto_descriptors
req = msgs.SpannerProjectsInstancesDatabasesUpdateDdlRequest(
database=database_ref.RelativeName(),
updateDatabaseDdlRequest=update_ddl_req)
return client.projects_instances_databases.UpdateDdl(req)
def Restore(database_ref, backup_ref, encryption_type=None, kms_key=None):
"""Restore a database from a backup."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
restore_db_request = msgs.RestoreDatabaseRequest(
backup=backup_ref.RelativeName(), databaseId=database_ref.Name())
if kms_key:
restore_db_request.encryptionConfig = msgs.RestoreDatabaseEncryptionConfig(
encryptionType=encryption_type,
kmsKeyName=kms_key.kms_key_name,
kmsKeyNames=kms_key.kms_key_names,
)
elif encryption_type:
restore_db_request.encryptionConfig = msgs.RestoreDatabaseEncryptionConfig(
encryptionType=encryption_type,
)
req = msgs.SpannerProjectsInstancesDatabasesRestoreRequest(
parent=database_ref.Parent().RelativeName(),
restoreDatabaseRequest=restore_db_request)
return client.projects_instances_databases.Restore(req)
def Update(database_ref, enable_drop_protection, kms_keys=None):
"""Update a database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
if enable_drop_protection and kms_keys:
raise errors.NoFieldsSpecifiedError(
'Multiple updates requested. Both flag --[no-]enable-drop-protection'
' and --kms-keys were specified. Please specify only one flag.'
)
if enable_drop_protection is None and kms_keys is None:
raise errors.NoFieldsSpecifiedError(
'No updates requested. Need to specify either flag '
'--[no-]enable-drop-protection OR --kms-keys.'
)
database_obj = None
update_mask = []
if enable_drop_protection is not None:
update_mask.append('enable_drop_protection')
database_obj = msgs.Database(
name=database_ref.RelativeName(),
enableDropProtection=enable_drop_protection,
)
elif kms_keys is not None:
update_mask.append('encryption_config')
database_obj = msgs.Database(
name=database_ref.RelativeName(),
encryptionConfig=msgs.EncryptionConfig(kmsKeyNames=kms_keys),
)
req = msgs.SpannerProjectsInstancesDatabasesPatchRequest(
database=database_obj,
name=database_ref.RelativeName(),
updateMask=','.join(update_mask),
)
return client.projects_instances_databases.Patch(req)
def ChangeQuorum(database_ref, quorum_type, etag=None):
"""ChangeQuorum a database."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.ChangeQuorumRequest(
etag=etag, name=database_ref.RelativeName(), quorumType=quorum_type
)
return client.projects_instances_databases.Changequorum(req)

View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner instance config operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def Get(config, operation):
"""Gets the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instanceConfigsId': config,
},
collection='spanner.projects.instanceConfigs.operations')
req = msgs.SpannerProjectsInstanceConfigsOperationsGetRequest(
name=ref.RelativeName())
return client.projects_instanceConfigs_operations.Get(req)
def List(config, type_filter=None):
"""List operations on instanceConfig using the generic operation list API."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs')
req = msgs.SpannerProjectsInstanceConfigsOperationsListRequest(
name=ref.RelativeName() + '/operations', filter=type_filter)
return list_pager.YieldFromList(
client.projects_instanceConfigs_operations,
req,
field='operations',
batch_size_attribute='pageSize')
def Cancel(config, operation):
"""Cancel the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instanceConfigsId': config,
},
collection='spanner.projects.instanceConfigs.operations')
req = msgs.SpannerProjectsInstanceConfigsOperationsCancelRequest(
name=ref.RelativeName())
return client.projects_instanceConfigs_operations.Cancel(req)
def Await(operation, message):
"""Wait for the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
poller = waiter.CloudOperationPoller(
client.projects_instanceConfigs,
client.projects_instanceConfigs_operations)
ref = resources.REGISTRY.ParseRelativeName(
operation.name, collection='spanner.projects.instanceConfigs.operations')
return waiter.WaitFor(poller, ref, message)
def BuildInstanceConfigOperationTypeFilter(op_type):
"""Builds the filter for the different instance config operation metadata types."""
if op_type is None:
return ''
base_string = 'metadata.@type:type.googleapis.com/google.spanner.admin.instance.v1.'
if op_type == 'INSTANCE_CONFIG_CREATE':
return base_string + 'CreateInstanceConfigMetadata'
if op_type == 'INSTANCE_CONFIG_UPDATE':
return base_string + 'UpdateInstanceConfigMetadata'

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner instanceConfigs API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
import six
class MissingReplicaError(core_exceptions.Error):
"""Indicates that the replica is missing in the source config."""
def __init__(self, replica_location, replica_type):
super(MissingReplicaError, self).__init__(
'The replica {0} of type {1} is not in the source config\'s replicas'
.format(replica_location, replica_type))
def Get(config):
"""Get the specified instance config."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs')
req = msgs.SpannerProjectsInstanceConfigsGetRequest(
name=ref.RelativeName())
return client.projects_instanceConfigs.Get(req)
def List():
"""List instance configs in the project."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
req = msgs.SpannerProjectsInstanceConfigsListRequest(
parent='projects/'+properties.VALUES.core.project.GetOrFail())
return list_pager.YieldFromList(
client.projects_instanceConfigs,
req,
field='instanceConfigs',
batch_size_attribute='pageSize')
def Delete(config, etag=None, validate_only=False):
"""Delete an instance config."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs')
req = msgs.SpannerProjectsInstanceConfigsDeleteRequest(
name=ref.RelativeName(), etag=etag, validateOnly=validate_only)
return client.projects_instanceConfigs.Delete(req)
def CreateUsingExistingConfig(args, config):
"""Create a new CMMR instance config based on an existing GMMR/CMMR config."""
msgs = apis.GetMessagesModule('spanner', 'v1')
# Override the user provided values, if any. Otherwise, clone the same from
# an existing config values.
display_name = args.display_name if args.display_name else config.displayName
labels = args.labels if args.labels else config.labels
# Note: baseConfig field is only set for user managed configurations.
# Use config name if this is not set.
base_config = config.baseConfig if config.baseConfig else config.name
replica_info_list = config.replicas
if args.skip_replicas:
_SkipReplicas(msgs, args.skip_replicas, replica_info_list)
if args.add_replicas:
_AppendReplicas(msgs, args.add_replicas, replica_info_list)
return _Create(msgs, args.config, display_name, base_config,
replica_info_list, labels, args.validate_only, args.etag)
def CreateUsingReplicas(config,
display_name,
base_config,
replicas_arg,
validate_only,
labels=None,
etag=None):
"""Create a new instance configs based on provided list of replicas."""
msgs = apis.GetMessagesModule('spanner', 'v1')
config_ref = resources.REGISTRY.Parse(
base_config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs')
replica_info_list = []
_AppendReplicas(msgs, replicas_arg, replica_info_list)
labels_message = {}
if labels is not None:
labels_message = msgs.InstanceConfig.LabelsValue(additionalProperties=[
msgs.InstanceConfig.LabelsValue.AdditionalProperty(
key=key, value=value) for key, value in six.iteritems(labels)
])
return _Create(msgs, config, display_name, config_ref.RelativeName(),
replica_info_list, labels_message, validate_only, etag)
def _Create(msgs,
config,
display_name,
base_config,
replica_info_list,
labels,
validate_only,
etag=None):
"""Create instance configs in the project."""
client = apis.GetClientInstance('spanner', 'v1')
project_ref = resources.REGISTRY.Create(
'spanner.projects', projectsId=properties.VALUES.core.project.GetOrFail)
config_ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs')
instance_config = msgs.InstanceConfig(
name=config_ref.RelativeName(),
displayName=display_name,
baseConfig=base_config,
labels=labels,
replicas=replica_info_list)
if etag:
instance_config.etag = etag
req = msgs.SpannerProjectsInstanceConfigsCreateRequest(
parent=project_ref.RelativeName(),
createInstanceConfigRequest=msgs.CreateInstanceConfigRequest(
instanceConfigId=config,
instanceConfig=instance_config,
validateOnly=validate_only))
return client.projects_instanceConfigs.Create(req)
def _AppendReplicas(msgs, add_replicas_arg, replica_info_list):
"""Appends each in add_replicas_arg to the given ReplicaInfo list."""
for replica in add_replicas_arg:
replica_type = arg_utils.ChoiceToEnum(replica['type'],
msgs.ReplicaInfo.TypeValueValuesEnum)
replica_info_list.append(
msgs.ReplicaInfo(location=replica['location'], type=replica_type))
def _SkipReplicas(msgs, skip_replicas_arg, replica_info_list):
"""Skips each in skip_replicas_arg from the given ReplicaInfo list."""
for replica_to_skip in skip_replicas_arg:
index_to_delete = None
replica_type = arg_utils.ChoiceToEnum(replica_to_skip['type'],
msgs.ReplicaInfo.TypeValueValuesEnum)
for index, replica in enumerate(replica_info_list):
# Only skip the first found matching replica.
if (replica.location == replica_to_skip['location'] and
replica.type == replica_type):
index_to_delete = index
pass
if index_to_delete is None:
raise MissingReplicaError(replica_to_skip['location'], replica_type)
replica_info_list.pop(index_to_delete)
def Patch(args):
"""Update an instance config."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
args.config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs')
instance_config = msgs.InstanceConfig(name=ref.RelativeName())
update_mask = []
if args.display_name is not None:
instance_config.displayName = args.display_name
update_mask.append('display_name')
if args.etag is not None:
instance_config.etag = args.etag
def GetLabels():
req = msgs.SpannerProjectsInstanceConfigsGetRequest(name=ref.RelativeName())
return client.projects_instanceConfigs.Get(req).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, msgs.InstanceConfig.LabelsValue, GetLabels)
if labels_update.needs_update:
instance_config.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
req = msgs.SpannerProjectsInstanceConfigsPatchRequest(
name=ref.RelativeName(),
updateInstanceConfigRequest=msgs.UpdateInstanceConfigRequest(
instanceConfig=instance_config,
updateMask=','.join(update_mask),
validateOnly=args.validate_only))
return client.projects_instanceConfigs.Patch(req)

View File

@@ -0,0 +1,86 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner instance operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
def Await(operation, message):
"""Wait for the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
poller = waiter.CloudOperationPoller(
client.projects_instances,
client.projects_instances_operations)
ref = resources.REGISTRY.ParseRelativeName(
operation.name,
collection='spanner.projects.instances.operations')
return waiter.WaitFor(poller, ref, message)
def Cancel(instance, operation):
"""Cancel the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance
},
collection='spanner.projects.instances.operations')
req = msgs.SpannerProjectsInstancesOperationsCancelRequest(
name=ref.RelativeName())
return client.projects_instances_operations.Cancel(req)
def Get(instance, operation):
"""Get the specified operation."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance
},
collection='spanner.projects.instances.operations')
req = msgs.SpannerProjectsInstancesOperationsGetRequest(
name=ref.RelativeName())
return client.projects_instances_operations.Get(req)
def List(instance):
"""List operations on the instance."""
client = apis.GetClientInstance('spanner', 'v1')
msgs = apis.GetMessagesModule('spanner', 'v1')
ref = resources.REGISTRY.Parse(
instance,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instances')
req = msgs.SpannerProjectsInstancesOperationsListRequest(
name=ref.RelativeName()+'/operations')
return list_pager.YieldFromList(
client.projects_instances_operations,
req,
field='operations',
batch_size_attribute='pageSize')

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner instance partition operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import datetime
from apitools.base.py import list_pager
from cloudsdk.google.protobuf import timestamp_pb2
from googlecloudsdk.api_lib.spanner import response_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.api_lib.util import waiter
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
# Timeout to use in ListInstancePartitionOperations for unreachable instance
# partitions.
UNREACHABLE_INSTANCE_PARTITION_TIMEOUT = datetime.timedelta(seconds=20)
_API_NAME = 'spanner'
_API_VERSION = 'v1'
def Await(operation, message):
"""Wait for the specified operation."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
poller = waiter.CloudOperationPoller(
client.projects_instances_instancePartitions,
client.projects_instances_instancePartitions_operations,
)
ref = resources.REGISTRY.ParseRelativeName(
operation.name,
collection='spanner.projects.instances.instancePartitions.operations',
)
return waiter.WaitFor(poller, ref, message)
def ListGeneric(instance, instance_partition):
"""List operations on an instance partition with generic LRO API."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
instance_partition,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancesId': instance,
},
collection='spanner.projects.instances.instancePartitions',
)
req = msgs.SpannerProjectsInstancesInstancePartitionsOperationsListRequest(
name=ref.RelativeName() + '/operations'
)
return list_pager.YieldFromList(
client.projects_instances_instancePartitions_operations,
req,
field='operations',
batch_size_attribute='pageSize',
)
def List(instance):
"""List operations on instance partitions under the given instance."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
tp_proto = timestamp_pb2.Timestamp()
tp_proto.FromDatetime(
datetime.datetime.now(tz=datetime.timezone.utc)
+ UNREACHABLE_INSTANCE_PARTITION_TIMEOUT
)
ref = resources.REGISTRY.Parse(
instance,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
},
collection='spanner.projects.instances',
)
req = msgs.SpannerProjectsInstancesInstancePartitionOperationsListRequest(
parent=ref.RelativeName(),
instancePartitionDeadline=tp_proto.ToJsonString(),
)
return list_pager.YieldFromList(
client.projects_instances_instancePartitionOperations,
req,
field='operations',
batch_size_attribute='pageSize',
get_field_func=response_util.GetFieldAndLogUnreachableInstancePartitions,
)
def Cancel(instance, instance_partition, operation):
"""Cancel the specified operation."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancePartitionsId': instance_partition,
'instancesId': instance,
},
collection='spanner.projects.instances.instancePartitions.operations',
)
req = msgs.SpannerProjectsInstancesInstancePartitionsOperationsCancelRequest(
name=ref.RelativeName()
)
return client.projects_instances_instancePartitions_operations.Cancel(req)
def Get(instance, instance_partition, operation):
"""Get the specified operation."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instancePartitionsId': instance_partition,
'instancesId': instance,
},
collection='spanner.projects.instances.instancePartitions.operations',
)
req = msgs.SpannerProjectsInstancesInstancePartitionsOperationsGetRequest(
name=ref.RelativeName()
)
return client.projects_instances_instancePartitions_operations.Get(req)

View File

@@ -0,0 +1,242 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner instance partition API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import datetime
from apitools.base.py import list_pager
from cloudsdk.google.protobuf import timestamp_pb2
from googlecloudsdk.api_lib.spanner import response_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
# The list of pre-defined IAM roles in Spanner.
KNOWN_ROLES = [
'roles/spanner.admin',
'roles/spanner.databaseAdmin',
'roles/spanner.databaseReader',
'roles/spanner.databaseUser',
'roles/spanner.viewer',
]
# Timeout to use in ListInstancePartitions for unreachable instance partitions.
UNREACHABLE_INSTANCE_PARTITION_TIMEOUT = datetime.timedelta(seconds=20)
_API_NAME = 'spanner'
_API_VERSION = 'v1'
def Create(
instance_ref,
instance_partition,
config,
description,
nodes,
processing_units=None,
autoscaling_min_nodes=None,
autoscaling_max_nodes=None,
autoscaling_min_processing_units=None,
autoscaling_max_processing_units=None,
autoscaling_high_priority_cpu_target=None,
autoscaling_total_cpu_target=None,
autoscaling_storage_target=None,
):
"""Create a new instance partition."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
# Module containing the definitions of messages for the specified API.
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
config_ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs',
)
instance_partition_obj = msgs.InstancePartition(
config=config_ref.RelativeName(), displayName=description
)
if nodes:
instance_partition_obj.nodeCount = nodes
elif processing_units:
instance_partition_obj.processingUnits = processing_units
elif (
autoscaling_min_nodes
or autoscaling_max_nodes
or autoscaling_min_processing_units
or autoscaling_max_processing_units
or autoscaling_high_priority_cpu_target is not None
or autoscaling_total_cpu_target is not None
or autoscaling_storage_target
):
instance_partition_obj.autoscalingConfig = msgs.AutoscalingConfig(
autoscalingLimits=msgs.AutoscalingLimits(
minNodes=autoscaling_min_nodes,
maxNodes=autoscaling_max_nodes,
minProcessingUnits=autoscaling_min_processing_units,
maxProcessingUnits=autoscaling_max_processing_units,
),
autoscalingTargets=msgs.AutoscalingTargets(
highPriorityCpuUtilizationPercent=autoscaling_high_priority_cpu_target,
totalCpuUtilizationPercent=autoscaling_total_cpu_target,
storageUtilizationPercent=autoscaling_storage_target,
),
)
req = msgs.SpannerProjectsInstancesInstancePartitionsCreateRequest(
parent=instance_ref.RelativeName(),
createInstancePartitionRequest=msgs.CreateInstancePartitionRequest(
instancePartitionId=instance_partition,
instancePartition=instance_partition_obj,
),
)
return client.projects_instances_instancePartitions.Create(req)
def Get(instance_partition_ref):
"""Get an instance partition by name."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
req = msgs.SpannerProjectsInstancesInstancePartitionsGetRequest(
name=instance_partition_ref.RelativeName()
)
return client.projects_instances_instancePartitions.Get(req)
def Patch(
instance_partition_ref,
description=None,
nodes=None,
processing_units=None,
autoscaling_min_nodes=None,
autoscaling_max_nodes=None,
autoscaling_min_processing_units=None,
autoscaling_max_processing_units=None,
autoscaling_high_priority_cpu_target=None,
autoscaling_total_cpu_target=None,
autoscaling_storage_target=None,
):
"""Update an instance partition."""
fields = []
if description is not None:
fields.append('displayName')
if nodes is not None:
fields.append('nodeCount')
if processing_units is not None:
fields.append('processingUnits')
if (
(autoscaling_min_nodes and autoscaling_max_nodes)
or (autoscaling_min_processing_units and autoscaling_max_processing_units)
) and (
(autoscaling_high_priority_cpu_target is not None or
autoscaling_total_cpu_target is not None)
and autoscaling_storage_target
):
fields.append('autoscalingConfig')
else:
if autoscaling_min_nodes:
fields.append('autoscalingConfig.autoscalingLimits.minNodes')
if autoscaling_max_nodes:
fields.append('autoscalingConfig.autoscalingLimits.maxNodes')
if autoscaling_min_processing_units:
fields.append('autoscalingConfig.autoscalingLimits.minProcessingUnits')
if autoscaling_max_processing_units:
fields.append('autoscalingConfig.autoscalingLimits.maxProcessingUnits')
if autoscaling_high_priority_cpu_target is not None:
fields.append(
'autoscalingConfig.autoscalingTargets.highPriorityCpuUtilizationPercent'
)
if autoscaling_total_cpu_target is not None:
fields.append(
'autoscalingConfig.autoscalingTargets.totalCpuUtilizationPercent'
)
if autoscaling_storage_target:
fields.append(
'autoscalingConfig.autoscalingTargets.storageUtilizationPercent'
)
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
instance_partition_obj = msgs.InstancePartition(displayName=description)
if processing_units:
instance_partition_obj.processingUnits = processing_units
elif nodes:
instance_partition_obj.nodeCount = nodes
elif (
autoscaling_min_nodes
or autoscaling_max_nodes
or autoscaling_min_processing_units
or autoscaling_max_processing_units
or autoscaling_high_priority_cpu_target is not None
or autoscaling_total_cpu_target is not None
or autoscaling_storage_target
):
instance_partition_obj.autoscalingConfig = msgs.AutoscalingConfig(
autoscalingLimits=msgs.AutoscalingLimits(
minNodes=autoscaling_min_nodes,
maxNodes=autoscaling_max_nodes,
minProcessingUnits=autoscaling_min_processing_units,
maxProcessingUnits=autoscaling_max_processing_units,
),
autoscalingTargets=msgs.AutoscalingTargets(
highPriorityCpuUtilizationPercent=autoscaling_high_priority_cpu_target,
totalCpuUtilizationPercent=autoscaling_total_cpu_target,
storageUtilizationPercent=autoscaling_storage_target,
),
)
req = msgs.SpannerProjectsInstancesInstancePartitionsPatchRequest(
name=instance_partition_ref.RelativeName(),
updateInstancePartitionRequest=msgs.UpdateInstancePartitionRequest(
fieldMask=','.join(fields), instancePartition=instance_partition_obj
),
)
return client.projects_instances_instancePartitions.Patch(req)
def List(instance_ref):
"""List instance partitions in the project."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
tp_proto = timestamp_pb2.Timestamp()
tp_proto.FromDatetime(
datetime.datetime.now(tz=datetime.timezone.utc)
+ UNREACHABLE_INSTANCE_PARTITION_TIMEOUT
)
req = msgs.SpannerProjectsInstancesInstancePartitionsListRequest(
parent=instance_ref.RelativeName(),
instancePartitionDeadline=tp_proto.ToJsonString(),
)
return list_pager.YieldFromList(
client.projects_instances_instancePartitions,
req,
field='instancePartitions',
batch_size_attribute='pageSize',
get_field_func=response_util.GetFieldAndLogUnreachableInstancePartitions,
)
def Delete(instance_partition_ref):
"""Delete an instance partition."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
req = msgs.SpannerProjectsInstancesInstancePartitionsDeleteRequest(
name=instance_partition_ref.RelativeName()
)
return client.projects_instances_instancePartitions.Delete(req)

View File

@@ -0,0 +1,709 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner instance API helper."""
import datetime
import re
from apitools.base.py import list_pager
from cloudsdk.google.protobuf import timestamp_pb2
from googlecloudsdk.api_lib.spanner import response_util
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.iam import iam_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.console import console_io
# The list of pre-defined IAM roles in Spanner.
KNOWN_ROLES = [
'roles/spanner.admin', 'roles/spanner.databaseAdmin',
'roles/spanner.databaseReader', 'roles/spanner.databaseUser',
'roles/spanner.viewer'
]
# Timeout to use in ListInstances for unreachable instances.
UNREACHABLE_INSTANCE_TIMEOUT = datetime.timedelta(seconds=20)
_SPANNER_API_NAME = 'spanner'
_SPANNER_API_VERSION = 'v1'
_FIELD_MASK_AUTOSCALING_CONFIG = 'autoscalingConfig'
def MaybeGetAutoscalingOverride(msgs, asymmetric_autoscaling_option):
"""Returns AutoscalingConfigOverrides object if any override is found in the parsed command-line flag key-value pairs, otherwise returns None."""
if (
'min_nodes' not in asymmetric_autoscaling_option
and 'max_nodes' not in asymmetric_autoscaling_option
and 'min_processing_units' not in asymmetric_autoscaling_option
and 'max_processing_units' not in asymmetric_autoscaling_option
and 'high_priority_cpu_target' not in asymmetric_autoscaling_option
and 'total_cpu_target' not in asymmetric_autoscaling_option
and 'disable_high_priority_cpu_autoscaling'
not in asymmetric_autoscaling_option
and 'disable_total_cpu_autoscaling' not in asymmetric_autoscaling_option
):
return None
obj = msgs.AutoscalingConfigOverrides(
autoscalingLimits=msgs.AutoscalingLimits()
)
if 'min_nodes' in asymmetric_autoscaling_option:
obj.autoscalingLimits.minNodes = asymmetric_autoscaling_option['min_nodes']
if 'max_nodes' in asymmetric_autoscaling_option:
obj.autoscalingLimits.maxNodes = asymmetric_autoscaling_option['max_nodes']
if 'min_processing_units' in asymmetric_autoscaling_option:
obj.autoscalingLimits.minProcessingUnits = asymmetric_autoscaling_option[
'min_processing_units'
]
if 'max_processing_units' in asymmetric_autoscaling_option:
obj.autoscalingLimits.maxProcessingUnits = asymmetric_autoscaling_option[
'max_processing_units'
]
if 'high_priority_cpu_target' in asymmetric_autoscaling_option:
obj.autoscalingTargetHighPriorityCpuUtilizationPercent = (
asymmetric_autoscaling_option['high_priority_cpu_target']
)
if 'total_cpu_target' in asymmetric_autoscaling_option:
obj.autoscalingTargetTotalCpuUtilizationPercent = (
asymmetric_autoscaling_option['total_cpu_target']
)
if 'disable_high_priority_cpu_autoscaling' in asymmetric_autoscaling_option:
obj.disableHighPriorityCpuAutoscaling = (
asymmetric_autoscaling_option['disable_high_priority_cpu_autoscaling']
)
if 'disable_total_cpu_autoscaling' in asymmetric_autoscaling_option:
obj.disableTotalCpuAutoscaling = (
asymmetric_autoscaling_option['disable_total_cpu_autoscaling']
)
return obj
# Merges existing_overrides with new_overrides and returned the merged result.
def MergeAutoscalingConfigOverride(msgs, existing_overrides, new_overrides):
"""Merges two AutoscalingConfigOverrides objects.
Args:
msgs: The messages module for the Spanner API.
existing_overrides: The existing AutoscalingConfigOverrides object.
new_overrides: The new AutoscalingConfigOverrides object to merge.
Returns:
The merged AutoscalingConfigOverrides object.
"""
if existing_overrides is None and new_overrides is None:
return None
if existing_overrides is None:
return new_overrides
if new_overrides is None:
return existing_overrides
# First, copy the existing values.
result = existing_overrides
# Next, assign any new limits overrides if any.
if new_overrides.autoscalingLimits is not None:
# Make sure autoscalingLimits is not None in the result.
if result.autoscalingLimits is None:
result.autoscalingLimits = msgs.AutoscalingLimits()
if new_overrides.autoscalingLimits.minNodes is not None:
result.autoscalingLimits.minNodes = (
new_overrides.autoscalingLimits.minNodes
)
if new_overrides.autoscalingLimits.maxNodes is not None:
result.autoscalingLimits.maxNodes = (
new_overrides.autoscalingLimits.maxNodes
)
if new_overrides.autoscalingLimits.minProcessingUnits is not None:
result.autoscalingLimits.minProcessingUnits = (
new_overrides.autoscalingLimits.minProcessingUnits
)
if new_overrides.autoscalingLimits.maxProcessingUnits is not None:
result.autoscalingLimits.maxProcessingUnits = (
new_overrides.autoscalingLimits.maxProcessingUnits
)
# Finally, assign any target overrides if any.
if (
new_overrides.autoscalingTargetHighPriorityCpuUtilizationPercent
is not None
):
result.autoscalingTargetHighPriorityCpuUtilizationPercent = (
new_overrides.autoscalingTargetHighPriorityCpuUtilizationPercent
)
if new_overrides.autoscalingTargetTotalCpuUtilizationPercent is not None:
result.autoscalingTargetTotalCpuUtilizationPercent = (
new_overrides.autoscalingTargetTotalCpuUtilizationPercent
)
if new_overrides.disableHighPriorityCpuAutoscaling is not None:
result.disableHighPriorityCpuAutoscaling = (
new_overrides.disableHighPriorityCpuAutoscaling
)
if result.disableHighPriorityCpuAutoscaling:
result.autoscalingTargetHighPriorityCpuUtilizationPercent = None
if new_overrides.disableTotalCpuAutoscaling is not None:
result.disableTotalCpuAutoscaling = new_overrides.disableTotalCpuAutoscaling
if result.disableTotalCpuAutoscaling:
result.autoscalingTargetTotalCpuUtilizationPercent = None
return result
# Set instance_obj.autoscalingConfig.asymmetricAutoscalingOptions by merging
# options found in the current_instance and patch requested, asym_option_patch.
def PatchAsymmetricAutoscalingOptions(
msgs, instance_obj, current_instance, asym_options_patch
):
"""Patch asymmetric autoscaling options.
Args:
msgs: API messages module.
instance_obj: The instance object to patch.
current_instance: The current instance object.
asym_options_patch: A list of AsymmetricAutoscalingOption objects to patch.
"""
option_by_location = {}
if config := current_instance.autoscalingConfig:
for existing_option in config.asymmetricAutoscalingOptions:
option_by_location[existing_option.replicaSelection.location] = (
existing_option
)
for patch_option in asym_options_patch:
location = patch_option.replicaSelection.location
if location in option_by_location:
# Update existing option
existing_option = option_by_location[location]
option_by_location[location].overrides = MergeAutoscalingConfigOverride(
msgs, existing_option.overrides, patch_option.overrides
)
else:
# Add new option
option_by_location[location] = patch_option
instance_obj.autoscalingConfig.asymmetricAutoscalingOptions.clear()
for opt in option_by_location.values():
instance_obj.autoscalingConfig.asymmetricAutoscalingOptions.append(opt)
def Create(
instance,
config,
description,
nodes,
processing_units=None,
autoscaling_min_nodes=None,
autoscaling_max_nodes=None,
autoscaling_min_processing_units=None,
autoscaling_max_processing_units=None,
autoscaling_high_priority_cpu_target=None,
autoscaling_total_cpu_target=None,
autoscaling_storage_target=None,
asymmetric_autoscaling_options=None,
disable_downscaling=None,
instance_type=None,
expire_behavior=None,
default_storage_type=None,
ssd_cache=None,
edition=None,
default_backup_schedule_type=None,
tags=None,
):
"""Create a new instance.
Args:
instance: The instance to create.
config: The instance config to use.
description: The instance description.
nodes: The number of nodes to use.
processing_units: The number of processing units to use.
autoscaling_min_nodes: The minimum number of nodes to use.
autoscaling_max_nodes: The maximum number of nodes to use.
autoscaling_min_processing_units: The minimum number of processing units to
use.
autoscaling_max_processing_units: The maximum number of processing units to
use.
autoscaling_high_priority_cpu_target: The high priority CPU target to use.
Zero is a valid value.
autoscaling_total_cpu_target: The total CPU target to use. Zero is a valid
value.
autoscaling_storage_target: The storage target to use.
asymmetric_autoscaling_options: A list of ordered dict of key-value pairs
representing the asymmetric autoscaling options.
disable_downscaling: Whether to disable downscaling for the instance.
instance_type: The instance type to use.
expire_behavior: The expire behavior to use.
default_storage_type: The default storage type to use.
ssd_cache: The ssd cache to use.
edition: The edition to use.
default_backup_schedule_type: The type of default backup schedule to use.
tags: The parsed tags value.
Returns:
The created instance.
"""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
# Module containing the definitions of messages for the specified API.
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
config_ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs',
)
project_ref = resources.REGISTRY.Create(
'spanner.projects', projectsId=properties.VALUES.core.project.GetOrFail
)
instance_obj = msgs.Instance(
config=config_ref.RelativeName(), displayName=description
)
if nodes:
instance_obj.nodeCount = nodes
elif processing_units:
instance_obj.processingUnits = processing_units
elif (
autoscaling_min_nodes
or autoscaling_max_nodes
or autoscaling_min_processing_units
or autoscaling_max_processing_units
or autoscaling_high_priority_cpu_target is not None
or autoscaling_total_cpu_target is not None
or autoscaling_storage_target
or disable_downscaling is not None
):
instance_obj.autoscalingConfig = msgs.AutoscalingConfig(
autoscalingLimits=msgs.AutoscalingLimits(
minNodes=autoscaling_min_nodes,
maxNodes=autoscaling_max_nodes,
minProcessingUnits=autoscaling_min_processing_units,
maxProcessingUnits=autoscaling_max_processing_units,
),
autoscalingTargets=msgs.AutoscalingTargets(
highPriorityCpuUtilizationPercent=autoscaling_high_priority_cpu_target,
totalCpuUtilizationPercent=autoscaling_total_cpu_target,
storageUtilizationPercent=autoscaling_storage_target,
),
disableDownscaling=disable_downscaling,
)
if instance_type is not None:
instance_obj.instanceType = instance_type
if expire_behavior is not None:
instance_obj.freeInstanceMetadata = msgs.FreeInstanceMetadata(
expireBehavior=expire_behavior
)
if default_storage_type is not None:
instance_obj.defaultStorageType = default_storage_type
if ssd_cache and ssd_cache.strip():
instance_obj.ssdCache = (
config_ref.RelativeName() + '/ssdCaches/' + ssd_cache.strip()
)
if edition is not None:
instance_obj.edition = msgs.Instance.EditionValueValuesEnum(edition)
if default_backup_schedule_type is not None:
instance_obj.defaultBackupScheduleType = (
msgs.Instance.DefaultBackupScheduleTypeValueValuesEnum(
default_backup_schedule_type
)
)
if tags is not None:
instance_obj.tags = msgs.Instance.TagsValue(
additionalProperties=[
msgs.Instance.TagsValue.AdditionalProperty(key=key, value=value)
for key, value in sorted(tags.items())
]
)
# Add asymmetric autoscaling options, if present.
if asymmetric_autoscaling_options is not None:
for asym_option in asymmetric_autoscaling_options:
instance_obj.autoscalingConfig.asymmetricAutoscalingOptions.append(
msgs.AsymmetricAutoscalingOption(
overrides=MaybeGetAutoscalingOverride(msgs, asym_option),
replicaSelection=msgs.InstanceReplicaSelection(
location=asym_option['location']
),
)
)
req = msgs.SpannerProjectsInstancesCreateRequest(
parent=project_ref.RelativeName(),
createInstanceRequest=msgs.CreateInstanceRequest(
instanceId=instance, instance=instance_obj
),
)
return client.projects_instances.Create(req)
def SetPolicy(instance_ref, policy, field_mask=None):
"""Saves the given policy on the instance, overwriting whatever exists."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
policy.version = iam_util.MAX_LIBRARY_IAM_SUPPORTED_VERSION
req = msgs.SpannerProjectsInstancesSetIamPolicyRequest(
resource=instance_ref.RelativeName(),
setIamPolicyRequest=msgs.SetIamPolicyRequest(policy=policy,
updateMask=field_mask))
return client.projects_instances.SetIamPolicy(req)
def GetIamPolicy(instance_ref):
"""Gets the IAM policy on an instance."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
req = msgs.SpannerProjectsInstancesGetIamPolicyRequest(
resource=instance_ref.RelativeName(),
getIamPolicyRequest=msgs.GetIamPolicyRequest(
options=msgs.GetPolicyOptions(
requestedPolicyVersion=
iam_util.MAX_LIBRARY_IAM_SUPPORTED_VERSION)))
return client.projects_instances.GetIamPolicy(req)
def Delete(instance):
"""Delete an instance."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
ref = resources.REGISTRY.Parse(
instance,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instances')
req = msgs.SpannerProjectsInstancesDeleteRequest(name=ref.RelativeName())
return client.projects_instances.Delete(req)
def Get(instance):
"""Get an instance by name."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
ref = resources.REGISTRY.Parse(
instance,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instances')
req = msgs.SpannerProjectsInstancesGetRequest(name=ref.RelativeName())
return client.projects_instances.Get(req)
def List():
"""List instances in the project."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
project_ref = resources.REGISTRY.Create(
'spanner.projects', projectsId=properties.VALUES.core.project.GetOrFail)
tp_proto = timestamp_pb2.Timestamp()
tp_proto.FromDatetime(
datetime.datetime.utcnow() + UNREACHABLE_INSTANCE_TIMEOUT)
req = msgs.SpannerProjectsInstancesListRequest(
parent=project_ref.RelativeName(),
instanceDeadline=tp_proto.ToJsonString())
return list_pager.YieldFromList(
client.projects_instances,
req,
field='instances',
batch_size_attribute='pageSize',
get_field_func=response_util.GetFieldAndLogUnreachable)
def Patch(
instance,
description=None,
nodes=None,
processing_units=None,
autoscaling_min_nodes=None,
autoscaling_max_nodes=None,
autoscaling_min_processing_units=None,
autoscaling_max_processing_units=None,
autoscaling_high_priority_cpu_target=None,
autoscaling_total_cpu_target=None,
autoscaling_storage_target=None,
asymmetric_autoscaling_options=None,
disable_downscaling=None,
clear_asymmetric_autoscaling_options=None,
instance_type=None,
expire_behavior=None,
ssd_cache_id=None,
edition=None,
default_backup_schedule_type=None,
):
"""Update an instance."""
fields = []
if description is not None:
fields.append('displayName')
if nodes is not None:
fields.append('nodeCount,autoscalingConfig')
if processing_units is not None:
fields.append('processingUnits,autoscalingConfig')
if (
(autoscaling_min_nodes and autoscaling_max_nodes)
or (autoscaling_min_processing_units and autoscaling_max_processing_units)
) and ((autoscaling_high_priority_cpu_target is not None or
autoscaling_total_cpu_target is not None)
and autoscaling_storage_target):
fields.append(_FIELD_MASK_AUTOSCALING_CONFIG)
else:
if autoscaling_min_nodes:
fields.append('autoscalingConfig.autoscalingLimits.minNodes')
if autoscaling_max_nodes:
fields.append('autoscalingConfig.autoscalingLimits.maxNodes')
if autoscaling_min_processing_units:
fields.append('autoscalingConfig.autoscalingLimits.minProcessingUnits')
if autoscaling_max_processing_units:
fields.append('autoscalingConfig.autoscalingLimits.maxProcessingUnits')
if autoscaling_high_priority_cpu_target is not None:
fields.append(
'autoscalingConfig.autoscalingTargets.highPriorityCpuUtilizationPercent'
)
if autoscaling_total_cpu_target is not None:
fields.append(
'autoscalingConfig.autoscalingTargets.totalCpuUtilizationPercent'
)
if autoscaling_storage_target:
fields.append(
'autoscalingConfig.autoscalingTargets.storageUtilizationPercent'
)
if disable_downscaling is not None:
fields.append('autoscalingConfig.disableDownscaling')
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
instance_obj = msgs.Instance(displayName=description)
if processing_units:
instance_obj.processingUnits = processing_units
elif nodes:
instance_obj.nodeCount = nodes
elif (
autoscaling_min_nodes
or autoscaling_max_nodes
or autoscaling_min_processing_units
or autoscaling_max_processing_units
or autoscaling_high_priority_cpu_target is not None
or autoscaling_total_cpu_target is not None
or autoscaling_storage_target
or disable_downscaling is not None
):
instance_obj.autoscalingConfig = msgs.AutoscalingConfig(
autoscalingLimits=msgs.AutoscalingLimits(
minNodes=autoscaling_min_nodes,
maxNodes=autoscaling_max_nodes,
minProcessingUnits=autoscaling_min_processing_units,
maxProcessingUnits=autoscaling_max_processing_units,
),
autoscalingTargets=msgs.AutoscalingTargets(
highPriorityCpuUtilizationPercent=autoscaling_high_priority_cpu_target,
totalCpuUtilizationPercent=autoscaling_total_cpu_target,
storageUtilizationPercent=autoscaling_storage_target,
),
disableDownscaling=disable_downscaling,
)
if asymmetric_autoscaling_options is not None:
if _FIELD_MASK_AUTOSCALING_CONFIG not in fields:
fields.append('autoscalingConfig.asymmetricAutoscalingOptions')
current_instance = Get(instance)
asym_options_patch = []
# Create AsymmetricAutoscalingOption objects from the flag value (key-value
# pairs).
for asym_option in asymmetric_autoscaling_options:
asym_options_patch.append(
msgs.AsymmetricAutoscalingOption(
replicaSelection=msgs.InstanceReplicaSelection(
location=asym_option['location']
),
overrides=MaybeGetAutoscalingOverride(msgs, asym_option),
)
)
if instance_obj.autoscalingConfig is None:
instance_obj.autoscalingConfig = msgs.AutoscalingConfig()
PatchAsymmetricAutoscalingOptions(
msgs, instance_obj, current_instance, asym_options_patch
)
if clear_asymmetric_autoscaling_options is not None:
if _FIELD_MASK_AUTOSCALING_CONFIG not in fields:
fields.append('autoscalingConfig.asymmetricAutoscalingOptions')
current_instance = Get(instance)
locations_to_remove = set(clear_asymmetric_autoscaling_options)
if instance_obj.autoscalingConfig is None:
instance_obj.autoscalingConfig = msgs.AutoscalingConfig()
instance_obj.autoscalingConfig.asymmetricAutoscalingOptions = []
for (
asym_option
) in current_instance.autoscalingConfig.asymmetricAutoscalingOptions:
if asym_option.replicaSelection.location not in locations_to_remove:
instance_obj.autoscalingConfig.asymmetricAutoscalingOptions.append(
asym_option
)
if instance_type is not None:
fields.append('instanceType')
instance_obj.instanceType = instance_type
if expire_behavior is not None:
fields.append('freeInstanceMetadata.expireBehavior')
instance_obj.freeInstanceMetadata = msgs.FreeInstanceMetadata(
expireBehavior=expire_behavior)
if ssd_cache_id is not None:
fields.append('ssdCache')
ssd_cache = ''
if ssd_cache_id.strip():
instance_res = Get(instance)
ssd_cache = instance_res.config + '/ssdCaches/' + ssd_cache_id.strip()
instance_obj.ssdCache = ssd_cache
if edition is not None:
fields.append('edition')
instance_obj.edition = msgs.Instance.EditionValueValuesEnum(edition)
if default_backup_schedule_type is not None:
fields.append('defaultBackupScheduleType')
instance_obj.defaultBackupScheduleType = (
msgs.Instance.DefaultBackupScheduleTypeValueValuesEnum(
default_backup_schedule_type
)
)
ref = resources.REGISTRY.Parse(
instance,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instances')
req = msgs.SpannerProjectsInstancesPatchRequest(
name=ref.RelativeName(),
updateInstanceRequest=msgs.UpdateInstanceRequest(
fieldMask=','.join(fields), instance=instance_obj))
return client.projects_instances.Patch(req)
def GetLocations(instance, verbose_flag):
"""Get all the replica regions for an instance."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
instance_res = Get(instance)
config_req = msgs.SpannerProjectsInstanceConfigsGetRequest(
name=instance_res.config)
config_res = client.projects_instanceConfigs.Get(config_req)
if verbose_flag:
command_output = []
for item in config_res.replicas:
command_output.append({'location': item.location, 'type': item.type})
else:
region_set = set()
for item in config_res.replicas:
region_set.add(item.location)
command_output = [{'location': item} for item in region_set]
return command_output
def Move(instance, target_instance_config, target_database_move_configs):
"""Moves an instance from one instance-config to another.
Args:
instance: Instance to move.
target_instance_config: The target instance configuration to move the
instance.
target_database_move_configs: Configurations for databases in the
destination instance config.
The configs can be google-managed or user-managed.
Ex: gcloud spanner instances move instance-to-move
--target-config=instance-config-to-move-to
Above example will move the instance(instance-to-move) to the following
instance config(instance-config-to-move-to).
"""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
config_ref = resources.REGISTRY.Parse(
target_instance_config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs',
)
instance_ref = resources.REGISTRY.Parse(
instance,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instances',
)
console_io.PromptContinue(
message=(
'You are about to move instance {0} from {1} to {2}. This is a'
' long-running operation with potential service'
' implications:\n\n\n\t* Increased latencies: Read and write'
' operations may experience delays.\n\n\t* Elevated abort rate:'
' Transactions may have a higher chance of failing.\n\n\t* Spiked CPU'
' utilization: System resources will be strained, impacting'
' performance.\n\n\t* Additional costs: Instance moves incur extra'
' charges, as described in the documentation.\n\n\t* Backups: It is'
' important that you copy your backups before moving your instance.'
' Backups need to be deleted from the Instance before the move. You'
' cannot create a backup while the move is in progress.\n\nBefore'
' proceeding, and for detailed information and best practices, refer'
' to the documentation at'
' https://cloud.google.com/spanner/docs/move-instance#move-prerequisites.'
.format(
instance, GetInstanceConfig(instance), target_instance_config
)
),
cancel_on_no=True,
prompt_string='Do you want to proceed',
)
req_args = {'targetConfig': config_ref.RelativeName()}
if target_database_move_configs is not None:
req_args['targetDatabaseMoveConfigs'] = []
for target_database_move_config in target_database_move_configs:
kms_key_names = target_database_move_config['kms-key-names'].split(',')
encryption_config_args = {}
encryption_config_args['kmsKeyNames'] = []
for kms_key_name in kms_key_names:
encryption_config_args['kmsKeyNames'].append(kms_key_name)
encryption_config = msgs.InstanceEncryptionConfig(
**encryption_config_args
)
req_args['targetDatabaseMoveConfigs'].append(
msgs.DatabaseMoveConfig(
databaseId=target_database_move_config['database-id'],
encryptionConfig=encryption_config,
)
)
move_req = msgs.SpannerProjectsInstancesMoveRequest(
moveInstanceRequest=msgs.MoveInstanceRequest(**req_args),
name=instance_ref.RelativeName(),
)
move_operation_id = client.projects_instances.Move(move_req).name
operation_id = re.search('.*/operations/(.*)', move_operation_id).group(1)
print(
'\nInstance move started for {0}\n\n'
'Track progress with: gcloud spanner operations'
' describe {1} --instance={2}'.format(
instance_ref.RelativeName(), operation_id, instance
)
)
def GetInstanceConfig(instance):
"""Get the instance config of the passed instance."""
client = apis.GetClientInstance(_SPANNER_API_NAME, _SPANNER_API_VERSION)
msgs = apis.GetMessagesModule(_SPANNER_API_NAME, _SPANNER_API_VERSION)
instance_ref = resources.REGISTRY.Parse(
instance,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instances',
)
instance_req = msgs.SpannerProjectsInstancesGetRequest(
name=instance_ref.RelativeName(), fieldMask='config'
)
instance_info = client.projects_instances.Get(instance_req)
instance_config = re.search(
'.*/instanceConfigs/(.*)', instance_info.config
).group(1)
return instance_config

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper for processing API responses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.core import log
def GetFieldAndLogUnreachable(message, attribute):
"""Response callback to log unreachable while generating fields of the message."""
if message.unreachable:
log.warning(
'The following instances were unreachable: {}.'
.format(', '.join(message.unreachable)))
return getattr(message, attribute)
def GetFieldAndLogUnreachableInstancePartitions(message, attribute):
"""Response callback to log unreachable while generating fields of the message."""
warning_text = 'The following instance partitions were unreachable: {}.'
if hasattr(message, 'unreachable') and message.unreachable:
# if the `message` is `ListInstancePartitionsResponse`.
log.warning(warning_text.format(', '.join(message.unreachable)))
elif (
hasattr(message, 'unreachableInstancePartitions')
and message.unreachableInstancePartitions
):
# If the `message` is `ListInstancePartitionOperationsResponse`.
log.warning(
warning_text.format(', '.join(message.unreachableInstancePartitions))
)
return getattr(message, attribute)

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner SSD Cache operations API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
_API_NAME = 'spanner'
_API_VERSION = 'v1'
def List(ssd_cache, config):
"""List operations on ssdCache using the generic operation list API."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
ssd_cache,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instanceConfigsId': config,
},
collection='spanner.projects.instanceConfigs.ssdCaches',
)
req = msgs.SpannerProjectsInstanceConfigsSsdCachesOperationsListRequest(
name=ref.RelativeName() + '/operations'
)
return list_pager.YieldFromList(
client.projects_instanceConfigs_ssdCaches_operations,
req,
field='operations',
batch_size_attribute='pageSize',
)
def Get(operation, ssd_cache, config):
"""Gets the specified operation."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
operation,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instanceConfigsId': config,
'ssdCachesId': ssd_cache,
},
collection='spanner.projects.instanceConfigs.ssdCaches.operations',
)
req = msgs.SpannerProjectsInstanceConfigsSsdCachesOperationsGetRequest(
name=ref.RelativeName()
)
return client.projects_instanceConfigs_ssdCaches_operations.Get(req)

View File

@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner SSD caches API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
_API_NAME = 'spanner'
_API_VERSION = 'v1'
def List(config):
"""List SSD caches in the instanceConfig."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
config,
params={'projectsId': properties.VALUES.core.project.GetOrFail},
collection='spanner.projects.instanceConfigs',
)
req = msgs.SpannerProjectsInstanceConfigsSsdCachesListRequest(
parent=ref.RelativeName()
)
return list_pager.YieldFromList(
client.projects_instanceConfigs_ssdCaches,
req,
field='ssdCaches',
batch_size_attribute='pageSize',
)
def Get(ssd_cache, config):
"""Gets the SSD cache in the specified instance config."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
ssd_cache,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instanceConfigsId': config,
},
collection='spanner.projects.instanceConfigs.ssdCaches',
)
req = msgs.SpannerProjectsInstanceConfigsSsdCachesGetRequest(
name=ref.RelativeName()
)
return client.projects_instanceConfigs_ssdCaches.Get(req)
def Patch(args):
"""Update an SSD cache."""
client = apis.GetClientInstance(_API_NAME, _API_VERSION)
msgs = apis.GetMessagesModule(_API_NAME, _API_VERSION)
ref = resources.REGISTRY.Parse(
args.cache_id,
params={
'projectsId': properties.VALUES.core.project.GetOrFail,
'instanceConfigsId': args.config,
},
collection='spanner.projects.instanceConfigs.ssdCaches',
)
ssd_cache = msgs.SsdCache(name=ref.RelativeName())
update_mask = []
if args.size_gib is not None:
ssd_cache.sizeGib = args.size_gib
update_mask.append('size_gib')
if args.display_name is not None:
ssd_cache.displayName = args.display_name
update_mask.append('display_name')
def GetLabels():
req = msgs.SpannerProjectsInstanceConfigsSsdCachesGetRequest(
name=ref.RelativeName()
)
return client.projects_instanceConfigs_ssdCaches.Get(req).labels
labels_update = labels_util.ProcessUpdateArgsLazy(
args, msgs.SsdCache.LabelsValue, GetLabels
)
if labels_update.needs_update:
ssd_cache.labels = labels_update.labels
update_mask.append('labels')
if not update_mask:
raise errors.NoFieldsSpecifiedError('No updates requested.')
req = msgs.SpannerProjectsInstanceConfigsSsdCachesPatchRequest(
name=ref.RelativeName(),
updateSsdCacheRequest=msgs.UpdateSsdCacheRequest(
ssdCache=ssd_cache, updateMask=','.join(update_mask)
),
)
return client.projects_instanceConfigs_ssdCaches.Patch(req)