feat: Add new gcloud commands, API clients, and third-party libraries across various services.

This commit is contained in:
2026-01-01 20:26:35 +01:00
parent 5e23cbece0
commit a19e592eb7
25221 changed files with 8324611 additions and 0 deletions

View File

@@ -0,0 +1,738 @@
#!/usr/bin/env python
"""BigqueryClient class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from http import client as http_client_lib
import json
import logging
import tempfile
import time
import traceback
from typing import Callable, List, Optional, Union
import urllib
# To configure apiclient logging.
from absl import flags
import certifi
import googleapiclient
from googleapiclient import discovery
import httplib2
from typing_extensions import TypeAlias
import bq_flags
import bq_utils
import credential_loader
from auth import main_credential_loader
from clients import bigquery_http
from clients import utils as bq_client_utils
from clients import wait_printer
from discovery_documents import discovery_document_cache
from discovery_documents import discovery_document_loader
from utils import bq_api_utils
from utils import bq_error
from utils import bq_logging
# TODO(b/388312723): Review if we can remove this try/except block.
try:
from google.auth import credentials as google_credentials # pylint: disable=g-import-not-at-top
_HAS_GOOGLE_AUTH = True
except ImportError:
_HAS_GOOGLE_AUTH = False
# TODO(b/388312723): Review if we can remove this try/except block.
try:
import google_auth_httplib2 # pylint: disable=g-import-not-at-top
_HAS_GOOGLE_AUTH_HTTPLIB2 = True
except ImportError:
_HAS_GOOGLE_AUTH_HTTPLIB2 = False
# A unique non-None default, for use in kwargs that need to
# distinguish default from None.
_DEFAULT = object()
LegacyAndGoogleAuthCredentialsUnionType = Union[
main_credential_loader.GoogleAuthCredentialsUnionType,
credential_loader.CredentialsFromFlagsUnionType,
]
Service = bq_api_utils.Service
Http: TypeAlias = Union[
httplib2.Http,
]
AuthorizedHttp: TypeAlias = Union[
httplib2.Http,
'google_auth_httplib2.AuthorizedHttp',
]
class BigqueryClient:
"""Class encapsulating interaction with the BigQuery service."""
class JobCreationMode(str, enum.Enum):
"""Enum of job creation mode."""
JOB_CREATION_REQUIRED = 'JOB_CREATION_REQUIRED'
JOB_CREATION_OPTIONAL = 'JOB_CREATION_OPTIONAL'
def __init__(
self,
*,
api: str,
api_version: str,
project_id: Optional[str] = '',
dataset_id: Optional[str] = '',
discovery_document: Union[bytes, object, None] = _DEFAULT,
job_property: str = '',
trace: Optional[str] = None,
sync: bool = True,
wait_printer_factory: Optional[
Callable[[], wait_printer.WaitPrinter]
] = wait_printer.TransitionWaitPrinter,
job_id_generator: bq_client_utils.JobIdGenerator = bq_client_utils.JobIdGeneratorIncrementing(
bq_client_utils.JobIdGeneratorRandom()
),
max_rows_per_request: Optional[int] = None,
quota_project_id: Optional[str] = None,
use_google_auth: bool = False,
credentials: Optional[LegacyAndGoogleAuthCredentialsUnionType] = None,
enable_resumable_uploads: bool = True,
**kwds,
):
"""Initializes BigqueryClient.
Required keywords:
api: the api to connect to, for example "bigquery".
api_version: the version of the api to connect to, for example "v2".
Optional keywords:
project_id: a default project id to use. While not required for
initialization, a project_id is required when calling any
method that creates a job on the server. Methods that have
this requirement pass through **kwds, and will raise
bq_error.BigqueryClientConfigurationError if no project_id can be
found.
dataset_id: a default dataset id to use.
discovery_document: the discovery document to use. If None, one
will be retrieved from the discovery api. If not specified,
the built-in discovery document will be used.
job_property: a list of "key=value" strings defining properties
to apply to all job operations.
trace: a tracing header to include in all bigquery api requests.
sync: boolean, when inserting jobs, whether to wait for them to
complete before returning from the insert request.
wait_printer_factory: a function that returns a WaitPrinter.
This will be called for each job that we wait on. See WaitJob().
Raises:
ValueError: if keywords are missing or incorrectly specified.
"""
super().__init__()
self.api = api
self.api_version = api_version
self.project_id = project_id
self.dataset_id = dataset_id
self.discovery_document = discovery_document
self.job_property = job_property
self.trace = trace
self.sync = sync
self.wait_printer_factory = wait_printer_factory
self.job_id_generator = job_id_generator
self.max_rows_per_request = max_rows_per_request
self.quota_project_id = quota_project_id
self.use_google_auth = use_google_auth
self.credentials = credentials
self.enable_resumable_uploads = enable_resumable_uploads
# TODO(b/324243535): Delete this block to make attributes explicit.
for key, value in kwds.items():
setattr(self, key, value)
self._apiclient = None
self._routines_apiclient = None
self._row_access_policies_apiclient = None
self._op_transfer_client = None
self._op_reservation_client = None
self._op_bi_reservation_client = None
self._models_apiclient = None
self._op_connection_service_client = None
self._iam_policy_apiclient = None
default_flag_values = {
'iam_policy_discovery_document': _DEFAULT,
}
for flagname, default in default_flag_values.items():
if not hasattr(self, flagname):
setattr(self, flagname, default)
columns_to_include_for_transfer_run = [
'updateTime',
'schedule',
'runTime',
'scheduleTime',
'params',
'endTime',
'dataSourceId',
'destinationDatasetId',
'state',
'startTime',
'name',
]
# These columns appear to be empty with scheduling a new transfer run
# so there are listed as excluded from the transfer run output.
columns_excluded_for_make_transfer_run = ['schedule', 'endTime', 'startTime']
def GetHttp(
self,
) -> AuthorizedHttp:
"""Returns the httplib2 Http to use."""
proxy_info = httplib2.proxy_info_from_environment
if flags.FLAGS.proxy_address and flags.FLAGS.proxy_port:
try:
port = int(flags.FLAGS.proxy_port)
except ValueError as e:
raise ValueError(
'Invalid value for proxy_port: {}'.format(flags.FLAGS.proxy_port)
) from e
proxy_info = httplib2.ProxyInfo(
proxy_type=3,
proxy_host=flags.FLAGS.proxy_address,
proxy_port=port,
proxy_user=flags.FLAGS.proxy_username or None,
proxy_pass=flags.FLAGS.proxy_password or None,
)
http = httplib2.Http(
proxy_info=proxy_info,
ca_certs=flags.FLAGS.ca_certificates_file or certifi.where(),
disable_ssl_certificate_validation=flags.FLAGS.disable_ssl_validation,
)
if hasattr(http, 'redirect_codes'):
http.redirect_codes = set(http.redirect_codes) - {308}
if flags.FLAGS.mtls:
_, self._cert_file = tempfile.mkstemp()
_, self._key_file = tempfile.mkstemp()
discovery.add_mtls_creds(
http, discovery.get_client_options(), self._cert_file, self._key_file
)
return http
def GetDiscoveryUrl(
self,
service: Service,
api_version: str,
domain_root: Optional[str] = None,
labels: Optional[str] = None,
) -> str:
"""Returns the url to the discovery document for bigquery."""
discovery_url = None # pylint:disable=unused-variable
if not discovery_url:
discovery_url = bq_api_utils.get_discovery_url_from_root_url(
domain_root
or bq_api_utils.get_tpc_root_url_from_flags(
service=service, inputted_flags=bq_flags
),
api_version=api_version,
)
return discovery_url
def GetAuthorizedHttp(
self,
credentials: LegacyAndGoogleAuthCredentialsUnionType,
http: Http,
) -> AuthorizedHttp:
"""Returns an http client that is authorized with the given credentials."""
if self.use_google_auth:
if not _HAS_GOOGLE_AUTH:
logging.error(
'System is set to use `google.auth`, but it did not load.'
)
if not isinstance(credentials, google_credentials.Credentials):
logging.error(
'The system is using `google.auth` but the parsed credentials are'
' of an incorrect type: %s',
type(credentials),
)
else:
logging.debug('System is set to not use `google.auth`.')
# LINT.IfChange(http_authorization)
if _HAS_GOOGLE_AUTH and isinstance(
credentials, google_credentials.Credentials
):
if not _HAS_GOOGLE_AUTH_HTTPLIB2:
raise ValueError(
'Credentials from google.auth specified, but '
'google-api-python-client is unable to use these credentials '
'unless google-auth-httplib2 is installed. Please install '
'google-auth-httplib2.'
)
return google_auth_httplib2.AuthorizedHttp(credentials, http=http)
# Note: This block simplified adding typing and should be removable when
# legacy credentials are removed.
if hasattr(credentials, 'authorize'):
return credentials.authorize(http)
else:
raise TypeError('Unsupported credential type: {type(credentials)}')
# LINT.ThenChange(
# //depot/google3/cloud/helix/testing/e2e/python_api_client/api_client_lib.py:http_authorization,
# //depot/google3/cloud/helix/testing/e2e/python_api_client/api_client_util.py:http_authorization,
# )
def _LoadDiscoveryDocumentLocal(
self,
service: Service,
discovery_url: Optional[str],
api_version: str,
) -> Optional[Union[str, bytes, object]]:
"""Loads the local discovery document for the given service.
Args:
service: The BigQuery service being used.
discovery_url: The URL to load the discovery doc from.
api_version: The API version for the targeted discovery doc.
Returns:
discovery_document The loaded discovery document.
"""
discovery_document = None
if self.discovery_document != _DEFAULT:
discovery_document = self.discovery_document
logging.info(
'Skipping local "%s" discovery document load since discovery_document'
' has a value: %s',
service,
discovery_document,
)
return discovery_document
if discovery_url is not None:
logging.info(
'Skipping the local "%s" discovery document load since discovery_url'
' has a value',
service,
)
elif bq_flags.BIGQUERY_DISCOVERY_API_KEY_FLAG.present:
logging.info(
'Skipping local "%s" discovery document load since the'
' bigquery_discovery_api_key flag was used',
service,
)
else:
# Load the local api description if one exists and is supported.
try:
discovery_document = (
discovery_document_loader.load_local_discovery_doc_from_service(
service=service,
api=self.api,
api_version=api_version,
)
)
if discovery_document:
logging.info('The "%s" discovery doc is already loaded', service)
except FileNotFoundError as e:
logging.warning(
'Failed to load the "%s" discovery doc from local files: %s',
service,
e,
)
return discovery_document
def _LoadDiscoveryDocumentUrl(
self,
service: Service,
http: AuthorizedHttp,
discovery_url: str,
) -> Optional[Union[str, bytes, object]]:
"""Loads the discovery document from the provided URL.
Args:
service: The BigQuery service being used.
http: Http object to be used to execute request.
discovery_url: The URL to load the discovery doc from.
Returns:
discovery_document The loaded discovery document.
Raises:
bq_error.BigqueryClientError: If the request to load the discovery
document fails.
"""
discovery_document = None
# Attempt to retrieve discovery doc with retry logic for transient,
# retry-able errors.
max_retries = 3
iterations = 0
headers = (
{'X-ESF-Use-Cloud-UberMint-If-Enabled': '1'}
if hasattr(self, 'use_uber_mint') and self.use_uber_mint
else None
)
while iterations < max_retries and discovery_document is None:
if iterations > 0:
# Wait briefly before retrying with exponentially increasing wait.
time.sleep(2**iterations)
iterations += 1
try:
logging.info(
'Requesting "%s" discovery document from %s',
service,
discovery_url,
)
if headers:
response_metadata, discovery_document = http.request(
discovery_url, headers=headers
)
else:
response_metadata, discovery_document = http.request(discovery_url)
discovery_document = discovery_document.decode('utf-8')
if int(response_metadata.get('status')) >= 400:
msg = 'Got %s response from discovery url: %s' % (
response_metadata.get('status'),
discovery_url,
)
logging.error('%s:\n%s', msg, discovery_document)
raise bq_error.BigqueryCommunicationError(msg)
except (
httplib2.HttpLib2Error,
googleapiclient.errors.HttpError,
http_client_lib.HTTPException,
) as e:
# We can't find the specified server. This can be thrown for
# multiple reasons, so inspect the error.
if hasattr(e, 'content'):
if iterations == max_retries:
content = ''
if hasattr(e, 'content'):
content = e.content
raise bq_error.BigqueryCommunicationError(
'Cannot contact server. Please try again.\nError: %r'
'\nContent: %s' % (e, content)
)
else:
if iterations == max_retries:
raise bq_error.BigqueryCommunicationError(
'Cannot contact server. Please try again.\nTraceback: %s'
% (traceback.format_exc(),)
)
except IOError as e:
if iterations == max_retries:
raise bq_error.BigqueryCommunicationError(
'Cannot contact server. Please try again.\nError: %r' % (e,)
)
except googleapiclient.errors.UnknownApiNameOrVersion as e:
# We can't resolve the discovery url for the given server.
# Don't retry in this case.
raise bq_error.BigqueryCommunicationError(
'Invalid API name or version: %s' % (str(e),)
)
return discovery_document
def BuildApiClient(
self,
service: Service,
discovery_url: Optional[str] = None,
discovery_root_url: Optional[str] = None,
api_version: Optional[str] = None,
domain_root: Optional[str] = None,
labels: Optional[str] = None,
) -> discovery.Resource:
"""Build and return BigQuery Dynamic client from discovery document."""
logging.info(
'BuildApiClient discovery_url: %s, discovery_root_url: %s',
discovery_url,
discovery_root_url,
)
if api_version is None:
api_version = self.api_version
# If self.credentials is of type google.auth, it has to be cleared of the
# _quota_project_id value later on in this function for discovery requests.
# bigquery_model has to be built with the quota project retained, so in this
# version of the implementation, it's built before discovery requests take
# place.
bigquery_model = bigquery_http.BigqueryModel(
trace=self.trace,
quota_project_id=bq_utils.GetEffectiveQuotaProjectIDForHTTPHeader(
quota_project_id=self.quota_project_id,
project_id=self.project_id,
use_google_auth=self.use_google_auth,
credentials=self.credentials,
),
)
bq_request_builder = bigquery_http.BigqueryHttp.Factory(
bigquery_model,
)
# Clean up quota project ID from Google Auth credentials.
# This is specifically needed to construct a http object used for discovery
# requests below as quota project ID shouldn't participate in discovery
# document retrieval, otherwise the discovery request would result in a
# permission error seen in b/321286043.
if self.use_google_auth and hasattr(self.credentials, '_quota_project_id'):
self.credentials._quota_project_id = None # pylint: disable=protected-access
http = None
if not http:
http_client = self.GetHttp()
http = self.GetAuthorizedHttp(self.credentials, http_client)
discovery_document = None
# First, trying to load the discovery document from the local package.
if discovery_document is None:
discovery_document = self._LoadDiscoveryDocumentLocal(
service=service,
discovery_url=discovery_url,
api_version=api_version,
)
# If document was not loaded from the local package and
# discovery_url is not provided, we will generate the url to fetch from the
# server.
discovery_url_not_provided = discovery_url is None
if discovery_document is None and discovery_url is None:
discovery_url = self.GetDiscoveryUrl(
service=service,
api_version=api_version,
domain_root=domain_root,
labels=labels,
)
# If discovery_document is still not loaded, fetch it from the server.
if not discovery_document:
discovery_document = self._LoadDiscoveryDocumentUrl(
service=service,
http=http,
discovery_url=discovery_url,
)
discovery_document_to_build_client = self.OverrideEndpoint(
discovery_document=discovery_document,
service=service,
discovery_root_url=discovery_root_url,
)
bq_logging.SaveStringToLogDirectoryIfAvailable(
file_prefix='discovery_document',
content=discovery_document_to_build_client,
apilog=bq_flags.APILOG.value,
)
try:
# If the underlying credentials object used for authentication is of type
# google.auth, its quota project ID will have been removed earlier in this
# function if one was provided explicitly. This specific http object
# created from that modified credentials object must be the one used for
# the discovery requests, otherwise they would result in a permission
# error as seen in b/321286043.
built_client = discovery.build_from_document(
discovery_document_to_build_client,
http=http,
model=bigquery_model,
requestBuilder=bq_request_builder,
)
except Exception:
logging.error(
'Error building from the "%s" discovery document: %s',
service,
discovery_document,
)
raise
return built_client
@property
def apiclient(self) -> discovery.Resource:
"""Returns a singleton ApiClient built for the BigQuery core API."""
if self._apiclient:
logging.info('Using the cached BigQuery API client')
else:
self._apiclient = self.BuildApiClient(service=Service.BIGQUERY)
return self._apiclient
def GetModelsApiClient(self) -> discovery.Resource:
"""Returns the apiclient attached to self."""
if self._models_apiclient is None:
self._models_apiclient = self.BuildApiClient(service=Service.BIGQUERY)
return self._models_apiclient
def GetRoutinesApiClient(self) -> discovery.Resource:
"""Return the apiclient attached to self."""
if self._routines_apiclient is None:
self._routines_apiclient = self.BuildApiClient(service=Service.BIGQUERY)
return self._routines_apiclient
def GetRowAccessPoliciesApiClient(self) -> discovery.Resource:
"""Return the apiclient attached to self."""
if self._row_access_policies_apiclient is None:
self._row_access_policies_apiclient = self.BuildApiClient(
service=Service.BIGQUERY
)
return self._row_access_policies_apiclient
def GetIAMPolicyApiClient(self) -> discovery.Resource:
"""Return the apiclient attached to self."""
if self._iam_policy_apiclient is None:
self._iam_policy_apiclient = self.BuildApiClient(
service=Service.BQ_IAM,
)
return self._iam_policy_apiclient
def GetInsertApiClient(self) -> discovery.Resource:
"""Return the apiclient that supports insert operation."""
discovery_url = None # pylint: disable=unused-variable
if discovery_url:
return self.BuildApiClient(
discovery_url=discovery_url, service=Service.BIGQUERY
)
return self.apiclient
def GetTransferV1ApiClient(
self, transferserver_address: Optional[str] = None
) -> discovery.Resource:
"""Return the apiclient that supports Transfer v1 operation."""
logging.info(
'GetTransferV1ApiClient transferserver_address: %s',
transferserver_address,
)
if self._op_transfer_client:
logging.info('Using the cached Transfer API client')
else:
path = transferserver_address or bq_api_utils.get_tpc_root_url_from_flags(
service=Service.DTS, inputted_flags=bq_flags
)
self._op_transfer_client = self.BuildApiClient(
domain_root=path,
discovery_root_url=path,
api_version='v1',
service=Service.DTS,
)
return self._op_transfer_client
def GetReservationApiClient(
self, reservationserver_address: Optional[str] = None
) -> discovery.Resource:
"""Return the apiclient that supports reservation operations."""
if self._op_reservation_client:
logging.info('Using the cached Reservations API client')
else:
path = (
reservationserver_address
or bq_api_utils.get_tpc_root_url_from_flags(
service=Service.RESERVATIONS,
inputted_flags=bq_flags,
)
)
reservation_version = 'v1'
labels = None
self._op_reservation_client = self.BuildApiClient(
service=Service.RESERVATIONS,
domain_root=path,
discovery_root_url=path,
api_version=reservation_version,
labels=labels,
)
return self._op_reservation_client
def GetConnectionV1ApiClient(
self, connection_service_address: Optional[str] = None
) -> discovery.Resource:
"""Return the apiclient that supports connections operations."""
if self._op_connection_service_client:
logging.info('Using the cached Connections API client')
else:
path = (
connection_service_address
or bq_api_utils.get_tpc_root_url_from_flags(
service=Service.CONNECTIONS,
inputted_flags=bq_flags,
)
)
discovery_url = bq_api_utils.get_discovery_url_from_root_url(
path, api_version='v1'
)
discovery_url = bq_api_utils.add_api_key_to_discovery_url(
discovery_url=discovery_url,
universe_domain=bq_flags.UNIVERSE_DOMAIN.value,
inputted_flags=bq_flags,
)
self._op_connection_service_client = self.BuildApiClient(
discovery_url=discovery_url,
discovery_root_url=path,
service=Service.CONNECTIONS,
api_version='v1',
)
return self._op_connection_service_client
def OverrideEndpoint(
self,
discovery_document: Union[str, bytes],
service: Service,
discovery_root_url: Optional[str] = None,
) -> Optional[str]:
"""Override rootUrl for regional endpoints.
Args:
discovery_document: BigQuery discovery document.
service: The BigQuery service being used.
discovery_root_url: The root URL to use for the discovery document.
Returns:
discovery_document updated discovery document.
Raises:
bq_error.BigqueryClientError: if location is not set and
use_regional_endpoints is.
"""
if discovery_document is None:
return discovery_document
discovery_document = bq_api_utils.parse_discovery_doc(discovery_document)
logging.info(
'Discovery doc routing values being considered for updates: rootUrl:'
' (%s), basePath: (%s), baseUrl: (%s)',
discovery_document['rootUrl'],
discovery_document['basePath'],
discovery_document['baseUrl'],
)
is_prod = True
original_root_url = discovery_document['rootUrl']
if is_prod:
discovery_document['rootUrl'] = bq_api_utils.get_tpc_root_url_from_flags(
service=service, inputted_flags=bq_flags
)
discovery_document['baseUrl'] = urllib.parse.urljoin(
discovery_document['rootUrl'], discovery_document['servicePath']
)
logging.info(
'Discovery doc routing values post updates: rootUrl: (%s), basePath:'
' (%s), baseUrl: (%s)',
discovery_document['rootUrl'],
discovery_document['basePath'],
discovery_document['baseUrl'],
)
return json.dumps(discovery_document)

View File

@@ -0,0 +1,7 @@
#!/usr/bin/env python
from clients import bigquery_client
class BigqueryClientExtended(bigquery_client.BigqueryClient):
"""Class extending BigqueryClient to add resource specific functionality."""

View File

@@ -0,0 +1,268 @@
#!/usr/bin/env python
# pylint: disable=g-unknown-interpreter
# Copyright 2012 Google Inc. All Rights Reserved.
"""Bigquery Client library for Python."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
# To configure apiclient logging.
from absl import flags
import googleapiclient
from googleapiclient import http as http_request
from googleapiclient import model
import httplib2
import bq_flags
import bq_utils
from clients import utils as bq_client_utils
_NUM_RETRIES_FOR_SERVER_SIDE_ERRORS = 3
# pylint: disable=protected-access
_ORIGINAL_GOOGLEAPI_CLIENT_RETRY_REQUEST = http_request._retry_request
# Note: All the `Optional` added here is to support tests.
def _RetryRequest(
http: Optional[httplib2.Http],
num_retries: int,
req_type: Optional[str],
sleep: Optional[Callable[[float], None]],
rand: Optional[Callable[[int], float]],
uri: Optional[str],
method: Optional[str],
*args,
**kwargs,
):
"""Conditionally retries an HTTP request.
If the original request fails with a specific permission error, retry it once
without the x-goog-user-project header.
Args:
http: Http object to be used to execute request.
num_retries: Maximum number of retries.
req_type: Type of the request (used for logging retries).
sleep: Function to sleep for random time between retries.
rand: Function to sleep for random time between retries.
uri: URI to be requested.
method: HTTP method to be used.
*args: Additional arguments passed to http.request.
**kwargs: Additional arguments passed to http.request.
Returns:
resp, content - Response from the http request (may be HTTP 5xx).
"""
# Call the original http_request._retry_request first to get the original
# response.
resp, content = _ORIGINAL_GOOGLEAPI_CLIENT_RETRY_REQUEST(
http, num_retries, req_type, sleep, rand, uri, method, *args, **kwargs
)
if int(resp.status) == 403:
data = json.loads(content.decode('utf-8'))
if isinstance(data, dict) and 'message' in data['error']:
err_message = data['error']['message']
if 'roles/serviceusage.serviceUsageConsumer' in err_message:
if 'headers' in kwargs and 'x-goog-user-project' in kwargs['headers']:
del kwargs['headers']['x-goog-user-project']
logging.info(
'Retrying request without the x-goog-user-project header'
)
resp, content = _ORIGINAL_GOOGLEAPI_CLIENT_RETRY_REQUEST(
http,
num_retries,
req_type,
sleep,
rand,
uri,
method,
*args,
**kwargs,
)
return resp, content
http_request._retry_request = _RetryRequest
# pylint: enable=protected-access
class BigqueryModel(model.JsonModel):
"""Adds optional global parameters to all requests."""
def __init__(
self,
trace: Optional[str] = None,
quota_project_id: Optional[str] = None,
**kwds,
):
super().__init__(**kwds)
self.trace = trace
self.quota_project_id = quota_project_id
# pylint: disable=g-bad-name
def request(
self,
headers: Dict[str, str],
path_params: Dict[str, str],
query_params: Dict[str, Any], # TODO(b/338466958): This seems incorrect.
body_value: object,
) -> Tuple[Dict[str, str], Dict[str, str], str, str]:
"""Updates outgoing request.
Headers updated here will be applied to only requests of API methods having
JSON-type responses. For API methods with non-JSON-type responses, headers
need to be set in BigqueryHttp.Factory._Construct.
Args:
headers: dict, request headers
path_params: dict, parameters that appear in the request path
query_params: dict, parameters that appear in the query
body_value: object, the request body as a Python object, which must be
serializable.
Returns:
A tuple of (headers, path_params, query, body)
headers: dict, request headers
path_params: dict, parameters that appear in the request path
query: string, query part of the request URI
body: string, the body serialized in the desired wire format.
"""
if 'trace' not in query_params and self.trace:
headers['cookie'] = self.trace
if 'user-agent' not in headers:
headers['user-agent'] = ''
user_agent = ' '.join([bq_utils.GetUserAgent(), headers['user-agent']])
headers['user-agent'] = user_agent.strip()
if self.quota_project_id:
headers['x-goog-user-project'] = self.quota_project_id
if bq_flags.REQUEST_REASON.value:
headers['x-goog-request-reason'] = bq_flags.REQUEST_REASON.value
return super().request(headers, path_params, query_params, body_value)
# pylint: enable=g-bad-name
# pylint: disable=g-bad-name
def response(self, resp: httplib2.Response, content: str) -> object:
"""Convert the response wire format into a Python object.
Args:
resp: httplib2.Response, the HTTP response headers and status
content: string, the body of the HTTP response
Returns:
The body de-serialized as a Python object.
Raises:
googleapiclient.errors.HttpError if a non 2xx response is received.
"""
logging.info('Response from server with status code: %s', resp['status'])
return super().response(resp, content)
# pylint: enable=g-bad-name
class BigqueryHttp(http_request.HttpRequest):
"""Converts errors into Bigquery errors."""
def __init__(
self,
bigquery_model: BigqueryModel,
*args,
**kwds,
):
super().__init__(*args, **kwds)
logging.info(
'URL being requested from BQ client: %s %s', kwds['method'], args[2]
)
self._model = bigquery_model
@staticmethod
def Factory(
bigquery_model: BigqueryModel,
) -> Callable[..., 'BigqueryHttp']:
"""Returns a function that creates a BigqueryHttp with the given model."""
def _Construct(*args, **kwds):
# Headers set here will be applied to all requests made through this
# BigqueryHttp object. Headers set in BigqueryModel.request will be
# applied to only methods that expect a JSON-type response.
if 'headers' not in kwds:
kwds['headers'] = {}
# Set user-agent if not already set in BigqueryModel.request, e.g. for
# DELETE requests.
user_agent = kwds['headers'].get('user-agent', '')
bq_user_agent = bq_utils.GetUserAgent()
if str.lower(bq_user_agent) not in str.lower(user_agent):
user_agent = ' '.join([bq_user_agent, user_agent])
kwds['headers']['user-agent'] = user_agent.strip()
if (
'x-goog-user-project' not in kwds['headers']
and bigquery_model.quota_project_id
):
logging.info(
'Setting x-goog-user-project header to: %s',
bigquery_model.quota_project_id,
)
kwds['headers']['x-goog-user-project'] = bigquery_model.quota_project_id
if (
'x-goog-request-reason' not in kwds['headers']
and bq_flags.REQUEST_REASON.value
):
logging.info(
'Setting x-goog-request-reason header to: %s',
bq_flags.REQUEST_REASON.value,
)
kwds['headers']['x-goog-request-reason'] = bq_flags.REQUEST_REASON.value
captured_model = bigquery_model
return BigqueryHttp(
captured_model,
*args,
**kwds,
)
return _Construct
# This function is mostly usually called without any parameters from a client
# like the `client_dataset` code calling:
# `apiclient.datasets().insert(body=body, **args).execute()`
# pylint: disable=g-bad-name
def execute(
self,
http: Optional[httplib2.Http] = None,
num_retries: Optional[int] = None,
):
# pylint: enable=g-bad-name
try:
if num_retries is None:
num_retries = _NUM_RETRIES_FOR_SERVER_SIDE_ERRORS
return super().execute(
http=http,
num_retries=num_retries,
)
except googleapiclient.errors.HttpError as e:
# TODO(user): Remove this when apiclient supports logging
# of error responses.
self._model._log_response(e.resp, e.content) # pylint: disable=protected-access
bq_client_utils.RaiseErrorFromHttpError(e)
except (httplib2.HttpLib2Error, IOError) as e:
bq_client_utils.RaiseErrorFromNonHttpError(e)

