feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*- #
# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner cli library functions and utilities for the spanner binary."""
import copy
import os
from googlecloudsdk.command_lib.util.anthos import binary_operations
from googlecloudsdk.core import exceptions as c_except
from googlecloudsdk.core import execution_utils
# default base command is sql
_BASE_COMMAND = "sql"
_SPANNER_CLI_BINARY = "spanner-cli"
def GetEnvArgsForCommand(extra_vars=None, exclude_vars=None):
"""Return an env dict to be passed on command invocation."""
env = copy.deepcopy(os.environ)
if extra_vars:
env.update(extra_vars)
if exclude_vars:
for key in exclude_vars:
env.pop(key)
return env
class SpannerCliException(c_except.Error):
"""Base Exception for any errors raised by gcloud spanner cli surface."""
class SpannerCliWrapper(binary_operations.BinaryBackedOperation):
"""Wrapper for spanner cli commands which calls the spanner binary."""
def __init__(self, **kwargs):
super(SpannerCliWrapper, self).__init__(
binary=_SPANNER_CLI_BINARY,
install_if_missing=True,
**kwargs,
)
def _ParseArgsForCommand(
self,
project=None,
database=None,
instance=None,
database_role=None,
host=None,
port=None,
api_endpoint=None,
idle_transaction_timeout=None,
skip_column_names=False,
skip_system_command=False,
system_command="OFF",
prompt=None,
delimiter=None,
table=False,
html=False,
xml=False,
execute=None,
source=None,
tee=None,
init_command=None,
init_command_add=None,
verbose=False,
directed_read=None,
proto_descriptor_file=None,
**kwargs,
):
del kwargs
formatted_arguments = (_BASE_COMMAND,)
if project:
formatted_arguments += (f"--project={project}",)
if database:
formatted_arguments += (f"--database={database}",)
if instance:
formatted_arguments += (f"--instance={instance}",)
if database_role:
formatted_arguments += (f"--role={database_role}",)
if port and host:
formatted_arguments += (f"--deployment-endpoint={host}:{port}",)
elif api_endpoint:
formatted_arguments += (f"--deployment-endpoint={api_endpoint}",)
if idle_transaction_timeout:
formatted_arguments += (
f"--idle-transaction-timeout={idle_transaction_timeout}",
)
if skip_column_names:
formatted_arguments += ("--skip-column-names",)
if skip_system_command or system_command == "OFF":
formatted_arguments += ("--skip-system-command",)
if prompt:
formatted_arguments += (f"--prompt={prompt}",)
if delimiter:
formatted_arguments += (f"--delimiter={delimiter}",)
if table:
formatted_arguments += ("--table",)
if html:
formatted_arguments += ("--html",)
if xml:
formatted_arguments += ("--xml",)
if execute:
formatted_arguments += (f"--execute={execute}",)
if source:
formatted_arguments += (f"--source={source}",)
if tee:
formatted_arguments += (f"--tee={tee}",)
if init_command:
formatted_arguments += (f"--init-command={init_command}",)
if init_command_add:
formatted_arguments += (f"--init-command-add={init_command_add}",)
if verbose:
formatted_arguments += ("--verbose",)
if directed_read:
formatted_arguments += (f"--directed-read={directed_read}",)
if proto_descriptor_file:
formatted_arguments += (
f"--proto-descriptor-file={proto_descriptor_file}",
)
return formatted_arguments
def _Execute(self, cmd, stdin=None, env=None, **kwargs):
"""Call the spanner cli binary with the given arguments."""
execution_utils.Exec(cmd)

View File

@@ -0,0 +1,225 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides ddl preprocessing for the Spanner ddl."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import logging
from googlecloudsdk.core import exceptions
class DDLSyntaxError(exceptions.Error):
pass
class DDLParser:
"""Parser for splitting ddl statements preserving GoogleSQL strings literals.
DDLParse has a list of modes. If any mode is selected, control is given to the
mode. If no mode is selected, the parser trys to enter the first mode that
could it could enter. The parser handles splitting statements upon ';'.
During parsing, a DDL has the following parts:
* parts that has been processed: emitted or skipped.
* followed by a buffer that has been matched by the current mode, which
could be emitted or skipped by a mode. The start index of which is
mode_start_index_.
* followed by the next character indexed by next_index_, which could direct
the parser to enter or exit a mode.
* followed by the unprocessed character.
DDLParser:
* acts as a default mode.
* provides utilities uesd by ParserMode to drive the parsing.
"""
def __init__(self, ddl):
self.ddl_ = ddl
# Index of the current character to process
self.next_index_ = 0
# Mode the parser is in now.
self.mode_ = None
# Start index of the buffer that has been matched by a mode or the parser.
self.mode_start_index_ = 0
# List of modes. The first one that the parser could enter wins in case of
# conflict.
self.modes_ = [
self.SkippingMode('--', ['\n', '\r']),
# For all the string modes below, we need to escape \\. If we don't, \\"
# will trigger mode exiting.
# Triple double quote.
# We need escape \", or \""" will be treated trigger mode exiting.
self.PreservingMode('"""', ['"""'], ['\\"', '\\\\']),
# Triple single quote.
# We need escape \', or \''' will be treated trigger mode exiting.
self.PreservingMode("'''", ["'''"], ["\\'", '\\\\']),
# Single double quote.
self.PreservingMode('"', ['"'], ['\\"', '\\\\']),
# Single single quote.
self.PreservingMode("'", ["'"], ["\\'", '\\\\']),
# Single back quote.
self.PreservingMode('`', ['`'], ['\\`', '\\\\']),
]
# A list of statements. A statement is a list of ddl fragments.
self.statements_ = []
self.StartNewStatement()
self.logger_ = logging.getLogger('SpannerDDLParser')
def SkippingMode(self, enter_seq, exit_seqs):
return DDLParserMode(self, enter_seq, exit_seqs, None, True)
def PreservingMode(self, enter_seq, exit_seqs, escape_sequences):
return DDLParserMode(self, enter_seq, exit_seqs, escape_sequences, False)
def IsEof(self):
return self.next_index_ == len(self.ddl_)
def Advance(self, l):
self.next_index_ += l
def StartNewStatement(self):
self.ddl_parts_ = []
self.statements_.append(self.ddl_parts_)
def EmitBuffer(self):
if self.mode_start_index_ >= self.next_index_:
# Buffer is empty.
return
self.ddl_parts_.append(
self.ddl_[self.mode_start_index_:self.next_index_])
self.SkipBuffer()
self.logger_.debug('emitted: %s', self.ddl_parts_[-1])
def SkipBuffer(self):
self.mode_start_index_ = self.next_index_
def EnterMode(self, mode):
self.logger_.debug('enter mode: %s at index: %d',
mode.enter_seq_, self.next_index_)
self.mode_ = mode
def ExitMode(self):
self.logger_.debug('exit mode: %s at index: %d',
self.mode_.enter_seq_, self.next_index_)
self.mode_ = None
def StartsWith(self, s):
return self.ddl_[self.next_index_:].startswith(s)
def Process(self):
"""Process the DDL."""
while not self.IsEof():
# Delegate to active mode if we have any.
if self.mode_:
self.mode_.Process()
continue
# Check statement break.
if self.ddl_[self.next_index_] == ';':
self.EmitBuffer()
self.StartNewStatement()
self.mode_start_index_ += 1
self.Advance(1)
continue
# If we could enter any mode.
for m in self.modes_:
if m.TryEnter():
self.EnterMode(m)
break
# No mode is found, consume the character.
if not self.mode_:
self.Advance(1)
# At the end of parsing, we close the unclosed mode.
if self.mode_ is not None:
m = self.mode_
if not m.is_to_skip_:
raise DDLSyntaxError(
'Unclosed %s start at index: %d, %s' %
(m.enter_seq_, self.mode_start_index_, self.ddl_))
self.mode_.Exit()
else:
self.EmitBuffer()
self.logger_.debug('ddls: %s', self.statements_)
res = [''.join(frags) for frags in self.statements_ if frags]
# See https://stackoverflow.com/q/67857941
if res and res[-1].isspace():
return res[:-1]
return res
class DDLParserMode:
"""A mode in DDLParser.
A mode has one entering sequence, a list of exit sequences and one escape
sequence. A mode could be:
* skipping (e.x. comments), which skips the matched text.
* non-skpping, (e.x. strings), which emits the matched text.
"""
def __init__(self, parser, enter_seq, exit_seqs, escape_sequences,
is_to_skip):
self.parser_ = parser
self.enter_seq_ = enter_seq
self.exit_seqs_ = exit_seqs
self.escape_sequences_ = escape_sequences
self.is_to_skip_ = is_to_skip
def TryEnter(self):
"""Trys to enter into the mode."""
res = self.parser_.StartsWith(self.enter_seq_)
if res:
self.parser_.EmitBuffer()
self.parser_.Advance(len(self.enter_seq_))
return res
def Exit(self):
if self.is_to_skip_:
self.parser_.SkipBuffer()
else:
self.parser_.EmitBuffer()
self.parser_.ExitMode()
def FindExitSeqence(self):
"""Finds a matching exit sequence."""
for s in self.exit_seqs_:
if self.parser_.StartsWith(s):
return s
return None
def Process(self):
"""Process the ddl at the current parser index."""
# Put escape sequence into buffer
if self.escape_sequences_:
for seq in self.escape_sequences_:
if self.parser_.StartsWith(seq):
self.parser_.Advance(len(self.escape_sequences_))
return
# Check if we should exit the current mode
exit_seq = self.FindExitSeqence()
if not exit_seq:
self.parser_.Advance(1)
return
# Before exit, put exit_seq into buffer for non skipping mode
if not self.is_to_skip_:
self.parser_.Advance(len(exit_seq))
self.Exit()
def PreprocessDDLWithParser(ddl_text):
return DDLParser(ddl_text).Process()

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides helper methods for dealing with JSON files for Spanner IAM."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.spanner import databases
from googlecloudsdk.api_lib.spanner import instances
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.iam import iam_util
def AddInstanceIamPolicyBinding(instance_ref, member, role):
"""Adds a policy binding to an instance IAM policy."""
msgs = apis.GetMessagesModule('spanner', 'v1')
policy = instances.GetIamPolicy(instance_ref)
iam_util.AddBindingToIamPolicy(msgs.Binding, policy, member, role)
return instances.SetPolicy(instance_ref, policy)
def SetInstanceIamPolicy(instance_ref, policy):
"""Sets the IAM policy on an instance."""
msgs = apis.GetMessagesModule('spanner', 'v1')
policy, field_mask = iam_util.ParsePolicyFileWithUpdateMask(policy,
msgs.Policy)
return instances.SetPolicy(instance_ref, policy, field_mask)
def RemoveInstanceIamPolicyBinding(instance_ref, member, role):
"""Removes a policy binding from an instance IAM policy."""
policy = instances.GetIamPolicy(instance_ref)
iam_util.RemoveBindingFromIamPolicy(policy, member, role)
return instances.SetPolicy(instance_ref, policy)
def AddDatabaseIamPolicyBinding(database_ref, member, role):
"""Adds a policy binding to a database IAM policy."""
msgs = apis.GetMessagesModule('spanner', 'v1')
policy = databases.GetIamPolicy(database_ref)
iam_util.AddBindingToIamPolicy(msgs.Binding, policy, member, role)
return databases.SetPolicy(database_ref, policy)
def SetDatabaseIamPolicy(database_ref, policy):
"""Sets the IAM policy on a database."""
msgs = apis.GetMessagesModule('spanner', 'v1')
policy = iam_util.ParsePolicyFile(policy, msgs.Policy)
return databases.SetPolicy(database_ref, policy)
def RemoveDatabaseIamPolicyBinding(database_ref, member, role):
"""Removes a policy binding from a database IAM policy."""
policy = databases.GetIamPolicy(database_ref)
iam_util.RemoveBindingFromIamPolicy(policy, member, role)
return databases.SetPolicy(database_ref, policy)

