# -*- coding: utf-8 -*- # # Copyright 2016 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Spanner database sessions API helper.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from apitools.base.py import encoding from apitools.base.py import exceptions as apitools_exceptions from apitools.base.py import extra_types from apitools.base.py import http_wrapper from apitools.base.py import list_pager from googlecloudsdk.api_lib.util import apis from googlecloudsdk.command_lib.spanner.sql import QueryHasDml def CheckResponse(response): """Wrap http_wrapper.CheckResponse to skip retry on 501.""" if response.status_code == 501: raise apitools_exceptions.HttpError.FromResponse(response) return http_wrapper.CheckResponse(response) def Create(database_ref, creator_role=None): """Create a database session. Args: database_ref: String, The database in which the new session is created. creator_role: String, The database role which created this session. Returns: Newly created session. """ client = _GetClientInstance('spanner', 'v1', None) msgs = apis.GetMessagesModule('spanner', 'v1') if creator_role is None: req = msgs.SpannerProjectsInstancesDatabasesSessionsCreateRequest( database=database_ref.RelativeName()) else: create_session_request = msgs.CreateSessionRequest( session=msgs.Session(creatorRole=creator_role)) req = msgs.SpannerProjectsInstancesDatabasesSessionsCreateRequest( createSessionRequest=create_session_request, database=database_ref.RelativeName()) return client.projects_instances_databases_sessions.Create(req) def List(database_ref, server_filter=None): """Lists all active sessions on the given database.""" client = apis.GetClientInstance('spanner', 'v1') msgs = apis.GetMessagesModule('spanner', 'v1') req = msgs.SpannerProjectsInstancesDatabasesSessionsListRequest( database=database_ref.RelativeName(), filter=server_filter) return list_pager.YieldFromList( client.projects_instances_databases_sessions, req, # There is a batch_size_attribute ('pageSize') but we want to yield as # many results as possible per request. batch_size_attribute=None, field='sessions') def Delete(session_ref): """Delete a database session.""" client = apis.GetClientInstance('spanner', 'v1') msgs = apis.GetMessagesModule('spanner', 'v1') req = msgs.SpannerProjectsInstancesDatabasesSessionsDeleteRequest( name=session_ref.RelativeName()) return client.projects_instances_databases_sessions.Delete(req) def _GetClientInstance(api_name, api_version, http_timeout_sec=None): client = apis.GetClientInstance( api_name, api_version, http_timeout_sec=http_timeout_sec) client.check_response_func = CheckResponse return client def ExecuteSql(sql, query_mode, session_ref, read_only_options=None, request_options=None, enable_partitioned_dml=False, http_timeout_sec=None): """Execute an SQL command. Args: sql: String, The SQL to execute. query_mode: String, The mode in which to run the query. Must be one of 'NORMAL', 'PLAN', 'PROFILE', 'WITH_STATS', or 'WITH_PLAN_AND_STATS'. session_ref: Session, Indicates that the repo should be created if it does not exist. read_only_options: The ReadOnly message for a read-only request. It is ignored in a DML request. request_options: The RequestOptions message that contains the priority. enable_partitioned_dml: Boolean, whether partitioned dml is enabled. http_timeout_sec: int, Maximum time in seconds to wait for the SQL query to complete. Returns: (Repo) The capture repository. """ client = _GetClientInstance('spanner', 'v1', http_timeout_sec) msgs = apis.GetMessagesModule('spanner', 'v1') _RegisterCustomMessageCodec(msgs) execute_sql_request = _GetQueryRequest( sql, query_mode, session_ref, read_only_options, request_options, enable_partitioned_dml, ) req = msgs.SpannerProjectsInstancesDatabasesSessionsExecuteSqlRequest( session=session_ref.RelativeName(), executeSqlRequest=execute_sql_request) resp = client.projects_instances_databases_sessions.ExecuteSql(req) if QueryHasDml(sql) and enable_partitioned_dml is False: result_set = msgs.ResultSet(metadata=resp.metadata) Commit(session_ref, [], result_set.metadata.transaction.id) return resp def _RegisterCustomMessageCodec(msgs): """Register custom message code. Args: msgs: Spanner v1 messages. """ # TODO(b/33482229): remove this workaround def _ToJson(msg): return extra_types.JsonProtoEncoder( extra_types.JsonArray(entries=msg.entry)) def _FromJson(data): return msgs.ResultSet.RowsValueListEntry( entry=extra_types.JsonProtoDecoder(data).entries) encoding.RegisterCustomMessageCodec( encoder=_ToJson, decoder=_FromJson)( msgs.ResultSet.RowsValueListEntry) def _GetQueryRequest(sql, query_mode, session_ref=None, read_only_options=None, request_options=None, enable_partitioned_dml=False): """Formats the request based on whether the statement contains DML. Args: sql: String, The SQL to execute. query_mode: String, The mode in which to run the query. Must be one of 'NORMAL', 'PLAN', 'PROFILE', 'WITH_STATS', or 'WITH_PLAN_AND_STATS'. session_ref: Reference to the session. read_only_options: The ReadOnly message for a read-only request. It is ignored in a DML request. request_options: The RequestOptions message that contains the priority. enable_partitioned_dml: Boolean, whether partitioned dml is enabled. Returns: ExecuteSqlRequest parameters """ msgs = apis.GetMessagesModule('spanner', 'v1') if enable_partitioned_dml is True: transaction = _GetPartitionedDmlTransaction(session_ref) elif QueryHasDml(sql): transaction_options = msgs.TransactionOptions(readWrite=msgs.ReadWrite()) transaction = msgs.TransactionSelector(begin=transaction_options) else: transaction_options = msgs.TransactionOptions( readOnly=read_only_options) transaction = msgs.TransactionSelector(singleUse=transaction_options) return msgs.ExecuteSqlRequest( sql=sql, requestOptions=request_options, queryMode=msgs.ExecuteSqlRequest.QueryModeValueValuesEnum(query_mode), transaction=transaction) def _GetPartitionedDmlTransaction(session_ref): """Creates a transaction for Partitioned DML. Args: session_ref: Reference to the session. Returns: TransactionSelector with the id property. """ client = apis.GetClientInstance('spanner', 'v1') msgs = apis.GetMessagesModule('spanner', 'v1') transaction_options = msgs.TransactionOptions( partitionedDml=msgs.PartitionedDml()) begin_transaction_req = msgs.BeginTransactionRequest( options=transaction_options) req = msgs.SpannerProjectsInstancesDatabasesSessionsBeginTransactionRequest( beginTransactionRequest=begin_transaction_req, session=session_ref.RelativeName()) resp = client.projects_instances_databases_sessions.BeginTransaction(req) return msgs.TransactionSelector(id=resp.id) def Commit(session_ref, mutations, transaction_id=None): """Commit a transaction through a session. In Cloud Spanner, each session can have at most one active transaction at a time. In order to avoid retrying aborted transactions by accident, this request uses a temporary single use transaction instead of a previously started transaction to execute the mutations. Note: this commit is non-idempotent. Args: session_ref: Session, through which the transaction would be committed. mutations: A list of mutations, each represents a modification to one or more Cloud Spanner rows. transaction_id: An optional string for the transaction id. Returns: The Cloud Spanner timestamp at which the transaction committed. """ client = apis.GetClientInstance('spanner', 'v1') msgs = apis.GetMessagesModule('spanner', 'v1') if transaction_id is not None: req = msgs.SpannerProjectsInstancesDatabasesSessionsCommitRequest( session=session_ref.RelativeName(), commitRequest=msgs.CommitRequest( mutations=mutations, transactionId=transaction_id)) else: req = msgs.SpannerProjectsInstancesDatabasesSessionsCommitRequest( session=session_ref.RelativeName(), commitRequest=msgs.CommitRequest( mutations=mutations, singleUseTransaction=msgs.TransactionOptions( readWrite=msgs.ReadWrite()))) return client.projects_instances_databases_sessions.Commit(req) class MutationFactory(object): """Factory that creates and returns a mutation object in Cloud Spanner. A Mutation represents a sequence of inserts, updates and deletes that can be applied to rows and tables in a Cloud Spanner database. """ msgs = apis.GetMessagesModule('spanner', 'v1') @classmethod def Insert(cls, table, data): """Constructs an INSERT mutation, which inserts a new row in a table. Args: table: String, the name of the table. data: A collections.OrderedDict, the keys of which are the column names and values are the column values to be inserted. Returns: An insert mutation operation. """ return cls.msgs.Mutation(insert=cls._GetWrite(table, data)) @classmethod def Update(cls, table, data): """Constructs an UPDATE mutation, which updates a row in a table. Args: table: String, the name of the table. data: An ordered dictionary where the keys are the column names and values are the column values to be updated. Returns: An update mutation operation. """ return cls.msgs.Mutation(update=cls._GetWrite(table, data)) @classmethod def Delete(cls, table, keys): """Constructs a DELETE mutation, which deletes a row in a table. Args: table: String, the name of the table. keys: String list, the primary key values of the row to delete. Returns: A delete mutation operation. """ return cls.msgs.Mutation(delete=cls._GetDelete(table, keys)) @classmethod def _GetWrite(cls, table, data): """Constructs Write object, which is needed for insert/update operations.""" # TODO(b/33482229): a workaround to handle JSON serialization def _ToJson(msg): return extra_types.JsonProtoEncoder( extra_types.JsonArray(entries=msg.entry)) encoding.RegisterCustomMessageCodec( encoder=_ToJson, decoder=None)( cls.msgs.Write.ValuesValueListEntry) json_columns = table.GetJsonData(data) json_column_names = [col.col_name for col in json_columns] json_column_values = [col.col_value for col in json_columns] return cls.msgs.Write( columns=json_column_names, table=table.name, values=[cls.msgs.Write.ValuesValueListEntry(entry=json_column_values)]) @classmethod def _GetDelete(cls, table, keys): """Constructs Delete object, which is needed for delete operation.""" # TODO(b/33482229): a workaround to handle JSON serialization def _ToJson(msg): return extra_types.JsonProtoEncoder( extra_types.JsonArray(entries=msg.entry)) encoding.RegisterCustomMessageCodec( encoder=_ToJson, decoder=None)( cls.msgs.KeySet.KeysValueListEntry) key_set = cls.msgs.KeySet(keys=[ cls.msgs.KeySet.KeysValueListEntry(entry=table.GetJsonKeys(keys)) ]) return cls.msgs.Delete(table=table.name, keySet=key_set)