View File

@@ -0,0 +1,472 @@
#!/usr/bin/env python
"""The BigQuery CLI connection client library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from typing import Any, Dict, List, Mapping, Optional
from googleapiclient import discovery
import inflection
from utils import bq_api_utils
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
Service = bq_api_utils.Service
# Data Transfer Service Authorization Info
AUTHORIZATION_CODE = 'authorization_code'
VERSION_INFO = 'version_info'
# Valid proto field name regex.
_VALID_FIELD_NAME_REGEXP = r'[0-9A-Za-z_]+'
# Connection field mask paths pointing to map keys.
_MAP_KEY_PATHS = [
'configuration.parameters',
'configuration.authentication.parameters',
]
_AUTH_PROFILE_ID_PATH = 'configuration.authentication.profile_id'
_AUTH_PATH = 'configuration.authentication'
def GetConnection(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
):
"""Gets connection with the given connection reference.
Arguments:
client: the client used to make the request.
reference: Connection to get.
Returns:
Connection object with the given id.
"""
return (
client.projects()
.locations()
.connections()
.get(name=reference.path())
.execute()
)
def CreateConnection(
client: discovery.Resource,
project_id: str,
location: str,
connection_type: str, # Actually a CONNECTION_TYPE_TO_PROPERTY_MAP key.
properties: str,
connection_credential: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
connection_id: Optional[str] = None,
kms_key_name: Optional[str] = None,
connector_configuration: Optional[str] = None,
):
"""Create a connection with the given connection reference.
Arguments:
client: the client used to make the request.
project_id: Project ID.
location: Location of connection.
connection_type: Type of connection, allowed values: ['CLOUD_SQL']
properties: Connection properties in JSON format.
connection_credential: Connection credentials in JSON format.
display_name: Friendly name for the connection.
description: Description of the connection.
connection_id: Optional connection ID.
kms_key_name: Optional KMS key name.
connector_configuration: Optional configuration for connector.
Returns:
Connection object that was created.
"""
connection = {}
if display_name:
connection['friendlyName'] = display_name
if description:
connection['description'] = description
if kms_key_name:
connection['kmsKeyName'] = kms_key_name
property_name = bq_processor_utils.CONNECTION_TYPE_TO_PROPERTY_MAP.get(
connection_type
)
if property_name:
connection[property_name] = bq_processor_utils.ParseJson(properties)
if connection_credential:
if isinstance(connection[property_name], Mapping):
connection[property_name]['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
else:
raise ValueError('The `properties` were not a dictionary.')
elif connector_configuration:
connection['configuration'] = bq_processor_utils.ParseJson(
connector_configuration
)
else:
error = (
'connection_type %s is unsupported or connector_configuration is not'
' specified' % connection_type
)
raise ValueError(error)
parent = 'projects/%s/locations/%s' % (project_id, location)
return (
client.projects()
.locations()
.connections()
.create(parent=parent, connectionId=connection_id, body=connection)
.execute()
)
def UpdateConnection(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
connection_type: Optional[
str
] = None, # Actually a CONNECTION_TYPE_TO_PROPERTY_MAP key.
properties: Optional[str] = None,
connection_credential: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
kms_key_name: Optional[str] = None,
connector_configuration: Optional[str] = None,
):
"""Update connection with the given connection reference.
Arguments:
client: the client used to make the request.
reference: Connection to update
connection_type: Type of connection, allowed values: ['CLOUD_SQL']
properties: Connection properties
connection_credential: Connection credentials in JSON format.
display_name: Friendly name for the connection
description: Description of the connection
kms_key_name: Optional KMS key name.
connector_configuration: Optional configuration for connector
Raises:
bq_error.BigqueryClientError: The connection type is not defined
when updating
connection_credential or properties.
Returns:
Connection object that was created.
"""
if (connection_credential or properties) and not connection_type:
raise bq_error.BigqueryClientError(
'connection_type is required when updating connection_credential or'
' properties'
)
connection = {}
update_mask = []
if display_name:
connection['friendlyName'] = display_name
update_mask.append('friendlyName')
if description:
connection['description'] = description
update_mask.append('description')
if kms_key_name is not None:
update_mask.append('kms_key_name')
if kms_key_name:
connection['kmsKeyName'] = kms_key_name
if connection_type == 'CLOUD_SQL':
if properties:
cloudsql_properties = bq_processor_utils.ParseJson(properties)
connection['cloudSql'] = cloudsql_properties
update_mask.extend(
_GetUpdateMask(connection_type.lower(), cloudsql_properties)
)
else:
connection['cloudSql'] = {}
if connection_credential:
connection['cloudSql']['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
update_mask.append('cloudSql.credential')
elif connection_type == 'AWS':
if properties:
aws_properties = bq_processor_utils.ParseJson(properties)
connection['aws'] = aws_properties
if aws_properties.get('crossAccountRole') and aws_properties[
'crossAccountRole'
].get('iamRoleId'):
update_mask.append('aws.crossAccountRole.iamRoleId')
if aws_properties.get('accessRole') and aws_properties['accessRole'].get(
'iamRoleId'
):
update_mask.append('aws.access_role.iam_role_id')
else:
connection['aws'] = {}
if connection_credential:
connection['aws']['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
update_mask.append('aws.credential')
elif connection_type == 'Azure':
if properties:
azure_properties = bq_processor_utils.ParseJson(properties)
connection['azure'] = azure_properties
if azure_properties.get('customerTenantId'):
update_mask.append('azure.customer_tenant_id')
if azure_properties.get('federatedApplicationClientId'):
update_mask.append('azure.federated_application_client_id')
elif connection_type == 'SQL_DATA_SOURCE':
if properties:
sql_data_source_properties = bq_processor_utils.ParseJson(properties)
connection['sqlDataSource'] = sql_data_source_properties
update_mask.extend(
_GetUpdateMask(connection_type.lower(), sql_data_source_properties)
)
else:
connection['sqlDataSource'] = {}
if connection_credential:
connection['sqlDataSource']['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
update_mask.append('sqlDataSource.credential')
elif connection_type == 'CLOUD_SPANNER':
if properties:
cloudspanner_properties = bq_processor_utils.ParseJson(properties)
connection['cloudSpanner'] = cloudspanner_properties
update_mask.extend(
_GetUpdateMask(connection_type.lower(), cloudspanner_properties)
)
else:
connection['cloudSpanner'] = {}
elif connection_type == 'SPARK':
if properties:
spark_properties = bq_processor_utils.ParseJson(properties)
connection['spark'] = spark_properties
if 'sparkHistoryServerConfig' in spark_properties:
update_mask.append('spark.spark_history_server_config')
if 'metastoreServiceConfig' in spark_properties:
update_mask.append('spark.metastore_service_config')
else:
connection['spark'] = {}
elif connector_configuration:
connection['configuration'] = bq_processor_utils.ParseJson(
connector_configuration
)
update_mask.extend(
_GetUpdateMaskRecursively('configuration', connection['configuration'])
)
if _AUTH_PROFILE_ID_PATH in update_mask and _AUTH_PATH not in update_mask:
update_mask.append(_AUTH_PATH)
return (
client.projects()
.locations()
.connections()
.patch(
name=reference.path(),
updateMask=','.join(update_mask),
body=connection,
)
.execute()
)
def _GetUpdateMask(
base_path: str, json_properties: Dict[str, Any]
) -> List[str]:
"""Creates an update mask from json_properties.
Arguments:
base_path: 'cloud_sql'
json_properties: { 'host': ... , 'instanceId': ... }
Returns:
list of paths in snake case:
mask = ['cloud_sql.host', 'cloud_sql.instance_id']
"""
return [
base_path + '.' + inflection.underscore(json_property)
for json_property in json_properties
]
def _EscapeIfRequired(prefix: str, name: str) -> str:
"""Escapes name if it points to a map key or converts it to snake case.
If name points to a map key:
1. Do not change the name.
2. Escape name with backticks if it is not a valid proto field name.
Args:
prefix: field mask prefix to check if name points to a map key.
name: name of the field.
Returns:
escaped name
"""
if prefix in _MAP_KEY_PATHS:
return (
name
if re.fullmatch(_VALID_FIELD_NAME_REGEXP, name)
else ('`' + name + '`')
)
# Otherwise, convert name to snake case
return inflection.underscore(name)
def _GetUpdateMaskRecursively(
prefix: str, json_value: Dict[str, Any]
) -> List[str]:
"""Recursively traverses json_value and returns a list of update mask paths.
Args:
prefix: current prefix of the json value.
json_value: value to traverse.
Returns:
a field mask containing all the set paths in the json value.
"""
if not isinstance(json_value, dict) or not json_value:
return [prefix]
result = []
for name in json_value:
new_prefix = prefix + '.' + _EscapeIfRequired(prefix, name)
new_json_value = json_value.get(name)
result.extend(_GetUpdateMaskRecursively(new_prefix, new_json_value))
return result
def DeleteConnection(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
):
"""Delete a connection with the given connection reference.
Arguments:
client: the client used to make the request.
reference: Connection to delete.
"""
client.projects().locations().connections().delete(
name=reference.path()
).execute()
def ListConnections(
client: discovery.Resource,
project_id: str,
location: str,
max_results: int,
page_token: Optional[str],
):
"""List connections in the project and location for the given reference.
Arguments:
client: the client used to make the request.
project_id: Project ID.
location: Location.
max_results: Number of results to show.
page_token: Token to retrieve the next page of results.
Returns:
List of connection objects
"""
parent = 'projects/%s/locations/%s' % (project_id, location)
return (
client.projects()
.locations()
.connections()
.list(parent=parent, pageToken=page_token, pageSize=max_results)
.execute()
)
def SetConnectionIAMPolicy(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
policy: str,
):
"""Sets IAM policy for the given connection resource.
Arguments:
client: the client used to make the request.
reference: the ConnectionReference for the connection resource.
policy: The policy string in JSON format.
Returns:
The updated IAM policy attached to the given connection resource.
Raises:
BigqueryTypeError: if reference is not a ConnectionReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ConnectionReference,
method='SetConnectionIAMPolicy',
)
return (
client.projects()
.locations()
.connections()
.setIamPolicy(resource=reference.path(), body={'policy': policy})
.execute()
)
def GetConnectionIAMPolicy(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
):
"""Gets IAM policy for the given connection resource.
Arguments:
client: the client used to make the request.
reference: the ConnectionReference for the connection resource.
Returns:
The IAM policy attached to the given connection resource.
Raises:
BigqueryTypeError: if reference is not a ConnectionReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ConnectionReference,
method='GetConnectionIAMPolicy',
)
return (
client.projects()
.locations()
.connections()
.getIamPolicy(resource=reference.path())
.execute()
)

View File

@@ -0,0 +1,739 @@
#!/usr/bin/env python
"""The BigQuery CLI data transfer client library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import logging
from typing import Any, Dict, NamedTuple, Optional
from googleapiclient import discovery
from clients import utils as bq_client_utils
from utils import bq_api_utils
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
Service = bq_api_utils.Service
# Data Transfer Service Authorization Info
AUTHORIZATION_CODE = 'authorization_code'
VERSION_INFO = 'version_info'
class TransferScheduleArgs:
"""Arguments to customize data transfer schedule."""
def __init__(
self,
schedule: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
disable_auto_scheduling: Optional[bool] = False,
event_driven_schedule: Optional[str] = None,
):
self.schedule = schedule
self.start_time = start_time
self.end_time = end_time
self.disable_auto_scheduling = disable_auto_scheduling
self.event_driven_schedule = event_driven_schedule
def to_schedule_options_v2_payload(
self, options_to_copy: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Returns a dictionary of schedule options v2.
Args:
options_to_copy: Existing options to copy from.
Returns:
A dictionary of schedule options v2 expected by the
bigquery.transfers.create and bigquery.transfers.update API methods.
Raises:
bq_error.BigqueryError: If shedule options conflict.
"""
self._validate_schedule_options()
options = {}
if self.event_driven_schedule:
options['eventDrivenSchedule'] = self._process_event_driven_schedule(
self.event_driven_schedule
)
elif self.disable_auto_scheduling:
options['manualSchedule'] = {}
else:
options['timeBasedSchedule'] = {}
if options_to_copy and 'timeBasedSchedule' in options_to_copy:
options['timeBasedSchedule'] = dict(
options_to_copy['timeBasedSchedule']
)
if self.schedule:
options['timeBasedSchedule']['schedule'] = self.schedule
if self.start_time:
options['timeBasedSchedule']['startTime'] = self._time_or_infitity(
self.start_time
)
if self.end_time:
options['timeBasedSchedule']['endTime'] = self._time_or_infitity(
self.end_time
)
return options
def to_schedule_options_payload(
self, options_to_copy: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""Returns a dictionary of schedule options.
Args:
options_to_copy: Existing options to be copied.
Returns:
A dictionary of schedule options expected by the
bigquery.transfers.create and bigquery.transfers.update API methods.
"""
# Copy the current options or start with an empty dictionary.
options = dict(options_to_copy or {})
if self.start_time is not None:
options['startTime'] = self._time_or_infitity(self.start_time)
if self.end_time is not None:
options['endTime'] = self._time_or_infitity(self.end_time)
options['disableAutoScheduling'] = self.disable_auto_scheduling
return options
def _time_or_infitity(self, time_str: str):
"""Returns None to indicate Inifinity, if time_str is an empty string."""
return time_str or None
def _validate_schedule_options(self):
"""Validates schedule options.
Raises:
bq_error.BigqueryError: If the given schedule options conflict.
"""
is_time_based_schedule = any(
[self.schedule, self.start_time, self.end_time]
)
is_event_driven_schedule = self.event_driven_schedule is not None
if (
sum([
self.disable_auto_scheduling,
is_time_based_schedule,
is_event_driven_schedule,
])
) > 1:
raise bq_error.BigqueryError(
'The provided scheduling options conflict. Please specify one of'
' no_auto_scheduling, time-based schedule or event-driven schedule.'
)
def _process_event_driven_schedule(
self,
event_driven_schedule: str,
) -> Dict[str, str]:
"""Processes the event_driven_schedule given in JSON format.
Args:
event_driven_schedule: The user specified event driven schedule. This
should be in JSON format given as a string. Ex:
--event_driven_schedule='{"pubsub_subscription":"subscription"}'.
Returns:
parsed_event_driven_schedule: The parsed event driven schedule.
Raises:
bq_error.BigqueryError: If there is an error with the given params.
"""
try:
parsed_event_driven_schedule = json.loads(event_driven_schedule)
except Exception as e:
raise bq_error.BigqueryError(
'Event driven schedule should be specified in JSON format.'
) from e
if 'pubsub_subscription' not in parsed_event_driven_schedule:
raise bq_error.BigqueryError(
'Must specify pubsub_subscription in --event_driven_schedule.'
)
return parsed_event_driven_schedule
def get_transfer_config(transfer_client: discovery.Resource, transfer_id: str):
return (
transfer_client.projects()
.locations()
.transferConfigs()
.get(name=transfer_id)
.execute()
)
def get_transfer_run(transfer_client: discovery.Resource, identifier: str):
return (
transfer_client.projects()
.locations()
.transferConfigs()
.runs()
.get(name=identifier)
.execute()
)
def list_transfer_configs(
transfer_client: discovery.Resource,
reference: Optional[bq_id_utils.ApiClientHelper.ProjectReference] = None,
location: Optional[str] = None,
page_size: Optional[int] = None,
page_token: Optional[str] = None,
data_source_ids: Optional[str] = None,
):
"""Return a list of transfer configurations.
Args:
transfer_client: the transfer client to use.
reference: The ProjectReference to list transfer configurations for.
location: The location id, e.g. 'us' or 'eu'.
page_size: The maximum number of transfer configurations to return.
page_token: Current page token (optional).
data_source_ids: The dataSourceIds to display transfer configurations for.
Returns:
A list of transfer configurations.
"""
results = None
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ProjectReference,
method='list_transfer_configs',
)
if page_size is not None:
if page_size > bq_processor_utils.MAX_RESULTS:
page_size = bq_processor_utils.MAX_RESULTS
request = bq_processor_utils.PrepareTransferListRequest(
reference, location, page_size, page_token, data_source_ids
)
if request:
bq_processor_utils.ApplyParameters(request)
result = (
transfer_client.projects()
.locations()
.transferConfigs()
.list(**request)
.execute()
)
results = result.get('transferConfigs', [])
if page_size is not None:
while 'nextPageToken' in result and len(results) < page_size:
request = bq_processor_utils.PrepareTransferListRequest(
reference,
location,
page_size - len(results),
result['nextPageToken'],
data_source_ids,
)
if request:
bq_processor_utils.ApplyParameters(request)
result = (
transfer_client.projects()
.locations()
.transferConfigs()
.list(**request)
.execute()
)
results.extend(result.get('nextPageToken', []))
else:
return
if len(results) < 1:
logging.info('There are no transfer configurations to be shown.')
if result.get('nextPageToken'):
return (results, result.get('nextPageToken'))
return (results,)
def list_transfer_runs(
transfer_client: discovery.Resource,
reference: Optional[bq_id_utils.ApiClientHelper.TransferConfigReference],
run_attempt: Optional[str],
max_results: Optional[int] = None,
page_token: Optional[str] = None,
states: Optional[str] = None,
):
"""Return a list of transfer runs.
Args:
transfer_client: the transfer client to use.
reference: The ProjectReference to list transfer runs for.
run_attempt: Which runs should be pulled. The default value is 'LATEST',
which only returns the latest run per day. To return all runs, please
specify 'RUN_ATTEMPT_UNSPECIFIED'.
max_results: The maximum number of transfer runs to return (optional).
page_token: Current page token (optional).
states: States to filter transfer runs (optional).
Returns:
A list of transfer runs.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TransferConfigReference,
method='list_transfer_runs',
)
reference = str(reference)
request = bq_processor_utils.PrepareTransferRunListRequest(
reference, run_attempt, max_results, page_token, states
)
response = (
transfer_client.projects()
.locations()
.transferConfigs()
.runs()
.list(**request)
.execute()
)
transfer_runs = response.get('transferRuns', [])
if max_results is not None:
while 'nextPageToken' in response and len(transfer_runs) < max_results:
page_token = response.get('nextPageToken')
max_results -= len(transfer_runs)
request = bq_processor_utils.PrepareTransferRunListRequest(
reference, run_attempt, max_results, page_token, states
)
response = (
transfer_client.projects()
.locations()
.transferConfigs()
.runs()
.list(**request)
.execute()
)
transfer_runs.extend(response.get('transferRuns', []))
if response.get('nextPageToken'):
return (transfer_runs, response.get('nextPageToken'))
return (transfer_runs,)
def list_transfer_logs(
transfer_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TransferRunReference,
message_type: Optional[str] = None,
max_results: Optional[int] = None,
page_token: Optional[str] = None,
):
"""Return a list of transfer run logs.
Args:
transfer_client: the transfer client to use.
reference: The ProjectReference to list transfer run logs for.
message_type: Message types to return.
max_results: The maximum number of transfer run logs to return.
page_token: Current page token (optional).
Returns:
A list of transfer run logs.
"""
reference = str(reference)
request = bq_processor_utils.PrepareListTransferLogRequest(
reference,
max_results=max_results,
page_token=page_token,
message_type=message_type,
)
response = (
transfer_client.projects()
.locations()
.transferConfigs()
.runs()
.transferLogs()
.list(**request)
.execute()
)
transfer_logs = response.get('transferMessages', [])
if max_results is not None:
while 'nextPageToken' in response and len(transfer_logs) < max_results:
page_token = response['nextPageToken']
max_results -= len(transfer_logs)
request = bq_processor_utils.PrepareListTransferLogRequest(
reference,
max_results=max_results,
page_token=page_token,
message_type=message_type,
)
response = (
transfer_client.projects()
.locations()
.transferConfigs()
.runs()
.transferLogs()
.list(**request)
.execute()
)
transfer_logs.extend(response.get('transferMessages', []))
if response.get('nextPageToken'):
return (transfer_logs, response.get('nextPageToken'))
return (transfer_logs,)
def start_manual_transfer_runs(
transfer_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TransferConfigReference,
start_time: Optional[str],
end_time: Optional[str],
run_time: Optional[str],
):
"""Starts manual transfer runs.
Args:
transfer_client: the transfer client to use.
reference: Transfer configuration name for the run.
start_time: Start time of the range of transfer runs.
end_time: End time of the range of transfer runs.
run_time: Specific time for a transfer run.
Returns:
The list of started transfer runs.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TransferConfigReference,
method='start_manual_transfer_runs',
)
parent = str(reference)
if run_time:
body = {'requestedRunTime': run_time}
else:
body = {
'requestedTimeRange': {'startTime': start_time, 'endTime': end_time}
}
configs_request = transfer_client.projects().locations().transferConfigs()
response = configs_request.startManualRuns(parent=parent, body=body).execute()
return response.get('runs')
def transfer_exists(
transfer_client: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.TransferConfigReference',
) -> bool:
"""Returns true if the transfer exists."""
# pylint: disable=missing-function-docstring
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TransferConfigReference,
method='transfer_exists',
)
try:
transfer_client.projects().locations().transferConfigs().get(
name=reference.transferConfigName
).execute()
return True
except bq_error.BigqueryNotFoundError:
return False
def _fetch_data_source(
transfer_client: discovery.Resource,
project_reference: str,
data_source_id: str,
):
data_source_retrieval = (
project_reference + '/locations/-/dataSources/' + data_source_id
)
return (
transfer_client.projects()
.locations()
.dataSources()
.get(name=data_source_retrieval)
.execute()
)
def update_transfer_config(
transfer_client: discovery.Resource,
id_fallbacks: NamedTuple(
'IDS',
[
('project_id', Optional[str]),
],
),
reference: bq_id_utils.ApiClientHelper.TransferConfigReference,
target_dataset: Optional[str] = None,
display_name: Optional[str] = None,
refresh_window_days: Optional[str] = None,
params: Optional[str] = None,
auth_info: Optional[Dict[str, str]] = None,
service_account_name: Optional[str] = None,
destination_kms_key: Optional[str] = None,
notification_pubsub_topic: Optional[str] = None,
schedule_args: Optional[TransferScheduleArgs] = None,
):
"""Updates a transfer config.
Args:
transfer_client: the transfer client to use.
id_fallbacks: IDs to use when they have not been explicitly specified.
reference: the TransferConfigReference to update.
target_dataset: Optional updated target dataset.
display_name: Optional change to the display name.
refresh_window_days: Optional update to the refresh window days. Some data
sources do not support this.
params: Optional parameters to update.
auth_info: A dict contains authorization info which can be either an
authorization_code or a version_info that the user input if they want to
update credentials.
service_account_name: The service account that the user could act as and
used as the credential to create transfer runs from the transfer config.
destination_kms_key: Optional KMS key for encryption.
notification_pubsub_topic: The Pub/Sub topic where notifications will be
sent after transfer runs associated with this transfer config finish.
schedule_args: Optional parameters to customize data transfer schedule.
Raises:
BigqueryTypeError: if reference is not a TransferConfigReference.
bq_error.BigqueryError: required field not given.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TransferConfigReference,
method='update_transfer_config',
)
project_reference = 'projects/' + (
bq_client_utils.GetProjectReference(id_fallbacks=id_fallbacks).projectId
)
current_config = get_transfer_config(
transfer_client, reference.transferConfigName
)
update_mask = []
update_items = {}
update_items['dataSourceId'] = current_config['dataSourceId']
if target_dataset:
update_items['destinationDatasetId'] = target_dataset
update_mask.append('transfer_config.destination_dataset_id')
if display_name:
update_mask.append('transfer_config.display_name')
update_items['displayName'] = display_name
if params:
update_items = bq_processor_utils.ProcessParamsFlag(params, update_items)
update_mask.append('transfer_config.params')
# if refresh window provided, check that data source supports it
if refresh_window_days:
data_source_info = _fetch_data_source(
transfer_client, project_reference, current_config['dataSourceId']
)
update_items = bq_processor_utils.ProcessRefreshWindowDaysFlag(
refresh_window_days,
data_source_info,
update_items,
current_config['dataSourceId'],
)
update_mask.append('transfer_config.data_refresh_window_days')
if schedule_args:
update_items['scheduleOptionsV2'] = (
schedule_args.to_schedule_options_v2_payload(
current_config.get('scheduleOptionsV2')
)
)
update_mask.append('transfer_config.scheduleOptionsV2')
if notification_pubsub_topic:
update_items['notification_pubsub_topic'] = notification_pubsub_topic
update_mask.append('transfer_config.notification_pubsub_topic')
if auth_info is not None and AUTHORIZATION_CODE in auth_info:
update_mask.append(AUTHORIZATION_CODE)
if auth_info is not None and VERSION_INFO in auth_info:
update_mask.append(VERSION_INFO)
if service_account_name:
update_mask.append('service_account_name')
if destination_kms_key:
update_items['encryption_configuration'] = {
'kms_key_name': {'value': destination_kms_key}
}
update_mask.append('encryption_configuration.kms_key_name')
try:
transfer_client.projects().locations().transferConfigs().patch(
body=update_items,
name=reference.transferConfigName,
updateMask=','.join(update_mask),
authorizationCode=(
None if auth_info is None else auth_info.get(AUTHORIZATION_CODE)
),
versionInfo=None if auth_info is None else auth_info.get(VERSION_INFO),
serviceAccountName=service_account_name,
x__xgafv='2',
).execute()
except bq_error.BigqueryInterfaceError as e:
if target_dataset and 'Not found: Dataset' in str(e):
dataset_reference = bq_client_utils.GetDatasetReference(
id_fallbacks=id_fallbacks, identifier=target_dataset
)
raise bq_error.BigqueryNotFoundError(
'Not found: %r' % (dataset_reference,), {'reason': 'notFound'}, []
) from e
def create_transfer_config(
transfer_client: discovery.Resource,
reference: str,
data_source: str,
target_dataset: Optional[str] = None,
display_name: Optional[str] = None,
refresh_window_days: Optional[str] = None,
params: Optional[str] = None,
auth_info: Optional[Dict[str, str]] = None,
service_account_name: Optional[str] = None,
notification_pubsub_topic: Optional[str] = None,
schedule_args: Optional[TransferScheduleArgs] = None,
destination_kms_key: Optional[str] = None,
location: Optional[str] = None,
):
"""Create a transfer config corresponding to TransferConfigReference.
Args:
transfer_client: the transfer client to use.
reference: the TransferConfigReference to create.
data_source: The data source for the transfer config.
target_dataset: The dataset where the new transfer config will exist.
display_name: A display name for the transfer config.
refresh_window_days: Refresh window days for the transfer config.
params: Parameters for the created transfer config. The parameters should be
in JSON format given as a string. Ex: --params="{'param':'value'}". The
params should be the required values needed for each data source and will
vary.
auth_info: A dict contains authorization info which can be either an
authorization_code or a version_info that the user input if they need
credentials.
service_account_name: The service account that the user could act as and
used as the credential to create transfer runs from the transfer config.
notification_pubsub_topic: The Pub/Sub topic where notifications will be
sent after transfer runs associated with this transfer config finish.
schedule_args: Optional parameters to customize data transfer schedule.
destination_kms_key: Optional KMS key for encryption.
location: The location where the new transfer config will run.
Raises:
BigqueryNotFoundError: if a requested item is not found.
bq_error.BigqueryError: if a required field isn't provided.
Returns:
The generated transfer configuration name.
"""
create_items = {}
# The backend will check if the dataset exists.
if target_dataset:
create_items['destinationDatasetId'] = target_dataset
if display_name:
create_items['displayName'] = display_name
else:
raise bq_error.BigqueryError('A display name must be provided.')
create_items['dataSourceId'] = data_source
# if refresh window provided, check that data source supports it
if refresh_window_days:
data_source_info = _fetch_data_source(
transfer_client, reference, data_source
)
create_items = bq_processor_utils.ProcessRefreshWindowDaysFlag(
refresh_window_days, data_source_info, create_items, data_source
)
# checks that all required params are given
# if a param that isn't required is provided, it is ignored.
if params:
create_items = bq_processor_utils.ProcessParamsFlag(params, create_items)
else:
raise bq_error.BigqueryError('Parameters must be provided.')
if location:
parent = reference + '/locations/' + location
else:
# The location is infererred by the data transfer service from the
# dataset location.
parent = reference + '/locations/-'
if schedule_args:
create_items['scheduleOptionsV2'] = (
schedule_args.to_schedule_options_v2_payload()
)
if notification_pubsub_topic:
create_items['notification_pubsub_topic'] = notification_pubsub_topic
if destination_kms_key:
create_items['encryption_configuration'] = {
'kms_key_name': {'value': destination_kms_key}
}
new_transfer_config = (
transfer_client.projects()
.locations()
.transferConfigs()
.create(
parent=parent,
body=create_items,
authorizationCode=(
None if auth_info is None else auth_info.get(AUTHORIZATION_CODE)
),
versionInfo=None
if auth_info is None
else auth_info.get(VERSION_INFO),
serviceAccountName=service_account_name,
)
.execute()
)
return new_transfer_config['name']
def delete_transfer_config(
transfer_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TransferConfigReference,
ignore_not_found: bool = False,
):
"""Deletes TransferConfigReference reference.
Args:
transfer_client: the transfer client to use.
reference: the TransferConfigReference to delete.
ignore_not_found: Whether to ignore "not found" errors.
Raises:
BigqueryTypeError: if reference is not a TransferConfigReference.
bq_error.BigqueryNotFoundError: if reference does not exist and
ignore_not_found is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TransferConfigReference,
method='delete_transfer_config',
)
try:
transfer_client.projects().locations().transferConfigs().delete(
name=reference.transferConfigName
).execute()
except bq_error.BigqueryNotFoundError as e:
if not ignore_not_found:
raise bq_error.BigqueryNotFoundError(
'Not found: %r' % (reference,), {'reason': 'notFound'}, []
) from e