View File

@@ -0,0 +1,266 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner migration library functions and utilities for the spanner-migration-tool binary."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import copy
import os
from googlecloudsdk.command_lib.util.anthos import binary_operations
from googlecloudsdk.core import exceptions as c_except
def GetEnvArgsForCommand(extra_vars=None, exclude_vars=None):
"""Return an env dict to be passed on command invocation."""
env = copy.deepcopy(os.environ)
if extra_vars:
env.update(extra_vars)
if exclude_vars:
for k in exclude_vars:
env.pop(k)
return env
class SpannerMigrationException(c_except.Error):
"""Base Exception for any errors raised by gcloud spanner migration surface."""
class SpannerMigrationWrapper(binary_operations.StreamingBinaryBackedOperation):
"""Binary operation wrapper for spanner-migration-tool commands."""
def __init__(self, **kwargs):
super(SpannerMigrationWrapper, self).__init__(
binary='spanner-migration-tool', install_if_missing=True, **kwargs)
def _ParseSchemaArgs(self,
source,
prefix=None,
source_profile=None,
target=None,
target_profile=None,
dry_run=False,
log_level=None,
project=None,
**kwargs):
""""Parse args for the schema command."""
del kwargs
exec_args = ['schema']
if source:
exec_args.extend(['--source', source])
if prefix:
exec_args.extend(['--prefix', prefix])
if source_profile:
exec_args.extend(['--source-profile', source_profile])
if target:
exec_args.extend(['--target', target])
if target_profile:
exec_args.extend(['--target-profile', target_profile])
if dry_run:
exec_args.append('--dry-run')
if log_level:
exec_args.extend(['--log-level', log_level])
if project:
exec_args.extend(['--project', project])
return exec_args
def _ParseDataArgs(self,
source,
session,
prefix=None,
skip_foreign_keys=False,
source_profile=None,
target=None,
target_profile=None,
write_limit=None,
dry_run=False,
log_level=None,
project=None,
dataflow_template=None,
**kwargs):
""""Parse args for the data command."""
del kwargs
exec_args = ['data']
if source:
exec_args.extend(['--source', source])
if session:
exec_args.extend(['--session', session])
if prefix:
exec_args.extend(['--prefix', prefix])
if skip_foreign_keys:
exec_args.append('--skip-foreign-keys')
if source_profile:
exec_args.extend(['--source-profile', source_profile])
if target:
exec_args.extend(['--target', target])
if target_profile:
exec_args.extend(['--target-profile', target_profile])
if write_limit:
exec_args.extend(['--write-limit', write_limit])
if dry_run:
exec_args.append('--dry-run')
if log_level:
exec_args.extend(['--log-level', log_level])
if project:
exec_args.extend(['--project', project])
if dataflow_template:
exec_args.extend(['--dataflow-template', dataflow_template])
return exec_args
def _ParseSchemaAndDataArgs(self,
source,
prefix=None,
skip_foreign_keys=False,
source_profile=None,
target=None,
target_profile=None,
write_limit=None,
dry_run=False,
log_level=None,
project=None,
dataflow_template=None,
**kwargs):
""""Parse args for the schema-and-data command."""
del kwargs
exec_args = ['schema-and-data']
if source:
exec_args.extend(['--source', source])
if prefix:
exec_args.extend(['--prefix', prefix])
if skip_foreign_keys:
exec_args.append('--skip-foreign-keys')
if source_profile:
exec_args.extend(['--source-profile', source_profile])
if target:
exec_args.extend(['--target', target])
if target_profile:
exec_args.extend(['--target-profile', target_profile])
if write_limit:
exec_args.extend(['--write-limit', write_limit])
if dry_run:
exec_args.append('--dry-run')
if log_level:
exec_args.extend(['--log-level', log_level])
if project:
exec_args.extend(['--project', project])
if dataflow_template:
exec_args.extend(['--dataflow-template', dataflow_template])
return exec_args
def _ParseWebArgs(self,
open_flag=False,
port=None,
log_level=None,
dataflow_template=None,
**kwargs):
"""Parse args for the web command."""
del kwargs
exec_args = ['web']
if open_flag:
exec_args.append('--open')
if port:
exec_args.extend(['--port', port])
if log_level:
exec_args.extend(['--log-level', log_level])
if dataflow_template:
exec_args.extend(['--dataflow-template', dataflow_template])
return exec_args
def ParseCleanupArgs(self,
job_id,
data_shard_ids=None,
target_profile=None,
datastream=False,
dataflow=False,
pub_sub=False,
monitoring=False,
log_level=None,
**kwargs):
""""Parse args for the cleanup command."""
del kwargs
exec_args = ['cleanup']
if job_id:
exec_args.extend(['--jobId', job_id])
if data_shard_ids:
exec_args.extend(['--dataShardIds', data_shard_ids])
if target_profile:
exec_args.extend(['--target-profile', target_profile])
if datastream:
exec_args.append('--datastream')
if dataflow:
exec_args.append('--dataflow')
if pub_sub:
exec_args.append('--pubsub')
if monitoring:
exec_args.append('--monitoring')
if log_level:
exec_args.append('--log-level')
return exec_args
def ParseImportArgs(self,
instance,
database,
source_uri,
source_format,
table_name=None,
project=None,
schema_uri=None,
csv_line_delimiter=None,
csv_field_delimiter=None,
database_dialect=None,
**kwargs):
""""Parse args for the import command."""
del kwargs
exec_args = ['import']
if instance:
exec_args.extend(['--instance', instance])
if database:
exec_args.extend(['--database', database])
if table_name:
exec_args.extend(['--table-name', table_name])
if source_uri:
exec_args.extend(['--source-uri', source_uri])
if source_format:
exec_args.extend(['--source-format', source_format])
if schema_uri:
exec_args.extend(['--schema-uri', schema_uri])
if csv_line_delimiter:
exec_args.extend(['--csv-line-delimiter', csv_line_delimiter])
if csv_field_delimiter:
exec_args.extend(['--csv-field-delimiter', csv_field_delimiter])
if project:
exec_args.extend(['--project', project])
if database_dialect:
exec_args.extend(['--database-dialect', database_dialect])
return exec_args
def _ParseArgsForCommand(self, command, **kwargs):
"""Call the parser corresponding to the command."""
if command == 'schema':
return self._ParseSchemaArgs(**kwargs)
elif command == 'data':
return self._ParseDataArgs(**kwargs)
elif command == 'schema-and-data':
return self._ParseSchemaAndDataArgs(**kwargs)
elif command == 'web':
return self._ParseWebArgs(**kwargs)
elif command == 'cleanup':
return self.ParseCleanupArgs(**kwargs)
elif command == 'import':
return self.ParseImportArgs(**kwargs)
else:
raise binary_operations.InvalidOperationForBinary(
'Invalid Operation [{}] for spanner-migration-tool'.format(command))

View File

