feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for camel case/snake case conversions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import re
from googlecloudsdk.core import metrics
import six
_DEFAULT_API_NAME = 'datastream'
CAMEL_CASE_CONVERSION_EVENT = _DEFAULT_API_NAME + '_camel_case_conversion'
def ConvertYamlToCamelCase(yaml_dict):
"""Recursively goes through the dictionary obj and replaces keys with the convert function.
taken from:
https://stackoverflow.com/questions/11700705/how-to-recursively-replace-character-in-keys-of-a-nested-dictionary.
Args:
yaml_dict: dict of loaded yaml
Returns:
A converted dict with camelCase keys
"""
# NOMUTANTS -- not necessary here.
if isinstance(yaml_dict, (str, int, float)):
return yaml_dict
if isinstance(yaml_dict, dict):
new = yaml_dict.__class__()
for k, v in yaml_dict.items():
new[SnakeToCamelCase(k)] = ConvertYamlToCamelCase(v)
elif isinstance(yaml_dict, (list, set, tuple)):
new = yaml_dict.__class__(ConvertYamlToCamelCase(v) for v in yaml_dict)
else:
return yaml_dict
return new
def SnakeToCamelCase(value):
"""Convert value from snake_case to camelCase."""
# If it's not snake_case format
if not re.match(r'[a-zA-Z]+_[a-zA-Z]+', value):
return value
# Remove unnecessary characters from beginning of line.
string = re.sub(r'^[\-_\.]', '', six.text_type(value.lower()))
if not string:
return string
# Record snake to camel case conversion (for tracking purposes)
metrics.CustomTimedEvent(CAMEL_CASE_CONVERSION_EVENT)
# convert first character to lower and replace characters
# after '_' to upppercase.
return string[0].lower() + re.sub(
r'[\-_\.\s]([a-z])', lambda matched: matched.group(1).upper(), string[1:]
)

View File