View File

@@ -0,0 +1,590 @@
#!/usr/bin/env python
"""The BigQuery CLI dataset client library."""
import datetime
from typing import Dict, List, NamedTuple, Optional
from googleapiclient import discovery
from clients import utils as bq_client_utils
from frontend import utils as frontend_utils
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
EXTERNAL_CATALOG_DATASET_OPTIONS_FIELD_NAME = 'externalCatalogDatasetOptions'
def GetDataset(apiclient: discovery.Resource, reference, dataset_view=None):
"""Get dataset with dataset_view parameter."""
request = dict(reference)
request['accessPolicyVersion'] = (
bq_client_utils.MAX_SUPPORTED_IAM_POLICY_VERSION
)
if dataset_view is not None:
request['datasetView'] = dataset_view
return apiclient.datasets().get(**request).execute()
def ListDatasets(
apiclient: discovery.Resource,
id_fallbacks: NamedTuple(
'IDS',
[
('project_id', Optional[str]),
],
),
reference: Optional[bq_id_utils.ApiClientHelper.ProjectReference] = None,
max_results: Optional[int] = None,
page_token: Optional[str] = None,
list_all: Optional[bool] = None,
filter_expression: Optional[str] = None,
):
"""List the datasets associated with this reference."""
return ListDatasetsWithTokenAndUnreachable(
apiclient,
id_fallbacks,
reference,
max_results,
page_token,
list_all,
filter_expression,
)['datasets']
def ListDatasetsWithTokenAndUnreachable(
apiclient: discovery.Resource,
id_fallbacks: NamedTuple(
'IDS',
[
('project_id', Optional[str]),
],
),
reference: Optional[bq_id_utils.ApiClientHelper.ProjectReference] = None,
max_results: Optional[int] = None,
page_token: Optional[str] = None,
list_all: Optional[bool] = None,
filter_expression: Optional[str] = None,
):
"""List the datasets associated with this reference."""
reference = bq_client_utils.NormalizeProjectReference(
id_fallbacks=id_fallbacks, reference=reference
)
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ProjectReference,
method='ListDatasets',
)
request = bq_processor_utils.PrepareListRequest(
reference, max_results, page_token, filter_expression
)
if list_all is not None:
request['all'] = list_all
result = apiclient.datasets().list(**request).execute()
dataset_list = result.get('datasets', [])
unreachable_set = set(result.get('unreachable', []))
next_token = result.get('nextPageToken', None)
if max_results is not None:
while 'nextPageToken' in result and len(dataset_list) < max_results:
request['maxResults'] = max_results - len(dataset_list)
request['pageToken'] = result['nextPageToken']
result = apiclient.datasets().list(**request).execute()
dataset_list.extend(result.get('datasets', []))
unreachable_set.update(result.get('unreachable', []))
next_token = result.get('nextPageToken', None)
response = dict(datasets=dataset_list)
if next_token:
response['token'] = next_token
if unreachable_set:
response['unreachable'] = list(unreachable_set)
return response
def GetDatasetIAMPolicy(apiclient, reference):
"""Gets IAM policy for the given dataset resource.
Arguments:
apiclient: the apiclient used to make the request.
reference: the DatasetReference for the dataset resource.
Returns:
The IAM policy attached to the given dataset resource.
Raises:
BigqueryTypeError: if reference is not a DatasetReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='GetDatasetIAMPolicy',
)
formatted_resource = 'projects/%s/datasets/%s' % (
reference.projectId,
reference.datasetId,
)
body = {
'options': {
'requestedPolicyVersion': (
bq_client_utils.MAX_SUPPORTED_IAM_POLICY_VERSION
)
}
}
return (
apiclient.datasets()
.getIamPolicy(
resource=formatted_resource,
body=body,
)
.execute()
)
def SetDatasetIAMPolicy(apiclient: discovery.Resource, reference, policy):
"""Sets IAM policy for the given dataset resource.
Arguments:
apiclient: the apiclient used to make the request.
reference: the DatasetReference for the dataset resource.
policy: The policy string in JSON format.
Returns:
The updated IAM policy attached to the given dataset resource.
Raises:
BigqueryTypeError: if reference is not a DatasetReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='SetDatasetIAMPolicy',
)
formatted_resource = 'projects/%s/datasets/%s' % (
reference.projectId,
reference.datasetId,
)
request = {'policy': policy}
return (
apiclient.datasets()
.setIamPolicy(body=request, resource=formatted_resource)
.execute()
)
def DatasetExists(
apiclient: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.DatasetReference',
) -> bool:
"""Returns true if a dataset exists."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='DatasetExists',
)
try:
apiclient.datasets().get(**dict(reference)).execute()
return True
except bq_error.BigqueryNotFoundError:
return False
def GetDatasetRegion(
apiclient: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.DatasetReference',
) -> Optional[str]:
"""Returns the region of a dataset as a string."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='GetDatasetRegion',
)
try:
return apiclient.datasets().get(**dict(reference)).execute()['location']
except bq_error.BigqueryNotFoundError:
return None
# TODO(b/191712821): add tags modification here. For the Preview Tags are not
# modifiable using BigQuery UI/Cli, only using ResourceManager.
def CreateDataset(
apiclient: discovery.Resource,
reference,
ignore_existing=False,
description=None,
display_name=None,
acl=None,
default_table_expiration_ms=None,
default_partition_expiration_ms=None,
data_location=None,
labels=None,
default_kms_key=None,
source_dataset_reference=None,
external_source=None,
connection_id=None,
external_catalog_dataset_options=None,
max_time_travel_hours=None,
storage_billing_model=None,
resource_tags=None,
):
"""Create a dataset corresponding to DatasetReference.
Args:
apiclient: The apiclient used to make the request.
reference: The DatasetReference to create.
ignore_existing: (boolean, default False) If False, raise an exception if
the dataset already exists.
description: An optional dataset description.
display_name: An optional friendly name for the dataset.
acl: An optional ACL for the dataset, as a list of dicts.
default_table_expiration_ms: Default expiration time to apply to new tables
in this dataset.
default_partition_expiration_ms: Default partition expiration time to apply
to new partitioned tables in this dataset.
data_location: Location where the data in this dataset should be stored.
Must be either 'EU' or 'US'. If specified, the project that owns the
dataset must be enabled for data location.
labels: An optional dict of labels.
default_kms_key: An optional kms dey that will apply to all newly created
tables in the dataset, if no explicit key is supplied in the creating
request.
source_dataset_reference: An optional ApiClientHelper.DatasetReference that
will be the source of this linked dataset. #
external_source: External source that backs this dataset.
connection_id: Connection used for accessing the external_source.
external_catalog_dataset_options: An optional JSON string or file path
containing the external catalog dataset options to create.
max_time_travel_hours: Optional. Define the max time travel in hours. The
value can be from 48 to 168 hours (2 to 7 days). The default value is 168
hours if this is not set.
storage_billing_model: Optional. Sets the storage billing model for the
dataset.
resource_tags: An optional dict of tags to attach to the dataset.
Raises:
BigqueryTypeError: If reference is not an ApiClientHelper.DatasetReference
or if source_dataset_reference is provided but is not an
bq_id_utils.ApiClientHelper.DatasetReference.
or if both external_dataset_reference and source_dataset_reference
are provided or if not all required arguments for external database is
provided.
BigqueryDuplicateError: if reference exists and ignore_existing
is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='CreateDataset',
)
body = bq_processor_utils.ConstructObjectInfo(reference)
if display_name is not None:
body['friendlyName'] = display_name
if description is not None:
body['description'] = description
if acl is not None:
body['access'] = acl
if default_table_expiration_ms is not None:
body['defaultTableExpirationMs'] = default_table_expiration_ms
if default_partition_expiration_ms is not None:
body['defaultPartitionExpirationMs'] = default_partition_expiration_ms
if default_kms_key is not None:
body['defaultEncryptionConfiguration'] = {'kmsKeyName': default_kms_key}
if data_location is not None:
body['location'] = data_location
if labels:
body['labels'] = {}
for label_key, label_value in labels.items():
body['labels'][label_key] = label_value
if source_dataset_reference is not None:
bq_id_utils.typecheck(
source_dataset_reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='CreateDataset',
)
body['linkedDatasetSource'] = {
'sourceDataset': bq_processor_utils.ConstructObjectInfo(
source_dataset_reference
)['datasetReference']
}
# externalDatasetReference can only be specified in case of externals
# datasets. This option cannot be used in case of regular dataset or linked
# datasets.
# So we only set this if an external_source is specified.
if external_source:
body['externalDatasetReference'] = {
'externalSource': external_source,
'connection': connection_id,
}
if external_catalog_dataset_options is not None:
body[EXTERNAL_CATALOG_DATASET_OPTIONS_FIELD_NAME] = frontend_utils.GetJson(
external_catalog_dataset_options
)
if max_time_travel_hours is not None:
body['maxTimeTravelHours'] = max_time_travel_hours
if storage_billing_model is not None:
body['storageBillingModel'] = storage_billing_model
if resource_tags is not None:
body['resourceTags'] = resource_tags
args = dict(reference.GetProjectReference())
args['accessPolicyVersion'] = bq_client_utils.MAX_SUPPORTED_IAM_POLICY_VERSION
try:
apiclient.datasets().insert(body=body, **args).execute()
except bq_error.BigqueryDuplicateError:
if not ignore_existing:
raise
def UpdateDataset(
apiclient: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.DatasetReference',
description: Optional[str] = None,
display_name: Optional[str] = None,
acl=None,
default_table_expiration_ms=None,
default_partition_expiration_ms=None,
labels_to_set=None,
label_keys_to_remove=None,
etag=None,
default_kms_key=None,
max_time_travel_hours=None,
storage_billing_model=None,
tags_to_attach: Optional[Dict[str, str]] = None,
tags_to_remove: Optional[List[str]] = None,
clear_all_tags: Optional[bool] = False,
external_catalog_dataset_options: Optional[str] = None,
update_mode: Optional[bq_client_utils.UpdateMode] = None,
):
"""Updates a dataset.
Args:
apiclient: The apiclient used to make the request.
reference: The DatasetReference to update.
description: An optional dataset description.
display_name: An optional friendly name for the dataset.
acl: An optional ACL for the dataset, as a list of dicts.
default_table_expiration_ms: Optional number of milliseconds for the default
expiration duration for new tables created in this dataset.
default_partition_expiration_ms: Optional number of milliseconds for the
default partition expiration duration for new partitioned tables created
in this dataset.
labels_to_set: An optional dict of labels to set on this dataset.
label_keys_to_remove: An optional list of label keys to remove from this
dataset.
etag: If set, checks that etag in the existing dataset matches.
default_kms_key: An optional kms dey that will apply to all newly created
tables in the dataset, if no explicit key is supplied in the creating
request.
max_time_travel_hours: Optional. Define the max time travel in hours. The
value can be from 48 to 168 hours (2 to 7 days). The default value is 168
hours if this is not set.
storage_billing_model: Optional. Sets the storage billing model for the
dataset.
tags_to_attach: An optional dict of tags to attach to the dataset
tags_to_remove: An optional list of tag keys to remove from the dataset
clear_all_tags: If set, clears all the tags attached to the dataset
external_catalog_dataset_options: An optional JSON string or file path
containing the external catalog dataset options to update.
update_mode: An optional flag indicating which datasets fields to update,
either metadata fields only, ACL fields only, or both metadata and ACL
fields.
Raises:
BigqueryTypeError: If reference is not a DatasetReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='UpdateDataset',
)
# Get the existing dataset and associated ETag.
dataset = _ExecuteGetDatasetRequest(apiclient, reference, etag)
# Merge in the changes.
if display_name is not None:
dataset['friendlyName'] = display_name
if description is not None:
dataset['description'] = description
if acl is not None:
dataset['access'] = acl
if default_table_expiration_ms is not None:
dataset['defaultTableExpirationMs'] = default_table_expiration_ms
if default_partition_expiration_ms is not None:
if default_partition_expiration_ms == 0:
dataset['defaultPartitionExpirationMs'] = None
else:
dataset['defaultPartitionExpirationMs'] = default_partition_expiration_ms
if default_kms_key is not None:
dataset['defaultEncryptionConfiguration'] = {'kmsKeyName': default_kms_key}
if 'labels' not in dataset:
dataset['labels'] = {}
if labels_to_set:
for label_key, label_value in labels_to_set.items():
dataset['labels'][label_key] = label_value
if label_keys_to_remove:
for label_key in label_keys_to_remove:
dataset['labels'][label_key] = None
if max_time_travel_hours is not None:
dataset['maxTimeTravelHours'] = max_time_travel_hours
if storage_billing_model is not None:
dataset['storageBillingModel'] = storage_billing_model
resource_tags = {}
if clear_all_tags and 'resourceTags' in dataset:
for tag in dataset['resourceTags']:
resource_tags[tag] = None
else:
for tag in tags_to_remove or []:
resource_tags[tag] = None
for tag in tags_to_attach or {}:
resource_tags[tag] = tags_to_attach[tag]
# resourceTags is used to add a new tag binding, update value of existing
# tag and also to remove a tag binding
dataset['resourceTags'] = resource_tags
if external_catalog_dataset_options is not None:
dataset.setdefault(EXTERNAL_CATALOG_DATASET_OPTIONS_FIELD_NAME, {})
current_options = dataset[EXTERNAL_CATALOG_DATASET_OPTIONS_FIELD_NAME]
dataset[EXTERNAL_CATALOG_DATASET_OPTIONS_FIELD_NAME] = (
frontend_utils.UpdateExternalCatalogDatasetOptions(
current_options, external_catalog_dataset_options
)
)
_ExecutePatchDatasetRequest(
apiclient,
reference,
dataset,
etag,
update_mode,
)
def _ExecuteGetDatasetRequest(
apiclient: discovery.Resource,
reference,
etag: Optional[str] = None,
):
"""Executes request to get dataset.
Args:
apiclient: the apiclient used to make the request.
reference: the DatasetReference to get.
etag: if set, checks that etag in the existing dataset matches.
Returns:
The result of executing the request, if it succeeds.
"""
args = dict(reference)
args['accessPolicyVersion'] = bq_client_utils.MAX_SUPPORTED_IAM_POLICY_VERSION
get_request = apiclient.datasets().get(**args)
if etag:
get_request.headers['If-Match'] = etag
dataset = get_request.execute()
return dataset
def _ExecutePatchDatasetRequest(
apiclient: discovery.Resource,
reference,
dataset,
etag: Optional[str] = None,
update_mode: Optional[bq_client_utils.UpdateMode] = None,
):
"""Executes request to patch dataset.
Args:
apiclient: the apiclient used to make the request.
reference: the DatasetReference to patch.
dataset: the body of request
etag: if set, checks that etag in the existing dataset matches.
update_mode: a flag indicating which datasets fields to update.
"""
parameters = dict(reference)
parameters['accessPolicyVersion'] = (
bq_client_utils.MAX_SUPPORTED_IAM_POLICY_VERSION
)
if update_mode is not None:
parameters['updateMode'] = update_mode.value
request = apiclient.datasets().patch(body=dataset, **parameters)
# Perform a conditional update to protect against concurrent
# modifications to this dataset. By placing the ETag returned in
# the get operation into the If-Match header, the API server will
# make sure the dataset hasn't changed. If there is a conflicting
# change, this update will fail with a "Precondition failed"
# error.
if etag or dataset['etag']:
request.headers['If-Match'] = etag if etag else dataset['etag']
request.execute()
def DeleteDataset(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.DatasetReference,
ignore_not_found: bool = False,
delete_contents: Optional[bool] = None,
) -> None:
"""Deletes DatasetReference reference.
Args:
apiclient: the api client to make the request with.
reference: the DatasetReference to delete.
ignore_not_found: Whether to ignore "not found" errors.
delete_contents: [Boolean] Whether to delete the contents of non-empty
datasets. If not specified and the dataset has tables in it, the delete
will fail. If not specified, the server default applies.
Raises:
BigqueryTypeError: if reference is not a DatasetReference.
bq_error.BigqueryNotFoundError: if reference does not exist and
ignore_not_found is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='DeleteDataset',
)
args = dict(reference)
if delete_contents is not None:
args['deleteContents'] = delete_contents
try:
apiclient.datasets().delete(**args).execute()
except bq_error.BigqueryNotFoundError:
if not ignore_not_found:
raise
def UndeleteDataset(
apiclient: discovery.Resource,
dataset_reference: bq_id_utils.ApiClientHelper.DatasetReference,
timestamp: Optional[datetime.datetime] = None,
) -> bool:
"""Undeletes a dataset.
Args:
apiclient: The api client to make the request with.
dataset_reference: [Type:
bq_id_utils.ApiClientHelper.DatasetReference]DatasetReference of the
dataset to be undeleted
timestamp: [Type: Optional[datetime.datetime]]Timestamp for which dataset
version is to be undeleted
Returns:
bool: The job description, or None for ignored errors.
Raises:
BigqueryDuplicateError: when the dataset to be undeleted already exists.
"""
try:
args = dict(dataset_reference)
if timestamp:
args['body'] = {
'deletionTime': frontend_utils.FormatRfc3339(timestamp).replace(
'+00:00', ''
)
}
return apiclient.datasets().undelete(**args).execute()
except bq_error.BigqueryDuplicateError as e:
raise e

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python
"""Legacy code that isn't split up into resource based clients."""
from collections.abc import Callable
import sys
from googleapiclient import discovery
from typing_extensions import TypeAlias
from clients import client_project
from clients import utils as bq_client_utils
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
# This Callable annotation would cause a type error before Python 3.9.2, see
# https://docs.python.org/3/whatsnew/3.9.html#notable-changes-in-python-3-9-2.
if sys.version_info >= (3, 9, 2):
GetApiClienFunction: TypeAlias = Callable[[], discovery.Resource]
else:
GetApiClienFunction: TypeAlias = Callable
def get_object_info(
apiclient: discovery.Resource,
get_routines_api_client: GetApiClienFunction,
get_models_api_client: GetApiClienFunction,
reference,
):
"""Get all data returned by the server about a specific object."""
# Projects are handled separately, because we only have
# bigquery.projects.list.
if isinstance(reference, bq_id_utils.ApiClientHelper.ProjectReference):
max_project_results = 1000
projects = client_project.list_projects(
apiclient=apiclient, max_results=max_project_results
)
for project in projects:
if bq_processor_utils.ConstructObjectReference(project) == reference:
project['kind'] = 'bigquery#project'
return project
if len(projects) >= max_project_results:
raise bq_error.BigqueryError(
'Number of projects found exceeded limit, please instead run'
' gcloud projects describe %s' % (reference,),
)
raise bq_error.BigqueryNotFoundError(
'Unknown %r' % (reference,), {'reason': 'notFound'}, []
)
if isinstance(reference, bq_id_utils.ApiClientHelper.JobReference):
return apiclient.jobs().get(**dict(reference)).execute()
elif isinstance(reference, bq_id_utils.ApiClientHelper.DatasetReference):
request = dict(reference)
request['accessPolicyVersion'] = (
bq_client_utils.MAX_SUPPORTED_IAM_POLICY_VERSION
)
return apiclient.datasets().get(**request).execute()
elif isinstance(reference, bq_id_utils.ApiClientHelper.TableReference):
return apiclient.tables().get(**dict(reference)).execute()
elif isinstance(reference, bq_id_utils.ApiClientHelper.ModelReference):
return (
get_models_api_client()
.models()
.get(
projectId=reference.projectId,
datasetId=reference.datasetId,
modelId=reference.modelId,
)
.execute()
)
elif isinstance(reference, bq_id_utils.ApiClientHelper.RoutineReference):
return (
get_routines_api_client()
.routines()
.get(
projectId=reference.projectId,
datasetId=reference.datasetId,
routineId=reference.routineId,
)
.execute()
)
else:
raise bq_error.BigqueryTypeError(
'Type of reference must be one of: ProjectReference, '
'JobReference, DatasetReference, or TableReference'
)