@@ -0,0 +1,657 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared resource flags for Cloud Spanner commands."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.calliope.concepts import concepts
from googlecloudsdk.calliope.concepts import deps
from googlecloudsdk.command_lib.util.apis import arg_utils
from googlecloudsdk.command_lib.util.concepts import concept_parsers
from googlecloudsdk.command_lib.util.concepts import presentation_specs
from googlecloudsdk.core import properties
_PROJECT = properties.VALUES.core.project
_INSTANCE = properties.VALUES.spanner.instance
_CREATE_BACKUP_ENCRYPTION_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
'--encryption-type',
apis.GetMessagesModule(
'spanner', 'v1'
).SpannerProjectsInstancesBackupsCreateRequest.EncryptionConfigEncryptionTypeValueValuesEnum,
help_str='The encryption type of the backup.',
required=False,
custom_mappings={
'USE_DATABASE_ENCRYPTION': (
'use-database-encryption',
'Use the same encryption configuration as the database.',
),
'GOOGLE_DEFAULT_ENCRYPTION': (
'google-default-encryption',
'Use Google default encryption.',
),
'CUSTOMER_MANAGED_ENCRYPTION': (
'customer-managed-encryption',
'Use the provided Cloud KMS key for encryption.'
+ 'If this option is '
+ 'selected, kms-key must be set.',
),
},
)
_CREATE_BACKUP_ENCRYPTION_CONFIG_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
'--encryption-type',
apis.GetMessagesModule(
'spanner', 'v1'
).CreateBackupEncryptionConfig.EncryptionTypeValueValuesEnum,
help_str='The encryption type of the backup.',
required=False,
custom_mappings={
'USE_DATABASE_ENCRYPTION': (
'use-database-encryption',
'Use the same encryption configuration as the database.',
),
'GOOGLE_DEFAULT_ENCRYPTION': (
'google-default-encryption',
'Use Google default encryption.',
),
'CUSTOMER_MANAGED_ENCRYPTION': (
'customer-managed-encryption',
(
'Use the provided Cloud KMS key for encryption. If this option'
' is selected, kms-key must be set.'
),
),
},
)
_COPY_BACKUP_ENCRYPTION_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
'--encryption-type',
apis.GetMessagesModule(
'spanner', 'v1'
).CopyBackupEncryptionConfig.EncryptionTypeValueValuesEnum,
help_str='The encryption type of the copied backup.',
required=False,
custom_mappings={
'USE_CONFIG_DEFAULT_OR_BACKUP_ENCRYPTION': (
'use-config-default-or-backup-encryption',
(
'Use the default encryption configuration if one exists.'
' otherwise use the same encryption configuration as the source'
' backup.'
),
),
'GOOGLE_DEFAULT_ENCRYPTION': (
'google-default-encryption',
'Use Google default encryption.',
),
'CUSTOMER_MANAGED_ENCRYPTION': (
'customer-managed-encryption',
(
'Use the provided Cloud KMS key for encryption. If this option'
' is selected, kms-key must be set.'
),
),
},
)
_RESTORE_DB_ENCRYPTION_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
'--encryption-type',
apis.GetMessagesModule(
'spanner', 'v1'
).RestoreDatabaseEncryptionConfig.EncryptionTypeValueValuesEnum,
help_str='The encryption type of the restored database.',
required=False,
custom_mappings={
'USE_CONFIG_DEFAULT_OR_BACKUP_ENCRYPTION': (
'use-config-default-or-backup-encryption',
(
'Use the default encryption configuration if one exists, '
'otherwise use the same encryption configuration as the backup.'
),
),
'GOOGLE_DEFAULT_ENCRYPTION': (
'google-default-encryption',
'Use Google default encryption.',
),
'CUSTOMER_MANAGED_ENCRYPTION': (
'customer-managed-encryption',
(
'Use the provided Cloud KMS key for encryption. If this option'
' is selected, kms-key must be set.'
),
),
},
)
_INSTANCE_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
'--instance-type',
apis.GetMessagesModule(
'spanner', 'v1'
).Instance.InstanceTypeValueValuesEnum,
help_str='Specifies the type for this instance.',
required=False,
custom_mappings={
'PROVISIONED': (
'provisioned',
(
'Provisioned instances have dedicated resources, standard usage'
' limits, and support.'
),
),
'FREE_INSTANCE': (
'free-instance',
(
'Free trial instances provide no guarantees for dedicated '
'resources, both node_count and processing_units should be 0. '
'They come with stricter usage limits and limited support.'
),
),
},
)
_DEFAULT_STORAGE_TYPE_MAPPER = arg_utils.ChoiceEnumMapper(
'--default-storage-type',
apis.GetMessagesModule(
'spanner', 'v1'
).Instance.DefaultStorageTypeValueValuesEnum,
help_str='Specifies the default storage type for this instance.',
required=False,
hidden=True,
custom_mappings={
'SSD': ('ssd', 'Use ssd as default storage type for this instance'),
'HDD': ('hdd', 'Use hdd as default storage type for this instance'),
},
)
_EXPIRE_BEHAVIOR_MAPPER = arg_utils.ChoiceEnumMapper(
'--expire-behavior',
apis.GetMessagesModule(
'spanner', 'v1').FreeInstanceMetadata.ExpireBehaviorValueValuesEnum,
help_str='The expire behavior of a free trial instance.',
required=False,
custom_mappings={
'FREE_TO_PROVISIONED':
('free-to-provisioned',
('When the free trial instance expires, upgrade the instance to a '
'provisioned instance.')),
'REMOVE_AFTER_GRACE_PERIOD':
('remove-after-grace-period',
('When the free trial instance expires, disable the instance, '
'and delete it after the grace period passes if it has not been '
'upgraded to a provisioned instance.')),
})
def InstanceAttributeConfig():
"""Get instance resource attribute with default value."""
return concepts.ResourceParameterAttributeConfig(
name='instance',
help_text='The Cloud Spanner instance for the {resource}.',
fallthroughs=[deps.PropertyFallthrough(_INSTANCE)])
def InstancePartitionAttributeConfig():
"""Get instance partition resource attribute with default value."""
return concepts.ResourceParameterAttributeConfig(
name='instance partition',
help_text='The Spanner instance partition for the {resource}.',
)
def DatabaseAttributeConfig():
"""Get database resource attribute."""
return concepts.ResourceParameterAttributeConfig(
name='database',
help_text='The Cloud Spanner database for the {resource}.')
def BackupAttributeConfig():
"""Get backup resource attribute."""
return concepts.ResourceParameterAttributeConfig(
name='backup',
help_text='The Cloud Spanner backup for the {resource}.')
def BackupScheduleAttributeConfig():
"""Get backup schedule resource attribute."""
return concepts.ResourceParameterAttributeConfig(
name='backup-schedule',
help_text='The Cloud Spanner backup schedule for the {resource}.')
def SessionAttributeConfig():
"""Get session resource attribute."""
return concepts.ResourceParameterAttributeConfig(
name='session', help_text='The Cloud Spanner session for the {resource}.')
def KmsKeyAttributeConfig():
# For anchor attribute, help text is generated automatically.
return concepts.ResourceParameterAttributeConfig(name='kms-key')
def KmsKeyringAttributeConfig():
return concepts.ResourceParameterAttributeConfig(
name='kms-keyring', help_text='KMS keyring id of the {resource}.')
def KmsLocationAttributeConfig():
return concepts.ResourceParameterAttributeConfig(
name='kms-location', help_text='Cloud location for the {resource}.')
def KmsProjectAttributeConfig():
return concepts.ResourceParameterAttributeConfig(
name='kms-project', help_text='Cloud project id for the {resource}.')
def GetInstanceResourceSpec():
return concepts.ResourceSpec(
'spanner.projects.instances',
resource_name='instance',
instancesId=InstanceAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)
def GetInstancePartitionResourceSpec():
return concepts.ResourceSpec(
'spanner.projects.instances.instancePartitions',
resource_name='instance partition',
instancePartitionsId=InstancePartitionAttributeConfig(),
instancesId=InstanceAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG,
)
def GetDatabaseResourceSpec():
return concepts.ResourceSpec(
'spanner.projects.instances.databases',
resource_name='database',
databasesId=DatabaseAttributeConfig(),
instancesId=InstanceAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)
def GetKmsKeyResourceSpec():
return concepts.ResourceSpec(
'cloudkms.projects.locations.keyRings.cryptoKeys',
resource_name='key',
cryptoKeysId=KmsKeyAttributeConfig(),
keyRingsId=KmsKeyringAttributeConfig(),
locationsId=KmsLocationAttributeConfig(),
projectsId=KmsProjectAttributeConfig())
def GetBackupResourceSpec():
return concepts.ResourceSpec(
'spanner.projects.instances.backups',
resource_name='backup',
backupsId=BackupAttributeConfig(),
instancesId=InstanceAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)
def GetBackupScheduleResourceSpec():
return concepts.ResourceSpec(
'spanner.projects.instances.databases.backupSchedules',
resource_name='backup-schedule',
backupSchedulesId=BackupScheduleAttributeConfig(),
databasesId=DatabaseAttributeConfig(),
instancesId=InstanceAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)
def GetSessionResourceSpec():
return concepts.ResourceSpec(
'spanner.projects.instances.databases.sessions',
resource_name='session',
sessionsId=SessionAttributeConfig(),
databasesId=DatabaseAttributeConfig(),
instancesId=InstanceAttributeConfig(),
projectsId=concepts.DEFAULT_PROJECT_ATTRIBUTE_CONFIG)
def AddInstanceResourceArg(parser, verb, positional=True):
"""Add a resource argument for a Cloud Spanner instance.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the argparse parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
positional: bool, if True, means that the instance ID is a positional rather
than a flag.
"""
name = 'instance' if positional else '--instance'
concept_parsers.ConceptParser.ForResource(
name,
GetInstanceResourceSpec(),
'The Cloud Spanner instance {}.'.format(verb),
required=True).AddToParser(parser)
def AddInstancePartitionResourceArg(parser, verb, positional=True):
"""Add a resource argument for a Spanner instance partition.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the argparse parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
positional: bool, if True, means that the instance ID is a positional rather
than a flag.
"""
name = 'instance_partition' if positional else '--instance-partition'
concept_parsers.ConceptParser.ForResource(
name,
GetInstancePartitionResourceSpec(),
'The Spanner instance partition {}.'.format(verb),
required=True,
).AddToParser(parser)
def AddDatabaseResourceArg(parser, verb, positional=True):
"""Add a resource argument for a Cloud Spanner database.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the argparse parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
positional: bool, if True, means that the database ID is a positional rather
than a flag.
"""
name = 'database' if positional else '--database'
concept_parsers.ConceptParser.ForResource(
name,
GetDatabaseResourceSpec(),
'The Cloud Spanner database {}.'.format(verb),
required=True).AddToParser(parser)
def AddKmsKeyResourceArg(parser, verb, positional=False):
"""Add a resource argument for a KMS Key used to create a CMEK database.
Args:
parser: argparser, the parser for the command.
verb: str, the verb used to describe the resource, such as 'to create'.
positional: bool, optional. True if the resource arg is postional rather
than a flag.
"""
kms_key_name = 'kms-key' if positional else '--kms-key'
kms_key_names = 'kms-keys' if positional else '--kms-keys'
group = parser.add_group('KMS key name group', mutex=True)
concept_parsers.ConceptParser([
presentation_specs.ResourcePresentationSpec(
kms_key_name,
GetKmsKeyResourceSpec(),
'Cloud KMS key to be used {}.'.format(verb),
required=False,
group=group,
),
presentation_specs.ResourcePresentationSpec(
kms_key_names,
GetKmsKeyResourceSpec(),
'Cloud KMS key(s) to be used {}.'.format(verb),
required=False,
prefixes=True,
plural=True,
group=group,
flag_name_overrides={
'kms-location': '',
'kms-keyring': '',
'kms-project': '',
},
),
]).AddToParser(parser)
def AddSessionResourceArg(parser, verb, positional=True):
"""Add a resource argument for a Cloud Spanner session.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
positional: bool, if True, means that the session ID is a positional rather
than a flag.
"""
name = 'session' if positional else '--session'
concept_parsers.ConceptParser.ForResource(
name,
GetSessionResourceSpec(),
'The Cloud Spanner session {}.'.format(verb),
required=True).AddToParser(parser)
def AddBackupResourceArg(parser, verb, positional=True):
"""Add a resource argument for a Cloud Spanner backup.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the argparse parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
positional: bool, if True, means that the backup ID is a positional rather
than a flag.
"""
name = 'backup' if positional else '--backup'
concept_parsers.ConceptParser.ForResource(
name,
GetBackupResourceSpec(),
'The Cloud Spanner backup {}.'.format(verb),
required=True).AddToParser(parser)
def AddBackupScheduleResourceArg(parser, verb, positional=True):
"""Add a resource argument for a Cloud Spanner backup schedule.
NOTE: Must be used only if it's the only resource arg in the command.
Args:
parser: the argparse parser for the command.
verb: str, the verb to describe the resource, such as 'to update'.
positional: bool, if True, means that the backup schedules ID is a
positional rather than a flag.
"""
name = 'backup_schedule' if positional else '--backup-schedule'
concept_parsers.ConceptParser.ForResource(
name,
GetBackupScheduleResourceSpec(),
'The Cloud Spanner backup schedule {}.'.format(verb),
required=True).AddToParser(parser)
def AddCreateBackupEncryptionTypeArg(parser):
return _CREATE_BACKUP_ENCRYPTION_TYPE_MAPPER.choice_arg.AddToParser(parser)
def GetCreateBackupEncryptionType(args):
return _CREATE_BACKUP_ENCRYPTION_TYPE_MAPPER.GetEnumForChoice(
args.encryption_type)
def AddCreateBackupEncryptionConfigTypeArg(parser):
return _CREATE_BACKUP_ENCRYPTION_CONFIG_TYPE_MAPPER.choice_arg.AddToParser(
parser
)
def GetCreateBackupEncryptionConfigType(args):
return _CREATE_BACKUP_ENCRYPTION_CONFIG_TYPE_MAPPER.GetEnumForChoice(
args.encryption_type
)
def AddCopyBackupResourceArgs(parser):
"""Add backup resource args (source, destination) for copy command."""
arg_specs = [
presentation_specs.ResourcePresentationSpec(
'--source',
GetBackupResourceSpec(),
'TEXT',
required=True,
flag_name_overrides={
'instance': '--source-instance',
'backup': '--source-backup'
}),
presentation_specs.ResourcePresentationSpec(
'--destination',
GetBackupResourceSpec(),
'TEXT',
required=True,
flag_name_overrides={
'instance': '--destination-instance',
'backup': '--destination-backup',
}),
]
concept_parsers.ConceptParser(arg_specs).AddToParser(parser)
def AddCopyBackupEncryptionTypeArg(parser):
return _COPY_BACKUP_ENCRYPTION_TYPE_MAPPER.choice_arg.AddToParser(parser)
def GetCopyBackupEncryptionType(args):
return _COPY_BACKUP_ENCRYPTION_TYPE_MAPPER.GetEnumForChoice(
args.encryption_type)
def AddRestoreResourceArgs(parser):
"""Add backup resource args (source, destination) for restore command."""
arg_specs = [
presentation_specs.ResourcePresentationSpec(
'--source',
GetBackupResourceSpec(),
'TEXT',
required=True,
flag_name_overrides={
'instance': '--source-instance',
'backup': '--source-backup'
}),
presentation_specs.ResourcePresentationSpec(
'--destination',
GetDatabaseResourceSpec(),
'TEXT',
required=True,
flag_name_overrides={
'instance': '--destination-instance',
'database': '--destination-database',
}),
]
concept_parsers.ConceptParser(arg_specs).AddToParser(parser)
def AddRestoreDbEncryptionTypeArg(parser):
return _RESTORE_DB_ENCRYPTION_TYPE_MAPPER.choice_arg.AddToParser(parser)
def GetRestoreDbEncryptionType(args):
return _RESTORE_DB_ENCRYPTION_TYPE_MAPPER.GetEnumForChoice(
args.encryption_type)
class CloudKmsKeyName:
"""CloudKmsKeyName to encapsulate `kmsKeyName` and `kmsKeyNames` fields.
Single `kmsKeyName` and repeated `kmsKeyNames` fields are extracted from user
input, which are later used in `EncryptionConfig` to pass to Spanner backend.
"""
def __init__(self, kms_key_name=None, kms_key_names=None):
self.kms_key_name = kms_key_name
if kms_key_names is None:
self.kms_key_names = []
else:
self.kms_key_names = kms_key_names
def GetAndValidateKmsKeyName(args) -> CloudKmsKeyName:
"""Parse the KMS key resource arg, make sure the key format is correct.
Args:
args: calliope framework gcloud args
Returns:
CloudKmsKeyName: if CMEK.
None: if non-CMEK.
"""
kms_key_name = args.CONCEPTS.kms_key.Parse()
kms_key_names = args.CONCEPTS.kms_keys.Parse()
cloud_kms_key_name = CloudKmsKeyName()
if kms_key_name:
cloud_kms_key_name.kms_key_name = kms_key_name.RelativeName()
elif kms_key_names:
cloud_kms_key_name.kms_key_names = [
kms_key_name.RelativeName() for kms_key_name in kms_key_names
]
else:
# If parsing failed but args were specified, raise error
for keyword in [
'kms-key',
'kms-keyring',
'kms-location',
'kms-project',
'kms-keys',
]:
if getattr(args, keyword.replace('-', '_'), None):
raise exceptions.InvalidArgumentException(
'--kms-project --kms-location --kms-keyring --kms-key or'
' --kms-keys',
'For a single KMS key, specify fully qualified KMS key ID with'
' --kms-key, or use combination of --kms-project, --kms-location,'
' --kms-keyring and '
+ '--kms-key to specify the key ID in pieces. Or specify fully'
' qualified KMS key ID with --kms-keys.',
)
return None # User didn't specify KMS key
return cloud_kms_key_name
def AddInstanceTypeArg(parser):
return _INSTANCE_TYPE_MAPPER.choice_arg.AddToParser(parser)
def GetInstanceType(args):
return _INSTANCE_TYPE_MAPPER.GetEnumForChoice(args.instance_type)
def AddDefaultStorageTypeArg(parser):
return _DEFAULT_STORAGE_TYPE_MAPPER.choice_arg.AddToParser(parser)
def GetDefaultStorageTypeArg(args):
return _DEFAULT_STORAGE_TYPE_MAPPER.GetEnumForChoice(
args.default_storage_type
)
def AddExpireBehaviorArg(parser):
return _EXPIRE_BEHAVIOR_MAPPER.choice_arg.AddToParser(parser)
def GetExpireBehavior(args):
return _EXPIRE_BEHAVIOR_MAPPER.GetEnumForChoice(args.expire_behavior)

