# -*- coding: utf-8 -*- # # Copyright 2020 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Cloud Datastream API utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals import uuid from apitools.base.py import encoding as api_encoding from googlecloudsdk.api_lib.dataproc import exceptions from googlecloudsdk.api_lib.datastream import camel_case_utils from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions from googlecloudsdk.api_lib.util import apis from googlecloudsdk.calliope import base from googlecloudsdk.command_lib.export import util as export_util from googlecloudsdk.core import resources from googlecloudsdk.core import yaml from googlecloudsdk.core.console import console_io import six _DEFAULT_API_VERSION = 'v1' _DEFAULT_API_NAME = 'datastream' CAMEL_CASE_CONVERSION_EVENT = _DEFAULT_API_NAME + '_camel_case_conversion' # TODO(b/207467120): remove translation after BETA deprecation. _UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1 = { 'allowlist': 'include_objects', 'rejectlist': 'exclude_objects', 'source_connection_profile_name': 'source_connection_profile', 'destination_connection_profile_name': 'destination_connection_profile', 'vpc_name': 'vpc', } RDBMS_FIELD_NAME_BY_RELEASE_TRACK = { 'schema': { base.ReleaseTrack.BETA: 'schema_name', base.ReleaseTrack.GA: 'schema' }, 'database': { base.ReleaseTrack.BETA: 'database_name', base.ReleaseTrack.GA: 'database' }, 'table': { base.ReleaseTrack.BETA: 'table_name', base.ReleaseTrack.GA: 'table' }, 'column': { base.ReleaseTrack.BETA: 'column_name', base.ReleaseTrack.GA: 'column' }, 'include_objects': { base.ReleaseTrack.BETA: 'allowlist', base.ReleaseTrack.GA: 'include_objects' }, 'exclude_objects': { base.ReleaseTrack.BETA: 'rejectlist', base.ReleaseTrack.GA: 'exclude_objects' } } def ParseJsonAndValidateSchema(parsed_json, schema_name, message_type): """Parses a config message from a parsed JSON and validates its schema.""" schema_path = export_util.GetSchemaPath( _DEFAULT_API_NAME, _DEFAULT_API_VERSION, schema_name, for_help=False ) message = CreateMessageWithCamelCaseConversion( message_type=message_type, parsed_yaml=parsed_json, schema_path=schema_path, ) return message def ParseMessageAndValidateSchema(config_file_path, schema_name, message_type): """Parses a config message from a file and validates its schema.""" schema_path = export_util.GetSchemaPath( _DEFAULT_API_NAME, _DEFAULT_API_VERSION, schema_name, for_help=False ) # NOMUTANTS -- not necessary here. data = console_io.ReadFromFileOrStdin(config_file_path, binary=False) parsed_yaml = yaml.load(data) message = CreateMessageWithCamelCaseConversion( message_type=message_type, parsed_yaml=parsed_yaml, schema_path=schema_path, ) return message def GetClientInstance(api_version=_DEFAULT_API_VERSION, no_http=False): return apis.GetClientInstance('datastream', api_version, no_http=no_http) def GetMessagesModule(api_version=_DEFAULT_API_VERSION): return apis.GetMessagesModule('datastream', api_version) def GetResourceParser(api_version=_DEFAULT_API_VERSION): resource_parser = resources.Registry() resource_parser.RegisterApiByName('datastream', api_version) return resource_parser def ParentRef(project, location): """Get the resource name of the parent collection. Args: project: the project of the parent collection. location: the GCP region of the membership. Returns: the resource name of the parent collection in the format of `projects/{project}/locations/{location}`. """ return 'projects/{}/locations/{}'.format(project, location) def GenerateRequestId(): """Generates a UUID to use as the request ID. Returns: string, the 40-character UUID for the request ID. """ return six.text_type(uuid.uuid4()) def ParseMysqlRdbmsFile( messages, mysql_rdbms_file, release_track=base.ReleaseTrack.BETA ): """Parses a mysql_rdbms_file into the MysqlRdbms message.""" if release_track != base.ReleaseTrack.BETA: return ParseMessageAndValidateSchema( mysql_rdbms_file, 'MysqlRdbms', messages.MysqlRdbms ) return ParseMysqlRdbmsFileBeta(messages, mysql_rdbms_file, release_track) def ParseOracleRdbmsFile( messages, oracle_rdbms_file, release_track=base.ReleaseTrack.BETA ): """Parses a oracle_rdbms_file into the OracleRdbms message.""" if release_track != base.ReleaseTrack.BETA: return ParseMessageAndValidateSchema( oracle_rdbms_file, 'OracleRdbms', messages.OracleRdbms ) return ParseOracleRdbmsFileBeta(messages, oracle_rdbms_file, release_track) def ParsePostgresqlRdbmsFile(messages, postgresql_rdbms_file): """Parses a postgresql_rdbms_file into the PostgresqlRdbms message.""" return ParseMessageAndValidateSchema( postgresql_rdbms_file, 'PostgresqlRdbms', messages.PostgresqlRdbms ) def ParseSqlServerRdbmsFile(messages, sqlserver_rdbms_file): """Parses a sqlserver_rdbms_file into the SqlServerRdbms message.""" return ParseMessageAndValidateSchema( sqlserver_rdbms_file, 'SqlServerRdbms', messages.SqlServerRdbms ) def ParseSalesforceOrgFile(messages, salesforce_org_file): """Parses a salesforce_org_file into the SalesforceOrg message.""" return ParseMessageAndValidateSchema( salesforce_org_file, 'SalesforceOrg', messages.SalesforceOrg ) def ParseMongodbFile(messages, mongodb_file): """Parses a mongodb_file into the MongodbCluster message.""" return ParseMessageAndValidateSchema( mongodb_file, 'MongodbCluster', messages.MongodbCluster ) def CreateMessageWithCamelCaseConversion( message_type, parsed_yaml, schema_path=None ): """Create a message from a yaml dict. Similar to export_util.Import (since we convert to camel case before) Args: message_type: a Datastream message type to create. parsed_yaml: dict schema_path: str, path to the message schema to validate against. Returns: a Datastream message. """ converted_yaml = camel_case_utils.ConvertYamlToCamelCase(parsed_yaml) if schema_path: # If a schema is provided, validate against it. export_util.ValidateYAML(converted_yaml, schema_path) try: message = api_encoding.PyValueToMessage(message_type, converted_yaml) except Exception as e: raise exceptions.ParseError('Cannot parse YAML: [{0}]'.format(e)) return message # TODO(b/207467120): deprecate BETA client. def GetRDBMSV1alpha1ToV1FieldName(field, release_track): return RDBMS_FIELD_NAME_BY_RELEASE_TRACK.get(field, {}).get( release_track, field ) def _GetRDBMSFieldName(field, release_track): return RDBMS_FIELD_NAME_BY_RELEASE_TRACK.get(field, {}).get( release_track, field ) # Deprecated BETA methods - TODO(b/207467120). # remove after full BETA deprecation. def ParseMysqlColumn(messages, mysql_column_object, release_track): """Parses a raw mysql column json/yaml into the MysqlColumn message.""" message = messages.MysqlColumn( column=mysql_column_object.get( _GetRDBMSFieldName('column', release_track), '')) data_type = mysql_column_object.get('data_type') if data_type is not None: message.dataType = data_type collation = mysql_column_object.get('collation') if collation is not None: message.collation = collation length = mysql_column_object.get('length') if length is not None: message.length = length nullable = mysql_column_object.get('nullable') if nullable is not None: message.nullable = nullable ordinal_position = mysql_column_object.get('ordinal_position') if ordinal_position is not None: message.ordinalPosition = ordinal_position primary_key = mysql_column_object.get('primary_key') if primary_key is not None: message.primaryKey = primary_key return message def ParseMysqlTable(messages, mysql_table_object, release_track): """Parses a raw mysql table json/yaml into the MysqlTable message.""" mysql_column_msg_list = [] for column in mysql_table_object.get('mysql_columns', []): mysql_column_msg_list.append( ParseMysqlColumn(messages, column, release_track)) table_key = _GetRDBMSFieldName('table', release_track) table_name = mysql_table_object.get(table_key) if not table_name: raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' % table_key) return messages.MysqlTable( table=table_name, mysqlColumns=mysql_column_msg_list) def ParseMysqlDatabase(messages, mysql_database_object, release_track): """Parses a raw mysql database json/yaml into the MysqlDatabase message.""" mysql_tables_msg_list = [] for table in mysql_database_object.get('mysql_tables', []): mysql_tables_msg_list.append( ParseMysqlTable(messages, table, release_track)) database_key = _GetRDBMSFieldName('database', release_track) database_name = mysql_database_object.get(database_key) if not database_name: raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' % database_key) return messages.MysqlDatabase( database=database_name, mysqlTables=mysql_tables_msg_list) def ParseMysqlSchemasListToMysqlRdbmsMessage(messages, mysql_rdbms_data, release_track=base.ReleaseTrack .BETA): """Parses an object of type {mysql_databases: [...]} into the MysqlRdbms message.""" mysql_databases_raw = mysql_rdbms_data.get('mysql_databases', []) mysql_database_msg_list = [] for schema in mysql_databases_raw: mysql_database_msg_list.append( ParseMysqlDatabase(messages, schema, release_track)) mysql_rdbms_msg = messages.MysqlRdbms( mysqlDatabases=mysql_database_msg_list) return mysql_rdbms_msg def ParseOracleColumn(messages, oracle_column_object, release_track): """Parses a raw oracle column json/yaml into the OracleColumn message.""" message = messages.OracleColumn( column=oracle_column_object.get( _GetRDBMSFieldName('column', release_track), '')) data_type = oracle_column_object.get('data_type') if data_type is not None: message.dataType = data_type encoding = oracle_column_object.get('encoding') if encoding is not None: message.encoding = encoding length = oracle_column_object.get('length') if length is not None: message.length = length nullable = oracle_column_object.get('nullable') if nullable is not None: message.nullable = nullable ordinal_position = oracle_column_object.get('ordinal_position') if ordinal_position is not None: message.ordinalPosition = ordinal_position precision = oracle_column_object.get('precision') if precision is not None: message.precision = precision primary_key = oracle_column_object.get('primary_key') if primary_key is not None: message.primaryKey = primary_key scale = oracle_column_object.get('scale') if scale is not None: message.scale = scale return message def ParseOracleTable(messages, oracle_table_object, release_track): """Parses a raw oracle table json/yaml into the OracleTable message.""" oracle_columns_msg_list = [] for column in oracle_table_object.get('oracle_columns', []): oracle_columns_msg_list.append( ParseOracleColumn(messages, column, release_track)) table_key = _GetRDBMSFieldName('table', release_track) table_name = oracle_table_object.get(table_key) if not table_name: raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' % table_key) return messages.OracleTable( table=table_name, oracleColumns=oracle_columns_msg_list) def ParseOracleSchema(messages, oracle_schema_object, release_track): """Parses a raw oracle schema json/yaml into the OracleSchema message.""" oracle_tables_msg_list = [] for table in oracle_schema_object.get('oracle_tables', []): oracle_tables_msg_list.append( ParseOracleTable(messages, table, release_track)) schema_key = _GetRDBMSFieldName('schema', release_track) schema_name = oracle_schema_object.get(schema_key) if not schema_name: raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' % schema_key) return messages.OracleSchema( schema=schema_name, oracleTables=oracle_tables_msg_list) def ParseOracleSchemasListToOracleRdbmsMessage(messages, oracle_rdbms_data, release_track=base.ReleaseTrack .BETA): """Parses an object of type {oracle_schemas: [...]} into the OracleRdbms message.""" oracle_schemas_raw = oracle_rdbms_data.get('oracle_schemas', []) oracle_schema_msg_list = [] for schema in oracle_schemas_raw: oracle_schema_msg_list.append( ParseOracleSchema(messages, schema, release_track)) oracle_rdbms_msg = messages.OracleRdbms( oracleSchemas=oracle_schema_msg_list) return oracle_rdbms_msg def ParsePostgresqlColumn(messages, postgresql_column_object): """Parses a raw postgresql column json/yaml into the PostgresqlColumn message.""" message = messages.PostgresqlColumn( column=postgresql_column_object.get('column', '')) data_type = postgresql_column_object.get('data_type') if data_type is not None: message.dataType = data_type length = postgresql_column_object.get('length') if length is not None: message.length = length precision = postgresql_column_object.get('precision') if precision is not None: message.precision = precision scale = postgresql_column_object.get('scale') if scale is not None: message.scale = scale primary_key = postgresql_column_object.get('primary_key') if primary_key is not None: message.primaryKey = primary_key nullable = postgresql_column_object.get('nullable') if nullable is not None: message.nullable = nullable ordinal_position = postgresql_column_object.get('ordinal_position') if ordinal_position is not None: message.ordinalPosition = ordinal_position return message def ParsePostgresqlTable(messages, postgresql_table_object): """Parses a raw postgresql table json/yaml into the PostgresqlTable message.""" postgresql_columns_msg_list = [] for column in postgresql_table_object.get('postgresql_columns', []): postgresql_columns_msg_list.append(ParsePostgresqlColumn(messages, column)) table_name = postgresql_table_object.get('table') if not table_name: raise ds_exceptions.ParseError('Cannot parse YAML: missing key "table".') return messages.PostgresqlTable( table=table_name, postgresqlColumns=postgresql_columns_msg_list) def ParsePostgresqlSchema(messages, postgresql_schema_object): """Parses a raw postgresql schema json/yaml into the PostgresqlSchema message.""" postgresql_tables_msg_list = [] for table in postgresql_schema_object.get('postgresql_tables', []): postgresql_tables_msg_list.append(ParsePostgresqlTable(messages, table)) schema_name = postgresql_schema_object.get('schema') if not schema_name: raise ds_exceptions.ParseError('Cannot parse YAML: missing key "schema".') return messages.PostgresqlSchema( schema=schema_name, postgresqlTables=postgresql_tables_msg_list) def ParsePostgresqlSchemasListToPostgresqlRdbmsMessage(messages, postgresql_rdbms_data): """Parses an object of type {postgresql_schemas: [...]} into the PostgresqlRdbms message.""" postgresql_schemas_raw = postgresql_rdbms_data.get('postgresql_schemas', []) postgresql_schema_msg_list = [] for schema in postgresql_schemas_raw: postgresql_schema_msg_list.append(ParsePostgresqlSchema(messages, schema)) postgresql_rdbms_msg = messages.PostgresqlRdbms( postgresqlSchemas=postgresql_schema_msg_list) return postgresql_rdbms_msg def ParseMongodbField(messages, mongodb_field_object): """Parses a raw mongodb field json/yaml into the MongodbField message.""" return messages.MongodbField(field=mongodb_field_object.get('field', '')) def ParseMongodbCollection(messages, mongodb_collection_object): """Parses a raw mongodb database json/yaml into the MongodbCollection message.""" fields_raw = mongodb_collection_object.get('fields', []) fields_msg_list = [] for field in fields_raw: fields_msg_list.append(ParseMongodbField(messages, field)) return messages.MongodbCollection( collection=mongodb_collection_object.get('collection', ''), fields=fields_msg_list, ) def ParseMongodbDatabase(messages, mongodb_database_object): """Parses a raw mongodb database json/yaml into the MongodbDatabase message.""" collections_raw = mongodb_database_object.get('collections', []) collections_msg_list = [] for collection in collections_raw: collections_msg_list.append(ParseMongodbCollection(messages, collection)) return messages.MongodbDatabase(collections=collections_msg_list) def ParseMongodbCluster(messages, mongodb_cluster_object): """Parses a raw mongodb cluster json/yaml into the MongodbCluster message.""" databases_raw = mongodb_cluster_object.get('databases', []) databases_msg_list = [] for database in databases_raw: databases_msg_list.append(ParseMongodbDatabase(messages, database)) return messages.MongodbCluster(databases=databases_msg_list) def UpdateV1alpha1ToV1MaskFields(field_mask): """Updates field mask paths according to the v1alpha1 > v1 Datastream API change. This allows for backwards compatibility with the current client field mask. Args: field_mask: List[str], list of stream fields to update Returns: updated_field_mask: List[str] field mask with fields translated from v1alpha1 API to v1. """ updated_field_mask = [] for path in field_mask: field_to_translate = None for field in _UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1: if field in path: field_to_translate = field break if field_to_translate: updated_field_mask.append( path.replace( field_to_translate, _UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1[field_to_translate]) ) else: updated_field_mask.append(path) return updated_field_mask def ParseMysqlRdbmsFileBeta( messages, mysql_rdbms_file, release_track=base.ReleaseTrack.BETA ): """Parses a mysql_rdbms_file into the MysqlRdbms message. deprecated.""" data = console_io.ReadFromFileOrStdin(mysql_rdbms_file, binary=False) try: mysql_rdbms_head_data = yaml.load(data) except Exception as e: raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e)) mysql_rdbms_data = mysql_rdbms_head_data.get( 'mysql_rdbms', mysql_rdbms_head_data ) return ParseMysqlSchemasListToMysqlRdbmsMessage( messages, mysql_rdbms_data, release_track ) def ParseOracleRdbmsFileBeta( messages, oracle_rdbms_file, release_track=base.ReleaseTrack.BETA ): """Parses a oracle_rdbms_file into the OracleRdbms message. deprecated.""" data = console_io.ReadFromFileOrStdin(oracle_rdbms_file, binary=False) try: oracle_rdbms_head_data = yaml.load(data) except Exception as e: raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e)) oracle_rdbms_data = oracle_rdbms_head_data.get( 'oracle_rdbms', oracle_rdbms_head_data ) return ParseOracleSchemasListToOracleRdbmsMessage( messages, oracle_rdbms_data, release_track )