View File

@@ -0,0 +1,162 @@
#!/usr/bin/env python
"""The BigQuery CLI model client library."""
from typing import Dict, List, Optional
from googleapiclient import discovery
from utils import bq_error
from utils import bq_id_utils
def list_models(
model_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.DatasetReference,
max_results: Optional[int],
page_token: Optional[str],
):
"""Lists models for the given dataset reference.
Arguments:
model_client: The apiclient used to make the request.
reference: Reference to the dataset.
max_results: Number of results to return.
page_token: Token to retrieve the next page of results.
Returns:
A dict that contains entries:
'results': a list of models
'token': nextPageToken for the last page, if present.
"""
return (
model_client.models()
.list(
projectId=reference.projectId,
datasetId=reference.datasetId,
maxResults=max_results,
pageToken=page_token,
)
.execute()
)
def model_exists(
model_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ModelReference,
) -> bool:
"""Returns true if the model exists."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ModelReference,
method='model_exists',
)
try:
return (
model_client.models()
.get(
projectId=reference.projectId,
datasetId=reference.datasetId,
modelId=reference.modelId,
)
.execute()
)
except bq_error.BigqueryNotFoundError:
return False
def update_model(
model_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ModelReference,
description: Optional[str] = None,
expiration: Optional[int] = None,
labels_to_set: Optional[Dict[str, str]] = None,
label_keys_to_remove: Optional[List[str]] = None,
vertex_ai_model_id: Optional[str] = None,
etag: Optional[str] = None,
):
"""Updates a Model.
Args:
model_client: The apiclient used to make the request.
reference: the ModelReference to update.
description: an optional description for model.
expiration: optional expiration time in milliseconds since the epoch.
Specifying 0 clears the expiration time for the model.
labels_to_set: an optional dict of labels to set on this model.
label_keys_to_remove: an optional list of label keys to remove from this
model.
vertex_ai_model_id: an optional string as Vertex AI model ID to register.
etag: if set, checks that etag in the existing model matches.
Raises:
BigqueryTypeError: if reference is not a ModelReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ModelReference,
method='update_model',
)
updated_model = {}
if description is not None:
updated_model['description'] = description
if expiration is not None:
updated_model['expirationTime'] = expiration or None
if 'labels' not in updated_model:
updated_model['labels'] = {}
if labels_to_set:
for label_key, label_value in labels_to_set.items():
updated_model['labels'][label_key] = label_value
if label_keys_to_remove:
for label_key in label_keys_to_remove:
updated_model['labels'][label_key] = None
if vertex_ai_model_id is not None:
updated_model['trainingRuns'] = [{'vertex_ai_model_id': vertex_ai_model_id}]
request = model_client.models().patch(
body=updated_model,
projectId=reference.projectId,
datasetId=reference.datasetId,
modelId=reference.modelId,
)
# Perform a conditional update to protect against concurrent
# modifications to this model. If there is a conflicting
# change, this update will fail with a "Precondition failed"
# error.
if etag:
request.headers['If-Match'] = etag if etag else updated_model['etag']
request.execute()
def delete_model(
model_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ModelReference,
ignore_not_found: bool = False,
):
"""Deletes ModelReference reference.
Args:
model_client: The apiclient used to make the request.
reference: the ModelReference to delete.
ignore_not_found: Whether to ignore "not found" errors.
Raises:
BigqueryTypeError: if reference is not a ModelReference.
bq_error.BigqueryNotFoundError: if reference does not exist and
ignore_not_found is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ModelReference,
method='delete_model',
)
try:
model_client.models().delete(
projectId=reference.projectId,
datasetId=reference.datasetId,
modelId=reference.modelId,
).execute()
except bq_error.BigqueryNotFoundError:
if not ignore_not_found:
raise

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python
"""The BigQuery CLI project client library."""
from typing import Optional
from googleapiclient import discovery
from utils import bq_processor_utils
def list_project_refs(apiclient: discovery.Resource, **kwds):
"""List the project references this user has access to."""
return list(
map(
bq_processor_utils.ConstructObjectReference,
list_projects(apiclient, **kwds),
)
)
def list_projects(
apiclient: discovery.Resource,
max_results: Optional[int] = None,
page_token: Optional[str] = None,
):
"""List the projects this user has access to."""
request = bq_processor_utils.PrepareListRequest({}, max_results, page_token)
result = _execute_list_projects_request(apiclient, request)
results = result.get('projects', [])
while 'nextPageToken' in result and (
max_results is not None and len(results) < max_results
):
request['pageToken'] = result['nextPageToken']
result = _execute_list_projects_request(apiclient, request)
results.extend(result.get('projects', []))
results.sort(key=lambda x: x['id'])
return results
def _execute_list_projects_request(apiclient, request):
return apiclient.projects().list(**request).execute()

View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python
"""The BigQuery CLI routine client library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional
from googleapiclient import discovery
from utils import bq_error
from utils import bq_id_utils
def ListRoutines(
routines_api_client: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.DatasetReference',
max_results: int,
page_token: str,
filter_expression: str,
):
"""Lists routines for the given dataset reference.
Arguments:
routines_api_client: the api client used to make the request.
reference: Reference to the dataset.
max_results: Number of results to return.
page_token: Token to retrieve the next page of results.
filter_expression: An expression for filtering routines.
Returns:
A dict that contains entries:
'routines': a list of routines.
'token': nextPageToken for the last page, if present.
"""
return (
routines_api_client.routines()
.list(
projectId=reference.projectId,
datasetId=reference.datasetId,
maxResults=max_results,
pageToken=page_token,
filter=filter_expression,
)
.execute()
)
def RoutineExists(
routines_api_client: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.RoutineReference',
):
"""Returns true if the routine exists."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.RoutineReference,
method='RoutineExists',
)
try:
return (
routines_api_client.routines()
.get(
projectId=reference.projectId,
datasetId=reference.datasetId,
routineId=reference.routineId,
)
.execute()
)
except bq_error.BigqueryNotFoundError:
return False
def DeleteRoutine(
routines_api_client: discovery.Resource,
reference: 'bq_id_utils.ApiClientHelper.RoutineReference',
ignore_not_found: Optional[bool] = False,
) -> None:
"""Deletes RoutineReference reference.
Args:
routines_api_client: the api client used to make the request.
reference: the RoutineReference to delete.
ignore_not_found: Whether to ignore "not found" errors.
Raises:
BigqueryTypeError: if reference is not a RoutineReference.
bq_error.BigqueryNotFoundError: if reference does not exist and
ignore_not_found is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.RoutineReference,
method='DeleteRoutine',
)
try:
routines_api_client.routines().delete(**dict(reference)).execute()
except bq_error.BigqueryNotFoundError:
if not ignore_not_found:
raise
def SetRoutineIAMPolicy(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.RoutineReference,
policy: str,
) -> ...:
"""Sets IAM policy for the given routine resource.
Arguments:
apiclient: the apiclient used to make the request.
reference: the RoutineReference for the routine resource.
policy: The policy string in JSON format.
Returns:
The updated IAM policy attached to the given routine resource.
Raises:
BigqueryTypeError: if reference is not a RoutineReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.RoutineReference,
method='SetRoutineIAMPolicy',
)
request = {'policy': policy}
return (
apiclient.routines()
.setIamPolicy(body=request, resource=reference.path())
.execute()
)
def GetRoutineIAMPolicy(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.RoutineReference,
) -> ...:
"""Gets IAM policy for the given routine resource.
Arguments:
apiclient: the apiclient used to make the request.
reference: the RoutineReference for the routine resource.
Returns:
The IAM policy attached to the given routine resource.
Raises:
BigqueryTypeError: if reference is not a RoutineReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.RoutineReference,
method='GetRoutineIAMPolicy',
)
return apiclient.routines().getIamPolicy(resource=reference.path()).execute()