View File

@@ -0,0 +1,97 @@
project:
name: project
collection: spanner.projects
attributes:
- &project
parameter_name: projectsId
attribute_name: project
help: The project name.
property: core/project
instance:
name: instance
collection: spanner.projects.instances
attributes:
- *project
- &instance
parameter_name: instancesId
attribute_name: instance
help: |
The name of the Cloud Spanner instance.
property: spanner/instance
disable_auto_completers: false
instancePartition:
name: instancePartition
collection: spanner.projects.instances.instancePartitions
attributes:
- *project
- *instance
- parameter_name: instancePartitionsId
attribute_name: instancePartition
help: |
The name of the Spanner instance partition.
disable_auto_completers: false
database:
name: database
collection: spanner.projects.instances.databases
attributes:
- *project
- *instance
- &database
parameter_name: databasesId
attribute_name: database
help: |
The name of the Cloud Spanner database.
disable_auto_completers: false
backup:
name: backup
collection: spanner.projects.instances.backups
attributes:
- *project
- *instance
- parameter_name: backupsId
attribute_name: backup
help: |
The name of the Cloud Spanner backup.
disable_auto_completers: false
backupSchedule:
name: backupSchedule
collection: spanner.projects.instances.databases.backupSchedules
attributes:
- *project
- *instance
- *database
- parameter_name: backupSchedulesId
attribute_name: backup_schedule
help: |
The name of the Cloud Spanner backup schedule.
disable_auto_completers: false
backupOperation:
name: backupOperation
collection: spanner.projects.instances.backups.operations
attributes:
- *project
- *instance
- &backup
- parameter_name: operationsId
attribute_name: operation
help: |
The name of the Cloud Spanner backup operation.
disable_auto_completers: false
operation:
name: operation
collection: spanner.projects.instances.operations
attributes:
- *project
- *instance
- parameter_name: operationsId
attribute_name: operation
help: |
The name of the Cloud Spanner operation.
disable_auto_completers: false

View File