@@ -0,0 +1,999 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream connection profiles API."""
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
from googlecloudsdk.api_lib.datastream import util
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.calliope.arg_parsers import HostPort
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
def GetConnectionProfileURI(resource):
connection_profile = resources.REGISTRY.ParseRelativeName(
resource.name,
collection='datastream.projects.locations.connectionProfiles')
return connection_profile.SelfLink()
class ConnectionProfilesClient:
"""Client for connection profiles service in the API."""
def __init__(self, client=None, messages=None):
self._client = client or util.GetClientInstance()
self._messages = messages or util.GetMessagesModule()
self._service = self._client.projects_locations_connectionProfiles
self._resource_parser = util.GetResourceParser()
def _ValidateArgs(self, args):
self._ValidateSslConfigArgs(args)
def _ValidateSslConfigArgs(self, args):
"""Validates Format of all SSL config args."""
self._ValidateCertificateFormat(args.ca_certificate, 'CA certificate')
self._ValidateCertificateFormat(args.client_certificate,
'client certificate')
self._ValidateCertificateFormat(args.client_key, 'client key')
# Validation for all Postgresql SSL config fields.
self._ValidateCertificateFormat(
args.postgresql_ca_certificate, 'Postgresql CA certificate'
)
self._ValidateCertificateFormat(
args.postgresql_client_certificate, 'Postgresql client certificate'
)
self._ValidateCertificateFormat(
args.postgresql_client_key, 'Postgresql client private key'
)
# Validation for Oracle SSL config fields.
self._ValidateCertificateFormat(
args.oracle_ca_certificate, 'Oracle CA certificate'
)
def _ValidateCertificateFormat(self, certificate, name):
if not certificate:
return True
cert = certificate.strip()
cert_lines = cert.split('\n')
if (not cert_lines[0].startswith('-----') or
not cert_lines[-1].startswith('-----')):
raise exceptions.InvalidArgumentException(
name,
'The certificate does not appear to be in PEM format: \n{0}'.format(
cert))
def _GetSslConfig(self, args):
return self._messages.MysqlSslConfig(
clientKey=args.client_key,
clientCertificate=args.client_certificate,
caCertificate=args.ca_certificate)
def _GetMySqlProfile(self, args):
ssl_config = self._GetSslConfig(args)
return self._messages.MysqlProfile(
hostname=args.mysql_hostname,
port=args.mysql_port,
username=args.mysql_username,
password=args.mysql_password,
secretManagerStoredPassword=args.mysql_secret_manager_stored_password,
sslConfig=ssl_config)
def _GetOracleProfile(self, args):
ssl_config = self._GetOracleSslConfig(args)
return self._messages.OracleProfile(
hostname=args.oracle_hostname,
port=args.oracle_port,
username=args.oracle_username,
password=args.oracle_password,
secretManagerStoredPassword=args.oracle_secret_manager_stored_password,
databaseService=args.database_service,
oracleSslConfig=ssl_config)
def _GetOracleSslConfig(self, args):
"""Returns a OracleSslConfig message based on the given args."""
return self._messages.OracleSslConfig(
caCertificate=args.oracle_ca_certificate,
serverCertificateDistinguishedName=args.oracle_server_certificate_distinguished_name,
)
def _GetPostgresqlSslConfig(self, args):
"""Returns a PostgresqlSslConfig message based on the given args."""
if args.postgresql_client_certificate or args.postgresql_client_key:
return self._messages.PostgresqlSslConfig(
serverAndClientVerification=self._messages.ServerAndClientVerification(
clientCertificate=args.postgresql_client_certificate,
clientKey=args.postgresql_client_key,
caCertificate=args.postgresql_ca_certificate,
serverCertificateHostname=args.postgresql_server_certificate_hostname,
)
)
if args.postgresql_ca_certificate:
return self._messages.PostgresqlSslConfig(
serverVerification=self._messages.ServerVerification(
caCertificate=args.postgresql_ca_certificate,
serverCertificateHostname=args.postgresql_server_certificate_hostname,
)
)
return None
def _GetPostgresqlProfile(self, args):
ssl_config = self._GetPostgresqlSslConfig(args)
return self._messages.PostgresqlProfile(
hostname=args.postgresql_hostname,
port=args.postgresql_port,
username=args.postgresql_username,
password=args.postgresql_password,
secretManagerStoredPassword=args.postgresql_secret_manager_stored_password,
database=args.postgresql_database,
sslConfig=ssl_config)
def _GetSqlServerProfile(self, args):
return self._messages.SqlServerProfile(
hostname=args.sqlserver_hostname,
port=args.sqlserver_port,
username=args.sqlserver_username,
password=args.sqlserver_password,
secretManagerStoredPassword=args.sqlserver_secret_manager_stored_password,
database=args.sqlserver_database,
)
def _GetSalesforceProfile(self, args):
if args.salesforce_oauth2_client_id:
return self._messages.SalesforceProfile(
domain=args.salesforce_domain,
oauth2ClientCredentials=self._messages.Oauth2ClientCredentials(
clientId=args.salesforce_oauth2_client_id,
clientSecret=args.salesforce_oauth2_client_secret,
secretManagerStoredClientSecret=args.salesforce_secret_manager_stored_oauth2_client_secret,
),
)
else:
return self._messages.SalesforceProfile(
domain=args.salesforce_domain,
userCredentials=self._messages.UserCredentials(
username=args.salesforce_username,
password=args.salesforce_password,
secretManagerStoredPassword=args.salesforce_secret_manager_stored_password,
securityToken=args.salesforce_security_token,
secretManagerStoredSecurityToken=args.salesforce_secret_manager_stored_security_token,
),
)
def _GetGCSProfile(self, args, release_track):
# TODO(b/207467120): remove bucket_name arg check.
if release_track == base.ReleaseTrack.BETA:
bucket = args.bucket_name
else:
bucket = args.bucket
gcs_profile = self._messages.GcsProfile(bucket=bucket)
gcs_profile.rootPath = args.root_path if args.root_path else '/'
return gcs_profile
def _GetMongodbProfile(self, args):
"""Returns the MongoDB profile message based on the given args."""
addresses = []
for host_address in args.mongodb_host_addresses:
if args.mongodb_srv_connection_format:
addresses.append(
self._messages.HostAddress(hostname=host_address)
)
else:
hostport = HostPort.Parse(host_address)
addresses.append(
self._messages.HostAddress(
hostname=hostport.host, port=int(hostport.port)
)
)
profile = self._messages.MongodbProfile(
hostAddresses=addresses,
username=args.mongodb_username,
replicaSet=args.mongodb_replica_set,
password=args.mongodb_password,
secretManagerStoredPassword=args.mongodb_secret_manager_stored_password,
)
if (
args.mongodb_direct_connection
and not args.mongodb_standard_connection_format
):
raise exceptions.InvalidArgumentException(
'mongodb-direct-connection',
'mongodb direct connection can only be used with the standard'
' connection format.',
)
if args.mongodb_srv_connection_format:
profile.srvConnectionFormat = {}
if args.mongodb_standard_connection_format:
profile.standardConnectionFormat = (
self._messages.StandardConnectionFormat(
directConnection=args.mongodb_direct_connection
)
)
if args.mongodb_tls:
profile.sslConfig = {}
if args.mongodb_ca_certificate:
profile.sslConfig.caCertificate = args.mongodb_ca_certificate
return profile
def _ParseSslConfig(self, data):
return self._messages.MysqlSslConfig(
clientKey=data.get('client_key'),
clientCertificate=data.get('client_certificate'),
caCertificate=data.get('ca_certificate'))
def _ParseMySqlProfile(self, data):
if not data:
return {}
ssl_config = self._ParseSslConfig(data)
return self._messages.MysqlProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
sslConfig=ssl_config)
def _ParseOracleProfile(self, data):
if not data:
return {}
return self._messages.OracleProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
databaseService=data.get('database_service'))
def _ParsePostgresqlProfile(self, data):
if not data:
return {}
return self._messages.PostgresqlProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
database=data.get('database'))
def _ParseSqlServerProfile(self, data):
if not data:
return {}
return self._messages.SqlServerProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
database=data.get('database'),
)
def _ParseGCSProfile(self, data):
if not data:
return {}
return self._messages.GcsProfile(
bucket=data.get('bucket_name'), rootPath=data.get('root_path'))
def _GetForwardSshTunnelConnectivity(self, args):
return self._messages.ForwardSshTunnelConnectivity(
hostname=args.forward_ssh_hostname,
port=args.forward_ssh_port,
username=args.forward_ssh_username,
privateKey=args.forward_ssh_private_key,
password=args.forward_ssh_password)
def _GetConnectionProfile(self, cp_type, connection_profile_id, args,
release_track):
"""Returns a connection profile according to type."""
labels = labels_util.ParseCreateArgs(
args, self._messages.ConnectionProfile.LabelsValue)
connection_profile_obj = self._messages.ConnectionProfile(
name=connection_profile_id, labels=labels,
displayName=args.display_name)
if cp_type == 'MYSQL':
connection_profile_obj.mysqlProfile = self._GetMySqlProfile(args)
elif cp_type == 'ORACLE':
connection_profile_obj.oracleProfile = self._GetOracleProfile(args)
elif cp_type == 'POSTGRESQL':
connection_profile_obj.postgresqlProfile = self._GetPostgresqlProfile(
args)
elif cp_type == 'SQLSERVER':
connection_profile_obj.sqlServerProfile = self._GetSqlServerProfile(args)
elif cp_type == 'GOOGLE-CLOUD-STORAGE':
connection_profile_obj.gcsProfile = self._GetGCSProfile(
args, release_track)
elif cp_type == 'BIGQUERY':
connection_profile_obj.bigqueryProfile = self._messages.BigQueryProfile()
elif cp_type == 'SALESFORCE':
connection_profile_obj.salesforceProfile = self._GetSalesforceProfile(
args
)
elif cp_type == 'MONGODB':
connection_profile_obj.mongodbProfile = self._GetMongodbProfile(args)
else:
raise exceptions.InvalidArgumentException(
cp_type,
'The connection profile type {0} is either unknown or not supported'
' yet.'.format(cp_type),
)
# TODO(b/207467120): deprecate BETA client.
if release_track == base.ReleaseTrack.BETA:
private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
else:
private_connectivity_ref = args.CONCEPTS.private_connection.Parse()
if private_connectivity_ref:
connection_profile_obj.privateConnectivity = (
self._messages.PrivateConnectivity(
privateConnection=private_connectivity_ref.RelativeName()
)
)
elif args.forward_ssh_hostname:
connection_profile_obj.forwardSshConnectivity = (
self._GetForwardSshTunnelConnectivity(args)
)
elif args.static_ip_connectivity:
connection_profile_obj.staticServiceIpConnectivity = {}
return connection_profile_obj
def _ParseConnectionProfileObjectFile(
self, connection_profile_object_file, release_track
):
"""Parses a connection-profile-file into the ConnectionProfile message."""
if release_track != base.ReleaseTrack.BETA:
return util.ParseMessageAndValidateSchema(
connection_profile_object_file,
'ConnectionProfile',
self._messages.ConnectionProfile,
)
data = console_io.ReadFromFileOrStdin(
connection_profile_object_file, binary=False)
try:
connection_profile_data = yaml.load(data)
except Exception as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
display_name = connection_profile_data.get('display_name')
labels = connection_profile_data.get('labels')
connection_profile_msg = self._messages.ConnectionProfile(
displayName=display_name,
labels=labels)
oracle_profile = self._ParseOracleProfile(
connection_profile_data.get('oracle_profile', {}))
mysql_profile = self._ParseMySqlProfile(
connection_profile_data.get('mysql_profile', {}))
postgresql_profile = self._ParsePostgresqlProfile(
connection_profile_data.get('postgresql_profile', {}))
sqlserver_profile = self._ParseSqlServerProfile(
connection_profile_data.get('sqlserver_profile', {})
)
gcs_profile = self._ParseGCSProfile(
connection_profile_data.get('gcs_profile', {}))
if oracle_profile:
connection_profile_msg.oracleProfile = oracle_profile
elif mysql_profile:
connection_profile_msg.mysqlProfile = mysql_profile
elif postgresql_profile:
connection_profile_msg.postgresqlProfile = postgresql_profile
elif sqlserver_profile:
connection_profile_msg.sqlServerProfile = sqlserver_profile
elif gcs_profile:
connection_profile_msg.gcsProfile = gcs_profile
if 'static_service_ip_connectivity' in connection_profile_data:
connection_profile_msg.staticServiceIpConnectivity = (
connection_profile_data.get('static_service_ip_connectivity')
)
elif 'forward_ssh_connectivity' in connection_profile_data:
connection_profile_msg.forwardSshConnectivity = (
connection_profile_data.get('forward_ssh_connectivity')
)
elif 'private_connectivity' in connection_profile_data:
connection_profile_msg.privateConnectivity = connection_profile_data.get(
'private_connectivity'
)
else:
raise ds_exceptions.ParseError(
'Cannot parse YAML: missing connectivity method.'
)
return connection_profile_msg
def _UpdateForwardSshTunnelConnectivity(
self, connection_profile, args, update_fields
):
"""Updates Forward SSH tunnel connectivity config."""
if args.IsSpecified('forward_ssh_hostname'):
connection_profile.forwardSshConnectivity.hostname = (
args.forward_ssh_hostname
)
update_fields.append('forwardSshConnectivity.hostname')
if args.IsSpecified('forward_ssh_port'):
connection_profile.forwardSshConnectivity.port = args.forward_ssh_port
update_fields.append('forwardSshConnectivity.port')
if args.IsSpecified('forward_ssh_username'):
connection_profile.forwardSshConnectivity.username = (
args.forward_ssh_username
)
update_fields.append('forwardSshConnectivity.username')
if args.IsSpecified('forward_ssh_private_key'):
connection_profile.forwardSshConnectivity.privateKey = (
args.forward_ssh_private_key
)
update_fields.append('forwardSshConnectivity.privateKey')
if args.IsSpecified('forward_ssh_password'):
connection_profile.forwardSshConnectivity.privateKey = (
args.forward_ssh_password
)
update_fields.append('forwardSshConnectivity.password')
def _UpdateGCSProfile(
self, connection_profile, release_track, args, update_fields
):
"""Updates GOOGLE CLOUD STORAGE connection profile."""
# TODO(b/207467120): remove bucket_name arg check.
if release_track == base.ReleaseTrack.BETA and args.IsSpecified(
'bucket_name'
):
connection_profile.gcsProfile.bucket = args.bucket_name
update_fields.append('gcsProfile.bucket')
if release_track == base.ReleaseTrack.GA and args.IsSpecified('bucket'):
connection_profile.gcsProfile.bucket = args.bucket
update_fields.append('gcsProfile.bucket')
if args.IsSpecified('root_path'):
connection_profile.gcsProfile.rootPath = args.root_path
update_fields.append('gcsProfile.rootPath')
def _UpdateOracleProfile(self,
connection_profile,
args,
update_fields):
"""Updates Oracle connection profile."""
if args.IsSpecified('oracle_hostname'):
connection_profile.oracleProfile.hostname = args.oracle_hostname
update_fields.append('oracleProfile.hostname')
if args.IsSpecified('oracle_port'):
connection_profile.oracleProfile.port = args.oracle_port
update_fields.append('oracleProfile.port')
if args.IsSpecified('oracle_username'):
connection_profile.oracleProfile.username = args.oracle_username
update_fields.append('oracleProfile.username')
if args.IsSpecified('oracle_password') or args.IsSpecified(
'oracle_secret_manager_stored_password'
):
connection_profile.oracleProfile.password = args.oracle_password
connection_profile.oracleProfile.secretManagerStoredPassword = (
args.oracle_secret_manager_stored_password
)
update_fields.append('oracleProfile.password')
update_fields.append('oracleProfile.secretManagerStoredPassword')
if args.IsSpecified('database_service'):
connection_profile.oracleProfile.databaseService = args.database_service
update_fields.append('oracleProfile.databaseService')
def _UpdateMysqlSslConfig(self, connection_profile, args, update_fields):
"""Updates Mysql SSL config."""
if args.IsSpecified('client_key'):
connection_profile.mysqlProfile.sslConfig.clientKey = args.client_key
update_fields.append('mysqlProfile.sslConfig.clientKey')
if args.IsSpecified('client_certificate'):
connection_profile.mysqlProfile.sslConfig.clientCertificate = (
args.client_certificate
)
update_fields.append('mysqlProfile.sslConfig.clientCertificate')
if args.IsSpecified('ca_certificate'):
connection_profile.mysqlProfile.sslConfig.caCertificate = (
args.ca_certificate
)
update_fields.append('mysqlProfile.sslConfig.caCertificate')
def _UpdateMySqlProfile(self, connection_profile, args, update_fields):
"""Updates MySQL connection profile."""
if args.IsSpecified('mysql_hostname'):
connection_profile.mysqlProfile.hostname = args.mysql_hostname
update_fields.append('mysqlProfile.hostname')
if args.IsSpecified('mysql_port'):
connection_profile.mysqlProfile.port = args.mysql_port
update_fields.append('mysqlProfile.port')
if args.IsSpecified('mysql_username'):
connection_profile.mysqlProfile.username = args.mysql_username
update_fields.append('mysqlProfile.username')
if args.IsSpecified('mysql_password') or args.IsSpecified(
'mysql_secret_manager_stored_password'
):
connection_profile.mysqlProfile.password = args.mysql_password
connection_profile.mysqlProfile.secretManagerStoredPassword = (
args.mysql_secret_manager_stored_password
)
update_fields.append('mysqlProfile.password')
update_fields.append('mysqlProfile.secretManagerStoredPassword')
self._UpdateMysqlSslConfig(connection_profile, args, update_fields)
def _UpdatePostgresqlSslConfig(self, connection_profile, args, update_fields):
"""Updates Postgresql SSL config."""
if args.IsSpecified('postgresql_client_certificate'):
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate = (
args.postgresql_client_certificate
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate'
)
if args.IsSpecified('postgresql_client_key'):
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientKey = (
args.postgresql_client_key
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.clientKey'
)
if args.IsSpecified('postgresql_ca_certificate'):
if connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification:
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate = (
args.postgresql_ca_certificate
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate'
)
else:
connection_profile.postgresqlProfile.sslConfig.serverVerification.caCertificate = (
args.postgresql_ca_certificate
)
update_fields.append(
'postgresqlProfile.sslConfig.serverVerification.caCertificate'
)
if args.IsSpecified('postgresql_server_certificate_hostname'):
if (
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification
):
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname = (
args.postgresql_server_certificate_hostname
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname'
)
else:
connection_profile.postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname = (
args.postgresql_server_certificate_hostname
)
update_fields.append(
'postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname'
)
def _UpdatePostgresqlProfile(self, connection_profile, args, update_fields):
"""Updates Postgresql connection profile."""
if args.IsSpecified('postgresql_hostname'):
connection_profile.postgresqlProfile.hostname = args.postgresql_hostname
update_fields.append('postgresqlProfile.hostname')
if args.IsSpecified('postgresql_port'):
connection_profile.postgresqlProfile.port = args.postgresql_port
update_fields.append('postgresqlProfile.port')
if args.IsSpecified('postgresql_username'):
connection_profile.postgresqlProfile.username = args.postgresql_username
update_fields.append('postgresqlProfile.username')
if args.IsSpecified('postgresql_password') or args.IsSpecified(
'postgresql_secret_manager_stored_password'
):
connection_profile.postgresqlProfile.password = args.postgresql_password
connection_profile.postgresqlProfile.secretManagerStoredPassword = (
args.postgresql_secret_manager_stored_password
)
update_fields.append('postgresqlProfile.password')
update_fields.append('postgresqlProfile.secretManagerStoredPassword')
if args.IsSpecified('postgresql_database'):
connection_profile.postgresqlProfile.database = args.postgresql_database
update_fields.append('postgresqlProfile.database')
self._UpdatePostgresqlSslConfig(connection_profile, args, update_fields)
def _UpdateSqlServerProfile(self, connection_profile, args, update_fields):
"""Updates SqlServer connection profile."""
if args.IsSpecified('sqlserver_hostname'):
connection_profile.sqlServerProfile.hostname = args.sqlserver_hostname
update_fields.append('sqlServerProfile.hostname')
if args.IsSpecified('sqlserver_port'):
connection_profile.sqlServerProfile.port = args.sqlserver_port
update_fields.append('sqlServerProfile.port')
if args.IsSpecified('sqlserver_username'):
connection_profile.sqlServerProfile.username = args.sqlserver_username
update_fields.append('sqlServerProfile.username')
if args.IsSpecified('sqlserver_password') or args.IsSpecified(
'sqlserver_secret_manager_stored_password'
):
connection_profile.sqlServerProfile.password = args.sqlserver_password
connection_profile.sqlServerProfile.secretManagerStoredPassword = (
args.sqlserver_secret_manager_stored_password
)
update_fields.append('sqlServerProfile.password')
update_fields.append('sqlServerProfile.secretManagerStoredPassword')
if args.IsSpecified('sqlserver_database'):
connection_profile.sqlServerProfile.database = args.sqlserver_database
update_fields.append('sqlServerProfile.database')
def _UpdateSalesforceProfile(self, connection_profile, args, update_fields):
"""Updates Salesforce connection profile."""
if args.IsSpecified('salesforce_domain'):
connection_profile.salesforceProfile.domain = args.salesforce_domain
update_fields.append('salesforceProfile.domain')
if args.IsSpecified('salesforce_username'):
connection_profile.salesforceProfile.userCredentials.username = (
args.salesforce_username
)
update_fields.append('salesforceProfile.userCredentials.username')
if args.IsSpecified('salesforce_password') or args.IsSpecified(
'salesforce_secret_manager_stored_password'
):
connection_profile.salesforceProfile.userCredentials.password = (
args.salesforce_password
)
connection_profile.salesforceProfile.userCredentials.secretManagerStoredPassword = (
args.salesforce_secret_manager_stored_password
)
update_fields.append('salesforceProfile.userCredentials.password')
update_fields.append(
'salesforceProfile.userCredentials.secretManagerStoredPassword'
)
if args.IsSpecified('salesforce_security_token') or args.IsSpecified(
'salesforce_secret_manager_stored_security_token'
):
connection_profile.salesforceProfile.userCredentials.securityToken = (
args.salesforce_security_token
)
connection_profile.salesforceProfile.userCredentials.secretManagerStoredSecurityToken = (
args.salesforce_secret_manager_stored_security_token
)
update_fields.append('salesforceProfile.userCredentials.securityToken')
update_fields.append(
'salesforceProfile.userCredentials.secretManagerStoredSecurityToken'
)
if args.IsSpecified('salesforce_oauth2_client_id'):
connection_profile.salesforceProfile.oauth2ClientCredentials.clientId = (
args.salesforce_oauth2_client_id
)
update_fields.append('salesforceProfile.oauth2ClientCredentials.clientId')
if args.IsSpecified('salesforce_oauth2_client_secret') or args.IsSpecified(
'salesforce_secret_manager_stored_oauth2_client_secret'
):
connection_profile.salesforceProfile.oauth2ClientCredentials.clientSecret = (
args.salesforce_oauth2_client_secret
)
connection_profile.salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret = (
args.salesforce_secret_manager_stored_oauth2_client_secret
)
update_fields.append(
'salesforceProfile.oauth2ClientCredentials.clientSecret'
)
update_fields.append(
'salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret'
)
def _UpdateMongodbProfile(self, connection_profile, args, update_fields):
"""Updates MongoDB connection profile."""
if args.IsSpecified('mongodb_host_addresses'):
addresses = []
for host_address in args.mongodb_host_addresses:
if args.mongodb_srv_connection_format:
addresses.append(
self._messages.HostAddress(hostname=host_address)
)
else:
hostname, port = host_address.split(':')
addresses.append(
self._messages.HostAddress(hostname=hostname, port=int(port))
)
connection_profile.mongodbProfile.hostAddresses = addresses
update_fields.append('monogodbProfile.hostAddresses')
if args.IsSpecified('mongodb_replica_set'):
connection_profile.mongodbProfile.replicaSet = args.mongodb_replica_set
update_fields.append('mongodbProfile.replicaSet')
if args.IsSpecified('mongodb_srv_connection_format') or args.IsSpecified(
'mongodb_standard_connection_format'
):
if args.mongodb_srv_connection_format:
connection_profile.mongodbProfile.srvConnectionFormat = {}
if args.mongodb_standard_connection_format:
connection_profile.mongodbProfile.standardConnectionFormat = {}
update_fields.append('mongodbProfile.srvConnectionFormat')
update_fields.append('mongodbProfile.standardConnectionFormat')
if args.IsSpecified('mongodb_username'):
connection_profile.mongodbProfile.username = args.mongodb_username
update_fields.append('mongodbProfile.username')
if args.IsSpecified('mongodb_password') or args.IsSpecified(
'mongodb_secret_manager_stored_password'
):
connection_profile.mongodbProfile.password = args.mongodb_password
connection_profile.mongodbProfile.secretManagerStoredPassword = (
args.mongodb_secret_manager_stored_password
)
update_fields.append('mongodbProfile.password')
update_fields.append('mongodbProfile.secretManagerStoredPassword')
def _GetExistingConnectionProfile(self, name):
get_req = (
self._messages.DatastreamProjectsLocationsConnectionProfilesGetRequest(
name=name
)
)
return self._service.Get(get_req)
def _UpdateLabels(self, connection_profile, args):
"""Updates labels of the connection profile."""
add_labels = labels_util.GetUpdateLabelsDictFromArgs(args)
remove_labels = labels_util.GetRemoveLabelsListFromArgs(args)
value_type = self._messages.ConnectionProfile.LabelsValue
update_result = labels_util.Diff(
additions=add_labels,
subtractions=remove_labels,
clear=args.clear_labels
).Apply(value_type, connection_profile.labels)
if update_result.needs_update:
connection_profile.labels = update_result.labels
def _GetUpdatedConnectionProfile(self, connection_profile, cp_type,
release_track, args):
"""Returns updated connection profile and list of updated fields."""
update_fields = []
if args.IsSpecified('display_name'):
connection_profile.displayName = args.display_name
update_fields.append('displayName')
if cp_type == 'MYSQL':
self._UpdateMySqlProfile(
connection_profile, args, update_fields)
elif cp_type == 'ORACLE':
self._UpdateOracleProfile(connection_profile, args, update_fields)
elif cp_type == 'POSTGRESQL':
self._UpdatePostgresqlProfile(connection_profile, args, update_fields)
elif cp_type == 'SQLSERVER':
self._UpdateSqlServerProfile(connection_profile, args, update_fields)
elif cp_type == 'SALESFORCE':
self._UpdateSalesforceProfile(connection_profile, args, update_fields)
elif cp_type == 'GOOGLE-CLOUD-STORAGE':
self._UpdateGCSProfile(
connection_profile, release_track, args, update_fields
)
elif cp_type == 'BIGQUERY':
# There are currently no parameters that can be updated in a bigquery CP.
pass
elif cp_type == 'MONGODB':
self._UpdateMongodbProfile(connection_profile, args, update_fields)
else:
raise exceptions.InvalidArgumentException(
cp_type,
'The connection profile type {0} is either unknown or not supported'
' yet.'.format(cp_type),
)
# TODO(b/207467120): deprecate BETA client.
if release_track == base.ReleaseTrack.BETA:
private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
else:
private_connectivity_ref = args.CONCEPTS.private_connection.Parse()
if private_connectivity_ref:
connection_profile.privateConnectivity = (
self._messages.PrivateConnectivity(
privateConnectionName=private_connectivity_ref.RelativeName()
)
)
update_fields.append('privateConnectivity')
elif args.forward_ssh_hostname:
self._UpdateForwardSshTunnelConnectivity(
connection_profile, args, update_fields
)
elif args.static_ip_connectivity:
connection_profile.staticServiceIpConnectivity = {}
update_fields.append('staticServiceIpConnectivity')
self._UpdateLabels(connection_profile, args)
return connection_profile, update_fields
def Create(self,
parent_ref,
connection_profile_id,
cp_type,
release_track,
args=None):
"""Creates a connection profile.
Args:
parent_ref: a Resource reference to a parent datastream.projects.locations
resource for this connection profile.
connection_profile_id: str, the name of the resource to create.
cp_type: str, the type of the connection profile ('MYSQL', ''
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for creating the connection profile.
"""
self._ValidateArgs(args)
connection_profile = self._GetConnectionProfile(cp_type,
connection_profile_id, args,
release_track)
# TODO(b/207467120): only use flags from args.
force = False
if release_track == base.ReleaseTrack.BETA or args.force:
force = True
request_id = util.GenerateRequestId()
create_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesCreateRequest
)
create_req = create_req_type(
connectionProfile=connection_profile,
connectionProfileId=connection_profile.name,
parent=parent_ref,
requestId=request_id,
force=force)
return self._service.Create(create_req)
def Update(self, name, cp_type, release_track, args=None):
"""Updates a connection profile.
Args:
name: str, the reference of the connection profile to
update.
cp_type: str, the type of the connection profile ('MYSQL', 'ORACLE')
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was
invoked with.
Returns:
Operation: the operation for updating the connection profile.
"""
self._ValidateArgs(args)
current_cp = self._GetExistingConnectionProfile(name)
updated_cp, update_fields = self._GetUpdatedConnectionProfile(
current_cp, cp_type, release_track, args)
# TODO(b/207467120): only use flags from args.
force = False
if release_track == base.ReleaseTrack.BETA or args.force:
force = True
request_id = util.GenerateRequestId()
update_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesPatchRequest
)
update_req = update_req_type(
connectionProfile=updated_cp,
name=updated_cp.name,
updateMask=','.join(update_fields),
requestId=request_id,
force=force,
)
return self._service.Patch(update_req)
def List(self, project_id, args):
"""Get the list of connection profiles in a project.
Args:
project_id: The project ID to retrieve
args: parsed command line arguments
Returns:
An iterator over all the matching connection profiles.
"""
location_ref = self._resource_parser.Create(
'datastream.projects.locations',
projectsId=project_id,
locationsId=args.location,
)
list_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesListRequest
)
list_req = list_req_type(
parent=location_ref.RelativeName(),
filter=args.filter,
orderBy=','.join(args.sort_by) if args.sort_by else None,
)
return list_pager.YieldFromList(
service=self._client.projects_locations_connectionProfiles,
request=list_req,
limit=args.limit,
batch_size=args.page_size,
field='connectionProfiles',
batch_size_attribute='pageSize')
def Discover(self, parent_ref, release_track, args):
"""Discover a connection profile.
Args:
parent_ref: a Resource reference to a parent datastream.projects.locations
resource for this connection profile.
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for discovering the connection profile.
"""
request = self._messages.DiscoverConnectionProfileRequest()
if args.connection_profile_name:
connection_profile_ref = args.CONCEPTS.connection_profile_name.Parse()
request.connectionProfileName = connection_profile_ref.RelativeName()
elif args.connection_profile_object_file:
request.connectionProfile = self._ParseConnectionProfileObjectFile(
args.connection_profile_object_file, release_track
)
if args.recursive or args.full_hierarchy:
request.fullHierarchy = True
elif args.recursive_depth:
request.hierarchyDepth = (int)(args.recursive_depth)
elif args.hierarchy_depth:
request.hierarchyDepth = (int)(args.hierarchy_depth)
else:
request.fullHierarchy = False
if args.mysql_rdbms_file:
request.mysqlRdbms = util.ParseMysqlRdbmsFile(self._messages,
args.mysql_rdbms_file,
release_track)
elif args.oracle_rdbms_file:
request.oracleRdbms = util.ParseOracleRdbmsFile(self._messages,
args.oracle_rdbms_file,
release_track)
elif args.postgresql_rdbms_file:
request.postgresqlRdbms = util.ParsePostgresqlRdbmsFile(
self._messages, args.postgresql_rdbms_file)
elif args.sqlserver_rdbms_file:
request.sqlServerRdbms = util.ParseSqlServerRdbmsFile(
self._messages, args.sqlserver_rdbms_file
)
discover_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesDiscoverRequest
)
discover_req = discover_req_type(
discoverConnectionProfileRequest=request, parent=parent_ref)
return self._service.Discover(discover_req)
def GetUri(self, name):
"""Get the URL string for a connection profile.
Args:
name: connection profile's full name.
Returns:
URL of the connection profile resource
"""
uri = self._resource_parser.ParseRelativeName(
name, collection='datastream.projects.locations.connectionProfiles')
return uri.SelfLink()

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for user-visible error exceptions to raise in the CLI."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
class ParseError(Exception):
"""File parsing error."""

View File

@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream private connections API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.datastream import util
from googlecloudsdk.calliope import base
from googlecloudsdk.core import resources
def GetPrivateConnectionURI(resource):
private_connection = resources.REGISTRY.ParseRelativeName(
resource.name,
collection='datastream.projects.locations.privateConnections')
return private_connection.SelfLink()
class PrivateConnectionsClient:
"""Client for private connections service in the API."""
def __init__(self, client=None, messages=None):
self._client = client or util.GetClientInstance()
self._messages = messages or util.GetMessagesModule()
self._service = self._client.projects_locations_privateConnections
self._resource_parser = util.GetResourceParser()
def _GetPrivateConnection(self, private_connection_id, release_track, args):
"""Returns a private connection object."""
private_connection_obj = self._messages.PrivateConnection(
name=private_connection_id, labels={}, displayName=args.display_name)
if hasattr(args, 'network_attachment') and args.network_attachment:
private_connection_obj.pscInterfaceConfig = (
self._messages.PscInterfaceConfig(
networkAttachment=args.network_attachment
)
)
else:
# TODO(b/207467120): use only vpc flag.
if release_track == base.ReleaseTrack.BETA:
vpc_peering_ref = args.CONCEPTS.vpc_name.Parse()
else:
vpc_peering_ref = args.CONCEPTS.vpc.Parse()
private_connection_obj.vpcPeeringConfig = self._messages.VpcPeeringConfig(
vpc=vpc_peering_ref.RelativeName(), subnet=args.subnet
)
return private_connection_obj
def Create(self, parent_ref, private_connection_id, release_track, args=None):
"""Creates a private connection.
Args:
parent_ref: a Resource reference to a parent datastream.projects.locations
resource for this private connection.
private_connection_id: str, the name of the resource to create.
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for creating the private connection.
"""
private_connection = self._GetPrivateConnection(private_connection_id,
release_track, args)
request_id = util.GenerateRequestId()
create_req_type = (
self._messages.DatastreamProjectsLocationsPrivateConnectionsCreateRequest
)
create_req = create_req_type(
privateConnection=private_connection,
privateConnectionId=private_connection.name,
parent=parent_ref,
requestId=request_id,
validateOnly=args.validate_only,
)
return self._service.Create(create_req)

View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream private connections API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from googlecloudsdk.api_lib.datastream import util
class RoutesClient:
"""Client for private connections routes service in the API."""
def __init__(self, client=None, messages=None):
self._client = client or util.GetClientInstance()
self._messages = messages or util.GetMessagesModule()
self._service = self._client.projects_locations_privateConnections_routes
self._resource_parser = util.GetResourceParser()
def _GetRoute(self, route_id, args):
"""Returns a route object."""
route_obj = self._messages.Route(
name=route_id, labels={}, displayName=args.display_name,
destinationAddress=args.destination_address,
destinationPort=args.destination_port)
return route_obj
def Create(self, parent_ref, route_id, args=None):
"""Creates a route.
Args:
parent_ref: a Resource reference to a parent datastream.projects.
locations.privateConnections resource for this route.
route_id: str, the name of the resource to create.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for creating the private connection.
"""
route = self._GetRoute(route_id, args)
request_id = util.GenerateRequestId()
create_req_type = self._messages.DatastreamProjectsLocationsPrivateConnectionsRoutesCreateRequest
create_req = create_req_type(
route=route,
routeId=route.name,
parent=parent_ref,
requestId=request_id)
return self._service.Create(create_req)

View File

@@ -0,0 +1,140 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream stream objects API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.datastream import util
class StreamObjectsClient:
"""Client for stream objects service in the API."""
def __init__(self, client=None, messages=None):
self._client = client or util.GetClientInstance()
self._messages = messages or util.GetMessagesModule()
self._service = self._client.projects_locations_streams_objects
self._resource_parser = util.GetResourceParser()
def List(self, project_id, stream, args):
"""Get the list of objects in a stream.
Args:
project_id: The project ID to retrieve
stream: The stream name to retrieve
args: parsed command line arguments
Returns:
An iterator over all the matching stream objects.
"""
stream_ref = self._resource_parser.Create(
'datastream.projects.locations.streams',
projectsId=project_id,
streamsId=stream,
locationsId=args.location)
list_req_type = self._messages.DatastreamProjectsLocationsStreamsObjectsListRequest
list_req = list_req_type(parent=stream_ref.RelativeName())
return list_pager.YieldFromList(
service=self._service,
request=list_req,
limit=args.limit,
batch_size=args.page_size,
field='streamObjects',
batch_size_attribute='pageSize')
def Lookup(self, project_id, stream_id, args):
"""Lookup a stream object.
Args:
project_id:
stream_id:
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
StreamObject: the looked up stream object.
"""
object_identifier = self._messages.SourceObjectIdentifier()
if args.oracle_schema:
object_identifier.oracleIdentifier = (
self._messages.OracleObjectIdentifier(
schema=args.oracle_schema, table=args.oracle_table
)
)
elif args.mysql_database:
object_identifier.mysqlIdentifier = self._messages.MysqlObjectIdentifier(
database=args.mysql_database, table=args.mysql_table
)
elif args.postgresql_schema:
object_identifier.postgresqlIdentifier = (
self._messages.PostgresqlObjectIdentifier(
schema=args.postgresql_schema, table=args.postgresql_table
)
)
elif args.sqlserver_schema:
object_identifier.sqlServerIdentifier = (
self._messages.SqlServerObjectIdentifier(
schema=args.sqlserver_schema, table=args.sqlserver_table
)
)
elif args.salesforce_object_name:
object_identifier.salesforceIdentifier = (
self._messages.SalesforceObjectIdentifier(
objectName=args.salesforce_object_name
)
)
elif args.mongodb_database:
object_identifier.mongodbIdentifier = (
self._messages.MongodbObjectIdentifier(
database=args.mongodb_database, collection=args.mongodb_collection
)
)
stream_ref = self._resource_parser.Create(
'datastream.projects.locations.streams',
projectsId=project_id,
streamsId=stream_id,
locationsId=args.location,
)
lookup_req_type = (
self._messages.DatastreamProjectsLocationsStreamsObjectsLookupRequest
)
lookup_req = lookup_req_type(
lookupStreamObjectRequest=self._messages.LookupStreamObjectRequest(
sourceObjectIdentifier=object_identifier
),
parent=stream_ref.RelativeName(),
)
return self._service.Lookup(lookup_req)
def GetUri(self, name):
"""Get the URL string for a stream object.
Args:
name: stream object's full name.
Returns:
URL of the stream object resource
"""
uri = self._resource_parser.ParseRelativeName(
name, collection='datastream.projects.locations.streams.objects')
return uri.SelfLink()

View File

@@ -0,0 +1,603 @@
# -*- coding: utf-8 -*- #
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream connection profiles API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import json
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
from googlecloudsdk.api_lib.datastream import util
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
_DEFAULT_API_VERSION = 'v1'
def GetStreamURI(resource):
stream = resources.REGISTRY.ParseRelativeName(
resource.name,
collection='datastream.projects.locations.streams')
return stream.SelfLink()
class StreamsClient:
"""Client for streams service in the API."""
def __init__(self, client=None, messages=None):
self._client = client or util.GetClientInstance()
self._messages = messages or util.GetMessagesModule()
self._service = self._client.projects_locations_streams
self._resource_parser = util.GetResourceParser()
def _GetBackfillAllStrategy(self, release_track, args):
"""Gets BackfillAllStrategy message based on Stream objects source type."""
if args.oracle_excluded_objects:
return self._messages.BackfillAllStrategy(
oracleExcludedObjects=util.ParseOracleRdbmsFile(
self._messages, args.oracle_excluded_objects, release_track))
elif args.mysql_excluded_objects:
return self._messages.BackfillAllStrategy(
mysqlExcludedObjects=util.ParseMysqlRdbmsFile(
self._messages, args.mysql_excluded_objects, release_track))
elif args.postgresql_excluded_objects:
return self._messages.BackfillAllStrategy(
postgresqlExcludedObjects=util.ParsePostgresqlRdbmsFile(
self._messages, args.postgresql_excluded_objects))
elif args.sqlserver_excluded_objects:
return self._messages.BackfillAllStrategy(
sqlServerExcludedObjects=util.ParseSqlServerRdbmsFile(
self._messages, args.sqlserver_excluded_objects
)
)
elif args.salesforce_excluded_objects:
return self._messages.BackfillAllStrategy(
salesforceExcludedObjects=util.ParseSalesforceOrgFile(
self._messages, args.salesforce_excluded_objects
)
)
elif args.mongodb_excluded_objects:
return self._messages.BackfillAllStrategy(
mongodbExcludedObjects=util.ParseMongodbFile(
self._messages, args.mongodb_excluded_objects
)
)
return self._messages.BackfillAllStrategy()
def _ParseOracleSourceConfig(self, oracle_source_config_file, release_track):
"""Parses a oracle_sorce_config into the OracleSourceConfig message."""
if release_track == base.ReleaseTrack.BETA:
return self._ParseOracleSourceConfigBeta(
oracle_source_config_file, release_track
)
return util.ParseMessageAndValidateSchema(
oracle_source_config_file,
'OracleSourceConfig',
self._messages.OracleSourceConfig,
)
def _ParseOracleSourceConfigBeta(
self, oracle_source_config_file, release_track
):
"""Parses a oracle_sorce_config into the OracleSourceConfig message."""
data = console_io.ReadFromFileOrStdin(
oracle_source_config_file, binary=False)
try:
oracle_source_config_head_data = yaml.load(data)
except yaml.YAMLParseError as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
oracle_sorce_config_data_object = oracle_source_config_head_data.get(
'oracle_source_config'
)
oracle_source_config = (
oracle_sorce_config_data_object
if oracle_sorce_config_data_object
else oracle_source_config_head_data
)
include_objects_raw = oracle_source_config.get(
util.GetRDBMSV1alpha1ToV1FieldName('include_objects', release_track),
{})
include_objects_data = util.ParseOracleSchemasListToOracleRdbmsMessage(
self._messages, include_objects_raw, release_track)
exclude_objects_raw = oracle_source_config.get(
util.GetRDBMSV1alpha1ToV1FieldName('exclude_objects', release_track),
{})
exclude_objects_data = util.ParseOracleSchemasListToOracleRdbmsMessage(
self._messages, exclude_objects_raw, release_track)
oracle_source_config_msg = self._messages.OracleSourceConfig(
includeObjects=include_objects_data,
excludeObjects=exclude_objects_data,
)
if oracle_source_config.get('max_concurrent_cdc_tasks'):
oracle_source_config_msg.maxConcurrentCdcTasks = oracle_source_config.get(
'max_concurrent_cdc_tasks')
return oracle_source_config_msg
def _ParseMysqlSourceConfig(self, mysql_source_config_file, release_track):
"""Parses a mysql_sorce_config into the MysqlSourceConfig message."""
if release_track == base.ReleaseTrack.BETA:
return self._ParseMysqlSourceConfigBeta(
mysql_source_config_file, release_track
)
return util.ParseMessageAndValidateSchema(
mysql_source_config_file,
'MysqlSourceConfig',
self._messages.MysqlSourceConfig,
)
def _ParseMysqlSourceConfigBeta(
self, mysql_source_config_file, release_track
):
"""Parses an old mysql_sorce_config into the MysqlSourceConfig message."""
data = console_io.ReadFromFileOrStdin(
mysql_source_config_file, binary=False)
try:
mysql_sorce_config_head_data = yaml.load(data)
except yaml.YAMLParseError as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
mysql_sorce_config_data_object = mysql_sorce_config_head_data.get(
'mysql_source_config'
)
mysql_source_config = (
mysql_sorce_config_data_object
if mysql_sorce_config_data_object
else mysql_sorce_config_head_data
)
include_objects_raw = mysql_source_config.get(
util.GetRDBMSV1alpha1ToV1FieldName('include_objects', release_track),
{})
include_objects_data = util.ParseMysqlSchemasListToMysqlRdbmsMessage(
self._messages, include_objects_raw, release_track)
exclude_objects_raw = mysql_source_config.get(
util.GetRDBMSV1alpha1ToV1FieldName('exclude_objects', release_track),
{})
exclude_objects_data = util.ParseMysqlSchemasListToMysqlRdbmsMessage(
self._messages, exclude_objects_raw, release_track)
mysql_sourec_config_msg = self._messages.MysqlSourceConfig(
includeObjects=include_objects_data,
excludeObjects=exclude_objects_data,
)
if mysql_source_config.get('max_concurrent_cdc_tasks'):
mysql_sourec_config_msg.maxConcurrentCdcTasks = mysql_source_config.get(
'max_concurrent_cdc_tasks')
return mysql_sourec_config_msg
def _ParsePostgresqlSourceConfig(self, postgresql_source_config_file):
"""Parses a postgresql_sorce_config into the PostgresqlSourceConfig message."""
return util.ParseMessageAndValidateSchema(
postgresql_source_config_file,
'PostgresqlSourceConfig',
self._messages.PostgresqlSourceConfig,
)
def _ParseSqlServerSourceConfig(self, sqlserver_source_config_file):
"""Parses a sqlserver_sorce_config into the SqlServerSourceConfig message."""
return util.ParseMessageAndValidateSchema(
sqlserver_source_config_file,
'SqlServerSourceConfig',
self._messages.SqlServerSourceConfig,
)
def _ParseSalesforceSourceConfig(self, salesforce_source_config_file):
"""Parses a salesforce_sorce_config into the SalesforceSourceConfig message."""
return util.ParseMessageAndValidateSchema(
salesforce_source_config_file,
'SalesforceSourceConfig',
self._messages.SalesforceSourceConfig,
)
def _ParseMongodbSourceConfig(self, mongodb_source_config_file):
"""Parses a mongodb_source_config into the MongodbSourceConfig message."""
return util.ParseMessageAndValidateSchema(
mongodb_source_config_file,
'MongodbSourceConfig',
self._messages.MongodbSourceConfig,
)
def _ParseGcsDestinationConfig(
self, gcs_destination_config_file, release_track
):
"""Parses a GcsDestinationConfig into the GcsDestinationConfig message."""
if release_track == base.ReleaseTrack.BETA:
return self._ParseGcsDestinationConfigBeta(gcs_destination_config_file)
return util.ParseMessageAndValidateSchema(
gcs_destination_config_file,
'GcsDestinationConfig',
self._messages.GcsDestinationConfig,
)
def _ParseGcsDestinationConfigBeta(self, gcs_destination_config_file):
"""Parses a gcs_destination_config into the GcsDestinationConfig message."""
data = console_io.ReadFromFileOrStdin(
gcs_destination_config_file, binary=False)
try:
gcs_destination_head_config_data = yaml.load(data)
except yaml.YAMLParseError as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
gcs_destination_config_data_object = gcs_destination_head_config_data.get(
'gcs_destination_config'
)
gcs_destination_config_data = (
gcs_destination_config_data_object
if gcs_destination_config_data_object
else gcs_destination_head_config_data
)
path = gcs_destination_config_data.get('path', '')
file_rotation_mb = gcs_destination_config_data.get('file_rotation_mb', {})
file_rotation_interval = gcs_destination_config_data.get(
'file_rotation_interval', {})
gcs_dest_config_msg = self._messages.GcsDestinationConfig(
path=path, fileRotationMb=file_rotation_mb,
fileRotationInterval=file_rotation_interval)
if 'avro_file_format' in gcs_destination_config_data:
gcs_dest_config_msg.avroFileFormat = self._messages.AvroFileFormat()
elif 'json_file_format' in gcs_destination_config_data:
json_file_format_data = gcs_destination_config_data.get(
'json_file_format')
gcs_dest_config_msg.jsonFileFormat = self._messages.JsonFileFormat(
compression=json_file_format_data.get('compression'),
schemaFileFormat=json_file_format_data.get('schema_file_format'))
else:
raise ds_exceptions.ParseError(
'Cannot parse YAML: missing file format.')
return gcs_dest_config_msg
def _ParseRuleSets(self, rule_sets_file):
"""Parses a list of RuleSets from a JSON file."""
data = util.console_io.ReadFromFileOrStdin(rule_sets_file, binary=False)
try:
parsed_rule_sets = json.loads(data)
except json.JSONDecodeError as e:
raise ds_exceptions.ParseError('Cannot parse JSON:[{0}]'.format(e))
rule_sets_list = []
for rule_set in parsed_rule_sets:
rule_sets_list.append(
util.ParseJsonAndValidateSchema(
rule_set, 'RuleSet', self._messages.RuleSet
)
)
return rule_sets_list
def _ParseBigqueryDestinationConfig(self, config_file):
"""Parses a BigQueryDestinationConfig into the BigQueryDestinationConfig message."""
return util.ParseMessageAndValidateSchema(
config_file,
'BigQueryDestinationConfig',
self._messages.BigQueryDestinationConfig,
)
def _GetStream(self, stream_id, release_track, args):
"""Returns a stream object."""
labels = labels_util.ParseCreateArgs(
args, self._messages.Stream.LabelsValue)
stream_obj = self._messages.Stream(
name=stream_id, labels=labels, displayName=args.display_name)
# TODO(b/207467120): use CONCEPTS.source only.
if release_track == base.ReleaseTrack.BETA:
source_connection_profile_ref = args.CONCEPTS.source_name.Parse()
else:
source_connection_profile_ref = args.CONCEPTS.source.Parse()
stream_source_config = self._messages.SourceConfig()
stream_source_config.sourceConnectionProfile = (
source_connection_profile_ref.RelativeName())
if args.oracle_source_config:
stream_source_config.oracleSourceConfig = self._ParseOracleSourceConfig(
args.oracle_source_config, release_track)
elif args.mysql_source_config:
stream_source_config.mysqlSourceConfig = self._ParseMysqlSourceConfig(
args.mysql_source_config, release_track)
elif args.postgresql_source_config:
stream_source_config.postgresqlSourceConfig = (
self._ParsePostgresqlSourceConfig(args.postgresql_source_config)
)
elif args.sqlserver_source_config:
stream_source_config.sqlServerSourceConfig = (
self._ParseSqlServerSourceConfig(args.sqlserver_source_config)
)
elif args.salesforce_source_config:
stream_source_config.salesforceSourceConfig = (
self._ParseSalesforceSourceConfig(args.salesforce_source_config)
)
elif args.mongodb_source_config:
stream_source_config.mongodbSourceConfig = self._ParseMongodbSourceConfig(
args.mongodb_source_config
)
stream_obj.sourceConfig = stream_source_config
# TODO(b/207467120): use CONCEPTS.destination only.
if release_track == base.ReleaseTrack.BETA:
destination_connection_profile_ref = args.CONCEPTS.destination_name.Parse(
)
else:
destination_connection_profile_ref = args.CONCEPTS.destination.Parse()
stream_destination_config = self._messages.DestinationConfig()
stream_destination_config.destinationConnectionProfile = (
destination_connection_profile_ref.RelativeName())
if args.gcs_destination_config:
stream_destination_config.gcsDestinationConfig = (
self._ParseGcsDestinationConfig(
args.gcs_destination_config, release_track
)
)
elif args.bigquery_destination_config:
stream_destination_config.bigqueryDestinationConfig = (
self._ParseBigqueryDestinationConfig(
args.bigquery_destination_config))
stream_obj.destinationConfig = stream_destination_config
if args.backfill_none:
stream_obj.backfillNone = self._messages.BackfillNoneStrategy()
elif args.backfill_all:
backfill_all_strategy = self._GetBackfillAllStrategy(release_track, args)
stream_obj.backfillAll = backfill_all_strategy
if args.rule_sets:
stream_obj.ruleSets = self._ParseRuleSets(args.rule_sets)
return stream_obj
def _GetExistingStream(self, name):
get_req = self._messages.DatastreamProjectsLocationsStreamsGetRequest(
name=name
)
return self._service.Get(get_req)
def _UpdateLabels(self, stream, args):
"""Updates labels of the stream."""
add_labels = labels_util.GetUpdateLabelsDictFromArgs(args)
remove_labels = labels_util.GetRemoveLabelsListFromArgs(args)
value_type = self._messages.Stream.LabelsValue
update_result = labels_util.Diff(
additions=add_labels,
subtractions=remove_labels,
clear=args.clear_labels
).Apply(value_type, stream.labels)
if update_result.needs_update:
stream.labels = update_result.labels
def _UpdateListWithFieldNamePrefixes(
self, update_fields, prefix_to_check, prefix_to_add):
"""Returns an updated list of field masks with necessary prefixes."""
temp_fields = [
prefix_to_add + field
for field in update_fields
if field.startswith(prefix_to_check)
]
update_fields = [
x for x in update_fields if (not x.startswith(prefix_to_check))
]
update_fields.extend(temp_fields)
return update_fields
def _GetUpdatedStream(self, stream, release_track, args):
"""Returns updated stream."""
# Verify command flag names align with Stream object field names.
update_fields = []
user_update_mask = args.update_mask or ''
user_update_mask_list = user_update_mask.split(',')
if release_track == base.ReleaseTrack.BETA:
user_update_mask_list = util.UpdateV1alpha1ToV1MaskFields(
user_update_mask_list)
update_fields.extend(user_update_mask_list)
if args.IsSpecified('display_name'):
stream.displayName = args.display_name
# TODO(b/207467120): use source field only.
if release_track == base.ReleaseTrack.BETA:
source_connection_profile_ref = args.CONCEPTS.source_name.Parse()
source_field_name = 'source_name'
else:
source_connection_profile_ref = args.CONCEPTS.source.Parse()
source_field_name = 'source'
if args.IsSpecified(source_field_name):
stream.sourceConfig.sourceConnectionProfile = (
source_connection_profile_ref.RelativeName())
if source_field_name in update_fields:
update_fields.remove(source_field_name)
update_fields.append('source_config.source_connection_profile')
if args.IsSpecified('oracle_source_config'):
stream.sourceConfig.oracleSourceConfig = self._ParseOracleSourceConfig(
args.oracle_source_config, release_track)
# Fix field names in update mask
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'oracle_source_config', 'source_config.')
elif args.IsSpecified('mysql_source_config'):
stream.sourceConfig.mysqlSourceConfig = self._ParseMysqlSourceConfig(
args.mysql_source_config, release_track)
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'mysql_source_config', 'source_config.')
elif args.IsSpecified('postgresql_source_config'):
stream.sourceConfig.postgresqlSourceConfig = (
self._ParsePostgresqlSourceConfig(args.postgresql_source_config)
)
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'postgresql_source_config', 'source_config.')
elif args.IsSpecified('sqlserver_source_config'):
stream.sourceConfig.sqlServerSourceConfig = (
self._ParseSqlServerSourceConfig(args.sqlserver_source_config)
)
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'sqlserver_source_config', 'source_config.'
)
elif args.IsSpecified('salesforce_source_config'):
stream.sourceConfig.salesforceSourceConfig = (
self._ParseSalesforceSourceConfig(args.salesforce_source_config)
)
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'salesforce_source_config', 'source_config.'
)
# TODO(b/207467120): use source field only.
if release_track == base.ReleaseTrack.BETA:
destination_connection_profile_ref = (
args.CONCEPTS.destination_name.Parse())
destination_field_name = 'destination_name'
else:
destination_connection_profile_ref = (args.CONCEPTS.destination.Parse())
destination_field_name = 'destination'
if args.IsSpecified(destination_field_name):
stream.destinationConfig.destinationConnectionProfile = (
destination_connection_profile_ref.RelativeName())
if destination_field_name in update_fields:
update_fields.remove(destination_field_name)
update_fields.append(
'destination_config.destination_connection_profile')
if args.IsSpecified('gcs_destination_config'):
stream.destinationConfig.gcsDestinationConfig = (
self._ParseGcsDestinationConfig(
args.gcs_destination_config, release_track
)
)
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'gcs_destination_config', 'destination_config.')
elif args.IsSpecified('bigquery_destination_config'):
stream.destinationConfig.bigqueryDestinationConfig = (
self._ParseBigqueryDestinationConfig(
args.bigquery_destination_config))
update_fields = self._UpdateListWithFieldNamePrefixes(
update_fields, 'bigquery_destination_config', 'destination_config.')
if args.IsSpecified('backfill_none'):
stream.backfillNone = self._messages.BackfillNoneStrategy()
# NOMUTANTS--This path has been verified by manual tests.
try:
stream.reset('backfillAll')
except AttributeError:
# Attempt to remove a backfill all
# previous definition, but doesn't exist.
pass
elif args.IsSpecified('backfill_all'):
backfill_all_strategy = self._GetBackfillAllStrategy(release_track, args)
stream.backfillAll = backfill_all_strategy
# NOMUTANTS--This path has been verified by manual tests.
try:
stream.reset('backfillNone')
except AttributeError:
# Attempt to remove a backfill none previous definition,
# but it doesn't exist.
pass
if args.IsSpecified('state'):
stream.state = self._messages.Stream.StateValueValuesEnum(
(args.state).upper())
self._UpdateLabels(stream, args)
return stream, update_fields
def Create(self, parent_ref, stream_id, release_track, args=None):
"""Creates a stream.
Args:
parent_ref: a Resource reference to a parent datastream.projects.locations
resource for this stream.
stream_id: str, the name of the resource to create.
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for creating the stream.
"""
stream = self._GetStream(stream_id, release_track, args)
validate_only = args.validate_only
force = args.force
request_id = util.GenerateRequestId()
create_req_type = self._messages.DatastreamProjectsLocationsStreamsCreateRequest
create_req = create_req_type(
stream=stream,
streamId=stream.name,
parent=parent_ref,
requestId=request_id,
validateOnly=validate_only,
force=force)
return self._service.Create(create_req)
def Update(self, name, release_track, args=None):
"""Updates a stream.
Args:
name: str, the reference of the stream to
update.
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was
invoked with.
Returns:
Operation: the operation for updating the connection profile.
"""
validate_only = args.validate_only
force = args.force
current_stream = self._GetExistingStream(name)
updated_stream, update_fields = self._GetUpdatedStream(
current_stream, release_track, args)
request_id = util.GenerateRequestId()
update_req_type = self._messages.DatastreamProjectsLocationsStreamsPatchRequest
update_req = update_req_type(
stream=updated_stream,
name=updated_stream.name,
requestId=request_id,
validateOnly=validate_only,
force=force
)
if args.update_mask:
update_req.updateMask = ','.join(update_fields)
return self._service.Patch(update_req)

View File

@@ -0,0 +1,555 @@
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream API utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import uuid
from apitools.base.py import encoding as api_encoding
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.datastream import camel_case_utils
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
import six
_DEFAULT_API_VERSION = 'v1'
_DEFAULT_API_NAME = 'datastream'
CAMEL_CASE_CONVERSION_EVENT = _DEFAULT_API_NAME + '_camel_case_conversion'
# TODO(b/207467120): remove translation after BETA deprecation.
_UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1 = {
'allowlist': 'include_objects',
'rejectlist': 'exclude_objects',
'source_connection_profile_name': 'source_connection_profile',
'destination_connection_profile_name': 'destination_connection_profile',
'vpc_name': 'vpc',
}
RDBMS_FIELD_NAME_BY_RELEASE_TRACK = {
'schema': {
base.ReleaseTrack.BETA: 'schema_name',
base.ReleaseTrack.GA: 'schema'
},
'database': {
base.ReleaseTrack.BETA: 'database_name',
base.ReleaseTrack.GA: 'database'
},
'table': {
base.ReleaseTrack.BETA: 'table_name',
base.ReleaseTrack.GA: 'table'
},
'column': {
base.ReleaseTrack.BETA: 'column_name',
base.ReleaseTrack.GA: 'column'
},
'include_objects': {
base.ReleaseTrack.BETA: 'allowlist',
base.ReleaseTrack.GA: 'include_objects'
},
'exclude_objects': {
base.ReleaseTrack.BETA: 'rejectlist',
base.ReleaseTrack.GA: 'exclude_objects'
}
}
def ParseJsonAndValidateSchema(parsed_json, schema_name, message_type):
"""Parses a config message from a parsed JSON and validates its schema."""
schema_path = export_util.GetSchemaPath(
_DEFAULT_API_NAME, _DEFAULT_API_VERSION, schema_name, for_help=False
)
message = CreateMessageWithCamelCaseConversion(
message_type=message_type,
parsed_yaml=parsed_json,
schema_path=schema_path,
)
return message
def ParseMessageAndValidateSchema(config_file_path, schema_name, message_type):
"""Parses a config message from a file and validates its schema."""
schema_path = export_util.GetSchemaPath(
_DEFAULT_API_NAME, _DEFAULT_API_VERSION, schema_name, for_help=False
)
# NOMUTANTS -- not necessary here.
data = console_io.ReadFromFileOrStdin(config_file_path, binary=False)
parsed_yaml = yaml.load(data)
message = CreateMessageWithCamelCaseConversion(
message_type=message_type,
parsed_yaml=parsed_yaml,
schema_path=schema_path,
)
return message
def GetClientInstance(api_version=_DEFAULT_API_VERSION, no_http=False):
return apis.GetClientInstance('datastream', api_version, no_http=no_http)
def GetMessagesModule(api_version=_DEFAULT_API_VERSION):
return apis.GetMessagesModule('datastream', api_version)
def GetResourceParser(api_version=_DEFAULT_API_VERSION):
resource_parser = resources.Registry()
resource_parser.RegisterApiByName('datastream', api_version)
return resource_parser
def ParentRef(project, location):
"""Get the resource name of the parent collection.
Args:
project: the project of the parent collection.
location: the GCP region of the membership.
Returns:
the resource name of the parent collection in the format of
`projects/{project}/locations/{location}`.
"""
return 'projects/{}/locations/{}'.format(project, location)
def GenerateRequestId():
"""Generates a UUID to use as the request ID.
Returns:
string, the 40-character UUID for the request ID.
"""
return six.text_type(uuid.uuid4())
def ParseMysqlRdbmsFile(
messages, mysql_rdbms_file, release_track=base.ReleaseTrack.BETA
):
"""Parses a mysql_rdbms_file into the MysqlRdbms message."""
if release_track != base.ReleaseTrack.BETA:
return ParseMessageAndValidateSchema(
mysql_rdbms_file, 'MysqlRdbms', messages.MysqlRdbms
)
return ParseMysqlRdbmsFileBeta(messages, mysql_rdbms_file, release_track)
def ParseOracleRdbmsFile(
messages, oracle_rdbms_file, release_track=base.ReleaseTrack.BETA
):
"""Parses a oracle_rdbms_file into the OracleRdbms message."""
if release_track != base.ReleaseTrack.BETA:
return ParseMessageAndValidateSchema(
oracle_rdbms_file, 'OracleRdbms', messages.OracleRdbms
)
return ParseOracleRdbmsFileBeta(messages, oracle_rdbms_file, release_track)
def ParsePostgresqlRdbmsFile(messages, postgresql_rdbms_file):
"""Parses a postgresql_rdbms_file into the PostgresqlRdbms message."""
return ParseMessageAndValidateSchema(
postgresql_rdbms_file, 'PostgresqlRdbms', messages.PostgresqlRdbms
)
def ParseSqlServerRdbmsFile(messages, sqlserver_rdbms_file):
"""Parses a sqlserver_rdbms_file into the SqlServerRdbms message."""
return ParseMessageAndValidateSchema(
sqlserver_rdbms_file, 'SqlServerRdbms', messages.SqlServerRdbms
)
def ParseSalesforceOrgFile(messages, salesforce_org_file):
"""Parses a salesforce_org_file into the SalesforceOrg message."""
return ParseMessageAndValidateSchema(
salesforce_org_file, 'SalesforceOrg', messages.SalesforceOrg
)
def ParseMongodbFile(messages, mongodb_file):
"""Parses a mongodb_file into the MongodbCluster message."""
return ParseMessageAndValidateSchema(
mongodb_file, 'MongodbCluster', messages.MongodbCluster
)
def CreateMessageWithCamelCaseConversion(
message_type, parsed_yaml, schema_path=None
):
"""Create a message from a yaml dict.
Similar to export_util.Import (since we convert to camel case before)
Args:
message_type: a Datastream message type to create.
parsed_yaml: dict
schema_path: str, path to the message schema to validate against.
Returns:
a Datastream message.
"""
converted_yaml = camel_case_utils.ConvertYamlToCamelCase(parsed_yaml)
if schema_path:
# If a schema is provided, validate against it.
export_util.ValidateYAML(converted_yaml, schema_path)
try:
message = api_encoding.PyValueToMessage(message_type, converted_yaml)
except Exception as e:
raise exceptions.ParseError('Cannot parse YAML: [{0}]'.format(e))
return message
# TODO(b/207467120): deprecate BETA client.
def GetRDBMSV1alpha1ToV1FieldName(field, release_track):
return RDBMS_FIELD_NAME_BY_RELEASE_TRACK.get(field, {}).get(
release_track, field
)
def _GetRDBMSFieldName(field, release_track):
return RDBMS_FIELD_NAME_BY_RELEASE_TRACK.get(field, {}).get(
release_track, field
)
# Deprecated BETA methods - TODO(b/207467120).
# remove after full BETA deprecation.
def ParseMysqlColumn(messages, mysql_column_object, release_track):
"""Parses a raw mysql column json/yaml into the MysqlColumn message."""
message = messages.MysqlColumn(
column=mysql_column_object.get(
_GetRDBMSFieldName('column', release_track), ''))
data_type = mysql_column_object.get('data_type')
if data_type is not None:
message.dataType = data_type
collation = mysql_column_object.get('collation')
if collation is not None:
message.collation = collation
length = mysql_column_object.get('length')
if length is not None:
message.length = length
nullable = mysql_column_object.get('nullable')
if nullable is not None:
message.nullable = nullable
ordinal_position = mysql_column_object.get('ordinal_position')
if ordinal_position is not None:
message.ordinalPosition = ordinal_position
primary_key = mysql_column_object.get('primary_key')
if primary_key is not None:
message.primaryKey = primary_key
return message
def ParseMysqlTable(messages, mysql_table_object, release_track):
"""Parses a raw mysql table json/yaml into the MysqlTable message."""
mysql_column_msg_list = []
for column in mysql_table_object.get('mysql_columns', []):
mysql_column_msg_list.append(
ParseMysqlColumn(messages, column, release_track))
table_key = _GetRDBMSFieldName('table', release_track)
table_name = mysql_table_object.get(table_key)
if not table_name:
raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
table_key)
return messages.MysqlTable(
table=table_name, mysqlColumns=mysql_column_msg_list)
def ParseMysqlDatabase(messages, mysql_database_object, release_track):
"""Parses a raw mysql database json/yaml into the MysqlDatabase message."""
mysql_tables_msg_list = []
for table in mysql_database_object.get('mysql_tables', []):
mysql_tables_msg_list.append(
ParseMysqlTable(messages, table, release_track))
database_key = _GetRDBMSFieldName('database', release_track)
database_name = mysql_database_object.get(database_key)
if not database_name:
raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
database_key)
return messages.MysqlDatabase(
database=database_name, mysqlTables=mysql_tables_msg_list)
def ParseMysqlSchemasListToMysqlRdbmsMessage(messages,
mysql_rdbms_data,
release_track=base.ReleaseTrack
.BETA):
"""Parses an object of type {mysql_databases: [...]} into the MysqlRdbms message."""
mysql_databases_raw = mysql_rdbms_data.get('mysql_databases', [])
mysql_database_msg_list = []
for schema in mysql_databases_raw:
mysql_database_msg_list.append(
ParseMysqlDatabase(messages, schema, release_track))
mysql_rdbms_msg = messages.MysqlRdbms(
mysqlDatabases=mysql_database_msg_list)
return mysql_rdbms_msg
def ParseOracleColumn(messages, oracle_column_object, release_track):
"""Parses a raw oracle column json/yaml into the OracleColumn message."""
message = messages.OracleColumn(
column=oracle_column_object.get(
_GetRDBMSFieldName('column', release_track), ''))
data_type = oracle_column_object.get('data_type')
if data_type is not None:
message.dataType = data_type
encoding = oracle_column_object.get('encoding')
if encoding is not None:
message.encoding = encoding
length = oracle_column_object.get('length')
if length is not None:
message.length = length
nullable = oracle_column_object.get('nullable')
if nullable is not None:
message.nullable = nullable
ordinal_position = oracle_column_object.get('ordinal_position')
if ordinal_position is not None:
message.ordinalPosition = ordinal_position
precision = oracle_column_object.get('precision')
if precision is not None:
message.precision = precision
primary_key = oracle_column_object.get('primary_key')
if primary_key is not None:
message.primaryKey = primary_key
scale = oracle_column_object.get('scale')
if scale is not None:
message.scale = scale
return message
def ParseOracleTable(messages, oracle_table_object, release_track):
"""Parses a raw oracle table json/yaml into the OracleTable message."""
oracle_columns_msg_list = []
for column in oracle_table_object.get('oracle_columns', []):
oracle_columns_msg_list.append(
ParseOracleColumn(messages, column, release_track))
table_key = _GetRDBMSFieldName('table', release_track)
table_name = oracle_table_object.get(table_key)
if not table_name:
raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
table_key)
return messages.OracleTable(
table=table_name, oracleColumns=oracle_columns_msg_list)
def ParseOracleSchema(messages, oracle_schema_object, release_track):
"""Parses a raw oracle schema json/yaml into the OracleSchema message."""
oracle_tables_msg_list = []
for table in oracle_schema_object.get('oracle_tables', []):
oracle_tables_msg_list.append(
ParseOracleTable(messages, table, release_track))
schema_key = _GetRDBMSFieldName('schema', release_track)
schema_name = oracle_schema_object.get(schema_key)
if not schema_name:
raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
schema_key)
return messages.OracleSchema(
schema=schema_name, oracleTables=oracle_tables_msg_list)
def ParseOracleSchemasListToOracleRdbmsMessage(messages,
oracle_rdbms_data,
release_track=base.ReleaseTrack
.BETA):
"""Parses an object of type {oracle_schemas: [...]} into the OracleRdbms message."""
oracle_schemas_raw = oracle_rdbms_data.get('oracle_schemas', [])
oracle_schema_msg_list = []
for schema in oracle_schemas_raw:
oracle_schema_msg_list.append(
ParseOracleSchema(messages, schema, release_track))
oracle_rdbms_msg = messages.OracleRdbms(
oracleSchemas=oracle_schema_msg_list)
return oracle_rdbms_msg
def ParsePostgresqlColumn(messages, postgresql_column_object):
"""Parses a raw postgresql column json/yaml into the PostgresqlColumn message."""
message = messages.PostgresqlColumn(
column=postgresql_column_object.get('column', ''))
data_type = postgresql_column_object.get('data_type')
if data_type is not None:
message.dataType = data_type
length = postgresql_column_object.get('length')
if length is not None:
message.length = length
precision = postgresql_column_object.get('precision')
if precision is not None:
message.precision = precision
scale = postgresql_column_object.get('scale')
if scale is not None:
message.scale = scale
primary_key = postgresql_column_object.get('primary_key')
if primary_key is not None:
message.primaryKey = primary_key
nullable = postgresql_column_object.get('nullable')
if nullable is not None:
message.nullable = nullable
ordinal_position = postgresql_column_object.get('ordinal_position')
if ordinal_position is not None:
message.ordinalPosition = ordinal_position
return message
def ParsePostgresqlTable(messages, postgresql_table_object):
"""Parses a raw postgresql table json/yaml into the PostgresqlTable message."""
postgresql_columns_msg_list = []
for column in postgresql_table_object.get('postgresql_columns', []):
postgresql_columns_msg_list.append(ParsePostgresqlColumn(messages, column))
table_name = postgresql_table_object.get('table')
if not table_name:
raise ds_exceptions.ParseError('Cannot parse YAML: missing key "table".')
return messages.PostgresqlTable(
table=table_name, postgresqlColumns=postgresql_columns_msg_list)
def ParsePostgresqlSchema(messages, postgresql_schema_object):
"""Parses a raw postgresql schema json/yaml into the PostgresqlSchema message."""
postgresql_tables_msg_list = []
for table in postgresql_schema_object.get('postgresql_tables', []):
postgresql_tables_msg_list.append(ParsePostgresqlTable(messages, table))
schema_name = postgresql_schema_object.get('schema')
if not schema_name:
raise ds_exceptions.ParseError('Cannot parse YAML: missing key "schema".')
return messages.PostgresqlSchema(
schema=schema_name, postgresqlTables=postgresql_tables_msg_list)
def ParsePostgresqlSchemasListToPostgresqlRdbmsMessage(messages,
postgresql_rdbms_data):
"""Parses an object of type {postgresql_schemas: [...]} into the PostgresqlRdbms message."""
postgresql_schemas_raw = postgresql_rdbms_data.get('postgresql_schemas', [])
postgresql_schema_msg_list = []
for schema in postgresql_schemas_raw:
postgresql_schema_msg_list.append(ParsePostgresqlSchema(messages, schema))
postgresql_rdbms_msg = messages.PostgresqlRdbms(
postgresqlSchemas=postgresql_schema_msg_list)
return postgresql_rdbms_msg
def ParseMongodbField(messages, mongodb_field_object):
"""Parses a raw mongodb field json/yaml into the MongodbField message."""
return messages.MongodbField(field=mongodb_field_object.get('field', ''))
def ParseMongodbCollection(messages, mongodb_collection_object):
"""Parses a raw mongodb database json/yaml into the MongodbCollection message."""
fields_raw = mongodb_collection_object.get('fields', [])
fields_msg_list = []
for field in fields_raw:
fields_msg_list.append(ParseMongodbField(messages, field))
return messages.MongodbCollection(
collection=mongodb_collection_object.get('collection', ''),
fields=fields_msg_list,
)
def ParseMongodbDatabase(messages, mongodb_database_object):
"""Parses a raw mongodb database json/yaml into the MongodbDatabase message."""
collections_raw = mongodb_database_object.get('collections', [])
collections_msg_list = []
for collection in collections_raw:
collections_msg_list.append(ParseMongodbCollection(messages, collection))
return messages.MongodbDatabase(collections=collections_msg_list)
def ParseMongodbCluster(messages, mongodb_cluster_object):
"""Parses a raw mongodb cluster json/yaml into the MongodbCluster message."""
databases_raw = mongodb_cluster_object.get('databases', [])
databases_msg_list = []
for database in databases_raw:
databases_msg_list.append(ParseMongodbDatabase(messages, database))
return messages.MongodbCluster(databases=databases_msg_list)
def UpdateV1alpha1ToV1MaskFields(field_mask):
"""Updates field mask paths according to the v1alpha1 > v1 Datastream API change.
This allows for backwards compatibility with the current client field
mask.
Args:
field_mask: List[str], list of stream fields to update
Returns:
updated_field_mask: List[str] field mask with fields translated
from v1alpha1 API to v1.
"""
updated_field_mask = []
for path in field_mask:
field_to_translate = None
for field in _UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1:
if field in path:
field_to_translate = field
break
if field_to_translate:
updated_field_mask.append(
path.replace(
field_to_translate,
_UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1[field_to_translate])
)
else:
updated_field_mask.append(path)
return updated_field_mask
def ParseMysqlRdbmsFileBeta(
messages, mysql_rdbms_file, release_track=base.ReleaseTrack.BETA
):
"""Parses a mysql_rdbms_file into the MysqlRdbms message. deprecated."""
data = console_io.ReadFromFileOrStdin(mysql_rdbms_file, binary=False)
try:
mysql_rdbms_head_data = yaml.load(data)
except Exception as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
mysql_rdbms_data = mysql_rdbms_head_data.get(
'mysql_rdbms', mysql_rdbms_head_data
)
return ParseMysqlSchemasListToMysqlRdbmsMessage(
messages, mysql_rdbms_data, release_track
)
def ParseOracleRdbmsFileBeta(
messages, oracle_rdbms_file, release_track=base.ReleaseTrack.BETA
):
"""Parses a oracle_rdbms_file into the OracleRdbms message. deprecated."""
data = console_io.ReadFromFileOrStdin(oracle_rdbms_file, binary=False)
try:
oracle_rdbms_head_data = yaml.load(data)
except Exception as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
oracle_rdbms_data = oracle_rdbms_head_data.get(
'oracle_rdbms', oracle_rdbms_head_data
)
return ParseOracleSchemasListToOracleRdbmsMessage(
messages, oracle_rdbms_data, release_track
)