View File

@@ -0,0 +1,287 @@
#!/usr/bin/env python
"""The BigQuery CLI row access policy client library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, Dict, List
# To configure apiclient logging.
from google.api_core import iam
from clients import bigquery_client
from utils import bq_id_utils
# IAM role name that represents being a grantee on a row access policy.
_FILTERED_DATA_VIEWER_ROLE = 'roles/bigquery.filteredDataViewer'
def create_row_access_policy(
bqclient: bigquery_client.BigqueryClient,
policy_reference: 'bq_id_utils.ApiClientHelper.RowAccessPolicyReference',
grantees: List[str],
filter_predicate: str,
):
"""Create a row access policy on the given table reference.
Arguments:
bqclient: BigQuery client to use for the request.
policy_reference: Reference to the row access policy to create.
grantees: Users or groups that can access rows protected by the row access
policy.
filter_predicate: A SQL boolean expression that needs to be true for a row to
be included in the result.
Returns:
rowAccessPolicy: The created row access policy defined in
google3/google/cloud/bigquery/v2/row_access_policy.proto;l=235;rcl=642795091
"""
row_access_policy = {
'rowAccessPolicyReference': {
'projectId': policy_reference.projectId,
'datasetId': policy_reference.datasetId,
'tableId': policy_reference.tableId,
'policyId': policy_reference.policyId,
},
'filterPredicate': filter_predicate,
'grantees': grantees,
}
return (
bqclient.GetRowAccessPoliciesApiClient()
.rowAccessPolicies()
.insert(
projectId=policy_reference.projectId,
datasetId=policy_reference.datasetId,
tableId=policy_reference.tableId,
body=row_access_policy,
)
.execute()
)
def update_row_access_policy(
bqclient: bigquery_client.BigqueryClient,
policy_reference: 'bq_id_utils.ApiClientHelper.RowAccessPolicyReference',
grantees: List[str],
filter_predicate: str,
):
"""Update a row access policy on the given table reference.
Arguments:
bqclient: BigQuery client to use for the request.
policy_reference: Reference to the row access policy to update.
grantees: Users or groups that can access rows protected by the row access
policy.
filter_predicate: A SQL boolean expression that needs to be true for a row to
be included in the result.
Returns:
rowAccessPolicy: The updated row access policy defined in
google3/google/cloud/bigquery/v2/row_access_policy.proto;l=235;rcl=642795091
"""
row_access_policy = {
'rowAccessPolicyReference': {
'projectId': policy_reference.projectId,
'datasetId': policy_reference.datasetId,
'tableId': policy_reference.tableId,
'policyId': policy_reference.policyId,
},
'filterPredicate': filter_predicate,
'grantees': grantees,
}
return (
bqclient.GetRowAccessPoliciesApiClient()
.rowAccessPolicies()
.update(
projectId=policy_reference.projectId,
datasetId=policy_reference.datasetId,
tableId=policy_reference.tableId,
policyId=policy_reference.policyId,
body=row_access_policy,
)
.execute()
)
def get_row_access_policy(
bqclient: bigquery_client.BigqueryClient,
policy_reference: 'bq_id_utils.ApiClientHelper.RowAccessPolicyReference',
):
"""Get a row access policy on the given table reference."""
response = _get_row_access_policy_reference(bqclient, policy_reference)
if 'rowAccessPolicyReference' in response:
_set_row_access_policy_grantees(bqclient, response)
return response
def _get_row_access_policy_reference(
bqclient: bigquery_client.BigqueryClient,
policy_reference: 'bq_id_utils.ApiClientHelper.RowAccessPolicyReference',
) -> Dict[str, Any]:
"""Returns the RowAccessPolicyReference for the given row access policy."""
return (
bqclient.GetRowAccessPoliciesApiClient()
.rowAccessPolicies()
.get(
projectId=policy_reference.projectId,
datasetId=policy_reference.datasetId,
tableId=policy_reference.tableId,
policyId=policy_reference.policyId,
)
.execute()
)
def delete_row_access_policy(
bqclient: bigquery_client.BigqueryClient,
policy_reference: 'bq_id_utils.ApiClientHelper.RowAccessPolicyReference',
force: bool = False,
):
"""Delete a row access policy on the given table reference."""
return (
bqclient.GetRowAccessPoliciesApiClient()
.rowAccessPolicies()
.delete(
projectId=policy_reference.projectId,
datasetId=policy_reference.datasetId,
tableId=policy_reference.tableId,
policyId=policy_reference.policyId,
force=force,
)
.execute()
)
def _list_row_access_policies(
bqclient: bigquery_client.BigqueryClient,
table_reference: 'bq_id_utils.ApiClientHelper.TableReference',
page_size: int,
page_token: str,
) -> Dict[str, List[Any]]:
"""Lists row access policies for the given table reference."""
return (
bqclient.GetRowAccessPoliciesApiClient()
.rowAccessPolicies()
.list(
projectId=table_reference.projectId,
datasetId=table_reference.datasetId,
tableId=table_reference.tableId,
pageSize=page_size,
pageToken=page_token,
)
.execute()
)
def list_row_access_policies_with_grantees(
bqclient: bigquery_client.BigqueryClient,
table_reference: 'bq_id_utils.ApiClientHelper.TableReference',
page_size: int,
page_token: str,
max_concurrent_iam_calls: int = 1,
) -> Dict[str, List[Any]]:
"""Lists row access policies for the given table reference.
Arguments:
bqclient: BigQuery client to use for the request.
table_reference: Reference to the table.
page_size: Number of results to return.
page_token: Token to retrieve the next page of results.
max_concurrent_iam_calls: Number of concurrent calls to getIAMPolicy.
Returns:
A dict that contains entries:
'rowAccessPolicies': a list of row access policies, with an additional
'grantees' field that contains the row access policy grantees.
'nextPageToken': nextPageToken for the next page, if present.
"""
response = _list_row_access_policies(
bqclient=bqclient,
table_reference=table_reference,
page_size=page_size,
page_token=page_token,
)
if 'rowAccessPolicies' in response:
row_access_policies = response['rowAccessPolicies']
for row_access_policy in row_access_policies:
_set_row_access_policy_grantees(
bqclient=bqclient,
row_access_policy=row_access_policy,
)
return response
def _set_row_access_policy_grantees(
bqclient: bigquery_client.BigqueryClient, row_access_policy
):
"""Sets the grantees on the given Row Access Policy."""
row_access_policy_ref = (
bq_id_utils.ApiClientHelper.RowAccessPolicyReference.Create(
**row_access_policy['rowAccessPolicyReference']
)
)
iam_policy = get_row_access_policy_iam_policy(
bqclient=bqclient, reference=row_access_policy_ref
)
grantees = _get_grantees_from_row_access_policy_iam_policy(iam_policy)
row_access_policy['grantees'] = grantees
def _get_grantees_from_row_access_policy_iam_policy(iam_policy):
"""Returns the filtered data viewer members of the given IAM policy."""
bindings = iam_policy.get('bindings')
if not bindings:
return []
filtered_data_viewer_binding = next(
(
binding
for binding in bindings
if binding.get('role') == _FILTERED_DATA_VIEWER_ROLE
),
None,
)
if not filtered_data_viewer_binding:
return []
return filtered_data_viewer_binding.get('members', [])
def get_row_access_policy_iam_policy(
bqclient: bigquery_client.BigqueryClient,
reference: 'bq_id_utils.ApiClientHelper.RowAccessPolicyReference',
) -> iam.Policy:
"""Gets IAM policy for the given row access policy resource.
Arguments:
bqclient: BigQuery client to use for the request.
reference: the RowAccessPolicyReference for the row access policy resource.
Returns:
The IAM policy attached to the given row access policy resource.
Raises:
BigqueryTypeError: if reference is not a RowAccessPolicyReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.RowAccessPolicyReference,
method='get_row_access_policy_iam_policy',
)
formatted_resource = (
'projects/%s/datasets/%s/tables/%s/rowAccessPolicies/%s'
% (
reference.projectId,
reference.datasetId,
reference.tableId,
reference.policyId,
)
)
return (
bqclient.GetIAMPolicyApiClient()
.rowAccessPolicies()
.getIamPolicy(resource=formatted_resource)
.execute()
)