@@ -0,0 +1,342 @@
# -*- coding: utf-8 -*- #
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner samples API helper."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import collections
import os
from googlecloudsdk.api_lib.spanner import databases
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import log
from googlecloudsdk.core.util import files
# TODO(b/230344467): Better default samples dir
_SAMPLES_DEFAULT_DIR_NAME = '.gcloud-spanner-samples'
_SAMPLES_DEFAULT_DIR_PATH = os.path.join(
os.path.expanduser('~'), _SAMPLES_DEFAULT_DIR_NAME)
SAMPLES_DIR_PATH = os.getenv('GCLOUD_SPANNER_SAMPLES_HOME',
_SAMPLES_DEFAULT_DIR_PATH)
_BIN_RELPATH = 'bin'
SAMPLES_BIN_PATH = os.path.join(SAMPLES_DIR_PATH, _BIN_RELPATH)
_LOG_RELPATH = 'log'
SAMPLES_LOG_PATH = os.path.join(SAMPLES_DIR_PATH, _LOG_RELPATH)
_ETC_RELPATH = 'etc'
SAMPLES_ETC_PATH = os.path.join(SAMPLES_DIR_PATH, _ETC_RELPATH)
_DATA_INSERT_RELPATH = 'data-insert-statements'
SAMPLES_DATA_INSERT_PATH = os.path.join(SAMPLES_DIR_PATH, _DATA_INSERT_RELPATH)
# TODO(b/228633873): Replace with prod bucket
GCS_BUCKET = 'gs://cloud-spanner-samples'
NOT_EXIST = 'not-exist'
BANKING_APP_NAME = 'banking'
FINANCE_APP_NAME = 'finance'
FINANCE_GRAPH_APP_NAME = 'finance-graph'
FINANCE_PG_APP_NAME = 'finance-pg'
GAMING_APP_NAME = 'gaming'
AppAttrs = collections.namedtuple(
'AppAttrs',
[
'db_id', # Name of the sample app DB
'bin_path', # Relative path for sample app bin files
'etc_path', # Relative path for schema, data, and other files
'gcs_prefix', # Prefix for sample app files in GCS_BUCKET
'schema_file', # Schema filename (in GCS and locally)
'backend_bin', # Backend/server bin filename
'workload_bin', # Workload bin filename
'database_dialect', # The database dialect used in this sample
'data_insert_statements_path', # Relative path for INSERT stmt files.
],
)
APPS = {
BANKING_APP_NAME: AppAttrs(
db_id='banking-db',
bin_path='banking',
etc_path='banking',
data_insert_statements_path='banking',
schema_file='banking-schema.sdl',
gcs_prefix='banking',
backend_bin=NOT_EXIST,
workload_bin=NOT_EXIST,
database_dialect=databases.DATABASE_DIALECT_GOOGLESQL,
),
FINANCE_APP_NAME: AppAttrs(
db_id='finance-db',
bin_path='finance',
etc_path='finance',
schema_file='finance-schema.sdl',
gcs_prefix='finance',
data_insert_statements_path=NOT_EXIST,
backend_bin='server-1.0-SNAPSHOT-jar-with-dependencies.jar',
workload_bin='workload-1.0-SNAPSHOT-jar-with-dependencies.jar',
database_dialect=databases.DATABASE_DIALECT_GOOGLESQL,
),
FINANCE_GRAPH_APP_NAME: AppAttrs(
db_id='finance-graph-db',
bin_path='finance-graph',
etc_path='finance-graph',
data_insert_statements_path='finance-graph',
schema_file='finance-graph-schema.sdl',
gcs_prefix='finance-graph',
backend_bin=NOT_EXIST,
workload_bin=NOT_EXIST,
database_dialect=databases.DATABASE_DIALECT_GOOGLESQL,
),
FINANCE_PG_APP_NAME: AppAttrs(
db_id='finance-pg-db',
bin_path='finance-pg',
etc_path='finance-pg',
schema_file='finance-schema-pg.sdl',
gcs_prefix='finance',
data_insert_statements_path=NOT_EXIST,
backend_bin='server-1.0-SNAPSHOT-jar-with-dependencies.jar',
workload_bin='workload-1.0-SNAPSHOT-jar-with-dependencies.jar',
database_dialect=databases.DATABASE_DIALECT_POSTGRESQL,
),
GAMING_APP_NAME: AppAttrs(
db_id='gaming-db',
bin_path='gaming',
etc_path='gaming',
data_insert_statements_path='gaming',
schema_file='gaming-schema.sdl',
gcs_prefix='gaming',
backend_bin=NOT_EXIST,
workload_bin=NOT_EXIST,
database_dialect=databases.DATABASE_DIALECT_GOOGLESQL,
),
}
_GCS_BIN_PREFIX = 'bin'
_GCS_SCHEMA_PREFIX = 'schema'
_GCS_DATA_INSERT_STATEMENTS_PREFIX = 'data-insert-statements'
class SpannerSamplesError(exceptions.Error):
"""User error running Cloud Spanner sample app commands."""
def check_appname(appname):
"""Raise if the given sample app doesn't exist.
Args:
appname: str, Name of the sample app.
Raises:
ValueError: if the given sample app doesn't exist.
"""
if appname not in APPS:
raise ValueError("Unknown sample app '{}'".format(appname))
def get_db_id_for_app(appname):
"""Get the database ID for the given sample app.
Args:
appname: str, Name of the sample app.
Returns:
str, The database ID, e.g. "finance-db".
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
return APPS[appname].db_id
def get_local_schema_path(appname):
"""Get the local path of the schema file for the given sample app.
Note that the file and parent dirs may not exist.
Args:
appname: str, Name of the sample app.
Returns:
str, The local path of the schema file.
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
app_attrs = APPS[appname]
return os.path.join(SAMPLES_ETC_PATH, app_attrs.etc_path,
app_attrs.schema_file)
def get_local_bin_path(appname):
"""Get the local path to binaries for the given sample app.
This typically includes server and workload binaries and any required
dependencies. Note that the path may not exist.
Args:
appname: str, Name of the sample app.
Returns:
str, The local path of the sample app binaries.
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
return os.path.join(SAMPLES_BIN_PATH, APPS[appname].bin_path)
def get_local_data_insert_statements_path(appname):
"""Get the local path to data insert statements for the given sample app.
Args:
appname: str, Name of the sample app.
Returns:
str, The local path of the sample app data insert statements.
Raises:
ValueError: if the given sample app or the data_insert_statements_path don't
exist.
"""
check_appname(appname)
if APPS[appname].data_insert_statements_path == NOT_EXIST:
raise ValueError(
"Unknown sample app data insert statements '{}'".format(appname)
)
return os.path.join(
SAMPLES_DATA_INSERT_PATH, APPS[appname].data_insert_statements_path
)
def get_gcs_schema_name(appname):
"""Get the GCS file path for the schema for the given sample app.
Doesn't include the bucket name. Use to download the sample app schema file
from GCS.
Args:
appname: str, Name of the sample app.
Returns:
str, The sample app schema GCS file path.
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
app_attrs = APPS[appname]
return '/'.join(
[app_attrs.gcs_prefix, _GCS_SCHEMA_PREFIX, app_attrs.schema_file])
def get_gcs_bin_prefix(appname):
"""Get the GCS prefix for binaries for the given sample app.
Doesn't include the bucket name. Different sample apps have different
numbers and types of binaries, list the bucket contents before downloading.
Args:
appname: str, Name of the sample app.
Returns:
str, The sample app binaries GCS prefix.
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
return '/'.join([APPS[appname].gcs_prefix, _GCS_BIN_PREFIX, ''])
def get_gcs_data_insert_statements_prefix(appname):
"""Get the GCS prefix for data insert statements for the given sample app.
Args:
appname: str, Name of the sample app.
Returns:
str, The sample app binaries GCS prefix.
Raises:
ValueError: if the given sample app or the gcs_prefix don't exist.
"""
check_appname(appname)
if APPS[appname].gcs_prefix == NOT_EXIST:
raise ValueError(
"Unknown sample app data insert statements '{}'".format(appname)
)
return '/'.join(
[APPS[appname].gcs_prefix, _GCS_DATA_INSERT_STATEMENTS_PREFIX, '']
)
def get_database_dialect(appname):
"""Get the database dialect for the given sample app.
Args:
appname: str, Name of the sample app.
Returns:
str, The database dialect.
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
return APPS[appname].database_dialect
def has_sample_data_statements(appname):
"""Check if the sample app has both gcs_prefix and data_insert_statements_path.
Args:
appname: str, Name of the sample app.
Returns:
bool, both gcs_prefix and data_insert_statements_path exist.
Raises:
ValueError: if the given sample app doesn't exist.
"""
check_appname(appname)
return (
APPS[appname].gcs_prefix != NOT_EXIST
and APPS[appname].data_insert_statements_path != NOT_EXIST
)
def run_proc(args, capture_logs_fn=None):
"""Wrapper for execution_utils.Subprocess that optionally captures logs.
Args:
args: [str], The arguments to execute. The first argument is the command.
capture_logs_fn: str, If set, save logs to the specified filename.
Returns:
subprocess.Popen or execution_utils.SubprocessTimeoutWrapper, The running
subprocess.
"""
if capture_logs_fn:
logfile = files.FileWriter(capture_logs_fn, append=True, create_path=True)
log.status.Print('Writing logs to {}'.format(capture_logs_fn))
popen_args = dict(stdout=logfile, stderr=logfile)
else:
popen_args = {}
return execution_utils.Subprocess(args, **popen_args)

View File

