473 lines
13 KiB
Python
473 lines
13 KiB
Python
#!/usr/bin/env python
|
|
"""The BigQuery CLI connection client library."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import re
|
|
from typing import Any, Dict, List, Mapping, Optional
|
|
|
|
from googleapiclient import discovery
|
|
import inflection
|
|
|
|
from utils import bq_api_utils
|
|
from utils import bq_error
|
|
from utils import bq_id_utils
|
|
from utils import bq_processor_utils
|
|
|
|
Service = bq_api_utils.Service
|
|
|
|
# Data Transfer Service Authorization Info
|
|
AUTHORIZATION_CODE = 'authorization_code'
|
|
VERSION_INFO = 'version_info'
|
|
|
|
# Valid proto field name regex.
|
|
_VALID_FIELD_NAME_REGEXP = r'[0-9A-Za-z_]+'
|
|
|
|
# Connection field mask paths pointing to map keys.
|
|
_MAP_KEY_PATHS = [
|
|
'configuration.parameters',
|
|
'configuration.authentication.parameters',
|
|
]
|
|
|
|
_AUTH_PROFILE_ID_PATH = 'configuration.authentication.profile_id'
|
|
_AUTH_PATH = 'configuration.authentication'
|
|
|
|
|
|
def GetConnection(
|
|
client: discovery.Resource,
|
|
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
):
|
|
"""Gets connection with the given connection reference.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
reference: Connection to get.
|
|
|
|
Returns:
|
|
Connection object with the given id.
|
|
"""
|
|
return (
|
|
client.projects()
|
|
.locations()
|
|
.connections()
|
|
.get(name=reference.path())
|
|
.execute()
|
|
)
|
|
|
|
|
|
def CreateConnection(
|
|
client: discovery.Resource,
|
|
project_id: str,
|
|
location: str,
|
|
connection_type: str, # Actually a CONNECTION_TYPE_TO_PROPERTY_MAP key.
|
|
properties: str,
|
|
connection_credential: Optional[str] = None,
|
|
display_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
connection_id: Optional[str] = None,
|
|
kms_key_name: Optional[str] = None,
|
|
connector_configuration: Optional[str] = None,
|
|
):
|
|
"""Create a connection with the given connection reference.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
project_id: Project ID.
|
|
location: Location of connection.
|
|
connection_type: Type of connection, allowed values: ['CLOUD_SQL']
|
|
properties: Connection properties in JSON format.
|
|
connection_credential: Connection credentials in JSON format.
|
|
display_name: Friendly name for the connection.
|
|
description: Description of the connection.
|
|
connection_id: Optional connection ID.
|
|
kms_key_name: Optional KMS key name.
|
|
connector_configuration: Optional configuration for connector.
|
|
|
|
Returns:
|
|
Connection object that was created.
|
|
"""
|
|
|
|
connection = {}
|
|
|
|
if display_name:
|
|
connection['friendlyName'] = display_name
|
|
|
|
if description:
|
|
connection['description'] = description
|
|
|
|
if kms_key_name:
|
|
connection['kmsKeyName'] = kms_key_name
|
|
|
|
property_name = bq_processor_utils.CONNECTION_TYPE_TO_PROPERTY_MAP.get(
|
|
connection_type
|
|
)
|
|
if property_name:
|
|
connection[property_name] = bq_processor_utils.ParseJson(properties)
|
|
if connection_credential:
|
|
if isinstance(connection[property_name], Mapping):
|
|
connection[property_name]['credential'] = bq_processor_utils.ParseJson(
|
|
connection_credential
|
|
)
|
|
else:
|
|
raise ValueError('The `properties` were not a dictionary.')
|
|
elif connector_configuration:
|
|
connection['configuration'] = bq_processor_utils.ParseJson(
|
|
connector_configuration
|
|
)
|
|
else:
|
|
error = (
|
|
'connection_type %s is unsupported or connector_configuration is not'
|
|
' specified' % connection_type
|
|
)
|
|
raise ValueError(error)
|
|
|
|
parent = 'projects/%s/locations/%s' % (project_id, location)
|
|
return (
|
|
client.projects()
|
|
.locations()
|
|
.connections()
|
|
.create(parent=parent, connectionId=connection_id, body=connection)
|
|
.execute()
|
|
)
|
|
|
|
|
|
def UpdateConnection(
|
|
client: discovery.Resource,
|
|
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
connection_type: Optional[
|
|
str
|
|
] = None, # Actually a CONNECTION_TYPE_TO_PROPERTY_MAP key.
|
|
properties: Optional[str] = None,
|
|
connection_credential: Optional[str] = None,
|
|
display_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
kms_key_name: Optional[str] = None,
|
|
connector_configuration: Optional[str] = None,
|
|
):
|
|
"""Update connection with the given connection reference.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
reference: Connection to update
|
|
connection_type: Type of connection, allowed values: ['CLOUD_SQL']
|
|
properties: Connection properties
|
|
connection_credential: Connection credentials in JSON format.
|
|
display_name: Friendly name for the connection
|
|
description: Description of the connection
|
|
kms_key_name: Optional KMS key name.
|
|
connector_configuration: Optional configuration for connector
|
|
|
|
Raises:
|
|
bq_error.BigqueryClientError: The connection type is not defined
|
|
when updating
|
|
connection_credential or properties.
|
|
Returns:
|
|
Connection object that was created.
|
|
"""
|
|
|
|
if (connection_credential or properties) and not connection_type:
|
|
raise bq_error.BigqueryClientError(
|
|
'connection_type is required when updating connection_credential or'
|
|
' properties'
|
|
)
|
|
connection = {}
|
|
update_mask = []
|
|
|
|
if display_name:
|
|
connection['friendlyName'] = display_name
|
|
update_mask.append('friendlyName')
|
|
|
|
if description:
|
|
connection['description'] = description
|
|
update_mask.append('description')
|
|
|
|
if kms_key_name is not None:
|
|
update_mask.append('kms_key_name')
|
|
if kms_key_name:
|
|
connection['kmsKeyName'] = kms_key_name
|
|
|
|
if connection_type == 'CLOUD_SQL':
|
|
if properties:
|
|
cloudsql_properties = bq_processor_utils.ParseJson(properties)
|
|
connection['cloudSql'] = cloudsql_properties
|
|
|
|
update_mask.extend(
|
|
_GetUpdateMask(connection_type.lower(), cloudsql_properties)
|
|
)
|
|
|
|
else:
|
|
connection['cloudSql'] = {}
|
|
|
|
if connection_credential:
|
|
connection['cloudSql']['credential'] = bq_processor_utils.ParseJson(
|
|
connection_credential
|
|
)
|
|
update_mask.append('cloudSql.credential')
|
|
|
|
elif connection_type == 'AWS':
|
|
|
|
if properties:
|
|
aws_properties = bq_processor_utils.ParseJson(properties)
|
|
connection['aws'] = aws_properties
|
|
if aws_properties.get('crossAccountRole') and aws_properties[
|
|
'crossAccountRole'
|
|
].get('iamRoleId'):
|
|
update_mask.append('aws.crossAccountRole.iamRoleId')
|
|
if aws_properties.get('accessRole') and aws_properties['accessRole'].get(
|
|
'iamRoleId'
|
|
):
|
|
update_mask.append('aws.access_role.iam_role_id')
|
|
else:
|
|
connection['aws'] = {}
|
|
|
|
if connection_credential:
|
|
connection['aws']['credential'] = bq_processor_utils.ParseJson(
|
|
connection_credential
|
|
)
|
|
update_mask.append('aws.credential')
|
|
|
|
elif connection_type == 'Azure':
|
|
if properties:
|
|
azure_properties = bq_processor_utils.ParseJson(properties)
|
|
connection['azure'] = azure_properties
|
|
if azure_properties.get('customerTenantId'):
|
|
update_mask.append('azure.customer_tenant_id')
|
|
if azure_properties.get('federatedApplicationClientId'):
|
|
update_mask.append('azure.federated_application_client_id')
|
|
|
|
elif connection_type == 'SQL_DATA_SOURCE':
|
|
if properties:
|
|
sql_data_source_properties = bq_processor_utils.ParseJson(properties)
|
|
connection['sqlDataSource'] = sql_data_source_properties
|
|
|
|
update_mask.extend(
|
|
_GetUpdateMask(connection_type.lower(), sql_data_source_properties)
|
|
)
|
|
|
|
else:
|
|
connection['sqlDataSource'] = {}
|
|
|
|
if connection_credential:
|
|
connection['sqlDataSource']['credential'] = bq_processor_utils.ParseJson(
|
|
connection_credential
|
|
)
|
|
update_mask.append('sqlDataSource.credential')
|
|
|
|
elif connection_type == 'CLOUD_SPANNER':
|
|
if properties:
|
|
cloudspanner_properties = bq_processor_utils.ParseJson(properties)
|
|
connection['cloudSpanner'] = cloudspanner_properties
|
|
update_mask.extend(
|
|
_GetUpdateMask(connection_type.lower(), cloudspanner_properties)
|
|
)
|
|
else:
|
|
connection['cloudSpanner'] = {}
|
|
|
|
elif connection_type == 'SPARK':
|
|
if properties:
|
|
spark_properties = bq_processor_utils.ParseJson(properties)
|
|
connection['spark'] = spark_properties
|
|
if 'sparkHistoryServerConfig' in spark_properties:
|
|
update_mask.append('spark.spark_history_server_config')
|
|
if 'metastoreServiceConfig' in spark_properties:
|
|
update_mask.append('spark.metastore_service_config')
|
|
else:
|
|
connection['spark'] = {}
|
|
elif connector_configuration:
|
|
connection['configuration'] = bq_processor_utils.ParseJson(
|
|
connector_configuration
|
|
)
|
|
update_mask.extend(
|
|
_GetUpdateMaskRecursively('configuration', connection['configuration'])
|
|
)
|
|
if _AUTH_PROFILE_ID_PATH in update_mask and _AUTH_PATH not in update_mask:
|
|
update_mask.append(_AUTH_PATH)
|
|
|
|
return (
|
|
client.projects()
|
|
.locations()
|
|
.connections()
|
|
.patch(
|
|
name=reference.path(),
|
|
updateMask=','.join(update_mask),
|
|
body=connection,
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
|
|
def _GetUpdateMask(
|
|
base_path: str, json_properties: Dict[str, Any]
|
|
) -> List[str]:
|
|
"""Creates an update mask from json_properties.
|
|
|
|
Arguments:
|
|
base_path: 'cloud_sql'
|
|
json_properties: { 'host': ... , 'instanceId': ... }
|
|
|
|
Returns:
|
|
list of paths in snake case:
|
|
mask = ['cloud_sql.host', 'cloud_sql.instance_id']
|
|
"""
|
|
return [
|
|
base_path + '.' + inflection.underscore(json_property)
|
|
for json_property in json_properties
|
|
]
|
|
|
|
|
|
def _EscapeIfRequired(prefix: str, name: str) -> str:
|
|
"""Escapes name if it points to a map key or converts it to snake case.
|
|
|
|
If name points to a map key:
|
|
1. Do not change the name.
|
|
2. Escape name with backticks if it is not a valid proto field name.
|
|
|
|
Args:
|
|
prefix: field mask prefix to check if name points to a map key.
|
|
name: name of the field.
|
|
|
|
Returns:
|
|
escaped name
|
|
"""
|
|
if prefix in _MAP_KEY_PATHS:
|
|
return (
|
|
name
|
|
if re.fullmatch(_VALID_FIELD_NAME_REGEXP, name)
|
|
else ('`' + name + '`')
|
|
)
|
|
|
|
# Otherwise, convert name to snake case
|
|
return inflection.underscore(name)
|
|
|
|
|
|
def _GetUpdateMaskRecursively(
|
|
prefix: str, json_value: Dict[str, Any]
|
|
) -> List[str]:
|
|
"""Recursively traverses json_value and returns a list of update mask paths.
|
|
|
|
Args:
|
|
prefix: current prefix of the json value.
|
|
json_value: value to traverse.
|
|
|
|
Returns:
|
|
a field mask containing all the set paths in the json value.
|
|
"""
|
|
if not isinstance(json_value, dict) or not json_value:
|
|
return [prefix]
|
|
|
|
result = []
|
|
for name in json_value:
|
|
new_prefix = prefix + '.' + _EscapeIfRequired(prefix, name)
|
|
new_json_value = json_value.get(name)
|
|
result.extend(_GetUpdateMaskRecursively(new_prefix, new_json_value))
|
|
|
|
return result
|
|
|
|
|
|
def DeleteConnection(
|
|
client: discovery.Resource,
|
|
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
):
|
|
"""Delete a connection with the given connection reference.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
reference: Connection to delete.
|
|
"""
|
|
client.projects().locations().connections().delete(
|
|
name=reference.path()
|
|
).execute()
|
|
|
|
|
|
def ListConnections(
|
|
client: discovery.Resource,
|
|
project_id: str,
|
|
location: str,
|
|
max_results: int,
|
|
page_token: Optional[str],
|
|
):
|
|
"""List connections in the project and location for the given reference.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
project_id: Project ID.
|
|
location: Location.
|
|
max_results: Number of results to show.
|
|
page_token: Token to retrieve the next page of results.
|
|
|
|
Returns:
|
|
List of connection objects
|
|
"""
|
|
parent = 'projects/%s/locations/%s' % (project_id, location)
|
|
return (
|
|
client.projects()
|
|
.locations()
|
|
.connections()
|
|
.list(parent=parent, pageToken=page_token, pageSize=max_results)
|
|
.execute()
|
|
)
|
|
|
|
|
|
def SetConnectionIAMPolicy(
|
|
client: discovery.Resource,
|
|
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
policy: str,
|
|
):
|
|
"""Sets IAM policy for the given connection resource.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
reference: the ConnectionReference for the connection resource.
|
|
policy: The policy string in JSON format.
|
|
|
|
Returns:
|
|
The updated IAM policy attached to the given connection resource.
|
|
|
|
Raises:
|
|
BigqueryTypeError: if reference is not a ConnectionReference.
|
|
"""
|
|
bq_id_utils.typecheck(
|
|
reference,
|
|
bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
method='SetConnectionIAMPolicy',
|
|
)
|
|
return (
|
|
client.projects()
|
|
.locations()
|
|
.connections()
|
|
.setIamPolicy(resource=reference.path(), body={'policy': policy})
|
|
.execute()
|
|
)
|
|
|
|
|
|
def GetConnectionIAMPolicy(
|
|
client: discovery.Resource,
|
|
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
):
|
|
"""Gets IAM policy for the given connection resource.
|
|
|
|
Arguments:
|
|
client: the client used to make the request.
|
|
reference: the ConnectionReference for the connection resource.
|
|
|
|
Returns:
|
|
The IAM policy attached to the given connection resource.
|
|
|
|
Raises:
|
|
BigqueryTypeError: if reference is not a ConnectionReference.
|
|
"""
|
|
bq_id_utils.typecheck(
|
|
reference,
|
|
bq_id_utils.ApiClientHelper.ConnectionReference,
|
|
method='GetConnectionIAMPolicy',
|
|
)
|
|
return (
|
|
client.projects()
|
|
.locations()
|
|
.connections()
|
|
.getIamPolicy(resource=reference.path())
|
|
.execute()
|
|
)
|