View File

@@ -0,0 +1,647 @@
#!/usr/bin/env python
"""The BigQuery CLI table client library."""
from typing import Dict, List, Optional, cast
from googleapiclient import discovery
from clients import table_reader as bq_table_reader
from frontend import utils as bq_frontend_utils
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
_EXTERNAL_CATALOG_TABLE_OPTIONS_FIELD_NAME = 'externalCatalogTableOptions'
def get_table_schema(
apiclient: discovery.Resource,
table_dict: bq_id_utils.ApiClientHelper.TableReference,
):
table_info = apiclient.tables().get(**table_dict).execute()
return table_info.get('schema', {})
def insert_table_rows(
insert_client: discovery.Resource,
table_dict: bq_id_utils.ApiClientHelper.TableReference,
inserts: List[Optional[bq_processor_utils.InsertEntry]],
skip_invalid_rows: Optional[bool] = None,
ignore_unknown_values: Optional[bool] = None,
template_suffix: Optional[int] = None,
):
"""Insert rows into a table.
Arguments:
insert_client: The apiclient used to make the request.
table_dict: table reference into which rows are to be inserted.
inserts: array of InsertEntry tuples where insert_id can be None.
skip_invalid_rows: Optional. Attempt to insert any valid rows, even if
invalid rows are present.
ignore_unknown_values: Optional. Ignore any values in a row that are not
present in the schema.
template_suffix: Optional. The suffix used to generate the template table's
name.
Returns:
result of the operation.
"""
def _encode_insert(insert):
encoded = dict(json=insert.record)
if insert.insert_id:
encoded['insertId'] = insert.insert_id
return encoded
op = insert_client.tabledata().insertAll(
body=dict(
skipInvalidRows=skip_invalid_rows,
ignoreUnknownValues=ignore_unknown_values,
templateSuffix=template_suffix,
rows=list(map(_encode_insert, inserts)),
),
**table_dict,
)
return op.execute()
def read_schema_and_rows(
apiclient: discovery.Resource,
table_ref: bq_id_utils.ApiClientHelper.TableReference,
start_row: Optional[int] = None,
max_rows: Optional[int] = None,
selected_fields: Optional[str] = None,
max_rows_per_request: Optional[int] = None,
):
"""Convenience method to get the schema and rows from a table.
Arguments:
apiclient: The apiclient used to make the request.
table_ref: table reference.
start_row: first row to read.
max_rows: number of rows to read.
selected_fields: a subset of fields to return.
max_rows_per_request: the maximum number of rows to read per request.
Returns:
A tuple where the first item is the list of fields and the
second item a list of rows.
Raises:
ValueError: will be raised if start_row is not explicitly provided.
ValueError: will be raised if max_rows is not explicitly provided.
"""
if start_row is None:
raise ValueError('start_row is required')
if max_rows is None:
raise ValueError('max_rows is required')
table_reader = bq_table_reader.TableTableReader(
apiclient, max_rows_per_request, table_ref
)
return table_reader.ReadSchemaAndRows(
start_row,
max_rows,
selected_fields=selected_fields,
)
def list_tables(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.DatasetReference,
max_results: Optional[int] = None,
page_token: Optional[str] = None,
):
"""List the tables associated with this reference."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.DatasetReference,
method='list_tables',
)
request = bq_processor_utils.PrepareListRequest(
reference, max_results, page_token
)
result = apiclient.tables().list(**request).execute()
results = result.get('tables', [])
if max_results is not None:
while 'nextPageToken' in result and len(results) < max_results:
request['maxResults'] = max_results - len(results)
request['pageToken'] = result['nextPageToken']
result = apiclient.tables().list(**request).execute()
results.extend(result.get('tables', []))
return results
def get_table_iam_policy(
iampolicy_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
):
"""Gets IAM policy for the given table resource.
Arguments:
iampolicy_client: The apiclient used to make the request.
reference: the TableReference for the table resource.
Returns:
The IAM policy attached to the given table resource.
Raises:
BigqueryTypeError: if reference is not a TableReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='get_table_iam_policy',
)
formatted_resource = 'projects/%s/datasets/%s/tables/%s' % (
reference.projectId,
reference.datasetId,
reference.tableId,
)
return (
iampolicy_client.tables()
.getIamPolicy(resource=formatted_resource)
.execute()
)
def set_table_iam_policy(
iampolicy_client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
policy,
):
"""Sets IAM policy for the given table resource.
Arguments:
iampolicy_client: The apiclient used to make the request.
reference: the TableReference for the table resource.
policy: The policy string in JSON format.
Returns:
The updated IAM policy attached to the given table resource.
Raises:
BigqueryTypeError: if reference is not a TableReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='set_table_iam_policy',
)
formatted_resource = 'projects/%s/datasets/%s/tables/%s' % (
reference.projectId,
reference.datasetId,
reference.tableId,
)
request = {'policy': policy}
return (
iampolicy_client.tables()
.setIamPolicy(body=request, resource=formatted_resource)
.execute()
)
def get_table_region(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
) -> Optional[str]:
"""Returns the region of a table as a string."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='get_table_region',
)
try:
return apiclient.tables().get(**dict(reference)).execute()['location']
except bq_error.BigqueryNotFoundError:
return None
def table_exists(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
):
"""Returns true if the table exists."""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='table_exists',
)
try:
return apiclient.tables().get(**dict(reference)).execute()
except bq_error.BigqueryNotFoundError:
return False
def create_table(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
ignore_existing: bool = False,
schema: Optional[str] = None,
description: Optional[str] = None,
display_name: Optional[str] = None,
expiration: Optional[int] = None,
view_query: Optional[str] = None,
materialized_view_query: Optional[str] = None,
enable_refresh: Optional[bool] = None,
refresh_interval_ms: Optional[int] = None,
max_staleness: Optional[str] = None,
external_data_config=None,
biglake_config=None,
external_catalog_table_options=None,
view_udf_resources=None,
use_legacy_sql: Optional[bool] = None,
labels: Optional[Dict[str, str]] = None,
time_partitioning=None,
clustering: Optional[Dict[str, List[str]]] = None,
range_partitioning=None,
require_partition_filter: Optional[bool] = None,
destination_kms_key: Optional[str] = None,
location: Optional[str] = None,
table_constraints: Optional[str] = None,
resource_tags: Optional[Dict[str, str]] = None,
):
"""Create a table corresponding to TableReference.
Args:
apiclient: The apiclient used to make the request.
reference: the TableReference to create.
ignore_existing: (boolean, default False) If False, raise an exception if
the dataset already exists.
schema: an optional schema for tables.
description: an optional description for tables or views.
display_name: an optional friendly name for the table.
expiration: optional expiration time in milliseconds since the epoch for
tables or views.
view_query: an optional Sql query for views.
materialized_view_query: an optional standard SQL query for materialized
views.
enable_refresh: for materialized views, an optional toggle to enable /
disable automatic refresh when the base table is updated.
refresh_interval_ms: for materialized views, an optional maximum frequency
for automatic refreshes.
max_staleness: INTERVAL value that determines the maximum staleness allowed
when querying a materialized view or an external table. By default no
staleness is allowed.
external_data_config: defines a set of external resources used to create an
external table. For example, a BigQuery table backed by CSV files in GCS.
biglake_config: specifies the configuration of a BigLake managed table.
external_catalog_table_options: Specifies the configuration of an external
catalog table.
view_udf_resources: optional UDF resources used in a view.
use_legacy_sql: The choice of using Legacy SQL for the query is optional. If
not specified, the server will automatically determine the dialect based
on query information, such as dialect prefixes. If no prefixes are found,
it will default to Legacy SQL.
labels: an optional dict of labels to set on the table.
time_partitioning: if set, enables time based partitioning on the table and
configures the partitioning.
clustering: if set, enables and configures clustering on the table.
range_partitioning: if set, enables range partitioning on the table and
configures the partitioning.
require_partition_filter: if set, partition filter is required for queiries
over this table.
destination_kms_key: User specified KMS key for encryption.
location: an optional location for which to create tables or views.
table_constraints: an optional primary key and foreign key configuration for
the table.
resource_tags: an optional dict of tags to attach to the table.
Raises:
BigqueryTypeError: if reference is not a TableReference.
BigqueryDuplicateError: if reference exists and ignore_existing
is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='create_table',
)
try:
body = bq_processor_utils.ConstructObjectInfo(reference)
if schema is not None:
body['schema'] = {'fields': schema}
if display_name is not None:
body['friendlyName'] = display_name
if description is not None:
body['description'] = description
if expiration is not None:
body['expirationTime'] = expiration
if view_query is not None:
view_args = {'query': view_query}
if view_udf_resources is not None:
view_args['userDefinedFunctionResources'] = view_udf_resources
body['view'] = view_args
if use_legacy_sql is not None:
view_args['useLegacySql'] = use_legacy_sql
if materialized_view_query is not None:
materialized_view_args = {'query': materialized_view_query}
if enable_refresh is not None:
materialized_view_args['enableRefresh'] = enable_refresh
if refresh_interval_ms is not None:
materialized_view_args['refreshIntervalMs'] = refresh_interval_ms
body['materializedView'] = materialized_view_args
if external_data_config is not None:
if max_staleness is not None:
body['maxStaleness'] = max_staleness
body['externalDataConfiguration'] = external_data_config
if biglake_config is not None:
body['biglakeConfiguration'] = biglake_config
if external_catalog_table_options is not None:
body['externalCatalogTableOptions'] = bq_frontend_utils.GetJson(
external_catalog_table_options
)
if labels is not None:
body['labels'] = labels
if time_partitioning is not None:
body['timePartitioning'] = time_partitioning
if clustering is not None:
body['clustering'] = clustering
if range_partitioning is not None:
body['rangePartitioning'] = range_partitioning
if require_partition_filter is not None:
body['requirePartitionFilter'] = require_partition_filter
if destination_kms_key is not None:
body['encryptionConfiguration'] = {'kmsKeyName': destination_kms_key}
if location is not None:
body['location'] = location
if table_constraints is not None:
body['table_constraints'] = table_constraints
if resource_tags is not None:
body['resourceTags'] = resource_tags
_execute_insert_table_request(apiclient, reference, body)
except bq_error.BigqueryDuplicateError:
if not ignore_existing:
raise
def update_table(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
schema=None,
description: Optional[str] = None,
display_name: Optional[str] = None,
expiration: Optional[int] = None,
view_query: Optional[str] = None,
materialized_view_query: Optional[str] = None,
enable_refresh: Optional[bool] = None,
refresh_interval_ms: Optional[int] = None,
max_staleness: Optional[str] = None,
external_data_config=None,
external_catalog_table_options=None,
view_udf_resources=None,
use_legacy_sql: Optional[bool] = None,
labels_to_set: Optional[Dict[str, str]] = None,
label_keys_to_remove: Optional[List[str]] = None,
time_partitioning=None,
range_partitioning=None,
clustering: Optional[Dict[str, List[str]]] = None,
require_partition_filter: Optional[bool] = None,
etag: Optional[str] = None,
encryption_configuration=None,
location: Optional[str] = None,
autodetect_schema: bool = False,
table_constraints=None,
tags_to_attach: Optional[Dict[str, str]] = None,
tags_to_remove: Optional[List[str]] = None,
clear_all_tags: bool = False,
):
"""Updates a table.
Args:
apiclient: The apiclient used to make the request.
reference: the TableReference to update.
schema: an optional schema for tables.
description: an optional description for tables or views.
display_name: an optional friendly name for the table.
expiration: optional expiration time in milliseconds since the epoch for
tables or views. Specifying 0 removes expiration time.
view_query: an optional Sql query to update a view.
materialized_view_query: an optional Standard SQL query for materialized
views.
enable_refresh: for materialized views, an optional toggle to enable /
disable automatic refresh when the base table is updated.
refresh_interval_ms: for materialized views, an optional maximum frequency
for automatic refreshes.
max_staleness: INTERVAL value that determines the maximum staleness allowed
when querying a materialized view or an external table. By default no
staleness is allowed.
external_data_config: defines a set of external resources used to create an
external table. For example, a BigQuery table backed by CSV files in GCS.
external_catalog_table_options: Specifies the configuration of an external
catalog table.
view_udf_resources: optional UDF resources used in a view.
use_legacy_sql: The choice of using Legacy SQL for the query is optional. If
not specified, the server will automatically determine the dialect based
on query information, such as dialect prefixes. If no prefixes are found,
it will default to Legacy SQL.
labels_to_set: an optional dict of labels to set on this table.
label_keys_to_remove: an optional list of label keys to remove from this
table.
time_partitioning: if set, enables time based partitioning on the table and
configures the partitioning.
range_partitioning: if set, enables range partitioning on the table and
configures the partitioning.
clustering: if set, enables clustering on the table and configures the
clustering spec.
require_partition_filter: if set, partition filter is required for queiries
over this table.
etag: if set, checks that etag in the existing table matches.
encryption_configuration: Updates the encryption configuration.
location: an optional location for which to update tables or views.
autodetect_schema: an optional flag to perform autodetect of file schema.
table_constraints: an optional primary key and foreign key configuration for
the table.
tags_to_attach: an optional dict of tags to attach to the table
tags_to_remove: an optional list of tag keys to remove from the table
clear_all_tags: if set, clears all the tags attached to the table
Raises:
BigqueryTypeError: if reference is not a TableReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='update_table',
)
existing_table = {}
if clear_all_tags:
# getting existing table. This is required to clear all tags attached to
# a table. Adding this at the start of the method as this can also be
# used for other scenarios
existing_table = _execute_get_table_request(
apiclient=apiclient, reference=reference
)
table = bq_processor_utils.ConstructObjectInfo(reference)
maybe_skip_schema = False
if schema is not None:
table['schema'] = {'fields': schema}
elif not maybe_skip_schema:
table['schema'] = None
if encryption_configuration is not None:
table['encryptionConfiguration'] = encryption_configuration
if display_name is not None:
table['friendlyName'] = display_name
if description is not None:
table['description'] = description
if expiration is not None:
if expiration == 0:
table['expirationTime'] = None
else:
table['expirationTime'] = expiration
if view_query is not None:
view_args = {'query': view_query}
if view_udf_resources is not None:
view_args['userDefinedFunctionResources'] = view_udf_resources
if use_legacy_sql is not None:
view_args['useLegacySql'] = use_legacy_sql
table['view'] = view_args
materialized_view_args = {}
if materialized_view_query is not None:
materialized_view_args['query'] = materialized_view_query
if enable_refresh is not None:
materialized_view_args['enableRefresh'] = enable_refresh
if refresh_interval_ms is not None:
materialized_view_args['refreshIntervalMs'] = refresh_interval_ms
if materialized_view_args:
table['materializedView'] = materialized_view_args
if external_data_config is not None:
table['externalDataConfiguration'] = external_data_config
if max_staleness is not None:
table['maxStaleness'] = max_staleness
if 'labels' not in table:
table['labels'] = {}
table_labels = cast(Dict[str, Optional[str]], table['labels'])
if table_labels is None:
raise ValueError('Missing labels in table.')
if labels_to_set:
for label_key, label_value in labels_to_set.items():
table_labels[label_key] = label_value
if label_keys_to_remove:
for label_key in label_keys_to_remove:
table_labels[label_key] = None
if time_partitioning is not None:
table['timePartitioning'] = time_partitioning
if range_partitioning is not None:
table['rangePartitioning'] = range_partitioning
if clustering is not None:
if clustering == {}: # pylint: disable=g-explicit-bool-comparison
table['clustering'] = None
else:
table['clustering'] = clustering
if require_partition_filter is not None:
table['requirePartitionFilter'] = require_partition_filter
if location is not None:
table['location'] = location
if table_constraints is not None:
table['table_constraints'] = table_constraints
resource_tags = {}
if clear_all_tags and 'resourceTags' in existing_table:
for tag in existing_table['resourceTags']:
resource_tags[tag] = None
else:
for tag in tags_to_remove or []:
resource_tags[tag] = None
for tag in tags_to_attach or {}:
resource_tags[tag] = tags_to_attach[tag]
# resourceTags is used to add a new tag binding, update value of existing
# tag and also to remove a tag binding
# check go/bq-table-tags-api for details
table['resourceTags'] = resource_tags
if external_catalog_table_options is not None:
existing_table = _execute_get_table_request(
apiclient=apiclient, reference=reference
)
existing_table.setdefault(_EXTERNAL_CATALOG_TABLE_OPTIONS_FIELD_NAME, {})
table[_EXTERNAL_CATALOG_TABLE_OPTIONS_FIELD_NAME] = (
bq_frontend_utils.UpdateExternalCatalogTableOptions(
existing_table[_EXTERNAL_CATALOG_TABLE_OPTIONS_FIELD_NAME],
external_catalog_table_options,
)
)
_execute_patch_table_request(
apiclient=apiclient,
reference=reference,
table=table,
autodetect_schema=autodetect_schema,
etag=etag,
)
def _execute_get_table_request(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
):
return apiclient.tables().get(**dict(reference)).execute()
def _execute_patch_table_request(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
table,
autodetect_schema: bool = False,
etag: Optional[str] = None,
):
"""Executes request to patch table.
Args:
apiclient: The apiclient used to make the request.
reference: the TableReference to patch.
table: the body of request
autodetect_schema: an optional flag to perform autodetect of file schema.
etag: if set, checks that etag in the existing table matches.
"""
request = apiclient.tables().patch(
autodetect_schema=autodetect_schema, body=table, **dict(reference)
)
# Perform a conditional update to protect against concurrent
# modifications to this table. If there is a conflicting
# change, this update will fail with a "Precondition failed"
# error.
if etag:
request.headers['If-Match'] = etag if etag else table['etag']
request.execute()
def _execute_insert_table_request(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
body,
):
apiclient.tables().insert(
body=body, **dict(reference.GetDatasetReference())
).execute()
def delete_table(
apiclient: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.TableReference,
ignore_not_found: bool = False,
):
"""Deletes TableReference reference.
Args:
apiclient: The apiclient used to make the request.
reference: the TableReference to delete.
ignore_not_found: Whether to ignore "not found" errors.
Raises:
BigqueryTypeError: if reference is not a TableReference.
bq_error.BigqueryNotFoundError: if reference does not exist and
ignore_not_found is False.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.TableReference,
method='delete_table',
)
try:
apiclient.tables().delete(**dict(reference)).execute()
except bq_error.BigqueryNotFoundError:
if not ignore_not_found:
raise

View File

@@ -0,0 +1,322 @@
#!/usr/bin/env python
"""The different TableReader options for the BQ CLI."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional
from googleapiclient import discovery
from utils import bq_error
from utils import bq_id_utils
class _TableReader:
"""Base class that defines the TableReader interface.
_TableReaders provide a way to read paginated rows and schemas from a table.
"""
def ReadRows(
self,
start_row: Optional[int] = 0,
max_rows: Optional[int] = None,
selected_fields: Optional[str] = None,
):
"""Read at most max_rows rows from a table.
Args:
start_row: first row to return.
max_rows: maximum number of rows to return.
selected_fields: a subset of fields to return.
Raises:
BigqueryInterfaceError: when bigquery returns something unexpected.
Returns:
list of rows, each of which is a list of field values.
"""
(_, rows) = self.ReadSchemaAndRows(
start_row=start_row,
max_rows=max_rows,
selected_fields=selected_fields,
)
return rows
def ReadSchemaAndRows(
self,
start_row: Optional[int],
max_rows: Optional[int],
selected_fields: Optional[str] = None,
):
"""Read at most max_rows rows from a table and the schema.
Args:
start_row: first row to read.
max_rows: maximum number of rows to return.
selected_fields: a subset of fields to return.
Raises:
BigqueryInterfaceError: when bigquery returns something unexpected.
ValueError: when start_row is None.
ValueError: when max_rows is None.
Returns:
A tuple where the first item is the list of fields and the
second item a list of rows.
"""
if start_row is None:
raise ValueError('start_row is required')
if max_rows is None:
raise ValueError('max_rows is required')
page_token = None
rows = []
schema = {}
while len(rows) < max_rows:
rows_to_read = max_rows - len(rows)
if not hasattr(self, 'max_rows_per_request'):
raise NotImplementedError(
'Subclass must have max_rows_per_request instance variable'
)
if self.max_rows_per_request:
rows_to_read = min(self.max_rows_per_request, rows_to_read)
(more_rows, page_token, current_schema) = self._ReadOnePage(
None if page_token else start_row,
max_rows=rows_to_read,
page_token=page_token,
selected_fields=selected_fields,
)
if not schema and current_schema:
schema = current_schema.get('fields', [])
for row in more_rows:
rows.append(self._ConvertFromFV(schema, row))
start_row += 1
if not page_token or not more_rows:
break
return (schema, rows)
def _ConvertFromFV(self, schema, row):
"""Converts from FV format to possibly nested lists of values."""
if not row:
return None
values = [entry.get('v', '') for entry in row.get('f', [])]
result = []
for field, v in zip(schema, values):
if 'type' not in field:
raise bq_error.BigqueryCommunicationError(
'Invalid response: missing type property'
)
if field['type'].upper() == 'RECORD':
# Nested field.
subfields = field.get('fields', [])
if field.get('mode', 'NULLABLE').upper() == 'REPEATED':
# Repeated and nested. Convert the array of v's of FV's.
result.append([
self._ConvertFromFV(subfields, subvalue.get('v', ''))
for subvalue in v
])
else:
# Nested non-repeated field. Convert the nested f from FV.
result.append(self._ConvertFromFV(subfields, v))
elif field.get('mode', 'NULLABLE').upper() == 'REPEATED':
# Repeated but not nested: an array of v's.
result.append([subvalue.get('v', '') for subvalue in v])
else:
# Normal flat field.
result.append(v)
return result
def __str__(self) -> str:
return self._GetPrintContext()
def __repr__(self) -> str:
return self._GetPrintContext()
def _GetPrintContext(self) -> str:
"""Returns context for what is being read."""
raise NotImplementedError('Subclass must implement GetPrintContext')
def _ReadOnePage(
self,
start_row: Optional[int],
max_rows: Optional[int],
page_token: Optional[str] = None,
selected_fields: Optional[str] = None,
):
"""Read one page of data, up to max_rows rows.
Assumes that the table is ready for reading. Will signal an error otherwise.
Args:
start_row: first row to read.
max_rows: maximum number of rows to return.
page_token: Optional. current page token.
selected_fields: a subset of field to return.
Returns:
tuple of:
rows: the actual rows of the table, in f,v format.
page_token: the page token of the next page of results.
schema: the schema of the table.
"""
raise NotImplementedError('Subclass must implement _ReadOnePage')
class TableTableReader(_TableReader):
"""A TableReader that reads from a table."""
def __init__(
self,
local_apiclient: discovery.Resource,
max_rows_per_request: int,
table_ref: bq_id_utils.ApiClientHelper.TableReference,
):
self.table_ref = table_ref
self.max_rows_per_request = max_rows_per_request
self._apiclient = local_apiclient
def _GetPrintContext(self) -> str:
return '%r' % (self.table_ref,)
def _ReadOnePage(
self,
start_row: Optional[int],
max_rows: Optional[int],
page_token: Optional[str] = None,
selected_fields: Optional[str] = None,
):
kwds = dict(self.table_ref)
kwds['maxResults'] = max_rows
if page_token:
kwds['pageToken'] = page_token
else:
kwds['startIndex'] = start_row
data = None
if selected_fields is not None:
kwds['selectedFields'] = selected_fields
if data is None:
data = self._apiclient.tabledata().list(**kwds).execute()
page_token = data.get('pageToken', None)
rows = data.get('rows', [])
kwds = dict(self.table_ref)
if selected_fields is not None:
kwds['selectedFields'] = selected_fields
table_info = self._apiclient.tables().get(**kwds).execute()
schema = table_info.get('schema', {})
return (rows, page_token, schema)
class JobTableReader(_TableReader):
"""A TableReader that reads from a completed job."""
def __init__(
self,
local_apiclient: discovery.Resource,
max_rows_per_request: int,
job_ref: bq_id_utils.ApiClientHelper.JobReference,
):
self.job_ref = job_ref
self.max_rows_per_request = max_rows_per_request
self._apiclient = local_apiclient
def _GetPrintContext(self) -> str:
return '%r' % (self.job_ref,)
def _ReadOnePage(
self,
start_row: Optional[int],
max_rows: Optional[int],
page_token: Optional[str] = None,
selected_fields: Optional[str] = None,
):
kwds = dict(self.job_ref)
kwds['maxResults'] = max_rows
# Sets the timeout to 0 because we assume the table is already ready.
kwds['timeoutMs'] = 0
if page_token:
kwds['pageToken'] = page_token
else:
kwds['startIndex'] = start_row
data = self._apiclient.jobs().getQueryResults(**kwds).execute()
if not data['jobComplete']:
raise bq_error.BigqueryError('Job %s is not done' % (self,))
page_token = data.get('pageToken', None)
schema = data.get('schema', None)
rows = data.get('rows', [])
return (rows, page_token, schema)
class QueryTableReader(_TableReader):
"""A TableReader that reads from a completed query."""
def __init__(
self,
local_apiclient: discovery.Resource,
max_rows_per_request: int,
job_ref: bq_id_utils.ApiClientHelper.JobReference,
results,
):
self.job_ref = job_ref
self.max_rows_per_request = max_rows_per_request
self._apiclient = local_apiclient
self._results = results
def _GetPrintContext(self) -> str:
return '%r' % (self.job_ref,)
def _ReadOnePage(
self,
start_row: Optional[int],
max_rows: Optional[int],
page_token: Optional[str] = None,
selected_fields: Optional[str] = None,
):
kwds = dict(self.job_ref) if self.job_ref else {}
kwds['maxResults'] = max_rows
# Sets the timeout to 0 because we assume the table is already ready.
kwds['timeoutMs'] = 0
if page_token:
kwds['pageToken'] = page_token
else:
kwds['startIndex'] = start_row
if not self._results['jobComplete']:
raise bq_error.BigqueryError('Job %s is not done' % (self,))
# DDL and DML statements return no rows, just delegate them to
# getQueryResults.
result_rows = self._results.get('rows', None)
total_rows = self._results.get('totalRows', None)
job_reference = self._results.get('jobReference', None)
if job_reference is None and (
total_rows is not None and int(total_rows) == 0
):
# Handle the case when jobs.query requests with JOB_CREATION_OPTIONAL
# return empty results. This will avoid a call to getQueryResults.
schema = self._results.get('schema', None)
rows = self._results.get('rows', [])
page_token = None
elif (
total_rows is not None
and result_rows is not None
and start_row is not None
and len(result_rows) >= min(int(total_rows), start_row + max_rows)
):
page_token = self._results.get('pageToken', None)
if len(result_rows) < int(total_rows) and page_token is None:
raise bq_error.BigqueryError(
'Synchronous query %s did not return all rows, yet it did not'
' return a page token' % (self,)
)
schema = self._results.get('schema', None)
rows = self._results.get('rows', [])
else:
data = self._apiclient.jobs().getQueryResults(**kwds).execute()
if not data['jobComplete']:
raise bq_error.BigqueryError('Job %s is not done' % (self,))
page_token = data.get('pageToken', None)
schema = data.get('schema', None)
rows = data.get('rows', [])
return (rows, page_token, schema)