@@ -0,0 +1,223 @@
# -*- coding: utf-8 -*- #
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides split file preprocessing for adding splits to a database."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import csv
import io
import re
from apitools.base.py import extra_types
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import exceptions as c_exceptions
from googlecloudsdk.core.util import files
class SplitFileParser:
r"""Parses a split file into a list of split points.
The split file is expected to be in the format of:
<ObjectType>[space]<ObjectName>[space](<Split Value>)
<ObjectType>[space]<ObjectName>[space](<Split Value>)
...
where ObjectType can be TABLE or INDEX.
Each split point must be in a new line.
Split value is expected to be a comma separated list of key parts.
Split values should be surrounded by parenthesis like ()
String values should be supplied in single quotes:'splitKeyPart'
Boolean values should be one of: true/false
INT64 and NUMERIC spanner datatype values should be supplied within
single quotes values like string format: '123',
'999999999999999999999999999.99'
Other number values should be supplied without quotes: 1.287
Timestamp values should be provided in the following format in single quote
values: '2020-06-18T17:24:53Z'
If the split value needs to have a comma, then that should be escaped by
backslash.
Examples:
TABLE Singers ('c32ca57a-786c-2268-09d4-95182a9930be')
INDEX Order (4.2)
TABLE TableD (0,'7ef9db22-d0e5-6041-8937-4bc6a7ef9db2')
INDEX IndexXYZ ('8762203435012030000',NULL,NULL)
INDEX IndexABC (0, '2020-06-18T17:24:53Z') TableKey (123,'ab\,c')
-- note that the above split value has a delimieter (comma) in it,
hence escaped by a backslash.
"""
def __init__(self, splits_file, split_expiration_date):
self.splits_file = splits_file
self.split_expiration_date = split_expiration_date
self.split_line_pattern = re.compile(r'(\S+)\s+(\S+)\s+(.+)')
self.incorrect_split_with_table_key_pattern = re.compile(
r'\((.*?)\) TABLE (\S+)\s+\((.*?)\)$'
)
self.incorrect_split_with_index_key_pattern = re.compile(
r'\((.*?)\) INDEX (\S+)\s+\((.*?)\)$'
)
self.index_full_key_pattern = re.compile(r'\((.*?)\) TableKey \((.*?)\)$')
self.single_key_pattern = re.compile(r'\((.*?)\)$')
def Process(self):
"""Gets the split points from the input file."""
msgs = apis.GetMessagesModule('spanner', 'v1')
split_points_list = []
with files.FileReader(self.splits_file) as file:
for single_split_string in file.read().splitlines():
single_split = self.ParseSplitPointString(single_split_string)
if (
not single_split
or not single_split['SplitValue']
or not single_split['ObjectName']
or not single_split['ObjectType']
or single_split['ObjectType'].upper() not in ['TABLE', 'INDEX']
):
raise c_exceptions.InvalidArgumentException(
'--splits-file',
'Invalid split point string: {}. Each split point must be in the'
' format of <ObjectType> <ObjectName> (<Split Value>) where'
' ObjectType can be TABLE or INDEX'.format(single_split_string),
)
split = msgs.SplitPoints()
if single_split['ObjectType'].upper() == 'TABLE':
split.table = single_split['ObjectName']
elif single_split['ObjectType'].upper() == 'INDEX':
split.index = single_split['ObjectName']
if single_split['SplitValue']:
split.keys = self.ParseSplitValue(single_split['SplitValue'])
if self.split_expiration_date:
split.expireTime = self.split_expiration_date
split_points_list.append(split)
return split_points_list
def ParseSplitPointString(self, input_string):
"""Parses a string in the format "<ObjectType> <ObjectName> (<Split Value>)".
and returns a dictionary with the extracted information.
Args:
input_string: The string to parse.
Returns:
A dictionary with keys "ObjectType", "ObjectName", and "SplitValue",
or None if the input string is not in the expected format.
"""
# Matches three groups of non-whitespace characters separated by spaces
match = self.split_line_pattern.match(input_string)
if match:
return {
'ObjectType': match.group(1),
'ObjectName': match.group(2),
'SplitValue': match.group(3)
}
else:
raise c_exceptions.InvalidArgumentException(
'--splits-file',
'Invalid split point string: {}. Each split point must be in the'
' format of <ObjectType> <ObjectName> (<Split Value>) where'
' ObjectType can be TABLE or INDEX'.format(input_string),
)
def ParseSplitValue(self, input_string):
"""Parses a string in the format "(CommaSeparatedKeyParts) TableKey (CommaSeparatedKeyParts)".
and returns a dictionary with the extracted information.
Args:
input_string: The string to parse.
Returns:
A split point key.
"""
msgs = apis.GetMessagesModule('spanner', 'v1')
keys_all = []
input_string = input_string.strip()
# Catches the case when single line contains multiple split points.
if self.incorrect_split_with_table_key_pattern.match(
input_string
) or self.incorrect_split_with_index_key_pattern.match(input_string):
raise c_exceptions.InvalidArgumentException(
'--splits-file',
'Invalid split point string: {}. Each line must contain a single'
' split point for a table or index.'.format(input_string),
)
all_keys_strings = []
match = self.index_full_key_pattern.match(input_string)
if match:
# Index split with full key
all_keys_strings.append(match.group(1))
all_keys_strings.append(match.group(2))
else:
match = self.single_key_pattern.match(input_string)
if match:
all_keys_strings.append(match.group(1))
else:
raise c_exceptions.InvalidArgumentException(
'--splits-file',
'The split value must be surrounded by parenthesis.',
)
for input_string_per_key in all_keys_strings:
input_string_per_key = input_string_per_key.strip()
input_string_per_key = input_string_per_key.strip('()')
single_key = msgs.Key()
for split_token in self.TokenizeWithCsv(input_string_per_key):
key_parts = extra_types.JsonValue()
if split_token == 'NULL':
key_parts.is_null = True
else:
if (
split_token == 'true'
or split_token == 'false'
or split_token == 'TRUE'
or split_token == 'FALSE'
):
key_parts.boolean_value = bool(split_token.lower())
else:
if split_token.find('\'') == -1:
key_parts.double_value = float(split_token)
else:
key_parts.string_value = split_token.strip('\'')
single_key.keyParts.append(key_parts)
keys_all.append(single_key)
return keys_all
def TokenizeWithCsv(self, text):
"""Tokenizes text using commas as delimiters, ignoring commas within single quotes.
Args:
text: The text to tokenize.
Returns:
A list of tokens.
"""
reader = csv.reader(
io.StringIO(text),
quotechar="'",
skipinitialspace=True, quoting=csv.QUOTE_NONE,
escapechar='\\'
)
return next(reader)
def ParseSplitPoints(args):
"""Gets the split points from the input file."""
return SplitFileParser(args.splits_file, args.split_expiration_date).Process()

View File

