1000 lines
40 KiB
Python
1000 lines
40 KiB
Python
# -*- coding: utf-8 -*- #
|
|
# Copyright 2020 Google LLC. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Cloud Datastream connection profiles API."""
|
|
|
|
|
|
from apitools.base.py import list_pager
|
|
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
|
|
from googlecloudsdk.api_lib.datastream import util
|
|
from googlecloudsdk.calliope import base
|
|
from googlecloudsdk.calliope import exceptions
|
|
from googlecloudsdk.calliope.arg_parsers import HostPort
|
|
from googlecloudsdk.command_lib.util.args import labels_util
|
|
from googlecloudsdk.core import resources
|
|
from googlecloudsdk.core import yaml
|
|
from googlecloudsdk.core.console import console_io
|
|
|
|
|
|
def GetConnectionProfileURI(resource):
|
|
connection_profile = resources.REGISTRY.ParseRelativeName(
|
|
resource.name,
|
|
collection='datastream.projects.locations.connectionProfiles')
|
|
return connection_profile.SelfLink()
|
|
|
|
|
|
class ConnectionProfilesClient:
|
|
"""Client for connection profiles service in the API."""
|
|
|
|
def __init__(self, client=None, messages=None):
|
|
self._client = client or util.GetClientInstance()
|
|
self._messages = messages or util.GetMessagesModule()
|
|
self._service = self._client.projects_locations_connectionProfiles
|
|
self._resource_parser = util.GetResourceParser()
|
|
|
|
def _ValidateArgs(self, args):
|
|
self._ValidateSslConfigArgs(args)
|
|
|
|
def _ValidateSslConfigArgs(self, args):
|
|
"""Validates Format of all SSL config args."""
|
|
self._ValidateCertificateFormat(args.ca_certificate, 'CA certificate')
|
|
self._ValidateCertificateFormat(args.client_certificate,
|
|
'client certificate')
|
|
self._ValidateCertificateFormat(args.client_key, 'client key')
|
|
|
|
# Validation for all Postgresql SSL config fields.
|
|
self._ValidateCertificateFormat(
|
|
args.postgresql_ca_certificate, 'Postgresql CA certificate'
|
|
)
|
|
self._ValidateCertificateFormat(
|
|
args.postgresql_client_certificate, 'Postgresql client certificate'
|
|
)
|
|
self._ValidateCertificateFormat(
|
|
args.postgresql_client_key, 'Postgresql client private key'
|
|
)
|
|
|
|
# Validation for Oracle SSL config fields.
|
|
self._ValidateCertificateFormat(
|
|
args.oracle_ca_certificate, 'Oracle CA certificate'
|
|
)
|
|
|
|
def _ValidateCertificateFormat(self, certificate, name):
|
|
if not certificate:
|
|
return True
|
|
cert = certificate.strip()
|
|
cert_lines = cert.split('\n')
|
|
if (not cert_lines[0].startswith('-----') or
|
|
not cert_lines[-1].startswith('-----')):
|
|
raise exceptions.InvalidArgumentException(
|
|
name,
|
|
'The certificate does not appear to be in PEM format: \n{0}'.format(
|
|
cert))
|
|
|
|
def _GetSslConfig(self, args):
|
|
return self._messages.MysqlSslConfig(
|
|
clientKey=args.client_key,
|
|
clientCertificate=args.client_certificate,
|
|
caCertificate=args.ca_certificate)
|
|
|
|
def _GetMySqlProfile(self, args):
|
|
ssl_config = self._GetSslConfig(args)
|
|
return self._messages.MysqlProfile(
|
|
hostname=args.mysql_hostname,
|
|
port=args.mysql_port,
|
|
username=args.mysql_username,
|
|
password=args.mysql_password,
|
|
secretManagerStoredPassword=args.mysql_secret_manager_stored_password,
|
|
sslConfig=ssl_config)
|
|
|
|
def _GetOracleProfile(self, args):
|
|
ssl_config = self._GetOracleSslConfig(args)
|
|
return self._messages.OracleProfile(
|
|
hostname=args.oracle_hostname,
|
|
port=args.oracle_port,
|
|
username=args.oracle_username,
|
|
password=args.oracle_password,
|
|
secretManagerStoredPassword=args.oracle_secret_manager_stored_password,
|
|
databaseService=args.database_service,
|
|
oracleSslConfig=ssl_config)
|
|
|
|
def _GetOracleSslConfig(self, args):
|
|
"""Returns a OracleSslConfig message based on the given args."""
|
|
return self._messages.OracleSslConfig(
|
|
caCertificate=args.oracle_ca_certificate,
|
|
serverCertificateDistinguishedName=args.oracle_server_certificate_distinguished_name,
|
|
)
|
|
|
|
def _GetPostgresqlSslConfig(self, args):
|
|
"""Returns a PostgresqlSslConfig message based on the given args."""
|
|
if args.postgresql_client_certificate or args.postgresql_client_key:
|
|
return self._messages.PostgresqlSslConfig(
|
|
serverAndClientVerification=self._messages.ServerAndClientVerification(
|
|
clientCertificate=args.postgresql_client_certificate,
|
|
clientKey=args.postgresql_client_key,
|
|
caCertificate=args.postgresql_ca_certificate,
|
|
serverCertificateHostname=args.postgresql_server_certificate_hostname,
|
|
)
|
|
)
|
|
|
|
if args.postgresql_ca_certificate:
|
|
return self._messages.PostgresqlSslConfig(
|
|
serverVerification=self._messages.ServerVerification(
|
|
caCertificate=args.postgresql_ca_certificate,
|
|
serverCertificateHostname=args.postgresql_server_certificate_hostname,
|
|
)
|
|
)
|
|
|
|
return None
|
|
|
|
def _GetPostgresqlProfile(self, args):
|
|
ssl_config = self._GetPostgresqlSslConfig(args)
|
|
return self._messages.PostgresqlProfile(
|
|
hostname=args.postgresql_hostname,
|
|
port=args.postgresql_port,
|
|
username=args.postgresql_username,
|
|
password=args.postgresql_password,
|
|
secretManagerStoredPassword=args.postgresql_secret_manager_stored_password,
|
|
database=args.postgresql_database,
|
|
sslConfig=ssl_config)
|
|
|
|
def _GetSqlServerProfile(self, args):
|
|
return self._messages.SqlServerProfile(
|
|
hostname=args.sqlserver_hostname,
|
|
port=args.sqlserver_port,
|
|
username=args.sqlserver_username,
|
|
password=args.sqlserver_password,
|
|
secretManagerStoredPassword=args.sqlserver_secret_manager_stored_password,
|
|
database=args.sqlserver_database,
|
|
)
|
|
|
|
def _GetSalesforceProfile(self, args):
|
|
if args.salesforce_oauth2_client_id:
|
|
return self._messages.SalesforceProfile(
|
|
domain=args.salesforce_domain,
|
|
oauth2ClientCredentials=self._messages.Oauth2ClientCredentials(
|
|
clientId=args.salesforce_oauth2_client_id,
|
|
clientSecret=args.salesforce_oauth2_client_secret,
|
|
secretManagerStoredClientSecret=args.salesforce_secret_manager_stored_oauth2_client_secret,
|
|
),
|
|
)
|
|
else:
|
|
return self._messages.SalesforceProfile(
|
|
domain=args.salesforce_domain,
|
|
userCredentials=self._messages.UserCredentials(
|
|
username=args.salesforce_username,
|
|
password=args.salesforce_password,
|
|
secretManagerStoredPassword=args.salesforce_secret_manager_stored_password,
|
|
securityToken=args.salesforce_security_token,
|
|
secretManagerStoredSecurityToken=args.salesforce_secret_manager_stored_security_token,
|
|
),
|
|
)
|
|
|
|
def _GetGCSProfile(self, args, release_track):
|
|
# TODO(b/207467120): remove bucket_name arg check.
|
|
if release_track == base.ReleaseTrack.BETA:
|
|
bucket = args.bucket_name
|
|
else:
|
|
bucket = args.bucket
|
|
|
|
gcs_profile = self._messages.GcsProfile(bucket=bucket)
|
|
gcs_profile.rootPath = args.root_path if args.root_path else '/'
|
|
return gcs_profile
|
|
|
|
def _GetMongodbProfile(self, args):
|
|
"""Returns the MongoDB profile message based on the given args."""
|
|
addresses = []
|
|
for host_address in args.mongodb_host_addresses:
|
|
if args.mongodb_srv_connection_format:
|
|
addresses.append(
|
|
self._messages.HostAddress(hostname=host_address)
|
|
)
|
|
else:
|
|
hostport = HostPort.Parse(host_address)
|
|
addresses.append(
|
|
self._messages.HostAddress(
|
|
hostname=hostport.host, port=int(hostport.port)
|
|
)
|
|
)
|
|
profile = self._messages.MongodbProfile(
|
|
hostAddresses=addresses,
|
|
username=args.mongodb_username,
|
|
replicaSet=args.mongodb_replica_set,
|
|
password=args.mongodb_password,
|
|
secretManagerStoredPassword=args.mongodb_secret_manager_stored_password,
|
|
)
|
|
if (
|
|
args.mongodb_direct_connection
|
|
and not args.mongodb_standard_connection_format
|
|
):
|
|
raise exceptions.InvalidArgumentException(
|
|
'mongodb-direct-connection',
|
|
'mongodb direct connection can only be used with the standard'
|
|
' connection format.',
|
|
)
|
|
if args.mongodb_srv_connection_format:
|
|
profile.srvConnectionFormat = {}
|
|
if args.mongodb_standard_connection_format:
|
|
profile.standardConnectionFormat = (
|
|
self._messages.StandardConnectionFormat(
|
|
directConnection=args.mongodb_direct_connection
|
|
)
|
|
)
|
|
if args.mongodb_tls:
|
|
profile.sslConfig = {}
|
|
if args.mongodb_ca_certificate:
|
|
profile.sslConfig.caCertificate = args.mongodb_ca_certificate
|
|
return profile
|
|
|
|
def _ParseSslConfig(self, data):
|
|
return self._messages.MysqlSslConfig(
|
|
clientKey=data.get('client_key'),
|
|
clientCertificate=data.get('client_certificate'),
|
|
caCertificate=data.get('ca_certificate'))
|
|
|
|
def _ParseMySqlProfile(self, data):
|
|
if not data:
|
|
return {}
|
|
ssl_config = self._ParseSslConfig(data)
|
|
return self._messages.MysqlProfile(
|
|
hostname=data.get('hostname'),
|
|
port=data.get('port'),
|
|
username=data.get('username'),
|
|
password=data.get('password'),
|
|
sslConfig=ssl_config)
|
|
|
|
def _ParseOracleProfile(self, data):
|
|
if not data:
|
|
return {}
|
|
return self._messages.OracleProfile(
|
|
hostname=data.get('hostname'),
|
|
port=data.get('port'),
|
|
username=data.get('username'),
|
|
password=data.get('password'),
|
|
databaseService=data.get('database_service'))
|
|
|
|
def _ParsePostgresqlProfile(self, data):
|
|
if not data:
|
|
return {}
|
|
return self._messages.PostgresqlProfile(
|
|
hostname=data.get('hostname'),
|
|
port=data.get('port'),
|
|
username=data.get('username'),
|
|
password=data.get('password'),
|
|
database=data.get('database'))
|
|
|
|
def _ParseSqlServerProfile(self, data):
|
|
if not data:
|
|
return {}
|
|
return self._messages.SqlServerProfile(
|
|
hostname=data.get('hostname'),
|
|
port=data.get('port'),
|
|
username=data.get('username'),
|
|
password=data.get('password'),
|
|
database=data.get('database'),
|
|
)
|
|
|
|
def _ParseGCSProfile(self, data):
|
|
if not data:
|
|
return {}
|
|
return self._messages.GcsProfile(
|
|
bucket=data.get('bucket_name'), rootPath=data.get('root_path'))
|
|
|
|
def _GetForwardSshTunnelConnectivity(self, args):
|
|
return self._messages.ForwardSshTunnelConnectivity(
|
|
hostname=args.forward_ssh_hostname,
|
|
port=args.forward_ssh_port,
|
|
username=args.forward_ssh_username,
|
|
privateKey=args.forward_ssh_private_key,
|
|
password=args.forward_ssh_password)
|
|
|
|
def _GetConnectionProfile(self, cp_type, connection_profile_id, args,
|
|
release_track):
|
|
"""Returns a connection profile according to type."""
|
|
labels = labels_util.ParseCreateArgs(
|
|
args, self._messages.ConnectionProfile.LabelsValue)
|
|
connection_profile_obj = self._messages.ConnectionProfile(
|
|
name=connection_profile_id, labels=labels,
|
|
displayName=args.display_name)
|
|
|
|
if cp_type == 'MYSQL':
|
|
connection_profile_obj.mysqlProfile = self._GetMySqlProfile(args)
|
|
elif cp_type == 'ORACLE':
|
|
connection_profile_obj.oracleProfile = self._GetOracleProfile(args)
|
|
elif cp_type == 'POSTGRESQL':
|
|
connection_profile_obj.postgresqlProfile = self._GetPostgresqlProfile(
|
|
args)
|
|
elif cp_type == 'SQLSERVER':
|
|
connection_profile_obj.sqlServerProfile = self._GetSqlServerProfile(args)
|
|
elif cp_type == 'GOOGLE-CLOUD-STORAGE':
|
|
connection_profile_obj.gcsProfile = self._GetGCSProfile(
|
|
args, release_track)
|
|
elif cp_type == 'BIGQUERY':
|
|
connection_profile_obj.bigqueryProfile = self._messages.BigQueryProfile()
|
|
elif cp_type == 'SALESFORCE':
|
|
connection_profile_obj.salesforceProfile = self._GetSalesforceProfile(
|
|
args
|
|
)
|
|
elif cp_type == 'MONGODB':
|
|
connection_profile_obj.mongodbProfile = self._GetMongodbProfile(args)
|
|
else:
|
|
raise exceptions.InvalidArgumentException(
|
|
cp_type,
|
|
'The connection profile type {0} is either unknown or not supported'
|
|
' yet.'.format(cp_type),
|
|
)
|
|
|
|
# TODO(b/207467120): deprecate BETA client.
|
|
if release_track == base.ReleaseTrack.BETA:
|
|
private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
|
|
else:
|
|
private_connectivity_ref = args.CONCEPTS.private_connection.Parse()
|
|
|
|
if private_connectivity_ref:
|
|
connection_profile_obj.privateConnectivity = (
|
|
self._messages.PrivateConnectivity(
|
|
privateConnection=private_connectivity_ref.RelativeName()
|
|
)
|
|
)
|
|
elif args.forward_ssh_hostname:
|
|
connection_profile_obj.forwardSshConnectivity = (
|
|
self._GetForwardSshTunnelConnectivity(args)
|
|
)
|
|
elif args.static_ip_connectivity:
|
|
connection_profile_obj.staticServiceIpConnectivity = {}
|
|
|
|
return connection_profile_obj
|
|
|
|
def _ParseConnectionProfileObjectFile(
|
|
self, connection_profile_object_file, release_track
|
|
):
|
|
"""Parses a connection-profile-file into the ConnectionProfile message."""
|
|
if release_track != base.ReleaseTrack.BETA:
|
|
return util.ParseMessageAndValidateSchema(
|
|
connection_profile_object_file,
|
|
'ConnectionProfile',
|
|
self._messages.ConnectionProfile,
|
|
)
|
|
|
|
data = console_io.ReadFromFileOrStdin(
|
|
connection_profile_object_file, binary=False)
|
|
try:
|
|
connection_profile_data = yaml.load(data)
|
|
except Exception as e:
|
|
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
|
|
|
|
display_name = connection_profile_data.get('display_name')
|
|
labels = connection_profile_data.get('labels')
|
|
connection_profile_msg = self._messages.ConnectionProfile(
|
|
displayName=display_name,
|
|
labels=labels)
|
|
|
|
oracle_profile = self._ParseOracleProfile(
|
|
connection_profile_data.get('oracle_profile', {}))
|
|
mysql_profile = self._ParseMySqlProfile(
|
|
connection_profile_data.get('mysql_profile', {}))
|
|
postgresql_profile = self._ParsePostgresqlProfile(
|
|
connection_profile_data.get('postgresql_profile', {}))
|
|
sqlserver_profile = self._ParseSqlServerProfile(
|
|
connection_profile_data.get('sqlserver_profile', {})
|
|
)
|
|
gcs_profile = self._ParseGCSProfile(
|
|
connection_profile_data.get('gcs_profile', {}))
|
|
if oracle_profile:
|
|
connection_profile_msg.oracleProfile = oracle_profile
|
|
elif mysql_profile:
|
|
connection_profile_msg.mysqlProfile = mysql_profile
|
|
elif postgresql_profile:
|
|
connection_profile_msg.postgresqlProfile = postgresql_profile
|
|
elif sqlserver_profile:
|
|
connection_profile_msg.sqlServerProfile = sqlserver_profile
|
|
elif gcs_profile:
|
|
connection_profile_msg.gcsProfile = gcs_profile
|
|
|
|
if 'static_service_ip_connectivity' in connection_profile_data:
|
|
connection_profile_msg.staticServiceIpConnectivity = (
|
|
connection_profile_data.get('static_service_ip_connectivity')
|
|
)
|
|
elif 'forward_ssh_connectivity' in connection_profile_data:
|
|
connection_profile_msg.forwardSshConnectivity = (
|
|
connection_profile_data.get('forward_ssh_connectivity')
|
|
)
|
|
elif 'private_connectivity' in connection_profile_data:
|
|
connection_profile_msg.privateConnectivity = connection_profile_data.get(
|
|
'private_connectivity'
|
|
)
|
|
else:
|
|
raise ds_exceptions.ParseError(
|
|
'Cannot parse YAML: missing connectivity method.'
|
|
)
|
|
|
|
return connection_profile_msg
|
|
|
|
def _UpdateForwardSshTunnelConnectivity(
|
|
self, connection_profile, args, update_fields
|
|
):
|
|
"""Updates Forward SSH tunnel connectivity config."""
|
|
if args.IsSpecified('forward_ssh_hostname'):
|
|
connection_profile.forwardSshConnectivity.hostname = (
|
|
args.forward_ssh_hostname
|
|
)
|
|
update_fields.append('forwardSshConnectivity.hostname')
|
|
if args.IsSpecified('forward_ssh_port'):
|
|
connection_profile.forwardSshConnectivity.port = args.forward_ssh_port
|
|
update_fields.append('forwardSshConnectivity.port')
|
|
if args.IsSpecified('forward_ssh_username'):
|
|
connection_profile.forwardSshConnectivity.username = (
|
|
args.forward_ssh_username
|
|
)
|
|
update_fields.append('forwardSshConnectivity.username')
|
|
if args.IsSpecified('forward_ssh_private_key'):
|
|
connection_profile.forwardSshConnectivity.privateKey = (
|
|
args.forward_ssh_private_key
|
|
)
|
|
update_fields.append('forwardSshConnectivity.privateKey')
|
|
if args.IsSpecified('forward_ssh_password'):
|
|
connection_profile.forwardSshConnectivity.privateKey = (
|
|
args.forward_ssh_password
|
|
)
|
|
update_fields.append('forwardSshConnectivity.password')
|
|
|
|
def _UpdateGCSProfile(
|
|
self, connection_profile, release_track, args, update_fields
|
|
):
|
|
"""Updates GOOGLE CLOUD STORAGE connection profile."""
|
|
# TODO(b/207467120): remove bucket_name arg check.
|
|
if release_track == base.ReleaseTrack.BETA and args.IsSpecified(
|
|
'bucket_name'
|
|
):
|
|
connection_profile.gcsProfile.bucket = args.bucket_name
|
|
update_fields.append('gcsProfile.bucket')
|
|
if release_track == base.ReleaseTrack.GA and args.IsSpecified('bucket'):
|
|
connection_profile.gcsProfile.bucket = args.bucket
|
|
update_fields.append('gcsProfile.bucket')
|
|
if args.IsSpecified('root_path'):
|
|
connection_profile.gcsProfile.rootPath = args.root_path
|
|
update_fields.append('gcsProfile.rootPath')
|
|
|
|
def _UpdateOracleProfile(self,
|
|
connection_profile,
|
|
args,
|
|
update_fields):
|
|
"""Updates Oracle connection profile."""
|
|
if args.IsSpecified('oracle_hostname'):
|
|
connection_profile.oracleProfile.hostname = args.oracle_hostname
|
|
update_fields.append('oracleProfile.hostname')
|
|
if args.IsSpecified('oracle_port'):
|
|
connection_profile.oracleProfile.port = args.oracle_port
|
|
update_fields.append('oracleProfile.port')
|
|
if args.IsSpecified('oracle_username'):
|
|
connection_profile.oracleProfile.username = args.oracle_username
|
|
update_fields.append('oracleProfile.username')
|
|
if args.IsSpecified('oracle_password') or args.IsSpecified(
|
|
'oracle_secret_manager_stored_password'
|
|
):
|
|
connection_profile.oracleProfile.password = args.oracle_password
|
|
connection_profile.oracleProfile.secretManagerStoredPassword = (
|
|
args.oracle_secret_manager_stored_password
|
|
)
|
|
update_fields.append('oracleProfile.password')
|
|
update_fields.append('oracleProfile.secretManagerStoredPassword')
|
|
if args.IsSpecified('database_service'):
|
|
connection_profile.oracleProfile.databaseService = args.database_service
|
|
update_fields.append('oracleProfile.databaseService')
|
|
|
|
def _UpdateMysqlSslConfig(self, connection_profile, args, update_fields):
|
|
"""Updates Mysql SSL config."""
|
|
if args.IsSpecified('client_key'):
|
|
connection_profile.mysqlProfile.sslConfig.clientKey = args.client_key
|
|
update_fields.append('mysqlProfile.sslConfig.clientKey')
|
|
if args.IsSpecified('client_certificate'):
|
|
connection_profile.mysqlProfile.sslConfig.clientCertificate = (
|
|
args.client_certificate
|
|
)
|
|
update_fields.append('mysqlProfile.sslConfig.clientCertificate')
|
|
if args.IsSpecified('ca_certificate'):
|
|
connection_profile.mysqlProfile.sslConfig.caCertificate = (
|
|
args.ca_certificate
|
|
)
|
|
update_fields.append('mysqlProfile.sslConfig.caCertificate')
|
|
|
|
def _UpdateMySqlProfile(self, connection_profile, args, update_fields):
|
|
"""Updates MySQL connection profile."""
|
|
if args.IsSpecified('mysql_hostname'):
|
|
connection_profile.mysqlProfile.hostname = args.mysql_hostname
|
|
update_fields.append('mysqlProfile.hostname')
|
|
if args.IsSpecified('mysql_port'):
|
|
connection_profile.mysqlProfile.port = args.mysql_port
|
|
update_fields.append('mysqlProfile.port')
|
|
if args.IsSpecified('mysql_username'):
|
|
connection_profile.mysqlProfile.username = args.mysql_username
|
|
update_fields.append('mysqlProfile.username')
|
|
if args.IsSpecified('mysql_password') or args.IsSpecified(
|
|
'mysql_secret_manager_stored_password'
|
|
):
|
|
connection_profile.mysqlProfile.password = args.mysql_password
|
|
connection_profile.mysqlProfile.secretManagerStoredPassword = (
|
|
args.mysql_secret_manager_stored_password
|
|
)
|
|
update_fields.append('mysqlProfile.password')
|
|
update_fields.append('mysqlProfile.secretManagerStoredPassword')
|
|
|
|
self._UpdateMysqlSslConfig(connection_profile, args, update_fields)
|
|
|
|
def _UpdatePostgresqlSslConfig(self, connection_profile, args, update_fields):
|
|
"""Updates Postgresql SSL config."""
|
|
if args.IsSpecified('postgresql_client_certificate'):
|
|
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate = (
|
|
args.postgresql_client_certificate
|
|
)
|
|
update_fields.append(
|
|
'postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate'
|
|
)
|
|
|
|
if args.IsSpecified('postgresql_client_key'):
|
|
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientKey = (
|
|
args.postgresql_client_key
|
|
)
|
|
update_fields.append(
|
|
'postgresqlProfile.sslConfig.serverAndClientVerification.clientKey'
|
|
)
|
|
|
|
if args.IsSpecified('postgresql_ca_certificate'):
|
|
if connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification:
|
|
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate = (
|
|
args.postgresql_ca_certificate
|
|
)
|
|
update_fields.append(
|
|
'postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate'
|
|
)
|
|
else:
|
|
connection_profile.postgresqlProfile.sslConfig.serverVerification.caCertificate = (
|
|
args.postgresql_ca_certificate
|
|
)
|
|
update_fields.append(
|
|
'postgresqlProfile.sslConfig.serverVerification.caCertificate'
|
|
)
|
|
if args.IsSpecified('postgresql_server_certificate_hostname'):
|
|
if (
|
|
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification
|
|
):
|
|
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname = (
|
|
args.postgresql_server_certificate_hostname
|
|
)
|
|
update_fields.append(
|
|
'postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname'
|
|
)
|
|
else:
|
|
connection_profile.postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname = (
|
|
args.postgresql_server_certificate_hostname
|
|
)
|
|
update_fields.append(
|
|
'postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname'
|
|
)
|
|
|
|
def _UpdatePostgresqlProfile(self, connection_profile, args, update_fields):
|
|
"""Updates Postgresql connection profile."""
|
|
if args.IsSpecified('postgresql_hostname'):
|
|
connection_profile.postgresqlProfile.hostname = args.postgresql_hostname
|
|
update_fields.append('postgresqlProfile.hostname')
|
|
if args.IsSpecified('postgresql_port'):
|
|
connection_profile.postgresqlProfile.port = args.postgresql_port
|
|
update_fields.append('postgresqlProfile.port')
|
|
if args.IsSpecified('postgresql_username'):
|
|
connection_profile.postgresqlProfile.username = args.postgresql_username
|
|
update_fields.append('postgresqlProfile.username')
|
|
if args.IsSpecified('postgresql_password') or args.IsSpecified(
|
|
'postgresql_secret_manager_stored_password'
|
|
):
|
|
connection_profile.postgresqlProfile.password = args.postgresql_password
|
|
connection_profile.postgresqlProfile.secretManagerStoredPassword = (
|
|
args.postgresql_secret_manager_stored_password
|
|
)
|
|
update_fields.append('postgresqlProfile.password')
|
|
update_fields.append('postgresqlProfile.secretManagerStoredPassword')
|
|
if args.IsSpecified('postgresql_database'):
|
|
connection_profile.postgresqlProfile.database = args.postgresql_database
|
|
update_fields.append('postgresqlProfile.database')
|
|
|
|
self._UpdatePostgresqlSslConfig(connection_profile, args, update_fields)
|
|
|
|
def _UpdateSqlServerProfile(self, connection_profile, args, update_fields):
|
|
"""Updates SqlServer connection profile."""
|
|
if args.IsSpecified('sqlserver_hostname'):
|
|
connection_profile.sqlServerProfile.hostname = args.sqlserver_hostname
|
|
update_fields.append('sqlServerProfile.hostname')
|
|
if args.IsSpecified('sqlserver_port'):
|
|
connection_profile.sqlServerProfile.port = args.sqlserver_port
|
|
update_fields.append('sqlServerProfile.port')
|
|
if args.IsSpecified('sqlserver_username'):
|
|
connection_profile.sqlServerProfile.username = args.sqlserver_username
|
|
update_fields.append('sqlServerProfile.username')
|
|
if args.IsSpecified('sqlserver_password') or args.IsSpecified(
|
|
'sqlserver_secret_manager_stored_password'
|
|
):
|
|
connection_profile.sqlServerProfile.password = args.sqlserver_password
|
|
connection_profile.sqlServerProfile.secretManagerStoredPassword = (
|
|
args.sqlserver_secret_manager_stored_password
|
|
)
|
|
update_fields.append('sqlServerProfile.password')
|
|
update_fields.append('sqlServerProfile.secretManagerStoredPassword')
|
|
if args.IsSpecified('sqlserver_database'):
|
|
connection_profile.sqlServerProfile.database = args.sqlserver_database
|
|
update_fields.append('sqlServerProfile.database')
|
|
|
|
def _UpdateSalesforceProfile(self, connection_profile, args, update_fields):
|
|
"""Updates Salesforce connection profile."""
|
|
if args.IsSpecified('salesforce_domain'):
|
|
connection_profile.salesforceProfile.domain = args.salesforce_domain
|
|
update_fields.append('salesforceProfile.domain')
|
|
if args.IsSpecified('salesforce_username'):
|
|
connection_profile.salesforceProfile.userCredentials.username = (
|
|
args.salesforce_username
|
|
)
|
|
update_fields.append('salesforceProfile.userCredentials.username')
|
|
if args.IsSpecified('salesforce_password') or args.IsSpecified(
|
|
'salesforce_secret_manager_stored_password'
|
|
):
|
|
connection_profile.salesforceProfile.userCredentials.password = (
|
|
args.salesforce_password
|
|
)
|
|
connection_profile.salesforceProfile.userCredentials.secretManagerStoredPassword = (
|
|
args.salesforce_secret_manager_stored_password
|
|
)
|
|
update_fields.append('salesforceProfile.userCredentials.password')
|
|
update_fields.append(
|
|
'salesforceProfile.userCredentials.secretManagerStoredPassword'
|
|
)
|
|
|
|
if args.IsSpecified('salesforce_security_token') or args.IsSpecified(
|
|
'salesforce_secret_manager_stored_security_token'
|
|
):
|
|
connection_profile.salesforceProfile.userCredentials.securityToken = (
|
|
args.salesforce_security_token
|
|
)
|
|
connection_profile.salesforceProfile.userCredentials.secretManagerStoredSecurityToken = (
|
|
args.salesforce_secret_manager_stored_security_token
|
|
)
|
|
update_fields.append('salesforceProfile.userCredentials.securityToken')
|
|
update_fields.append(
|
|
'salesforceProfile.userCredentials.secretManagerStoredSecurityToken'
|
|
)
|
|
|
|
if args.IsSpecified('salesforce_oauth2_client_id'):
|
|
connection_profile.salesforceProfile.oauth2ClientCredentials.clientId = (
|
|
args.salesforce_oauth2_client_id
|
|
)
|
|
update_fields.append('salesforceProfile.oauth2ClientCredentials.clientId')
|
|
if args.IsSpecified('salesforce_oauth2_client_secret') or args.IsSpecified(
|
|
'salesforce_secret_manager_stored_oauth2_client_secret'
|
|
):
|
|
connection_profile.salesforceProfile.oauth2ClientCredentials.clientSecret = (
|
|
args.salesforce_oauth2_client_secret
|
|
)
|
|
connection_profile.salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret = (
|
|
args.salesforce_secret_manager_stored_oauth2_client_secret
|
|
)
|
|
update_fields.append(
|
|
'salesforceProfile.oauth2ClientCredentials.clientSecret'
|
|
)
|
|
update_fields.append(
|
|
'salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret'
|
|
)
|
|
|
|
def _UpdateMongodbProfile(self, connection_profile, args, update_fields):
|
|
"""Updates MongoDB connection profile."""
|
|
if args.IsSpecified('mongodb_host_addresses'):
|
|
addresses = []
|
|
for host_address in args.mongodb_host_addresses:
|
|
if args.mongodb_srv_connection_format:
|
|
addresses.append(
|
|
self._messages.HostAddress(hostname=host_address)
|
|
)
|
|
else:
|
|
hostname, port = host_address.split(':')
|
|
addresses.append(
|
|
self._messages.HostAddress(hostname=hostname, port=int(port))
|
|
)
|
|
connection_profile.mongodbProfile.hostAddresses = addresses
|
|
update_fields.append('monogodbProfile.hostAddresses')
|
|
if args.IsSpecified('mongodb_replica_set'):
|
|
connection_profile.mongodbProfile.replicaSet = args.mongodb_replica_set
|
|
update_fields.append('mongodbProfile.replicaSet')
|
|
if args.IsSpecified('mongodb_srv_connection_format') or args.IsSpecified(
|
|
'mongodb_standard_connection_format'
|
|
):
|
|
if args.mongodb_srv_connection_format:
|
|
connection_profile.mongodbProfile.srvConnectionFormat = {}
|
|
if args.mongodb_standard_connection_format:
|
|
connection_profile.mongodbProfile.standardConnectionFormat = {}
|
|
update_fields.append('mongodbProfile.srvConnectionFormat')
|
|
update_fields.append('mongodbProfile.standardConnectionFormat')
|
|
if args.IsSpecified('mongodb_username'):
|
|
connection_profile.mongodbProfile.username = args.mongodb_username
|
|
update_fields.append('mongodbProfile.username')
|
|
if args.IsSpecified('mongodb_password') or args.IsSpecified(
|
|
'mongodb_secret_manager_stored_password'
|
|
):
|
|
connection_profile.mongodbProfile.password = args.mongodb_password
|
|
connection_profile.mongodbProfile.secretManagerStoredPassword = (
|
|
args.mongodb_secret_manager_stored_password
|
|
)
|
|
update_fields.append('mongodbProfile.password')
|
|
update_fields.append('mongodbProfile.secretManagerStoredPassword')
|
|
|
|
def _GetExistingConnectionProfile(self, name):
|
|
get_req = (
|
|
self._messages.DatastreamProjectsLocationsConnectionProfilesGetRequest(
|
|
name=name
|
|
)
|
|
)
|
|
return self._service.Get(get_req)
|
|
|
|
def _UpdateLabels(self, connection_profile, args):
|
|
"""Updates labels of the connection profile."""
|
|
add_labels = labels_util.GetUpdateLabelsDictFromArgs(args)
|
|
remove_labels = labels_util.GetRemoveLabelsListFromArgs(args)
|
|
value_type = self._messages.ConnectionProfile.LabelsValue
|
|
update_result = labels_util.Diff(
|
|
additions=add_labels,
|
|
subtractions=remove_labels,
|
|
clear=args.clear_labels
|
|
).Apply(value_type, connection_profile.labels)
|
|
if update_result.needs_update:
|
|
connection_profile.labels = update_result.labels
|
|
|
|
def _GetUpdatedConnectionProfile(self, connection_profile, cp_type,
|
|
release_track, args):
|
|
"""Returns updated connection profile and list of updated fields."""
|
|
update_fields = []
|
|
if args.IsSpecified('display_name'):
|
|
connection_profile.displayName = args.display_name
|
|
update_fields.append('displayName')
|
|
|
|
if cp_type == 'MYSQL':
|
|
self._UpdateMySqlProfile(
|
|
connection_profile, args, update_fields)
|
|
elif cp_type == 'ORACLE':
|
|
self._UpdateOracleProfile(connection_profile, args, update_fields)
|
|
elif cp_type == 'POSTGRESQL':
|
|
self._UpdatePostgresqlProfile(connection_profile, args, update_fields)
|
|
elif cp_type == 'SQLSERVER':
|
|
self._UpdateSqlServerProfile(connection_profile, args, update_fields)
|
|
elif cp_type == 'SALESFORCE':
|
|
self._UpdateSalesforceProfile(connection_profile, args, update_fields)
|
|
elif cp_type == 'GOOGLE-CLOUD-STORAGE':
|
|
self._UpdateGCSProfile(
|
|
connection_profile, release_track, args, update_fields
|
|
)
|
|
elif cp_type == 'BIGQUERY':
|
|
# There are currently no parameters that can be updated in a bigquery CP.
|
|
pass
|
|
elif cp_type == 'MONGODB':
|
|
self._UpdateMongodbProfile(connection_profile, args, update_fields)
|
|
else:
|
|
raise exceptions.InvalidArgumentException(
|
|
cp_type,
|
|
'The connection profile type {0} is either unknown or not supported'
|
|
' yet.'.format(cp_type),
|
|
)
|
|
|
|
# TODO(b/207467120): deprecate BETA client.
|
|
if release_track == base.ReleaseTrack.BETA:
|
|
private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
|
|
else:
|
|
private_connectivity_ref = args.CONCEPTS.private_connection.Parse()
|
|
|
|
if private_connectivity_ref:
|
|
connection_profile.privateConnectivity = (
|
|
self._messages.PrivateConnectivity(
|
|
privateConnectionName=private_connectivity_ref.RelativeName()
|
|
)
|
|
)
|
|
update_fields.append('privateConnectivity')
|
|
elif args.forward_ssh_hostname:
|
|
self._UpdateForwardSshTunnelConnectivity(
|
|
connection_profile, args, update_fields
|
|
)
|
|
elif args.static_ip_connectivity:
|
|
connection_profile.staticServiceIpConnectivity = {}
|
|
update_fields.append('staticServiceIpConnectivity')
|
|
|
|
self._UpdateLabels(connection_profile, args)
|
|
return connection_profile, update_fields
|
|
|
|
def Create(self,
|
|
parent_ref,
|
|
connection_profile_id,
|
|
cp_type,
|
|
release_track,
|
|
args=None):
|
|
"""Creates a connection profile.
|
|
|
|
Args:
|
|
parent_ref: a Resource reference to a parent datastream.projects.locations
|
|
resource for this connection profile.
|
|
connection_profile_id: str, the name of the resource to create.
|
|
cp_type: str, the type of the connection profile ('MYSQL', ''
|
|
release_track: Some arguments are added based on the command release
|
|
track.
|
|
args: argparse.Namespace, The arguments that this command was invoked
|
|
with.
|
|
|
|
Returns:
|
|
Operation: the operation for creating the connection profile.
|
|
"""
|
|
self._ValidateArgs(args)
|
|
|
|
connection_profile = self._GetConnectionProfile(cp_type,
|
|
connection_profile_id, args,
|
|
release_track)
|
|
# TODO(b/207467120): only use flags from args.
|
|
force = False
|
|
if release_track == base.ReleaseTrack.BETA or args.force:
|
|
force = True
|
|
|
|
request_id = util.GenerateRequestId()
|
|
create_req_type = (
|
|
self._messages.DatastreamProjectsLocationsConnectionProfilesCreateRequest
|
|
)
|
|
create_req = create_req_type(
|
|
connectionProfile=connection_profile,
|
|
connectionProfileId=connection_profile.name,
|
|
parent=parent_ref,
|
|
requestId=request_id,
|
|
force=force)
|
|
|
|
return self._service.Create(create_req)
|
|
|
|
def Update(self, name, cp_type, release_track, args=None):
|
|
"""Updates a connection profile.
|
|
|
|
Args:
|
|
name: str, the reference of the connection profile to
|
|
update.
|
|
cp_type: str, the type of the connection profile ('MYSQL', 'ORACLE')
|
|
release_track: Some arguments are added based on the command release
|
|
track.
|
|
args: argparse.Namespace, The arguments that this command was
|
|
invoked with.
|
|
|
|
Returns:
|
|
Operation: the operation for updating the connection profile.
|
|
"""
|
|
self._ValidateArgs(args)
|
|
|
|
current_cp = self._GetExistingConnectionProfile(name)
|
|
|
|
updated_cp, update_fields = self._GetUpdatedConnectionProfile(
|
|
current_cp, cp_type, release_track, args)
|
|
|
|
# TODO(b/207467120): only use flags from args.
|
|
force = False
|
|
if release_track == base.ReleaseTrack.BETA or args.force:
|
|
force = True
|
|
|
|
request_id = util.GenerateRequestId()
|
|
update_req_type = (
|
|
self._messages.DatastreamProjectsLocationsConnectionProfilesPatchRequest
|
|
)
|
|
update_req = update_req_type(
|
|
connectionProfile=updated_cp,
|
|
name=updated_cp.name,
|
|
updateMask=','.join(update_fields),
|
|
requestId=request_id,
|
|
force=force,
|
|
)
|
|
|
|
return self._service.Patch(update_req)
|
|
|
|
def List(self, project_id, args):
|
|
"""Get the list of connection profiles in a project.
|
|
|
|
Args:
|
|
project_id: The project ID to retrieve
|
|
args: parsed command line arguments
|
|
|
|
Returns:
|
|
An iterator over all the matching connection profiles.
|
|
"""
|
|
location_ref = self._resource_parser.Create(
|
|
'datastream.projects.locations',
|
|
projectsId=project_id,
|
|
locationsId=args.location,
|
|
)
|
|
|
|
list_req_type = (
|
|
self._messages.DatastreamProjectsLocationsConnectionProfilesListRequest
|
|
)
|
|
list_req = list_req_type(
|
|
parent=location_ref.RelativeName(),
|
|
filter=args.filter,
|
|
orderBy=','.join(args.sort_by) if args.sort_by else None,
|
|
)
|
|
|
|
return list_pager.YieldFromList(
|
|
service=self._client.projects_locations_connectionProfiles,
|
|
request=list_req,
|
|
limit=args.limit,
|
|
batch_size=args.page_size,
|
|
field='connectionProfiles',
|
|
batch_size_attribute='pageSize')
|
|
|
|
def Discover(self, parent_ref, release_track, args):
|
|
"""Discover a connection profile.
|
|
|
|
Args:
|
|
parent_ref: a Resource reference to a parent datastream.projects.locations
|
|
resource for this connection profile.
|
|
release_track: Some arguments are added based on the command release
|
|
track.
|
|
args: argparse.Namespace, The arguments that this command was invoked
|
|
with.
|
|
|
|
Returns:
|
|
Operation: the operation for discovering the connection profile.
|
|
"""
|
|
request = self._messages.DiscoverConnectionProfileRequest()
|
|
if args.connection_profile_name:
|
|
connection_profile_ref = args.CONCEPTS.connection_profile_name.Parse()
|
|
request.connectionProfileName = connection_profile_ref.RelativeName()
|
|
elif args.connection_profile_object_file:
|
|
request.connectionProfile = self._ParseConnectionProfileObjectFile(
|
|
args.connection_profile_object_file, release_track
|
|
)
|
|
|
|
if args.recursive or args.full_hierarchy:
|
|
request.fullHierarchy = True
|
|
elif args.recursive_depth:
|
|
request.hierarchyDepth = (int)(args.recursive_depth)
|
|
elif args.hierarchy_depth:
|
|
request.hierarchyDepth = (int)(args.hierarchy_depth)
|
|
else:
|
|
request.fullHierarchy = False
|
|
|
|
if args.mysql_rdbms_file:
|
|
request.mysqlRdbms = util.ParseMysqlRdbmsFile(self._messages,
|
|
args.mysql_rdbms_file,
|
|
release_track)
|
|
elif args.oracle_rdbms_file:
|
|
request.oracleRdbms = util.ParseOracleRdbmsFile(self._messages,
|
|
args.oracle_rdbms_file,
|
|
release_track)
|
|
elif args.postgresql_rdbms_file:
|
|
request.postgresqlRdbms = util.ParsePostgresqlRdbmsFile(
|
|
self._messages, args.postgresql_rdbms_file)
|
|
elif args.sqlserver_rdbms_file:
|
|
request.sqlServerRdbms = util.ParseSqlServerRdbmsFile(
|
|
self._messages, args.sqlserver_rdbms_file
|
|
)
|
|
discover_req_type = (
|
|
self._messages.DatastreamProjectsLocationsConnectionProfilesDiscoverRequest
|
|
)
|
|
discover_req = discover_req_type(
|
|
discoverConnectionProfileRequest=request, parent=parent_ref)
|
|
return self._service.Discover(discover_req)
|
|
|
|
def GetUri(self, name):
|
|
"""Get the URL string for a connection profile.
|
|
|
|
Args:
|
|
name: connection profile's full name.
|
|
|
|
Returns:
|
|
URL of the connection profile resource
|
|
"""
|
|
|
|
uri = self._resource_parser.ParseRelativeName(
|
|
name, collection='datastream.projects.locations.connectionProfiles')
|
|
return uri.SelfLink()
|