View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python
"""BQ CLI library for wait printers."""
import logging
import sys
import time
from typing import Optional
import googleapiclient
import httplib2
from clients import utils as bq_client_utils
def _overwrite_current_line(
s: str, previous_token: Optional[int] = None
) -> int:
"""Print string over the current terminal line, and stay on that line.
The full width of any previous output (by the token) will be wiped clean.
If multiple callers call this at the same time, it would be bad.
Args:
s: string to print. May not contain newlines.
previous_token: token returned from previous call, or None on first call.
Returns:
a token to pass into your next call to this function.
"""
# Tricks in use:
# carriage return \r brings the printhead back to the start of the line.
# sys.stdout.write() does not add a newline.
# Erase any previous, in case new string is shorter.
if previous_token is not None:
sys.stderr.write('\r' + (' ' * previous_token))
# Put new string.
sys.stderr.write('\r' + s)
# Display.
sys.stderr.flush()
return len(s)
def execute_in_chunks_with_progress(request) -> None:
"""Run an apiclient request with a resumable upload, showing progress.
Args:
request: an apiclient request having a media_body that is a
MediaFileUpload(resumable=True).
Returns:
The result of executing the request, if it succeeds.
Raises:
BigQueryError: on a non-retriable error or too many retriable errors.
"""
result = None
retriable_errors = 0
output_token = None
status = None
while result is None:
try:
status, result = request.next_chunk()
except googleapiclient.errors.HttpError as e:
logging.error(
'HTTP Error %d during resumable media upload', e.resp.status
)
# Log response headers, which contain debug info for GFEs.
for key, value in e.resp.items():
logging.info(' %s: %s', key, value)
if e.resp.status in [502, 503, 504]:
sleep_sec = 2**retriable_errors
retriable_errors += 1
if retriable_errors > 3:
raise
print('Error %d, retry #%d' % (e.resp.status, retriable_errors))
time.sleep(sleep_sec)
# Go around and try again.
else:
bq_client_utils.RaiseErrorFromHttpError(e)
except (httplib2.HttpLib2Error, IOError) as e:
bq_client_utils.RaiseErrorFromNonHttpError(e)
if status:
output_token = _overwrite_current_line(
'Uploaded %d%%... ' % int(status.progress() * 100), output_token
)
_overwrite_current_line('Upload complete.', output_token)
sys.stderr.write('\n')
return result
class WaitPrinter:
"""Base class that defines the WaitPrinter interface."""
def print(self, job_id: str, wait_time: float, status: str) -> None:
"""Prints status for the current job we are waiting on.
Args:
job_id: the identifier for this job.
wait_time: the number of seconds we have been waiting so far.
status: the status of the job we are waiting for.
"""
raise NotImplementedError('Subclass must implement Print')
def done(self) -> None:
"""Waiting is done and no more Print calls will be made.
This function should handle the case of Print not being called.
"""
raise NotImplementedError('Subclass must implement Done')
class WaitPrinterHelper(WaitPrinter):
"""A Done implementation that prints based off a property."""
print_on_done = False
def done(self) -> None:
if self.print_on_done:
sys.stderr.write('\n')
class QuietWaitPrinter(WaitPrinterHelper):
"""A WaitPrinter that prints nothing."""
def print(
self, unused_job_id: str, unused_wait_time: float, unused_status: str
):
pass
class VerboseWaitPrinter(WaitPrinterHelper):
"""A WaitPrinter that prints every update."""
def __init__(self):
self.output_token = None
def print(self, job_id: str, wait_time: float, status: str) -> None:
self.print_on_done = True
self.output_token = _overwrite_current_line(
'Waiting on %s ... (%ds) Current status: %-7s'
% (job_id, wait_time, status),
self.output_token,
)
class TransitionWaitPrinter(VerboseWaitPrinter):
"""A WaitPrinter that only prints status change updates."""
_previous_status = None
def print(self, job_id: str, wait_time: float, status: str) -> None:
if status != self._previous_status:
self._previous_status = status
super(TransitionWaitPrinter, self).print(job_id, wait_time, status)