@@ -0,0 +1,389 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common methods to display parts of SQL query results."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from functools import partial
from apitools.base.py import encoding
from googlecloudsdk.core.resource import resource_printer
from googlecloudsdk.core.util import text
from sqlparse import lexer
from sqlparse import tokens as T
def _GetAdditionalProperty(properties, property_key, not_found_value='Unknown'):
"""Gets the value for the given key in a list of properties.
Looks through a list of properties and tries to find the value for the given
key. If it's not found, not_found_value is returned.
Args:
properties: A dictionary of key string, value string pairs.
property_key: The key string for which we want to get the value.
not_found_value: The string value to return if the key is not found.
Returns:
A string containing the value for the given key, or `not_found_value` if
the key is not found.
"""
for prop in properties:
if prop.key == property_key:
if hasattr(prop, 'value'):
return prop.value
break
return not_found_value
def _ConvertToTree(plan_nodes):
"""Creates tree of Node objects from the plan_nodes in server response.
Args:
plan_nodes (spanner_v1_messages.PlanNode[]): The plan_nodes from the server
response. Plan nodes are topologically sorted.
Returns:
A Node, root of a tree built from `plan_nodes`.
"""
# plan_nodes is a topologically sorted list, with the root node first.
return _BuildSubTree(plan_nodes, plan_nodes[0])
def _BuildSubTree(plan_nodes, node):
"""Helper for building the subtree of a query plan node.
Args:
plan_nodes (spanner_v1_messages.PlanNode[]): The plan_nodes from the server
response. Plan nodes are topologically sorted.
node (spanner_v1_messages.PlanNode): The root node of the subtree to be
built.
Returns:
A Node object.
"""
children = None
if node.childLinks:
children = [_BuildSubTree(plan_nodes, plan_nodes[link.childIndex])
for link in node.childLinks]
return Node(node, children)
def _ConvertToStringValue(prop):
"""Converts the prop to a string if it exists.
Args:
prop (object_value): The value returned from _GetAdditionalProperty.
Returns:
A string value for the given prop, or the `not_found_value` if the prop does
not exist.
"""
return getattr(prop, 'string_value', prop)
def _DisplayNumberOfRowsModified(row_count, is_exact_count, out):
"""Prints number of rows modified by a DML statement.
Args:
row_count: Either the exact number of rows modified by statement or the
lower bound of rows modified by a Partitioned DML statement.
is_exact_count: Boolean stating whether the number is the exact count.
out: Output stream to which we print.
"""
if is_exact_count:
output_str = 'Statement modified {} {}'
else:
output_str = 'Statement modified a lower bound of {} {}'
if row_count == 1:
out.Print(output_str.format(row_count, 'row'))
else:
out.Print(output_str.format(row_count, 'rows'))
def QueryHasDml(sql):
"""Determines if the sql string contains a DML query.
Args:
sql (string): The sql string entered by the user.
Returns:
A boolean.
"""
sql = sql.lstrip().lower()
tokenized = lexer.tokenize(sql)
for token in list(tokenized):
has_dml = (
token == (T.Keyword.DML, 'insert') or
token == (T.Keyword.DML, 'update') or
token == (T.Keyword.DML, 'delete'))
if has_dml:
return True
return False
def QueryHasAggregateStats(result):
"""Checks if the given results have aggregate statistics.
Args:
result (spanner_v1_messages.ResultSetStats): The stats for a query.
Returns:
A boolean indicating whether 'results' contain aggregate statistics.
"""
return hasattr(
result, 'stats') and getattr(result.stats, 'queryStats', None) is not None
def DisplayQueryAggregateStats(query_stats, out):
"""Displays the aggregate stats for a Spanner SQL query.
Looks at the queryStats portion of the query response and prints some of
the aggregate statistics.
Args:
query_stats (spanner_v1_messages.ResultSetStats.QueryStatsValue): The query
stats taken from the server response to a query.
out: Output stream to which we print.
"""
get_prop = partial(_GetAdditionalProperty, query_stats.additionalProperties)
stats = {
'total_elapsed_time': _ConvertToStringValue(get_prop('elapsed_time')),
'cpu_time': _ConvertToStringValue(get_prop('cpu_time')),
'rows_returned': _ConvertToStringValue(get_prop('rows_returned')),
'rows_scanned': _ConvertToStringValue(get_prop('rows_scanned')),
'optimizer_version': _ConvertToStringValue(get_prop('optimizer_version')),
}
resource_printer.Print(
stats,
'table[box](total_elapsed_time, cpu_time, rows_returned, rows_scanned, optimizer_version)',
out=out)
def DisplayQueryPlan(result, out):
"""Displays a graphical query plan for a query.
Args:
result (spanner_v1_messages.ResultSet): The server response to a query.
out: Output stream to which we print.
"""
node_tree_root = _ConvertToTree(result.stats.queryPlan.planNodes)
node_tree_root.PrettyPrint(out)
def DisplayQueryResults(result, out):
"""Prints the result rows for a query.
Args:
result (spanner_v1_messages.ResultSet): The server response to a query.
out: Output stream to which we print.
"""
if hasattr(result.stats,
'rowCountExact') and result.stats.rowCountExact is not None:
_DisplayNumberOfRowsModified(result.stats.rowCountExact, True, out)
if hasattr(
result.stats,
'rowCountLowerBound') and result.stats.rowCountLowerBound is not None:
_DisplayNumberOfRowsModified(result.stats.rowCountLowerBound, False, out)
if result.metadata.rowType.fields:
# Print "(Unspecified)" for computed columns.
fields = [
field.name or '(Unspecified)'
for field in result.metadata.rowType.fields
]
# Create the format string we pass to the table layout.
table_format = ','.join('row.slice({0}).join():label="{1}"'.format(i, f)
for i, f in enumerate(fields))
rows = [{
'row': encoding.MessageToPyValue(row.entry)
} for row in result.rows]
# Can't use the PrintText method because we want special formatting.
resource_printer.Print(rows, 'table({0})'.format(table_format), out=out)
class Node(object):
"""Represents a single node in a Spanner query plan.
Attributes:
properties (spanner_v1_messages.PlanNode): The details about a given node
as returned from the server.
children: A list of children in the query plan of type Node.
"""
def __init__(self, properties, children=None):
self.children = children or []
self.properties = properties
def _DisplayKindAndName(self, out, prepend, stub):
"""Prints the kind of the node (SCALAR or RELATIONAL) and its name."""
kind_and_name = '{}{} {} {}'.format(prepend, stub, self.properties.kind,
self.properties.displayName)
out.Print(kind_and_name)
def _GetNestedStatProperty(self, prop_name, nested_prop_name):
"""Gets a nested property name on this object's executionStats.
Args:
prop_name: A string of the key name for the outer property on
executionStats.
nested_prop_name: A string of the key name of the nested property.
Returns:
The string value of the nested property, or None if the outermost
property or nested property don't exist.
"""
prop = _GetAdditionalProperty(
self.properties.executionStats.additionalProperties, prop_name, '')
if not prop:
return None
nested_prop = _GetAdditionalProperty(prop.object_value.properties,
nested_prop_name, '')
if nested_prop:
return nested_prop.string_value
return None
def _DisplayExecutionStats(self, out, prepend, beneath_stub):
"""Prints the relevant execution statistics for a node.
More specifically, print out latency information and the number of
executions. This information only exists when query is run in 'PROFILE'
mode.
Args:
out: Output stream to which we print.
prepend: String that precedes any information about this node to maintain
a visible hierarchy.
beneath_stub: String that preserves the indentation of the vertical lines.
"""
if not self.properties.executionStats:
return None
stat_props = []
num_executions = self._GetNestedStatProperty('execution_summary',
'num_executions')
if num_executions:
num_executions = int(num_executions)
executions_str = '{} {}'.format(num_executions,
text.Pluralize(num_executions,
'execution'))
stat_props.append(executions_str)
# Total latency and latency unit are always expected to be present when
# latency exists. Latency exists when the query is run in PROFILE mode.
mean_latency = self._GetNestedStatProperty('latency', 'mean')
total_latency = self._GetNestedStatProperty('latency', 'total')
unit = self._GetNestedStatProperty('latency', 'unit')
if mean_latency:
stat_props.append('{} {} average latency'.format(mean_latency, unit))
elif total_latency:
stat_props.append('{} {} total latency'.format(total_latency, unit))
if stat_props:
executions_stats_str = '{}{} ({})'.format(prepend, beneath_stub,
', '.join(stat_props))
out.Print(executions_stats_str)
def _DisplayMetadata(self, out, prepend, beneath_stub):
"""Prints the keys and values of the metadata for a node.
Args:
out: Output stream to which we print.
prepend: String that precedes any information about this node to maintain
a visible hierarchy.
beneath_stub: String that preserves the indentation of the vertical lines.
"""
if self.properties.metadata:
additional_props = []
# additionalProperties looks like: [key: {value: {string_value: str}}]
for prop in self.properties.metadata.additionalProperties:
additional_props.append(
'{}: {}'.format(prop.key, prop.value.string_value))
metadata = '{}{} {}'.format(prepend, beneath_stub,
', '.join(sorted(additional_props)))
out.Print(metadata)
def _DisplayShortRepresentation(self, out, prepend, beneath_stub):
if self.properties.shortRepresentation:
short_rep = '{}{} {}'.format(
prepend, beneath_stub,
self.properties.shortRepresentation.description)
out.Print(short_rep)
def _DisplayBreakLine(self, out, prepend, beneath_stub, is_root):
"""Displays an empty line between nodes for visual breathing room.
Keeps in tact the vertical lines connecting all immediate children of a
node to each other.
Args:
out: Output stream to which we print.
prepend: String that precedes any information about this node to maintain
a visible hierarchy.
beneath_stub: String that preserves the indentation of the vertical lines.
is_root: Boolean indicating whether this node is the root of the tree.
"""
above_child = ' ' if is_root else ''
above_child += ' |' if self.children else ''
break_line = '{}{}{}'.format(prepend, beneath_stub, above_child)
# It could be the case the beneath_stub adds spaces but above_child doesn't
# add an additional vertical line, in which case we want to remove the
# extra trailing spaces.
out.Print(break_line.rstrip())
def PrettyPrint(self, out, prepend=None, is_last=True, is_root=True):
"""Prints a string representation of this node in the tree.
Args:
out: Output stream to which we print.
prepend: String that precedes any information about this node to maintain
a visible hierarchy.
is_last: Boolean indicating whether this node is the last child of its
parent.
is_root: Boolean indicating whether this node is the root of the tree.
"""
prepend = prepend or ''
# The symbol immediately before node kind to indicate that this is a child
# of its parents. All nodes except the root get one.
stub = '' if is_root else (r'\-' if is_last else '+-')
# To list additional properties beneath the name, figure out how they should
# be indented relative to the name's stub.
beneath_stub = '' if is_root else (' ' if is_last else '| ')
self._DisplayKindAndName(out, prepend, stub)
self._DisplayExecutionStats(out, prepend, beneath_stub)
self._DisplayMetadata(out, prepend, beneath_stub)
self._DisplayShortRepresentation(out, prepend, beneath_stub)
self._DisplayBreakLine(out, prepend, beneath_stub, is_root)
for idx, child in enumerate(self.children):
is_last_child = idx == len(self.children) - 1
# The amount each subsequent level in the tree is indented.
indent = ' '
# Connect all immediate children to each other with a vertical line
# of '|'. Don't extend this line down past the last child node. It's
# cleaner.
child_prepend = prepend + (' ' if is_last else '|') + indent
child.PrettyPrint(
out, prepend=child_prepend, is_last=is_last_child, is_root=False)

View File

@@ -0,0 +1,448 @@
# -*- coding: utf-8 -*- #
# Copyright 2017 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides helper methods for dealing with Cloud Spanner Writes API.
The main reasons for adding the util functions for Writes API are as below:
- API expects column values to be extra_types.JsonValue, apitool cannot
handle it by default.
- for different data types the API expects different formats, for example:
for INT64, API expects a string value; for FLOAT64, it expects a number.
As the values user input are strings by default, the type conversion is
necessary.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import abc
from collections import OrderedDict
import re
from apitools.base.py import extra_types
from googlecloudsdk.core.exceptions import Error
import six
from six.moves import zip
class BadColumnNameError(Error):
"""Raised when a column name entered by user is not found in the table."""
class BadTableNameError(Error):
"""Raised when a table name entered by user is not found in the database."""
class InvalidKeysError(Error):
"""Raised when the number of keys user input does not match the DDL."""
class InvalidArrayInputError(Error):
"""Raised when the user tries to input a list as a value in the data flag."""
class _TableColumn(object):
"""A wrapper that stores the column information.
Attributes:
name: String, the name of the table column.
col_type: _ScalarColumnType or _ArrayColumnType.
"""
_COLUMN_DDL_PATTERN = re.compile(
r"""
# A column definition has a name and a type, with some additional
# properties.
# Some examples:
# Foo INT64 NOT NULL
# Bar STRING(1024)
# Baz ARRAY<FLOAT32>
[`]?(?P<name>\w+)[`]?\s+
(?P<type>[\w<>]+)
# We don't care about "NOT NULL", and the length number after STRING
# or BYTES (e.g.STRING(MAX), BYTES(1024)).
""", re.DOTALL | re.VERBOSE)
def __init__(self, name, col_type):
self.name = name
self.col_type = col_type
def __eq__(self, other):
return self.name == other.name and self.col_type == other.col_type
@classmethod
def FromDdl(cls, column_ddl):
"""Constructs an instance of _TableColumn from a column_def DDL statement.
Args:
column_ddl: string, the parsed string contains the column name and type
information. Example: SingerId INT64 NOT NULL.
Returns:
A _TableColumn object.
Raises:
ValueError: invalid DDL, this error shouldn't happen in theory, as the API
is expected to return valid DDL statement strings.
"""
column_match = cls._COLUMN_DDL_PATTERN.search(column_ddl)
if not column_match:
raise ValueError('Invalid DDL: [{}].'.format(column_ddl))
column_name = column_match.group('name')
col_type = _ColumnType.FromDdl(column_match.group('type'))
return _TableColumn(column_name, col_type)
def GetJsonValues(self, value):
"""Convert the user input values to JSON value or JSON array value.
Args:
value: String or string list, the user input values of the column.
Returns:
extra_types.JsonArray or extra_types.JsonValue, the json value of a single
column in the format that API accepts.
"""
return self.col_type.GetJsonValue(value)
class _ColumnType(six.with_metaclass(abc.ABCMeta, object)):
"""A wrapper that stores the column type information.
A column type can be one of the scalar types such as integers, as well as
array. An array type is an ordered list of zero or more elements of
scalar type.
Attributes:
scalar_type: String, the type of the column or the element of the column
(if the column is an array).
"""
# For Scalar types: there are 12 scalar types in Cloud Spanner considered as
# valid key and column types. 'JSON', 'TOKENLIST', 'FLOAT32' however, are not
# valid key types.
_SCALAR_TYPES = (
'BOOL',
'BYTES',
'DATE',
'FLOAT64',
'INT64',
'STRING',
'TIMESTAMP',
'NUMERIC',
'UUID',
'JSON',
'TOKENLIST',
'FLOAT32',
)
def __init__(self, scalar_type):
self.scalar_type = scalar_type
@classmethod
def FromDdl(cls, column_type_ddl):
"""Constructs a _ColumnType object from a partial DDL statement.
Args:
column_type_ddl: string, the parsed string only contains the column type
information. Example: INT64 NOT NULL, ARRAY<STRING(MAX)> or BYTES(200).
Returns:
A _ArrayColumnType or a _ScalarColumnType object.
Raises:
ValueError: invalid DDL, this error shouldn't happen in theory, as the API
is expected to return valid DDL statement strings.
"""
scalar_match = None
for data_type in cls._SCALAR_TYPES:
if data_type in column_type_ddl:
scalar_match = data_type
break
if not scalar_match:
raise ValueError(
'Invalid DDL: unrecognized type [{}].'.format(column_type_ddl))
if column_type_ddl.startswith('ARRAY'):
return _ArrayColumnType(scalar_match)
else:
return _ScalarColumnType(scalar_match)
@abc.abstractmethod
def GetJsonValue(self, value):
raise NotImplementedError()
def ConvertJsonValueForScalarTypes(scalar_type, scalar_value):
"""Convert the user input scalar value to JSON value.
Args:
scalar_type: String, the scalar type of the column, e.g INT64, DATE.
scalar_value: String, the value of the column that user inputs.
Returns:
An API accepts JSON value of a column or an element of an array column.
"""
if scalar_value == 'NULL':
return extra_types.JsonValue(is_null=True)
elif scalar_type == 'BOOL':
# True and true are valid boolean values.
bool_value = scalar_value.upper() == 'TRUE'
return extra_types.JsonValue(boolean_value=bool_value)
elif scalar_type in ('FLOAT64', 'FLOAT32'):
# NaN, +/-inf are valid float values.
if scalar_value in ('NaN', 'Infinity', '-Infinity'):
return extra_types.JsonValue(string_value=scalar_value)
else:
return extra_types.JsonValue(double_value=float(scalar_value))
else:
# TODO(b/73077622): add bytes conversion.
# For other data types (INT, STRING, TIMESTAMP, DATE, NUMERIC, JSON), the
# json format would be string.
return extra_types.JsonValue(string_value=scalar_value)
class _ScalarColumnType(_ColumnType):
def __init__(self, scalar_type):
super(_ScalarColumnType, self).__init__(scalar_type)
def __eq__(self, other):
return self.scalar_type == other.scalar_type and isinstance(
other, _ScalarColumnType)
def GetJsonValue(self, value):
return ConvertJsonValueForScalarTypes(self.scalar_type, value)
class _ArrayColumnType(_ColumnType):
def __init__(self, scalar_type):
super(_ArrayColumnType, self).__init__(scalar_type)
def __eq__(self, other):
return self.scalar_type == other.scalar_type and isinstance(
other, _ArrayColumnType)
def GetJsonValue(self, values):
return extra_types.JsonValue(
array_value=extra_types.JsonArray(entries=[
ConvertJsonValueForScalarTypes(self.scalar_type, v) for v in values
]))
class ColumnJsonData(object):
"""Container for the column name and value to be written in a table.
Attributes:
col_name: String, the name of the column to be written.
col_value: extra_types.JsonArray(array column) or
extra_types.JsonValue(scalar column), the value to be written.
"""
def __init__(self, col_name, col_value):
self.col_name = col_name
self.col_value = col_value
class Table(object):
"""Container for the properties of a table in Cloud Spanner database.
Attributes:
name: String, the name of table.
_columns: OrderedDict, with keys are the column names and values are the
_TableColumn objects.
_primary_keys: String list, the names of the primary key columns in the
order defined in the DDL statement
"""
_TABLE_DDL_PATTERN = re.compile(
r"""
# Every table starts with "CREATE TABLE" followed by name and column
# definitions, in a big set of parenthesis.
# For example:
# CREATE TABLE Foos (
# Bar INT64 NOT NULL,
# Baz INT64 NOT NULL,
# Qux STRING(MAX),
# )
CREATE\s+TABLE\s+
(?P<name>[\w\.]+)\s+\(\s+
(?P<columns>.*)\)\s+
# Then, it has "PRIMARY KEY" and a list of primary keys, in parens:
# PRIMARY KEY ( Bar, Qux )
PRIMARY\s+KEY\s*\(
(?P<primary_keys>.*)\)
# It may have extra instructions on the end (e.g. INTERLEAVE) to
# tell Spanner how to store the data, but we don't really care.
""", re.DOTALL | re.VERBOSE)
def __init__(self, table_name, _columns, _primary_keys=None):
self.name = table_name
self._columns = _columns
self._primary_keys = _primary_keys or []
def GetJsonData(self, data_dict):
"""Get the column names and values to be written from data input.
Args:
data_dict: Dictionary where keys are the column names and values are user
input data value, which is parsed from --data argument in the command.
Returns:
List of ColumnJsonData, which includes the column names and values to be
written.
"""
column_list = []
for col_name, col_value in six.iteritems(data_dict):
col_in_table = self._FindColumnByName(col_name)
col_json_value = col_in_table.GetJsonValues(col_value)
column_list.append(ColumnJsonData(col_name, col_json_value))
return column_list
def GetJsonKeys(self, keys_list):
"""Get the primary key values to be written from keys input.
Args:
keys_list: String list, the primary key values of the row to be deleted.
Returns:
List of extra_types.JsonValue.
Raises:
InvalidKeysError: the keys are invalid.
"""
# Raise an exception when the number of keys entered by user does not
# match the number of the primary key columns in the current table.
if len(keys_list) != len(self._primary_keys):
raise InvalidKeysError(
'Invalid keys. There are {} primary key columns in the table [{}]. '
'{} are given.'.format(
len(self._primary_keys), self.name, len(keys_list)))
keys_json_list = []
for key_name, key_value in zip(self._primary_keys, keys_list):
col_in_table = self._FindColumnByName(key_name)
col_json_value = col_in_table.GetJsonValues(key_value)
keys_json_list.append(col_json_value)
return keys_json_list
@classmethod
def FromDdl(cls, database_ddl, table_name):
"""Constructs a Table from ddl statements.
Args:
database_ddl: String list, the ddl statements of the current table from
server.
table_name: String, the table name user inputs.
Returns:
Table.
Raises:
BadTableNameError: the table name is invalid.
ValueError: Invalid Ddl.
"""
# A list of all the table names in the current database.
table_name_list = []
for ddl in database_ddl:
# If the ddl statement is a create table statement and matches the given
# table name, parse the string and return the table object.
table_match = cls._TABLE_DDL_PATTERN.search(ddl)
if not table_match:
continue
name = table_match.group('name')
if name != table_name:
# Store all valid table names of the database.
table_name_list.append(name)
continue
column_defs = table_match.group('columns')
column_dict = OrderedDict()
for column_ddl in column_defs.split(','):
# It can be an empty string at the end of the list.
if column_ddl and not column_ddl.isspace():
column = _TableColumn.FromDdl(column_ddl)
column_dict[column.name] = column
# Set the primary key list in the table.
# Example: PRIMARY KEY ( Bar, Qux ) -> [Bar,Qux].
raw_primary_keys = table_match.groupdict()['primary_keys']
primary_keys_list = [k.strip() for k in raw_primary_keys.split(',')]
return Table(table_name, column_dict, primary_keys_list)
raise BadTableNameError(
'Table name [{}] is invalid. Valid table names: [{}].'.format(
table_name, ', '.join(table_name_list)))
def GetColumnTypes(self):
"""Maps the column name to the column type.
Returns:
OrderedDict of column names to types.
"""
col_to_type = OrderedDict()
for name, column in six.iteritems(self._columns):
col_to_type[name] = column.col_type
return col_to_type
def _FindColumnByName(self, col_name):
"""Find the _TableColumn object with the given column name.
Args:
col_name: String, the name of the column.
Returns:
_TableColumn.
Raises:
BadColumnNameError: the column name is invalid.
"""
try:
return self._columns[col_name]
except KeyError:
valid_column_names = ', '.join(list(self._columns.keys()))
raise BadColumnNameError(
'Column name [{}] is invalid. Valid column names: [{}].'.format(
col_name, valid_column_names))
def ValidateArrayInput(table, data):
"""Checks array input is valid.
Args:
table: Table, the table which data is being modified.
data: OrderedDict, the data entered by the user.
Returns:
data (OrderedDict) the validated data.
Raises:
InvalidArrayInputError: if the input contains an array which is invalid.
"""
col_to_type = table.GetColumnTypes()
for column, value in six.iteritems(data):
col_type = col_to_type[column]
if isinstance(col_type,
_ArrayColumnType) and isinstance(value, list) is False:
raise InvalidArrayInputError(
'Column name [{}] has an invalid array input: {}. `--flags-file` '
'should be used to specify array values.'.format(column, value))
return data